Skip to content

ezpz.slurm⚓︎

ezpz/slurm.py

SLURM scheduler utilities: job discovery, nodefile generation, and launch command construction.

Prefers SLURM_JOB_ID / SLURM_NODELIST environment variables (always set inside a SLURM allocation) and falls back to sacct / scontrol shell commands only when the env vars are absent (e.g. on a login node).

build_launch_cmd(ngpus=None, nhosts=None, ngpu_per_host=None, hostfile=None, cpu_bind=None) ⚓︎

Build an srun command to launch a distributed job on SLURM.

Resolution order for node count
  1. Explicit nhosts argument.
  2. Line count from hostfile.
  3. SLURM_NNODES env var.
  4. Active job discovery via sacct / scontrol (slow fallback).

Parameters⚓︎

ngpus : int, optional Total number of tasks (GPUs). If None, computed as nhosts * ngpu_per_host. nhosts : int, optional Number of nodes. If None, inferred from hostfile, env vars, or the active job. ngpu_per_host : int, optional GPUs per node. If None, detected via ezpz.get_gpus_per_node(). hostfile : path-like, optional Path to a hostfile (one hostname per line). cpu_bind : str, optional CPU binding policy (e.g. "verbose,list:0-7"). Passed as --cpu-bind=<value> to srun when provided.

Source code in src/ezpz/slurm.py
def build_launch_cmd(
    ngpus: Optional[int] = None,
    nhosts: Optional[int] = None,
    ngpu_per_host: Optional[int] = None,
    hostfile: Optional[Union[str, Path, os.PathLike]] = None,
    cpu_bind: Optional[str] = None,
) -> str:
    """Build an ``srun`` command to launch a distributed job on SLURM.

    Resolution order for node count:
      1. Explicit *nhosts* argument.
      2. Line count from *hostfile*.
      3. ``SLURM_NNODES`` env var.
      4. Active job discovery via ``sacct`` / ``scontrol`` (slow fallback).

    Parameters
    ----------
    ngpus : int, optional
        Total number of tasks (GPUs). If ``None``, computed as
        ``nhosts * ngpu_per_host``.
    nhosts : int, optional
        Number of nodes. If ``None``, inferred from *hostfile*, env vars,
        or the active job.
    ngpu_per_host : int, optional
        GPUs per node. If ``None``, detected via
        ``ezpz.get_gpus_per_node()``.
    hostfile : path-like, optional
        Path to a hostfile (one hostname per line).
    cpu_bind : str, optional
        CPU binding policy (e.g. ``"verbose,list:0-7"``).  Passed as
        ``--cpu-bind=<value>`` to ``srun`` when provided.
    """
    if ngpu_per_host is None:
        ngpu_per_host = ezpz.get_gpus_per_node()

    if nhosts is not None:
        num_nodes = nhosts
    elif ngpus is not None and ngpu_per_host is not None and ngpu_per_host > 0:
        # Infer node count from total GPUs and per-node count, matching the
        # PBS launcher logic so `-n 2 -ppn 2` means 1 node, not max.
        if ngpus % ngpu_per_host != 0:
            raise ValueError(
                f"`ngpus` must be divisible by `ngpu_per_host`: "
                f"ngpus={ngpus}, ngpu_per_host={ngpu_per_host}"
            )
        num_nodes = ngpus // ngpu_per_host
    elif hostfile is not None and Path(hostfile).is_file():
        with open(hostfile) as f:
            num_nodes = len([ln for ln in f if ln.strip()])
    else:
        # Try SLURM_NNODES env var before expensive sacct/scontrol calls.
        slurm_nnodes = os.environ.get("SLURM_NNODES")
        if slurm_nnodes is not None:
            num_nodes = int(slurm_nnodes)
        else:
            running_jobid = get_slurm_jobid_of_active_job()
            if running_jobid is None:
                raise ValueError(
                    "No running SLURM job found for current user."
                )
            nodelist = get_nodelist_from_slurm_jobid(running_jobid)
            if not nodelist:
                raise ValueError(
                    f"No nodelist found for jobid {running_jobid}"
                )
            num_nodes = len(nodelist)

    total_gpus = ngpus if ngpus is not None else num_nodes * ngpu_per_host

    if total_gpus <= 0:
        raise ValueError(
            f"Total tasks must be positive, got {total_gpus} "
            f"(num_nodes={num_nodes}, ngpu_per_host={ngpu_per_host}). "
            f"On CPU-only machines, pass ngpus explicitly."
        )

    cmd = f"srun -u --verbose -N{num_nodes} -n{total_gpus}"
    if ngpu_per_host > 0:
        cmd += f" --gpus-per-node={ngpu_per_host}"
    if cpu_bind is not None:
        cmd += f" --cpu-bind={cpu_bind}"
    return cmd

get_nodelist_from_slurm_jobid(jobid) ⚓︎

Get the expanded nodelist for jobid.

Checks SLURM_NODELIST first (instant, no subprocess). Falls back to scontrol show job <jobid> when the env var is absent.

Source code in src/ezpz/slurm.py
def get_nodelist_from_slurm_jobid(jobid: str | int) -> list[str]:
    """Get the expanded nodelist for *jobid*.

    Checks ``SLURM_NODELIST`` first (instant, no subprocess).  Falls back
    to ``scontrol show job <jobid>`` when the env var is absent.
    """
    # Fast path: env var is set inside every SLURM allocation.
    nodelist_str = os.environ.get("SLURM_NODELIST")
    if nodelist_str:
        return _expand_slurm_nodelist(nodelist_str)

    # Slow path: query scontrol for an arbitrary job ID.
    try:
        from sh import scontrol  # type:ignore
    except Exception as e:
        logger.error("Error importing sh.scontrol: %s", e)
        raise e
    try:
        output = scontrol("show", "job", str(jobid)).split("\n")
        # scontrol can return multiple NodeList= lines; skip "(null)" entries
        # which appear for batch job wrappers before the real allocation.
        best_match: str | None = None
        for line in output:
            m = re.search(r"NodeList=([^\s]+)", line)
            if m and m.group(1) != "(null)":
                best_match = m.group(1)
                break
        if not best_match:
            raise ValueError(
                f"NodeList not found (or all entries are (null)) "
                f"in scontrol output for job {jobid}"
            )
        return _expand_slurm_nodelist(best_match)
    except Exception as e:
        logger.error(f"Error getting nodelist for job {jobid}: {e}")
        raise e

get_slurm_jobid_of_active_job() ⚓︎

Get the job ID of the currently active SLURM job.

Checks SLURM_JOB_ID / SLURM_JOBID env vars first (instant). Falls back to sacct + hostname matching only when the env vars are absent.

Source code in src/ezpz/slurm.py
def get_slurm_jobid_of_active_job() -> str | None:
    """Get the job ID of the currently active SLURM job.

    Checks ``SLURM_JOB_ID`` / ``SLURM_JOBID`` env vars first (instant).
    Falls back to ``sacct`` + hostname matching only when the env vars
    are absent.
    """
    # Fast path: env var is always set inside a SLURM allocation.
    jobid = os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID")
    if jobid:
        return str(jobid)

    # Slow path: query sacct and match hostname.
    import socket

    hostname = socket.getfqdn().split("-")[0]
    running_jobs = get_slurm_running_jobs_for_user()
    for jid, nodelist in running_jobs.items():
        logger.info(f"Checking jobid {jid} for hostname {hostname}...")
        if hostname in nodelist:
            logger.info(f"Found {hostname} in nodelist for {jid}")
            return str(jid)
    return None

get_slurm_nodefile_from_jobid(jobid) ⚓︎

Write a nodefile for jobid and return its path.

The file is written to ./nodefile-<jobid> in the current working directory with one hostname per line.

Source code in src/ezpz/slurm.py
def get_slurm_nodefile_from_jobid(jobid: int | str) -> str:
    """Write a nodefile for *jobid* and return its path.

    The file is written to ``./nodefile-<jobid>`` in the current
    working directory with one hostname per line.
    """
    assert jobid is not None, "Job ID must be provided."
    nodelist = get_nodelist_from_slurm_jobid(jobid)
    nodefile = Path(os.getcwd()).joinpath(f"nodefile-{jobid}")
    logger.info(f"Writing {nodelist} to {nodefile}")
    with nodefile.open("w") as f:
        for hn in nodelist:
            f.write(f"{hn}\n")

    return nodefile.absolute().resolve().as_posix()

get_slurm_nodefile_of_active_job() ⚓︎

Get the nodefile of the currently active job.

Source code in src/ezpz/slurm.py
def get_slurm_nodefile_of_active_job() -> str | None:
    """Get the nodefile of the currently active job."""
    jobid = get_slurm_jobid_of_active_job()
    if jobid is not None:
        return get_slurm_nodefile_from_jobid(jobid)
    return None

get_slurm_running_jobs() ⚓︎

Get the running jobs from sacct.

Returns a deduplicated list of base job IDs (strings) that are currently in RUNNING state, or None if sacct is unavailable.

Source code in src/ezpz/slurm.py
def get_slurm_running_jobs() -> list[str] | None:
    """Get the running jobs from ``sacct``.

    Returns a deduplicated list of base job IDs (strings) that are
    currently in ``RUNNING`` state, or ``None`` if ``sacct`` is
    unavailable.
    """
    try:
        from sh import sacct  # type:ignore
    except (ImportError, ModuleNotFoundError):
        logger.warning("sacct unavailable (sh package not installed)")
        return None

    try:
        return list(
            {
                i.replace(".", " ").split(" ")[0]
                for i in [j for j in sacct().split("\n") if " RUNNING " in j]
            }
        )
    except Exception as e:
        logger.error("Error getting running jobs from sacct: %s", e)
        return None

get_slurm_running_jobs_for_user() ⚓︎

Get all running jobs for the current user.

Returns a dict mapping job-ID strings to their expanded nodelists.

Source code in src/ezpz/slurm.py
def get_slurm_running_jobs_for_user() -> dict[str, list[str]]:
    """Get all running jobs for the current user.

    Returns a dict mapping job-ID strings to their expanded nodelists.
    """
    running_jobs = get_slurm_running_jobs()
    jobs = {}
    if running_jobs is not None:
        for job in running_jobs:
            jobs[job] = get_nodelist_from_slurm_jobid(job)
    return jobs