Skip to content

ezpz.distributed⚓︎

The core distributed training module. Provides initialization, device/backend detection, collective operations, and model wrapping for distributed PyTorch training across any supported hardware.

Setup⚓︎

The primary entry point is setup_torch(), which bootstraps the entire distributed environment in a single call:

import ezpz

rank = ezpz.setup_torch(seed=42)

Advanced Parameters⚓︎

setup_torch() accepts keyword-only parameters for multi-dimensional parallelism:

rank = ezpz.setup_torch(
    seed=42,
    tensor_parallel_size=2,    # Split model across 2 GPUs per TP group
    pipeline_parallel_size=1,  # No pipeline parallelism
    context_parallel_size=1,   # No context parallelism
)
Parameter Default Description
port None Rendezvous port (auto-detected if not set)
seed None Random seed for reproducibility
timeout None DDP init timeout in seconds (env TORCH_DDP_TIMEOUT, default 3600)
verbose False Enable verbose logging during setup
tensor_parallel_size* 1 Number of ranks per tensor-parallel group
pipeline_parallel_size* 1 Number of pipeline stages
context_parallel_size* 1 Number of context-parallel ranks
tensor_parallel_backend* None Backend for TP groups (auto-detected)
pipeline_parallel_backend* None Backend for PP groups (auto-detected)
context_parallel_backend* None Backend for CP groups (auto-detected)
data_parallel_backend* None Backend for DP groups (auto-detected)
device_id* None Override local device index

Parameters marked with * are keyword-only.

Model Wrapping⚓︎

High-level API⚓︎

wrap_model() selects the appropriate wrapping strategy based on the arguments:

model = MyModel().to(ezpz.get_torch_device())

# FSDP wrapping (default)
model = ezpz.wrap_model(model, use_fsdp=True, dtype="bf16")

# DDP wrapping
model = ezpz.wrap_model(model, use_fsdp=False)

Explicit DDP⚓︎

model = ezpz.wrap_model_for_ddp(model)

FSDP (v1)⚓︎

model = ezpz.wrap_model_for_fsdp(model, dtype="bfloat16")

FSDP2⚓︎

FSDP2 uses torch.distributed._composable.fsdp.fully_shard for per-layer sharding with optional device mesh support:

model = ezpz.wrap_model_for_fsdp2(
    model,
    dtype="bf16",
    device_mesh=my_mesh,  # optional DeviceMesh for multi-dim parallelism
)

Hostfile Helpers⚓︎

Functions for managing hostfiles used by MPI launchers:

# Read nodes from an existing hostfile
nodes = ezpz.get_nodes_from_hostfile("/path/to/hostfile")

# Find or create a hostfile with fallback logic
hostfile = ezpz.get_hostfile_with_fallback()

# Write the current hostname to a hostfile
ezpz.write_localhost_to_hostfile("/tmp/hostfile")

# Write a list of hosts to a hostfile
ezpz.write_hostfile_from_list_of_hosts(
    ["node1", "node2", "node3"],
    "/tmp/hostfile"
)

Dtype Map⚓︎

The TORCH_DTYPES_MAP dictionary maps string dtype names to torch.dtype objects:

>>> ezpz.TORCH_DTYPES_MAP["bf16"]
torch.bfloat16
>>> ezpz.TORCH_DTYPES_MAP["fp32"]
torch.float32

Supported keys: "bf16", "bfloat16", "fp16", "float16", "half", "fp32", "float32".

Simplified distributed training primitives for ezpz.

This module is a clean rewrite of :pymod:ezpz.dist. It preserves the same public API surface that the rest of the codebase relies on while eliminating module-level side effects, dead code, redundant aliases, and tangled responsibilities.

Design principles:

  • No side effects on import -- no env vars mutated, no devices set, no wandb probed until the caller explicitly asks for it.
  • Lazy heavy imports -- mpi4py, torch, optional deps are imported inside the functions that need them so that import ezpz.distributed stays fast.
  • Single responsibility per function -- no 200-line monoliths.
  • Flat public API -- every symbol listed in __all__ is a first-class citizen; everything else is prefixed with _.

TORCH_DTYPES_MAP = {'bf16': None, 'bfloat16': None, 'fp16': None, 'float16': None, 'half': None, 'fp32': None, 'float32': None} module-attribute ⚓︎

Mapping of short dtype names to :class:torch.dtype objects.

Populated lazily on first access via :func:_ensure_dtype_map.

all_reduce(obj, op=None, implementation=None) ⚓︎

All-reduce obj across all ranks.

Parameters:

Name Type Description Default
obj Any

Numeric value to reduce.

required
op Any

Reduction operation (defaults to MPI.SUM).

None
implementation str | None

"mpi" (default) or "torch".

None

Returns:

Type Description
Any

The reduced value.

Source code in src/ezpz/distributed.py
def all_reduce(
    obj: Any,
    op: Any = None,
    implementation: str | None = None,
) -> Any:
    """All-reduce *obj* across all ranks.

    Args:
        obj: Numeric value to reduce.
        op: Reduction operation (defaults to ``MPI.SUM``).
        implementation: ``"mpi"`` (default) or ``"torch"``.

    Returns:
        The reduced value.
    """
    impl = (implementation or "mpi").lower()
    if impl == "mpi":
        from mpi4py import MPI

        op = MPI.SUM if op is None else op
        return _get_mpi_comm().allreduce(obj, op=op)
    if impl in {"torch", "pytorch", "pt"}:
        import torch
        import torch.distributed as tdist

        op = tdist.ReduceOp.SUM if op is None else op
        tensor = torch.tensor(obj)
        tdist.all_reduce(tensor, op=op)
        return tensor.item()
    raise ValueError(
        f"Unsupported all_reduce implementation: {implementation}"
    )

barrier(group=None, implementation=None) ⚓︎

Barrier across all ranks.

Tries MPI first (fast), falls back to torch.distributed.barrier.

Parameters:

Name Type Description Default
group Any

Optional torch.distributed process group.

None
implementation str | None

Force "mpi" or "torch".

None
Source code in src/ezpz/distributed.py
def barrier(
    group: Any = None,
    implementation: str | None = None,
) -> None:
    """Barrier across all ranks.

    Tries MPI first (fast), falls back to ``torch.distributed.barrier``.

    Args:
        group: Optional ``torch.distributed`` process group.
        implementation: Force ``"mpi"`` or ``"torch"``.
    """
    if implementation is not None and implementation.lower() not in {
        "mpi",
        "torch",
    }:
        raise ValueError(
            f"Unsupported barrier implementation: {implementation}"
        )
    if implementation is None or implementation.lower() == "mpi":
        try:
            _get_mpi_comm().barrier()
            return
        except Exception:
            if get_rank() == 0:
                logger.warning(
                    "MPI barrier failed; falling back to torch.distributed.barrier"
                )
    # torch fallback
    import torch.distributed

    if torch.distributed.is_initialized():
        kwargs: dict[str, Any] = {}
        if group is not None:
            kwargs["group"] = group
        torch.distributed.barrier(**kwargs)

broadcast(obj, root=0) ⚓︎

Broadcast a picklable obj from root to all ranks via MPI.

Parameters:

Name Type Description Default
obj Any

Payload to broadcast.

required
root int

Originating rank.

0

Returns:

Type Description
Any

The broadcast payload on every rank.

Source code in src/ezpz/distributed.py
def broadcast(obj: Any, root: int = 0) -> Any:
    """Broadcast a picklable *obj* from *root* to all ranks via MPI.

    Args:
        obj: Payload to broadcast.
        root: Originating rank.

    Returns:
        The broadcast payload on every rank.
    """
    return _get_mpi_comm().bcast(obj, root=root)

cleanup() ⚓︎

Destroy the torch.distributed process group if active.

Source code in src/ezpz/distributed.py
def cleanup() -> None:
    """Destroy the ``torch.distributed`` process group if active."""
    import torch.distributed

    try:
        import wandb  # noqa: F811

        if wandb.run is not None:
            logger.info("wandb.run=[%s](%s)", wandb.run.name, wandb.run.url)
    except Exception:
        pass
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()

get_cpus_per_node() ⚓︎

Return the number of CPUs available on the local node.

Source code in src/ezpz/distributed.py
def get_cpus_per_node() -> int:
    """Return the number of CPUs available on the local node."""
    return os.cpu_count() or 1

get_device_properties(device=None) ⚓︎

Return device properties as a dictionary.

Parameters:

Name Type Description Default
device int | None

Device index. Defaults to get_local_rank().

None
Source code in src/ezpz/distributed.py
def get_device_properties(device: int | None = None) -> dict[str, Any]:
    """Return device properties as a dictionary.

    Args:
        device: Device index.  Defaults to ``get_local_rank()``.
    """
    import torch

    device_type = get_torch_device_type()
    idx = device if device is not None else get_local_rank()
    if device_type == "cuda":
        props = torch.cuda.get_device_properties(idx)
        return {"name": props.name, "total_memory": props.total_mem}
    if device_type == "xpu" and hasattr(torch, "xpu"):
        props = torch.xpu.get_device_properties(idx)
        return {
            "name": props.name,
            "total_memory": getattr(props, "total_memory", -1),
        }
    return {"name": device_type, "total_memory": -1}

get_dist_info(verbose=None, hostfile=None) ⚓︎

Return a dictionary summarising the distributed environment.

Parameters:

Name Type Description Default
verbose bool | None

If True, log the info dict as formatted JSON.

None
hostfile str | PathLike | None

Explicit hostfile path.

None
Source code in src/ezpz/distributed.py
def get_dist_info(
    verbose: bool | None = None,
    hostfile: str | os.PathLike | None = None,
) -> dict[str, Any]:
    """Return a dictionary summarising the distributed environment.

    Args:
        verbose: If ``True``, log the info dict as formatted JSON.
        hostfile: Explicit hostfile path.
    """
    import sys

    from ezpz.configs import get_scheduler

    hfp = (
        get_hostfile_with_fallback(hostfile)
        if hostfile is None
        else Path(hostfile)
    )
    if hfp is not None and Path(hfp).is_file():
        hosts = get_nodes_from_hostfile(hfp)
        hostfile_path = Path(hfp).resolve().as_posix()
    else:
        hosts = [get_hostname()]
        hostfile_path = str(hfp) if hfp is not None else ""
    num_nodes = len(hosts)
    gpus = get_gpus_per_node()
    info: dict[str, Any] = {}
    info.update(
        {
            "DEVICE": get_torch_device(),
            "DEVICE_ID": f"{get_torch_device()}:{get_local_rank()}",
            "DISTRIBUTED_BACKEND": get_torch_backend(),
            "GPUS_PER_NODE": gpus,
            "HOSTS": str(hosts),
            "HOSTFILE": hostfile_path,
            "HOSTNAME": get_hostname(),
            "LOCAL_RANK": get_local_rank(),
            "MACHINE": get_machine(),
            "NUM_NODES": num_nodes,
            "NGPUS": num_nodes * gpus,
            "NGPUS_AVAILABLE": get_world_size_total(),
            "NODE_ID": get_node_index(),
            "RANK": get_rank(),
            "SCHEDULER": get_scheduler(),
            "WORLD_SIZE_TOTAL": get_world_size_total(),
            "WORLD_SIZE_IN_USE": get_world_size_in_use(),
            "world_size": get_world_size(),
            "EZPZ_RUN_COMMAND": os.environ.get(
                "EZPZ_RUN_COMMAND", sys.argv[0]
            ),
        }
    )
    if verbose:
        import json

        logger.info("DistInfo=%s", json.dumps(info, indent=4, sort_keys=True))
    return info

get_gpus_per_node() ⚓︎

Return the number of accelerators on the local node.

Prefers environment variables (NGPU_PER_HOST, LOCAL_WORLD_SIZE, PMI_LOCAL_SIZE, SLURM_NTASKS_PER_NODE) then falls back to torch.{cuda,xpu}.device_count().

Source code in src/ezpz/distributed.py
def get_gpus_per_node() -> int:
    """Return the number of accelerators on the local node.

    Prefers environment variables (``NGPU_PER_HOST``, ``LOCAL_WORLD_SIZE``,
    ``PMI_LOCAL_SIZE``, ``SLURM_NTASKS_PER_NODE``) then falls back to
    ``torch.{cuda,xpu}.device_count()``.
    """
    for var in (
        "NGPU_PER_HOST",
        "LOCAL_WORLD_SIZE",
        "PMI_LOCAL_SIZE",
        "SLURM_NTASKS_PER_NODE",
    ):
        val = os.environ.get(var)
        if val is not None:
            return int(val)
    import torch

    if torch.cuda.is_available():
        return torch.cuda.device_count()
    if torch.xpu.is_available():
        return torch.xpu.device_count()
    if torch.backends.mps.is_available():
        return get_world_size_in_use()
    return 0

get_hostfile_with_fallback(hostfile=None) ⚓︎

Locate (or create) a usable hostfile.

Checks PBS, SLURM, and environment variables. As a last resort, writes localhost to a file in the current directory.

Parameters:

Name Type Description Default
hostfile str | PathLike | None

Explicit path; auto-detected when None.

None

Returns:

Type Description
Path

class:Path to the hostfile.

Source code in src/ezpz/distributed.py
def get_hostfile_with_fallback(
    hostfile: str | os.PathLike | None = None,
) -> Path:
    """Locate (or create) a usable hostfile.

    Checks PBS, SLURM, and environment variables.  As a last resort,
    writes ``localhost`` to a file in the current directory.

    Args:
        hostfile: Explicit path; auto-detected when ``None``.

    Returns:
        :class:`Path` to the hostfile.
    """
    from ezpz.configs import get_scheduler

    scheduler = get_scheduler()

    if scheduler.lower() == "slurm":
        return _make_hostfile_from_slurm()

    if hostfile is not None:
        hfp = Path(hostfile)
        if hfp.is_file():
            return hfp

    # Try standard env vars
    for var in ("PBS_NODEFILE", "HOSTFILE"):
        val = os.environ.get(var)
        if val and Path(val).is_file():
            return Path(val)

    # PBS without env var
    if scheduler == "PBS":
        try:
            import ezpz.pbs

            nodefile = ezpz.pbs.get_pbs_nodefile()
            if nodefile is not None:
                return Path(nodefile)
        except Exception:
            pass

    # Fallback: write localhost
    hfp = Path(os.getcwd()) / "hostfile"
    hfp.touch(exist_ok=True)
    write_localhost_to_hostfile(hfp)
    return hfp

get_hostname() ⚓︎

Return the hostname of the current machine (lowercased).

Source code in src/ezpz/distributed.py
def get_hostname() -> str:
    """Return the hostname of the current machine (lowercased)."""
    try:
        name = socket.gethostname()
        if name:
            try:
                return socket.gethostbyaddr(name)[0].strip().lower()
            except OSError:
                return name.strip().lower()
    except Exception:
        pass
    for var in ("HOSTNAME", "HOST"):
        val = os.environ.get(var)
        if val:
            return val.strip().lower()
    try:
        import platform

        node = platform.node()
        if node:
            return node.strip().lower()
    except Exception:
        pass
    return "localhost"

get_local_rank() ⚓︎

Return the local rank (GPU index) of the current process on its node.

The value is resolved from well-known environment variables set by common MPI implementations and job schedulers. If none are set the function falls back to get_rank() % get_gpus_per_node().

Returns:

Type Description
int

Local rank (0-indexed within the node).

Source code in src/ezpz/distributed.py
def get_local_rank() -> int:
    """Return the local rank (GPU index) of the current process on its node.

    The value is resolved from well-known environment variables set by
    common MPI implementations and job schedulers.  If none are set the
    function falls back to ``get_rank() % get_gpus_per_node()``.

    Returns:
        Local rank (0-indexed within the node).
    """
    _ENV_VARS = (
        "LOCAL_RANK",
        "PMI_LOCAL_RANK",
        "OMPI_COMM_WORLD_LOCAL_RANK",
        "MPI_LOCALRANKID",
        "MPICH_LOCALRANKID",
        "SLURM_LOCAL_ID",
    )
    for var in _ENV_VARS:
        val = os.environ.get(var)
        if val is not None:
            return int(val)
    ws = get_world_size()
    if ws <= 1:
        return 0
    gpn = get_gpus_per_node()
    return get_rank() % gpn if gpn > 0 else 0

get_machine(hostname=None) ⚓︎

Identify the ALCF / HPC machine from its hostname prefix.

Parameters:

Name Type Description Default
hostname str | None

Override; auto-detected when None.

None

Returns:

Type Description
str

A human-readable machine name (e.g. "Polaris", "Aurora").

Source code in src/ezpz/distributed.py
def get_machine(hostname: str | None = None) -> str:
    """Identify the ALCF / HPC machine from its hostname prefix.

    Args:
        hostname: Override; auto-detected when ``None``.

    Returns:
        A human-readable machine name (e.g. ``"Polaris"``, ``"Aurora"``).
    """
    if hostname is None:
        hostname = get_hostname()
    _PREFIX_MAP = (
        ("frontier", "Frontier"),
        ("sophia", "Sophia"),
        ("theta", "ThetaGPU"),
        ("x1", "SunSpot"),
        ("x4", "Aurora"),
        ("login", "Perlmutter"),
        ("nid", "Perlmutter"),
    )
    for prefix, name in _PREFIX_MAP:
        if hostname.startswith(prefix):
            return name
    if hostname.startswith("x3"):
        return "Sirius" if "sirius" in hostname else "Polaris"
    return hostname

get_node_index() ⚓︎

Return the index of the current node (rank // gpus_per_node).

Source code in src/ezpz/distributed.py
def get_node_index() -> int:
    """Return the index of the current node (``rank // gpus_per_node``)."""
    gpn = get_gpus_per_node()
    return get_rank() // gpn if gpn > 0 else 0

get_nodes_from_hostfile(hostfile) ⚓︎

Read hostnames from hostfile, one per line.

Parameters:

Name Type Description Default
hostfile str | PathLike

Path to the hostfile.

required

Returns:

Type Description
list[str]

List of hostnames.

Source code in src/ezpz/distributed.py
def get_nodes_from_hostfile(hostfile: str | os.PathLike) -> list[str]:
    """Read hostnames from *hostfile*, one per line.

    Args:
        hostfile: Path to the hostfile.

    Returns:
        List of hostnames.
    """
    fpath = Path(hostfile)
    if not fpath.is_file():
        return [get_hostname()]
    with fpath.open("r") as f:
        return [line.rstrip("\n") for line in f if line.strip()]

get_num_nodes(hostfile=None) ⚓︎

Return the number of nodes in the current allocation.

Checks SLURM_NNODES first, then counts lines in hostfile.

Parameters:

Name Type Description Default
hostfile str | PathLike | None

Explicit path; resolved automatically when None.

None
Source code in src/ezpz/distributed.py
def get_num_nodes(hostfile: str | os.PathLike | None = None) -> int:
    """Return the number of nodes in the current allocation.

    Checks ``SLURM_NNODES`` first, then counts lines in *hostfile*.

    Args:
        hostfile: Explicit path; resolved automatically when ``None``.
    """
    slurm_nnodes = os.environ.get("SLURM_NNODES")
    if slurm_nnodes is not None:
        return int(slurm_nnodes)
    hfp = get_hostfile_with_fallback(hostfile)
    hosts = [h.split(".")[0] for h in get_nodes_from_hostfile(hfp)]
    return len(hosts)

get_rank() ⚓︎

Return the global MPI rank of the current process.

The value is resolved from well-known environment variables set by common MPI implementations and job schedulers. If none are set the function falls back to querying the MPI communicator.

Returns:

Type Description
int

Global rank (0-indexed).

Source code in src/ezpz/distributed.py
def get_rank() -> int:
    """Return the global MPI rank of the current process.

    The value is resolved from well-known environment variables set by
    common MPI implementations and job schedulers.  If none are set the
    function falls back to querying the MPI communicator.

    Returns:
        Global rank (0-indexed).
    """
    _ENV_VARS = (
        "RANK",
        "PMI_RANK",
        "OMPI_COMM_WORLD_RANK",
        "SLURM_PROCID",
    )
    for var in _ENV_VARS:
        val = os.environ.get(var)
        if val is not None:
            return int(val)
    return int(_get_mpi_comm().Get_rank())

get_torch_backend() ⚓︎

Return the appropriate torch.distributed backend string.

Checks TORCH_BACKEND env, then probes hardware availability to select nccl / xccl / gloo.

Source code in src/ezpz/distributed.py
def get_torch_backend() -> str:
    """Return the appropriate ``torch.distributed`` backend string.

    Checks ``TORCH_BACKEND`` env, then probes hardware availability to
    select ``nccl`` / ``xccl`` / ``gloo``.
    """
    env = os.environ.get("TORCH_BACKEND")
    if env is not None:
        return env
    import torch

    if torch.cuda.is_available() and torch.distributed.is_backend_available(
        "nccl"
    ):
        return "nccl"
    if torch.xpu.is_available():
        if torch.distributed.is_backend_available("xccl"):
            return "xccl"
        return "ccl"
    return "gloo"

get_torch_device(*, device_type=None, as_torch_device=None) ⚓︎

Return the current accelerator device.

Parameters:

Name Type Description Default
device_type str | None

Force a specific device type.

None
as_torch_device bool | None

If True, return a :class:torch.device object instead of a plain string.

None
Source code in src/ezpz/distributed.py
def get_torch_device(
    *,
    device_type: str | None = None,
    as_torch_device: bool | None = None,
) -> str | torch.device:
    """Return the current accelerator device.

    Args:
        device_type: Force a specific device type.
        as_torch_device: If ``True``, return a :class:`torch.device` object
            instead of a plain string.
    """
    import torch

    env = os.environ.get("TORCH_DEVICE")
    if env:
        normalized = env.strip().lower()
        base = normalized.split(":", 1)[0]
        if base in _SUPPORTED_DEVICE_TYPES:
            return torch.device(normalized) if as_torch_device else normalized
    dt = device_type if device_type is not None else get_torch_device_type()
    return torch.device(dt) if as_torch_device else dt

get_torch_device_type(device_type=None) ⚓︎

Return the accelerator type as a string ("cuda", "xpu", …).

Respects the TORCH_DEVICE environment variable when set.

Parameters:

Name Type Description Default
device_type str | None

Override value; returned as-is after validation.

None
Source code in src/ezpz/distributed.py
def get_torch_device_type(device_type: str | None = None) -> str:
    """Return the accelerator type as a string (``"cuda"``, ``"xpu"``, …).

    Respects the ``TORCH_DEVICE`` environment variable when set.

    Args:
        device_type: Override value; returned as-is after validation.
    """
    if device_type is not None:
        if device_type not in _SUPPORTED_DEVICE_TYPES:
            raise ValueError(
                f"Unsupported device_type={device_type!r}; "
                f"expected one of {sorted(_SUPPORTED_DEVICE_TYPES)}"
            )
        return device_type
    env = os.environ.get("TORCH_DEVICE")
    if env:
        base = env.strip().lower().split(":", 1)[0]
        if base in _SUPPORTED_DEVICE_TYPES:
            return base
        logger.warning(
            "Ignoring unsupported TORCH_DEVICE=%s; expected one of %s",
            env,
            sorted(_SUPPORTED_DEVICE_TYPES),
        )
    import torch

    if torch.xpu.is_available():
        return "xpu"
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"

get_world_size(*, total=False, in_use=False) ⚓︎

Return the distributed world size.

Parameters:

Name Type Description Default
total bool

If True, return the total available accelerator count (num_nodes * gpus_per_node).

False
in_use bool

If True, return MPI.COMM_WORLD.Get_size().

False

Returns:

Type Description
int

World size as an integer.

Source code in src/ezpz/distributed.py
def get_world_size(
    *,
    total: bool = False,
    in_use: bool = False,
) -> int:
    """Return the distributed world size.

    Args:
        total: If ``True``, return the *total available* accelerator count
            (``num_nodes * gpus_per_node``).
        in_use: If ``True``, return ``MPI.COMM_WORLD.Get_size()``.

    Returns:
        World size as an integer.
    """
    if total:
        return get_world_size_total()
    if in_use:
        return get_world_size_in_use()
    return int(_get_mpi_comm().Get_size())

get_world_size_in_use() ⚓︎

Return the number of MPI ranks currently participating.

Source code in src/ezpz/distributed.py
def get_world_size_in_use() -> int:
    """Return the number of MPI ranks currently participating."""
    return int(_get_mpi_comm().Get_size())

get_world_size_total() ⚓︎

Return max(get_gpus_per_node(), 1) * get_num_nodes().

Source code in src/ezpz/distributed.py
def get_world_size_total() -> int:
    """Return ``max(get_gpus_per_node(), 1) * get_num_nodes()``."""
    return max(get_gpus_per_node(), 1) * get_num_nodes()

log_dict_as_bulleted_list(d, name=None) ⚓︎

Log a dictionary as a bulleted list.

Parameters:

Name Type Description Default
d dict

Dictionary to format.

required
name str | None

Optional label for the header.

None
Source code in src/ezpz/distributed.py
def log_dict_as_bulleted_list(d: dict, name: str | None = None) -> None:
    """Log a dictionary as a bulleted list.

    Args:
        d: Dictionary to format.
        name: Optional label for the header.
    """
    tag = name or getattr(d, "__qualname__", "dict")
    lines = [f"[{tag}]:"] + [f"  - {k}={v}" for k, v in d.items()]
    logger.info("\n\n%s\n", "\n".join(lines))

print_dist_setup(hostfile=None, display=True) ⚓︎

Build (and optionally log) a one-line-per-rank summary string.

Parameters:

Name Type Description Default
hostfile str | PathLike | None

Explicit hostfile path.

None
display bool

If True, emit the string via :func:logger.info.

True

Returns:

Type Description
str

The formatted summary string.

Source code in src/ezpz/distributed.py
def print_dist_setup(
    hostfile: str | os.PathLike | None = None,
    display: bool = True,
) -> str:
    """Build (and optionally log) a one-line-per-rank summary string.

    Args:
        hostfile: Explicit hostfile path.
        display: If ``True``, emit the string via :func:`logger.info`.

    Returns:
        The formatted summary string.
    """
    rank = get_rank()
    world_size = get_world_size(in_use=True)
    local_rank = get_local_rank()
    gpn = max(get_gpus_per_node(), 1)
    num_nodes = max(world_size // gpn, 1)
    node = get_node_index()
    device = get_torch_device_type()
    hn = socket.gethostname()

    rw = len(str(max(0, world_size - 1)))
    lw = len(str(max(0, gpn - 1)))
    nnw = len(str(max(0, num_nodes - 1)))
    nw = nnw

    parts = [
        f"['{hn}']",
        f"[{device=}]",
        f"[node={node:>0{nw}d}/{num_nodes - 1:<0{nnw}d}]",
        f"[local_rank={local_rank:>0{lw}d}/{gpn - 1:<0{lw}d}]",
        f"[rank={rank:>0{rw}d}/{world_size - 1:<0{rw}d}]",
    ]
    dist_str = "".join(parts)
    if display:
        logger.info(dist_str)
    if rank == 0:
        wst = get_world_size(total=True)
        logger.warning(
            'Using [%d / %d] available "%s" devices !!',
            world_size,
            wst,
            device,
        )
    return dist_str

query_environment() ⚓︎

Return {world_size, rank, local_rank} from env vars or MPI.

Source code in src/ezpz/distributed.py
def query_environment() -> dict[str, int]:
    """Return ``{world_size, rank, local_rank}`` from env vars or MPI."""
    ws = os.environ.get("WORLD_SIZE")
    r = os.environ.get("RANK")
    lr = os.environ.get("LOCAL_RANK")
    if ws is not None and r is not None and lr is not None:
        return {"world_size": int(ws), "rank": int(r), "local_rank": int(lr)}
    return {
        "world_size": get_world_size(),
        "rank": get_rank(),
        "local_rank": get_local_rank(),
    }

seed_everything(seed) ⚓︎

Seed Python, NumPy, and PyTorch RNGs for reproducibility.

Parameters:

Name Type Description Default
seed int

The random seed.

required
Source code in src/ezpz/distributed.py
def seed_everything(seed: int) -> None:
    """Seed Python, NumPy, and PyTorch RNGs for reproducibility.

    Args:
        seed: The random seed.
    """
    import random

    import numpy as np
    import torch

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    if torch.xpu.is_available():
        torch.xpu.manual_seed(seed)

setup_torch(port=None, seed=None, timeout=None, verbose=False, *, tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, device_id=None) ⚓︎

Initialise torch.distributed and return the global rank.

This is the main entry point. It:

  1. Uses MPI to discover rank / world_size / master_addr / master_port.
  2. Calls torch.distributed.init_process_group.
  3. Sets the local CUDA/XPU device.
  4. Optionally seeds RNGs and initialises tensor parallelism.
  5. Prints a one-line-per-rank summary.

Parameters:

Name Type Description Default
port str | int | None

Fallback master port (rank 0 picks a free port otherwise).

None
seed int | None

If given, call :func:seed_everything with a rank-aware seed.

None
timeout str | int | None

init_process_group timeout in seconds (default 3600).

None
verbose bool

Print verbose dist info on rank 0.

False
tensor_parallel_size int

TP degree (default 1 = disabled).

1
pipeline_parallel_size int

PP degree (default 1 = disabled).

1
context_parallel_size int

CP degree (default 1 = disabled).

1
tensor_parallel_backend str | None

Override backend for TP group.

None
pipeline_parallel_backend str | None

Override backend for PP group.

None
context_parallel_backend str | None

Override backend for CP group.

None
data_parallel_backend str | None

Override backend for DP group.

None
device_id int | None

Explicit device ordinal for init_process_group.

None

Returns:

Type Description
int

The global rank of this process.

Source code in src/ezpz/distributed.py
def setup_torch(
    port: str | int | None = None,
    seed: int | None = None,
    timeout: str | int | None = None,
    verbose: bool = False,
    *,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    context_parallel_size: int = 1,
    tensor_parallel_backend: str | None = None,
    pipeline_parallel_backend: str | None = None,
    context_parallel_backend: str | None = None,
    data_parallel_backend: str | None = None,
    device_id: int | None = None,
) -> int:
    """Initialise ``torch.distributed`` and return the global rank.

    This is the main entry point.  It:

    1. Uses MPI to discover rank / world_size / master_addr / master_port.
    2. Calls ``torch.distributed.init_process_group``.
    3. Sets the local CUDA/XPU device.
    4. Optionally seeds RNGs and initialises tensor parallelism.
    5. Prints a one-line-per-rank summary.

    Args:
        port: Fallback master port (rank 0 picks a free port otherwise).
        seed: If given, call :func:`seed_everything` with a rank-aware seed.
        timeout: ``init_process_group`` timeout in seconds (default 3600).
        verbose: Print verbose dist info on rank 0.
        tensor_parallel_size: TP degree (default 1 = disabled).
        pipeline_parallel_size: PP degree (default 1 = disabled).
        context_parallel_size: CP degree (default 1 = disabled).
        tensor_parallel_backend: Override backend for TP group.
        pipeline_parallel_backend: Override backend for PP group.
        context_parallel_backend: Override backend for CP group.
        data_parallel_backend: Override backend for DP group.
        device_id: Explicit device ordinal for ``init_process_group``.

    Returns:
        The global rank of this process.
    """
    import torch

    device_type = get_torch_device_type()
    backend = get_torch_backend()
    timeout_s = (
        int(timeout)
        if timeout is not None
        else int(os.environ.get("TORCH_DDP_TIMEOUT", 3600))
    )

    # -- Single-device fast path --
    ws_env = os.environ.get("WORLD_SIZE")
    if ws_env is not None and ws_env == "1":
        if get_rank() == 0:
            logger.info(
                "Running on a single %s, not initialising torch.distributed!",
                device_type,
            )
        _set_env_vars(rank=0, local_rank=0, world_size=1)
        return 0

    # -- Multi-device init --
    dsetup = _setup_ddp(
        port=str(port) if port is not None else "1234",
        timeout=timedelta(seconds=timeout_s),
        backend=backend,
        device_id=device_id,
    )

    rank = dsetup["rank"]
    world_size = dsetup["world_size"]
    local_rank = dsetup["local_rank"]

    # Set local device
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
    if torch.xpu.is_available():
        torch.xpu.set_device(local_rank)

    _set_env_vars(rank=rank, local_rank=local_rank, world_size=world_size)

    # -- Tensor / pipeline / context parallelism --
    if (
        tensor_parallel_size > 1
        or pipeline_parallel_size > 1
        or context_parallel_size > 1
    ):
        import ezpz.tp

        ezpz.tp.initialize_tensor_parallel(
            tensor_parallel_size=tensor_parallel_size,
            pipeline_parallel_size=pipeline_parallel_size,
            context_parallel_size=context_parallel_size,
            tensor_parallel_backend=tensor_parallel_backend,
            pipeline_parallel_backend=pipeline_parallel_backend,
            context_parallel_backend=context_parallel_backend,
            data_parallel_backend=data_parallel_backend,
            timeout=timedelta(seconds=float(timeout_s)),
        )

    # -- Seed --
    if seed is not None:
        if rank == 0:
            logger.warning("Manually specifying seed=%d", seed)
        seed_everything(seed * (rank + 1) * (local_rank + 1))

    # -- Diagnostics --
    if rank == 0:
        _ = get_dist_info(verbose=verbose)
        logger.info(
            "Using device=%s with backend=%s for distributed training.",
            device_type,
            backend,
        )

    if world_size > 1:
        barrier()

    logger.info(print_dist_setup(display=False))
    barrier()
    return rank

setup_wandb(project_name=None, entity=None, config=None, outdir=None, project=None, dir=None, id=None, name=None, notes=None, tags=None, config_exclude_keys=None, config_include_keys=None, allow_val_change=None, group=None, job_type=None, mode=None, force=False, reinit=None, resume=None, resume_from=None, fork_from=None, save_code=None, init_timeout=None, start_method=None, tensorboard=None, sync_tensorboard=None, monitor_gym=None, settings=None, **kwargs) ⚓︎

Initialise a wandb run (rank 0 only logs, others get disabled mode).

Most parameters are forwarded directly to :func:wandb.init. See the wandb docs <https://docs.wandb.ai/ref/python/init/>_ for details.

Returns:

Name Type Description
The Any

obj:wandb.run object, or None if wandb is unavailable.

Source code in src/ezpz/distributed.py
def setup_wandb(
    project_name: str | None = None,
    entity: str | None = None,
    config: dict[str, Any] | None = None,
    outdir: str | os.PathLike | None = None,
    project: str | None = None,
    dir: str | os.PathLike | None = None,
    id: str | None = None,
    name: str | None = None,
    notes: str | None = None,
    tags: Sequence[str] | None = None,
    config_exclude_keys: list[str] | None = None,
    config_include_keys: list[str] | None = None,
    allow_val_change: bool | None = None,
    group: str | None = None,
    job_type: str | None = None,
    mode: Literal["online", "offline", "disabled", "shared"] | None = None,
    force: bool = False,
    reinit: bool | str | None = None,
    resume: bool | str | None = None,
    resume_from: str | None = None,
    fork_from: str | None = None,
    save_code: bool | None = None,
    init_timeout: int | float | None = None,
    start_method: Literal["fork", "spawn", "thread", "process"] | None = None,
    tensorboard: bool | None = None,
    sync_tensorboard: bool | None = None,
    monitor_gym: bool | None = None,
    settings: dict[str, Any] | None = None,
    **kwargs,
) -> Any:
    """Initialise a wandb run (rank 0 only logs, others get disabled mode).

    Most parameters are forwarded directly to :func:`wandb.init`.  See
    the `wandb docs <https://docs.wandb.ai/ref/python/init/>`_ for
    details.

    Returns:
        The :obj:`wandb.run` object, or ``None`` if wandb is unavailable.
    """
    import wandb

    if not verify_wandb():
        logger.warning("verify_wandb() failed; not initialising run")
        return None

    rank = get_rank()
    outdir_str = Path(outdir).as_posix() if outdir else os.getcwd()

    # Resolve project name
    _project = project or project_name
    if _project is None:
        _project = os.environ.get(
            "WB_PROJECT",
            os.environ.get("WANDB_PROJECT", os.environ.get("WB_PROJECT_NAME")),
        )
    if _project is None:
        import sys

        frame = sys._getframe().f_back
        if frame is not None:
            fp = Path(frame.f_code.co_filename)
            _project = f"{fp.parent.stem}.{fp.stem}"

    # Resolve mode
    _mode = _resolve_wandb_mode(mode)

    logger.info("Setting up wandb from rank=%d", rank)
    logger.info("Using WB_PROJECT=%s", _project)

    try:
        run = wandb.init(
            entity=entity,
            project=_project,
            dir=str(dir) if dir is not None else outdir_str,
            id=id,
            name=name,
            notes=notes,
            tags=tags,
            config_exclude_keys=config_exclude_keys,
            config_include_keys=config_include_keys,
            allow_val_change=allow_val_change,
            group=group,
            job_type=job_type,
            mode=_mode,
            force=force,
            reinit=reinit,
            resume=resume,
            resume_from=resume_from,
            fork_from=fork_from,
            save_code=save_code,
            tensorboard=tensorboard if tensorboard is not None else False,
            sync_tensorboard=sync_tensorboard
            if sync_tensorboard is not None
            else False,
            monitor_gym=monitor_gym,
            settings=(
                settings
                if settings is not None
                else wandb.Settings(
                    init_timeout=init_timeout
                    if init_timeout is not None
                    else 60,
                    start_method=start_method
                    if start_method is not None
                    else "fork",
                )
            ),
            **kwargs,
        )
        if run is not None:
            logger.info("wandb.run=[%s](%s)", run.name, run.url)
            import torch

            import ezpz

            run.config.update(
                {
                    "DIST_INFO": get_dist_info(),
                    "hostname": get_hostname(),
                    "pytorch_backend": get_torch_backend(),
                    "torch_version": torch.__version__,
                    "world_size": get_world_size(),
                    "ezpz_version": ezpz.__version__,
                    "machine": get_machine(),
                    "working_directory": os.getcwd(),
                }
            )
            if config is not None:
                run.config.update({"config": config})
        return wandb.run
    except Exception as exc:
        logger.exception("wandb.init() failed from rank=%d: %s", rank, exc)
        logger.warning("Continuing without wandb logging.")
        return None

synchronize(device=None) ⚓︎

Block until all work on the given accelerator has finished.

Parameters:

Name Type Description Default
device str | int | None

Device specifier; auto-detected when None.

None
Source code in src/ezpz/distributed.py
def synchronize(device: str | int | None = None) -> None:
    """Block until all work on the given accelerator has finished.

    Args:
        device: Device specifier; auto-detected when ``None``.
    """
    import torch

    if torch.cuda.is_available():
        torch.cuda.synchronize(device)
    elif torch.xpu.is_available():
        torch.xpu.synchronize(device)
    elif torch.backends.mps.is_available():
        torch.mps.synchronize()

timeitlogit(rank=None, record=True, verbose=False, prefix=None) ⚓︎

Decorator factory to time a function, optionally logging to wandb.

Parameters:

Name Type Description Default
rank int | None

Rank whose logger emits messages (defaults to :func:get_rank).

None
record bool

Whether to wandb.log the timing.

True
verbose bool

Whether to log to stdout on the selected rank.

False
prefix str | None

Metric prefix for wandb (default "timeit").

None

Returns:

Type Description
Callable

A decorator that wraps the target function.

Example::

@timeitlogit(rank=0, verbose=True)
def train_step(batch): ...
Source code in src/ezpz/distributed.py
def timeitlogit(
    rank: int | None = None,
    record: bool = True,
    verbose: bool = False,
    prefix: str | None = None,
) -> Callable:
    """Decorator factory to time a function, optionally logging to wandb.

    Args:
        rank: Rank whose logger emits messages (defaults to :func:`get_rank`).
        record: Whether to ``wandb.log`` the timing.
        verbose: Whether to log to stdout on the selected rank.
        prefix: Metric prefix for wandb (default ``"timeit"``).

    Returns:
        A decorator that wraps the target function.

    Example::

        @timeitlogit(rank=0, verbose=True)
        def train_step(batch): ...
    """
    _rank = rank if rank is not None else get_rank()
    _prefix = prefix or "timeit"

    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            t0 = time.perf_counter()
            result = func(*args, **kwargs)
            dt = time.perf_counter() - t0
            fname = getattr(
                func, "__qualname__", getattr(func, "__name__", "unknown")
            )
            if record:
                try:
                    import wandb

                    if wandb.run is not None:
                        wandb.log({f"{_prefix}/{fname}": dt}, commit=False)
                except Exception:
                    pass
            if verbose and _rank == 0:
                arg_str = ", ".join(map(str, args))
                kw_str = ", ".join(f"{k}={v}" for k, v in kwargs.items())
                inner = ", ".join(filter(None, [arg_str, kw_str]))
                logger.info("%s(%s) took %.4f s", fname, inner, dt)
            return result

        return wrapper

    return decorator

verify_wandb() ⚓︎

Return True if wandb is importable, enabled, and authenticated.

Source code in src/ezpz/distributed.py
def verify_wandb() -> bool:
    """Return ``True`` if wandb is importable, enabled, and authenticated."""
    rank = get_rank()
    try:
        import wandb
    except Exception:
        if rank == 0:
            logger.warning(
                "Unable to import wandb; install with `pip install wandb`"
            )
        return False
    if os.environ.get("WANDB_DISABLED"):
        return False
    wm = os.environ.get("WANDB_MODE", "").lower()
    if wm == "disabled":
        return False
    if (
        wandb.api.api_key is not None
        or os.environ.get("WANDB_API_KEY") is not None
    ):
        return True
    # Last resort: check ~/.netrc
    try:
        import netrc as _netrc

        netrc_path = Path(os.path.expanduser("~/.netrc"))
        if netrc_path.is_file():
            auth = _netrc.netrc(netrc_path).authenticators("api.wandb.ai")
            return bool(auth)
    except Exception:
        pass
    return False

wrap_model(model, use_fsdp=True, dtype='bfloat16', device_id=None, device_mesh=None) ⚓︎

Wrap model with DDP or FSDP for distributed training.

Parameters:

Name Type Description Default
model Module

Model to wrap.

required
use_fsdp bool

Use FSDP when True, DDP when False.

True
dtype str

Mixed-precision parameter dtype for FSDP (e.g. "bf16").

'bfloat16'
device_id device | int | None

Explicit device ordinal for FSDP.

None
device_mesh Any

Optional :class:torch.distributed.device_mesh.DeviceMesh.

None

Returns:

Type Description
Module

The wrapped model. If world_size <= 1 the original model is

Module

returned unchanged.

Source code in src/ezpz/distributed.py
def wrap_model(
    model: torch.nn.Module,
    use_fsdp: bool = True,
    dtype: str = "bfloat16",
    device_id: torch.device | int | None = None,
    device_mesh: Any = None,
) -> torch.nn.Module:
    """Wrap *model* with DDP or FSDP for distributed training.

    Args:
        model: Model to wrap.
        use_fsdp: Use FSDP when ``True``, DDP when ``False``.
        dtype: Mixed-precision parameter dtype for FSDP (e.g. ``"bf16"``).
        device_id: Explicit device ordinal for FSDP.
        device_mesh: Optional :class:`torch.distributed.device_mesh.DeviceMesh`.

    Returns:
        The wrapped model.  If ``world_size <= 1`` the original model is
        returned unchanged.
    """
    import torch

    ws = get_world_size()
    if ws <= 1:
        logger.warning(
            "%s requested but world_size=%d; returning unwrapped model.",
            "FSDP" if use_fsdp else "DDP",
            ws,
        )
        return model
    if get_rank() == 0:
        logger.info("Wrapping model with %s", "fsdp" if use_fsdp else "ddp")
    if use_fsdp:
        device_type = get_torch_device_type()
        if device_type in ("cpu", "mps"):
            logger.warning(
                "FSDP is not supported on %s devices; falling back to DDP.",
                device_type,
            )
            return wrap_model_for_ddp(model)
        if device_mesh is not None:
            return _wrap_fsdp2(
                model, dtype=dtype, device_mesh=device_mesh,
            )
        if device_id is not None:
            device_id = torch.device(device_type, device_id)
        return _wrap_fsdp(model, dtype=dtype, device_id=device_id)
    return wrap_model_for_ddp(model)

wrap_model_for_ddp(model) ⚓︎

Wrap model with :class:~torch.nn.parallel.DistributedDataParallel.

Parameters:

Name Type Description Default
model Module

Model to wrap (should already be on the correct device).

required

Returns:

Type Description
Module

A DDP-wrapped model.

Source code in src/ezpz/distributed.py
def wrap_model_for_ddp(model: torch.nn.Module) -> torch.nn.Module:
    """Wrap *model* with :class:`~torch.nn.parallel.DistributedDataParallel`.

    Args:
        model: Model to wrap (should already be on the correct device).

    Returns:
        A DDP-wrapped model.
    """
    from torch.nn.parallel import DistributedDataParallel as DDP

    device_type = get_torch_device_type()
    local_rank = get_local_rank()
    if device_type in {"cuda", "xpu"}:
        return DDP(model, device_ids=[local_rank])
    return DDP(model)

wrap_model_for_fsdp(model, dtype='bfloat16', device_id=None, **kwargs) ⚓︎

Wrap model with FSDP 1 and mixed precision.

Parameters:

Name Type Description Default
model Module

Model to wrap (should already be on the correct device).

required
dtype str

Mixed-precision parameter dtype (e.g. "bf16").

'bfloat16'
device_id int | None

Explicit device ordinal for FSDP.

None
**kwargs Any

Extra keyword arguments forwarded to the FSDP constructor.

{}

Returns:

Type Description
Module

An FSDP-wrapped model.

Source code in src/ezpz/distributed.py
def wrap_model_for_fsdp(
    model: torch.nn.Module,
    dtype: str = "bfloat16",
    device_id: int | None = None,
    **kwargs: Any,
) -> torch.nn.Module:
    """Wrap *model* with FSDP 1 and mixed precision.

    Args:
        model: Model to wrap (should already be on the correct device).
        dtype: Mixed-precision parameter dtype (e.g. ``"bf16"``).
        device_id: Explicit device ordinal for FSDP.
        **kwargs: Extra keyword arguments forwarded to the FSDP constructor.

    Returns:
        An FSDP-wrapped model.
    """
    return _wrap_fsdp(model, dtype=dtype, device_id=device_id, **kwargs)

wrap_model_for_fsdp2(model, dtype='bfloat16', device_mesh=None, **kwargs) ⚓︎

Wrap model with FSDP2 (per-module fully_shard).

.. note:: Experimental -- the FSDP2 API is subject to change in future PyTorch releases.

Parameters:

Name Type Description Default
model Module

Model to wrap.

required
dtype str

Mixed-precision parameter dtype (e.g. "bf16").

'bfloat16'
device_mesh Any

Optional :class:torch.distributed.device_mesh.DeviceMesh.

None
**kwargs Any

Extra keyword arguments forwarded to fully_shard.

{}

Returns:

Type Description
Module

The model after applying fully_shard to every sub-module.

Source code in src/ezpz/distributed.py
def wrap_model_for_fsdp2(
    model: torch.nn.Module,
    dtype: str = "bfloat16",
    device_mesh: Any = None,
    **kwargs: Any,
) -> torch.nn.Module:
    """Wrap *model* with FSDP2 (per-module ``fully_shard``).

    .. note:: **Experimental** -- the FSDP2 API is subject to change in
       future PyTorch releases.

    Args:
        model: Model to wrap.
        dtype: Mixed-precision parameter dtype (e.g. ``"bf16"``).
        device_mesh: Optional :class:`torch.distributed.device_mesh.DeviceMesh`.
        **kwargs: Extra keyword arguments forwarded to ``fully_shard``.

    Returns:
        The model after applying ``fully_shard`` to every sub-module.
    """
    return _wrap_fsdp2(model, dtype=dtype, device_mesh=device_mesh, **kwargs)

write_hostfile_from_list_of_hosts(hosts, hostfile) ⚓︎

Write a hostfile from a list of hostnames.

Parameters:

Name Type Description Default
hosts Sequence[str]

Sequence of hostnames to write.

required
hostfile str | PathLike

Path to write to.

required
Source code in src/ezpz/distributed.py
def write_hostfile_from_list_of_hosts(
    hosts: Sequence[str],
    hostfile: str | os.PathLike,
) -> Path:
    """Write a hostfile from a list of hostnames.

    Args:
        hosts: Sequence of hostnames to write.
        hostfile: Path to write to.
    """
    hfp = Path(hostfile)
    hfp.parent.mkdir(parents=True, exist_ok=True)
    hfp.write_text("\n".join(hosts) + "\n")
    return hfp

write_localhost_to_hostfile(hostfile) ⚓︎

Write the current hostname to hostfile (rank 0 only).

Parameters:

Name Type Description Default
hostfile str | PathLike

Path to write to.

required
Source code in src/ezpz/distributed.py
def write_localhost_to_hostfile(hostfile: str | os.PathLike) -> None:
    """Write the current hostname to *hostfile* (rank 0 only).

    Args:
        hostfile: Path to write to.
    """
    if get_rank() == 0:
        hn = get_hostname()
        Path(hostfile).write_text(hn)