ezpz.utils.yeet_env⚓︎
Distribute a Python environment to worker nodes via parallel rsync.
Default behavior: auto-detect the active Python environment and rsync
it to /tmp/<env-name>/ on every node in the current job allocation.
Usage::
# Rsync the active venv to all worker nodes:
ezpz yeet-env
# Rsync a specific path:
ezpz yeet-env --src /path/to/env
# Custom destination:
ezpz yeet-env --dst /local/fast/storage/myenv
# Preview without syncing:
ezpz yeet-env --dry-run
main(argv=None)
⚓︎
CLI entry point.
parse_args(argv=None)
⚓︎
Parse ezpz yeet command-line arguments.
Source code in src/ezpz/utils/yeet_env.py
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
"""Parse ezpz yeet command-line arguments."""
parser = argparse.ArgumentParser(
prog="ezpz yeet",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=(
"Distribute files (envs, models, datasets, etc.) to worker nodes "
"via parallel rsync. By default (no args), rsyncs the active "
"venv/conda env to /tmp/<env-name>/ on all nodes in the current "
"job allocation. Pass any path to yeet arbitrary content."
),
)
parser.add_argument(
"src_positional",
nargs="?",
default=None,
metavar="SRC",
help=(
"Source path (positional shorthand for --src). Mutually "
"exclusive with --src. May be a directory OR a .tar.gz/.tgz "
"file."
),
)
parser.add_argument(
"--src",
type=str,
default=None,
help=(
"Source path (defaults to the active venv/conda env). May "
"be a directory OR a .tar.gz/.tgz file — in the latter "
"case the tarball is copied to /tmp/ and extracted there, "
"skipping the create step that --compress does."
),
)
parser.add_argument(
"--dst",
type=str,
default=None,
help="Destination path on worker nodes (defaults to /tmp/<env-name>/).",
)
parser.add_argument(
"--hostfile",
type=str,
default=None,
help="Hostfile to read node list from (auto-detected from scheduler when omitted).",
)
parser.add_argument(
"--copy",
action="store_true",
help=(
"Use 'cp -a' instead of rsync for the local copy "
"(Lustre → /tmp/). Faster for initial copies of large "
"environments with many small files. Remote node "
"distribution still uses rsync."
),
)
parser.add_argument(
"--compress",
action="store_true",
help=(
"Create a .tar.gz archive, copy it to /tmp/, then extract. "
"Reduces Lustre I/O from millions of small-file reads to "
"one sequential read. Remote distribution still uses rsync."
),
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be synced without doing it.",
)
args, unknown = parser.parse_known_args(argv)
if unknown:
# Tolerate one stray "yeet-env"/"yeet" token leaked from old
# entry points; reject anything else.
leftover = [a for a in unknown if a not in ("yeet", "yeet-env")]
if leftover:
parser.error(
f"unrecognized arguments: {' '.join(leftover)}"
)
# Mutex: positional SRC and --src can't both be set.
if args.src_positional is not None and args.src is not None:
parser.error(
"--src and positional SRC are mutually exclusive; "
"pick one"
)
if args.src is None:
args.src = args.src_positional
return args
pick_source(source_active, max_per_source, *, rng=None)
⚓︎
Pick a least-loaded source under the per-source cap.
Ties are broken with the supplied rng (defaults to a fresh
random.Random() per call) so the greedy fan-out actually fans
out — without randomization, dict.items() order would always
pick the same source first, pinning all early traffic to one node
and defeating the tree distribution.
Returns None if every source is at capacity (the caller should
wait for an in-flight sync to finish, freeing a slot).
Pulled out of run() so tests can exercise the algorithm directly instead of reconstructing it inline.
Source code in src/ezpz/utils/yeet_env.py
def pick_source(
source_active: dict[str, int],
max_per_source: int,
*,
rng: random.Random | None = None,
) -> str | None:
"""Pick a least-loaded source under the per-source cap.
Ties are broken with the supplied ``rng`` (defaults to a fresh
``random.Random()`` per call) so the greedy fan-out actually fans
out — without randomization, ``dict.items()`` order would always
pick the same source first, pinning all early traffic to one node
and defeating the tree distribution.
Returns ``None`` if every source is at capacity (the caller should
wait for an in-flight sync to finish, freeing a slot).
Pulled out of run() so tests can exercise the algorithm directly
instead of reconstructing it inline.
"""
candidates = [
s for s, count in source_active.items() if count < max_per_source
]
if not candidates:
return None
min_count = min(source_active[s] for s in candidates)
least_loaded = [s for s in candidates if source_active[s] == min_count]
if rng is None:
rng = random.Random()
return rng.choice(least_loaded)
run(argv=None)
⚓︎
Main entry point for yeet-env.
Source code in src/ezpz/utils/yeet_env.py
def run(argv: Optional[Sequence[str]] = None) -> int:
"""Main entry point for yeet-env."""
args = parse_args(argv)
# ── Resolve source ──────────────────────────────────────────────
if args.src is not None:
src = Path(args.src).resolve()
if not src.exists():
logger.error("Source path does not exist: %s", src)
return 1
else:
src = _detect_env_source()
# Convenience hint: if a fresher same-named tarball exists
# next to the env or in cwd, point it out — tarball broadcast
# is ~10× faster at scale than per-file rsync. Don't auto-pick
# it; explicit is safer (the tarball might be stale).
_suggest_tarball_if_present(src)
# If --src is a .tar.gz / .tgz file, treat it as a pre-built
# archive: skip the "tar create" step and just copy + extract.
src_is_tarball = src.is_file() and (
src.name.endswith(".tar.gz") or src.name.endswith(".tgz")
)
if src_is_tarball:
# Strip .tar.gz / .tgz suffix to derive the destination name
env_name = src.name
for suffix in (".tar.gz", ".tgz"):
if env_name.endswith(suffix):
env_name = env_name[: -len(suffix)]
break
else:
env_name = src.name
# ── Resolve destination ─────────────────────────────────────────
if args.dst is not None:
dst = Path(args.dst)
else:
dst = Path("/tmp") / env_name
# ── Discover nodes ──────────────────────────────────────────────
nodes = _get_worker_nodes(hostfile=args.hostfile)
current = _get_current_hostname()
if not nodes:
logger.error("No worker nodes found.")
return 1
# Copy locally first (current node), then rsync to remote nodes.
# The local copy is needed because /tmp is node-local.
needs_local_copy = not str(src).startswith("/tmp")
# Filter out the current node — also handle the HSN variant
# (current node may appear as "node01" while nodes contain "node01-hsn0").
current_variants = {current, current + "-hsn0", current.removesuffix("-hsn0")}
remote_nodes = [n for n in nodes if n not in current_variants]
# ── Print summary ───────────────────────────────────────────────
env_size = _get_env_size(src)
total_nodes = (1 if needs_local_copy else 0) + len(remote_nodes)
logger.info("Source: %s (%s)", src, env_size)
logger.info("Target: %s/ on %d node(s)", dst, total_nodes)
if needs_local_copy:
logger.info(" local: %s (rsync to %s/)", current, dst)
if remote_nodes:
if len(remote_nodes) <= 6:
logger.info(" remote: %s", ", ".join(remote_nodes))
else:
shown = ', '.join(remote_nodes[:3])
logger.info(" remote: %s, ... (%d nodes)", shown, len(remote_nodes))
if args.dry_run:
logger.info("[dry-run] No files transferred.")
return 0
if total_nodes == 0:
logger.info("Nothing to sync (source is already in %s).", dst)
return 0
# ── Sync ────────────────────────────────────────────────────────
#
# Greedy tree distribution: instead of all N nodes pulling from the
# source (which saturates the source node's NIC), each completed
# node immediately becomes a source for others. A single thread
# pool runs for the entire sync — as soon as any rsync finishes,
# new rsyncs are submitted using the newly-available source.
#
# Each source has a concurrency cap (MAX_PER_SOURCE) so no single
# node is overwhelmed. The tree grows organically: the first node
# seeds a few targets, each of those fans out to more, etc.
MAX_PER_SOURCE = 8 # max concurrent outbound rsyncs per source node
all_nodes: list[str] = []
if needs_local_copy:
all_nodes.append(current)
all_nodes.extend(remote_nodes)
total = len(all_nodes)
progress = _AggregateProgress(total_nodes=total)
results: list[tuple[str, float, int]] = []
logger.info("Syncing (%d nodes)...", total)
t0 = time.perf_counter()
# Step 1: copy source to local /tmp/ and patch paths ONCE.
# All subsequent rsyncs distribute the already-patched copy.
if needs_local_copy:
_local_t0 = time.perf_counter()
def _spinner(label: str) -> None:
"""Reusable spinner that shows label + elapsed time."""
if not _IS_TTY:
return
elapsed = time.perf_counter() - _local_t0
frames = _AggregateProgress._FRAMES
idx = int(elapsed * 2) % len(frames)
sys.stdout.write(f"\r\033[K {frames[idx]} {label} [{elapsed:.0f}s]")
sys.stdout.flush()
if src_is_tarball:
# Source is already a .tar.gz/.tgz: copy it to /tmp/ and
# extract there. Skips the "tar create" step that --compress
# does. Useful when you already have a pre-built tarball
# (e.g. from `ezpz tar-env`) on a shared filesystem.
method = "tar.gz (pre-built)"
local_tarball = Path("/tmp") / src.name
print()
try:
tb_size_gb = src.stat().st_size / (1024**3)
except OSError:
tb_size_gb = 0.0
# Step 1: copy tarball Lustre → /tmp/
_spinner(f"copying {src.name} ({tb_size_gb:.1f}G) → /tmp/")
cp_proc = subprocess.Popen(
["cp", str(src), str(local_tarball)],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
while cp_proc.poll() is None:
_spinner(f"copying {src.name} ({tb_size_gb:.1f}G) → /tmp/")
time.sleep(0.5)
if cp_proc.returncode != 0:
stderr = (cp_proc.stderr.read() or b"").decode()
logger.warning("cp tarball failed: %s", stderr.strip())
local_elapsed = time.perf_counter() - _local_t0
local_rc = cp_proc.returncode or 1
# Partial copy may have left a truncated tarball behind
_cleanup_path(local_tarball)
else:
# Step 2: extract into dst
if dst.exists() and not _safe_rmtree(dst):
local_elapsed = time.perf_counter() - _local_t0
local_rc = 1
# We never started extracting, but the local
# tarball still needs cleanup.
_cleanup_path(local_tarball)
else:
dst.mkdir(parents=True, exist_ok=True)
_spinner(f"extracting {local_tarball.name} → {dst}/")
tar_extract = subprocess.Popen(
["tar", "-xzf", str(local_tarball),
"--strip-components=1", "-C", str(dst)],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
while tar_extract.poll() is None:
_spinner(f"extracting {local_tarball.name} → {dst}/")
time.sleep(0.5)
local_elapsed = time.perf_counter() - _local_t0
local_rc = tar_extract.returncode or 0
if local_rc != 0:
stderr = (tar_extract.stderr.read() or b"").decode()
logger.warning("tar extract failed: %s", stderr.strip())
# Half-extracted dst is unusable — drop it.
# We just wrote it ourselves, so skip the
# /tmp-only safety guard.
_remove_partial_dst(dst)
# Always clean up the local tarball copy whether
# extraction succeeded or failed.
_cleanup_path(local_tarball)
if _IS_TTY:
sys.stdout.write("\r\033[K")
elif args.compress:
# tar.gz: compress on Lustre (sequential write), copy one
# file to /tmp/ (sequential read), extract locally.
# Much less Lustre metadata pressure than per-file rsync/cp.
method = "tar.gz"
tarball = Path(f"/tmp/{env_name}.tar.gz")
print()
# Step 1: create archive from source on Lustre
_spinner(f"tar -czf {tarball.name} (compressing)")
tar_create = subprocess.Popen(
[
"tar", "-czf", str(tarball),
"--exclude=__pycache__",
"-C", str(src.parent), src.name,
],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
while tar_create.poll() is None:
_spinner(f"tar -czf {tarball.name} (compressing)")
time.sleep(0.5)
if tar_create.returncode != 0:
stderr = (tar_create.stderr.read() or b"").decode()
logger.warning("tar create failed: %s", stderr.strip())
local_elapsed = time.perf_counter() - _local_t0
local_rc = tar_create.returncode or 1
# Partial archive: don't ship a corrupt tarball.
_cleanup_path(tarball)
else:
# Show tarball size
try:
tb_size = tarball.stat().st_size / (1024**3)
_spinner(f"tar.gz: {tb_size:.1f}G")
except OSError:
pass
# Step 2: extract into /tmp/
if dst.exists() and not _safe_rmtree(dst):
local_elapsed = time.perf_counter() - _local_t0
local_rc = 1
_cleanup_path(tarball)
else:
dst.mkdir(parents=True, exist_ok=True)
_spinner(f"extracting {tarball.name} → {dst}/")
tar_extract = subprocess.Popen(
["tar", "-xzf", str(tarball),
"--strip-components=1", "-C", str(dst)],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
while tar_extract.poll() is None:
_spinner(f"extracting {tarball.name} → /tmp/")
time.sleep(0.5)
local_elapsed = time.perf_counter() - _local_t0
local_rc = tar_extract.returncode or 0
if local_rc != 0:
stderr = (tar_extract.stderr.read() or b"").decode()
logger.warning("tar extract failed: %s", stderr.strip())
# Half-extracted dst is unusable — drop it
_remove_partial_dst(dst)
_cleanup_path(tarball)
if _IS_TTY:
sys.stdout.write("\r\033[K")
elif args.copy:
# cp -a: faster than rsync for large venvs on parallel
# filesystems (sequential directory walk vs per-file stat).
method = "cp"
print()
_spinner("cp -a → /tmp/")
dst.parent.mkdir(parents=True, exist_ok=True)
if dst.exists() and not _safe_rmtree(dst):
local_elapsed = time.perf_counter() - _local_t0
local_rc = 1
else:
cp_proc = subprocess.Popen(
["cp", "-a", str(src), str(dst)],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
while cp_proc.poll() is None:
_spinner("cp -a → /tmp/")
time.sleep(0.5)
local_elapsed = time.perf_counter() - _local_t0
local_rc = cp_proc.returncode or 0
stderr = (cp_proc.stderr.read() or b"").decode()
if local_rc != 0:
logger.warning("cp failed (exit %d): %s", local_rc, stderr.strip())
# Half-copied dst is unusable — drop it
_remove_partial_dst(dst)
if _IS_TTY:
sys.stdout.write("\r\033[K")
else:
method = "rsync"
def _local_progress(pct: str = "", speed: str = "", eta: str = "") -> None:
if not _IS_TTY:
return
elapsed = time.perf_counter() - _local_t0
parts = ["Copying to local /tmp/"]
if pct:
parts.append(pct)
if speed:
parts.append(speed)
parts.append(f"[{elapsed:.0f}s]")
msg = " " + " ".join(parts)
sys.stdout.write(f"\r\033[K{msg}")
sys.stdout.flush()
print()
_local_progress()
_, local_elapsed, local_rc = _rsync_to_node(
src, dst, current, local=True,
progress_callback=_local_progress,
)
if _IS_TTY:
sys.stdout.write("\r\033[K")
if local_rc == 0:
# Patching only applies to venvs \u2014 it's a no-op for other
# sources, but skip explicitly to avoid noise.
if (dst / "bin" / "activate").exists():
_patch_venv_paths_local(dst, src)
print(f" \u2713 {current} (local, {method}) \u2014 {local_elapsed:.1f}s")
results.append((current, local_elapsed, local_rc))
else:
print(f" \u2717 {current} (local, {method}) \u2014 FAILED")
results.append((current, local_elapsed, local_rc))
# No valid local copy — abort, don't distribute a broken env
logger.error("Local copy failed — aborting distribution.")
return 1
remaining = [n for n in all_nodes if n != current]
# Track per-source active rsync count to enforce MAX_PER_SOURCE.
source_active: dict[str, int] = {current: 0}
source_lock = threading.Lock()
# Function-scoped RNG so we don't disturb the global random state
# of any caller that has seeded it for their own reasons.
pick_rng = random.Random()
def _submit_work(pool: ThreadPoolExecutor, futures: dict) -> None: # type: ignore[type-arg]
"""Submit as many rsyncs as sources allow."""
while remaining:
with source_lock:
src_node = pick_source(
source_active, MAX_PER_SOURCE, rng=pick_rng,
)
if src_node is None:
break # all sources at capacity — wait for completions
target = remaining.pop(0)
source_active[src_node] += 1
remote_src = None if src_node == current else src_node
fut = pool.submit(
_rsync_to_node,
dst, # always rsync from the /tmp/ copy
dst,
target,
from_node=remote_src,
progress_callback=progress.update,
)
futures[fut] = (target, src_node)
# Step 2: greedy fan-out using a single persistent pool.
with ThreadPoolExecutor(max_workers=min(total, 128)) as pool:
futures: dict = {}
_submit_work(pool, futures)
while futures:
# Wait for the next completion
done_iter = as_completed(futures)
fut = next(done_iter)
n, elapsed, rc = fut.result()
_, src_used = futures.pop(fut)
with source_lock:
source_active[src_used] -= 1
if rc == 0:
# This node now has the data — register as a source
source_active[n] = 0
label = f"{n} (local)" if n == current else n
progress.mark_done(label, elapsed, rc)
results.append((n, elapsed, rc))
# Submit more work now that a source slot freed up
# (and possibly a new source appeared)
_submit_work(pool, futures)
progress.clear()
total_elapsed = time.perf_counter() - t0
failed = sum(1 for _, _, rc in results if rc != 0)
if failed:
logger.warning("%d/%d node(s) failed!", failed, total_nodes)
else:
logger.info("Done in %.1fs", total_elapsed)
# ── Guidance ────────────────────────────────────────────────────
# Detect from `dst` so tarball sources (where the directory only
# exists after extraction) are classified correctly.
is_venv = (dst / "bin" / "activate").exists()
is_conda = (dst / "conda-meta").is_dir()
has_bin = (dst / "bin").is_dir()
print()
if is_venv:
print(f" To use this environment:")
print(f" deactivate 2>/dev/null")
print(f" source {dst}/bin/activate")
print()
print(f" Then launch your training (from a shared filesystem path):")
print(f" cd /path/to/your/project")
print(f" ezpz launch python3 -m your_app.train")
print()
print(f" Note: /tmp is node-local. Make sure your working directory")
print(f" is on a shared filesystem (e.g. Lustre) before launching,")
print(f" so all ranks can access data and outputs.")
elif is_conda:
print(f" To use this environment:")
print(f" conda deactivate")
print(f" conda activate {dst}")
print()
print(f" Then launch your training (from a shared filesystem path):")
print(f" cd /path/to/your/project")
print(f" ezpz launch python3 -m your_app.train")
print()
print(f" Note: /tmp is node-local. Make sure your working directory")
print(f" is on a shared filesystem (e.g. Lustre) before launching,")
print(f" so all ranks can access data and outputs.")
else:
successful = len(results) - failed
print(f" Synced to {dst}/ on {successful} node(s).")
if has_bin:
print(f" (looks like a tool directory — add to PATH if needed:")
print(f" export PATH={dst}/bin:$PATH)")
print()
print(f" Note: /tmp is node-local. Reference the synced path on each")
print(f" worker (e.g. {dst}) — the shared-filesystem source path will")
print(f" not see writes from worker nodes.")
return 1 if failed else 0