🏃♂️ Quick Start⚓︎
Everything you need to get started: install, write a script, launch it, track metrics, and a complete API cheat sheet with before/after diffs.
For a complete runnable example with terminal output, see the Reference.
📦 Install⚓︎
Editable install for development
Try without installing via uv run
If you already have a Python environment with
{torch, mpi4py} installed, you can try ezpz without installing
it:
Verify: ezpz test
After installing, run a quick smoke test to verify distributed functionality and device detection:
This trains a simple MLP on MNIST using DDP and reports timing metrics. See the W&B Report for example output.
[Optional] Shell Environment and Setup
-
Deprecation Notice
I plan to deprecate
utils.shin favor of a uv native approach. This shell script was originally developed for personal use, and I don't plan to officially support this script in the long term.-
ezpz/bin/
utils.sh: Shell script containing a collection of functions that I've accumulated over time and found to be useful. To use these, we can source the file directly from the command line: -
savejobenv: Shell function that will save relevant job-specific environment variables to a file for later use. -
Running from a PBS job:
savejobenvwill automatically save the relevant environment variables to a file~/.pbsenvwhich can be later sourced viasource ~/.pbsenvfrom, e.g., another terminal:
-
🚂 Distributed Training Script⚓︎
import ezpz
import torch
rank = ezpz.setup_torch() # auto-detects device + backend, returns global rank
device = ezpz.get_torch_device()
model = torch.nn.Linear(128, 64).to(device)
model = ezpz.wrap_model(
model,
use_fsdp=True, # False will use DDP
dtype=torch.bfloat16,
)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(100):
x = torch.randn(32, 128, device=device)
loss = model(x).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
ezpz.cleanup()
That's it — ezpz detects the available backend, initializes the process
group, wraps your model in FSDP/DDP, and assigns each rank to the correct
device.
ezpz.synchronize()
Use ezpz.synchronize() instead of torch.cuda.synchronize() to get
correct timing measurements on any backend (CUDA, XPU, MPS, CPU).
🚀 Launch It⚓︎
This single command works everywhere because the launcher detects the active job scheduler automatically:
| Environment | What ezpz launch runs |
|---|---|
| PBS job (Polaris, Aurora, Sunspot) | mpiexec with hostfile from $PBS_NODEFILE |
| SLURM job (Frontier, Perlmutter) | srun with SLURM topology |
| No scheduler (laptop / workstation) | mpirun -np <ngpus> fallback |
Flag aliases
-n, -np, and --nproc are all aliases for the same flag (number of
processes). Similarly, -ppn and --nproc_per_node are aliases.
Advanced Launcher Examples
To pass arguments through to the launcher1:
$ ezpz launch -- python3 -m ezpz.examples.fsdp
# pass --line-buffer through to mpiexec:
$ ezpz launch --line-buffer -- python3 \
-m ezpz.examples.vit --compile --fsdp
# Create and use a custom hostfile
$ head -n 2 "${PBS_NODEFILE}" > hostfile0-2
$ ezpz launch --hostfile hostfile0-2 -- python3 \
-m ezpz.examples.fsdp_tp
# use explicit np/ppn/nhosts
$ ezpz launch \
-np 4 \
-ppn 2 \
--nhosts 2 \
--hostfile hostfile0-2 \
-- \
python3 -m ezpz.examples.diffusion
# forward the PYTHONPATH environment variable
$ ezpz launch -x PYTHONPATH=/tmp/.venv/bin:${PYTHONPATH} \
-- \
python3 -m ezpz.examples.fsdp
For pass-through launcher flags, custom hostfiles, and advanced usage, see the Reference launcher section and the CLI reference.
🛠️ API Cheat Sheet⚓︎
Each ezpz component can be used independently — pick only what you need.
Setup & Distributed Init⚓︎
- import os, torch.distributed as dist
- dist.init_process_group(backend="nccl", ...)
- rank = int(os.environ["RANK"])
- local_rank = int(os.environ["LOCAL_RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
+ import ezpz
+ rank = ezpz.setup_torch() # returns global rank
+ local_rank = ezpz.get_local_rank()
+ world_size = ezpz.get_world_size()
Device Management⚓︎
- device = torch.device("cuda")
- model.to("cuda")
- batch = batch.to("cuda")
+ device = ezpz.get_torch_device() # cuda, xpu, mps, or cpu
+ model.to(device)
+ batch = batch.to(device)
Model Wrapping⚓︎
- from torch.nn.parallel import DistributedDataParallel as DDP
- model = DDP(model, device_ids=[local_rank], output_device=local_rank)
+ model = ezpz.wrap_model(model) # DDP (default)
+ model = ezpz.wrap_model(model, use_fsdp=True) # FSDP
Training Loop⚓︎
for step, batch in enumerate(dataloader):
- batch = batch.to("cuda")
+ batch = batch.to(ezpz.get_torch_device())
t0 = time.perf_counter()
loss = train_step(...)
- torch.cuda.synchronize()
+ ezpz.synchronize()
dt = time.perf_counter() - t0
Metric Tracking⚓︎
import ezpz
logger = ezpz.get_logger(__name__)
ezpz.setup_wandb(project_name="my-project") # optional
history = ezpz.History()
for step in range(100):
loss = train_step(...)
logger.info(history.update({"step": step, "loss": loss.item()}))
history.finalize(outdir="./outputs") # saves dataset + plots
For full details on History, see the
Metric Tracking guide.
📊 Track Metrics⚓︎
Use the built-in History class to record, save, and plot training metrics:
from ezpz.history import History
# Optional: enable W&B logging before creating History
ezpz.setup_wandb(project_name="my-project")
history = History()
for step in range(100):
# ... training loop ...
metrics = {"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}
# update() returns a summary string suitable for logging
summary = history.update(metrics)
logger.info(summary) # e.g. "loss=0.123456 lr=0.001000"
history.finalize(outdir="./outputs") # saves dataset + generates plots
History automatically computes distributed statistics (min, max, mean, std)
across all ranks for every recorded metric — no extra code needed on worker
ranks.
What finalize produces
Calling history.finalize() writes a summary dataset and generates
loss curves and other plots — ready for inspection or inclusion in
reports. See the Reference complete example
for sample output with terminal plots.
For the full History API — distributed aggregation, environment variables,
StopWatch, and more — see the Metric Tracking guide.
🔗 Next Steps⚓︎
- Reference — complete runnable example with terminal output
- Metric Tracking — full
Historyguide: distributed stats, W&B, plots - Examples — end-to-end training scripts (FSDP, ViT, Diffusion, etc.)
- CLI Reference — full
ezpz launchusage and flags - Configuration — environment variables and config dataclasses
- Architecture — how
ezpzworks under the hood
-
This will be
srunif a Slurm scheduler is detected,mpirun/mpiexecotherwise. ↩