ezpz.data.distributed⚓︎
ezpz/data/distributed.py
TPBroadcastDataLoader
⚓︎
Wrapper that ensures only TP leader samples/loads, then broadcasts each batch to other TP ranks.
Source code in src/ezpz/data/distributed.py
class TPBroadcastDataLoader:
"""
Wrapper that ensures only TP leader samples/loads, then broadcasts
each batch to other TP ranks.
"""
def __init__(
self, dl: DataLoader, tp_group: torch.distributed.ProcessGroup
):
self.dl = dl
self.tp_group = tp_group
self.leader = _tp_is_leader(tp_group)
def __iter__(self) -> Iterator:
it: Iterable = iter(self.dl) if self.leader else range(len(self.dl))
# Non-leaders iterate dummy range to keep step counts aligned
for maybe_batch in it:
batch = maybe_batch if self.leader else None
batch = _broadcast_batch(batch, self.tp_group)
yield batch
def __len__(self) -> int:
return len(self.dl)
get_random_dataset_fsdp_tp(batch_size, vocab_size, seq_length, *, num_workers=0, pin_memory=True, dp_group=None, tp_group=None, broadcast_within_tp=False, drop_last=True, seed=1337)
⚓︎
Build dataset/sampler/dataloader for FSDP (DP) + Tensor Parallel (TP).
Key idea
- Shard the dataset ONLY across the DP group (FSDP replica group).
- Optionally broadcast each batch within TP so only TP-leader does I/O.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dp_group
|
Optional[ProcessGroup]
|
Process group that defines FSDP data-parallel replicas. |
None
|
tp_group
|
Optional[ProcessGroup]
|
Process group that defines tensor parallel group. |
None
|
broadcast_within_tp
|
bool
|
If True, TP leader loads and broadcasts batches. |
False
|
drop_last
|
bool
|
Prefer True for static shapes across DP replicas. |
True
|
seed
|
int
|
Base seed for shuffling (per-epoch add epoch to this). |
1337
|
Returns:
| Type | Description |
|---|---|
Dict[str, Any]
|
dict with 'dataset', 'sampler', 'dataloader' |
Source code in src/ezpz/data/distributed.py
def get_random_dataset_fsdp_tp(
batch_size: int,
vocab_size: int,
seq_length: int,
*,
num_workers: int = 0,
pin_memory: bool = True,
dp_group: Optional[torch.distributed.ProcessGroup] = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
broadcast_within_tp: bool = False,
drop_last: bool = True,
seed: int = 1337,
) -> Dict[str, Any]:
"""
Build dataset/sampler/dataloader for FSDP (DP) + Tensor Parallel (TP).
Key idea:
- Shard the dataset ONLY across the **DP group** (FSDP replica group).
- Optionally broadcast each batch within TP so only TP-leader does I/O.
Args:
dp_group: Process group that defines FSDP data-parallel replicas.
tp_group: Process group that defines tensor parallel group.
broadcast_within_tp: If True, TP leader loads and broadcasts batches.
drop_last: Prefer True for static shapes across DP replicas.
seed: Base seed for shuffling (per-epoch add epoch to this).
Returns:
dict with 'dataset', 'sampler', 'dataloader'
"""
from ezpz.data.text import RandomTokenDataset
dset = RandomTokenDataset(vocab_size=vocab_size, seq_length=seq_length)
use_dist = _is_dist()
sampler = None
if use_dist:
# Determine DP rank/world_size; TP is ignored by the sampler.
dp_rank, dp_world = _rank_ws(dp_group)
# Important: num_replicas/rank are DP-based, not global.
sampler = DistributedSampler(
dset,
num_replicas=dp_world,
rank=dp_rank,
shuffle=True,
drop_last=drop_last,
seed=seed,
)
dl = DataLoader(
dset,
batch_size=batch_size,
sampler=sampler,
shuffle=(sampler is None), # never shuffle when a sampler is provided
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
persistent_workers=(num_workers > 0),
)
if use_dist and broadcast_within_tp and tp_group is not None:
dl = TPBroadcastDataLoader(dl, tp_group)
return {
"dataset": dset,
"sampler": sampler,
"dataloader": dl,
}