ezpz.distributed⚓︎
The core distributed training module. Provides initialization, device/backend detection, collective operations, and model wrapping for distributed PyTorch training across any supported hardware.
Setup⚓︎
The primary entry point is setup_torch(), which bootstraps the entire
distributed environment in a single call:
Advanced Parameters⚓︎
setup_torch() accepts keyword-only parameters for multi-dimensional
parallelism:
rank = ezpz.setup_torch(
seed=42,
tensor_parallel_size=2, # Split model across 2 GPUs per TP group
pipeline_parallel_size=1, # No pipeline parallelism
context_parallel_size=1, # No context parallelism
)
| Parameter | Default | Description |
|---|---|---|
port |
None |
Rendezvous port (auto-detected if not set) |
seed |
None |
Random seed for reproducibility |
timeout |
None |
DDP init timeout in seconds (env TORCH_DDP_TIMEOUT, default 3600) |
verbose |
False |
Enable verbose logging during setup |
tensor_parallel_size* |
1 |
Number of ranks per tensor-parallel group |
pipeline_parallel_size* |
1 |
Number of pipeline stages |
context_parallel_size* |
1 |
Number of context-parallel ranks |
tensor_parallel_backend* |
None |
Backend for TP groups (auto-detected) |
pipeline_parallel_backend* |
None |
Backend for PP groups (auto-detected) |
context_parallel_backend* |
None |
Backend for CP groups (auto-detected) |
data_parallel_backend* |
None |
Backend for DP groups (auto-detected) |
device_id* |
None |
Override the per-rank device index. Defaults to LOCAL_RANK. When set, this is the device the process group binds to (init_process_group(device_id=...)) AND the device setup_torch activates before init. On XPU this binding is load-bearing — see Multi-dimensional DeviceMesh for why. |
Parameters marked with * are keyword-only.
Model Wrapping⚓︎
High-level API⚓︎
wrap_model() selects the appropriate wrapping strategy based on the arguments:
model = MyModel().to(ezpz.get_torch_device())
# FSDP wrapping (default)
model = ezpz.wrap_model(model, use_fsdp=True, dtype="bf16")
# DDP wrapping
model = ezpz.wrap_model(model, use_fsdp=False)
Explicit DDP⚓︎
FSDP (v1)⚓︎
FSDP2⚓︎
FSDP2 uses torch.distributed._composable.fsdp.fully_shard for per-layer
sharding with optional device mesh support:
model = ezpz.wrap_model_for_fsdp2(
model,
dtype="bf16",
device_mesh=my_mesh, # optional DeviceMesh for multi-dim parallelism
)
Multi-dimensional DeviceMesh (XPU-safe)⚓︎
Since v0.18.4
ezpz.init_device_mesh_safe() was added in
v0.18.4
(PR #149) — see
troubleshooting
for the underlying xccl split_group limitation it works around.
Building a DeviceMesh directly on Aurora/Sunspot (xccl) requires a small
workaround — torch's DeviceMesh._init_one_process_group prefers the
split_group path when the default PG is device-bound, but the current
xccl backend reports supports_splitting=False and raises:
ezpz.init_device_mesh_safe() is a drop-in for torch.distributed.init_device_mesh
that round-trips bound_device_id around the call so torch takes the
new_group(ranks, ...) fallback (which xccl supports), then restores the
binding so FSDP2's per-device PG resolution still works. No-op on CUDA/NCCL
(which supports split_group natively).
import ezpz
# 1D mesh
mesh = ezpz.init_device_mesh_safe("xpu", (world_size,))
# 2D (dp, tp) mesh — see ezpz.examples.fsdp_tp
mesh = ezpz.init_device_mesh_safe(
str(ezpz.get_torch_device()),
(dp_size, tp_size),
mesh_dim_names=("dp", "tp"),
)
ezpz.wrap_model's auto-created 1D mesh and ezpz.examples.fsdp_tp both
route through this helper, so callers using those paths get the workaround
for free. Reach for init_device_mesh_safe directly when you're building
your own mesh (TP, PP, CP, EP, 2D/3D combinations).
Hostfile Helpers⚓︎
Functions for managing hostfiles used by MPI launchers:
# Read nodes from an existing hostfile
nodes = ezpz.get_nodes_from_hostfile("/path/to/hostfile")
# Find or create a hostfile with fallback logic
hostfile = ezpz.get_hostfile_with_fallback()
# Write the current hostname to a hostfile
ezpz.write_localhost_to_hostfile("/tmp/hostfile")
# Write a list of hosts to a hostfile
ezpz.write_hostfile_from_list_of_hosts(
["node1", "node2", "node3"],
"/tmp/hostfile"
)
Dtype Map⚓︎
The TORCH_DTYPES_MAP dictionary maps string dtype names to torch.dtype
objects:
Supported keys: "bf16", "bfloat16", "fp16", "float16", "half",
"fp32", "float32".
Simplified distributed training primitives for ezpz.
This module is a clean rewrite of :pymod:ezpz.dist. It preserves the
same public API surface that the rest of the codebase relies on while
eliminating module-level side effects, dead code, redundant aliases, and
tangled responsibilities.
Design principles:
- No side effects on import -- no env vars mutated, no devices set, no wandb probed until the caller explicitly asks for it.
- Lazy heavy imports --
mpi4py,torch, optional deps are imported inside the functions that need them so thatimport ezpz.distributedstays fast. - Single responsibility per function -- no 200-line monoliths.
- Flat public API -- every symbol listed in
__all__is a first-class citizen; everything else is prefixed with_.
FSDP_SHARDING_STRATEGIES = {'full-shard': True, 'shard-grad-op': False, 'no-shard': None, 'hybrid-shard': 'hybrid'}
module-attribute
⚓︎
Map CLI sharding strategy names to reshard_after_forward values.
Used by :func:wrap_model and the example CLI parsers::
--fsdp-sharding-strategy full-shard # reshard_after_forward=True (ZeRO-3)
--fsdp-sharding-strategy shard-grad-op # reshard_after_forward=False (ZeRO-2)
--fsdp-sharding-strategy hybrid-shard # reshard to intra-node size
--fsdp-sharding-strategy no-shard # fall back to DDP
TORCH_DTYPES_MAP = {'bf16': None, 'bfloat16': None, 'fp16': None, 'float16': None, 'half': None, 'fp32': None, 'float32': None}
module-attribute
⚓︎
Mapping of short dtype names to :class:torch.dtype objects.
Populated lazily on first access via :func:_ensure_dtype_map.
all_reduce(obj, op=None, implementation=None)
⚓︎
All-reduce obj across all ranks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj
|
Any
|
Numeric value to reduce. |
required |
op
|
Any
|
Reduction operation (defaults to |
None
|
implementation
|
str | None
|
|
None
|
Returns:
| Type | Description |
|---|---|
Any
|
The reduced value. |
Source code in src/ezpz/distributed.py
def all_reduce(
obj: Any,
op: Any = None,
implementation: str | None = None,
) -> Any:
"""All-reduce *obj* across all ranks.
Args:
obj: Numeric value to reduce.
op: Reduction operation (defaults to ``MPI.SUM``).
implementation: ``"mpi"`` (default) or ``"torch"``.
Returns:
The reduced value.
"""
impl = (implementation or "mpi").lower()
if impl == "mpi":
from mpi4py import MPI
op = MPI.SUM if op is None else op
return _get_mpi_comm().allreduce(obj, op=op)
if impl in {"torch", "pytorch", "pt"}:
import torch
import torch.distributed as tdist
op = tdist.ReduceOp.SUM if op is None else op
tensor = torch.tensor(obj)
tdist.all_reduce(tensor, op=op)
return tensor.item()
raise ValueError(
f"Unsupported all_reduce implementation: {implementation}"
)
barrier(group=None, implementation=None)
⚓︎
Barrier across all ranks.
Tries MPI first (fast), falls back to torch.distributed.barrier.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
group
|
Any
|
Optional |
None
|
implementation
|
str | None
|
Force |
None
|
Source code in src/ezpz/distributed.py
def barrier(
group: Any = None,
implementation: str | None = None,
) -> None:
"""Barrier across all ranks.
Tries MPI first (fast), falls back to ``torch.distributed.barrier``.
Args:
group: Optional ``torch.distributed`` process group.
implementation: Force ``"mpi"`` or ``"torch"``.
"""
if implementation is not None and implementation.lower() not in {
"mpi",
"torch",
}:
raise ValueError(
f"Unsupported barrier implementation: {implementation}"
)
if implementation is None or implementation.lower() == "mpi":
try:
_get_mpi_comm().barrier()
return
except Exception:
if get_rank() == 0:
logger.warning(
"MPI barrier failed; falling back to torch.distributed.barrier"
)
# torch fallback
import torch.distributed
if torch.distributed.is_initialized():
kwargs: dict[str, Any] = {}
if group is not None:
kwargs["group"] = group
torch.distributed.barrier(**kwargs)
broadcast(obj, root=0)
⚓︎
cleanup()
⚓︎
Destroy the torch.distributed process group if active.
Source code in src/ezpz/distributed.py
def cleanup() -> None:
"""Destroy the ``torch.distributed`` process group if active."""
import torch.distributed
if get_rank() == 0 and verify_wandb():
try:
import wandb # noqa: F811
if wandb.run is not None and not getattr(wandb.run, "disabled", False):
logger.info("wandb.run=[%s](%s)", wandb.run.name, wandb.run.url)
except Exception:
pass
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
get_cpus_per_node()
⚓︎
get_device_properties(device=None)
⚓︎
Return device properties as a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
int | None
|
Device index. Defaults to |
None
|
Source code in src/ezpz/distributed.py
def get_device_properties(device: int | None = None) -> dict[str, Any]:
"""Return device properties as a dictionary.
Args:
device: Device index. Defaults to ``get_local_rank()``.
"""
import torch
device_type = get_torch_device_type()
idx = device if device is not None else get_local_rank()
if device_type == "cuda":
props = torch.cuda.get_device_properties(idx)
# torch.cuda.DeviceProperties uses `total_memory` (in bytes),
# not `total_mem`. The old name silently AttributeError'd to a
# plain `props.total_mem` call that doesn't exist.
return {"name": props.name, "total_memory": props.total_memory}
if device_type == "xpu" and hasattr(torch, "xpu"):
props = torch.xpu.get_device_properties(idx)
return {
"name": props.name,
"total_memory": getattr(props, "total_memory", -1),
}
return {"name": device_type, "total_memory": -1}
get_dist_info(verbose=None, hostfile=None)
⚓︎
Return a dictionary summarising the distributed environment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
verbose
|
bool | None
|
If |
None
|
hostfile
|
str | PathLike | None
|
Explicit hostfile path. |
None
|
Source code in src/ezpz/distributed.py
def get_dist_info(
verbose: bool | None = None,
hostfile: str | os.PathLike | None = None,
) -> dict[str, Any]:
"""Return a dictionary summarising the distributed environment.
Args:
verbose: If ``True``, log the info dict as formatted JSON.
hostfile: Explicit hostfile path.
"""
import sys
from ezpz.configs import get_scheduler
hfp = (
get_hostfile_with_fallback(hostfile)
if hostfile is None
else Path(hostfile)
)
if hfp is not None and Path(hfp).is_file():
hosts = get_nodes_from_hostfile(hfp)
hostfile_path = Path(hfp).resolve().as_posix()
else:
hosts = [get_hostname()]
hostfile_path = str(hfp) if hfp is not None else ""
num_nodes = len(hosts)
gpus = get_gpus_per_node()
info: dict[str, Any] = {}
info.update(
{
"DEVICE": get_torch_device(),
"DEVICE_ID": f"{get_torch_device()}:{get_local_rank()}",
"DISTRIBUTED_BACKEND": get_torch_backend(),
"GPUS_PER_NODE": gpus,
"HOSTS": str(hosts),
"HOSTFILE": hostfile_path,
"HOSTNAME": get_hostname(),
"LOCAL_RANK": get_local_rank(),
"MACHINE": get_machine(),
"NUM_NODES": num_nodes,
"NGPUS": num_nodes * gpus,
"NGPUS_AVAILABLE": get_world_size_total(),
"NODE_ID": get_node_index(),
"RANK": get_rank(),
"SCHEDULER": get_scheduler(),
"WORLD_SIZE_TOTAL": get_world_size_total(),
"WORLD_SIZE_IN_USE": get_world_size_in_use(),
"world_size": get_world_size(),
"EZPZ_RUN_COMMAND": os.environ.get(
"EZPZ_RUN_COMMAND", sys.argv[0]
),
}
)
if verbose:
import json
logger.info("DistInfo=%s", json.dumps(info, indent=4, sort_keys=True))
return info
get_gpus_per_node()
⚓︎
Return the number of accelerators on the local node.
Prefers environment variables (NGPU_PER_HOST, LOCAL_WORLD_SIZE,
PMI_LOCAL_SIZE, SLURM_NTASKS_PER_NODE) then falls back to
torch.{cuda,xpu}.device_count().
Source code in src/ezpz/distributed.py
def get_gpus_per_node() -> int:
"""Return the number of accelerators on the local node.
Prefers environment variables (``NGPU_PER_HOST``, ``LOCAL_WORLD_SIZE``,
``PMI_LOCAL_SIZE``, ``SLURM_NTASKS_PER_NODE``) then falls back to
``torch.{cuda,xpu}.device_count()``.
"""
for var in (
"NGPU_PER_HOST",
"LOCAL_WORLD_SIZE",
"PMI_LOCAL_SIZE",
"SLURM_NTASKS_PER_NODE",
):
val = os.environ.get(var)
if val is not None and val != "":
return int(val)
import torch
if torch.cuda.is_available():
return torch.cuda.device_count()
if torch.xpu.is_available():
return torch.xpu.device_count()
if torch.backends.mps.is_available():
return get_world_size_in_use()
return 0
get_hostfile_with_fallback(hostfile=None)
⚓︎
Locate (or create) a usable hostfile.
Checks PBS, SLURM, and environment variables. As a last resort,
writes localhost to a file in the current directory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hostfile
|
str | PathLike | None
|
Explicit path; auto-detected when |
None
|
Returns:
| Type | Description |
|---|---|
Path
|
class: |
Source code in src/ezpz/distributed.py
def get_hostfile_with_fallback(
hostfile: str | os.PathLike | None = None,
) -> Path:
"""Locate (or create) a usable hostfile.
Checks PBS, SLURM, and environment variables. As a last resort,
writes ``localhost`` to a file in the current directory.
Args:
hostfile: Explicit path; auto-detected when ``None``.
Returns:
:class:`Path` to the hostfile.
"""
from ezpz.configs import get_scheduler
scheduler = get_scheduler()
if scheduler.lower() == "slurm":
return _make_hostfile_from_slurm()
if hostfile is not None:
hfp = Path(hostfile)
if hfp.is_file():
return hfp
# Try standard env vars
for var in ("PBS_NODEFILE", "HOSTFILE"):
val = os.environ.get(var)
if val and Path(val).is_file():
return Path(val)
# PBS without env var
if scheduler == "PBS":
try:
import ezpz.pbs
nodefile = ezpz.pbs.get_pbs_nodefile()
if nodefile is not None:
return Path(nodefile)
except Exception:
pass
# Fallback: write localhost
hfp = Path(os.getcwd()) / "hostfile"
hfp.touch(exist_ok=True)
write_localhost_to_hostfile(hfp)
return hfp
get_hostname()
⚓︎
Return the hostname of the current machine (lowercased).
Source code in src/ezpz/distributed.py
def get_hostname() -> str:
"""Return the hostname of the current machine (lowercased)."""
try:
name = socket.gethostname()
if name:
try:
return socket.gethostbyaddr(name)[0].strip().lower()
except OSError:
return name.strip().lower()
except Exception:
pass
for var in ("HOSTNAME", "HOST"):
val = os.environ.get(var)
if val is not None and val != "":
return val.strip().lower()
try:
import platform
node = platform.node()
if node:
return node.strip().lower()
except Exception:
pass
return "localhost"
get_local_rank()
⚓︎
Return the local rank (GPU index) of the current process on its node.
The value is resolved from well-known environment variables set by
common MPI implementations and job schedulers. If none are set the
function falls back to get_rank() % get_gpus_per_node().
Returns:
| Type | Description |
|---|---|
int
|
Local rank (0-indexed within the node). |
Source code in src/ezpz/distributed.py
def get_local_rank() -> int:
"""Return the local rank (GPU index) of the current process on its node.
The value is resolved from well-known environment variables set by
common MPI implementations and job schedulers. If none are set the
function falls back to ``get_rank() % get_gpus_per_node()``.
Returns:
Local rank (0-indexed within the node).
"""
_ENV_VARS = (
"LOCAL_RANK",
"PMI_LOCAL_RANK",
"OMPI_COMM_WORLD_LOCAL_RANK",
"MPI_LOCALRANKID",
"MPICH_LOCALRANKID",
"SLURM_LOCAL_ID",
)
for var in _ENV_VARS:
val = os.environ.get(var)
if val is not None and val != "":
return int(val)
ws = get_world_size()
if ws <= 1:
return 0
gpn = get_gpus_per_node()
return get_rank() % gpn if gpn > 0 else 0
get_machine(hostname=None)
⚓︎
Identify the ALCF / HPC machine from its hostname prefix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hostname
|
str | None
|
Override; auto-detected when |
None
|
Returns:
| Type | Description |
|---|---|
str
|
A human-readable machine name (e.g. |
Source code in src/ezpz/distributed.py
def get_machine(hostname: str | None = None) -> str:
"""Identify the ALCF / HPC machine from its hostname prefix.
Args:
hostname: Override; auto-detected when ``None``.
Returns:
A human-readable machine name (e.g. ``"Polaris"``, ``"Aurora"``).
"""
if hostname is None:
hostname = get_hostname()
_PREFIX_MAP = (
("frontier", "Frontier"),
("sophia", "Sophia"),
("theta", "ThetaGPU"),
("sunspot", "SunSpot"),
("x1", "SunSpot"),
("aurora", "Aurora"),
("x4", "Aurora"),
("login", "Perlmutter"),
("nid", "Perlmutter"),
)
for prefix, name in _PREFIX_MAP:
if hostname.startswith(prefix):
return name
if hostname.startswith("x3"):
return "Sirius" if "sirius" in hostname else "Polaris"
return hostname
get_node_index()
⚓︎
get_nodes_from_hostfile(hostfile)
⚓︎
Read hostnames from hostfile, one per line.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hostfile
|
str | PathLike
|
Path to the hostfile. |
required |
Returns:
| Type | Description |
|---|---|
list[str]
|
List of hostnames. |
Source code in src/ezpz/distributed.py
def get_nodes_from_hostfile(hostfile: str | os.PathLike) -> list[str]:
"""Read hostnames from *hostfile*, one per line.
Args:
hostfile: Path to the hostfile.
Returns:
List of hostnames.
"""
fpath = Path(hostfile)
if not fpath.is_file():
return [get_hostname()]
with fpath.open("r") as f:
return [line.rstrip("\n") for line in f if line.strip()]
get_num_nodes(hostfile=None)
⚓︎
Return the number of nodes in the current allocation.
Checks SLURM_NNODES first, then counts lines in hostfile.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hostfile
|
str | PathLike | None
|
Explicit path; resolved automatically when |
None
|
Source code in src/ezpz/distributed.py
def get_num_nodes(hostfile: str | os.PathLike | None = None) -> int:
"""Return the number of nodes in the current allocation.
Checks ``SLURM_NNODES`` first, then counts lines in *hostfile*.
Args:
hostfile: Explicit path; resolved automatically when ``None``.
"""
slurm_nnodes = os.environ.get("SLURM_NNODES")
if slurm_nnodes is not None:
return int(slurm_nnodes)
hfp = get_hostfile_with_fallback(hostfile)
hosts = [h.split(".")[0] for h in get_nodes_from_hostfile(hfp)]
return len(hosts)
get_rank()
⚓︎
Return the global MPI rank of the current process.
The value is resolved from well-known environment variables set by common MPI implementations and job schedulers. If none are set the function falls back to querying the MPI communicator.
Returns:
| Type | Description |
|---|---|
int
|
Global rank (0-indexed). |
Source code in src/ezpz/distributed.py
def get_rank() -> int:
"""Return the global MPI rank of the current process.
The value is resolved from well-known environment variables set by
common MPI implementations and job schedulers. If none are set the
function falls back to querying the MPI communicator.
Returns:
Global rank (0-indexed).
"""
_ENV_VARS = (
"RANK",
"PMI_RANK",
"OMPI_COMM_WORLD_RANK",
"SLURM_PROCID",
)
for var in _ENV_VARS:
val = os.environ.get(var)
if val is not None and val != "":
return int(val)
try:
return int(_get_mpi_comm().Get_rank())
except Exception:
return 0
get_torch_backend()
⚓︎
Return the appropriate torch.distributed backend string.
Checks TORCH_BACKEND env, then probes hardware availability to
select nccl / xccl / gloo.
Source code in src/ezpz/distributed.py
def get_torch_backend() -> str:
"""Return the appropriate ``torch.distributed`` backend string.
Checks ``TORCH_BACKEND`` env, then probes hardware availability to
select ``nccl`` / ``xccl`` / ``gloo``.
"""
env = os.environ.get("TORCH_BACKEND")
if env is not None:
return env
import torch
if torch.cuda.is_available() and torch.distributed.is_backend_available(
"nccl"
):
return "nccl"
if torch.xpu.is_available():
if torch.distributed.is_backend_available("xccl"):
return "xccl"
return "ccl"
return "gloo"
get_torch_device(*, device_type=None, as_torch_device=None)
⚓︎
Return the current accelerator device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device_type
|
str | None
|
Force a specific device type. |
None
|
as_torch_device
|
bool | None
|
If |
None
|
Source code in src/ezpz/distributed.py
def get_torch_device(
*,
device_type: str | None = None,
as_torch_device: bool | None = None,
) -> str | torch.device:
"""Return the current accelerator device.
Args:
device_type: Force a specific device type.
as_torch_device: If ``True``, return a :class:`torch.device` object
instead of a plain string.
"""
import torch
env = os.environ.get("TORCH_DEVICE")
if env:
normalized = env.strip().lower()
base = normalized.split(":", 1)[0]
if base in _SUPPORTED_DEVICE_TYPES:
return torch.device(normalized) if as_torch_device else normalized
dt = device_type if device_type is not None else get_torch_device_type()
return torch.device(dt) if as_torch_device else dt
get_torch_device_type(device_type=None)
⚓︎
Return the accelerator type as a string ("cuda", "xpu", …).
Respects the TORCH_DEVICE environment variable when set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device_type
|
str | None
|
Override value; returned as-is after validation. |
None
|
Source code in src/ezpz/distributed.py
def get_torch_device_type(device_type: str | None = None) -> str:
"""Return the accelerator type as a string (``"cuda"``, ``"xpu"``, …).
Respects the ``TORCH_DEVICE`` environment variable when set.
Args:
device_type: Override value; returned as-is after validation.
"""
if device_type is not None:
if device_type not in _SUPPORTED_DEVICE_TYPES:
raise ValueError(
f"Unsupported device_type={device_type!r}; "
f"expected one of {sorted(_SUPPORTED_DEVICE_TYPES)}"
)
return device_type
env = os.environ.get("TORCH_DEVICE")
if env:
base = env.strip().lower().split(":", 1)[0]
if base in _SUPPORTED_DEVICE_TYPES:
return base
logger.warning(
"Ignoring unsupported TORCH_DEVICE=%s; expected one of %s",
env,
sorted(_SUPPORTED_DEVICE_TYPES),
)
import torch
if torch.xpu.is_available():
return "xpu"
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
get_world_size(*, total=False, in_use=False)
⚓︎
Return the distributed world size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
total
|
bool
|
If |
False
|
in_use
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
int
|
World size as an integer. |
Source code in src/ezpz/distributed.py
def get_world_size(
*,
total: bool = False,
in_use: bool = False,
) -> int:
"""Return the distributed world size.
Args:
total: If ``True``, return the *total available* accelerator count
(``num_nodes * gpus_per_node``).
in_use: If ``True``, return ``MPI.COMM_WORLD.Get_size()``.
Returns:
World size as an integer.
"""
if total:
return get_world_size_total()
if in_use:
return get_world_size_in_use()
_ENV_VARS = (
"WORLD_SIZE",
"PMI_SIZE",
"OMPI_COMM_WORLD_SIZE",
"SLURM_NTASKS",
)
for var in _ENV_VARS:
val = os.environ.get(var)
if val is not None and val != "":
return int(val)
try:
return int(_get_mpi_comm().Get_size())
except Exception:
return 1
get_world_size_in_use()
⚓︎
get_world_size_total()
⚓︎
init_device_mesh_safe(device_type, mesh_shape, *, mesh_dim_names=None)
⚓︎
Drop-in replacement for torch.distributed.init_device_mesh.
Works around xccl's missing ProcessGroup.split_group support.
For FSDP2 to route foreach_all_gather correctly on XPU,
_setup_ddp binds the default PG to a device by passing
device_id= to init_process_group. Torch then sees
default_group.bound_device_id is not None and, inside
DeviceMesh._init_one_process_group, prefers
split_group(parent_pg, ...) over new_group(ranks, ...).
On the current xccl backend parent_backend.supports_splitting
is False, so split_group raises:
RuntimeError: No backend for the parent process group or its
backend does not support splitting
We temporarily clear default_group.bound_device_id for the
duration of the init_device_mesh call so torch takes the
new_group fallback (which xccl supports), then restore it so
FSDP2's per-device PG resolution still works.
No-op on non-xpu devices and when no default PG exists yet.
Source code in src/ezpz/distributed.py
def init_device_mesh_safe(
device_type: str,
mesh_shape: tuple[int, ...],
*,
mesh_dim_names: tuple[str, ...] | None = None,
) -> Any:
"""Drop-in replacement for ``torch.distributed.init_device_mesh``.
Works around xccl's missing ``ProcessGroup.split_group`` support.
For FSDP2 to route ``foreach_all_gather`` correctly on XPU,
``_setup_ddp`` binds the default PG to a device by passing
``device_id=`` to ``init_process_group``. Torch then sees
``default_group.bound_device_id is not None`` and, inside
``DeviceMesh._init_one_process_group``, prefers
``split_group(parent_pg, ...)`` over ``new_group(ranks, ...)``.
On the current xccl backend ``parent_backend.supports_splitting``
is ``False``, so ``split_group`` raises:
RuntimeError: No backend for the parent process group or its
backend does not support splitting
We temporarily clear ``default_group.bound_device_id`` for the
duration of the ``init_device_mesh`` call so torch takes the
``new_group`` fallback (which xccl supports), then restore it so
FSDP2's per-device PG resolution still works.
No-op on non-xpu devices and when no default PG exists yet.
"""
import torch
from torch.distributed.device_mesh import init_device_mesh as _imd
default_pg = None
saved: Any = None
needs_workaround = False
if device_type == "xpu" and torch.distributed.is_initialized():
try:
default_pg = torch.distributed.distributed_c10d._get_default_group()
saved = getattr(default_pg, "bound_device_id", None)
if saved is not None:
default_pg.bound_device_id = None # type: ignore[attr-defined]
needs_workaround = True
except (AttributeError, RuntimeError):
needs_workaround = False
try:
return _imd(device_type, mesh_shape, mesh_dim_names=mesh_dim_names)
finally:
if needs_workaround and default_pg is not None:
try:
default_pg.bound_device_id = saved # type: ignore[attr-defined]
except AttributeError:
pass
log_dict_as_bulleted_list(d, name=None)
⚓︎
Log a dictionary as a bulleted list.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d
|
dict
|
Dictionary to format. |
required |
name
|
str | None
|
Optional label for the header. |
None
|
Source code in src/ezpz/distributed.py
def log_dict_as_bulleted_list(d: dict, name: str | None = None) -> None:
"""Log a dictionary as a bulleted list.
Args:
d: Dictionary to format.
name: Optional label for the header.
"""
tag = name or getattr(d, "__qualname__", "dict")
lines = [f"[{tag}]:"] + [f" - {k}={v}" for k, v in d.items()]
logger.info("\n\n%s\n", "\n".join(lines))
print_dist_setup(hostfile=None, display=True)
⚓︎
Build (and optionally log) a one-line-per-rank summary string.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hostfile
|
str | PathLike | None
|
Explicit hostfile path. |
None
|
display
|
bool
|
If |
True
|
Returns:
| Type | Description |
|---|---|
str
|
The formatted summary string. |
Source code in src/ezpz/distributed.py
def print_dist_setup(
hostfile: str | os.PathLike | None = None,
display: bool = True,
) -> str:
"""Build (and optionally log) a one-line-per-rank summary string.
Args:
hostfile: Explicit hostfile path.
display: If ``True``, emit the string via :func:`logger.info`.
Returns:
The formatted summary string.
"""
rank = get_rank()
world_size = get_world_size(in_use=True)
local_rank = get_local_rank()
gpn = max(get_gpus_per_node(), 1)
num_nodes = max(world_size // gpn, 1)
node = get_node_index()
device = get_torch_device_type()
hn = socket.gethostname()
rw = len(str(max(0, world_size - 1)))
lw = len(str(max(0, gpn - 1)))
nnw = len(str(max(0, num_nodes - 1)))
nw = nnw
parts = [
f"['{hn}']",
f"[{device=}]",
f"[node={node:>0{nw}d}/{num_nodes - 1:<0{nnw}d}]",
f"[local_rank={local_rank:>0{lw}d}/{gpn - 1:<0{lw}d}]",
f"[rank={rank:>0{rw}d}/{world_size - 1:<0{rw}d}]",
]
dist_str = "".join(parts)
if display:
logger.info(dist_str)
if rank == 0:
wst = get_world_size(total=True)
logger.warning(
'Using [%d / %d] available "%s" devices !!',
world_size,
wst,
device,
)
return dist_str
query_environment()
⚓︎
Return {world_size, rank, local_rank} from env vars or MPI.
Source code in src/ezpz/distributed.py
def query_environment() -> dict[str, int]:
"""Return ``{world_size, rank, local_rank}`` from env vars or MPI."""
ws = os.environ.get("WORLD_SIZE")
r = os.environ.get("RANK")
lr = os.environ.get("LOCAL_RANK")
if ws is not None and r is not None and lr is not None:
return {"world_size": int(ws), "rank": int(r), "local_rank": int(lr)}
return {
"world_size": get_world_size(),
"rank": get_rank(),
"local_rank": get_local_rank(),
}
resolve_fsdp_strategy(strategy)
⚓︎
Convert a CLI sharding strategy name to a reshard_after_forward value.
Returns None when the strategy is "no-shard" (caller should
use DDP instead).
Source code in src/ezpz/distributed.py
def resolve_fsdp_strategy(
strategy: str,
) -> bool | int | None:
"""Convert a CLI sharding strategy name to a ``reshard_after_forward`` value.
Returns ``None`` when the strategy is ``"no-shard"`` (caller should
use DDP instead).
"""
if strategy not in FSDP_SHARDING_STRATEGIES:
raise ValueError(
f"Unknown FSDP sharding strategy {strategy!r}. "
f"Choose from: {', '.join(FSDP_SHARDING_STRATEGIES)}"
)
val = FSDP_SHARDING_STRATEGIES[strategy]
if val == "hybrid":
return get_gpus_per_node()
return val
seed_everything(seed)
⚓︎
Seed Python, NumPy, and PyTorch RNGs for reproducibility.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int
|
The random seed. |
required |
Source code in src/ezpz/distributed.py
def seed_everything(seed: int) -> None:
"""Seed Python, NumPy, and PyTorch RNGs for reproducibility.
Args:
seed: The random seed.
"""
import random
import numpy as np
import torch
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if torch.xpu.is_available():
torch.xpu.manual_seed(seed)
setup_mlflow(project_name=None, config=None, outdir=None, **kwargs)
⚓︎
Initialise an MLflow run (rank 0 only logs, others return None).
Convenience wrapper around :class:~ezpz.tracker.MLflowBackend that
mirrors :func:setup_wandb. Handles dotenv loading, auth, experiment
name resolution, and system-param logging automatically.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
project_name
|
str | None
|
MLflow experiment name. Falls back to
|
None
|
config
|
dict[str, Any] | None
|
Run-level config dict logged as MLflow params. |
None
|
outdir
|
str | PathLike | None
|
Artifact output directory. |
None
|
**kwargs
|
Any
|
Forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
The |
Any
|
unavailable or the current rank is not 0. |
Source code in src/ezpz/distributed.py
def setup_mlflow(
project_name: str | None = None,
config: dict[str, Any] | None = None,
outdir: str | os.PathLike | None = None,
**kwargs: Any,
) -> Any:
"""Initialise an MLflow run (rank 0 only logs, others return ``None``).
Convenience wrapper around :class:`~ezpz.tracker.MLflowBackend` that
mirrors :func:`setup_wandb`. Handles dotenv loading, auth, experiment
name resolution, and system-param logging automatically.
Args:
project_name: MLflow experiment name. Falls back to
``MLFLOW_EXPERIMENT_NAME``, then wandb project env vars,
then a script-derived default.
config: Run-level config dict logged as MLflow params.
outdir: Artifact output directory.
**kwargs: Forwarded to ``mlflow.start_run``.
Returns:
The ``mlflow.ActiveRun`` object, or ``None`` if MLflow is
unavailable or the current rank is not 0.
"""
try:
from ezpz.tracker import MLflowBackend
backend = MLflowBackend(
project_name=project_name,
config=config,
outdir=outdir,
**kwargs,
)
return backend.run
except ImportError:
logger.warning("mlflow is not installed; skipping MLflow setup.")
return None
except Exception as exc:
logger.warning("setup_mlflow() failed: %s — continuing without MLflow", exc)
return None
setup_torch(port=None, seed=None, timeout=None, verbose=False, *, tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, device_id=None)
⚓︎
Initialise torch.distributed and return the global rank.
This is the main entry point. It:
- Uses MPI to discover rank / world_size / master_addr / master_port.
- Calls
torch.distributed.init_process_group. - Sets the local CUDA/XPU device.
- Optionally seeds RNGs and initialises tensor parallelism.
- Prints a one-line-per-rank summary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
port
|
str | int | None
|
Fallback master port (rank 0 picks a free port otherwise). |
None
|
seed
|
int | None
|
If given, call :func: |
None
|
timeout
|
str | int | None
|
|
None
|
verbose
|
bool
|
Print verbose dist info on rank 0. |
False
|
tensor_parallel_size
|
int
|
TP degree (default 1 = disabled). |
1
|
pipeline_parallel_size
|
int
|
PP degree (default 1 = disabled). |
1
|
context_parallel_size
|
int
|
CP degree (default 1 = disabled). |
1
|
tensor_parallel_backend
|
str | None
|
Override backend for TP group. |
None
|
pipeline_parallel_backend
|
str | None
|
Override backend for PP group. |
None
|
context_parallel_backend
|
str | None
|
Override backend for CP group. |
None
|
data_parallel_backend
|
str | None
|
Override backend for DP group. |
None
|
device_id
|
int | None
|
Explicit device ordinal for |
None
|
Returns:
| Type | Description |
|---|---|
int
|
The global rank of this process. |
Source code in src/ezpz/distributed.py
def setup_torch(
port: str | int | None = None,
seed: int | None = None,
timeout: str | int | None = None,
verbose: bool = False,
*,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
context_parallel_size: int = 1,
tensor_parallel_backend: str | None = None,
pipeline_parallel_backend: str | None = None,
context_parallel_backend: str | None = None,
data_parallel_backend: str | None = None,
device_id: int | None = None,
) -> int:
"""Initialise ``torch.distributed`` and return the global rank.
This is the main entry point. It:
1. Uses MPI to discover rank / world_size / master_addr / master_port.
2. Calls ``torch.distributed.init_process_group``.
3. Sets the local CUDA/XPU device.
4. Optionally seeds RNGs and initialises tensor parallelism.
5. Prints a one-line-per-rank summary.
Args:
port: Fallback master port (rank 0 picks a free port otherwise).
seed: If given, call :func:`seed_everything` with a rank-aware seed.
timeout: ``init_process_group`` timeout in seconds (default 3600).
verbose: Print verbose dist info on rank 0.
tensor_parallel_size: TP degree (default 1 = disabled).
pipeline_parallel_size: PP degree (default 1 = disabled).
context_parallel_size: CP degree (default 1 = disabled).
tensor_parallel_backend: Override backend for TP group.
pipeline_parallel_backend: Override backend for PP group.
context_parallel_backend: Override backend for CP group.
data_parallel_backend: Override backend for DP group.
device_id: Explicit device ordinal for ``init_process_group``.
Returns:
The global rank of this process.
"""
import torch
device_type = get_torch_device_type()
backend = get_torch_backend()
timeout_s = (
int(timeout)
if timeout is not None
else int(os.environ.get("TORCH_DDP_TIMEOUT", 3600))
)
# -- Single-device fast path --
ws_env = os.environ.get("WORLD_SIZE")
if ws_env is not None and ws_env == "1":
if get_rank() == 0:
logger.info(
"Running on a single %s, not initialising torch.distributed!",
device_type,
)
_set_env_vars(rank=0, local_rank=0, world_size=1)
return 0
# -- Multi-device init --
#
# Set the per-process local device BEFORE init_process_group so
# that whatever "current device" the process group binds to is
# actually the same device. Without this, every rank's PG would
# bind to the default device (e.g. xpu:0 on Aurora) at
# construction time; later `set_device(...)` switches the current
# device but the PG remains stuck on xpu:0. xccl/foreach_all_gather
# then routes collectives to two different XPU queues and they
# never meet up — FSDP2 deadlocks on the very first
# all_gather_into_tensor.
#
# CUDA tends to mask this by being more forgiving about
# current-device-at-init-time, but on XPU it's load-bearing.
#
# IMPORTANT: the device we set here must match whatever `_setup_ddp`
# will bind via `device_id=` in init_process_group. That's the
# caller-provided `device_id` if given, otherwise `local_rank`.
# Using `LOCAL_RANK` unconditionally would reintroduce the wrong-
# device hang on XPU whenever a caller passes an explicit
# `device_id` (e.g. `setup_torch(device_id=2)` on local_rank 5).
pre_local_rank = get_local_rank()
pre_device_index = device_id if device_id is not None else pre_local_rank
_set_local_device(get_torch_device_type(), pre_device_index)
dsetup = _setup_ddp(
port=str(port) if port is not None else "1234",
timeout=timedelta(seconds=timeout_s),
backend=backend,
device_id=device_id,
)
rank = dsetup["rank"]
world_size = dsetup["world_size"]
local_rank = dsetup["local_rank"]
# Re-set in case `_setup_ddp` resolved a different device than we
# pre-bound (rare — happens if `get_local_rank()`'s pre-init guess
# differs from what `_setup_ddp` saw in env vars). Compare against
# the device we ACTUALLY set above, not raw `LOCAL_RANK` — those
# diverge whenever `device_id` was passed.
post_device_index = device_id if device_id is not None else local_rank
if post_device_index != pre_device_index:
_set_local_device(get_torch_device_type(), post_device_index)
_set_env_vars(rank=rank, local_rank=local_rank, world_size=world_size)
# -- Tensor / pipeline / context parallelism --
if (
tensor_parallel_size > 1
or pipeline_parallel_size > 1
or context_parallel_size > 1
):
import ezpz.tp
ezpz.tp.initialize_tensor_parallel(
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
context_parallel_size=context_parallel_size,
tensor_parallel_backend=tensor_parallel_backend,
pipeline_parallel_backend=pipeline_parallel_backend,
context_parallel_backend=context_parallel_backend,
data_parallel_backend=data_parallel_backend,
timeout=timedelta(seconds=float(timeout_s)),
)
# -- Seed --
if seed is not None:
if rank == 0:
logger.warning("Manually specifying seed=%d", seed)
seed_everything(seed * (rank + 1) * (local_rank + 1))
# -- Diagnostics --
if rank == 0:
_ = get_dist_info(verbose=verbose)
logger.info(
"Using device=%s with backend=%s for distributed training.",
device_type,
backend,
)
if world_size > 1:
barrier()
logger.info(print_dist_setup(display=False))
barrier()
_configure_rank_warnings(rank)
return rank
setup_wandb(project_name=None, entity=None, config=None, outdir=None, project=None, dir=None, id=None, name=None, notes=None, tags=None, config_exclude_keys=None, config_include_keys=None, allow_val_change=None, group=None, job_type=None, mode=None, force=False, reinit=None, resume=None, resume_from=None, fork_from=None, save_code=None, init_timeout=None, start_method=None, tensorboard=None, sync_tensorboard=None, monitor_gym=None, settings=None, **kwargs)
⚓︎
Initialise a wandb run (rank 0 only — non-zero ranks return None).
Most parameters are forwarded directly to :func:wandb.init. See
the wandb docs <https://docs.wandb.ai/ref/python/init/>_ for
details.
Returns:
| Name | Type | Description |
|---|---|---|
The |
Any
|
obj: |
Any
|
rank, or |
.. note::
Returns None on every rank other than 0. Previously
non-zero ranks got a mode="disabled" wandb.run back — that
still meant verify_wandb(), wandb.init(), and full
run.config.update() ran on every rank, which on a 96-rank job
produced 96 dummy runs and a wall of "Setting up wandb from
rank=N" log spam. Callers that need a no-op tracker on
non-zero ranks should test for None and use
:class:ezpz.tracker.NullTracker (or just ignore the return —
log() calls against None should be guarded by the
caller anyway).
Source code in src/ezpz/distributed.py
def setup_wandb(
project_name: str | None = None,
entity: str | None = None,
config: dict[str, Any] | None = None,
outdir: str | os.PathLike | None = None,
project: str | None = None,
dir: str | os.PathLike | None = None,
id: str | None = None,
name: str | None = None,
notes: str | None = None,
tags: Sequence[str] | None = None,
config_exclude_keys: list[str] | None = None,
config_include_keys: list[str] | None = None,
allow_val_change: bool | None = None,
group: str | None = None,
job_type: str | None = None,
mode: Literal["online", "offline", "disabled", "shared"] | None = None,
force: bool = False,
reinit: bool | str | None = None,
resume: bool | str | None = None,
resume_from: str | None = None,
fork_from: str | None = None,
save_code: bool | None = None,
init_timeout: int | float | None = None,
start_method: Literal["fork", "spawn", "thread", "process"] | None = None,
tensorboard: bool | None = None,
sync_tensorboard: bool | None = None,
monitor_gym: bool | None = None,
settings: dict[str, Any] | None = None,
**kwargs,
) -> Any:
"""Initialise a wandb run (rank 0 only — non-zero ranks return ``None``).
Most parameters are forwarded directly to :func:`wandb.init`. See
the `wandb docs <https://docs.wandb.ai/ref/python/init/>`_ for
details.
Returns:
The :obj:`wandb.run` object, ``None`` if called from a non-zero
rank, or ``None`` if wandb is unavailable.
.. note::
Returns ``None`` on every rank other than 0. Previously
non-zero ranks got a ``mode="disabled"`` wandb.run back — that
still meant verify_wandb(), wandb.init(), and full
run.config.update() ran on every rank, which on a 96-rank job
produced 96 dummy runs and a wall of "Setting up wandb from
rank=N" log spam. Callers that need a no-op tracker on
non-zero ranks should test for ``None`` and use
:class:`ezpz.tracker.NullTracker` (or just ignore the return —
``log()`` calls against ``None`` should be guarded by the
caller anyway).
"""
# Hard rank gate: non-zero ranks skip all wandb work entirely.
# verify_wandb() and wandb.init() each take real time and produce
# log spam; on a 96-rank job that's 95x wasted work + a 96x
# multiplier on every log line in this function.
rank = get_rank()
if rank != 0:
return None
import wandb
if not verify_wandb():
logger.warning("verify_wandb() failed; not initialising run")
return None
outdir_str = Path(outdir).as_posix() if outdir else os.getcwd()
# Resolve project name
_project = project or project_name
if _project is None:
_project = os.environ.get(
"WB_PROJECT",
os.environ.get("WANDB_PROJECT", os.environ.get("WB_PROJECT_NAME")),
)
if _project is None:
import sys
frame = sys._getframe().f_back
if frame is not None:
fp = Path(frame.f_code.co_filename)
_project = f"{fp.parent.stem}.{fp.stem}"
# Resolve mode
_mode = _resolve_wandb_mode(mode)
logger.info("Setting up wandb from rank=%d", rank)
logger.info("Using WB_PROJECT=%s", _project)
try:
run = wandb.init(
entity=entity,
project=_project,
dir=str(dir) if dir is not None else outdir_str,
id=id,
name=name,
notes=notes,
tags=tags,
config_exclude_keys=config_exclude_keys,
config_include_keys=config_include_keys,
allow_val_change=allow_val_change,
group=group,
job_type=job_type,
mode=_mode,
force=force,
reinit=reinit,
resume=resume,
resume_from=resume_from,
fork_from=fork_from,
save_code=save_code,
tensorboard=tensorboard if tensorboard is not None else False,
sync_tensorboard=sync_tensorboard
if sync_tensorboard is not None
else False,
monitor_gym=monitor_gym,
settings=(
settings
if settings is not None
else _build_wandb_settings(
wandb=wandb,
init_timeout=init_timeout,
start_method=start_method,
)
),
**kwargs,
)
if run is not None:
logger.info("wandb.run=[%s](%s)", run.name, run.url)
import sys # noqa: PLC0415
import torch # noqa: PLC0415
import ezpz # noqa: PLC0415
from ezpz.configs import get_scheduler # noqa: PLC0415
now = datetime.datetime.now()
# Best-effort resolution of the active scheduler jobid.
# ezpz.launch.get_active_jobid imports ezpz.pbs / ezpz.slurm
# lazily and returns None when no job is detected. Wrapped
# in try/except because the launch module pulls a non-
# trivial chain on first import — any failure here should
# NOT block the wandb run from being created.
jobid: str | None = None
try:
from ezpz.launch import (
get_active_jobid,
) # noqa: PLC0415
jobid = get_active_jobid()
except Exception:
pass
# num_nodes / gpus_per_node have their own getters that
# already swallow failures internally (return 1 / fall
# back). Safe to call directly.
run.config.update(
{
# "DIST_INFO": get_dist_info(),
# --- existing fields (unchanged) ---
"hostname": get_hostname(),
"pytorch_backend": get_torch_backend(),
"torch_version": torch.__version__,
"world_size": get_world_size(),
"ezpz_version": ezpz.__version__,
"machine": get_machine(),
"working_directory": os.getcwd(),
"year": now.year,
"month": now.month,
"day": now.day,
"tstamp": now.isoformat(),
# --- new dimensions for filtering / grouping ---
# Pivot from a wandb run → the cluster job that
# ran it. None when not inside a PBS/SLURM job.
"jobid": jobid,
# "pbs" / "slurm" / "" — useful when you have
# runs from both systems in the same project.
"scheduler": get_scheduler(),
# Distinct from world_size: same world_size can be
# 8x8 or 4x16, lets you separate the two.
"num_nodes": get_num_nodes(),
"ranks_per_node": get_gpus_per_node(),
# "cuda" / "xpu" / "mps" / "cpu" — at-a-glance
# distinction between Aurora vs NVIDIA vs CPU runs.
"device_type": get_torch_device_type(),
# --- debugging / postmortems ---
"python_version": sys.version.split()[0],
# None when ezpz is a pip-install, not a git
# checkout. Disambiguates dev branches sharing the
# same ezpz_version.
"ezpz_git_sha": _get_ezpz_git_sha(),
}
)
if config is not None:
run.config.update({"config": config})
return wandb.run
except Exception as exc:
logger.exception("wandb.init() failed from rank=%d: %s", rank, exc)
logger.warning("Continuing without wandb logging.")
return None
synchronize(device=None)
⚓︎
Block until all work on the given accelerator has finished.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
str | int | None
|
Device specifier; auto-detected when |
None
|
Source code in src/ezpz/distributed.py
def synchronize(device: str | int | None = None) -> None:
"""Block until all work on the given accelerator has finished.
Args:
device: Device specifier; auto-detected when ``None``.
"""
import torch
if torch.cuda.is_available():
torch.cuda.synchronize(device)
elif torch.xpu.is_available():
torch.xpu.synchronize(device)
elif torch.backends.mps.is_available():
torch.mps.synchronize()
timeitlogit(rank=None, record=True, verbose=False, prefix=None)
⚓︎
Decorator factory to time a function, optionally logging to wandb.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rank
|
int | None
|
Rank whose logger emits messages (defaults to :func: |
None
|
record
|
bool
|
Whether to |
True
|
verbose
|
bool
|
Whether to log to stdout on the selected rank. |
False
|
prefix
|
str | None
|
Metric prefix for wandb (default |
None
|
Returns:
| Type | Description |
|---|---|
Callable
|
A decorator that wraps the target function. |
Example::
@timeitlogit(rank=0, verbose=True)
def train_step(batch): ...
Source code in src/ezpz/distributed.py
def timeitlogit(
rank: int | None = None,
record: bool = True,
verbose: bool = False,
prefix: str | None = None,
) -> Callable:
"""Decorator factory to time a function, optionally logging to wandb.
Args:
rank: Rank whose logger emits messages (defaults to :func:`get_rank`).
record: Whether to ``wandb.log`` the timing.
verbose: Whether to log to stdout on the selected rank.
prefix: Metric prefix for wandb (default ``"timeit"``).
Returns:
A decorator that wraps the target function.
Example::
@timeitlogit(rank=0, verbose=True)
def train_step(batch): ...
"""
_rank = rank if rank is not None else get_rank()
_prefix = prefix or "timeit"
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
t0 = time.perf_counter()
result = func(*args, **kwargs)
dt = time.perf_counter() - t0
fname = getattr(
func, "__qualname__", getattr(func, "__name__", "unknown")
)
if record:
try:
import wandb
if wandb.run is not None:
wandb.log({f"{_prefix}/{fname}": dt}, commit=False)
except Exception:
pass
if verbose and _rank == 0:
arg_str = ", ".join(map(str, args))
kw_str = ", ".join(f"{k}={v}" for k, v in kwargs.items())
inner = ", ".join(filter(None, [arg_str, kw_str]))
logger.info("%s(%s) took %.4f s", fname, inner, dt)
return result
return wrapper
return decorator
verify_wandb()
⚓︎
Return True if wandb is importable, enabled, and usable.
"Usable" means one of:
WANDB_MODE=offlineis set — offline runs don't need network or credentials; they write to a localwandb/offline-run-*directory and can be synced later. This is the documented compute-node workflow (seedocs/troubleshooting.md+doctor.py's permissive handling of offline mode).- An API key is reachable (
$WANDB_API_KEY,wandb.api.api_key, or a~/.netrcentry forapi.wandb.ai).
Returns False when wandb is uninstalled, WANDB_DISABLED is
truthy, WANDB_MODE=disabled, or neither offline mode nor any
credential source is configured.
Source code in src/ezpz/distributed.py
def verify_wandb() -> bool:
"""Return ``True`` if wandb is importable, enabled, and usable.
"Usable" means one of:
1. ``WANDB_MODE=offline`` is set — offline runs don't need
network or credentials; they write to a local
``wandb/offline-run-*`` directory and can be synced later.
This is the documented compute-node workflow (see
``docs/troubleshooting.md`` + ``doctor.py``'s permissive
handling of offline mode).
2. An API key is reachable (``$WANDB_API_KEY``, ``wandb.api.api_key``,
or a ``~/.netrc`` entry for ``api.wandb.ai``).
Returns ``False`` when wandb is uninstalled, ``WANDB_DISABLED`` is
truthy, ``WANDB_MODE=disabled``, or neither offline mode nor any
credential source is configured.
"""
rank = get_rank()
try:
import wandb
except Exception:
if rank == 0:
logger.warning(
"Unable to import wandb; install with `pip install wandb`"
)
return False
if os.environ.get("WANDB_DISABLED"):
return False
wm = os.environ.get("WANDB_MODE", "").lower()
if wm == "disabled":
return False
# Offline mode is fine without credentials — wandb.init(mode="offline")
# writes locally and never touches the network. Return True so callers
# (esp. setup_wandb) don't silently no-op on compute nodes that use
# the offline-then-sync workflow.
if wm == "offline":
return True
if (
wandb.api.api_key is not None
or os.environ.get("WANDB_API_KEY") is not None
):
return True
# Last resort: check ~/.netrc
try:
import netrc as _netrc
netrc_path = Path(os.path.expanduser("~/.netrc"))
if netrc_path.is_file():
auth = _netrc.netrc(netrc_path).authenticators("api.wandb.ai")
return bool(auth)
except Exception:
pass
return False
wrap_model(model, use_fsdp=True, dtype='bfloat16', device_mesh=None, reshard_after_forward=True, device_id=None)
⚓︎
Wrap model with DDP or FSDP for distributed training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to wrap. |
required |
use_fsdp
|
bool
|
Use FSDP when |
True
|
dtype
|
str
|
Mixed-precision parameter dtype for FSDP (e.g. |
'bfloat16'
|
device_id
|
device | int | None
|
Explicit device ordinal for FSDP (legacy FSDP1 only). |
None
|
device_mesh
|
Any
|
Optional :class: |
None
|
reshard_after_forward
|
bool | int
|
Controls parameter lifetime after forward:
|
True
|
Returns:
| Type | Description |
|---|---|
Module
|
The wrapped model. If |
Module
|
returned unchanged. |
Source code in src/ezpz/distributed.py
def wrap_model(
model: torch.nn.Module,
use_fsdp: bool = True,
dtype: str = "bfloat16",
device_mesh: Any = None,
reshard_after_forward: bool | int = True,
device_id: torch.device | int | None = None,
) -> torch.nn.Module:
"""Wrap *model* with DDP or FSDP for distributed training.
Args:
model: Model to wrap.
use_fsdp: Use FSDP when ``True``, DDP when ``False``.
dtype: Mixed-precision parameter dtype for FSDP (e.g. ``"bf16"``).
device_id: Explicit device ordinal for FSDP (legacy FSDP1 only).
device_mesh: Optional :class:`torch.distributed.device_mesh.DeviceMesh`.
reshard_after_forward: Controls parameter lifetime after forward:
- ``True`` (default): reshard after forward (FULL_SHARD / ZeRO-3).
- ``False``: keep unsharded (SHARD_GRAD_OP / ZeRO-2).
- ``int``: reshard to this world-size (HYBRID_SHARD).
Returns:
The wrapped model. If ``world_size <= 1`` the original model is
returned unchanged.
"""
import torch
ws = get_world_size()
if ws <= 1:
logger.warning(
"%s requested but world_size=%d; returning unwrapped model.",
"FSDP" if use_fsdp else "DDP",
ws,
)
return model
if get_rank() == 0:
logger.info("Wrapping model with %s", "fsdp" if use_fsdp else "ddp")
if not use_fsdp:
return wrap_model_for_ddp(model)
device_type = get_torch_device_type()
if device_type in ("cpu", "mps"):
logger.warning(
"FSDP is not supported on %s devices; falling back to DDP.",
device_type,
)
return wrap_model_for_ddp(model)
# Auto-create a 1D DeviceMesh when none is provided so FSDP2
# (fully_shard) is the default sharding strategy.
if device_mesh is None:
device_mesh = init_device_mesh_safe(device_type, (ws,))
return _wrap_fsdp2(
model,
dtype=dtype,
device_mesh=device_mesh,
reshard_after_forward=reshard_after_forward,
)
wrap_model_for_ddp(model)
⚓︎
Wrap model with :class:~torch.nn.parallel.DistributedDataParallel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to wrap (should already be on the correct device). |
required |
Returns:
| Type | Description |
|---|---|
Module
|
A DDP-wrapped model. |
Source code in src/ezpz/distributed.py
def wrap_model_for_ddp(model: torch.nn.Module) -> torch.nn.Module:
"""Wrap *model* with :class:`~torch.nn.parallel.DistributedDataParallel`.
Args:
model: Model to wrap (should already be on the correct device).
Returns:
A DDP-wrapped model.
"""
from torch.nn.parallel import DistributedDataParallel as DDP
device_type = get_torch_device_type()
local_rank = get_local_rank()
if device_type in {"cuda", "xpu"}:
return DDP(model, device_ids=[local_rank])
return DDP(model)
wrap_model_for_fsdp(model, dtype='bfloat16', device_id=None, **kwargs)
⚓︎
Wrap model with FSDP 1 and mixed precision.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to wrap (should already be on the correct device). |
required |
dtype
|
str
|
Mixed-precision parameter dtype (e.g. |
'bfloat16'
|
device_id
|
int | None
|
Explicit device ordinal for FSDP. |
None
|
**kwargs
|
Any
|
Extra keyword arguments forwarded to the FSDP constructor. |
{}
|
Returns:
| Type | Description |
|---|---|
Module
|
An FSDP-wrapped model. |
Source code in src/ezpz/distributed.py
def wrap_model_for_fsdp(
model: torch.nn.Module,
dtype: str = "bfloat16",
device_id: int | None = None,
**kwargs: Any,
) -> torch.nn.Module:
"""Wrap *model* with FSDP 1 and mixed precision.
Args:
model: Model to wrap (should already be on the correct device).
dtype: Mixed-precision parameter dtype (e.g. ``"bf16"``).
device_id: Explicit device ordinal for FSDP.
**kwargs: Extra keyword arguments forwarded to the FSDP constructor.
Returns:
An FSDP-wrapped model.
"""
return _wrap_fsdp(model, dtype=dtype, device_id=device_id, **kwargs)
wrap_model_for_fsdp2(model, dtype='bfloat16', device_mesh=None, **kwargs)
⚓︎
Wrap model with FSDP2 (per-module fully_shard).
.. note:: Experimental -- the FSDP2 API is subject to change in future PyTorch releases.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to wrap. |
required |
dtype
|
str
|
Mixed-precision parameter dtype (e.g. |
'bfloat16'
|
device_mesh
|
Any
|
Optional :class: |
None
|
**kwargs
|
Any
|
Extra keyword arguments forwarded to |
{}
|
Returns:
| Type | Description |
|---|---|
Module
|
The model after applying |
Source code in src/ezpz/distributed.py
def wrap_model_for_fsdp2(
model: torch.nn.Module,
dtype: str = "bfloat16",
device_mesh: Any = None,
**kwargs: Any,
) -> torch.nn.Module:
"""Wrap *model* with FSDP2 (per-module ``fully_shard``).
.. note:: **Experimental** -- the FSDP2 API is subject to change in
future PyTorch releases.
Args:
model: Model to wrap.
dtype: Mixed-precision parameter dtype (e.g. ``"bf16"``).
device_mesh: Optional :class:`torch.distributed.device_mesh.DeviceMesh`.
**kwargs: Extra keyword arguments forwarded to ``fully_shard``.
Returns:
The model after applying ``fully_shard`` to every sub-module.
"""
return _wrap_fsdp2(model, dtype=dtype, device_mesh=device_mesh, **kwargs)
write_hostfile_from_list_of_hosts(hosts, hostfile)
⚓︎
Write a hostfile from a list of hostnames.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hosts
|
Sequence[str]
|
Sequence of hostnames to write. |
required |
hostfile
|
str | PathLike
|
Path to write to. |
required |
Source code in src/ezpz/distributed.py
def write_hostfile_from_list_of_hosts(
hosts: Sequence[str],
hostfile: str | os.PathLike,
) -> Path:
"""Write a hostfile from a list of hostnames.
Args:
hosts: Sequence of hostnames to write.
hostfile: Path to write to.
"""
hfp = Path(hostfile)
hfp.parent.mkdir(parents=True, exist_ok=True)
hfp.write_text("\n".join(hosts) + "\n")
return hfp