Train CNN with FSDP on MNISTβοΈ
Key API Functions
setup_torch()β Initialize distributed trainingwrap_model()β Wrap model for FSDP (withuse_fsdp=True)TrainConfigβ Training configuration
See:
- π examples/FSDP
- π src/ezpz/examples/fsdp.py
SourceβοΈ
src/ezpz/examples/fsdp.py
"""FSDP training example on MNIST/OpenImages/ImageNet-style datasets.
Launch with:
ezpz launch -m ezpz.examples.fsdp --dataset MNIST --batch-size 128
Help output (``python3 -m ezpz.examples.fsdp --help``):
usage: fsdp.py [-h] [--num-workers N]
[--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}]
[--batch-size N] [--dtype D] [--test-batch-size N] [--epochs N]
[--lr LR] [--gamma M] [--seed S] [--save-model]
[--data-prefix DATA_PREFIX]
PyTorch MNIST Example using FSDP
options:
-h, --help show this help message and exit
--num-workers N number of data loading workers (default: 4)
--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}
Dataset to use (default: MNIST)
--batch-size N input batch size for training (default: 64)
--dtype D Datatype for training (default=bf16).
--test-batch-size N input batch size for testing (default: 1000)
--epochs N number of epochs to train (default: 10)
--lr LR learning rate (default: 1e-3)
--gamma M Learning rate step gamma (default: 0.7)
--seed S random seed (default: 1)
--save-model For Saving the current Model
--data-prefix DATA_PREFIX
data directory prefix
"""
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler, DataLoader
import argparse
import os
from pathlib import Path
import sys
import time
import ezpz
# from ezpz.history import WANDB_DISABLED
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ezpz.models import summarize_model
from ezpz.examples import get_example_outdir
logger = ezpz.get_logger(__name__)
try:
import wandb
except Exception:
wandb = None # type: ignore
fp = Path(__file__)
fname = f"{fp.parent.stem}.{fp.stem}"
WBPROJ_NAME = f"ezpz.{fp.parent.stem}.{fp.stem}"
OUTPUT_DIR = Path(os.getcwd()).joinpath("outputs", fname)
MODEL_PRESETS = {
"debug": {
"conv1_channels": 8,
"conv2_channels": 16,
"fc_dim": 64,
},
"small": {
"conv1_channels": 16,
"conv2_channels": 32,
"fc_dim": 128,
},
"medium": {
"conv1_channels": 32,
"conv2_channels": 64,
"fc_dim": 256,
},
"large": {
"conv1_channels": 64,
"conv2_channels": 128,
"fc_dim": 512,
},
}
MODEL_PRESET_FLAGS = {
"conv1_channels": ["--conv1-channels"],
"conv2_channels": ["--conv2-channels"],
"fc_dim": ["--fc-dim"],
}
class Net(nn.Module):
"""Simple CNN classifier used in the FSDP example."""
def __init__(
self,
num_classes: int = 10,
img_size: int = 28,
conv1_channels: int = 32,
conv2_channels: int = 64,
fc_dim: int = 128,
):
"""Initialize convolutional and fully connected layers.
Args:
num_classes: Number of output classes for the classifier.
img_size: Input image size (assumes square inputs).
conv1_channels: Number of output channels for conv1.
conv2_channels: Number of output channels for conv2.
fc_dim: Hidden dimension for the first fully connected layer.
"""
super().__init__()
self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
feature_size = self._feature_size(img_size, conv2_channels)
self.fc1 = nn.Linear(feature_size, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
@staticmethod
def _feature_size(img_size: int, conv2_channels: int) -> int:
conv1_size = img_size - 2
conv2_size = conv1_size - 2
pooled_size = conv2_size // 2
return conv2_channels * pooled_size * pooled_size
def forward(self, x):
"""Compute logits for input images."""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
model: nn.Module | DistributedDataParallel | FSDP,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
sampler: DistributedSampler | None = None,
) -> dict:
"""One epoch of training and loss aggregation across ranks.
Args:
model: Wrapped model (DDP/FSDP).
train_loader: Dataloader for training set.
optimizer: Optimizer instance.
epoch: Current epoch index.
sampler: Optional distributed sampler to set epoch.
Returns:
Dict with epoch, wall-clock duration, and averaged train loss.
"""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
model.train()
ddp_loss = torch.zeros(2).to(device)
if sampler:
sampler.set_epoch(epoch)
ezpz.distributed.synchronize()
t0 = time.perf_counter()
batch, target = next(iter(train_loader))
for _, (batch, target) in enumerate(train_loader):
batch, target = batch.to(device), target.to(device)
optimizer.zero_grad()
output = model(batch)
loss = F.nll_loss(output, target, reduction="sum")
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(batch)
ezpz.distributed.synchronize()
t1 = time.perf_counter()
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
return {
"epoch": epoch,
"dt": t1 - t0,
"train_loss": ddp_loss[0] / ddp_loss[1],
}
@ezpz.timeitlogit(rank=ezpz.get_rank())
def test(model, test_loader):
"""Evaluate model on validation data and gather metrics."""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
model.eval()
# correct = 0
ddp_loss = torch.zeros(3).to(device)
with torch.no_grad():
for batch, target in test_loader:
batch, target = batch.to(device), target.to(device)
output = model(batch)
ddp_loss[0] += F.nll_loss(output, target, reduction="sum")
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum()
ddp_loss[2] += len(batch)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
test_loss = ddp_loss[0] / ddp_loss[2]
return {
"test_loss": test_loss,
"test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
}
def prepare_model_optimizer_and_scheduler(args: argparse.Namespace) -> dict:
"""Create the FSDP-wrapped model, optimizer, and LR scheduler."""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
if args.dataset == "MNIST":
num_classes = 10
img_size = 28
elif args.dataset == "OpenImages":
num_classes = 600
img_size = 224
elif args.dataset == "ImageNet":
num_classes = 1000
img_size = 224
elif args.dataset == "ImageNet1k":
num_classes = 1000
img_size = 224
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
model = Net(
num_classes=num_classes,
img_size=img_size,
conv1_channels=args.conv1_channels,
conv2_channels=args.conv2_channels,
fc_dim=args.fc_dim,
).to(device)
logger.info(f"\n{summarize_model(model, verbose=False, depth=2)}")
dtypes = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = dtypes[args.dtype]
model = FSDP(
model,
device_id=device,
mixed_precision=MixedPrecision(
param_dtype=dtype,
cast_forward_inputs=True,
),
)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
logger.info(f"{model=}")
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
return {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
}
def get_data(args: argparse.Namespace) -> dict:
"""Load train/test datasets according to args.dataset."""
# data_prefix_fallback = Path(os.getcwd()).joinpath(
# ".cache", "ezpz", "data", f"{args.dataset.lower()}"
# )
# data_prefix = args.data_prefix or data_prefix_fallback
# if args.dataset == "MNIST":
# from ezpz.data.vision import get_mnist
#
# data = get_mnist(
# outdir=Path(data_prefix),
# train_batch_size=args.batch_size,
# test_batch_size=args.test_batch_size,
# pin_memory=True,
# num_workers=args.num_workers,
# )
# else:
# raise ValueError(f"Unsupported dataset: {args.dataset}")
data_prefix_fallback = Path(os.getcwd()).joinpath(
".cache", "ezpz", "data", f"{args.dataset.lower()}"
)
data_prefix = args.data_prefix or data_prefix_fallback
if args.dataset == "MNIST":
from ezpz.data.vision import get_mnist
data = get_mnist(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet1k":
from ezpz.data.vision import get_imagenet1k
data = get_imagenet1k(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "OpenImages":
from ezpz.data.vision import get_openimages
data = get_openimages(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet":
from ezpz.data.vision import get_imagenet
data = get_imagenet(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
return data
@ezpz.timeitlogit(rank=ezpz.get_rank())
def fsdp_main(args: argparse.Namespace) -> None:
"""Main training loop orchestrating data, model, and logging."""
t0 = time.perf_counter()
rank = ezpz.setup_torch(seed=args.seed)
t_setup = time.perf_counter()
if rank == 0:
# try:
fp = Path(__file__)
run = ezpz.setup_wandb(project_name=f"ezpz.{fp.parent.stem}.{fp.stem}")
if run is not None and wandb is not None and run is wandb.run:
run.config.update({"args": {**vars(args)}})
run.config.update({"ezpz.dist": {**ezpz.get_dist_info()}})
data = get_data(args)
ezpz.distributed.barrier()
train_loader = data["train"]["loader"]
test_loader = data["test"]["loader"]
tmp = prepare_model_optimizer_and_scheduler(args)
model = tmp["model"]
optimizer = tmp["optimizer"]
scheduler = tmp["scheduler"]
outdir = get_example_outdir(WBPROJ_NAME)
logger.info("Outputs will be saved to %s", outdir)
metrics_path = outdir.joinpath(f"metrics-{rank}.jsonl")
outdir.mkdir(parents=True, exist_ok=True)
history = ezpz.history.History(
report_dir=outdir,
report_enabled=(rank == 0),
jsonl_path=metrics_path,
# jsonl_overwrite=True,
distributed_history=(
1 < ezpz.get_world_size() <= 384 # and not config.pytorch_profiler
),
)
start = time.perf_counter()
for epoch in range(1, args.epochs + 1):
train_metrics = train(
model=model,
train_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
sampler=data["train"]["sampler"],
)
test_metrics = test(model, test_loader)
scheduler.step()
logger.info(history.update({**train_metrics, **test_metrics}))
train_end = time.perf_counter()
logger.info(
" ".join(
[
f"{args.epochs + 1} epochs took",
f"{train_end - start:.1f}s",
]
)
)
timings = {
"main/setup_torch": t_setup - t0,
"main/train": train_end - start,
"main/total": train_end - t0,
"timings/training_start": start - t0,
"timings/train_duration": train_end - start,
"timings/end-to-end": train_end - t0,
}
logger.info("Timings: %s", timings)
if wandb is not None and getattr(wandb, "run", None) is not None:
try:
wandb.log(
{
(f"timings/{k}" if not k.startswith("timings/") else k): v
for k, v in timings.items()
}
)
except Exception:
logger.warning("Failed to log timings to wandb")
ezpz.distributed.barrier()
if args.save_model:
ezpz.distributed.barrier() # wait for slowpokes
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
if rank == 0:
dataset = history.finalize(
run_name=WBPROJ_NAME,
dataset_fname="train",
)
logger.info(f"{dataset=}")
def _arg_provided(argv: list[str], flags: list[str]) -> bool:
return any(flag in argv for flag in flags)
def apply_model_preset(args: argparse.Namespace, argv: list[str]) -> None:
if args.model is None:
return
preset = MODEL_PRESETS[args.model]
for field_name, value in preset.items():
flags = MODEL_PRESET_FLAGS.get(field_name, [])
if not _arg_provided(argv, flags):
setattr(args, field_name, value)
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""CLI parser for the FSDP example."""
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(
description="PyTorch MNIST Example using FSDP"
)
parser.add_argument(
"--num-workers",
type=int,
default=0,
metavar="N",
help="number of data loading workers (default: 4)",
)
parser.add_argument(
"--dataset",
type=str,
default="MNIST",
choices=["MNIST", "OpenImages", "ImageNet", "ImageNet1k"],
help="Dataset to use (default: MNIST)",
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--model",
type=str,
default=None,
choices=sorted(MODEL_PRESETS.keys()),
help="Model size preset (overrides conv/fc defaults)",
)
parser.add_argument(
"--conv1-channels",
type=int,
default=32,
metavar="N",
help="Number of output channels in conv1",
)
parser.add_argument(
"--conv2-channels",
type=int,
default=64,
metavar="N",
help="Number of output channels in conv2",
)
parser.add_argument(
"--fc-dim",
type=int,
default=128,
metavar="N",
help="Hidden dimension for the first linear layer",
)
parser.add_argument(
"--dtype",
type=str,
default="bf16",
metavar="D",
help="Datatype for training (default=bf16).",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr",
type=float,
default=1e-3,
metavar="LR",
help="learning rate (default: 1e-3)",
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
metavar="S",
help="random seed (default: 1)",
)
parser.add_argument(
"--save-model",
action="store_true",
default=False,
help="For Saving the current Model",
)
parser.add_argument(
"--data-prefix",
type=str,
required=False,
default=None,
help="data directory prefix",
)
args = parser.parse_args(argv)
apply_model_preset(args, argv)
return args
if __name__ == "__main__":
args = parse_args()
fsdp_main(args=args)
ezpz.cleanup()
Code WalkthroughβοΈ
Imports
The FSDP and MixedPrecision imports enable fully-sharded data parallelism
with optional half-precision compute, which is the core distribution strategy
this example demonstrates. ezpz replaces the manual init_process_group /
device-selection boilerplate so the same script works on CUDA, XPU, and MPS
without changes. summarize_model and get_example_outdir are convenience
helpers for logging parameter counts and writing outputs to a timestamped
directory.
# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler, DataLoader
import argparse
import os
from pathlib import Path
import sys
import time
import ezpz
# from ezpz.history import WANDB_DISABLED
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ezpz.models import summarize_model
from ezpz.examples import get_example_outdir
logger = ezpz.get_logger(__name__)
try:
import wandb
except Exception:
wandb = None # type: ignore
Model Presets
Named presets (debug, small, medium, large) let users scale the CNN
architecture from the command line with --model <preset>. Each preset
bundles conv1_channels, conv2_channels, and fc_dim so you can quickly
compare FSDP overhead at different model sizes without manually tuning
individual flags. Any CLI flag the user passes explicitly overrides the
preset value.
MODEL_PRESETS = {
"debug": {
"conv1_channels": 8,
"conv2_channels": 16,
"fc_dim": 64,
},
"small": {
"conv1_channels": 16,
"conv2_channels": 32,
"fc_dim": 128,
},
"medium": {
"conv1_channels": 32,
"conv2_channels": 64,
"fc_dim": 256,
},
"large": {
"conv1_channels": 64,
"conv2_channels": 128,
"fc_dim": 512,
},
}
Net -- CNN Architecture
A two-layer convolutional network with dropout and two fully connected
layers. _feature_size computes the flattened dimension after convolutions
and pooling so the first linear layer is sized correctly.
class Net(nn.Module):
"""Simple CNN classifier used in the FSDP example."""
def __init__(
self,
num_classes: int = 10,
img_size: int = 28,
conv1_channels: int = 32,
conv2_channels: int = 64,
fc_dim: int = 128,
):
"""Initialize convolutional and fully connected layers.
Args:
num_classes: Number of output classes for the classifier.
img_size: Input image size (assumes square inputs).
conv1_channels: Number of output channels for conv1.
conv2_channels: Number of output channels for conv2.
fc_dim: Hidden dimension for the first fully connected layer.
"""
super().__init__()
self.conv1 = nn.Conv2d(1, conv1_channels, 3, 1)
self.conv2 = nn.Conv2d(conv1_channels, conv2_channels, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
feature_size = self._feature_size(img_size, conv2_channels)
self.fc1 = nn.Linear(feature_size, fc_dim)
self.fc2 = nn.Linear(fc_dim, num_classes)
@staticmethod
def _feature_size(img_size: int, conv2_channels: int) -> int:
conv1_size = img_size - 2
conv2_size = conv1_size - 2
pooled_size = conv2_size // 2
return conv2_channels * pooled_size * pooled_size
def forward(self, x):
"""Compute logits for input images."""
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
train -- Single-Epoch Training
Runs one training epoch, accumulating loss across batches. After the loop,
dist.all_reduce sums the loss and sample count across all ranks so every
worker sees the global average.
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
model: nn.Module | DistributedDataParallel | FSDP,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
epoch: int,
sampler: DistributedSampler | None = None,
) -> dict:
"""One epoch of training and loss aggregation across ranks.
Args:
model: Wrapped model (DDP/FSDP).
train_loader: Dataloader for training set.
optimizer: Optimizer instance.
epoch: Current epoch index.
sampler: Optional distributed sampler to set epoch.
Returns:
Dict with epoch, wall-clock duration, and averaged train loss.
"""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
model.train()
ddp_loss = torch.zeros(2).to(device)
if sampler:
sampler.set_epoch(epoch)
ezpz.distributed.synchronize()
t0 = time.perf_counter()
batch, target = next(iter(train_loader))
for _, (batch, target) in enumerate(train_loader):
batch, target = batch.to(device), target.to(device)
optimizer.zero_grad()
output = model(batch)
loss = F.nll_loss(output, target, reduction="sum")
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(batch)
ezpz.distributed.synchronize()
t1 = time.perf_counter()
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
return {
"epoch": epoch,
"dt": t1 - t0,
"train_loss": ddp_loss[0] / ddp_loss[1],
}
test -- Evaluation
Evaluates the model on validation data with gradients disabled. Tracks loss, correct predictions, and total samples, then all-reduces across ranks.
@ezpz.timeitlogit(rank=ezpz.get_rank())
def test(model, test_loader):
"""Evaluate model on validation data and gather metrics."""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
model.eval()
# correct = 0
ddp_loss = torch.zeros(3).to(device)
with torch.no_grad():
for batch, target in test_loader:
batch, target = batch.to(device), target.to(device)
output = model(batch)
ddp_loss[0] += F.nll_loss(output, target, reduction="sum")
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum()
ddp_loss[2] += len(batch)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
test_loss = ddp_loss[0] / ddp_loss[2]
return {
"test_loss": test_loss,
"test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
}
prepare_model_optimizer_and_scheduler -- FSDP Wrapping
Creates the Net model, wraps it with FullyShardedDataParallel using
mixed-precision settings, and returns the model, optimizer, and LR
scheduler.
def prepare_model_optimizer_and_scheduler(args: argparse.Namespace) -> dict:
"""Create the FSDP-wrapped model, optimizer, and LR scheduler."""
device_type = ezpz.distributed.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.distributed.get_local_rank()}")
)
if args.dataset == "MNIST":
num_classes = 10
img_size = 28
elif args.dataset == "OpenImages":
num_classes = 600
img_size = 224
elif args.dataset == "ImageNet":
num_classes = 1000
img_size = 224
elif args.dataset == "ImageNet1k":
num_classes = 1000
img_size = 224
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
model = Net(
num_classes=num_classes,
img_size=img_size,
conv1_channels=args.conv1_channels,
conv2_channels=args.conv2_channels,
fc_dim=args.fc_dim,
).to(device)
logger.info(f"\n{summarize_model(model, verbose=False, depth=2)}")
dtypes = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"bfloat16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = dtypes[args.dtype]
model = FSDP(
model,
device_id=device,
mixed_precision=MixedPrecision(
param_dtype=dtype,
cast_forward_inputs=True,
),
)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
logger.info(f"{model=}")
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
return {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
}
get_data -- Data Loading
Dispatches to dataset-specific loaders (get_mnist, get_imagenet1k,
get_openimages, get_imagenet) from ezpz.data.vision based on the
--dataset flag.
def get_data(args: argparse.Namespace) -> dict:
"""Load train/test datasets according to args.dataset."""
data_prefix_fallback = Path(os.getcwd()).joinpath(
".cache", "ezpz", "data", f"{args.dataset.lower()}"
)
data_prefix = args.data_prefix or data_prefix_fallback
if args.dataset == "MNIST":
from ezpz.data.vision import get_mnist
data = get_mnist(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet1k":
from ezpz.data.vision import get_imagenet1k
data = get_imagenet1k(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "OpenImages":
from ezpz.data.vision import get_openimages
data = get_openimages(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
elif args.dataset == "ImageNet":
from ezpz.data.vision import get_imagenet
data = get_imagenet(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers,
)
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
return data
fsdp_main -- Main Function
Orchestrates the full training run: initializes distributed training with
ezpz.setup_torch, optionally sets up Weights & Biases logging, loads
data, prepares the FSDP-wrapped model, and runs the epoch loop.
@ezpz.timeitlogit(rank=ezpz.get_rank())
def fsdp_main(args: argparse.Namespace) -> None:
"""Main training loop orchestrating data, model, and logging."""
t0 = time.perf_counter()
rank = ezpz.setup_torch(seed=args.seed)
t_setup = time.perf_counter()
if rank == 0:
# try:
fp = Path(__file__)
run = ezpz.setup_wandb(project_name=f"ezpz.{fp.parent.stem}.{fp.stem}")
if run is not None and wandb is not None and run is wandb.run:
run.config.update({"args": {**vars(args)}})
run.config.update({"ezpz.dist": {**ezpz.get_dist_info()}})
data = get_data(args)
ezpz.distributed.barrier()
train_loader = data["train"]["loader"]
test_loader = data["test"]["loader"]
tmp = prepare_model_optimizer_and_scheduler(args)
model = tmp["model"]
optimizer = tmp["optimizer"]
scheduler = tmp["scheduler"]
An ezpz.history.History object tracks per-epoch metrics and optionally
writes them to JSONL. The epoch loop calls train, test, and
scheduler.step each iteration.
outdir = get_example_outdir(WBPROJ_NAME)
logger.info("Outputs will be saved to %s", outdir)
metrics_path = outdir.joinpath(f"metrics-{rank}.jsonl")
outdir.mkdir(parents=True, exist_ok=True)
history = ezpz.history.History(
report_dir=outdir,
report_enabled=(rank == 0),
jsonl_path=metrics_path,
# jsonl_overwrite=True,
distributed_history=(
1 < ezpz.get_world_size() <= 384 # and not config.pytorch_profiler
),
)
start = time.perf_counter()
for epoch in range(1, args.epochs + 1):
train_metrics = train(
model=model,
train_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
sampler=data["train"]["sampler"],
)
test_metrics = test(model, test_loader)
scheduler.step()
logger.info(history.update({**train_metrics, **test_metrics}))
After training completes, timings are logged (and optionally sent to W&B),
the model checkpoint is saved if --save-model was passed, and
history.finalize writes the final report on rank 0.
train_end = time.perf_counter()
logger.info(
" ".join(
[
f"{args.epochs + 1} epochs took",
f"{train_end - start:.1f}s",
]
)
)
timings = {
"main/setup_torch": t_setup - t0,
"main/train": train_end - start,
"main/total": train_end - t0,
"timings/training_start": start - t0,
"timings/train_duration": train_end - start,
"timings/end-to-end": train_end - t0,
}
logger.info("Timings: %s", timings)
if wandb is not None and getattr(wandb, "run", None) is not None:
try:
wandb.log(
{
(f"timings/{k}" if not k.startswith("timings/") else k): v
for k, v in timings.items()
}
)
except Exception:
logger.warning("Failed to log timings to wandb")
ezpz.distributed.barrier()
if args.save_model:
ezpz.distributed.barrier() # wait for slowpokes
states = model.state_dict()
if rank == 0:
torch.save(states, "mnist_cnn.pt")
if rank == 0:
dataset = history.finalize(
run_name=WBPROJ_NAME,
dataset_fname="train",
)
logger.info(f"{dataset=}")
Entrypoint
Parses CLI arguments, runs fsdp_main, and calls ezpz.cleanup() to tear
down the process group.
HelpβοΈ
--help
$ python3 -m ezpz.examples.fsdp --help
usage: fsdp.py [-h] [--num-workers N]
[--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}]
[--batch-size N] [--dtype D] [--test-batch-size N] [--epochs N]
[--lr LR] [--gamma M] [--seed S] [--save-model]
[--data-prefix DATA_PREFIX]
PyTorch MNIST Example using FSDP
options:
-h, --help show this help message and exit
--num-workers N number of data loading workers (default: 4)
--dataset {MNIST,OpenImages,ImageNet,ImageNet1k}
Dataset to use (default: MNIST)
--batch-size N input batch size for training (default: 64)
--dtype D Datatype for training (default=bf16).
--test-batch-size N input batch size for testing (default: 1000)
--epochs N number of epochs to train (default: 10)
--lr LR learning rate (default: 1e-3)
--gamma M Learning rate step gamma (default: 0.7)
--seed S random seed (default: 1)
--save-model For Saving the current Model
--data-prefix DATA_PREFIX
data directory prefix
OutputβοΈ
Output on Sunspot
$ ezpz launch python3 -m ezpz.examples.fsdp
[2025-12-31 12:21:21,523041][I][ezpz/launch:396:launch] ----[π ezpz.launch][started][2025-12-31-122121]----
[2025-12-31 12:21:22,375537][I][ezpz/launch:416:launch] Job ID: 12458339
[2025-12-31 12:21:22,376302][I][ezpz/launch:417:launch] nodelist: ['x1921c0s3b0n0', 'x1921c0s7b0n0']
[2025-12-31 12:21:22,376691][I][ezpz/launch:418:launch] hostfile: /var/spool/pbs/aux/12458339.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov
[2025-12-31 12:21:22,377360][I][ezpz/pbs:264:get_pbs_launch_cmd] β
Using [24/24] GPUs [2 hosts] x [12 GPU/host]
[2025-12-31 12:21:22,378079][I][ezpz/launch:367:build_executable] Building command to execute by piecing together:
[2025-12-31 12:21:22,378474][I][ezpz/launch:368:build_executable] (1.) launch_cmd: mpiexec --envall --np=24 --ppn=12 --hostfile=/var/spool/pbs/aux/12458339.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov --no-vni --cpu-bind=verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96
[2025-12-31 12:21:22,379293][I][ezpz/launch:369:build_executable] (2.) cmd_to_launch: python3 -m ezpz.examples.fsdp
[2025-12-31 12:21:22,380037][I][ezpz/launch:433:launch] Took: 1.45 seconds to build command.
[2025-12-31 12:21:22,380393][I][ezpz/launch:436:launch] Executing:
mpiexec
--envall
--np=24
--ppn=12
--hostfile=/var/spool/pbs/aux/12458339.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov
--no-vni
--cpu-bind=verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96
python3
-m
ezpz.examples.fsdp
[2025-12-31 12:21:22,381628][I][ezpz/launch:443:launch] Execution started @ 2025-12-31-122122...
[2025-12-31 12:21:22,382071][I][ezpz/launch:139:run_command] Running command:
mpiexec --envall --np=24 --ppn=12 --hostfile=/var/spool/pbs/aux/12458339.sunspot-pbs-0001.head.cm.sunspot.alcf.anl.gov --no-vni --cpu-bind=verbose,list:2-4:10-12:18-20:26-28:34-36:42-44:54-56:62-64:70-72:78-80:86-88:94-96 python3 -m ezpz.examples.fsdp
cpubind:list x1921c0s7b0n0 pid 111174 rank 12 0: mask 0x1c
cpubind:list x1921c0s7b0n0 pid 111175 rank 13 1: mask 0x1c00
cpubind:list x1921c0s7b0n0 pid 111176 rank 14 2: mask 0x1c0000
cpubind:list x1921c0s7b0n0 pid 111177 rank 15 3: mask 0x1c000000
cpubind:list x1921c0s7b0n0 pid 111178 rank 16 4: mask 0x1c00000000
cpubind:list x1921c0s7b0n0 pid 111179 rank 17 5: mask 0x1c0000000000
cpubind:list x1921c0s7b0n0 pid 111180 rank 18 6: mask 0x1c0000000000000
cpubind:list x1921c0s7b0n0 pid 111181 rank 19 7: mask 0x1c000000000000000
cpubind:list x1921c0s7b0n0 pid 111182 rank 20 8: mask 0x1c00000000000000000
cpubind:list x1921c0s7b0n0 pid 111183 rank 21 9: mask 0x1c0000000000000000000
cpubind:list x1921c0s7b0n0 pid 111184 rank 22 10: mask 0x1c000000000000000000000
cpubind:list x1921c0s7b0n0 pid 111185 rank 23 11: mask 0x1c00000000000000000000000
cpubind:list x1921c0s3b0n0 pid 107043 rank 0 0: mask 0x1c
cpubind:list x1921c0s3b0n0 pid 107044 rank 1 1: mask 0x1c00
cpubind:list x1921c0s3b0n0 pid 107045 rank 2 2: mask 0x1c0000
cpubind:list x1921c0s3b0n0 pid 107046 rank 3 3: mask 0x1c000000
cpubind:list x1921c0s3b0n0 pid 107047 rank 4 4: mask 0x1c00000000
cpubind:list x1921c0s3b0n0 pid 107048 rank 5 5: mask 0x1c0000000000
cpubind:list x1921c0s3b0n0 pid 107049 rank 6 6: mask 0x1c0000000000000
cpubind:list x1921c0s3b0n0 pid 107050 rank 7 7: mask 0x1c000000000000000
cpubind:list x1921c0s3b0n0 pid 107051 rank 8 8: mask 0x1c00000000000000000
cpubind:list x1921c0s3b0n0 pid 107052 rank 9 9: mask 0x1c0000000000000000000
cpubind:list x1921c0s3b0n0 pid 107053 rank 10 10: mask 0x1c000000000000000000000
cpubind:list x1921c0s3b0n0 pid 107054 rank 11 11: mask 0x1c00000000000000000000000
[2025-12-31 12:21:26,964250][I][ezpz/dist:1501:setup_torch_distributed] Using torch_{device,backend}= {xpu, xccl}
[2025-12-31 12:21:26,967037][I][ezpz/dist:1366:setup_torch_DDP] Caught MASTER_PORT=41625 from environment!
[2025-12-31 12:21:26,967795][I][ezpz/dist:1382:setup_torch_DDP] Using torch.distributed.init_process_group with
- master_addr='x1921c0s3b0n0'
- master_port='41625'
- world_size=24
- rank=0
- local_rank=0
- timeout=datetime.timedelta(seconds=3600)
- backend='xccl'
[2025-12-31 12:21:26,968707][I][ezpz/dist:1014:init_process_group] Calling torch.distributed.init_process_group_with: rank=0 world_size=24 backend=xccl
[2025-12-31 12:21:27,619965][I][ezpz/dist:1727:setup_torch] Using device='xpu' with backend='xccl' + 'xccl' for distributed training.
[2025-12-31 12:21:27,620787][W][ezpz/dist:544:print_dist_setup] Using [24 / 24] available "xpu" devices !!
[2025-12-31 12:21:27,621230][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=0/1][rank=00/23][local_rank=00/11]
[2025-12-31 12:21:27,620421][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=1/1][rank=01/23][local_rank=01/11]
[2025-12-31 12:21:27,620452][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=0/1][rank=02/23][local_rank=02/11]
[2025-12-31 12:21:27,620445][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=0/1][rank=04/23][local_rank=04/11]
[2025-12-31 12:21:27,620450][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=1/1][rank=05/23][local_rank=05/11]
[2025-12-31 12:21:27,620418][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=0/1][rank=06/23][local_rank=06/11]
[2025-12-31 12:21:27,620439][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=1/1][rank=07/23][local_rank=07/11]
[2025-12-31 12:21:27,620431][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=0/1][rank=08/23][local_rank=08/11]
[2025-12-31 12:21:27,620400][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=1/1][rank=09/23][local_rank=09/11]
[2025-12-31 12:21:27,620398][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=0/1][rank=10/23][local_rank=10/11]
[2025-12-31 12:21:27,620433][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=1/1][rank=11/23][local_rank=11/11]
[2025-12-31 12:21:27,620451][I][ezpz/dist:1774:setup_torch] ['x1921c0s3b0n0'][device='xpu'][node=1/1][rank=03/23][local_rank=03/11]
[2025-12-31 12:21:27,620523][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=0/1][rank=12/23][local_rank=00/11]
[2025-12-31 12:21:27,620546][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=1/1][rank=13/23][local_rank=01/11]
[2025-12-31 12:21:27,620556][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=0/1][rank=14/23][local_rank=02/11]
[2025-12-31 12:21:27,620557][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=0/1][rank=16/23][local_rank=04/11]
[2025-12-31 12:21:27,620568][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=1/1][rank=15/23][local_rank=03/11]
[2025-12-31 12:21:27,620557][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=1/1][rank=17/23][local_rank=05/11]
[2025-12-31 12:21:27,620575][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=1/1][rank=19/23][local_rank=07/11]
[2025-12-31 12:21:27,620556][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=0/1][rank=20/23][local_rank=08/11]
[2025-12-31 12:21:27,620560][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=1/1][rank=21/23][local_rank=09/11]
[2025-12-31 12:21:27,620578][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=0/1][rank=22/23][local_rank=10/11]
[2025-12-31 12:21:27,620579][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=1/1][rank=23/23][local_rank=11/11]
[2025-12-31 12:21:27,620579][I][ezpz/dist:1774:setup_torch] ['x1921c0s7b0n0'][device='xpu'][node=0/1][rank=18/23][local_rank=06/11]
[2025-12-31 12:21:28,206982][I][ezpz/dist:2039:setup_wandb] Setting up wandb from rank=0
[2025-12-31 12:21:28,207580][I][ezpz/dist:2040:setup_wandb] Using WB_PROJECT=ezpz.examples.fsdp
wandb: Currently logged in as: foremans (aurora_gpt) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.23.1
wandb: Run data is saved locally in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/wandb/run-20251231_122128-11cqdt05
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run vivid-glade-86
wandb: View project at https://wandb.ai/aurora_gpt/ezpz.examples.fsdp
wandb: View run at https://wandb.ai/aurora_gpt/ezpz.examples.fsdp/runs/11cqdt05
[2025-12-31 12:21:29,790902][I][ezpz/dist:2069:setup_wandb] wandb.run=[vivid-glade-86](https://wandb.ai/aurora_gpt/ezpz.examples.fsdp/runs/11cqdt05)
[2025-12-31 12:21:29,796125][I][ezpz/dist:2112:setup_wandb] Running on machine='SunSpot'
[2025-12-31 12:21:30,092593][I][examples/fsdp:196:prepare_model_optimizer_and_scheduler]
=================================================================
Layer (type:depth-idx) Param #
=================================================================
Net --
ββConv2d: 1-1 320
ββConv2d: 1-2 18,496
ββDropout: 1-3 --
ββDropout: 1-4 --
ββLinear: 1-5 1,179,776
ββLinear: 1-6 1,290
=================================================================
Total params: 1,199,882
Trainable params: 1,199,882
Non-trainable params: 0
=================================================================
[2025-12-31 12:21:30,134352][I][examples/fsdp:212:prepare_model_optimizer_and_scheduler] model=FullyShardedDataParallel(
(_fsdp_wrapped_module): Net(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(dropout1): Dropout(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): Linear(in_features=9216, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
)
[2025-12-31 12:21:30,173375][I][ezpz/history:220:__init__] Using History with distributed_history=True
2025:12:31-12:21:30:(107043) |CCL_WARN| value of CCL_OP_SYNC changed to be 1 (default:0)
2025:12:31-12:21:30:(107043) |CCL_WARN| value of CCL_PROCESS_LAUNCHER changed to be pmix (default:hydra)
[2025-12-31 12:21:55,502783][I][examples/fsdp:340:fsdp_main] epoch=1 dt=12.487221 train_loss=0.596659 test_loss=0.143485 test_acc=95.563553 dt/mean=11.990577 dt/max=12.487222 dt/min=11.897395 dt/std=0.119125 train_loss/mean=0.596659 train_loss/max=0.596659 train_loss/min=0.596659 train_loss/std=0.000173 test_loss/mean=0.143485 test_loss/max=0.143485 test_loss/min=0.143485 test_loss/std=0.000000 test_acc/mean=95.563560 test_acc/max=95.563553 test_acc/min=95.563553 test_acc/std=0.000000
[2025-12-31 12:21:55,911549][I][examples/fsdp:340:fsdp_main] epoch=2 dt=0.361235 train_loss=0.174450 test_loss=0.080361 test_acc=97.511993 dt/mean=0.365279 dt/max=0.373996 dt/min=0.355496 dt/std=0.005433 train_loss/mean=0.174450 train_loss/max=0.174450 train_loss/min=0.174450 train_loss/std=0.000000 test_loss/mean=0.080361 test_loss/max=0.080361 test_loss/min=0.080361 test_loss/std=0.000022 test_acc/mean=97.511993 test_acc/max=97.511993 test_acc/min=97.511993 test_acc/std=0.000000
[2025-12-31 12:21:56,308947][I][examples/fsdp:340:fsdp_main] epoch=3 dt=0.359641 train_loss=0.120487 test_loss=0.060764 test_acc=98.021584 dt/mean=0.358203 dt/max=0.361614 dt/min=0.353194 dt/std=0.002922 train_loss/mean=0.120487 train_loss/max=0.120487 train_loss/min=0.120487 train_loss/std=0.000000 test_loss/mean=0.060764 test_loss/max=0.060764 test_loss/min=0.060764 test_loss/std=0.000015 test_acc/mean=98.021591 test_acc/max=98.021584 test_acc/min=98.021584 test_acc/std=0.000000
[2025-12-31 12:21:56,703145][I][examples/fsdp:340:fsdp_main] epoch=4 dt=0.356608 train_loss=0.098917 test_loss=0.052346 test_acc=98.301361 dt/mean=0.356618 dt/max=0.359070 dt/min=0.353434 dt/std=0.001995 train_loss/mean=0.098917 train_loss/max=0.098917 train_loss/min=0.098917 train_loss/std=0.000000 test_loss/mean=0.052346 test_loss/max=0.052346 test_loss/min=0.052346 test_loss/std=0.000000 test_acc/mean=98.301361 test_acc/max=98.301361 test_acc/min=98.301361 test_acc/std=0.031250
[2025-12-31 12:21:57,100230][I][examples/fsdp:340:fsdp_main] epoch=5 dt=0.357687 train_loss=0.085740 test_loss=0.047243 test_acc=98.441246 dt/mean=0.356900 dt/max=0.360295 dt/min=0.352879 dt/std=0.002699 train_loss/mean=0.085740 train_loss/max=0.085740 train_loss/min=0.085740 train_loss/std=0.000000 test_loss/mean=0.047243 test_loss/max=0.047243 test_loss/min=0.047243 test_loss/std=0.000000 test_acc/mean=98.441246 test_acc/max=98.441246 test_acc/min=98.441246 test_acc/std=0.000000
[2025-12-31 12:21:57,497234][I][examples/fsdp:340:fsdp_main] epoch=6 dt=0.357410 train_loss=0.080569 test_loss=0.044845 test_acc=98.471222 dt/mean=0.356574 dt/max=0.359746 dt/min=0.353584 dt/std=0.002156 train_loss/mean=0.080569 train_loss/max=0.080569 train_loss/min=0.080569 train_loss/std=0.000000 test_loss/mean=0.044845 test_loss/max=0.044845 test_loss/min=0.044845 test_loss/std=0.000015 test_acc/mean=98.471222 test_acc/max=98.471222 test_acc/min=98.471222 test_acc/std=0.000000
[2025-12-31 12:21:57,893327][I][examples/fsdp:340:fsdp_main] epoch=7 dt=0.355675 train_loss=0.075174 test_loss=0.043703 test_acc=98.481216 dt/mean=0.356044 dt/max=0.358311 dt/min=0.353675 dt/std=0.001370 train_loss/mean=0.075174 train_loss/max=0.075174 train_loss/min=0.075174 train_loss/std=0.000022 test_loss/mean=0.043703 test_loss/max=0.043703 test_loss/min=0.043703 test_loss/std=0.000011 test_acc/mean=98.481224 test_acc/max=98.481216 test_acc/min=98.481216 test_acc/std=0.000000
[2025-12-31 12:21:58,292161][I][examples/fsdp:340:fsdp_main] epoch=8 dt=0.358490 train_loss=0.073104 test_loss=0.041848 test_acc=98.551163 dt/mean=0.359055 dt/max=0.362143 dt/min=0.355792 dt/std=0.001879 train_loss/mean=0.073104 train_loss/max=0.073104 train_loss/min=0.073104 train_loss/std=0.000022 test_loss/mean=0.041848 test_loss/max=0.041848 test_loss/min=0.041848 test_loss/std=0.000000 test_acc/mean=98.551170 test_acc/max=98.551163 test_acc/min=98.551163 test_acc/std=0.000000
[2025-12-31 12:21:58,692175][I][examples/fsdp:340:fsdp_main] epoch=9 dt=0.359963 train_loss=0.069403 test_loss=0.041198 test_acc=98.571144 dt/mean=0.360091 dt/max=0.363091 dt/min=0.356911 dt/std=0.001945 train_loss/mean=0.069403 train_loss/max=0.069403 train_loss/min=0.069403 train_loss/std=0.000022 test_loss/mean=0.041198 test_loss/max=0.041198 test_loss/min=0.041198 test_loss/std=0.000011 test_acc/mean=98.571152 test_acc/max=98.571144 test_acc/min=98.571144 test_acc/std=0.000000
[2025-12-31 12:21:59,091674][I][examples/fsdp:340:fsdp_main] epoch=10 dt=0.358637 train_loss=0.068348 test_loss=0.041941 test_acc=98.571144 dt/mean=0.358994 dt/max=0.361870 dt/min=0.356423 dt/std=0.001696 train_loss/mean=0.068348 train_loss/max=0.068348 train_loss/min=0.068348 train_loss/std=0.000000 test_loss/mean=0.041941 test_loss/max=0.041941 test_loss/min=0.041941 test_loss/std=0.000000 test_acc/mean=98.571152 test_acc/max=98.571144 test_acc/min=98.571144 test_acc/std=0.000000
[2025-12-31 12:21:59,093446][I][examples/fsdp:342:fsdp_main] 11 epochs took 28.9s
[2025-12-31 12:21:59,124624][I][ezpz/history:2385:finalize] Saving plots to /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/mplot (matplotlib) and /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot (tplot)
dt dt/min
ββββββββββββββββββββββββββββββββββββ ββββββββββββββββββββββββββββββββββββ
12.5β€β β11.9β€- β
10.5β€β β 8.0β€ - β
βββ β 4.2β€ - β
8.4β€ β β 0.4β€ -------------------------------β
6.4β€ β β ββ¬ββββββββ¬βββββββββ¬ββββββββ¬ββββββββ¬β
4.4β€ β β 1.0 3.2 5.5 7.8 10.0
β β βdt/min iter
2.4β€ ββ β dt/std
0.4β€ ββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββ
ββ¬ββββββββ¬βββββββββ¬ββββββββ¬ββββββββ¬β0.119β€* β
1.0 3.2 5.5 7.8 10.0 0.099β€ * β
dt iter 0.060β€ * β
dt/mean 0.041β€ * β
ββββββββββββββββββββββββββββββββββββ0.001β€ *****************************β
12.0β€Β· β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
10.1β€Β· β 1.0 3.2 5.5 7.8 10.0
βΒ· βdt/std iter
8.1β€ Β· β dt/max
6.2β€ Β· β ββββββββββββββββββββββββββββββββββββ
β Β· β12.5β€+ β
4.2β€ Β· β10.5β€ + β
2.3β€ Β· β 6.4β€ + β
β Β· β 4.4β€ + β
0.4β€ Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·β 0.4β€ ++++++++++++++++++++++++++++++β
ββ¬ββββββββ¬βββββββββ¬ββββββββ¬ββββββββ¬β ββ¬ββββββββ¬βββββββββ¬ββββββββ¬ββββββββ¬β
1.0 3.2 5.5 7.8 10.0 1.0 3.2 5.5 7.8 10.0
dt/mean iter dt/max iter
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/dt.txt
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
12.5β€ ++ dt/max β
β -- dt/min β
β Β·Β· dt/mean β
β ββ dt β
10.5β€ β β
β β β
β ββ β
β β β
8.4β€ β β
β β β
β β β
β β β
6.4β€ β β
β β β
β β β
β β β
β β β
4.4β€ β β
β β β
β ββ β
β β β
2.4β€ β β
β ββ β
β β β
β β β
0.4β€ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ββ¬ββββββββββββββββββ¬βββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬β
1.0 3.2 5.5 7.8 10.0
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/dt_summary.txt
dt/mean hist dt/max hist
βββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββ
9.0β€ββββ β9.0β€ββββ β
7.5β€ββββ β7.5β€ββββ β
βββββ β βββββ β
6.0β€ββββ β6.0β€ββββ β
4.5β€ββββ β4.5β€ββββ β
βββββ β βββββ β
3.0β€ββββ β3.0β€ββββ β
1.5β€ββββ β1.5β€ββββ β
βββββ βββββ βββββ βββββ
0.0β€βββ βββββ0.0β€βββ βββββ
ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β
-0.2 3.0 6.2 9.3 12.5 -0.2 3.1 6.4 9.7 13.0
dt/min hist dt/std hist
βββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββ
9.0β€ββββ β9.0β€ββββ β
βββββ β βββββ β
7.5β€ββββ β7.5β€ββββ β
6.0β€ββββ β6.0β€ββββ β
βββββ β βββββ β
4.5β€ββββ β4.5β€ββββ β
βββββ β βββββ β
3.0β€ββββ β3.0β€ββββ β
1.5β€ββββ β1.5β€ββββ β
βββββ βββββ βββββ βββββ
0.0β€βββ βββββ0.0β€βββ βββββ
ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β
-0.2 3.0 6.1 9.3 12.4 -0.004 0.028 0.060 0.092 0.124
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/dt_hist.txt
test_acc test_acc/min
βββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββ
98.57β€ ββββββββββββββββββββ98.57β€ ----------------------β
98.07β€ ββββββββ β97.57β€ ------- β
β ββ β96.57β€ -- β
97.57β€ ββ β95.56β€-- β
97.07β€ β β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
96.57β€ ββ β 1.0 3.2 5.5 7.8 10.0
β β βtest_acc/min iter
96.06β€β β test_acc/std
95.56β€β β ββββββββββββββββββββββββββββββββββ
ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β0.0312β€ * β
1.0 3.2 5.5 7.8 10.0 0.0260β€ * * β
test_acc iter 0.0156β€ * * β
test_acc/mean 0.0104β€ * * β
βββββββββββββββββββββββββββββββββββ0.0000β€******** ******************β
98.57β€ Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·β ββ¬ββββββββ¬ββββββββ¬βββββββ¬ββββββββ¬β
98.07β€ Β·Β·Β· β 1.0 3.2 5.5 7.8 10.0
β Β·Β·Β·Β· βtest_acc/std iter
97.57β€ Β·Β·Β· β test_acc/max
97.07β€ Β· β βββββββββββββββββββββββββββββββββββ
β Β· β98.57β€ ++++++++++++++++++++++β
96.57β€ Β· β98.07β€ +++++++ β
96.06β€ Β· β97.07β€ + β
βΒ· β96.57β€ + β
95.56β€Β· β95.56β€++ β
ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
1.0 3.2 5.5 7.8 10.0 1.0 3.2 5.5 7.8 10.0
test_acc/mean iter test_acc/max iter
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/test_acc.txt
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
98.57β€ ++ test_acc/max ββββββββββββββββββββββ
β -- test_acc/min ββββββββββββββββββββββββΒ·Β·Β· β
β Β·Β· test_acc/mean ββββββΒ·Β· β
β ββ test_acc ββββ β
98.07β€ ββββΒ·Β· β
β ββΒ·Β·Β· β
β βββΒ· β
β ββΒ· β
97.57β€ ββΒ· β
β ββΒ· β
β β β
β ββ β
97.07β€ β β
β ββ β
β β β
β ββ β
β β β
96.57β€ ββ β
β β β
β ββ β
β β β
96.06β€ ββ β
β β β
βββ β
ββ β
95.56β€β β
ββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬β
1.0 3.2 5.5 7.8 10.0
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/test_acc_summary.txt
test_acc/mean hist test_acc/max hist
βββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββ
7.0β€ βββββ7.0β€ βββββ
5.8β€ βββββ5.8β€ βββββ
β βββββ β βββββ
4.7β€ βββββ4.7β€ βββββ
3.5β€ βββββ3.5β€ βββββ
β βββββ β βββββ
2.3β€ βββββ2.3β€ βββββ
1.2β€ βββββ1.2β€ βββββ
βββββ ββββ ββββββββ βββββ ββββ ββββββββ
0.0β€βββ βββ ββββββββ0.0β€βββ βββ ββββββββ
ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β
95.4 96.2 97.1 97.9 98.7 95.4 96.2 97.1 97.9 98.7
test_acc/min hist test_acc/std hist
βββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββ
7.0β€ βββββ9.0β€ββββ β
β βββββ βββββ β
5.8β€ βββββ7.5β€ββββ β
4.7β€ βββββ6.0β€ββββ β
β βββββ βββββ β
3.5β€ βββββ4.5β€ββββ β
β βββββ βββββ β
2.3β€ βββββ3.0β€ββββ β
1.2β€ βββββ1.5β€ββββ β
βββββ ββββ ββββββββ βββββ βββββ
0.0β€βββ βββ ββββββββ0.0β€βββ βββββ
ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β ββ¬βββββββββ¬ββββββββ¬βββββββββ¬βββββββββ
95.4 96.2 97.1 97.9 98.7 -0.0014 0.0071 0.0156 0.0241
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/test_acc_hist.txt
test_loss test_loss/min
βββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββ
0.143β€β β0.143β€- β
0.126β€ββ β0.109β€ -- β
β β β0.075β€ ----- β
0.109β€ β β0.041β€ -------------------------β
0.092β€ ββ β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
0.075β€ ββ β 1.0 3.2 5.5 7.8 10.0
β βββ βtest_loss/min iter
0.058β€ ββββββ β test_loss/std
0.041β€ βββββββββββββββββββββββ βββββββββββββββββββββββββββββββ
ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β0.0000216β€ * β
1.0 3.2 5.5 7.8 10.0 0.0000180β€ * *** * β
test_loss iter 0.0000108β€ * * * *** * β
test_loss/mean 0.0000072β€* * * * * * β
βββββββββββββββββββββββββββββββββββ0.0000000β€* ***** *** **β
0.143β€Β· β ββ¬βββββββ¬βββββββ¬βββββββ¬βββββββ¬β
0.126β€Β· β 1.0 3.2 5.5 7.8 10.0
β Β· βtest_loss/std iter
0.109β€ Β· β test_loss/max
0.092β€ Β· β βββββββββββββββββββββββββββββββββββ
β Β· β0.143β€+ β
0.075β€ Β· β0.126β€ ++ β
0.058β€ Β·Β·Β· β0.092β€ ++ β
β Β·Β·Β·Β·Β·Β·Β· β0.075β€ +++ β
0.041β€ Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·β0.041β€ +++++++++++++++++++++++++β
ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
1.0 3.2 5.5 7.8 10.0 1.0 3.2 5.5 7.8 10.0
test_loss/mean iter test_loss/max iter
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/test_loss.txt
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
0.143β€ ++ test_loss/max β
β -- test_loss/min β
β Β·Β· test_loss/mean β
β ββ test_loss β
0.126β€ β β
β β β
β β β
β β β
0.109β€ β β
β β β
β β β
β β β
0.092β€ β β
β β β
β β β
β ββ β
β ββ β
0.075β€ ββ β
β ββ β
β ββ β
β βββ β
0.058β€ βββββ β
β βββββ β
β ββββββ β
β ββββββββββββΒ·Β·Β·Β·Β·Β·Β·Β· β
0.041β€ βββββββββββββββββββββββββββββββββ
ββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬β
1.0 3.2 5.5 7.8 10.0
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/test_loss_summary.txt
test_loss/mean hist test_loss/max hist
βββββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββββ
6β€ββββ β6β€ββββ β
5β€ββββ β5β€ββββ β
βββββ β βββββ β
4β€ββββ β4β€ββββ β
3β€ββββ β3β€ββββ β
βββββ β βββββ β
2β€ββββββββ β2β€ββββββββ β
1β€ββββββββ ββββ βββββ1β€ββββββββ ββββ βββββ
βββββββββ ββββ βββββ βββββββββ ββββ βββββ
0β€βββββββ ββββ βββββ0β€βββββββ ββββ βββββ
ββ¬βββββββββ¬βββββββββ¬βββββββββ¬βββββββββ¬β ββ¬βββββββββ¬βββββββββ¬βββββββββ¬βββββββββ¬β
0.037 0.064 0.092 0.120 0.148 0.037 0.064 0.092 0.120 0.148
test_loss/min hist test_loss/std hist
βββββββββββββββββββββββββββββββββββββββ ββββββββββββββββββββββββββββββββββββ
6β€ββββ β5.00β€ββββ β
βββββ β βββββ β
5β€ββββ β4.17β€ββββ β
4β€ββββ β3.33β€ββββ β
βββββ β βββββ β
3β€ββββ β2.50β€ββββ β
βββββ β βββββ ββββ βββ β
2β€ββββββββ β1.67β€ββββ ββββ βββ β
1β€ββββββββ ββββ βββββ0.83β€ββββ ββββ βββ βββββ
βββββββββ ββββ βββββ βββββ ββββ βββ βββββ
0β€βββββββ ββββ βββββ0.00β€βββ βββ βββ βββββ
ββ¬βββββββββ¬βββββββββ¬βββββββββ¬βββββββββ¬β ββ¬βββββββββββββββββ¬ββββββββ¬βββββββββ
0.037 0.064 0.092 0.120 0.148 -0.0000010 0.0000108 0.0000167
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/test_loss_hist.txt
train_loss train_loss/min
βββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββ
0.597β€β β0.597β€- β
0.509β€β β0.421β€ -- β
β β β0.244β€ -- β
0.421β€ β β0.068β€ ----------------------------β
0.333β€ β β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
0.244β€ β β 1.0 3.2 5.5 7.8 10.0
β β βtrain_loss/min iter
0.156β€ βββββ β train_loss/std
0.068β€ βββββββββββββββββββββββββββ ββββββββββββββββββββββββββββββββ
ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β0.000173β€* β
1.0 3.2 5.5 7.8 10.0 0.000144β€* β
train_loss iter 0.000086β€ * β
train_loss/mean 0.000058β€ * ******** β
βββββββββββββββββββββββββββββββββββ0.000000β€ **************** ***β
0.597β€Β· β ββ¬βββββββ¬ββββββββ¬βββββββ¬βββββββ¬β
0.509β€Β· β 1.0 3.2 5.5 7.8 10.0
β Β· βtrain_loss/std iter
0.421β€ Β· β train_loss/max
0.333β€ Β· β βββββββββββββββββββββββββββββββββββ
β Β· β0.597β€+ β
0.244β€ Β· β0.509β€ + β
0.156β€ Β· β0.333β€ + β
β Β·Β·Β·Β·Β·Β·Β· β0.244β€ ++ β
0.068β€ Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·β0.068β€ ++++++++++++++++++++++++++++β
ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β ββ¬ββββββββ¬ββββββββ¬ββββββββ¬ββββββββ¬β
1.0 3.2 5.5 7.8 10.0 1.0 3.2 5.5 7.8 10.0
train_loss/mean iter train_loss/max iter
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/train_loss.txt
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
0.597β€ ++ train_loss/max β
β -- train_loss/min β
β Β·Β· train_loss/mean β
β ββ train_loss β
0.509β€ β β
β ββ β
β β β
β β β
0.421β€ β β
β β β
β ββ β
β β β
0.333β€ β β
β β β
β β β
β ββ β
β β β
0.244β€ β β
β β β
β β β
β ββ β
0.156β€ ββββ β
β ββββ β
β βββββββ β
β Β·Β·Β·βββββββββββββΒ·Β·Β·Β·Β·Β·Β·Β· β
0.068β€ ββββββββββββββββββββββββββββββββββββββββββ
ββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬ββββββββββββββββββ¬β
1.0 3.2 5.5 7.8 10.0
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/train_loss_summary.txt
train_loss/mean hist train_loss/max hist
βββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββ
8.0β€ββββ β8.0β€ββββ β
6.7β€ββββ β6.7β€ββββ β
βββββ β βββββ β
5.3β€ββββ β5.3β€ββββ β
4.0β€ββββ β4.0β€ββββ β
βββββ β βββββ β
2.7β€ββββ β2.7β€ββββ β
1.3β€ββββ β1.3β€ββββ β
βββββ ββββ βββββ βββββ ββββ βββββ
0.0β€βββ βββ βββββ0.0β€βββ βββ βββββ
ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β
0.04 0.19 0.33 0.48 0.62 0.04 0.19 0.33 0.48 0.62
train_loss/min hist train_loss/std hist
βββββββββββββββββββββββββββββββββββββ βββββββββββββββββββββββββββββββββββββββ
8.0β€ββββ β6β€ββββ β
βββββ β βββββ β
6.7β€ββββ β5β€ββββ β
5.3β€ββββ β4β€ββββ β
βββββ β βββββ β
4.0β€ββββ β3β€ββββββββ β
βββββ β βββββββββ β
2.7β€ββββ β2β€ββββββββ β
1.3β€ββββ β1β€ββββββββ βββββ
βββββ ββββ βββββ βββββββββ βββββ
0.0β€βββ βββ βββββ0β€βββββββ βββββ
ββ¬βββββββββ¬ββββββββ¬βββββββββ¬ββββββββ¬β ββ¬ββββββββββββββββββ¬βββββββββ¬ββββββββββ
0.04 0.19 0.33 0.48 0.62 -0.000008 0.000086 0.000133
text saved in /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/plots/tplot/train_loss_hist.txt
[2025-12-31 12:22:03,182749][W][ezpz/history:2320:save_dataset] Unable to save dataset to W&B, skipping!
[2025-12-31 12:22:03,184704][I][utils/__init__:651:dataset_to_h5pyfile] Saving dataset to: /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/train_dataset.h5
[2025-12-31 12:22:03,196685][I][ezpz/history:2433:finalize] Saving history report to /lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/outputs/ezpz-fsdp/2025-12-31-122159/report.md
[2025-12-31 12:22:03,202017][I][examples/fsdp:360:fsdp_main] dataset=<xarray.Dataset> Size: 2kB
Dimensions: (draw: 10)
Coordinates:
* draw (draw) int64 80B 0 1 2 3 4 5 6 7 8 9
Data variables: (12/25)
epoch (draw) int64 80B 1 2 3 4 5 6 7 8 9 10
dt (draw) float64 80B 12.49 0.3612 0.3596 ... 0.36 0.3586
train_loss (draw) float32 40B 0.5967 0.1744 0.1205 ... 0.0694 0.06835
test_loss (draw) float32 40B 0.1435 0.08036 ... 0.0412 0.04194
test_acc (draw) float32 40B 95.56 97.51 98.02 ... 98.55 98.57 98.57
epoch_mean (draw) float64 80B 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0
... ...
test_loss_min (draw) float64 80B 0.1435 0.08036 ... 0.0412 0.04194
test_loss_std (draw) float64 80B 0.0 2.158e-05 ... 1.079e-05 0.0
test_acc_mean (draw) float64 80B 95.56 97.51 98.02 ... 98.55 98.57 98.57
test_acc_max (draw) float64 80B 95.56 97.51 98.02 ... 98.55 98.57 98.57
test_acc_min (draw) float64 80B 95.56 97.51 98.02 ... 98.55 98.57 98.57
test_acc_std (draw) float64 80B 0.0 0.0 0.0 0.03125 ... 0.0 0.0 0.0 0.0
[2025-12-31 12:22:03,205311][I][examples/fsdp:452:<module>] Took 36.24 seconds
wandb:
wandb: π View run vivid-glade-86 at:
wandb: Find logs at: ../../../../../../lus/tegu/projects/datascience/foremans/projects/saforem2/ezpz/wandb/run-20251231_122128-11cqdt05/logs
[2025-12-31 12:22:04,704632][I][ezpz/launch:447:launch] ----[π ezpz.launch][stop][2025-12-31-122204]----
[2025-12-31 12:22:04,705324][I][ezpz/launch:448:launch] Execution finished with 0.
[2025-12-31 12:22:04,705724][I][ezpz/launch:449:launch] Executing finished in 42.32 seconds.
[2025-12-31 12:22:04,706075][I][ezpz/launch:450:launch] Took 42.32 seconds to run. Exiting.