Skip to content

ezpz.examples.fsdp⚓︎

FSDP training example on MNIST/OpenImages/ImageNet-style datasets.

Launch with:

ezpz launch -m ezpz.examples.fsdp --dataset MNIST --batch-size 128

Help output (python3 -m ezpz.examples.fsdp --help):

usage: fsdp.py [-h] [--num-workers N]
               [--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}]
               [--batch-size N] [--dtype D] [--test-batch-size N] [--epochs N]
               [--lr LR] [--gamma M] [--seed S] [--save-model]
               [--data-prefix DATA_PREFIX]

PyTorch MNIST Example using FSDP

options:
  -h, --help            show this help message and exit
  --num-workers N       number of data loading workers (default: 4)
  --dataset {MNIST,OpenImages,ImageNet,ImageNet1k}
                        Dataset to use (default: MNIST)
  --batch-size N        input batch size for training (default: 64)
  --dtype D             Datatype for training (default=bf16).
  --test-batch-size N   input batch size for testing (default: 1000)
  --epochs N            number of epochs to train (default: 10)
  --lr LR               learning rate (default: 1e-3)
  --gamma M             Learning rate step gamma (default: 0.7)
  --seed S              random seed (default: 1)
  --save-model          For Saving the current Model
  --data-prefix DATA_PREFIX
                        data directory prefix

Net ⚓︎

Bases: Module

Simple CNN classifier used in the FSDP example.

Source code in src/ezpz/examples/fsdp.py
class Net(nn.Module):
    """Simple CNN classifier used in the FSDP example."""

    def __init__(
        self,
        num_classes: int = 10,
        img_size: int = 28,
        conv1_channels: int = 32,
        conv2_channels: int = 64,
        fc_dim: int = 128,
    ):
        """Initialize convolutional and fully connected layers.

        Args:
            num_classes: Number of output classes for the classifier.
            img_size: Input image size (assumes square inputs).
            conv1_channels: Number of output channels for conv1.
            conv2_channels: Number of output channels for conv2.
            fc_dim: Hidden dimension for the first fully connected layer.
        """
        super().__init__()
        self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
        self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        feature_size = self._feature_size(img_size, conv2_channels)
        self.fc1 = nn.Linear(feature_size, fc_dim)
        self.fc2 = nn.Linear(fc_dim, num_classes)

    @staticmethod
    def _feature_size(img_size: int, conv2_channels: int) -> int:
        conv1_size = img_size - 2
        conv2_size = conv1_size - 2
        pooled_size = conv2_size // 2
        return conv2_channels * pooled_size * pooled_size

    def forward(self, x):
        """Compute logits for input images."""
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

__init__(num_classes=10, img_size=28, conv1_channels=32, conv2_channels=64, fc_dim=128) ⚓︎

Initialize convolutional and fully connected layers.

Parameters:

Name Type Description Default
num_classes int

Number of output classes for the classifier.

10
img_size int

Input image size (assumes square inputs).

28
conv1_channels int

Number of output channels for conv1.

32
conv2_channels int

Number of output channels for conv2.

64
fc_dim int

Hidden dimension for the first fully connected layer.

128
Source code in src/ezpz/examples/fsdp.py
def __init__(
    self,
    num_classes: int = 10,
    img_size: int = 28,
    conv1_channels: int = 32,
    conv2_channels: int = 64,
    fc_dim: int = 128,
):
    """Initialize convolutional and fully connected layers.

    Args:
        num_classes: Number of output classes for the classifier.
        img_size: Input image size (assumes square inputs).
        conv1_channels: Number of output channels for conv1.
        conv2_channels: Number of output channels for conv2.
        fc_dim: Hidden dimension for the first fully connected layer.
    """
    super().__init__()
    self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
    self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    feature_size = self._feature_size(img_size, conv2_channels)
    self.fc1 = nn.Linear(feature_size, fc_dim)
    self.fc2 = nn.Linear(fc_dim, num_classes)

forward(x) ⚓︎

Compute logits for input images.

Source code in src/ezpz/examples/fsdp.py
def forward(self, x):
    """Compute logits for input images."""
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output

fsdp_main(args) ⚓︎

Main training loop orchestrating data, model, and logging.

Source code in src/ezpz/examples/fsdp.py
def fsdp_main(args: argparse.Namespace) -> None:
    """Main training loop orchestrating data, model, and logging."""
    rank = ezpz.setup_torch(seed=args.seed)
    START_TIME = ezpz.get_timestamp() if ezpz.get_rank() == 0 else None
    START_TIME = ezpz.dist.broadcast(START_TIME, root=0)
    if rank == 0:
        # try:
        fp = Path(__file__)
        run = ezpz.setup_wandb(project_name=f"ezpz.{fp.parent.stem}.{fp.stem}")
        if run is not None and wandb is not None and run is wandb.run:
            run.config.update({**vars(args)})
            run.config.update({"ezpz.dist": {**ezpz.get_dist_info()}})

    data = get_data(args)
    ezpz.dist.barrier()
    train_loader = data["train"]["loader"]
    test_loader = data["test"]["loader"]

    tmp = prepare_model_optimizer_and_scheduler(args)
    model = tmp["model"]
    optimizer = tmp["optimizer"]
    scheduler = tmp["scheduler"]

    # if rank == 0:
    outdir = Path(os.getcwd()).joinpath("outputs", fname, START_TIME)
    metrics_path = outdir.joinpath(f"metrics-{rank}.jsonl")
    outdir.mkdir(parents=True, exist_ok=True)
    history = ezpz.history.History(
        report_dir=outdir,
        report_enabled=(rank == 0),
        jsonl_path=metrics_path,
        # jsonl_overwrite=True,
        distributed_history=(
            1 < ezpz.get_world_size() <= 384  # and not config.pytorch_profiler
        ),
    )
    start = time.perf_counter()
    for epoch in range(1, args.epochs + 1):
        train_metrics = train(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            epoch=epoch,
            sampler=data["train"]["sampler"],
        )
        test_metrics = test(model, test_loader)
        scheduler.step()
        logger.info(history.update({**train_metrics, **test_metrics}))

    logger.info(
        " ".join(
            [
                f"{args.epochs + 1} epochs took",
                f"{time.perf_counter() - start:.1f}s",
            ]
        )
    )
    ezpz.dist.barrier()

    if args.save_model:
        ezpz.dist.barrier()  # wait for slowpokes
        states = model.state_dict()
        if rank == 0:
            torch.save(states, "mnist_cnn.pt")

    if rank == 0:
        dataset = history.finalize(run_name="ezpz-fsdp", dataset_fname="train")
        logger.info(f"{dataset=}")

get_data(args) ⚓︎

Load train/test datasets according to args.dataset.

Source code in src/ezpz/examples/fsdp.py
def get_data(args: argparse.Namespace) -> dict:
    """Load train/test datasets according to args.dataset."""
    # data_prefix_fallback = Path(os.getcwd()).joinpath(
    #     ".cache", "ezpz", "data", f"{args.dataset.lower()}"
    # )
    # data_prefix = args.data_prefix or data_prefix_fallback
    # if args.dataset == "MNIST":
    #     from ezpz.data.vision import get_mnist
    #
    #     data = get_mnist(
    #         outdir=Path(data_prefix),
    #         train_batch_size=args.batch_size,
    #         test_batch_size=args.test_batch_size,
    #         pin_memory=True,
    #         num_workers=args.num_workers,
    #     )
    # else:
    #     raise ValueError(f"Unsupported dataset: {args.dataset}")
    data_prefix_fallback = Path(os.getcwd()).joinpath(
        ".cache", "ezpz", "data", f"{args.dataset.lower()}"
    )
    data_prefix = args.data_prefix or data_prefix_fallback
    if args.dataset == "MNIST":
        from ezpz.data.vision import get_mnist

        data = get_mnist(
            outdir=Path(data_prefix),
            train_batch_size=args.batch_size,
            test_batch_size=args.test_batch_size,
            pin_memory=True,
            num_workers=args.num_workers,
        )

    elif args.dataset == "ImageNet1k":
        from ezpz.data.vision import get_imagenet1k

        data = get_imagenet1k(
            outdir=Path(data_prefix),
            train_batch_size=args.batch_size,
            test_batch_size=args.test_batch_size,
            pin_memory=True,
            num_workers=args.num_workers,
        )

    elif args.dataset == "OpenImages":
        from ezpz.data.vision import get_openimages

        data = get_openimages(
            outdir=Path(data_prefix),
            train_batch_size=args.batch_size,
            test_batch_size=args.test_batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=args.num_workers,
        )
    elif args.dataset == "ImageNet":
        from ezpz.data.vision import get_imagenet

        data = get_imagenet(
            outdir=Path(data_prefix),
            train_batch_size=args.batch_size,
            test_batch_size=args.test_batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=args.num_workers,
        )
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    return data

parse_args(argv=None) ⚓︎

CLI parser for the FSDP example.

Source code in src/ezpz/examples/fsdp.py
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    """CLI parser for the FSDP example."""
    if argv is None:
        argv = sys.argv[1:]
    parser = argparse.ArgumentParser(
        description="PyTorch MNIST Example using FSDP"
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=0,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="MNIST",
        choices=["MNIST", "OpenImages", "ImageNet", "ImageNet1k"],
        help="Dataset to use (default: MNIST)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        choices=sorted(MODEL_PRESETS.keys()),
        help="Model size preset (overrides conv/fc defaults)",
    )
    parser.add_argument(
        "--conv1-channels",
        type=int,
        default=32,
        metavar="N",
        help="Number of output channels in conv1",
    )
    parser.add_argument(
        "--conv2-channels",
        type=int,
        default=64,
        metavar="N",
        help="Number of output channels in conv2",
    )
    parser.add_argument(
        "--fc-dim",
        type=int,
        default=128,
        metavar="N",
        help="Hidden dimension for the first linear layer",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bf16",
        metavar="D",
        help="Datatype for training (default=bf16).",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: 1000)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 10)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=1e-3,
        metavar="LR",
        help="learning rate (default: 1e-3)",
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.7,
        metavar="M",
        help="Learning rate step gamma (default: 0.7)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        metavar="S",
        help="random seed (default: 1)",
    )
    parser.add_argument(
        "--save-model",
        action="store_true",
        default=False,
        help="For Saving the current Model",
    )
    parser.add_argument(
        "--data-prefix",
        type=str,
        required=False,
        default=None,
        help="data directory prefix",
    )
    args = parser.parse_args(argv)
    apply_model_preset(args, argv)
    return args

prepare_model_optimizer_and_scheduler(args) ⚓︎

Create the FSDP-wrapped model, optimizer, and LR scheduler.

Source code in src/ezpz/examples/fsdp.py
def prepare_model_optimizer_and_scheduler(args: argparse.Namespace) -> dict:
    """Create the FSDP-wrapped model, optimizer, and LR scheduler."""
    device_type = ezpz.dist.get_torch_device_type()
    device = (
        torch.device("cpu")
        if device_type == "cpu"
        else torch.device(f"{device_type}:{ezpz.dist.get_local_rank()}")
    )
    if args.dataset == "MNIST":
        num_classes = 10
        img_size = 28
    elif args.dataset == "OpenImages":
        num_classes = 600
        img_size = 224
    elif args.dataset == "ImageNet":
        num_classes = 1000
        img_size = 224
    elif args.dataset == "ImageNet1k":
        num_classes = 1000
        img_size = 224
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")
    model = Net(
        num_classes=num_classes,
        img_size=img_size,
        conv1_channels=args.conv1_channels,
        conv2_channels=args.conv2_channels,
        fc_dim=args.fc_dim,
    ).to(device)
    logger.info(f"\n{summarize_model(model, verbose=False, depth=2)}")
    dtypes = {
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
        "bfloat16": torch.bfloat16,
        "fp32": torch.float32,
    }
    dtype = dtypes[args.dtype]
    model = FSDP(
        model,
        device_id=device,
        mixed_precision=MixedPrecision(
            param_dtype=dtype,
            cast_forward_inputs=True,
        ),
    )
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    logger.info(f"{model=}")
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    return {
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler,
    }

test(model, test_loader) ⚓︎

Evaluate model on validation data and gather metrics.

Source code in src/ezpz/examples/fsdp.py
def test(model, test_loader):
    """Evaluate model on validation data and gather metrics."""
    device_type = ezpz.dist.get_torch_device_type()
    device = (
        torch.device("cpu")
        if device_type == "cpu"
        else torch.device(f"{device_type}:{ezpz.dist.get_local_rank()}")
    )
    model.eval()
    # correct = 0
    ddp_loss = torch.zeros(3).to(device)
    with torch.no_grad():
        for batch, target in test_loader:
            batch, target = batch.to(device), target.to(device)
            output = model(batch)
            ddp_loss[0] += F.nll_loss(output, target, reduction="sum")
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            ddp_loss[1] += pred.eq(target.view_as(pred)).sum()
            ddp_loss[2] += len(batch)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)  # type:ignore

    test_loss = ddp_loss[0] / ddp_loss[2]

    return {
        "test_loss": test_loss,
        "test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
    }

train(model, train_loader, optimizer, epoch, sampler=None) ⚓︎

One epoch of training and loss aggregation across ranks.

Parameters:

Name Type Description Default
model Module | DistributedDataParallel | FullyShardedDataParallel

Wrapped model (DDP/FSDP).

required
train_loader DataLoader

Dataloader for training set.

required
optimizer Optimizer

Optimizer instance.

required
epoch int

Current epoch index.

required
sampler DistributedSampler | None

Optional distributed sampler to set epoch.

None

Returns:

Type Description
dict

Dict with epoch, wall-clock duration, and averaged train loss.

Source code in src/ezpz/examples/fsdp.py
def train(
    model: nn.Module | DistributedDataParallel | FSDP,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    sampler: DistributedSampler | None = None,
) -> dict:
    """One epoch of training and loss aggregation across ranks.

    Args:
        model: Wrapped model (DDP/FSDP).
        train_loader: Dataloader for training set.
        optimizer: Optimizer instance.
        epoch: Current epoch index.
        sampler: Optional distributed sampler to set epoch.

    Returns:
        Dict with epoch, wall-clock duration, and averaged train loss.
    """
    device_type = ezpz.dist.get_torch_device_type()
    device = (
        torch.device("cpu")
        if device_type == "cpu"
        else torch.device(f"{device_type}:{ezpz.dist.get_local_rank()}")
    )
    model.train()
    ddp_loss = torch.zeros(2).to(device)
    if sampler:
        sampler.set_epoch(epoch)
    ezpz.dist.synchronize()
    t0 = time.perf_counter()
    batch, target = next(iter(train_loader))
    for _, (batch, target) in enumerate(train_loader):
        batch, target = batch.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(batch)
        loss = F.nll_loss(output, target, reduction="sum")
        loss.backward()
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(batch)
    ezpz.dist.synchronize()
    t1 = time.perf_counter()
    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)  # type:ignore
    return {
        "epoch": epoch,
        "dt": t1 - t0,
        "train_loss": ddp_loss[0] / ddp_loss[1],
    }