ezpz.slurm⚓︎
- See ezpz/
slurm.py
ezpz/slurm.py
SLURM scheduler utilities: job discovery, nodefile generation, and launch command construction.
Prefers SLURM_JOB_ID / SLURM_NODELIST environment variables (always
set inside a SLURM allocation) and falls back to sacct / scontrol
shell commands only when the env vars are absent (e.g. on a login node).
build_launch_cmd(ngpus=None, nhosts=None, ngpu_per_host=None, hostfile=None, cpu_bind=None)
⚓︎
Build an srun command to launch a distributed job on SLURM.
Resolution order for node count
- Explicit nhosts argument.
- Line count from hostfile.
SLURM_NNODESenv var.- Active job discovery via
sacct/scontrol(slow fallback).
Parameters⚓︎
ngpus : int, optional
Total number of tasks (GPUs). If None, computed as
nhosts * ngpu_per_host.
nhosts : int, optional
Number of nodes. If None, inferred from hostfile, env vars,
or the active job.
ngpu_per_host : int, optional
GPUs per node. If None, detected via
ezpz.get_gpus_per_node().
hostfile : path-like, optional
Path to a hostfile (one hostname per line).
cpu_bind : str, optional
CPU binding policy (e.g. "verbose,list:0-7"). Passed as
--cpu-bind=<value> to srun when provided.
Source code in src/ezpz/slurm.py
def build_launch_cmd(
ngpus: Optional[int] = None,
nhosts: Optional[int] = None,
ngpu_per_host: Optional[int] = None,
hostfile: Optional[Union[str, Path, os.PathLike]] = None,
cpu_bind: Optional[str] = None,
) -> str:
"""Build an ``srun`` command to launch a distributed job on SLURM.
Resolution order for node count:
1. Explicit *nhosts* argument.
2. Line count from *hostfile*.
3. ``SLURM_NNODES`` env var.
4. Active job discovery via ``sacct`` / ``scontrol`` (slow fallback).
Parameters
----------
ngpus : int, optional
Total number of tasks (GPUs). If ``None``, computed as
``nhosts * ngpu_per_host``.
nhosts : int, optional
Number of nodes. If ``None``, inferred from *hostfile*, env vars,
or the active job.
ngpu_per_host : int, optional
GPUs per node. If ``None``, detected via
``ezpz.get_gpus_per_node()``.
hostfile : path-like, optional
Path to a hostfile (one hostname per line).
cpu_bind : str, optional
CPU binding policy (e.g. ``"verbose,list:0-7"``). Passed as
``--cpu-bind=<value>`` to ``srun`` when provided.
"""
if ngpu_per_host is None:
ngpu_per_host = ezpz.get_gpus_per_node()
if nhosts is not None:
num_nodes = nhosts
elif ngpus is not None and ngpu_per_host is not None and ngpu_per_host > 0:
# Infer node count from total GPUs and per-node count, matching the
# PBS launcher logic so `-n 2 -ppn 2` means 1 node, not max.
if ngpus % ngpu_per_host != 0:
raise ValueError(
f"`ngpus` must be divisible by `ngpu_per_host`: "
f"ngpus={ngpus}, ngpu_per_host={ngpu_per_host}"
)
num_nodes = ngpus // ngpu_per_host
elif hostfile is not None and Path(hostfile).is_file():
with open(hostfile) as f:
num_nodes = len([ln for ln in f if ln.strip()])
else:
# Try SLURM_NNODES env var before expensive sacct/scontrol calls.
slurm_nnodes = os.environ.get("SLURM_NNODES")
if slurm_nnodes is not None:
num_nodes = int(slurm_nnodes)
else:
running_jobid = get_slurm_jobid_of_active_job()
if running_jobid is None:
raise ValueError(
"No running SLURM job found for current user."
)
nodelist = get_nodelist_from_slurm_jobid(running_jobid)
if not nodelist:
raise ValueError(
f"No nodelist found for jobid {running_jobid}"
)
num_nodes = len(nodelist)
total_gpus = ngpus if ngpus is not None else num_nodes * ngpu_per_host
if total_gpus <= 0:
raise ValueError(
f"Total tasks must be positive, got {total_gpus} "
f"(num_nodes={num_nodes}, ngpu_per_host={ngpu_per_host}). "
f"On CPU-only machines, pass ngpus explicitly."
)
cmd = f"srun -u --verbose -N{num_nodes} -n{total_gpus}"
if ngpu_per_host > 0:
cmd += f" --gpus-per-node={ngpu_per_host}"
if cpu_bind is not None:
cmd += f" --cpu-bind={cpu_bind}"
return cmd
get_nodelist_from_slurm_jobid(jobid)
⚓︎
Get the expanded nodelist for jobid.
Checks SLURM_NODELIST first (instant, no subprocess). Falls back
to scontrol show job <jobid> when the env var is absent.
Source code in src/ezpz/slurm.py
def get_nodelist_from_slurm_jobid(jobid: str | int) -> list[str]:
"""Get the expanded nodelist for *jobid*.
Checks ``SLURM_NODELIST`` first (instant, no subprocess). Falls back
to ``scontrol show job <jobid>`` when the env var is absent.
"""
# Fast path: env var is set inside every SLURM allocation.
nodelist_str = os.environ.get("SLURM_NODELIST")
if nodelist_str:
return _expand_slurm_nodelist(nodelist_str)
# Slow path: query scontrol for an arbitrary job ID.
try:
from sh import scontrol # type:ignore
except Exception as e:
logger.error("Error importing sh.scontrol: %s", e)
raise e
try:
output = scontrol("show", "job", str(jobid)).split("\n")
# scontrol can return multiple NodeList= lines; skip "(null)" entries
# which appear for batch job wrappers before the real allocation.
best_match: str | None = None
for line in output:
m = re.search(r"NodeList=([^\s]+)", line)
if m and m.group(1) != "(null)":
best_match = m.group(1)
break
if not best_match:
raise ValueError(
f"NodeList not found (or all entries are (null)) "
f"in scontrol output for job {jobid}"
)
return _expand_slurm_nodelist(best_match)
except Exception as e:
logger.error(f"Error getting nodelist for job {jobid}: {e}")
raise e
get_slurm_jobid_of_active_job()
⚓︎
Get the job ID of the currently active SLURM job.
Checks SLURM_JOB_ID / SLURM_JOBID env vars first (instant).
Falls back to sacct + hostname matching only when the env vars
are absent.
Source code in src/ezpz/slurm.py
def get_slurm_jobid_of_active_job() -> str | None:
"""Get the job ID of the currently active SLURM job.
Checks ``SLURM_JOB_ID`` / ``SLURM_JOBID`` env vars first (instant).
Falls back to ``sacct`` + hostname matching only when the env vars
are absent.
"""
# Fast path: env var is always set inside a SLURM allocation.
jobid = os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID")
if jobid:
return str(jobid)
# Slow path: query sacct and match hostname.
import socket
hostname = socket.getfqdn().split("-")[0]
running_jobs = get_slurm_running_jobs_for_user()
for jid, nodelist in running_jobs.items():
logger.info(f"Checking jobid {jid} for hostname {hostname}...")
if hostname in nodelist:
logger.info(f"Found {hostname} in nodelist for {jid}")
return str(jid)
return None
get_slurm_nodefile_from_jobid(jobid)
⚓︎
Write a nodefile for jobid and return its path.
The file is written to ./nodefile-<jobid> in the current
working directory with one hostname per line.
Source code in src/ezpz/slurm.py
def get_slurm_nodefile_from_jobid(jobid: int | str) -> str:
"""Write a nodefile for *jobid* and return its path.
The file is written to ``./nodefile-<jobid>`` in the current
working directory with one hostname per line.
"""
assert jobid is not None, "Job ID must be provided."
nodelist = get_nodelist_from_slurm_jobid(jobid)
nodefile = Path(os.getcwd()).joinpath(f"nodefile-{jobid}")
logger.info(f"Writing {nodelist} to {nodefile}")
with nodefile.open("w") as f:
for hn in nodelist:
f.write(f"{hn}\n")
return nodefile.absolute().resolve().as_posix()
get_slurm_nodefile_of_active_job()
⚓︎
Get the nodefile of the currently active job.
get_slurm_running_jobs()
⚓︎
Get the running jobs from sacct.
Returns a deduplicated list of base job IDs (strings) that are
currently in RUNNING state, or None if sacct is
unavailable.
Source code in src/ezpz/slurm.py
def get_slurm_running_jobs() -> list[str] | None:
"""Get the running jobs from ``sacct``.
Returns a deduplicated list of base job IDs (strings) that are
currently in ``RUNNING`` state, or ``None`` if ``sacct`` is
unavailable.
"""
try:
from sh import sacct # type:ignore
except (ImportError, ModuleNotFoundError):
logger.warning("sacct unavailable (sh package not installed)")
return None
try:
return list(
{
i.replace(".", " ").split(" ")[0]
for i in [j for j in sacct().split("\n") if " RUNNING " in j]
}
)
except Exception as e:
logger.error("Error getting running jobs from sacct: %s", e)
return None
get_slurm_running_jobs_for_user()
⚓︎
Get all running jobs for the current user.
Returns a dict mapping job-ID strings to their expanded nodelists.
Source code in src/ezpz/slurm.py
def get_slurm_running_jobs_for_user() -> dict[str, list[str]]:
"""Get all running jobs for the current user.
Returns a dict mapping job-ID strings to their expanded nodelists.
"""
running_jobs = get_slurm_running_jobs()
jobs = {}
if running_jobs is not None:
for job in running_jobs:
jobs[job] = get_nodelist_from_slurm_jobid(job)
return jobs