Minimal Training (Synthetic Data)⚓︎
Deprecated — use ezpz.examples.test instead
This example is deprecated and will be removed in a future release.
New users should start with ezpz.examples.test, which
covers the same concepts with better defaults.
The walkthrough below is preserved for reference only.
The simplest ezpz example — trains an MLP to reconstruct random inputs using env-var configuration. No dataset downloads required.
Key API Functions
setup_torch()— Initialize distributed trainingwrap_model_for_ddp()— Wrap model for DDPHistory— Track and finalize metrics- [
get_logger()][ezpz.log.config.get_logger] — Rank-aware logging
See:
Source⚓︎
src/ezpz/examples/minimal.py
"""Minimal synthetic training loop for testing distributed setup and logging.
This example builds a tiny MLP that learns to reconstruct random inputs.
Launch it with:
ezpz launch -m ezpz.examples.minimal
Running ``python3 -m ezpz.examples.minimal --help`` prints:
usage: ezpz.examples.minimal --help
(Set env vars such as PRINT_ITERS=100 TRAIN_ITERS=1000 INPUT_SIZE=128 OUTPUT_SIZE=128 LAYER_SIZES=\"128,256,128\" before calling ezpz launch)
"""
import os
import time
from pathlib import Path
import torch
import ezpz
from ezpz.examples import get_example_outdir
from ezpz.flops import compute_mfu, try_estimate
logger = ezpz.get_logger(__name__)
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
outdir: os.PathLike | str,
) -> ezpz.History:
"""Run a synthetic training loop on random data.
Args:
model: Model to train (wrapped or unwrapped).
optimizer: Optimizer configured for the model.
Returns:
Training history with timing and loss metrics.
"""
unwrapped_model = (
model.module
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model
)
metrics_path = Path(outdir).joinpath("metrics.jsonl")
history = ezpz.History(
project_name=os.environ.get("PROJECT_NAME", "ezpz.examples.minimal"),
config={"model": str(unwrapped_model), "outdir": str(outdir)},
outdir=outdir,
report_dir=outdir,
report_enabled=True,
jsonl_path=metrics_path,
jsonl_overwrite=True,
distributed_history=(1 < ezpz.get_world_size() <= 384),
)
device_type = ezpz.get_torch_device_type()
dtype = unwrapped_model.layers[0].weight.dtype
bsize = int(os.environ.get("BATCH_SIZE", 64))
isize = unwrapped_model.layers[0].in_features
_model_flops = try_estimate(model, (bsize, isize))
warmup = int(os.environ.get("WARMUP_ITERS", 10))
log_freq = int(os.environ.get("LOG_FREQ", 1))
print_freq = int(os.environ.get("PRINT_FREQ", 10))
model.train()
summary = ""
for step in range(int(os.environ.get("TRAIN_ITERS", 500))):
with torch.autocast(
device_type=device_type,
dtype=dtype,
):
t0 = time.perf_counter()
x = torch.rand((bsize, isize), dtype=dtype).to(device_type)
y = model(x)
loss = ((y - x) ** 2).sum()
dtf = (t1 := time.perf_counter()) - t0
loss.backward()
optimizer.step()
optimizer.zero_grad()
dtb = time.perf_counter() - t1
if step % log_freq == 0 and step > warmup:
metrics = {
"iter": step,
"loss": loss.item(),
"dt": dtf + dtb,
"dtf": dtf,
"dtb": dtb,
}
if _model_flops > 0:
dt_total = dtf + dtb
if dt_total > 0:
metrics["tflops"] = _model_flops / dt_total / 1e12
metrics["mfu"] = compute_mfu(_model_flops, dt_total)
# Device memory: empty on CPU/MPS, 4 keys on CUDA/XPU.
metrics |= ezpz.get_memory_metrics()
summary = history.update(
metrics
)
if step % print_freq == 0 and step > warmup:
logger.info(summary)
return history
@ezpz.timeitlogit(rank=ezpz.get_rank())
def setup():
"""Initialize distributed runtime, model, and optimizer."""
ezpz.setup_torch(seed=int(os.environ.get("SEED", 0)))
device_type = ezpz.get_torch_device_type()
from ezpz.models.minimal import SequentialLinearNet
model = SequentialLinearNet(
input_dim=int((os.environ.get("INPUT_SIZE", 128))),
output_dim=int(os.environ.get("OUTPUT_SIZE", 128)),
sizes=[
int(x)
for x in os.environ.get(
"LAYER_SIZES", "256,512,1024,2048,1024,512,256,128"
).split(",")
],
)
model.to(device_type)
model.to((os.environ.get("DTYPE", torch.bfloat16)))
try:
from ezpz.utils import model_summary
model_summary(model)
except Exception:
logger.exception("Failed to summarize model")
logger.info(f"{model=}")
optimizer = torch.optim.Adam(model.parameters())
if ezpz.get_world_size() > 1:
model = ezpz.distributed.wrap_model_for_ddp(model)
# from torch.nn.parallel import DistributedDataParallel as DDP
#
# model = DDP(model, device_ids=[ezpz.get_local_rank()])
return model, optimizer
@ezpz.timeitlogit(rank=ezpz.get_rank())
def main():
"""Entrypoint for launching the minimal synthetic training example."""
t0 = time.perf_counter()
model, optimizer = setup()
t_setup = time.perf_counter()
module_name = "ezpz.examples.minimal"
outdir = get_example_outdir(module_name)
logger.info("Outputs will be saved to %s", outdir)
train_start = time.perf_counter()
history = train(model, optimizer, outdir)
train_end = time.perf_counter()
timings = {
"main/setup": t_setup - t0,
"main/train": train_end - train_start,
"main/total": train_end - t0,
"timings/training_start": train_start - t0,
"timings/train_duration": train_end - train_start,
"timings/end-to-end": train_end - t0,
}
logger.info("Timings: %s", timings)
history.tracker.log(
{
(f"timings/{k}" if not k.startswith("timings/") else k): v
for k, v in timings.items()
}
)
if ezpz.get_rank() == 0:
dataset = history.finalize(
outdir=outdir,
run_name=module_name,
dataset_fname="train",
verbose=False,
)
del dataset # logged by finalize()
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] in ["--help", "-h"]:
print(
"\n".join(
[
"Usage: ",
" ".join(
[
"PRINT_ITERS=100",
"TRAIN_ITERS=1000",
"INPUT_SIZE=128",
"OUTPUT_SIZE=128",
"LAYER_SIZES=\"'128,256,128'\"",
"ezpz-launch",
"-m ezpz.examples.minimal",
]
),
]
)
)
exit(0)
else:
main()
Code Walkthrough⚓︎
Imports and Logger
Standard imports plus ezpz for distributed training utilities. The rank-aware logger ensures only rank 0 prints by default.
train()
The @ezpz.timeitlogit decorator logs wall-clock time for the entire function. Inside, the model is unwrapped if DDP-wrapped, and an ezpz.History is created to track metrics to a JSONL file.
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
outdir: os.PathLike | str,
) -> ezpz.History:
"""Run a synthetic training loop on random data.
Args:
model: Model to train (wrapped or unwrapped).
optimizer: Optimizer configured for the model.
Returns:
Training history with timing and loss metrics.
"""
unwrapped_model = (
model.module
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model
)
metrics_path = Path(outdir).joinpath("metrics.jsonl")
history = ezpz.History(
project_name=os.environ.get("PROJECT_NAME", "ezpz.examples.minimal"),
config={"model": str(unwrapped_model), "outdir": str(outdir)},
outdir=outdir,
report_dir=outdir,
report_enabled=True,
jsonl_path=metrics_path,
jsonl_overwrite=True,
distributed_history=(1 < ezpz.get_world_size() <= 384),
)
device_type = ezpz.get_torch_device_type()
dtype = unwrapped_model.layers[0].weight.dtype
bsize = int(os.environ.get("BATCH_SIZE", 64))
isize = unwrapped_model.layers[0].in_features
_model_flops = try_estimate(model, (bsize, isize))
warmup = int(os.environ.get("WARMUP_ITERS", 10))
log_freq = int(os.environ.get("LOG_FREQ", 1))
print_freq = int(os.environ.get("PRINT_FREQ", 10))
model.train()
The training loop generates random input, computes a reconstruction loss, and records forward/backward timings separately. Metrics are logged via history.update() after a warmup period.
summary = ""
for step in range(int(os.environ.get("TRAIN_ITERS", 500))):
with torch.autocast(
device_type=device_type,
dtype=dtype,
):
t0 = time.perf_counter()
x = torch.rand((bsize, isize), dtype=dtype).to(device_type)
y = model(x)
loss = ((y - x) ** 2).sum()
dtf = (t1 := time.perf_counter()) - t0
loss.backward()
optimizer.step()
optimizer.zero_grad()
dtb = time.perf_counter() - t1
if step % log_freq == 0 and step > warmup:
metrics = {
"iter": step,
"loss": loss.item(),
"dt": dtf + dtb,
"dtf": dtf,
"dtb": dtb,
}
if _model_flops > 0:
dt_total = dtf + dtb
if dt_total > 0:
metrics["tflops"] = _model_flops / dt_total / 1e12
metrics["mfu"] = compute_mfu(_model_flops, dt_total)
# Device memory: empty on CPU/MPS, 4 keys on CUDA/XPU.
metrics |= ezpz.get_memory_metrics()
summary = history.update(
metrics
)
if step % print_freq == 0 and step > warmup:
logger.info(summary)
return history
setup()
Initializes the distributed backend via ezpz.setup_torch(), optionally sets up W&B on rank 0, builds a SequentialLinearNet with env-var-driven dimensions, and wraps the model for DDP when running multi-GPU.
@ezpz.timeitlogit(rank=ezpz.get_rank())
def setup():
"""Initialize distributed runtime, model, and optimizer."""
rank = ezpz.setup_torch(seed=int(os.environ.get("SEED", 0)))
if os.environ.get("WANDB_DISABLED", False):
logger.info("WANDB_DISABLED is set, not initializing wandb")
elif rank == 0:
try:
_ = ezpz.setup_wandb(
project_name=os.environ.get("PROJECT_NAME", "ezpz.examples.minimal")
)
except Exception:
logger.exception("Failed to initialize wandb, continuing without it")
device_type = ezpz.get_torch_device_type()
from ezpz.models.minimal import SequentialLinearNet
model = SequentialLinearNet(
input_dim=int((os.environ.get("INPUT_SIZE", 128))),
output_dim=int(os.environ.get("OUTPUT_SIZE", 128)),
sizes=[
int(x)
for x in os.environ.get(
"LAYER_SIZES", "256,512,1024,2048,1024,512,256,128"
).split(",")
],
)
model.to(device_type)
model.to((os.environ.get("DTYPE", torch.bfloat16)))
try:
from ezpz.utils import model_summary
model_summary(model)
except Exception:
logger.exception("Failed to summarize model")
logger.info(f"{model=}")
optimizer = torch.optim.Adam(model.parameters())
if ezpz.get_world_size() > 1:
model = ezpz.distributed.wrap_model_for_ddp(model)
# from torch.nn.parallel import DistributedDataParallel as DDP
#
# model = DDP(model, device_ids=[ezpz.get_local_rank()])
return model, optimizer
main()
Orchestrates the full run: calls setup(), runs train(), then finalizes the history on rank 0 to persist metrics. Timing breakdowns are logged and optionally sent to W&B.
@ezpz.timeitlogit(rank=ezpz.get_rank())
def main():
"""Entrypoint for launching the minimal synthetic training example."""
t0 = time.perf_counter()
model, optimizer = setup()
t_setup = time.perf_counter()
module_name = "ezpz.examples.minimal"
outdir = get_example_outdir(module_name)
logger.info("Outputs will be saved to %s", outdir)
train_start = time.perf_counter()
history = train(model, optimizer, outdir)
train_end = time.perf_counter()
if ezpz.get_rank() == 0:
dataset = history.finalize(
outdir=outdir,
run_name=module_name,
dataset_fname="train",
verbose=False,
)
logger.info(f"{dataset=}")
timings = {
"main/setup": t_setup - t0,
"main/train": train_end - train_start,
"main/total": train_end - t0,
"timings/training_start": train_start - t0,
"timings/train_duration": train_end - train_start,
"timings/end-to-end": train_end - t0,
}
logger.info("Timings: %s", timings)
try:
import wandb
if getattr(wandb, "run", None) is not None:
wandb.log(
{
(f"timings/{k}" if not k.startswith("timings/") else k): v
for k, v in timings.items()
}
)
except Exception:
logger.debug("Skipping wandb timings log")
__main__ Guard
Prints a usage message on --help, otherwise calls main().
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] in ["--help", "-h"]:
print(
"\n".join(
[
"Usage: ",
" ".join(
[
"PRINT_ITERS=100",
"TRAIN_ITERS=1000",
"INPUT_SIZE=128",
"OUTPUT_SIZE=128",
"LAYER_SIZES=\"'128,256,128'\"",
"ezpz-launch",
"-m ezpz.examples.minimal",
]
),
]
)
)
exit(0)
else:
main()
MFU Tracking⚓︎
This example reports per-step TFLOPS and MFU (Model FLOPS
Utilization) alongside loss/timing metrics. The model FLOPS are
counted once at startup via try_estimate,
and compute_mfu divides by the device's peak BF16 throughput
(see ezpz.flops for details).
from ezpz.flops import compute_mfu, try_estimate
_model_flops = try_estimate(model, (bsize, isize))
# ...
metrics["tflops"] = _model_flops / dt / 1e12
metrics["mfu"] = compute_mfu(_model_flops, dt)
Configuration⚓︎
All configuration is via environment variables:
| Variable | Default | Description |
|---|---|---|
TRAIN_ITERS |
500 |
Number of training iterations |
BATCH_SIZE |
64 |
Batch size |
INPUT_SIZE |
128 |
Input dimension |
OUTPUT_SIZE |
128 |
Output dimension |
LAYER_SIZES |
256,512,...,128 |
Comma-separated hidden layer sizes |
DTYPE |
bfloat16 |
Model dtype |
SEED |
0 |
Random seed |
LOG_FREQ |
1 |
Log metrics every N steps |
PRINT_FREQ |
10 |
Print summary every N steps |
WARMUP_ITERS |
10 |
Steps to skip before recording metrics |
Example Output (Sunspot, 2 nodes × 12 ranks = 24 total)⚓︎
[2026-04-29 18:03:57][I][ezpz/distributed:1536:_setup_ddp] init_process_group: master_addr=x1921c6s0b0n0, master_port=53741, world_size=24, rank=0, backend=xccl, timeout=1:00:00
[2026-04-29 18:04:27][I][examples/minimal:150:main] Outputs will be saved to /tmp/outputs/ezpz.examples.minimal/2026-04-29-180355
[2026-04-29 18:04:37][I][examples/minimal:102:train] iter=20 loss=708.216064 dt=0.004312 dtf=0.000920 dtb=0.003391 tflops=0.496617 mfu=0.166545 ...
[2026-04-29 18:04:38][I][examples/minimal:102:train] iter=40 loss=676.927185 dt=0.004310 dtf=0.000922 dtb=0.003388 tflops=0.496832 mfu=0.166616 ...
[2026-04-29 18:04:38][I][examples/minimal:102:train] iter=60 loss=687.406372 dt=0.004286 dtf=0.000918 dtb=0.003368 tflops=0.499541 mfu=0.167525 ...
...
[2026-04-29 18:04:39][I][examples/minimal:102:train] iter=180 loss=676.073486 dt=0.004343 dtf=0.000921 dtb=0.003422 tflops=0.492990 mfu=0.165328 ...
Notice each step takes ~4 ms (forward dtf=0.9 ms + backward dtb=3.4 ms),
giving ~0.5 TFLOPS per device — about 0.17% MFU on PVC. This is a tiny
synthetic MLP, so MFU is dominated by collective overhead and kernel
launch latency rather than compute.