Skip to content

ezpz.utils.yeet_env⚓︎

Distribute a Python environment to worker nodes via parallel rsync.

Default behavior: auto-detect the active Python environment and rsync it to /tmp/<env-name>/ on every node in the current job allocation.

Usage::

# Rsync the active venv to all worker nodes:
ezpz yeet-env

# Rsync a specific path:
ezpz yeet-env --src /path/to/env

# Custom destination:
ezpz yeet-env --dst /local/fast/storage/myenv

# Preview without syncing:
ezpz yeet-env --dry-run

main(argv=None) ⚓︎

CLI entry point.

Source code in src/ezpz/utils/yeet_env.py
def main(argv: Optional[Sequence[str]] = None) -> int:
    """CLI entry point."""
    if Path(sys.argv[0]).name == "ezpz-yeet-env":
        print(
            "ezpz-yeet-env is deprecated; use 'ezpz yeet' as a drop-in "
            "replacement",
            file=sys.stderr,
        )
    return run(argv)

parse_args(argv=None) ⚓︎

Parse ezpz yeet command-line arguments.

Source code in src/ezpz/utils/yeet_env.py
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    """Parse ezpz yeet command-line arguments."""
    parser = argparse.ArgumentParser(
        prog="ezpz yeet",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=(
            "Distribute files (envs, models, datasets, etc.) to worker nodes "
            "via parallel rsync. By default (no args), rsyncs the active "
            "venv/conda env to /tmp/<env-name>/ on all nodes in the current "
            "job allocation. Pass any path to yeet arbitrary content."
        ),
    )
    parser.add_argument(
        "src_positional",
        nargs="?",
        default=None,
        metavar="SRC",
        help=(
            "Source path (positional shorthand for --src). Mutually "
            "exclusive with --src. May be a directory OR a .tar.gz/.tgz "
            "file."
        ),
    )
    parser.add_argument(
        "--src",
        type=str,
        default=None,
        help=(
            "Source path (defaults to the active venv/conda env). May "
            "be a directory OR a .tar.gz/.tgz file — in the latter "
            "case the tarball is copied to /tmp/ and extracted there, "
            "skipping the create step that --compress does."
        ),
    )
    parser.add_argument(
        "--dst",
        type=str,
        default=None,
        help="Destination path on worker nodes (defaults to /tmp/<env-name>/).",
    )
    parser.add_argument(
        "--hostfile",
        type=str,
        default=None,
        help="Hostfile to read node list from (auto-detected from scheduler when omitted).",
    )
    parser.add_argument(
        "--copy",
        action="store_true",
        help=(
            "Use 'cp -a' instead of rsync for the local copy "
            "(Lustre → /tmp/). Faster for initial copies of large "
            "environments with many small files. Remote node "
            "distribution still uses rsync."
        ),
    )
    parser.add_argument(
        "--compress",
        action="store_true",
        help=(
            "Create a .tar.gz archive, copy it to /tmp/, then extract. "
            "Reduces Lustre I/O from millions of small-file reads to "
            "one sequential read. Remote distribution still uses rsync."
        ),
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Show what would be synced without doing it.",
    )

    # Proceed-with-spares: when set, yeet returns 0 even if some
    # nodes failed, as long as the success count meets the threshold.
    # Useful at scale where 1-2 nodes always fail with permanent SSH
    # issues and the user has spare allocation. The failed-nodes list
    # is written to $dst/.ezpz-yeet-failed-nodes.txt so downstream
    # tooling (training scripts, ezpz launch wrappers) can read +
    # exclude those hosts. Mutually exclusive flags.
    threshold = parser.add_mutually_exclusive_group()
    threshold.add_argument(
        "--min-success-nodes",
        type=int,
        default=None,
        metavar="N",
        help=(
            "Return success (rc=0) as long as at least N nodes "
            "received the rsync successfully, even if some "
            "failed. The failed-node list is written to "
            "$dst/.ezpz-yeet-failed-nodes.txt for downstream "
            "consumers. Mutually exclusive with "
            "--min-success-fraction."
        ),
    )
    threshold.add_argument(
        "--min-success-fraction",
        type=float,
        default=None,
        metavar="F",
        help=(
            "Same as --min-success-nodes but expressed as a "
            "fraction of the total node count (e.g. 0.95 = at "
            "least 95%% of nodes must succeed). Computed against "
            "the full node list including the local node. "
            "Mutually exclusive with --min-success-nodes."
        ),
    )

    args, unknown = parser.parse_known_args(argv)
    if unknown:
        # Tolerate one stray "yeet-env"/"yeet" token leaked from old
        # entry points; reject anything else.
        leftover = [a for a in unknown if a not in ("yeet", "yeet-env")]
        if leftover:
            parser.error(
                f"unrecognized arguments: {' '.join(leftover)}"
            )
    # Mutex: positional SRC and --src can't both be set.
    if args.src_positional is not None and args.src is not None:
        parser.error(
            "--src and positional SRC are mutually exclusive; "
            "pick one"
        )
    if args.src is None:
        args.src = args.src_positional

    # Validate the threshold args (argparse can't easily do range
    # validation for floats).
    if args.min_success_nodes is not None and args.min_success_nodes < 1:
        parser.error(
            f"--min-success-nodes must be >= 1, got {args.min_success_nodes}"
        )
    if args.min_success_fraction is not None and not (
        0.0 < args.min_success_fraction <= 1.0
    ):
        parser.error(
            f"--min-success-fraction must be in (0, 1], got "
            f"{args.min_success_fraction}"
        )

    return args

pick_source(source_active, max_per_source, *, rng=None) ⚓︎

Pick a least-loaded source under the per-source cap.

Ties are broken with the supplied rng (defaults to a fresh random.Random() per call) so the greedy fan-out actually fans out — without randomization, dict.items() order would always pick the same source first, pinning all early traffic to one node and defeating the tree distribution.

Returns None if every source is at capacity (the caller should wait for an in-flight sync to finish, freeing a slot).

Pulled out of run() so tests can exercise the algorithm directly instead of reconstructing it inline.

Source code in src/ezpz/utils/yeet_env.py
def pick_source(
    source_active: dict[str, int],
    max_per_source: int,
    *,
    rng: random.Random | None = None,
) -> str | None:
    """Pick a least-loaded source under the per-source cap.

    Ties are broken with the supplied ``rng`` (defaults to a fresh
    ``random.Random()`` per call) so the greedy fan-out actually fans
    out — without randomization, ``dict.items()`` order would always
    pick the same source first, pinning all early traffic to one node
    and defeating the tree distribution.

    Returns ``None`` if every source is at capacity (the caller should
    wait for an in-flight sync to finish, freeing a slot).

    Pulled out of run() so tests can exercise the algorithm directly
    instead of reconstructing it inline.
    """
    candidates = [
        s for s, count in source_active.items() if count < max_per_source
    ]
    if not candidates:
        return None
    min_count = min(source_active[s] for s in candidates)
    least_loaded = [s for s in candidates if source_active[s] == min_count]
    if rng is None:
        rng = random.Random()
    return rng.choice(least_loaded)

run(argv=None) ⚓︎

Main entry point for yeet-env.

Source code in src/ezpz/utils/yeet_env.py
def run(argv: Optional[Sequence[str]] = None) -> int:
    """Main entry point for yeet-env."""
    args = parse_args(argv)

    # ── Resolve source ──────────────────────────────────────────────
    if args.src is not None:
        src = Path(args.src).resolve()
        if not src.exists():
            logger.error("Source path does not exist: %s", src)
            return 1
    else:
        src = _detect_env_source()
        # Convenience hint: if a fresher same-named tarball exists
        # next to the env or in cwd, point it out — tarball broadcast
        # is ~10× faster at scale than per-file rsync. Don't auto-pick
        # it; explicit is safer (the tarball might be stale).
        _suggest_tarball_if_present(src)

    # If --src is a .tar.gz / .tgz file, treat it as a pre-built
    # archive: skip the "tar create" step and just copy + extract.
    src_is_tarball = src.is_file() and (
        src.name.endswith(".tar.gz") or src.name.endswith(".tgz")
    )
    if src_is_tarball:
        # Strip .tar.gz / .tgz suffix to derive the destination name
        env_name = src.name
        for suffix in (".tar.gz", ".tgz"):
            if env_name.endswith(suffix):
                env_name = env_name[: -len(suffix)]
                break
    else:
        env_name = src.name

    # ── Resolve destination ─────────────────────────────────────────
    if args.dst is not None:
        dst = Path(args.dst)
    else:
        dst = Path("/tmp") / env_name

    # ── Discover nodes ──────────────────────────────────────────────
    nodes = _get_worker_nodes(hostfile=args.hostfile)
    current = _get_current_hostname()

    if not nodes:
        logger.error("No worker nodes found.")
        return 1

    # Copy locally first (current node), then rsync to remote nodes.
    # The local copy is needed because /tmp is node-local — but if
    # src is ALREADY under /tmp (e.g. user pre-copied or this is a
    # second invocation), skip the local step and fan out from src
    # directly. Extracted as a helper so tests can override the
    # platform-dependent /tmp prefix (pytest tmp_path lands under
    # /tmp on Linux runners but /var/folders on macOS — which
    # silently changes which branch the tests exercise).
    needs_local_copy = _needs_local_copy(src)
    # Filter out the current node — also handle the HSN variant
    # (current node may appear as "node01" while nodes contain "node01-hsn0").
    current_variants = {current, current + "-hsn0", current.removesuffix("-hsn0")}
    remote_nodes = [n for n in nodes if n not in current_variants]

    # ── Print summary ───────────────────────────────────────────────
    env_size = _get_env_size(src)
    total_nodes = (1 if needs_local_copy else 0) + len(remote_nodes)
    logger.info("Source: %s (%s)", src, env_size)
    logger.info("Target: %s/ on %d node(s)", dst, total_nodes)
    if needs_local_copy:
        logger.info("  local:  %s (rsync to %s/)", current, dst)
    if remote_nodes:
        if len(remote_nodes) <= 6:
            logger.info("  remote: %s", ", ".join(remote_nodes))
        else:
            shown = ', '.join(remote_nodes[:3])
            logger.info("  remote: %s, ... (%d nodes)", shown, len(remote_nodes))
    if args.dry_run:
        logger.info("[dry-run] No files transferred.")
        return 0

    if total_nodes == 0:
        logger.info("Nothing to sync (source is already in %s).", dst)
        return 0

    # ── Sync ────────────────────────────────────────────────────────
    #
    # Greedy tree distribution: instead of all N nodes pulling from the
    # source (which saturates the source node's NIC), each completed
    # node immediately becomes a source for others.  A single thread
    # pool runs for the entire sync — as soon as any rsync finishes,
    # new rsyncs are submitted using the newly-available source.
    #
    # Each source has a concurrency cap (MAX_PER_SOURCE) so no single
    # node is overwhelmed.  The tree grows organically: the first node
    # seeds a few targets, each of those fans out to more, etc.

    MAX_PER_SOURCE = 8  # max concurrent outbound rsyncs per source node

    all_nodes: list[str] = []
    if needs_local_copy:
        all_nodes.append(current)
    all_nodes.extend(remote_nodes)

    total = len(all_nodes)
    progress = _AggregateProgress(total_nodes=total)
    results: list[tuple[str, float, int]] = []

    logger.info("Syncing (%d nodes)...", total)
    t0 = time.perf_counter()

    # Step 1: copy source to local /tmp/ and patch paths ONCE.
    # All subsequent rsyncs distribute the already-patched copy.
    if needs_local_copy:
        _local_t0 = time.perf_counter()

        def _spinner(label: str) -> None:
            """Reusable spinner that shows label + elapsed time."""
            if not _IS_TTY:
                return
            elapsed = time.perf_counter() - _local_t0
            frames = _AggregateProgress._FRAMES
            idx = int(elapsed * 2) % len(frames)
            sys.stdout.write(f"\r\033[K    {frames[idx]} {label}  [{elapsed:.0f}s]")
            sys.stdout.flush()

        if src_is_tarball:
            # Source is already a .tar.gz/.tgz: copy it to /tmp/ and
            # extract there. Skips the "tar create" step that --compress
            # does. Useful when you already have a pre-built tarball
            # (e.g. from `ezpz tar-env`) on a shared filesystem.
            method = "tar.gz (pre-built)"
            local_tarball = Path("/tmp") / src.name
            print()
            try:
                tb_size_gb = src.stat().st_size / (1024**3)
            except OSError:
                tb_size_gb = 0.0

            # Step 1: copy tarball Lustre → /tmp/
            _spinner(f"copying {src.name} ({tb_size_gb:.1f}G) → /tmp/")
            cp_proc = subprocess.Popen(
                ["cp", str(src), str(local_tarball)],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.PIPE,
            )
            while cp_proc.poll() is None:
                _spinner(f"copying {src.name} ({tb_size_gb:.1f}G) → /tmp/")
                time.sleep(0.5)
            if cp_proc.returncode != 0:
                stderr = (cp_proc.stderr.read() or b"").decode()
                logger.warning("cp tarball failed: %s", stderr.strip())
                local_elapsed = time.perf_counter() - _local_t0
                local_rc = cp_proc.returncode or 1
                # Partial copy may have left a truncated tarball behind
                _cleanup_path(local_tarball)
            else:
                # Step 2: extract into dst
                if dst.exists() and not _safe_rmtree(dst):
                    local_elapsed = time.perf_counter() - _local_t0
                    local_rc = 1
                    # We never started extracting, but the local
                    # tarball still needs cleanup.
                    _cleanup_path(local_tarball)
                else:
                    dst.mkdir(parents=True, exist_ok=True)
                    _spinner(f"extracting {local_tarball.name}{dst}/")
                    tar_extract = subprocess.Popen(
                        ["tar", "-xzf", str(local_tarball),
                         "--strip-components=1", "-C", str(dst)],
                        stdout=subprocess.DEVNULL,
                        stderr=subprocess.PIPE,
                    )
                    while tar_extract.poll() is None:
                        _spinner(f"extracting {local_tarball.name}{dst}/")
                        time.sleep(0.5)
                    local_elapsed = time.perf_counter() - _local_t0
                    local_rc = tar_extract.returncode or 0
                    if local_rc != 0:
                        stderr = (tar_extract.stderr.read() or b"").decode()
                        logger.warning("tar extract failed: %s", stderr.strip())
                        # Half-extracted dst is unusable — drop it.
                        # We just wrote it ourselves, so skip the
                        # /tmp-only safety guard.
                        _remove_partial_dst(dst)

                    # Always clean up the local tarball copy whether
                    # extraction succeeded or failed.
                    _cleanup_path(local_tarball)
            if _IS_TTY:
                sys.stdout.write("\r\033[K")

        elif args.compress:
            # tar.gz: compress on Lustre (sequential write), copy one
            # file to /tmp/ (sequential read), extract locally.
            # Much less Lustre metadata pressure than per-file rsync/cp.
            method = "tar.gz"
            tarball = Path(f"/tmp/{env_name}.tar.gz")
            print()

            # Step 1: create archive from source on Lustre
            _spinner(f"tar -czf {tarball.name} (compressing)")
            tar_create = subprocess.Popen(
                [
                    "tar", "-czf", str(tarball),
                    "--exclude=__pycache__",
                    "-C", str(src.parent), src.name,
                ],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.PIPE,
            )
            while tar_create.poll() is None:
                _spinner(f"tar -czf {tarball.name} (compressing)")
                time.sleep(0.5)
            if tar_create.returncode != 0:
                stderr = (tar_create.stderr.read() or b"").decode()
                logger.warning("tar create failed: %s", stderr.strip())
                local_elapsed = time.perf_counter() - _local_t0
                local_rc = tar_create.returncode or 1
                # Partial archive: don't ship a corrupt tarball.
                _cleanup_path(tarball)
            else:
                # Show tarball size
                try:
                    tb_size = tarball.stat().st_size / (1024**3)
                    _spinner(f"tar.gz: {tb_size:.1f}G")
                except OSError:
                    pass

                # Step 2: extract into /tmp/
                if dst.exists() and not _safe_rmtree(dst):
                    local_elapsed = time.perf_counter() - _local_t0
                    local_rc = 1
                    _cleanup_path(tarball)
                else:
                    dst.mkdir(parents=True, exist_ok=True)
                    _spinner(f"extracting {tarball.name}{dst}/")
                    tar_extract = subprocess.Popen(
                        ["tar", "-xzf", str(tarball),
                         "--strip-components=1", "-C", str(dst)],
                        stdout=subprocess.DEVNULL,
                        stderr=subprocess.PIPE,
                    )
                    while tar_extract.poll() is None:
                        _spinner(f"extracting {tarball.name} → /tmp/")
                        time.sleep(0.5)
                    local_elapsed = time.perf_counter() - _local_t0
                    local_rc = tar_extract.returncode or 0
                    if local_rc != 0:
                        stderr = (tar_extract.stderr.read() or b"").decode()
                        logger.warning("tar extract failed: %s", stderr.strip())
                        # Half-extracted dst is unusable — drop it
                        _remove_partial_dst(dst)
                    _cleanup_path(tarball)
            if _IS_TTY:
                sys.stdout.write("\r\033[K")

        elif args.copy:
            # cp -a: faster than rsync for large venvs on parallel
            # filesystems (sequential directory walk vs per-file stat).
            method = "cp"
            print()
            _spinner("cp -a → /tmp/")
            dst.parent.mkdir(parents=True, exist_ok=True)
            if dst.exists() and not _safe_rmtree(dst):
                local_elapsed = time.perf_counter() - _local_t0
                local_rc = 1
            else:
                cp_proc = subprocess.Popen(
                    ["cp", "-a", str(src), str(dst)],
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.PIPE,
                )
                while cp_proc.poll() is None:
                    _spinner("cp -a → /tmp/")
                    time.sleep(0.5)
                local_elapsed = time.perf_counter() - _local_t0
                local_rc = cp_proc.returncode or 0
                stderr = (cp_proc.stderr.read() or b"").decode()
                if local_rc != 0:
                    logger.warning("cp failed (exit %d): %s", local_rc, stderr.strip())
                    # Half-copied dst is unusable — drop it
                    _remove_partial_dst(dst)
            if _IS_TTY:
                sys.stdout.write("\r\033[K")

        else:
            method = "rsync"
            def _local_progress(pct: str = "", speed: str = "", eta: str = "") -> None:
                if not _IS_TTY:
                    return
                elapsed = time.perf_counter() - _local_t0
                parts = ["Copying to local /tmp/"]
                if pct:
                    parts.append(pct)
                if speed:
                    parts.append(speed)
                parts.append(f"[{elapsed:.0f}s]")
                msg = "    " + "  ".join(parts)
                sys.stdout.write(f"\r\033[K{msg}")
                sys.stdout.flush()
            print()
            _local_progress()
            _, local_elapsed, local_rc = _rsync_to_node(
                src, dst, current, local=True,
                progress_callback=_local_progress,
            )
            if _IS_TTY:
                sys.stdout.write("\r\033[K")
        if local_rc == 0:
            # Patching only applies to venvs \u2014 it's a no-op for other
            # sources, but skip explicitly to avoid noise.
            if (dst / "bin" / "activate").exists():
                _patch_venv_paths_local(dst, src)
            print(f"    \u2713 {current} (local, {method}) \u2014 {local_elapsed:.1f}s")
            results.append((current, local_elapsed, local_rc))
        else:
            print(f"    \u2717 {current} (local, {method}) \u2014 FAILED")
            results.append((current, local_elapsed, local_rc))
            # No valid local copy — abort, don't distribute a broken env
            logger.error("Local copy failed — aborting distribution.")
            return 1

    remaining = [n for n in all_nodes if n != current]

    # Track per-source active rsync count to enforce MAX_PER_SOURCE.
    source_active: dict[str, int] = {current: 0}
    source_lock = threading.Lock()
    # Function-scoped RNG so we don't disturb the global random state
    # of any caller that has seeded it for their own reasons.
    pick_rng = random.Random()

    # Per-target retry bookkeeping. `retries_done[n]` counts how many
    # *retries* have been performed against node `n` so far. The
    # initial attempt is NOT a retry, so this is 0 before the first
    # try, becomes 1 after the first failure (when we requeue),
    # becomes 2 after the second failure, etc. Capped at
    # _DEFAULT_RSYNC_RETRIES so the total attempt count for any node
    # is at most _DEFAULT_RSYNC_RETRIES + 1.
    retries_done: dict[str, int] = {n: 0 for n in remaining}
    # Accumulate per-target elapsed time across attempts so the final
    # reported wall-clock isn't just the last attempt's duration.
    elapsed_total: dict[str, float] = {n: 0.0 for n in remaining}

    def _submit_work(pool: ThreadPoolExecutor, futures: dict) -> None:  # type: ignore[type-arg]
        """Submit as many rsyncs as sources allow."""
        while remaining:
            with source_lock:
                src_node = pick_source(
                    source_active, MAX_PER_SOURCE, rng=pick_rng,
                )
                if src_node is None:
                    break  # all sources at capacity — wait for completions
                target = remaining.pop(0)
                source_active[src_node] += 1

            remote_src = None if src_node == current else src_node
            fut = pool.submit(
                _rsync_to_node,
                dst,  # always rsync from the /tmp/ copy
                dst,
                target,
                from_node=remote_src,
                progress_callback=progress.update,
            )
            futures[fut] = (target, src_node)

    # Step 2: greedy fan-out using a single persistent pool.
    with ThreadPoolExecutor(max_workers=min(total, 128)) as pool:
        futures: dict = {}
        _submit_work(pool, futures)

        while futures:
            # Wait for the next completion
            done_iter = as_completed(futures)
            fut = next(done_iter)
            n, elapsed, rc = fut.result()
            _, src_used = futures.pop(fut)
            elapsed_total[n] = elapsed_total.get(n, 0.0) + elapsed

            with source_lock:
                source_active[src_used] -= 1
                if rc == 0:
                    # This node now has the data — register as a source
                    source_active[n] = 0

            if rc != 0 and retries_done[n] < _DEFAULT_RSYNC_RETRIES:
                # Transient failure — requeue for another attempt.
                # Don't mark_done yet (final outcome still pending) and
                # don't append to results yet.  Note we *don't* register
                # this node as a source (rc != 0 branch above).
                retries_done[n] += 1
                # Human-readable attempt count: 1 (initial) + retries
                # done so far. So after the first failure: attempt
                # 1/(N+1) of N+1 budgeted.
                attempt_human = retries_done[n]  # we just incremented
                total_budgeted = _DEFAULT_RSYNC_RETRIES + 1
                logger.warning(
                    "rsync to %s failed (attempt %d/%d, rc=%d) — "
                    "requeueing for retry",
                    n,
                    attempt_human,
                    total_budgeted,
                    rc,
                )
                with source_lock:
                    remaining.append(n)
                _submit_work(pool, futures)
                continue

            label = f"{n} (local)" if n == current else n
            progress.mark_done(label, elapsed_total[n], rc)
            results.append((n, elapsed_total[n], rc))

            # Submit more work now that a source slot freed up
            # (and possibly a new source appeared)
            _submit_work(pool, futures)

    progress.clear()

    total_elapsed = time.perf_counter() - t0
    failed_nodes = [n for n, _, rc in results if rc != 0]
    failed = len(failed_nodes)
    ok_nodes = [n for n, _, rc in results if rc == 0]
    if failed:
        logger.warning("%d/%d node(s) failed!", failed, total_nodes)
    else:
        logger.info("Done in %.1fs", total_elapsed)

    # Resolve the proceed-with-spares threshold.
    #
    # IMPORTANT: --min-success-fraction computes against the FULL
    # hostfile node count (`len(nodes)`), NOT `total_nodes` (which
    # is the rsync-op count — smaller by 1 when needs_local_copy is
    # False because the local node isn't rsync'd to). This matches
    # the docs' promise that the fraction is "of the total node
    # count", which is the user-mental-model for a hostfile of N
    # nodes regardless of which of those N happen to be the
    # launcher's local host.
    #
    # --min-success-nodes is a HARD lower bound: it MUST be met
    # regardless of whether there were failures. A clean run that
    # only synced to 500 nodes when the user asked for >= 512
    # MUST fail loudly, not silently under-provision downstream
    # experiments (codex P2).
    hostfile_node_count = len(nodes)

    if args.min_success_nodes is not None:
        threshold = args.min_success_nodes
        threshold_src = f"--min-success-nodes={args.min_success_nodes}"
    elif args.min_success_fraction is not None:
        threshold = math.ceil(
            args.min_success_fraction * hostfile_node_count
        )
        threshold_src = (
            f"--min-success-fraction={args.min_success_fraction} "
            f{hostfile_node_count} = {threshold}"
        )
    else:
        # No flag → threshold equals total rsync ops (original
        # fail-on-any-failure behavior preserved).
        threshold = total_nodes
        threshold_src = "default (all nodes must succeed)"

    # The threshold is met when at least `threshold` rsyncs
    # succeeded. proceed_with_spares is True iff:
    #   - we have at least one failure (otherwise sentinel path is
    #     a no-op even if threshold is met), AND
    #   - the threshold is met
    threshold_met = len(ok_nodes) >= threshold
    proceed_with_spares = failed > 0 and threshold_met

    # ── Sentinel cleanup + write ────────────────────────────────────
    # ALWAYS remove a stale sentinel from a prior run first. Without
    # this, re-running into the same `dst` (common when `src` is
    # already under /tmp and the local-copy step is skipped) would
    # leave behind the previous run's failed-node list, and the
    # documented downstream snippet would incorrectly exclude hosts
    # that aren't actually problematic on this run (codex/copilot
    # P2). Only re-write when proceed_with_spares is True.
    sentinel_path = dst / ".ezpz-yeet-failed-nodes.txt"
    try:
        sentinel_path.unlink(missing_ok=True)
    except OSError as exc:
        # We can't proceed safely with a stale sentinel — fail
        # rather than risk downstream tooling reading wrong data.
        # Most likely a permissions issue or dst-doesn't-exist
        # (the latter shouldn't happen since the sync wrote there
        # successfully, but be defensive).
        logger.error(
            "Could not remove possibly-stale sentinel %s: %s. "
            "Refusing to proceed — fix the permissions then retry.",
            sentinel_path,
            exc,
        )
        return 1

    if proceed_with_spares:
        try:
            sentinel_path.parent.mkdir(parents=True, exist_ok=True)
            sentinel_path.write_text(
                "\n".join(failed_nodes) + "\n"
            )
            logger.warning(
                "Proceeding despite %d failure(s) — threshold met "
                "(%d/%d nodes ok, need >=%d via %s). Failed nodes "
                "written to %s",
                failed,
                len(ok_nodes),
                total_nodes,
                threshold,
                threshold_src,
                sentinel_path,
            )
        except OSError as exc:
            # The threshold IS met — don't fail the run just because
            # we couldn't write the bookkeeping file. Log loudly so
            # the user knows the file is missing.
            logger.warning(
                "Threshold met (%d/%d ok via %s) but couldn't "
                "write failed-nodes file at %s: %s. Failed nodes "
                "were: %s",
                len(ok_nodes), total_nodes, threshold_src,
                sentinel_path, exc,
                ", ".join(failed_nodes),
            )

    # ── Guidance ────────────────────────────────────────────────────
    # Detect from `dst` so tarball sources (where the directory only
    # exists after extraction) are classified correctly.
    is_venv = (dst / "bin" / "activate").exists()
    is_conda = (dst / "conda-meta").is_dir()
    has_bin = (dst / "bin").is_dir()
    print()
    if is_venv:
        print(f"  To use this environment:")
        print(f"    deactivate 2>/dev/null")
        print(f"    source {dst}/bin/activate")
        print()
        print(f"  Then launch your training (from a shared filesystem path):")
        print(f"    cd /path/to/your/project")
        print(f"    ezpz launch python3 -m your_app.train")
        print()
        print(f"  Note: /tmp is node-local. Make sure your working directory")
        print(f"  is on a shared filesystem (e.g. Lustre) before launching,")
        print(f"  so all ranks can access data and outputs.")
    elif is_conda:
        print(f"  To use this environment:")
        print(f"    conda deactivate")
        print(f"    conda activate {dst}")
        print()
        print(f"  Then launch your training (from a shared filesystem path):")
        print(f"    cd /path/to/your/project")
        print(f"    ezpz launch python3 -m your_app.train")
        print()
        print(f"  Note: /tmp is node-local. Make sure your working directory")
        print(f"  is on a shared filesystem (e.g. Lustre) before launching,")
        print(f"  so all ranks can access data and outputs.")
    else:
        successful = len(results) - failed
        print(f"  Synced to {dst}/ on {successful} node(s).")
        if has_bin:
            print(f"    (looks like a tool directory — add to PATH if needed:")
            print(f"     export PATH={dst}/bin:$PATH)")
        print()
        print(f"  Note: /tmp is node-local. Reference the synced path on each")
        print(f"  worker (e.g. {dst}) — the shared-filesystem source path will")
        print(f"  not see writes from worker nodes.")

    # rc semantics:
    #
    # When --min-success-* is set, the threshold is a HARD lower
    # bound: it MUST be met regardless of failure count. Otherwise
    # `yeet --min-success-nodes 512` on a 500-node hostfile with
    # all 500 ok would return 0 and silently under-provision the
    # downstream training (codex P2).
    #
    #   no flag set:        original behavior — 0 iff no failures
    #   threshold set:      0 iff threshold met (regardless of `failed`)
    threshold_explicitly_set = (
        args.min_success_nodes is not None
        or args.min_success_fraction is not None
    )
    if threshold_explicitly_set:
        if not threshold_met:
            logger.error(
                "Threshold NOT met: only %d/%d nodes succeeded, "
                "need >=%d via %s",
                len(ok_nodes), total_nodes, threshold, threshold_src,
            )
            return 1
        return 0
    # Default path: original fail-on-any behavior.
    return 1 if failed else 0