ezpz.examples.fsdp_tp⚓︎
ezpz/examples/fsdp_tp.py
2D tensor/sequence parallel + FSDP training demo on a Llama-style model.
Sam Foreman 2025-09-08
Modified from: https://pytorch.org/tutorials/intermediate/TP_tutorial.html
This is the script to test 2D Parallel which combines Tensor/Sequence parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a example Llama2 model. We show an E2E working flow from forward, backward and optimization.
We enabled Fully Sharded Data Parallel + Tensor Parallel in separate parallel dimensions: Data Parallel ("dp") across hosts Tensor Parallel ("tp") within each host
We use a simple diagram to illustrate below:
+-----.-----+-----+-----+ | 0 | 1 | 2 | 3 | | | | | | +-----+-----+-----+-----+ | 4 | 5 | 6 | 7 | | | | | | +-----+-----+-----+-----+ | 8 | 9 | 10 | 11 | | | | | | +-----+-----+-----+-----+
+----------+ +------------+ +----------+ +------------+ | Host 1 | | Host 2 | | | | Host N | | 8 GPUs | | 8 GPUs | | | | 8 GPUs | | | | | | ... | | | | (TP) | | (TP) | | | | (TP) | |[0,1,..,7]| | [8,9..,15] | | | | [8N-8,8N-7 | | | | | | | | .., 8N-1] | | | | | | | | | +----------+ +------------+ +----------+ +------------+
- FSDP:
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
Launch with:
ezpz launch -m ezpz.examples.fsdp_tp --tp 2 --batch-size 8
Help output (python3 -m ezpz.examples.fsdp_tp --help):
usage: fsdp_tp.py [-h] [--dim DIM] [--n-layers N_LAYERS] [--n-heads N_HEADS]
[--n-kv-heads N_KV_HEADS] [--multiple-of MULTIPLE_OF]
[--ffn-dim-multiplier FFN_DIM_MULTIPLIER]
[--norm-eps NORM_EPS] [--vocab-size VOCAB_SIZE]
[--seq-length SEQ_LENGTH] [--lr LR] [--epochs EPOCHS]
[--batch-size BATCH_SIZE]
[--test-batch-size TEST_BATCH_SIZE]
[--num-workers NUM_WORKERS] [--seed SEED] [--tp TP]
[--sharding-strategy SHARDING_STRATEGY]
[--max-grad-norm MAX_GRAD_NORM] [--outdir OUTDIR]
[--dataset DATASET] [--tokenizer_name TOKENIZER_NAME]
[--model_name_or_path MODEL_NAME_OR_PATH]
[--hf-split HF_SPLIT] [--hf-text-column HF_TEXT_COLUMN]
[--hf-limit HF_LIMIT] [--seq-len SEQ_LEN]
[--max-seq-len MAX_SEQ_LEN] [--depth-init DEPTH_INIT]
[--fp32]
2D Parallel Training
options:
-h, --help show this help message and exit
--dim DIM
--n-layers N_LAYERS
--n-heads N_HEADS
--n-kv-heads N_KV_HEADS
--multiple-of MULTIPLE_OF
--ffn-dim-multiplier FFN_DIM_MULTIPLIER
--norm-eps NORM_EPS
--vocab-size VOCAB_SIZE
--seq-length SEQ_LENGTH
--lr LR
--epochs EPOCHS
--batch-size BATCH_SIZE
--test-batch-size TEST_BATCH_SIZE
--num-workers NUM_WORKERS
--seed SEED
--tp TP
--sharding-strategy SHARDING_STRATEGY
--max-grad-norm MAX_GRAD_NORM
--outdir OUTDIR
--dataset DATASET
--tokenizer_name TOKENIZER_NAME
--model_name_or_path MODEL_NAME_OR_PATH
--hf-split HF_SPLIT, --hf_split HF_SPLIT
Dataset split to load.
--hf-text-column HF_TEXT_COLUMN, --hf_text_column HF_TEXT_COLUMN
Column containing raw text in the dataset.
--hf-limit HF_LIMIT, --hf_limit HF_LIMIT
Number of rows to sample from the HF dataset for quick
experiments.
--seq-len SEQ_LEN
--max-seq-len MAX_SEQ_LEN
--depth-init DEPTH_INIT
--fp32 Disable mixed precision (use fp32) for debugging NaNs.
The remaining comments outline the parallel layout used to combine TP/SP with FSDP.
main(args)
⚓︎
Entrypoint to set up distributed context and dispatch training.
Source code in src/ezpz/examples/fsdp_tp.py
def main(args: argparse.Namespace) -> int:
"""Entrypoint to set up distributed context and dispatch training."""
rank = ezpz.dist.setup_torch(tensor_parallel_size=args.tp, seed=args.seed)
if rank == 0:
outdir = args.outdir if args.outdir is not None else OUTDIR
else:
outdir = None
outdir = ezpz.dist.broadcast(outdir, root=0)
logger.info(f"Using {outdir=}")
train(args=args, outdir=outdir)
return 0
parallelize(model, device_mesh, mixed_precision, sharding_strategy=None, device_id=None)
⚓︎
Wrap the model with tensor-parallel and FSDP sharding strategies.
Source code in src/ezpz/examples/fsdp_tp.py
def parallelize(
model: nn.Module,
device_mesh: DeviceMesh,
mixed_precision: Optional[MixedPrecision],
sharding_strategy: Optional[ShardingStrategy | str] = None,
device_id: Optional[torch.device] = None,
) -> nn.Module:
"""Wrap the model with tensor-parallel and FSDP sharding strategies."""
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]
if isinstance(sharding_strategy, str):
sharding_strategy = SHARDING_STRATEGIES.get(sharding_strategy, None)
model.init_weights() # type: ignore
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
# use DTensor as the output
# use_local_output=False,
),
},
)
assert isinstance(model.layers, Iterable)
for _, transformer_block in enumerate(model.layers):
layer_tp_plan = {
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1), None), # type:ignore
desired_input_layouts=(Replicate(), None), # type:ignore
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w3": ColwiseParallel(),
}
attn_layer = transformer_block.attention # type: ignore
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()
parallelize_module(
module=transformer_block, # type: ignore
device_mesh=tp_mesh,
parallelize_plan=layer_tp_plan,
)
# from torch.distributed.fsdp import fully_shard
# ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
# ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
# ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
# ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
# ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
sharded_model = FSDP(
model,
mixed_precision=mixed_precision,
device_mesh=dp_mesh,
sharding_strategy=sharding_strategy,
device_id=device_id,
)
logger.info(f"Model after parallelization:\n{sharded_model=}\n")
return sharded_model
parse_args(argv=None)
⚓︎
CLI parser for 2D parallel (TP/SP + FSDP) training.
Source code in src/ezpz/examples/fsdp_tp.py
def parse_args(argv: Optional[list[str]] = None):
"""CLI parser for 2D parallel (TP/SP + FSDP) training."""
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(description="2D Parallel Training")
parser.add_argument("--dim", type=int, default=256)
parser.add_argument("--n-layers", type=int, default=32)
parser.add_argument("--n-heads", type=int, default=32)
parser.add_argument("--n-kv-heads", type=int, default=4)
parser.add_argument("--multiple-of", type=int, default=360)
parser.add_argument("--ffn-dim-multiplier", type=float, default=None)
parser.add_argument("--norm-eps", type=float, default=1e-5)
parser.add_argument("--vocab-size", type=int, default=32_000)
parser.add_argument("--seq-length", type=int, default=2048)
parser.add_argument("--lr", type=float, default=3e-3)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument(
"--model",
type=str,
default=None,
choices=sorted(MODEL_PRESETS.keys()),
help="Model size preset (overrides dim/layer defaults)",
)
parser.add_argument("--test-batch-size", type=int, default=1000)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--tp", type=int, default=2)
parser.add_argument("--sharding-strategy", type=str, default="full_shard")
parser.add_argument("--max-grad-norm", type=float, default=1.0)
parser.add_argument("--outdir", type=str, default="outputs/fsdp_tp")
# parser.add_argument('--dataset', type=str, default='random')
parser.add_argument(
"--dataset", type=str, default="eliplutchok/fineweb-small-sample"
)
parser.add_argument(
"--tokenizer_name", type=str, default="meta-llama/llama-2-7b-hf"
)
parser.add_argument(
"--model_name_or_path",
type=str,
default=None,
)
parser.add_argument(
"--hf-split",
"--hf_split",
type=str,
default="train",
help="Dataset split to load.",
)
parser.add_argument(
"--hf-text-column",
"--hf_text_column",
type=str,
default="text",
help="Column containing raw text in the dataset.",
)
parser.add_argument(
"--hf-limit",
"--hf_limit",
type=int,
default=512,
help="Number of rows to sample from the HF dataset for quick experiments.",
)
# parser.add_argument('--max_batch_size', type=int, default=None)
parser.add_argument(
"--seq-len", type=int, default=int(os.environ.get("SEQ_LEN", 1024))
)
parser.add_argument("--max-seq-len", type=int, default=32768)
parser.add_argument("--depth-init", type=bool, default=True)
parser.add_argument(
"--fp32",
action="store_true",
help="Disable mixed precision (use fp32) for debugging NaNs.",
)
# max_batch_size: int = 32
# max_seq_len: int = 32768
# depth_init: bool = True
args = parser.parse_args(argv)
apply_model_preset(args, argv)
return args
train(args, outdir)
⚓︎
Run TP/SP + FSDP training and optionally log metrics.
Source code in src/ezpz/examples/fsdp_tp.py
def train(
args: argparse.Namespace,
outdir: Path | str | os.PathLike,
) -> int:
"""Run TP/SP + FSDP training and optionally log metrics."""
world_size = ezpz.dist.get_world_size()
assert world_size % args.tp == 0, "WORLD_SIZE must be divisible by TP"
dpsize = world_size // args.tp
device_mesh = init_device_mesh(
str(ezpz.get_torch_device()),
(dpsize, args.tp),
mesh_dim_names=("dp", "tp"),
)
logger.info(f"Device mesh created:\n{device_mesh=}")
hf_dataset = None
hf_tokenizer = None
if args.dataset.lower() not in {"mnist", "random"}:
from ezpz.data.hf import get_hf_text_dataset
seed = int(os.environ.get("EZPZ_HF_SAMPLE_SEED", "1337"))
hf_dataset, hf_tokenizer = get_hf_text_dataset(
dataset_name=args.dataset,
split=args.hf_split,
text_column=args.hf_text_column,
tokenizer_name=args.tokenizer_name,
seq_len=args.seq_len,
limit=args.hf_limit,
seed=seed,
)
if hf_tokenizer.vocab_size != args.vocab_size:
logger.warning(
"Overriding vocab_size from %s to tokenizer vocab_size=%s",
args.vocab_size,
hf_tokenizer.vocab_size,
)
args.vocab_size = hf_tokenizer.vocab_size
config = ModelArgs(
dim=args.dim,
n_layers=args.n_layers,
n_heads=args.n_heads,
n_kv_heads=args.n_kv_heads,
batch_size=args.batch_size,
vocab_size=args.vocab_size,
multiple_of=args.multiple_of,
)
logger.info(f"config:\n{config}")
metrics_every = int(os.environ.get("EZPZ_METRICS_EVERY", "1"))
track_logits = os.environ.get("EZPZ_TRACK_LOGITS", "0") == "1"
track_hist = os.environ.get("EZPZ_TRACK_HIST", "0") == "1"
track_act_hist = os.environ.get("EZPZ_TRACK_ACT_HIST", "1") == "1"
hist_bins = int(os.environ.get("EZPZ_HIST_BINS", "64"))
hist_samples = int(os.environ.get("EZPZ_HIST_SAMPLES", "20000"))
dataset_tag = args.dataset.lower().replace("/", "_")
if ezpz.get_rank() == 0 and not os.environ.get("WANDB_DISABLED", False):
run = ezpz.dist.setup_wandb(project_name=WBPROJ_NAME)
if wandb is not None:
assert run is not None and run is wandb.run
from dataclasses import asdict
wandb.config.update(ezpz.get_dist_info())
wandb.config.update(asdict(config)) # type:ignore
device_type = ezpz.dist.get_torch_device_type()
device = (
torch.device("cpu")
if device_type == "cpu"
else torch.device(f"{device_type}:{ezpz.get_local_rank()}")
)
model = Transformer.from_model_args(config)
mstr = summarize_model(
model,
verbose=False,
depth=2,
# input_size=(
# torch.tensor((int(args.batch_size), int(args.seq_length))).to(
# torch.long
# )
# ).shape,
)
logger.info(f"\n{mstr}")
model.to(device)
mp_config: Optional[MixedPrecision] = None
if not args.fp32:
mp_config = MixedPrecision(
param_dtype=torch.bfloat16,
cast_forward_inputs=True,
reduce_dtype=torch.float32,
)
model = parallelize(
model,
device_mesh,
mp_config,
sharding_strategy=args.sharding_strategy,
device_id=device,
)
base_model = model
if not hasattr(base_model, "layers"):
base_model = getattr(model, "_fsdp_wrapped_module", model)
act_activations: dict[str, torch.Tensor] = {}
act_handles: list[torch.utils.hooks.RemovableHandle] = []
if track_hist and track_act_hist and ezpz.get_rank() == 0:
hist_layers_spec = os.environ.get(
"EZPZ_HIST_LAYERS", f"0,{config.n_layers - 1}"
)
layer_ids = _parse_hist_layers(hist_layers_spec, config.n_layers)
act_activations, act_handles = _register_activation_hooks(
base_model, layer_ids
)
logger.info(f"Creating optimizer=AdamW with lr={args.lr}")
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, foreach=True)
# reuse device for input placement
tp_group = device_mesh.get_group("tp")
if args.dataset.lower() == "mnist":
data_prefix = Path(os.getcwd()).joinpath(
".cache", "ezpz", "data", f"{args.dataset.lower()}"
)
from ezpz.data.vision import get_mnist
from ezpz.data.distributed import TPBroadcastDataLoader
data = get_mnist(
outdir=Path(data_prefix),
train_batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
num_replicas=dpsize,
rank=device_mesh.get_local_rank("dp"),
pin_memory=True,
num_workers=args.num_workers,
)
dataset = data["dataset"]
sampler = data["sampler"]
dataloader = data["dataloader"]
if args.tp > 1:
dataloader = TPBroadcastDataLoader(dataloader, tp_group)
elif args.dataset.lower() == "random":
from ezpz.data.distributed import get_random_dataset_fsdp_tp
data = get_random_dataset_fsdp_tp(
batch_size=args.batch_size,
vocab_size=args.vocab_size,
seq_length=args.seq_length,
dp_group=device_mesh.get_group("dp"),
tp_group=tp_group,
broadcast_within_tp=True,
drop_last=True,
)
dataset = data["dataset"]
sampler = data["sampler"]
dataloader = data["dataloader"]
# if args.dataset.lower() != "random":
else:
from ezpz.data.distributed import TPBroadcastDataLoader
assert hf_dataset is not None
dataset = hf_dataset
sampler = (
DistributedSampler(
dataset=dataset,
num_replicas=dpsize,
rank=device_mesh.get_local_rank("dp"),
)
if ezpz.get_world_size() > 1
else None
)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
drop_last=False,
)
if args.tp > 1:
dataloader = TPBroadcastDataLoader(dataloader, tp_group)
# ezpz.breakpoint(0)
logger.info("Starting 2D training...")
model.train()
# outdir = Path(args.outdir).joinpath(ezpz.utils.get_timestamp())
metrics_path = Path(outdir).joinpath(
f"metrics-{ezpz.dist.get_rank()}.jsonl"
)
Path(outdir).mkdir(parents=True, exist_ok=True)
history = ezpz.history.History(
report_dir=outdir,
report_enabled=True,
jsonl_path=metrics_path,
jsonl_overwrite=True,
distributed_history=(
1 < world_size <= 384 # and not config.pytorch_profiler
),
)
# For TP, input needs to be the same across all TP ranks.
# while for SP, input can be different across all ranks
# We will use dp_rank for setting the random seed
# to mimic the behavior of the dataloader
# x = torch.tensor((args.batch_size, args.seq_len))
x = torch.tensor(0)
global_step = 0
for epoch in range(args.epochs):
if sampler is not None:
sampler.set_epoch(epoch)
for idx, batch in enumerate(dataloader):
ezpz.dist.synchronize()
t0 = perf_counter()
attn_mask = None
if isinstance(batch, dict) and "input_ids" in batch:
x = batch["input_ids"]
attn_mask = batch.get("attention_mask")
else:
x = batch
assert isinstance(x, torch.Tensor)
x = x.to(device)
x = x.to(torch.long)
if args.dataset == "random":
inp = x[:, :-1]
labels = x[:, 1:]
else:
inp = x[:, :-1]
labels = x[:, 1:]
inp = inp.to(device)
labels = labels.to(device)
if attn_mask is not None:
attn_mask = attn_mask.to(device)
pred = model(inp)
local_seq_len = pred.shape[1]
if labels.shape[1] != local_seq_len:
labels = _slice_for_sequence_parallel(labels, local_seq_len)
if attn_mask is not None:
if attn_mask.shape[1] > 1:
attn_labels = attn_mask[:, 1:]
else:
attn_labels = attn_mask
if attn_labels.shape[1] != local_seq_len:
attn_labels = _slice_for_sequence_parallel(
attn_labels, local_seq_len
)
labels = labels.clone()
labels[attn_labels == 0] = -100
pad_id = getattr(dataset, "pad_id", None)
if pad_id is not None:
labels = labels.clone()
labels[labels == int(pad_id)] = -100
ezpz.dist.synchronize()
t1 = perf_counter()
tp_mod = getattr(ezpz, "tp", None)
tp_rank = (
getattr(tp_mod, "get_tensor_parallel_rank", lambda: 0)()
if tp_mod is not None
else 0
)
if epoch == 0 and idx == 0:
pred_finite = torch.isfinite(pred)
pred_nonfinite = int((~pred_finite).sum().item())
pred_max = float(pred.abs().max().item())
logger.info(
"pred_stats rank=%s tp=%s shape=%s nonfinite=%s max_abs=%s",
ezpz.get_rank(),
tp_rank,
tuple(pred.shape),
pred_nonfinite,
f"{pred_max:.6f}",
)
loss = F.cross_entropy(
pred.reshape(-1, pred.size(-1)),
labels.reshape(-1),
ignore_index=-100,
)
if epoch == 0 and idx == 0:
valid_labels = int((labels != -100).sum().item())
logger.info(
"loss_inputs rank=%s tp=%s local_seq_len=%s labels=%s valid_labels=%s",
ezpz.get_rank(),
tp_rank,
local_seq_len,
tuple(labels.shape),
valid_labels,
)
# loss = F.cross_entropy(
# pred.flatten(0, 1),
# labels.flatten(0, 1),
# )
# loss = output.loss
optimizer.zero_grad(set_to_none=True)
loss.backward()
grad_norm_preclip = None
if args.max_grad_norm > 0:
grad_norm_preclip = torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm
)
optimizer.step()
ezpz.dist.synchronize()
t2 = perf_counter()
global_step += 1
metrics: dict[str, object] = {
"train/iter": global_step,
"train/epoch": epoch,
"train/bidx": idx,
"train/loss": loss.item(),
"train/dt": t2 - t0,
"train/dtf": t1 - t0,
"train/dtb": t2 - t1,
}
if grad_norm_preclip is not None:
metrics["grad/norm_preclip"] = float(grad_norm_preclip)
if global_step % max(metrics_every, 1) == 0:
metrics.update(_collect_param_grad_stats(model, device))
metrics["opt/iter"] = (global_step,)
metrics["opt/lr"] = float(optimizer.param_groups[0]["lr"])
metrics["input/iter"] = (global_step,)
metrics["input/max"] = float(x.max().item())
metrics["input/min"] = float(x.min().item())
metrics["labels/valid"] = float((labels != -100).sum().item())
if track_logits:
pred_finite = torch.isfinite(pred)
metrics["logits/nonfinite"] = float(
(~pred_finite).sum().item()
)
metrics["logits/max_abs"] = float(pred.abs().max().item())
if track_hist and ezpz.get_rank() == 0:
logits_sample = _sample_tensor_values(pred, hist_samples)
if logits_sample is not None:
logits_hist = _histogram_dict(logits_sample, hist_bins)
if logits_hist is not None:
metrics[f"hist/{dataset_tag}/logits"] = logits_hist
layer_grad_norms = _collect_layer_grad_norms(base_model)
if layer_grad_norms:
layer_grad_hist = _histogram_dict(
torch.tensor(layer_grad_norms), hist_bins
)
if layer_grad_hist is not None:
metrics[
f"hist/{dataset_tag}/grad_norm_per_layer"
] = layer_grad_hist
if track_act_hist and act_activations:
for act_key, act_tensor in act_activations.items():
act_sample = _sample_tensor_values(
act_tensor, hist_samples
)
act_hist = _histogram_dict(act_sample, hist_bins)
if act_hist is not None:
metrics[
f"hist/{dataset_tag}/activations/{act_key}"
] = act_hist
_wandb_log_histograms(
metrics, step=global_step, enabled=track_hist
)
history.update(metrics, summarize=False)
history.log_metrics(
metrics,
logger=logger,
debug_prefixes=("hist/",),
include_summary=True,
rank0_only_summary=True,
)
if epoch == 0 and idx == 0:
logger.info(f"{x.shape}")
if act_handles:
for handle in act_handles:
handle.remove()
ezpz.dist.barrier()
logger.info("Finished 2D training")
if ezpz.get_rank() == 0:
dataset = history.finalize(
run_name=WBRUN_NAME,
dataset_fname="train",
warmup=0.1,
)
logger.info(f"{dataset=}")
return 0