Skip to content

ezpz.examples.inference⚓︎

Distributed inference over a HuggingFace model + dataset.

A general-purpose inference example: shard a HuggingFace dataset across ranks (data parallel), generate completions with a HuggingFace model on each rank, and aggregate the results.

Launch it with::

ezpz launch python3 -m ezpz.examples.inference \
    --model meta-llama/Llama-3.2-1B \
    --dataset wikitext --dataset-config wikitext-2-raw-v1 \
    --max-samples 256 --max-new-tokens 32

Each rank loads the model, processes its shard of prompts, and records per-batch latency, throughput (tokens/sec), MFU, and the generated text. Outputs go to outputs/ezpz.examples.inference/<ts>/: predictions-rank<N>.jsonl (one row per sample) and a finalized History dataset with the timing/perf metrics.

main(argv=None) ⚓︎

Entry point for distributed HF inference.

Source code in src/ezpz/examples/inference.py
def main(argv: Optional[Sequence[str]] = None) -> int:
    """Entry point for distributed HF inference."""
    # Silence noisy per-request HTTP logs from HF Hub clients
    import logging as _logging
    for _noisy in ("httpx", "huggingface_hub", "filelock", "urllib3"):
        _logging.getLogger(_noisy).setLevel(_logging.WARNING)
    # Silence noisy transformers messages (e.g. the BPE
    # clean_up_tokenization_spaces warning fired on every decode call)
    try:
        import transformers as _transformers
        _transformers.logging.set_verbosity_error()
        _transformers.logging.disable_progress_bar()
    except Exception:
        pass

    args = parse_args(argv)

    # ── Distributed setup ──────────────────────────────────────────
    rank = ezpz.setup_torch(seed=args.seed)
    world_size = ezpz.get_world_size()
    device = ezpz.get_torch_device(as_torch_device=True)
    dtype = _torch_dtype(args.dtype)

    # ── Model + tokenizer ──────────────────────────────────────────
    # Load on rank 0 first to populate the HF cache, then let other
    # ranks read from cache. Avoids thundering-herd downloads on first
    # run from a clean cache.
    if rank == 0:
        logger.info("Loading model: %s", args.model)
    from transformers import AutoModelForCausalLM, AutoTokenizer

    def _load_tokenizer_and_model():
        tok = AutoTokenizer.from_pretrained(
            args.model,
            clean_up_tokenization_spaces=False,
        )
        # Resolve a usable pad id: prefer existing pad → reuse eos →
        # add a new <pad> token (last-resort, requires resizing the
        # model embedding table to match — done below after model load).
        added_pad_token = False
        if tok.pad_token_id is None:
            if tok.eos_token_id is not None:
                tok.pad_token = tok.eos_token
            else:
                tok.add_special_tokens({"pad_token": "<pad>"})
                added_pad_token = True
        # Left-pad so generation continues from the end of each prompt
        tok.padding_side = "left"
        m = AutoModelForCausalLM.from_pretrained(
            args.model,
            torch_dtype=dtype,
        ).to(device)
        # If we extended the vocab, resize the embedding table so the
        # new pad id doesn't index out of bounds.
        if added_pad_token:
            m.resize_token_embeddings(len(tok))
        m.eval()
        return tok, m

    if rank == 0:
        tokenizer, model = _load_tokenizer_and_model()
        if world_size > 1:
            ezpz.synchronize()
    else:
        if world_size > 1:
            ezpz.synchronize()
        tokenizer, model = _load_tokenizer_and_model()
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

    if rank == 0:
        n_params = sum(p.numel() for p in model.parameters())
        logger.info("Model loaded: %.2fB parameters", n_params / 1e9)

    # FLOPS / MFU tracking is opt-in via --flops.  Without that flag
    # the metrics are simply omitted — better than reporting
    # approximated numbers that can be misleading (linear scaling
    # under-counts attention's O(seq^2) cost; fixed-shape startup
    # estimates over-count short batches).  When --flops is set, we
    # measure exact per-batch FLOPS via FlopCounterMode every
    # --flops-every-n-steps batches.

    # ── Prompts (and optional labels) ──────────────────────────────
    # In benchmark mode we synthesize random tokens — no dataset load.
    # In generate / eval mode we pull from a HuggingFace dataset.
    # eval_unlabeled = True means we'll score next-token prediction
    # directly (no labels, no generation).
    prompts: list[str]
    labels: list[Optional[str]]
    eval_unlabeled = args.mode == "eval" and not args.label_column

    if args.mode == "benchmark":
        if rank == 0:
            logger.info(
                "Benchmark mode: %d iterations × batch_size=%d at "
                "max_input_tokens=%d, max_new_tokens=%d",
                args.benchmark_iters, args.batch_size,
                args.max_input_tokens, args.max_new_tokens,
            )
        # Synthesize random token IDs as "prompts" — we'll skip the
        # tokenizer in the loop and feed input_ids directly.
        prompts = [""] * args.benchmark_iters * args.batch_size
        labels = [None] * len(prompts)
    else:
        if rank == 0 and eval_unlabeled:
            logger.info(
                "--mode eval without --label-column: scoring next-token "
                "prediction (accuracy + perplexity) on the prompt itself"
            )

        if rank == 0:
            logger.info(
                "Loading dataset: %s [%s] split=%s",
                args.dataset, args.dataset_config, args.dataset_split,
            )
        from datasets import load_dataset

        def _load_ds():
            return load_dataset(
                args.dataset, args.dataset_config, split=args.dataset_split,
            )

        if rank == 0:
            ds = _load_ds()
            if world_size > 1:
                ezpz.synchronize()
        else:
            if world_size > 1:
                ezpz.synchronize()
            ds = _load_ds()

        prompts = []
        labels = []
        for row in ds:
            if not isinstance(row, dict):
                continue
            text = row.get(args.text_column, "")
            if not (isinstance(text, str) and text.strip()):
                continue
            prompts.append(text.strip())
            if args.label_column:
                lbl = row.get(args.label_column)
                labels.append(str(lbl) if lbl is not None else None)
            else:
                labels.append(None)
            if len(prompts) >= args.max_samples:
                break

        if not prompts:
            if rank == 0:
                logger.error(
                    "No prompts found in column %r of %s/%s split=%s — aborting.",
                    args.text_column, args.dataset, args.dataset_config,
                    args.dataset_split,
                )
            ezpz.cleanup()
            return 1

    my_indices = shard_indices(len(prompts), rank, world_size)
    my_prompts = [prompts[i] for i in my_indices]
    my_labels = [labels[i] for i in my_indices]

    if rank == 0:
        logger.info(
            "Total samples: %d%d per rank (rank 0 has %d)",
            len(prompts),
            (len(prompts) + world_size - 1) // world_size,
            len(my_prompts),
        )

    # ── Output paths ───────────────────────────────────────────────
    module_name = "ezpz.examples.inference"
    outdir = get_example_outdir(module_name)
    if rank == 0:
        logger.info("Outputs will be saved to %s", outdir)

    history = ezpz.History(
        project_name=module_name,
        config={
            "model": args.model,
            "dataset": f"{args.dataset}/{args.dataset_config}",
            "world_size": world_size,
            "batch_size": args.batch_size,
            "max_new_tokens": args.max_new_tokens,
            "dtype": args.dtype,
        },
        outdir=outdir,
        report_dir=outdir,
        report_enabled=True,
        distributed_history=(1 < world_size <= 384),
    )

    # Predictions file: only useful for generate / eval.  Disable in
    # benchmark mode (the "completions" are random-token noise).
    pred_file = None
    if args.save_predictions and args.mode != "benchmark":
        pred_path = Path(outdir) / f"predictions-rank{rank}.jsonl"
        pred_path.parent.mkdir(parents=True, exist_ok=True)
        pred_file = pred_path.open("w", encoding="utf-8")

    # ── Inference loop ─────────────────────────────────────────────
    gen_kwargs: dict[str, Any] = {
        "max_new_tokens": args.max_new_tokens,
        "do_sample": args.do_sample,
        "pad_token_id": pad_id,
    }
    if args.do_sample:
        gen_kwargs["temperature"] = args.temperature
        gen_kwargs["top_p"] = args.top_p

    total_samples = 0
    total_new_tokens = 0
    n_correct = 0  # only meaningful in --mode eval (labeled)
    n_with_label = 0
    # For unlabeled --mode eval (next-token prediction):
    n_token_correct = 0
    n_tokens_scored = 0
    nll_sum = 0.0  # sum of -log p(next token); perplexity = exp(nll/n)
    benchmark_warmup = (
        args.benchmark_warmup if args.mode == "benchmark" else 0
    )
    t_start = time.perf_counter()

    with torch.inference_mode():
        for batch_idx in range(0, len(my_prompts), args.batch_size):
            batch_prompts = my_prompts[batch_idx : batch_idx + args.batch_size]
            batch_labels = my_labels[batch_idx : batch_idx + args.batch_size]
            if not batch_prompts:
                continue

            ezpz.synchronize()
            t0 = time.perf_counter()

            if args.mode == "benchmark":
                # Synthesize random input_ids — skip tokenizer entirely
                input_ids = torch.randint(
                    0,
                    int(getattr(tokenizer, "vocab_size", 32000)),
                    (len(batch_prompts), args.max_input_tokens),
                    device=device,
                    dtype=torch.long,
                )
                attention_mask = torch.ones_like(input_ids)
            else:
                enc = tokenizer(
                    batch_prompts,
                    padding=True,
                    truncation=True,
                    max_length=args.max_input_tokens,
                    return_tensors="pt",
                ).to(device)
                input_ids = enc.input_ids
                attention_mask = enc.attention_mask

            # Unlabeled eval path: forward pass + score next-token
            # prediction. No autoregressive generation.
            decoded: list[str] = []
            batch_token_correct = 0
            batch_token_scored = 0
            batch_nll_sum = 0.0
            n_new = 0
            measured_flops = 0  # only set when this step runs FlopCounterMode
            batch_num = batch_idx // args.batch_size
            measure_flops_now = (
                args.flops
                and args.flops_every_n_steps > 0
                and batch_num % args.flops_every_n_steps == 0
            )
            if eval_unlabeled:
                outputs, measured_flops = _run_with_optional_flops(
                    model,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    measure=measure_flops_now,
                )
                logits = outputs.logits  # [B, T, V]
                # Predict token i+1 from logits at position i
                shift_logits = logits[:, :-1, :].contiguous()
                shift_targets = input_ids[:, 1:].contiguous()
                # Mask padded positions out of the score
                shift_mask = attention_mask[:, 1:].to(torch.bool)
                preds = shift_logits.argmax(dim=-1)
                correct = (preds == shift_targets) & shift_mask
                batch_token_correct = int(correct.sum().item())
                batch_token_scored = int(shift_mask.sum().item())
                # Cross-entropy in float32 for numerical stability
                ce = torch.nn.functional.cross_entropy(
                    shift_logits.float().reshape(-1, shift_logits.size(-1)),
                    shift_targets.reshape(-1),
                    reduction="none",
                ).reshape(shift_targets.shape)
                batch_nll_sum = float((ce * shift_mask).sum().item())
                ezpz.synchronize()
                dt = time.perf_counter() - t0
            else:
                output_ids, measured_flops = _run_with_optional_flops(
                    model.generate,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    measure=measure_flops_now,
                    **gen_kwargs,
                )
                ezpz.synchronize()
                dt = time.perf_counter() - t0
                # Slice off the prompt portion so we report only new tokens
                input_len = input_ids.shape[1]
                new_token_ids = output_ids[:, input_len:]
                n_new = int((new_token_ids != pad_id).sum().item())
                # Decode (skip in benchmark mode — random tokens aren't useful)
                if args.mode != "benchmark":
                    decoded = tokenizer.batch_decode(
                        new_token_ids,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=False,
                    )

            n_in = int(attention_mask.sum().item())
            is_warmup = batch_num < benchmark_warmup

            # Labeled-eval string match (only when we have labels)
            batch_correct = 0
            batch_with_label = 0
            if args.mode == "eval" and not eval_unlabeled:
                for completion, label in zip(decoded, batch_labels):
                    if label is None:
                        continue
                    batch_with_label += 1
                    if _normalize(completion) == _normalize(label) or (
                        _normalize(label) and _normalize(label) in _normalize(completion)
                    ):
                        batch_correct += 1
                n_correct += batch_correct
                n_with_label += batch_with_label

            if eval_unlabeled:
                n_token_correct += batch_token_correct
                n_tokens_scored += batch_token_scored
                nll_sum += batch_nll_sum

            if not is_warmup:
                total_samples += len(batch_prompts)
                total_new_tokens += n_new

            # In unlabeled-eval mode there are no generated tokens —
            # we ran a single forward pass and scored prompt tokens.
            # Report tokens/sec and MFU based on tokens *scored*
            # (one forward pass per token), not generated.
            if eval_unlabeled:
                tokens_for_throughput = batch_token_scored
            else:
                tokens_for_throughput = n_new

            metrics: dict[str, Any] = {
                "batch": batch_num,
                "samples": len(batch_prompts),
                "input_tokens": n_in,
                "dt": dt,
                "tokens_per_sec": (
                    tokens_for_throughput / dt if dt > 0 else 0.0
                ),
                "warmup": is_warmup,
            }
            # In unlabeled-eval mode there are no generated tokens —
            # only scored tokens.  Reporting "new_tokens=0" on every
            # row would be noise in the JSONL/CSV; emit one or the
            # other based on the mode.
            if eval_unlabeled:
                metrics["tokens_scored"] = batch_token_scored
            else:
                metrics["new_tokens"] = n_new
            # FLOPS / MFU only when --flops was set AND this step
            # actually ran FlopCounterMode.  Skip on the in-between
            # batches when --flops-every-n-steps > 1.  No approximation
            # is reported — the metrics are simply absent on unmeasured
            # steps (you can filter on flops_measured == True downstream).
            if measured_flops > 0 and dt > 0:
                metrics["tflops"] = measured_flops / dt / 1e12
                metrics["mfu"] = compute_mfu(measured_flops, dt)
                metrics["flops_measured"] = True
            if args.mode == "eval" and batch_with_label > 0:
                metrics["accuracy"] = batch_correct / batch_with_label
            if eval_unlabeled and batch_token_scored > 0:
                metrics["next_token_accuracy"] = (
                    batch_token_correct / batch_token_scored
                )
                metrics["perplexity"] = float(
                    torch.exp(
                        torch.tensor(batch_nll_sum / batch_token_scored)
                    ).item()
                )

            if not is_warmup:
                history.update(metrics)

            if pred_file is not None:
                for i, (prompt, completion) in enumerate(zip(batch_prompts, decoded)):
                    row = {
                        "rank": rank,
                        "prompt": prompt,
                        "completion": completion,
                    }
                    if args.mode == "eval":
                        row["label"] = batch_labels[i]
                    pred_file.write(json.dumps(row) + "\n")

            if rank == 0:
                tag = " [warmup]" if is_warmup else ""
                parts = [
                    f"batch={batch_num}{tag}",
                    f"samples={metrics['samples']}",
                ]
                # In unlabeled-eval mode there are no generated
                # tokens — show what was *scored* instead.
                if "tokens_scored" in metrics:
                    parts.append(f"tokens_scored={metrics['tokens_scored']}")
                else:
                    parts.append(f"new_tokens={metrics['new_tokens']}")
                parts.extend([
                    f"dt={dt:.3f}s",
                    f"tps={metrics['tokens_per_sec']:.1f}",
                ])
                if "tflops" in metrics:
                    parts.append(f"tflops={metrics['tflops']:.2f}")
                if "mfu" in metrics:
                    parts.append(f"mfu={metrics['mfu']:.2f}%")
                if "accuracy" in metrics:
                    parts.append(f"accuracy={metrics['accuracy']:.3f}")
                if "next_token_accuracy" in metrics:
                    parts.append(
                        f"next_token_acc={metrics['next_token_accuracy']:.3f}"
                    )
                if "perplexity" in metrics:
                    parts.append(f"perplexity={metrics['perplexity']:.2f}")
                logger.info(" ".join(parts))

    if pred_file is not None:
        pred_file.close()

    t_total = time.perf_counter() - t_start
    if rank == 0:
        logger.info(
            "Done [%s] — %d samples, %d new tokens in %.1fs (%.1f tok/s aggregate per rank)",
            args.mode,
            total_samples,
            total_new_tokens,
            t_total,
            total_new_tokens / t_total if t_total > 0 else 0.0,
        )
        if args.mode == "eval" and not eval_unlabeled and n_with_label > 0:
            logger.info(
                "Accuracy [rank 0]: %d/%d = %.2f%%",
                n_correct, n_with_label,
                100.0 * n_correct / n_with_label,
            )
        if eval_unlabeled and n_tokens_scored > 0:
            avg_nll = nll_sum / n_tokens_scored
            ppl = float(torch.exp(torch.tensor(avg_nll)).item())
            logger.info(
                "Next-token [rank 0]: accuracy=%d/%d=%.2f%%  perplexity=%.3f  (NLL/token=%.4f)",
                n_token_correct, n_tokens_scored,
                100.0 * n_token_correct / n_tokens_scored,
                ppl, avg_nll,
            )

    if rank == 0:
        history.finalize(
            outdir=outdir,
            run_name=module_name,
            dataset_fname="inference",
            verbose=False,
        )

    ezpz.cleanup()
    return 0

parse_args(argv=None) ⚓︎

Parse inference command-line arguments.

Source code in src/ezpz/examples/inference.py
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    """Parse inference command-line arguments."""
    parser = argparse.ArgumentParser(
        prog="ezpz.examples.inference",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=(
            "Distributed inference over a HuggingFace model + dataset. "
            "Three modes: --mode benchmark (throughput), generate "
            "(synthetic data corpus), eval (accuracy on labeled data). "
            "Each rank processes a disjoint shard of prompts."
        ),
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="generate",
        choices=["benchmark", "generate", "eval"],
        help=(
            "Inference mode: "
            "'benchmark' = synthetic random tokens, no dataset, focus "
            "on tokens/sec/MFU. "
            "'generate' = dataset prompts → completions, save to JSONL "
            "(synthetic data / distillation use case). "
            "'eval' = dataset prompts + gold labels, compare generated "
            "text to label, report accuracy."
        ),
    )
    parser.add_argument(
        "--model",
        type=str,
        default="meta-llama/Llama-3.2-1B",
        help="HuggingFace model name or local path",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="wikitext",
        help="HuggingFace dataset name or local path (ignored in --mode benchmark)",
    )
    parser.add_argument(
        "--dataset-config",
        type=str,
        default="wikitext-2-raw-v1",
        help="Dataset configuration (subset name)",
    )
    parser.add_argument(
        "--dataset-split",
        type=str,
        default="test",
        help="Dataset split (train/validation/test)",
    )
    parser.add_argument(
        "--text-column",
        type=str,
        default="text",
        help="Dataset column containing the prompt text",
    )
    parser.add_argument(
        "--label-column",
        type=str,
        default=None,
        help="Dataset column containing the gold label (required for --mode eval)",
    )
    parser.add_argument(
        "--benchmark-iters",
        type=int,
        default=20,
        help="Number of benchmark iterations (only with --mode benchmark)",
    )
    parser.add_argument(
        "--benchmark-warmup",
        type=int,
        default=3,
        help="Warmup iterations to skip when reporting (only with --mode benchmark)",
    )
    parser.add_argument(
        "--max-samples",
        type=int,
        default=128,
        help="Maximum number of samples to process across all ranks",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=4,
        help="Per-rank batch size",
    )
    parser.add_argument(
        "--max-input-tokens",
        type=int,
        default=512,
        help="Truncate prompts to this many tokens",
    )
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=64,
        help="Maximum tokens to generate per sample",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["float32", "float16", "bfloat16"],
        help="Model dtype",
    )
    parser.add_argument(
        "--flops",
        action="store_true",
        help=(
            "Measure exact per-batch FLOPS via FlopCounterMode and "
            "report tflops + mfu in metrics. Off by default — without "
            "this flag, MFU/TFLOPS columns are simply omitted (rather "
            "than reporting approximated values). Adds ~15-40%% "
            "overhead per step."
        ),
    )
    parser.add_argument(
        "--flops-every-n-steps",
        type=int,
        default=1,
        help=(
            "When --flops is set, measure FLOPS every N steps. "
            "Use a higher value to amortize the overhead across batches."
        ),
    )
    parser.add_argument(
        "--do-sample",
        action="store_true",
        help="Use sampling instead of greedy decoding",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="Sampling temperature (only used with --do-sample)",
    )
    parser.add_argument(
        "--top-p",
        type=float,
        default=1.0,
        help="Nucleus sampling threshold (only used with --do-sample)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        help="Random seed",
    )
    parser.add_argument(
        # BooleanOptionalAction generates both --save-predictions and
        # --no-save-predictions automatically.  The previous hand-rolled
        # store_true/store_false pair rendered as
        #   --no-save-predictions ... (default: True)
        # which read as if the *flag* defaulted to True instead of the
        # underlying behavior.  BooleanOptionalAction shows
        #   --save-predictions, --no-save-predictions ... (default: True)
        # which is the standard argparse way to express this and reads
        # correctly under ArgumentDefaultsHelpFormatter.
        "--save-predictions",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Write per-sample predictions to JSONL",
    )
    return parser.parse_args(argv)

shard_indices(total, rank, world_size) ⚓︎

Return the subset of indices [0, total) assigned to rank.

Uses contiguous block sharding — rank r gets indices [r*chunk, (r+1)*chunk) where chunk = ceil(total / world_size). The last rank may receive fewer samples.

Source code in src/ezpz/examples/inference.py
def shard_indices(total: int, rank: int, world_size: int) -> list[int]:
    """Return the subset of indices [0, total) assigned to *rank*.

    Uses contiguous block sharding — rank ``r`` gets indices
    ``[r*chunk, (r+1)*chunk)`` where ``chunk = ceil(total / world_size)``.
    The last rank may receive fewer samples.
    """
    if total <= 0 or world_size <= 0:
        return []
    chunk = (total + world_size - 1) // world_size
    start = rank * chunk
    end = min(start + chunk, total)
    return list(range(start, end))