Skip to content

🔮 ezpz.examples.inference⚓︎

Distributed inference over a HuggingFace model + dataset, with three distinct modes: benchmark, generate, and eval.

Each rank loads the model and processes a disjoint shard of inputs (data parallelism). Per-batch latency and throughput (tokens/sec) are tracked through ezpz.History. With --mode eval an accuracy metric is added. TFLOPS and MFU are opt-in via --flops — without that flag the columns are simply absent rather than approximated (see MFU Tracking (opt-in)).

Source⚓︎

src/ezpz/examples/inference.py
src/ezpz/examples/inference.py
"""Distributed inference over a HuggingFace model + dataset.

A general-purpose inference example: shard a HuggingFace dataset across
ranks (data parallel), generate completions with a HuggingFace model on
each rank, and aggregate the results.

Launch it with::

    ezpz launch python3 -m ezpz.examples.inference \\
        --model meta-llama/Llama-3.2-1B \\
        --dataset wikitext --dataset-config wikitext-2-raw-v1 \\
        --max-samples 256 --max-new-tokens 32

Each rank loads the model, processes its shard of prompts, and
records per-batch latency, throughput (tokens/sec), MFU, and the
generated text.  Outputs go to ``outputs/ezpz.examples.inference/<ts>/``:
``predictions-rank<N>.jsonl`` (one row per sample) and a finalized
``History`` dataset with the timing/perf metrics.
"""

from __future__ import annotations

import argparse
import json
import time
from pathlib import Path
from typing import Any, Optional, Sequence

import torch

import ezpz
from ezpz.examples import get_example_outdir
from ezpz.flops import compute_mfu

logger = ezpz.get_logger(__name__)


def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    """Parse inference command-line arguments."""
    parser = argparse.ArgumentParser(
        prog="ezpz.examples.inference",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=(
            "Distributed inference over a HuggingFace model + dataset. "
            "Three modes: --mode benchmark (throughput), generate "
            "(synthetic data corpus), eval (accuracy on labeled data). "
            "Each rank processes a disjoint shard of prompts."
        ),
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="generate",
        choices=["benchmark", "generate", "eval"],
        help=(
            "Inference mode: "
            "'benchmark' = synthetic random tokens, no dataset, focus "
            "on tokens/sec/MFU. "
            "'generate' = dataset prompts → completions, save to JSONL "
            "(synthetic data / distillation use case). "
            "'eval' = dataset prompts + gold labels, compare generated "
            "text to label, report accuracy."
        ),
    )
    parser.add_argument(
        "--model",
        type=str,
        default="meta-llama/Llama-3.2-1B",
        help="HuggingFace model name or local path",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="wikitext",
        help="HuggingFace dataset name or local path (ignored in --mode benchmark)",
    )
    parser.add_argument(
        "--dataset-config",
        type=str,
        default="wikitext-2-raw-v1",
        help="Dataset configuration (subset name)",
    )
    parser.add_argument(
        "--dataset-split",
        type=str,
        default="test",
        help="Dataset split (train/validation/test)",
    )
    parser.add_argument(
        "--text-column",
        type=str,
        default="text",
        help="Dataset column containing the prompt text",
    )
    parser.add_argument(
        "--label-column",
        type=str,
        default=None,
        help="Dataset column containing the gold label (required for --mode eval)",
    )
    parser.add_argument(
        "--benchmark-iters",
        type=int,
        default=20,
        help="Number of benchmark iterations (only with --mode benchmark)",
    )
    parser.add_argument(
        "--benchmark-warmup",
        type=int,
        default=3,
        help="Warmup iterations to skip when reporting (only with --mode benchmark)",
    )
    parser.add_argument(
        "--max-samples",
        type=int,
        default=128,
        help="Maximum number of samples to process across all ranks",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=4,
        help="Per-rank batch size",
    )
    parser.add_argument(
        "--max-input-tokens",
        type=int,
        default=512,
        help="Truncate prompts to this many tokens",
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=64,
        help="Maximum tokens to generate per sample",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["float32", "float16", "bfloat16"],
        help="Model dtype",
    )
    parser.add_argument(
        "--flops",
        action="store_true",
        help=(
            "Measure exact per-batch FLOPS via FlopCounterMode and "
            "report tflops + mfu in metrics. Off by default — without "
            "this flag, MFU/TFLOPS columns are simply omitted (rather "
            "than reporting approximated values). Adds ~15-40%% "
            "overhead per step."
        ),
    )
    parser.add_argument(
        "--flops-every-n-steps",
        type=int,
        default=1,
        help=(
            "When --flops is set, measure FLOPS every N steps. "
            "Use a higher value to amortize the overhead across batches."
        ),
    )
    parser.add_argument(
        "--do-sample",
        action="store_true",
        help="Use sampling instead of greedy decoding",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="Sampling temperature (only used with --do-sample)",
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=1.0,
        help="Nucleus sampling threshold (only used with --do-sample)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed",
    )
    parser.add_argument(
        # BooleanOptionalAction generates both --save-predictions and
        # --no-save-predictions automatically.  The previous hand-rolled
        # store_true/store_false pair rendered as
        #   --no-save-predictions ... (default: True)
        # which read as if the *flag* defaulted to True instead of the
        # underlying behavior.  BooleanOptionalAction shows
        #   --save-predictions, --no-save-predictions ... (default: True)
        # which is the standard argparse way to express this and reads
        # correctly under ArgumentDefaultsHelpFormatter.
        "--save-predictions",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Write per-sample predictions to JSONL",
    )
    return parser.parse_args(argv)


def shard_indices(total: int, rank: int, world_size: int) -> list[int]:
    """Return the subset of indices [0, total) assigned to *rank*.

    Uses contiguous block sharding — rank ``r`` gets indices
    ``[r*chunk, (r+1)*chunk)`` where ``chunk = ceil(total / world_size)``.
    The last rank may receive fewer samples.
    """
    if total <= 0 or world_size <= 0:
        return []
    chunk = (total + world_size - 1) // world_size
    start = rank * chunk
    end = min(start + chunk, total)
    return list(range(start, end))


def _torch_dtype(name: str) -> torch.dtype:
    return {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }[name]


def _normalize(text: str) -> str:
    """Lowercase + collapse whitespace + strip — used for eval matching."""
    if text is None:
        return ""
    return " ".join(str(text).lower().split())


def _run_with_optional_flops(
    fn,
    *args,
    measure: bool,
    **kwargs,
):
    """Run *fn(*args, **kwargs)* under FlopCounterMode if *measure* is True.

    Returns ``(result, measured_flops)``.  When *measure* is False or
    when the counter raises (XPU, etc.), ``measured_flops`` is ``0``.
    Pulled out of the hot loop so the eval-unlabeled forward path and
    the generate path share one wrapper instead of duplicating the
    enter/exit/try plumbing.
    """
    if not measure:
        return fn(*args, **kwargs), 0
    from torch.utils.flop_counter import FlopCounterMode
    with FlopCounterMode(display=False) as fc:
        result = fn(*args, **kwargs)
    try:
        return result, int(fc.get_total_flops())
    except Exception:
        return result, 0


def main(argv: Optional[Sequence[str]] = None) -> int:
    """Entry point for distributed HF inference."""
    # Silence noisy per-request HTTP logs from HF Hub clients
    import logging as _logging
    for _noisy in ("httpx", "huggingface_hub", "filelock", "urllib3"):
        _logging.getLogger(_noisy).setLevel(_logging.WARNING)
    # Silence noisy transformers messages (e.g. the BPE
    # clean_up_tokenization_spaces warning fired on every decode call)
    try:
        import transformers as _transformers
        _transformers.logging.set_verbosity_error()
        _transformers.logging.disable_progress_bar()
    except Exception:
        pass

    args = parse_args(argv)

    # ── Distributed setup ──────────────────────────────────────────
    rank = ezpz.setup_torch(seed=args.seed)
    world_size = ezpz.get_world_size()
    device = ezpz.get_torch_device(as_torch_device=True)
    dtype = _torch_dtype(args.dtype)

    # ── Model + tokenizer ──────────────────────────────────────────
    # Load on rank 0 first to populate the HF cache, then let other
    # ranks read from cache. Avoids thundering-herd downloads on first
    # run from a clean cache.
    if rank == 0:
        logger.info("Loading model: %s", args.model)
    from transformers import AutoModelForCausalLM, AutoTokenizer

    def _load_tokenizer_and_model():
        tok = AutoTokenizer.from_pretrained(
            args.model,
            clean_up_tokenization_spaces=False,
        )
        # Resolve a usable pad id: prefer existing pad → reuse eos →
        # add a new <pad> token (last-resort, requires resizing the
        # model embedding table to match — done below after model load).
        added_pad_token = False
        if tok.pad_token_id is None:
            if tok.eos_token_id is not None:
                tok.pad_token = tok.eos_token
            else:
                tok.add_special_tokens({"pad_token": "<pad>"})
                added_pad_token = True
        # Left-pad so generation continues from the end of each prompt
        tok.padding_side = "left"
        m = AutoModelForCausalLM.from_pretrained(
            args.model,
            torch_dtype=dtype,
        ).to(device)
        # If we extended the vocab, resize the embedding table so the
        # new pad id doesn't index out of bounds.
        if added_pad_token:
            m.resize_token_embeddings(len(tok))
        m.eval()
        return tok, m

    if rank == 0:
        tokenizer, model = _load_tokenizer_and_model()
        if world_size > 1:
            ezpz.synchronize()
    else:
        if world_size > 1:
            ezpz.synchronize()
        tokenizer, model = _load_tokenizer_and_model()
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

    if rank == 0:
        n_params = sum(p.numel() for p in model.parameters())
        logger.info("Model loaded: %.2fB parameters", n_params / 1e9)

    # FLOPS / MFU tracking is opt-in via --flops.  Without that flag
    # the metrics are simply omitted — better than reporting
    # approximated numbers that can be misleading (linear scaling
    # under-counts attention's O(seq^2) cost; fixed-shape startup
    # estimates over-count short batches).  When --flops is set, we
    # measure exact per-batch FLOPS via FlopCounterMode every
    # --flops-every-n-steps batches.

    # ── Prompts (and optional labels) ──────────────────────────────
    # In benchmark mode we synthesize random tokens — no dataset load.
    # In generate / eval mode we pull from a HuggingFace dataset.
    # eval_unlabeled = True means we'll score next-token prediction
    # directly (no labels, no generation).
    prompts: list[str]
    labels: list[Optional[str]]
    eval_unlabeled = args.mode == "eval" and not args.label_column

    if args.mode == "benchmark":
        if rank == 0:
            logger.info(
                "Benchmark mode: %d iterations × batch_size=%d at "
                "max_input_tokens=%d, max_new_tokens=%d",
                args.benchmark_iters, args.batch_size,
                args.max_input_tokens, args.max_new_tokens,
            )
        # Synthesize random token IDs as "prompts" — we'll skip the
        # tokenizer in the loop and feed input_ids directly.
        prompts = [""] * args.benchmark_iters * args.batch_size
        labels = [None] * len(prompts)
    else:
        if rank == 0 and eval_unlabeled:
            logger.info(
                "--mode eval without --label-column: scoring next-token "
                "prediction (accuracy + perplexity) on the prompt itself"
            )

        if rank == 0:
            logger.info(
                "Loading dataset: %s [%s] split=%s",
                args.dataset, args.dataset_config, args.dataset_split,
            )
        from datasets import load_dataset

        def _load_ds():
            return load_dataset(
                args.dataset, args.dataset_config, split=args.dataset_split,
            )

        if rank == 0:
            ds = _load_ds()
            if world_size > 1:
                ezpz.synchronize()
        else:
            if world_size > 1:
                ezpz.synchronize()
            ds = _load_ds()

        prompts = []
        labels = []
        for row in ds:
            if not isinstance(row, dict):
                continue
            text = row.get(args.text_column, "")
            if not (isinstance(text, str) and text.strip()):
                continue
            prompts.append(text.strip())
            if args.label_column:
                lbl = row.get(args.label_column)
                labels.append(str(lbl) if lbl is not None else None)
            else:
                labels.append(None)
            if len(prompts) >= args.max_samples:
                break

        if not prompts:
            if rank == 0:
                logger.error(
                    "No prompts found in column %r of %s/%s split=%s — aborting.",
                    args.text_column, args.dataset, args.dataset_config,
                    args.dataset_split,
                )
            ezpz.cleanup()
            return 1

    my_indices = shard_indices(len(prompts), rank, world_size)
    my_prompts = [prompts[i] for i in my_indices]
    my_labels = [labels[i] for i in my_indices]

    if rank == 0:
        logger.info(
            "Total samples: %d%d per rank (rank 0 has %d)",
            len(prompts),
            (len(prompts) + world_size - 1) // world_size,
            len(my_prompts),
        )

    # ── Output paths ───────────────────────────────────────────────
    module_name = "ezpz.examples.inference"
    outdir = get_example_outdir(module_name)
    if rank == 0:
        logger.info("Outputs will be saved to %s", outdir)

    history = ezpz.History(
        project_name=module_name,
        config={
            "model": args.model,
            "dataset": f"{args.dataset}/{args.dataset_config}",
            "world_size": world_size,
            "batch_size": args.batch_size,
            "max_new_tokens": args.max_new_tokens,
            "dtype": args.dtype,
        },
        outdir=outdir,
        report_dir=outdir,
        report_enabled=True,
        distributed_history=(1 < world_size <= 384),
    )

    # Predictions file: only useful for generate / eval.  Disable in
    # benchmark mode (the "completions" are random-token noise).
    pred_file = None
    if args.save_predictions and args.mode != "benchmark":
        pred_path = Path(outdir) / f"predictions-rank{rank}.jsonl"
        pred_path.parent.mkdir(parents=True, exist_ok=True)
        pred_file = pred_path.open("w", encoding="utf-8")

    # ── Inference loop ─────────────────────────────────────────────
    gen_kwargs: dict[str, Any] = {
        "max_new_tokens": args.max_new_tokens,
        "do_sample": args.do_sample,
        "pad_token_id": pad_id,
    }
    if args.do_sample:
        gen_kwargs["temperature"] = args.temperature
        gen_kwargs["top_p"] = args.top_p

    total_samples = 0
    total_new_tokens = 0
    n_correct = 0  # only meaningful in --mode eval (labeled)
    n_with_label = 0
    # For unlabeled --mode eval (next-token prediction):
    n_token_correct = 0
    n_tokens_scored = 0
    nll_sum = 0.0  # sum of -log p(next token); perplexity = exp(nll/n)
    benchmark_warmup = (
        args.benchmark_warmup if args.mode == "benchmark" else 0
    )
    t_start = time.perf_counter()

    with torch.inference_mode():
        for batch_idx in range(0, len(my_prompts), args.batch_size):
            batch_prompts = my_prompts[batch_idx : batch_idx + args.batch_size]
            batch_labels = my_labels[batch_idx : batch_idx + args.batch_size]
            if not batch_prompts:
                continue

            ezpz.synchronize()
            t0 = time.perf_counter()

            if args.mode == "benchmark":
                # Synthesize random input_ids — skip tokenizer entirely
                input_ids = torch.randint(
                    0,
                    int(getattr(tokenizer, "vocab_size", 32000)),
                    (len(batch_prompts), args.max_input_tokens),
                    device=device,
                    dtype=torch.long,
                )
                attention_mask = torch.ones_like(input_ids)
            else:
                enc = tokenizer(
                    batch_prompts,
                    padding=True,
                    truncation=True,
                    max_length=args.max_input_tokens,
                    return_tensors="pt",
                ).to(device)
                input_ids = enc.input_ids
                attention_mask = enc.attention_mask

            # Unlabeled eval path: forward pass + score next-token
            # prediction. No autoregressive generation.
            decoded: list[str] = []
            batch_token_correct = 0
            batch_token_scored = 0
            batch_nll_sum = 0.0
            n_new = 0
            measured_flops = 0  # only set when this step runs FlopCounterMode
            batch_num = batch_idx // args.batch_size
            measure_flops_now = (
                args.flops
                and args.flops_every_n_steps > 0
                and batch_num % args.flops_every_n_steps == 0
            )
            if eval_unlabeled:
                outputs, measured_flops = _run_with_optional_flops(
                    model,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    measure=measure_flops_now,
                )
                logits = outputs.logits  # [B, T, V]
                # Predict token i+1 from logits at position i
                shift_logits = logits[:, :-1, :].contiguous()
                shift_targets = input_ids[:, 1:].contiguous()
                # Mask padded positions out of the score
                shift_mask = attention_mask[:, 1:].to(torch.bool)
                preds = shift_logits.argmax(dim=-1)
                correct = (preds == shift_targets) & shift_mask
                batch_token_correct = int(correct.sum().item())
                batch_token_scored = int(shift_mask.sum().item())
                # Cross-entropy in float32 for numerical stability
                ce = torch.nn.functional.cross_entropy(
                    shift_logits.float().reshape(-1, shift_logits.size(-1)),
                    shift_targets.reshape(-1),
                    reduction="none",
                ).reshape(shift_targets.shape)
                batch_nll_sum = float((ce * shift_mask).sum().item())
                ezpz.synchronize()
                dt = time.perf_counter() - t0
            else:
                output_ids, measured_flops = _run_with_optional_flops(
                    model.generate,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    measure=measure_flops_now,
                    **gen_kwargs,
                )
                ezpz.synchronize()
                dt = time.perf_counter() - t0
                # Slice off the prompt portion so we report only new tokens
                input_len = input_ids.shape[1]
                new_token_ids = output_ids[:, input_len:]
                n_new = int((new_token_ids != pad_id).sum().item())
                # Decode (skip in benchmark mode — random tokens aren't useful)
                if args.mode != "benchmark":
                    decoded = tokenizer.batch_decode(
                        new_token_ids,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=False,
                    )

            n_in = int(attention_mask.sum().item())
            is_warmup = batch_num < benchmark_warmup

            # Labeled-eval string match (only when we have labels)
            batch_correct = 0
            batch_with_label = 0
            if args.mode == "eval" and not eval_unlabeled:
                for completion, label in zip(decoded, batch_labels):
                    if label is None:
                        continue
                    batch_with_label += 1
                    if _normalize(completion) == _normalize(label) or (
                        _normalize(label) and _normalize(label) in _normalize(completion)
                    ):
                        batch_correct += 1
                n_correct += batch_correct
                n_with_label += batch_with_label

            if eval_unlabeled:
                n_token_correct += batch_token_correct
                n_tokens_scored += batch_token_scored
                nll_sum += batch_nll_sum

            if not is_warmup:
                total_samples += len(batch_prompts)
                total_new_tokens += n_new

            # In unlabeled-eval mode there are no generated tokens —
            # we ran a single forward pass and scored prompt tokens.
            # Report tokens/sec and MFU based on tokens *scored*
            # (one forward pass per token), not generated.
            if eval_unlabeled:
                tokens_for_throughput = batch_token_scored
            else:
                tokens_for_throughput = n_new

            metrics: dict[str, Any] = {
                "batch": batch_num,
                "samples": len(batch_prompts),
                "input_tokens": n_in,
                "dt": dt,
                "tokens_per_sec": (
                    tokens_for_throughput / dt if dt > 0 else 0.0
                ),
                "warmup": is_warmup,
            }
            # In unlabeled-eval mode there are no generated tokens —
            # only scored tokens.  Reporting "new_tokens=0" on every
            # row would be noise in the JSONL/CSV; emit one or the
            # other based on the mode.
            if eval_unlabeled:
                metrics["tokens_scored"] = batch_token_scored
            else:
                metrics["new_tokens"] = n_new
            # FLOPS / MFU only when --flops was set AND this step
            # actually ran FlopCounterMode.  Skip on the in-between
            # batches when --flops-every-n-steps > 1.  No approximation
            # is reported — the metrics are simply absent on unmeasured
            # steps (you can filter on flops_measured == True downstream).
            if measured_flops > 0 and dt > 0:
                metrics["tflops"] = measured_flops / dt / 1e12
                metrics["mfu"] = compute_mfu(measured_flops, dt)
                metrics["flops_measured"] = True
            if args.mode == "eval" and batch_with_label > 0:
                metrics["accuracy"] = batch_correct / batch_with_label
            if eval_unlabeled and batch_token_scored > 0:
                metrics["next_token_accuracy"] = (
                    batch_token_correct / batch_token_scored
                )
                metrics["perplexity"] = float(
                    torch.exp(
                        torch.tensor(batch_nll_sum / batch_token_scored)
                    ).item()
                )

            if not is_warmup:
                history.update(metrics)

            if pred_file is not None:
                for i, (prompt, completion) in enumerate(zip(batch_prompts, decoded)):
                    row = {
                        "rank": rank,
                        "prompt": prompt,
                        "completion": completion,
                    }
                    if args.mode == "eval":
                        row["label"] = batch_labels[i]
                    pred_file.write(json.dumps(row) + "\n")

            if rank == 0:
                tag = " [warmup]" if is_warmup else ""
                parts = [
                    f"batch={batch_num}{tag}",
                    f"samples={metrics['samples']}",
                ]
                # In unlabeled-eval mode there are no generated
                # tokens — show what was *scored* instead.
                if "tokens_scored" in metrics:
                    parts.append(f"tokens_scored={metrics['tokens_scored']}")
                else:
                    parts.append(f"new_tokens={metrics['new_tokens']}")
                parts.extend([
                    f"dt={dt:.3f}s",
                    f"tps={metrics['tokens_per_sec']:.1f}",
                ])
                if "tflops" in metrics:
                    parts.append(f"tflops={metrics['tflops']:.2f}")
                if "mfu" in metrics:
                    parts.append(f"mfu={metrics['mfu']:.2f}%")
                if "accuracy" in metrics:
                    parts.append(f"accuracy={metrics['accuracy']:.3f}")
                if "next_token_accuracy" in metrics:
                    parts.append(
                        f"next_token_acc={metrics['next_token_accuracy']:.3f}"
                    )
                if "perplexity" in metrics:
                    parts.append(f"perplexity={metrics['perplexity']:.2f}")
                logger.info(" ".join(parts))

    if pred_file is not None:
        pred_file.close()

    t_total = time.perf_counter() - t_start
    if rank == 0:
        logger.info(
            "Done [%s] — %d samples, %d new tokens in %.1fs (%.1f tok/s aggregate per rank)",
            args.mode,
            total_samples,
            total_new_tokens,
            t_total,
            total_new_tokens / t_total if t_total > 0 else 0.0,
        )
        if args.mode == "eval" and not eval_unlabeled and n_with_label > 0:
            logger.info(
                "Accuracy [rank 0]: %d/%d = %.2f%%",
                n_correct, n_with_label,
                100.0 * n_correct / n_with_label,
            )
        if eval_unlabeled and n_tokens_scored > 0:
            avg_nll = nll_sum / n_tokens_scored
            ppl = float(torch.exp(torch.tensor(avg_nll)).item())
            logger.info(
                "Next-token [rank 0]: accuracy=%d/%d=%.2f%%  perplexity=%.3f  (NLL/token=%.4f)",
                n_token_correct, n_tokens_scored,
                100.0 * n_token_correct / n_tokens_scored,
                ppl, avg_nll,
            )

    if rank == 0:
        history.finalize(
            outdir=outdir,
            run_name=module_name,
            dataset_fname="inference",
            verbose=False,
        )

    ezpz.cleanup()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

Code Walkthrough⚓︎

shard_indices — data-parallel input partitioning

Each rank computes its slice of the input array; the helper handles uneven splits and the "more ranks than samples" edge case.

src/ezpz/examples/inference.py:202:215
    return parser.parse_args(argv)


def shard_indices(total: int, rank: int, world_size: int) -> list[int]:
    """Return the subset of indices [0, total) assigned to *rank*.

    Uses contiguous block sharding — rank ``r`` gets indices
    ``[r*chunk, (r+1)*chunk)`` where ``chunk = ceil(total / world_size)``.
    The last rank may receive fewer samples.
    """
    if total <= 0 or world_size <= 0:
        return []
    chunk = (total + world_size - 1) // world_size
    start = rank * chunk
_run_with_optional_flops — opt-in FLOPS measurement

The eval-unlabeled forward path and the model.generate path share this wrapper. When measure=False (the default — no --flops), the call is a passthrough; otherwise it runs inside FlopCounterMode and returns the measured FLOP count alongside the result.

src/ezpz/examples/inference.py:232:254
    return " ".join(str(text).lower().split())


def _run_with_optional_flops(
    fn,
    *args,
    measure: bool,
    **kwargs,
):
    """Run *fn(*args, **kwargs)* under FlopCounterMode if *measure* is True.

    Returns ``(result, measured_flops)``.  When *measure* is False or
    when the counter raises (XPU, etc.), ``measured_flops`` is ``0``.
    Pulled out of the hot loop so the eval-unlabeled forward path and
    the generate path share one wrapper instead of duplicating the
    enter/exit/try plumbing.
    """
    if not measure:
        return fn(*args, **kwargs), 0
    from torch.utils.flop_counter import FlopCounterMode
    with FlopCounterMode(display=False) as fc:
        result = fn(*args, **kwargs)
    try:
Per-batch metrics — measured vs absent

tflops and mfu are only added to the metrics dict when the step actually ran FlopCounterMode. No approximation is reported on unmeasured batches; downstream filters can rely on flops_measured == True to identify trustworthy points.

The unlabeled-eval path emits tokens_scored (one forward pass over the prompt batch); the generate path emits new_tokens (one forward per autoregressive token). Never both.

Eval-unlabeled scoring — perplexity from logits

Without --label-column, eval mode runs a single forward pass per batch and scores next-token prediction at every position. Argmax accuracy and per-token cross-entropy go into the metrics; perplexity = exp(NLL/token) is reported per batch and overall.

This is much cheaper than autoregressive generation and matches the "language modeling perplexity" you'd get from eval-harness tools.

Modes⚓︎

# Throughput benchmark — synthetic random tokens, no dataset
ezpz launch python3 -m ezpz.examples.inference --mode benchmark

# Synthetic data generation — dataset prompts → completions JSONL
ezpz launch python3 -m ezpz.examples.inference --mode generate

# Evaluation — generate, then score against gold labels
ezpz launch python3 -m ezpz.examples.inference --mode eval \
    --dataset gsm8k --dataset-config main --dataset-split test \
    --text-column question --label-column answer
Mode Source Saves predictions? Reports accuracy?
benchmark random tokens no no
generate (default) dataset prompts yes no
eval dataset prompts + labels yes (with label) yes

--mode benchmark⚓︎

Pure throughput measurement. Skips the tokenizer entirely and feeds random input_ids of shape (batch_size, max_input_tokens) into the model. Configurable via:

  • --benchmark-iters (default 20) — number of forward passes
  • --benchmark-warmup (default 3) — warmup iters excluded from totals

Use this for hardware/configuration comparisons (different --batch-size, --dtype, --world-size, --max-input-tokens). The reported tokens/sec and MFU are uncontaminated by tokenizer overhead or dataset variance.

--mode generate⚓︎

Reads prompts from a HuggingFace dataset, generates completions, and writes them to predictions-rank<N>.jsonl. Useful for:

  • Synthetic data generation for downstream training
  • Distillation — generate teacher completions to train a student
  • Spot-checking model behavior on a known corpus

--mode eval (with --label-column)⚓︎

Same as generate, but also extracts a gold label from --label-column and compares the completion. The match rule is "normalized exact-match OR substring": both completion and label are lowercased, whitespace-collapsed, then checked. The label counts as correct if it appears as a substring of the normalized completion (handles "the answer is 42" vs gold "42").

Reports accuracy = correct / labeled per batch (in History) and overall (in the final log line).

--mode eval (without --label-column)⚓︎

If you omit --label-column, eval falls back to next-token prediction scoring — useful on any text dataset, no labels needed:

ezpz launch python3 -m ezpz.examples.inference --mode eval \
    --dataset wikitext --dataset-config wikitext-2-raw-v1

This path:

  • Runs one forward pass over the full prompt batch (not generation)
  • Computes argmax accuracy at each position (token i+1 predicted from logits at position i)
  • Computes per-token cross-entropy → perplexity = exp(NLL/token)
  • Reports next_token_accuracy and perplexity per batch + overall

Much faster than the labeled path (one forward pass vs autoregressive generation), and gives the standard "language modeling perplexity" metric you'd see in eval-harness tools — without needing a labeled dataset.

The per-batch log line shows tokens_scored= instead of new_tokens= to make the difference obvious.

How sharding works⚓︎

Inputs are loaded once on every rank, then each rank processes a contiguous block via shard_indices:

total = 10 samples, world_size = 4

rank 0: [0, 1, 2]
rank 1: [3, 4, 5]
rank 2: [6, 7, 8]
rank 3: [9]

This is data parallelism — each rank holds the full model. There is no model-parallel sharding, so the model must fit in a single device's memory. For larger models, use a model-parallel inference framework (vLLM, DeepSpeed-Inference) instead.

CLI options⚓︎

Flag Default Description
--mode generate benchmark / generate / eval
--model meta-llama/Llama-3.2-1B HF model name or local path
--dataset wikitext HF dataset (ignored in benchmark mode)
--dataset-config wikitext-2-raw-v1 Dataset configuration / subset
--dataset-split test Split (train/validation/test)
--text-column text Column containing the prompt
--label-column Gold label column (optional for --mode eval; without it, scores next-token prediction)
--max-samples 128 Total samples across all ranks
--batch-size 4 Per-rank batch size
--max-input-tokens 512 Truncate prompts to this many tokens
--max-new-tokens 64 Tokens to generate per sample
--dtype bfloat16 Model dtype (float32/float16/bfloat16)
--do-sample off Sample instead of greedy decoding
--temperature 1.0 Sampling temperature (with --do-sample)
--top-p 1.0 Nucleus sampling cutoff (with --do-sample)
--benchmark-iters 20 Iterations (only --mode benchmark)
--benchmark-warmup 3 Warmup iters excluded from totals
--seed 0 Random seed for token sampling and shard generation
--flops off Measure real per-batch FLOPS via FlopCounterMode. Without this flag, tflops and mfu are not reported (rather than reporting approximated values).
--flops-every-n-steps 1 When --flops is set, measure every N steps to amortize the overhead
--no-save-predictions save on Skip writing per-sample JSONL

Outputs⚓︎

outputs/ezpz.examples.inference/<timestamp>/
├── predictions-rank0.jsonl       # not written in benchmark mode
├── predictions-rank1.jsonl
├── ...
├── inference.h5                  # finalized History dataset
├── report-inference.md           # markdown summary
└── plots/                        # auto-generated metric plots

Each predictions-rank<N>.jsonl row:

// generate mode
{"rank": 0, "prompt": "...", "completion": "..."}

// eval mode
{"rank": 0, "prompt": "...", "completion": "...", "label": "42"}

MFU Tracking (opt-in)⚓︎

tflops and mfu are off by default because the only honest measurement is via FlopCounterMode, which adds ~15-40% per-step overhead. Approximated values (linear-scaled startup estimates, n_tokens × forward_flops, etc.) tend to be misleading — they ignore attention's O(seq²) cost and KV-cache savings, and have produced MFU values >100% in practice.

To opt in:

# Measure real FLOPS on every batch
ezpz launch python3 -m ezpz.examples.inference --mode eval --flops

# Or amortize the overhead — measure every 10th batch
ezpz launch python3 -m ezpz.examples.inference \
    --mode eval --flops --flops-every-n-steps 10

Behavior:

  • Without --flops: tflops and mfu keys are absent from metrics. Per-batch logs show timing/throughput/eval metrics only.
  • With --flops and --flops-every-n-steps 1 (default): every batch is profiled. ~15-40% slower overall, exact MFU on every step.
  • With --flops and --flops-every-n-steps N (N > 1): every Nth batch is profiled, others run normally. Profiled batches get metrics["flops_measured"] = True so post-analysis can filter to the trustworthy points.

See ezpz.flops for the per-device MFU formula.

Help⚓︎

--help
$ ezpz launch python3 -m ezpz.examples.inference --help
usage: ezpz.examples.inference [-h] [--mode {benchmark,generate,eval}]
                               [--model MODEL] [--dataset DATASET]
                               [--dataset-config DATASET_CONFIG]
                               [--dataset-split DATASET_SPLIT]
                               [--text-column TEXT_COLUMN]
                               [--label-column LABEL_COLUMN]
                               [--benchmark-iters BENCHMARK_ITERS]
                               [--benchmark-warmup BENCHMARK_WARMUP]
                               [--max-samples MAX_SAMPLES]
                               [--batch-size BATCH_SIZE]
                               [--max-input-tokens MAX_INPUT_TOKENS]
                               [--max-new-tokens MAX_NEW_TOKENS]
                               [--dtype {float32,float16,bfloat16}] [--flops]
                               [--flops-every-n-steps FLOPS_EVERY_N_STEPS]
                               [--do-sample] [--temperature TEMPERATURE]
                               [--top-p TOP_P] [--seed SEED]
                               [--save-predictions | --no-save-predictions]

Distributed inference over a HuggingFace model + dataset. Three modes: --mode
benchmark (throughput), generate (synthetic data corpus), eval (accuracy on
labeled data). Each rank processes a disjoint shard of prompts.

options:
  -h, --help            show this help message and exit
  --mode {benchmark,generate,eval}
                        Inference mode: 'benchmark' = synthetic random tokens,
                        no dataset, focus on tokens/sec/MFU. 'generate' =
                        dataset prompts  completions, save to JSONL
                        (synthetic data / distillation use case). 'eval' =
                        dataset prompts + gold labels, compare generated text
                        to label, report accuracy. (default: generate)
  --model MODEL         HuggingFace model name or local path (default: meta-
                        llama/Llama-3.2-1B)
  --dataset DATASET     HuggingFace dataset name or local path (ignored in
                        --mode benchmark) (default: wikitext)
  --dataset-config DATASET_CONFIG
                        Dataset configuration (subset name) (default:
                        wikitext-2-raw-v1)
  --dataset-split DATASET_SPLIT
                        Dataset split (train/validation/test) (default: test)
  --text-column TEXT_COLUMN
                        Dataset column containing the prompt text (default:
                        text)
  --label-column LABEL_COLUMN
                        Dataset column containing the gold label (required for
                        --mode eval) (default: None)
  --benchmark-iters BENCHMARK_ITERS
                        Number of benchmark iterations (only with --mode
                        benchmark) (default: 20)
  --benchmark-warmup BENCHMARK_WARMUP
                        Warmup iterations to skip when reporting (only with
                        --mode benchmark) (default: 3)
  --max-samples MAX_SAMPLES
                        Maximum number of samples to process across all ranks
                        (default: 128)
  --batch-size BATCH_SIZE
                        Per-rank batch size (default: 4)
  --max-input-tokens MAX_INPUT_TOKENS
                        Truncate prompts to this many tokens (default: 512)
  --max-new-tokens MAX_NEW_TOKENS
                        Maximum tokens to generate per sample (default: 64)
  --dtype {float32,float16,bfloat16}
                        Model dtype (default: bfloat16)
  --flops               Measure exact per-batch FLOPS via FlopCounterMode and
                        report tflops + mfu in metrics. Off by default                         without this flag, MFU/TFLOPS columns are simply
                        omitted (rather than reporting approximated values).
                        Adds ~15-40% overhead per step. (default: False)
  --flops-every-n-steps FLOPS_EVERY_N_STEPS
                        When --flops is set, measure FLOPS every N steps. Use
                        a higher value to amortize the overhead across
                        batches. (default: 1)
  --do-sample           Use sampling instead of greedy decoding (default:
                        False)
  --temperature TEMPERATURE
                        Sampling temperature (only used with --do-sample)
                        (default: 1.0)
  --top-p TOP_P         Nucleus sampling threshold (only used with --do-
                        sample) (default: 1.0)
  --seed SEED           Random seed (default: 0)
  --save-predictions, --no-save-predictions
                        Write per-sample predictions to JSONL (default: True)

See Also⚓︎