ezpz.examples.fsdp⚓︎
FSDP training example on MNIST/OpenImages/ImageNet-style datasets.
Launch with:
ezpz launch -m ezpz.examples.fsdp --dataset MNIST --batch-size 128
Help output (python3 -m ezpz.examples.fsdp --help):
usage: fsdp.py [-h] [--num-workers N]
[--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}]
[--batch-size N] [--dtype D] [--test-batch-size N] [--epochs N]
[--lr LR] [--gamma M] [--seed S] [--save-model]
[--data-prefix DATA_PREFIX]
PyTorch MNIST Example using FSDP
options:
-h, --help show this help message and exit
--num-workers N number of data loading workers (default: 4)
--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}
Dataset to use (default: MNIST)
--batch-size N input batch size for training (default: 64)
--dtype D Datatype for training (default=bf16).
--test-batch-size N input batch size for testing (default: 1000)
--epochs N number of epochs to train (default: 10)
--lr LR learning rate (default: 1e-3)
--gamma M Learning rate step gamma (default: 0.7)
--seed S random seed (default: 1)
--save-model For Saving the current Model
--data-prefix DATA_PREFIX
data directory prefix
Net
⚓︎
Bases: Module
Simple CNN classifier used in the FSDP example.
Source code in src/ezpz/examples/fsdp.py
class Net(nn.Module):
"""Simple CNN classifier used in the FSDP example."""
def __init__(
self,
num_classes: int = 10,
img_size: int = 28,
conv1_channels: int = 32,
conv2_channels: int = 64,
fc_dim: int = 128,
):
"""Initialize convolutional and fully connected layers.
Args:
num_classes: Number of output classes for the classifier.
img_size: Input image size (assumes square inputs).
conv1_channels: Number of output channels for conv1.
conv2_channels: Number of output channels for conv2.
fc_dim: Hidden dimension for the first fully connected layer.
"""
super().__init__()
self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
feature_size = self._feature_size(img_size, conv2_channels)
self.fc1 = nn.Linear(feature_size, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
@staticmethod
def _feature_size(img_size: int, conv2_channels: int) -> int:
conv1_size = img_size - 2
conv2_size = conv1_size - 2
pooled_size = conv2_size // 2
return conv2_channels * pooled_size * pooled_size
def forward(self, x):
"""Compute logits for input images."""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
__init__(num_classes=10, img_size=28, conv1_channels=32, conv2_channels=64, fc_dim=128)
⚓︎
Initialize convolutional and fully connected layers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of output classes for the classifier. |
10
|
img_size
|
int
|
Input image size (assumes square inputs). |
28
|
conv1_channels
|
int
|
Number of output channels for conv1. |
32
|
conv2_channels
|
int
|
Number of output channels for conv2. |
64
|
fc_dim
|
int
|
Hidden dimension for the first fully connected layer. |
128
|
Source code in src/ezpz/examples/fsdp.py
def __init__(
self,
num_classes: int = 10,
img_size: int = 28,
conv1_channels: int = 32,
conv2_channels: int = 64,
fc_dim: int = 128,
):
"""Initialize convolutional and fully connected layers.
Args:
num_classes: Number of output classes for the classifier.
img_size: Input image size (assumes square inputs).
conv1_channels: Number of output channels for conv1.
conv2_channels: Number of output channels for conv2.
fc_dim: Hidden dimension for the first fully connected layer.
"""
super().__init__()
self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
feature_size = self._feature_size(img_size, conv2_channels)
self.fc1 = nn.Linear(feature_size, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
forward(x)
⚓︎
Compute logits for input images.
Source code in src/ezpz/examples/fsdp.py
def forward(self, x):
"""Compute logits for input images."""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
fsdp_main(args)
⚓︎
Main training loop orchestrating data, model, and logging.
Source code in src/ezpz/examples/fsdp.py
def fsdp_main(args: argparse.Namespace) -> None:
"""Main training loop orchestrating data, model, and logging."""
rank = ezpz.setup_torch(seed=args.seed)
START_TIME = ezpz.get_timestamp() if ezpz.get_rank() == 0 else None
START_TIME = ezpz.dist.broadcast(START_TIME, root=0)
if rank == 0:
# try:
fp = Path(__file__)
run = ezpz.setup_wandb(project_name=f"ezpz.{fp.parent.stem}.{fp.stem}")
if run is not None and wandb is not None and run is wandb.run:
run.config.update({**vars(args)})
run.config.update({"ezpz.dist": {**ezpz.get_dist_info()}})
data = get_data(args)
ezpz.dist.barrier()
train_loader = data["train"]["loader"]
test_loader = data["test"]["loader"]
tmp = prepare_model_optimizer_and_scheduler(args)
model = tmp["model"]
optimizer = tmp["optimizer"]
scheduler = tmp["scheduler"]
# if rank == 0:
outdir = Path(os.getcwd()).joinpath("outputs", fname, START_TIME)
metrics_path = outdir.joinpath(f"metrics-{rank}.jsonl")
outdir.mkdir(parents=True, exist_ok=True)
history = ezpz.history.History(
report_dir=outdir,
report_enabled=(rank == 0),
jsonl_path=metrics_path,
# jsonl_overwrite=True,
distributed_history=(
1 < ezpz.get_world_size() <= 384 # and not config.pytorch_profiler
),
)
start = time.perf_counter()
for epoch in range(1, args.epochs + 1):
train_metrics = train(
model=model,
train_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
sampler=data["train"]["sampler"],
)
test_metrics = test(model, test_loader)
scheduler.step()
logger.info(history.update({**train_metrics, **test_metrics}))
logger.info(
" ".join(
[
f"{args.epochs + 1} epochs took",
f"{time.perf_counter() - start:.1f}s",
]
)
)
ezpz.dist.barrier()
if args.save_model:
ezpz.dist.barrier() # wait for slowpokes
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
if rank == 0:
dataset = history.finalize(run_name="ezpz-fsdp", dataset_fname="train")
logger.info(f"{dataset=}")
get_data(args)
⚓︎
Load train/test datasets according to args.dataset.
Source code in src/ezpz/examples/fsdp.py
def get_data(args: argparse.Namespace) -> dict:
"""Load train/test datasets according to args.dataset."""
# data_prefix_fallback = Path(os.getcwd()).joinpath(
# ".cache", "ezpz", "data", f"{args.dataset.lower()}"
# )
# data_prefix = args.data_prefix or data_prefix_fallback
# if args.dataset == "MNIST":
# from ezpz.data.vision import get_mnist
#
# data = get_mnist(
# outdir=Path(data_prefix),
# train_batch_size=args.batch_size,
# test_batch_size=args.test_batch_size,
# pin_memory=True,
# num_workers=args.num_workers,
# )
# else:
# raise ValueError(f"Unsupported dataset: {args.dataset}")
data_prefix_fallback = Path(os.getcwd()).joinpath(
".cache", "ezpz", "data", f"{args.dataset.lower()}"
)
data_prefix = args.data_prefix or data_prefix_fallback
if args.dataset == "MNIST":
from ezpz.data.vision import get_mnist
data = get_mnist(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet1k":
from ezpz.data.vision import get_imagenet1k
data = get_imagenet1k(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "OpenImages":
from ezpz.data.vision import get_openimages
data = get_openimages(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet":
from ezpz.data.vision import get_imagenet
data = get_imagenet(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
return data
parse_args(argv=None)
⚓︎
CLI parser for the FSDP example.
Source code in src/ezpz/examples/fsdp.py
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""CLI parser for the FSDP example."""
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(
description="PyTorch MNIST Example using FSDP"
)
parser.add_argument(
"--num-workers",
type=int,
default=0,
metavar="N",
help="number of data loading workers (default: 4)",
)
parser.add_argument(
"--dataset",
type=str,
default="MNIST",
choices=["MNIST", "OpenImages", "ImageNet", "ImageNet1k"],
help="Dataset to use (default: MNIST)",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--model",
type=str,
default=None,
choices=sorted(MODEL_PRESETS.keys()),
help="Model size preset (overrides conv/fc defaults)",
)
parser.add_argument(
"--conv1-channels",
type=int,
default=32,
metavar="N",
help="Number of output channels in conv1",
)
parser.add_argument(
"--conv2-channels",
type=int,
default=64,
metavar="N",
help="Number of output channels in conv2",
)
parser.add_argument(
"--fc-dim",
type=int,
default=128,
metavar="N",
help="Hidden dimension for the first linear layer",
)
parser.add_argument(
"--dtype",
type=str,
default="bf16",
metavar="D",
help="Datatype for training (default=bf16).",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr",
type=float,
default=1e-3,
metavar="LR",
help="learning rate (default: 1e-3)",
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
metavar="S",
help="random seed (default: 1)",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
parser.add_argument(
"--data-prefix",
type=str,
required=False,
default=None,
help="data directory prefix",
)
args = parser.parse_args(argv)
apply_model_preset(args, argv)
return args
prepare_model_optimizer_and_scheduler(args)
⚓︎
Create the FSDP-wrapped model, optimizer, and LR scheduler.
Source code in src/ezpz/examples/fsdp.py
def prepare_model_optimizer_and_scheduler(args: argparse.Namespace) -> dict:
"""Create the FSDP-wrapped model, optimizer, and LR scheduler."""
device_type = ezpz.dist.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.dist.get_local_rank()}")
)
if args.dataset == "MNIST":
num_classes = 10
img_size = 28
elif args.dataset == "OpenImages":
num_classes = 600
img_size = 224
elif args.dataset == "ImageNet":
num_classes = 1000
img_size = 224
elif args.dataset == "ImageNet1k":
num_classes = 1000
img_size = 224
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
model = Net(
num_classes=num_classes,
img_size=img_size,
conv1_channels=args.conv1_channels,
conv2_channels=args.conv2_channels,
fc_dim=args.fc_dim,
).to(device)
logger.info(f"\n{summarize_model(model, verbose=False, depth=2)}")
dtypes = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = dtypes[args.dtype]
model = FSDP(
model,
device_id=device,
mixed_precision=MixedPrecision(
param_dtype=dtype,
cast_forward_inputs=True,
),
)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
logger.info(f"{model=}")
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
return {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
}
test(model, test_loader)
⚓︎
Evaluate model on validation data and gather metrics.
Source code in src/ezpz/examples/fsdp.py
def test(model, test_loader):
"""Evaluate model on validation data and gather metrics."""
device_type = ezpz.dist.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.dist.get_local_rank()}")
)
model.eval()
# correct = 0
ddp_loss = torch.zeros(3).to(device)
with torch.no_grad():
for batch, target in test_loader:
batch, target = batch.to(device), target.to(device)
output = model(batch)
ddp_loss[0] += F.nll_loss(output, target, reduction="sum")
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum()
ddp_loss[2] += len(batch)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
test_loss = ddp_loss[0] / ddp_loss[2]
return {
"test_loss": test_loss,
"test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
}
train(model, train_loader, optimizer, epoch, sampler=None)
⚓︎
One epoch of training and loss aggregation across ranks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module | DistributedDataParallel | FullyShardedDataParallel
|
Wrapped model (DDP/FSDP). |
required |
train_loader
|
DataLoader
|
Dataloader for training set. |
required |
optimizer
|
Optimizer
|
Optimizer instance. |
required |
epoch
|
int
|
Current epoch index. |
required |
sampler
|
DistributedSampler | None
|
Optional distributed sampler to set epoch. |
None
|
Returns:
| Type | Description |
|---|---|
dict
|
Dict with epoch, wall-clock duration, and averaged train loss. |
Source code in src/ezpz/examples/fsdp.py
def train(
model: nn.Module | DistributedDataParallel | FSDP,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
sampler: DistributedSampler | None = None,
) -> dict:
"""One epoch of training and loss aggregation across ranks.
Args:
model: Wrapped model (DDP/FSDP).
train_loader: Dataloader for training set.
optimizer: Optimizer instance.
epoch: Current epoch index.
sampler: Optional distributed sampler to set epoch.
Returns:
Dict with epoch, wall-clock duration, and averaged train loss.
"""
device_type = ezpz.dist.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.dist.get_local_rank()}")
)
model.train()
ddp_loss = torch.zeros(2).to(device)
if sampler:
sampler.set_epoch(epoch)
ezpz.dist.synchronize()
t0 = time.perf_counter()
batch, target = next(iter(train_loader))
for _, (batch, target) in enumerate(train_loader):
batch, target = batch.to(device), target.to(device)
optimizer.zero_grad()
output = model(batch)
loss = F.nll_loss(output, target, reduction="sum")
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(batch)
ezpz.dist.synchronize()
t1 = time.perf_counter()
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
return {
"epoch": epoch,
"dt": t1 - t0,
"train_loss": ddp_loss[0] / ddp_loss[1],
}