ezpz.examples.fsdp⚓︎
- See ezpz/examples/
fsdp.py - For the full CLI flag reference (with current defaults), see the
Help section in the FSDP example walkthrough
or run
python3 -m ezpz.examples.fsdp --help.
FSDP training example on MNIST/OpenImages/ImageNet-style datasets.
Launch with::
ezpz launch -m ezpz.examples.fsdp --dataset MNIST --batch-size 128
Run python3 -m ezpz.examples.fsdp --help for the full list of
flags and their current defaults.
Net
⚓︎
Bases: Module
Simple CNN classifier used in the FSDP example.
Source code in src/ezpz/examples/fsdp.py
class Net(nn.Module):
"""Simple CNN classifier used in the FSDP example."""
def __init__(
self,
num_classes: int = 10,
img_size: int = 28,
conv1_channels: int = 32,
conv2_channels: int = 64,
fc_dim: int = 128,
):
"""Initialize convolutional and fully connected layers.
Args:
num_classes: Number of output classes for the classifier.
img_size: Input image size (assumes square inputs).
conv1_channels: Number of output channels for conv1.
conv2_channels: Number of output channels for conv2.
fc_dim: Hidden dimension for the first fully connected layer.
"""
super().__init__()
self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
feature_size = self._feature_size(img_size, conv2_channels)
self.fc1 = nn.Linear(feature_size, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
@staticmethod
def _feature_size(img_size: int, conv2_channels: int) -> int:
conv1_size = img_size - 2
conv2_size = conv1_size - 2
pooled_size = conv2_size // 2
return conv2_channels * pooled_size * pooled_size
def forward(self, x):
"""Compute logits for input images."""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
__init__(num_classes=10, img_size=28, conv1_channels=32, conv2_channels=64, fc_dim=128)
⚓︎
Initialize convolutional and fully connected layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of output classes for the classifier. |
10
|
img_size
|
int
|
Input image size (assumes square inputs). |
28
|
conv1_channels
|
int
|
Number of output channels for conv1. |
32
|
conv2_channels
|
int
|
Number of output channels for conv2. |
64
|
fc_dim
|
int
|
Hidden dimension for the first fully connected layer. |
128
|
Source code in src/ezpz/examples/fsdp.py
def __init__(
self,
num_classes: int = 10,
img_size: int = 28,
conv1_channels: int = 32,
conv2_channels: int = 64,
fc_dim: int = 128,
):
"""Initialize convolutional and fully connected layers.
Args:
num_classes: Number of output classes for the classifier.
img_size: Input image size (assumes square inputs).
conv1_channels: Number of output channels for conv1.
conv2_channels: Number of output channels for conv2.
fc_dim: Hidden dimension for the first fully connected layer.
"""
super().__init__()
self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
feature_size = self._feature_size(img_size, conv2_channels)
self.fc1 = nn.Linear(feature_size, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
forward(x)
⚓︎
Compute logits for input images.
Source code in src/ezpz/examples/fsdp.py
def forward(self, x):
"""Compute logits for input images."""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
fsdp_main(args)
⚓︎
Main training loop orchestrating data, model, and logging.
Source code in src/ezpz/examples/fsdp.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def fsdp_main(args: argparse.Namespace) -> None:
"""Main training loop orchestrating data, model, and logging."""
t0 = time.perf_counter()
rank = ezpz.setup_torch(seed=args.seed)
t_setup = time.perf_counter()
data = get_data(args)
ezpz.distributed.barrier()
train_loader = data["train"]["loader"]
test_loader = data["test"]["loader"]
tmp = prepare_model_optimizer_and_scheduler(args)
model = tmp["model"]
optimizer = tmp["optimizer"]
scheduler = tmp["scheduler"]
_model_flops = tmp.get("model_flops", 0)
outdir = get_example_outdir(WBPROJ_NAME)
logger.info("Outputs will be saved to %s", outdir)
metrics_path = outdir.joinpath(f"metrics-{rank}.jsonl")
outdir.mkdir(parents=True, exist_ok=True)
history = ezpz.history.History(
report_dir=outdir,
report_enabled=(rank == 0),
jsonl_path=metrics_path,
project_name=WBPROJ_NAME,
config={"args": vars(args), **ezpz.get_dist_info()},
distributed_history=(
1 < ezpz.get_world_size() <= 384 # and not config.pytorch_profiler
),
)
start = time.perf_counter()
for epoch in range(1, args.epochs + 1):
train_metrics = train(
model=model,
train_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
sampler=data["train"]["sampler"],
)
test_metrics = test(model, test_loader)
scheduler.step()
merged = {**train_metrics, **test_metrics}
if _model_flops > 0:
# FSDP epoch loop reports per-epoch averages, so MFU here
# is averaged over the whole epoch (epoch_dt / num_batches),
# not per-step. Smooths out warmup spikes but obscures
# straggler effects compared to the per-step MFU other
# examples report.
dt_step = merged.get("dt_per_step", 0.0)
if dt_step > 0:
merged["tflops"] = _model_flops / dt_step / 1e12
merged["mfu"] = compute_mfu(_model_flops, dt_step)
# Device memory: per-EPOCH peak (we don't reset between batches),
# so mem_peak_* here captures the high-water mark across the
# whole training+eval epoch.
merged |= ezpz.get_memory_metrics()
logger.info(history.update(merged))
train_end = time.perf_counter()
logger.info(
" ".join(
[
f"{args.epochs + 1} epochs took",
f"{train_end - start:.1f}s",
]
)
)
timings = {
"main/setup_torch": t_setup - t0,
"main/train": train_end - start,
"main/total": train_end - t0,
"timings/training_start": start - t0,
"timings/train_duration": train_end - start,
"timings/end-to-end": train_end - t0,
}
logger.info("Timings: %s", timings)
history.tracker.log(timings)
ezpz.distributed.barrier()
if args.save_model:
ezpz.distributed.barrier() # wait for slowpokes
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
if rank == 0:
dataset = history.finalize(
run_name=WBPROJ_NAME,
dataset_fname="train",
)
del dataset # logged by finalize()
get_data(args)
⚓︎
Load train/test datasets according to args.dataset.
Source code in src/ezpz/examples/fsdp.py
def get_data(args: argparse.Namespace) -> dict:
"""Load train/test datasets according to args.dataset."""
# data_prefix_fallback = Path(os.getcwd()).joinpath(
# ".cache", "ezpz", "data", f"{args.dataset.lower()}"
# )
# data_prefix = args.data_prefix or data_prefix_fallback
# if args.dataset == "MNIST":
# from ezpz.data.vision import get_mnist
#
# data = get_mnist(
# outdir=Path(data_prefix),
# train_batch_size=args.batch_size,
# test_batch_size=args.test_batch_size,
# pin_memory=True,
# num_workers=args.num_workers,
# )
# else:
# raise ValueError(f"Unsupported dataset: {args.dataset}")
data_prefix_fallback = Path(os.getcwd()).joinpath(
".cache", "ezpz", "data", f"{args.dataset.lower()}"
)
data_prefix = args.data_prefix or data_prefix_fallback
if args.dataset == "MNIST":
from ezpz.data.vision import get_mnist
data = get_mnist(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet1k":
from ezpz.data.vision import get_imagenet1k
data = get_imagenet1k(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "OpenImages":
from ezpz.data.vision import get_openimages
data = get_openimages(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet":
from ezpz.data.vision import get_imagenet
data = get_imagenet(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
return data
parse_args(argv=None)
⚓︎
CLI parser for the FSDP example.
Source code in src/ezpz/examples/fsdp.py
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""CLI parser for the FSDP example."""
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(
description="PyTorch MNIST Example using FSDP",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--num-workers",
type=int,
default=0,
metavar="N",
help="number of data loading workers",
)
parser.add_argument(
"--dataset",
type=str,
default="MNIST",
choices=["MNIST", "OpenImages", "ImageNet", "ImageNet1k"],
help="Dataset to use",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training",
)
parser.add_argument(
"--model",
type=str,
default=None,
choices=sorted(MODEL_PRESETS.keys()),
help="Model size preset (overrides conv/fc defaults)",
)
parser.add_argument(
"--conv1-channels",
type=int,
default=32,
metavar="N",
help="Number of output channels in conv1",
)
parser.add_argument(
"--conv2-channels",
type=int,
default=64,
metavar="N",
help="Number of output channels in conv2",
)
parser.add_argument(
"--fc-dim",
type=int,
default=128,
metavar="N",
help="Hidden dimension for the first linear layer",
)
parser.add_argument(
"--dtype",
type=str,
default="bf16",
metavar="D",
help="Datatype for training",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train",
)
parser.add_argument(
"--lr",
type=float,
default=1e-3,
metavar="LR",
help="learning rate",
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma",
)
parser.add_argument(
"--seed",
type=int,
default=None,
metavar="S",
help="random seed",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
parser.add_argument(
"--data-prefix",
type=str,
required=False,
default=None,
help="data directory prefix",
)
args = parser.parse_args(argv)
apply_model_preset(args, argv)
return args
prepare_model_optimizer_and_scheduler(args)
⚓︎
Create the FSDP-wrapped model, optimizer, and LR scheduler.
Source code in src/ezpz/examples/fsdp.py
def prepare_model_optimizer_and_scheduler(args: argparse.Namespace) -> dict:
"""Create the FSDP-wrapped model, optimizer, and LR scheduler."""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
if args.dataset == "MNIST":
num_classes = 10
img_size = 28
elif args.dataset == "OpenImages":
num_classes = 600
img_size = 224
elif args.dataset == "ImageNet":
num_classes = 1000
img_size = 224
elif args.dataset == "ImageNet1k":
num_classes = 1000
img_size = 224
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
model = Net(
num_classes=num_classes,
img_size=img_size,
conv1_channels=args.conv1_channels,
conv2_channels=args.conv2_channels,
fc_dim=args.fc_dim,
).to(device)
logger.info(f"\n{summarize_model(model, verbose=False, depth=2)}")
_model_flops = try_estimate(model, (args.batch_size, 1, img_size, img_size))
dtypes = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = dtypes[args.dtype]
model = FSDP(
model,
device_id=device,
mixed_precision=MixedPrecision(
param_dtype=dtype,
cast_forward_inputs=True,
),
)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
logger.info(f"{model=}")
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
return {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
"model_flops": _model_flops,
}
test(model, test_loader)
⚓︎
Evaluate model on validation data and gather metrics.
Source code in src/ezpz/examples/fsdp.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def test(model, test_loader):
"""Evaluate model on validation data and gather metrics."""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
model.eval()
# correct = 0
ddp_loss = torch.zeros(3).to(device)
with torch.no_grad():
for batch, target in test_loader:
batch, target = batch.to(device), target.to(device)
output = model(batch)
ddp_loss[0] += F.nll_loss(output, target, reduction="sum")
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum()
ddp_loss[2] += len(batch)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
test_loss = ddp_loss[0] / ddp_loss[2]
return {
"test_loss": test_loss,
"test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
}
train(model, train_loader, optimizer, epoch, sampler=None)
⚓︎
One epoch of training and loss aggregation across ranks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module | DistributedDataParallel | FullyShardedDataParallel
|
Wrapped model (DDP/FSDP). |
required |
train_loader
|
DataLoader
|
Dataloader for training set. |
required |
optimizer
|
Optimizer
|
Optimizer instance. |
required |
epoch
|
int
|
Current epoch index. |
required |
sampler
|
DistributedSampler | None
|
Optional distributed sampler to set epoch. |
None
|
Returns:
| Type | Description |
|---|---|
dict
|
Dict with epoch, wall-clock duration, and averaged train loss. |
Source code in src/ezpz/examples/fsdp.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
model: nn.Module | DistributedDataParallel | FSDP,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
sampler: DistributedSampler | None = None,
) -> dict:
"""One epoch of training and loss aggregation across ranks.
Args:
model: Wrapped model (DDP/FSDP).
train_loader: Dataloader for training set.
optimizer: Optimizer instance.
epoch: Current epoch index.
sampler: Optional distributed sampler to set epoch.
Returns:
Dict with epoch, wall-clock duration, and averaged train loss.
"""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
model.train()
ddp_loss = torch.zeros(2).to(device)
if sampler:
sampler.set_epoch(epoch)
ezpz.distributed.synchronize()
t0 = time.perf_counter()
num_batches = 0
batch, target = next(iter(train_loader))
for _, (batch, target) in enumerate(train_loader):
batch, target = batch.to(device), target.to(device)
optimizer.zero_grad()
output = model(batch)
loss = F.nll_loss(output, target, reduction="sum")
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(batch)
num_batches += 1
ezpz.distributed.synchronize()
t1 = time.perf_counter()
epoch_dt = t1 - t0
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
return {
"epoch": epoch,
"dt": epoch_dt,
"dt_per_step": epoch_dt / max(num_batches, 1),
"train_loss": ddp_loss[0] / ddp_loss[1],
}