🔮 ezpz.examples.inference⚓︎
Distributed inference over a HuggingFace model + dataset, with three distinct modes: benchmark, generate, and eval.
Each rank loads the model and processes a disjoint shard of inputs
(data parallelism). Per-batch latency and throughput (tokens/sec)
are tracked through ezpz.History. With --mode eval an accuracy
metric is added. TFLOPS and MFU are opt-in via --flops —
without that flag the columns are simply absent rather than
approximated (see MFU Tracking (opt-in)).
Source⚓︎
src/ezpz/examples/inference.py
"""Distributed inference over a HuggingFace model + dataset.
A general-purpose inference example: shard a HuggingFace dataset across
ranks (data parallel), generate completions with a HuggingFace model on
each rank, and aggregate the results.
Launch it with::
ezpz launch python3 -m ezpz.examples.inference \\
--model meta-llama/Llama-3.2-1B \\
--dataset wikitext --dataset-config wikitext-2-raw-v1 \\
--max-samples 256 --max-new-tokens 32
Each rank loads the model, processes its shard of prompts, and
records per-batch latency, throughput (tokens/sec), MFU, and the
generated text. Outputs go to ``outputs/ezpz.examples.inference/<ts>/``:
``predictions-rank<N>.jsonl`` (one row per sample) and a finalized
``History`` dataset with the timing/perf metrics.
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
from typing import Any, Optional, Sequence
import torch
import ezpz
from ezpz.examples import get_example_outdir
from ezpz.flops import compute_mfu
logger = ezpz.get_logger(__name__)
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
"""Parse inference command-line arguments."""
parser = argparse.ArgumentParser(
prog="ezpz.examples.inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=(
"Distributed inference over a HuggingFace model + dataset. "
"Three modes: --mode benchmark (throughput), generate "
"(synthetic data corpus), eval (accuracy on labeled data). "
"Each rank processes a disjoint shard of prompts."
),
)
parser.add_argument(
"--mode",
type=str,
default="generate",
choices=["benchmark", "generate", "eval"],
help=(
"Inference mode: "
"'benchmark' = synthetic random tokens, no dataset, focus "
"on tokens/sec/MFU. "
"'generate' = dataset prompts → completions, save to JSONL "
"(synthetic data / distillation use case). "
"'eval' = dataset prompts + gold labels, compare generated "
"text to label, report accuracy."
),
)
parser.add_argument(
"--model",
type=str,
default="meta-llama/Llama-3.2-1B",
help="HuggingFace model name or local path",
)
parser.add_argument(
"--dataset",
type=str,
default="wikitext",
help="HuggingFace dataset name or local path (ignored in --mode benchmark)",
)
parser.add_argument(
"--dataset-config",
type=str,
default="wikitext-2-raw-v1",
help="Dataset configuration (subset name)",
)
parser.add_argument(
"--dataset-split",
type=str,
default="test",
help="Dataset split (train/validation/test)",
)
parser.add_argument(
"--text-column",
type=str,
default="text",
help="Dataset column containing the prompt text",
)
parser.add_argument(
"--label-column",
type=str,
default=None,
help="Dataset column containing the gold label (required for --mode eval)",
)
parser.add_argument(
"--benchmark-iters",
type=int,
default=20,
help="Number of benchmark iterations (only with --mode benchmark)",
)
parser.add_argument(
"--benchmark-warmup",
type=int,
default=3,
help="Warmup iterations to skip when reporting (only with --mode benchmark)",
)
parser.add_argument(
"--max-samples",
type=int,
default=128,
help="Maximum number of samples to process across all ranks",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Per-rank batch size",
)
parser.add_argument(
"--max-input-tokens",
type=int,
default=512,
help="Truncate prompts to this many tokens",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=64,
help="Maximum tokens to generate per sample",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model dtype",
)
parser.add_argument(
"--flops",
action="store_true",
help=(
"Measure exact per-batch FLOPS via FlopCounterMode and "
"report tflops + mfu in metrics. Off by default — without "
"this flag, MFU/TFLOPS columns are simply omitted (rather "
"than reporting approximated values). Adds ~15-40%% "
"overhead per step."
),
)
parser.add_argument(
"--flops-every-n-steps",
type=int,
default=1,
help=(
"When --flops is set, measure FLOPS every N steps. "
"Use a higher value to amortize the overhead across batches."
),
)
parser.add_argument(
"--do-sample",
action="store_true",
help="Use sampling instead of greedy decoding",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature (only used with --do-sample)",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
help="Nucleus sampling threshold (only used with --do-sample)",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Random seed",
)
parser.add_argument(
# BooleanOptionalAction generates both --save-predictions and
# --no-save-predictions automatically. The previous hand-rolled
# store_true/store_false pair rendered as
# --no-save-predictions ... (default: True)
# which read as if the *flag* defaulted to True instead of the
# underlying behavior. BooleanOptionalAction shows
# --save-predictions, --no-save-predictions ... (default: True)
# which is the standard argparse way to express this and reads
# correctly under ArgumentDefaultsHelpFormatter.
"--save-predictions",
action=argparse.BooleanOptionalAction,
default=True,
help="Write per-sample predictions to JSONL",
)
return parser.parse_args(argv)
def shard_indices(total: int, rank: int, world_size: int) -> list[int]:
"""Return the subset of indices [0, total) assigned to *rank*.
Uses contiguous block sharding — rank ``r`` gets indices
``[r*chunk, (r+1)*chunk)`` where ``chunk = ceil(total / world_size)``.
The last rank may receive fewer samples.
"""
if total <= 0 or world_size <= 0:
return []
chunk = (total + world_size - 1) // world_size
start = rank * chunk
end = min(start + chunk, total)
return list(range(start, end))
def _torch_dtype(name: str) -> torch.dtype:
return {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[name]
def _normalize(text: str) -> str:
"""Lowercase + collapse whitespace + strip — used for eval matching."""
if text is None:
return ""
return " ".join(str(text).lower().split())
def _run_with_optional_flops(
fn,
*args,
measure: bool,
**kwargs,
):
"""Run *fn(*args, **kwargs)* under FlopCounterMode if *measure* is True.
Returns ``(result, measured_flops)``. When *measure* is False or
when the counter raises (XPU, etc.), ``measured_flops`` is ``0``.
Pulled out of the hot loop so the eval-unlabeled forward path and
the generate path share one wrapper instead of duplicating the
enter/exit/try plumbing.
"""
if not measure:
return fn(*args, **kwargs), 0
from torch.utils.flop_counter import FlopCounterMode
with FlopCounterMode(display=False) as fc:
result = fn(*args, **kwargs)
try:
return result, int(fc.get_total_flops())
except Exception:
return result, 0
def main(argv: Optional[Sequence[str]] = None) -> int:
"""Entry point for distributed HF inference."""
# Silence noisy per-request HTTP logs from HF Hub clients
import logging as _logging
for _noisy in ("httpx", "huggingface_hub", "filelock", "urllib3"):
_logging.getLogger(_noisy).setLevel(_logging.WARNING)
# Silence noisy transformers messages (e.g. the BPE
# clean_up_tokenization_spaces warning fired on every decode call)
try:
import transformers as _transformers
_transformers.logging.set_verbosity_error()
_transformers.logging.disable_progress_bar()
except Exception:
pass
args = parse_args(argv)
# ── Distributed setup ──────────────────────────────────────────
rank = ezpz.setup_torch(seed=args.seed)
world_size = ezpz.get_world_size()
device = ezpz.get_torch_device(as_torch_device=True)
dtype = _torch_dtype(args.dtype)
# ── Model + tokenizer ──────────────────────────────────────────
# Load on rank 0 first to populate the HF cache, then let other
# ranks read from cache. Avoids thundering-herd downloads on first
# run from a clean cache.
if rank == 0:
logger.info("Loading model: %s", args.model)
from transformers import AutoModelForCausalLM, AutoTokenizer
def _load_tokenizer_and_model():
tok = AutoTokenizer.from_pretrained(
args.model,
clean_up_tokenization_spaces=False,
)
# Resolve a usable pad id: prefer existing pad → reuse eos →
# add a new <pad> token (last-resort, requires resizing the
# model embedding table to match — done below after model load).
added_pad_token = False
if tok.pad_token_id is None:
if tok.eos_token_id is not None:
tok.pad_token = tok.eos_token
else:
tok.add_special_tokens({"pad_token": "<pad>"})
added_pad_token = True
# Left-pad so generation continues from the end of each prompt
tok.padding_side = "left"
m = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=dtype,
).to(device)
# If we extended the vocab, resize the embedding table so the
# new pad id doesn't index out of bounds.
if added_pad_token:
m.resize_token_embeddings(len(tok))
m.eval()
return tok, m
if rank == 0:
tokenizer, model = _load_tokenizer_and_model()
if world_size > 1:
ezpz.synchronize()
else:
if world_size > 1:
ezpz.synchronize()
tokenizer, model = _load_tokenizer_and_model()
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
if rank == 0:
n_params = sum(p.numel() for p in model.parameters())
logger.info("Model loaded: %.2fB parameters", n_params / 1e9)
# FLOPS / MFU tracking is opt-in via --flops. Without that flag
# the metrics are simply omitted — better than reporting
# approximated numbers that can be misleading (linear scaling
# under-counts attention's O(seq^2) cost; fixed-shape startup
# estimates over-count short batches). When --flops is set, we
# measure exact per-batch FLOPS via FlopCounterMode every
# --flops-every-n-steps batches.
# ── Prompts (and optional labels) ──────────────────────────────
# In benchmark mode we synthesize random tokens — no dataset load.
# In generate / eval mode we pull from a HuggingFace dataset.
# eval_unlabeled = True means we'll score next-token prediction
# directly (no labels, no generation).
prompts: list[str]
labels: list[Optional[str]]
eval_unlabeled = args.mode == "eval" and not args.label_column
if args.mode == "benchmark":
if rank == 0:
logger.info(
"Benchmark mode: %d iterations × batch_size=%d at "
"max_input_tokens=%d, max_new_tokens=%d",
args.benchmark_iters, args.batch_size,
args.max_input_tokens, args.max_new_tokens,
)
# Synthesize random token IDs as "prompts" — we'll skip the
# tokenizer in the loop and feed input_ids directly.
prompts = [""] * args.benchmark_iters * args.batch_size
labels = [None] * len(prompts)
else:
if rank == 0 and eval_unlabeled:
logger.info(
"--mode eval without --label-column: scoring next-token "
"prediction (accuracy + perplexity) on the prompt itself"
)
if rank == 0:
logger.info(
"Loading dataset: %s [%s] split=%s",
args.dataset, args.dataset_config, args.dataset_split,
)
from datasets import load_dataset
def _load_ds():
return load_dataset(
args.dataset, args.dataset_config, split=args.dataset_split,
)
if rank == 0:
ds = _load_ds()
if world_size > 1:
ezpz.synchronize()
else:
if world_size > 1:
ezpz.synchronize()
ds = _load_ds()
prompts = []
labels = []
for row in ds:
if not isinstance(row, dict):
continue
text = row.get(args.text_column, "")
if not (isinstance(text, str) and text.strip()):
continue
prompts.append(text.strip())
if args.label_column:
lbl = row.get(args.label_column)
labels.append(str(lbl) if lbl is not None else None)
else:
labels.append(None)
if len(prompts) >= args.max_samples:
break
if not prompts:
if rank == 0:
logger.error(
"No prompts found in column %r of %s/%s split=%s — aborting.",
args.text_column, args.dataset, args.dataset_config,
args.dataset_split,
)
ezpz.cleanup()
return 1
my_indices = shard_indices(len(prompts), rank, world_size)
my_prompts = [prompts[i] for i in my_indices]
my_labels = [labels[i] for i in my_indices]
if rank == 0:
logger.info(
"Total samples: %d → %d per rank (rank 0 has %d)",
len(prompts),
(len(prompts) + world_size - 1) // world_size,
len(my_prompts),
)
# ── Output paths ───────────────────────────────────────────────
module_name = "ezpz.examples.inference"
outdir = get_example_outdir(module_name)
if rank == 0:
logger.info("Outputs will be saved to %s", outdir)
history = ezpz.History(
project_name=module_name,
config={
"model": args.model,
"dataset": f"{args.dataset}/{args.dataset_config}",
"world_size": world_size,
"batch_size": args.batch_size,
"max_new_tokens": args.max_new_tokens,
"dtype": args.dtype,
},
outdir=outdir,
report_dir=outdir,
report_enabled=True,
distributed_history=(1 < world_size <= 384),
)
# Predictions file: only useful for generate / eval. Disable in
# benchmark mode (the "completions" are random-token noise).
pred_file = None
if args.save_predictions and args.mode != "benchmark":
pred_path = Path(outdir) / f"predictions-rank{rank}.jsonl"
pred_path.parent.mkdir(parents=True, exist_ok=True)
pred_file = pred_path.open("w", encoding="utf-8")
# ── Inference loop ─────────────────────────────────────────────
gen_kwargs: dict[str, Any] = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
"pad_token_id": pad_id,
}
if args.do_sample:
gen_kwargs["temperature"] = args.temperature
gen_kwargs["top_p"] = args.top_p
total_samples = 0
total_new_tokens = 0
n_correct = 0 # only meaningful in --mode eval (labeled)
n_with_label = 0
# For unlabeled --mode eval (next-token prediction):
n_token_correct = 0
n_tokens_scored = 0
nll_sum = 0.0 # sum of -log p(next token); perplexity = exp(nll/n)
benchmark_warmup = (
args.benchmark_warmup if args.mode == "benchmark" else 0
)
t_start = time.perf_counter()
with torch.inference_mode():
for batch_idx in range(0, len(my_prompts), args.batch_size):
batch_prompts = my_prompts[batch_idx : batch_idx + args.batch_size]
batch_labels = my_labels[batch_idx : batch_idx + args.batch_size]
if not batch_prompts:
continue
ezpz.synchronize()
t0 = time.perf_counter()
if args.mode == "benchmark":
# Synthesize random input_ids — skip tokenizer entirely
input_ids = torch.randint(
0,
int(getattr(tokenizer, "vocab_size", 32000)),
(len(batch_prompts), args.max_input_tokens),
device=device,
dtype=torch.long,
)
attention_mask = torch.ones_like(input_ids)
else:
enc = tokenizer(
batch_prompts,
padding=True,
truncation=True,
max_length=args.max_input_tokens,
return_tensors="pt",
).to(device)
input_ids = enc.input_ids
attention_mask = enc.attention_mask
# Unlabeled eval path: forward pass + score next-token
# prediction. No autoregressive generation.
decoded: list[str] = []
batch_token_correct = 0
batch_token_scored = 0
batch_nll_sum = 0.0
n_new = 0
measured_flops = 0 # only set when this step runs FlopCounterMode
batch_num = batch_idx // args.batch_size
measure_flops_now = (
args.flops
and args.flops_every_n_steps > 0
and batch_num % args.flops_every_n_steps == 0
)
if eval_unlabeled:
outputs, measured_flops = _run_with_optional_flops(
model,
input_ids=input_ids,
attention_mask=attention_mask,
measure=measure_flops_now,
)
logits = outputs.logits # [B, T, V]
# Predict token i+1 from logits at position i
shift_logits = logits[:, :-1, :].contiguous()
shift_targets = input_ids[:, 1:].contiguous()
# Mask padded positions out of the score
shift_mask = attention_mask[:, 1:].to(torch.bool)
preds = shift_logits.argmax(dim=-1)
correct = (preds == shift_targets) & shift_mask
batch_token_correct = int(correct.sum().item())
batch_token_scored = int(shift_mask.sum().item())
# Cross-entropy in float32 for numerical stability
ce = torch.nn.functional.cross_entropy(
shift_logits.float().reshape(-1, shift_logits.size(-1)),
shift_targets.reshape(-1),
reduction="none",
).reshape(shift_targets.shape)
batch_nll_sum = float((ce * shift_mask).sum().item())
ezpz.synchronize()
dt = time.perf_counter() - t0
else:
output_ids, measured_flops = _run_with_optional_flops(
model.generate,
input_ids=input_ids,
attention_mask=attention_mask,
measure=measure_flops_now,
**gen_kwargs,
)
ezpz.synchronize()
dt = time.perf_counter() - t0
# Slice off the prompt portion so we report only new tokens
input_len = input_ids.shape[1]
new_token_ids = output_ids[:, input_len:]
n_new = int((new_token_ids != pad_id).sum().item())
# Decode (skip in benchmark mode — random tokens aren't useful)
if args.mode != "benchmark":
decoded = tokenizer.batch_decode(
new_token_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
n_in = int(attention_mask.sum().item())
is_warmup = batch_num < benchmark_warmup
# Labeled-eval string match (only when we have labels)
batch_correct = 0
batch_with_label = 0
if args.mode == "eval" and not eval_unlabeled:
for completion, label in zip(decoded, batch_labels):
if label is None:
continue
batch_with_label += 1
if _normalize(completion) == _normalize(label) or (
_normalize(label) and _normalize(label) in _normalize(completion)
):
batch_correct += 1
n_correct += batch_correct
n_with_label += batch_with_label
if eval_unlabeled:
n_token_correct += batch_token_correct
n_tokens_scored += batch_token_scored
nll_sum += batch_nll_sum
if not is_warmup:
total_samples += len(batch_prompts)
total_new_tokens += n_new
# In unlabeled-eval mode there are no generated tokens —
# we ran a single forward pass and scored prompt tokens.
# Report tokens/sec and MFU based on tokens *scored*
# (one forward pass per token), not generated.
if eval_unlabeled:
tokens_for_throughput = batch_token_scored
else:
tokens_for_throughput = n_new
metrics: dict[str, Any] = {
"batch": batch_num,
"samples": len(batch_prompts),
"input_tokens": n_in,
"dt": dt,
"tokens_per_sec": (
tokens_for_throughput / dt if dt > 0 else 0.0
),
"warmup": is_warmup,
}
# In unlabeled-eval mode there are no generated tokens —
# only scored tokens. Reporting "new_tokens=0" on every
# row would be noise in the JSONL/CSV; emit one or the
# other based on the mode.
if eval_unlabeled:
metrics["tokens_scored"] = batch_token_scored
else:
metrics["new_tokens"] = n_new
# FLOPS / MFU only when --flops was set AND this step
# actually ran FlopCounterMode. Skip on the in-between
# batches when --flops-every-n-steps > 1. No approximation
# is reported — the metrics are simply absent on unmeasured
# steps (you can filter on flops_measured == True downstream).
if measured_flops > 0 and dt > 0:
metrics["tflops"] = measured_flops / dt / 1e12
metrics["mfu"] = compute_mfu(measured_flops, dt)
metrics["flops_measured"] = True
if args.mode == "eval" and batch_with_label > 0:
metrics["accuracy"] = batch_correct / batch_with_label
if eval_unlabeled and batch_token_scored > 0:
metrics["next_token_accuracy"] = (
batch_token_correct / batch_token_scored
)
metrics["perplexity"] = float(
torch.exp(
torch.tensor(batch_nll_sum / batch_token_scored)
).item()
)
if not is_warmup:
history.update(metrics)
if pred_file is not None:
for i, (prompt, completion) in enumerate(zip(batch_prompts, decoded)):
row = {
"rank": rank,
"prompt": prompt,
"completion": completion,
}
if args.mode == "eval":
row["label"] = batch_labels[i]
pred_file.write(json.dumps(row) + "\n")
if rank == 0:
tag = " [warmup]" if is_warmup else ""
parts = [
f"batch={batch_num}{tag}",
f"samples={metrics['samples']}",
]
# In unlabeled-eval mode there are no generated
# tokens — show what was *scored* instead.
if "tokens_scored" in metrics:
parts.append(f"tokens_scored={metrics['tokens_scored']}")
else:
parts.append(f"new_tokens={metrics['new_tokens']}")
parts.extend([
f"dt={dt:.3f}s",
f"tps={metrics['tokens_per_sec']:.1f}",
])
if "tflops" in metrics:
parts.append(f"tflops={metrics['tflops']:.2f}")
if "mfu" in metrics:
parts.append(f"mfu={metrics['mfu']:.2f}%")
if "accuracy" in metrics:
parts.append(f"accuracy={metrics['accuracy']:.3f}")
if "next_token_accuracy" in metrics:
parts.append(
f"next_token_acc={metrics['next_token_accuracy']:.3f}"
)
if "perplexity" in metrics:
parts.append(f"perplexity={metrics['perplexity']:.2f}")
logger.info(" ".join(parts))
if pred_file is not None:
pred_file.close()
t_total = time.perf_counter() - t_start
if rank == 0:
logger.info(
"Done [%s] — %d samples, %d new tokens in %.1fs (%.1f tok/s aggregate per rank)",
args.mode,
total_samples,
total_new_tokens,
t_total,
total_new_tokens / t_total if t_total > 0 else 0.0,
)
if args.mode == "eval" and not eval_unlabeled and n_with_label > 0:
logger.info(
"Accuracy [rank 0]: %d/%d = %.2f%%",
n_correct, n_with_label,
100.0 * n_correct / n_with_label,
)
if eval_unlabeled and n_tokens_scored > 0:
avg_nll = nll_sum / n_tokens_scored
ppl = float(torch.exp(torch.tensor(avg_nll)).item())
logger.info(
"Next-token [rank 0]: accuracy=%d/%d=%.2f%% perplexity=%.3f (NLL/token=%.4f)",
n_token_correct, n_tokens_scored,
100.0 * n_token_correct / n_tokens_scored,
ppl, avg_nll,
)
if rank == 0:
history.finalize(
outdir=outdir,
run_name=module_name,
dataset_fname="inference",
verbose=False,
)
ezpz.cleanup()
return 0
if __name__ == "__main__":
raise SystemExit(main())
Code Walkthrough⚓︎
shard_indices — data-parallel input partitioning
Each rank computes its slice of the input array; the helper handles uneven splits and the "more ranks than samples" edge case.
return parser.parse_args(argv)
def shard_indices(total: int, rank: int, world_size: int) -> list[int]:
"""Return the subset of indices [0, total) assigned to *rank*.
Uses contiguous block sharding — rank ``r`` gets indices
``[r*chunk, (r+1)*chunk)`` where ``chunk = ceil(total / world_size)``.
The last rank may receive fewer samples.
"""
if total <= 0 or world_size <= 0:
return []
chunk = (total + world_size - 1) // world_size
start = rank * chunk
_run_with_optional_flops — opt-in FLOPS measurement
The eval-unlabeled forward path and the model.generate path share
this wrapper. When measure=False (the default — no --flops),
the call is a passthrough; otherwise it runs inside FlopCounterMode
and returns the measured FLOP count alongside the result.
return " ".join(str(text).lower().split())
def _run_with_optional_flops(
fn,
*args,
measure: bool,
**kwargs,
):
"""Run *fn(*args, **kwargs)* under FlopCounterMode if *measure* is True.
Returns ``(result, measured_flops)``. When *measure* is False or
when the counter raises (XPU, etc.), ``measured_flops`` is ``0``.
Pulled out of the hot loop so the eval-unlabeled forward path and
the generate path share one wrapper instead of duplicating the
enter/exit/try plumbing.
"""
if not measure:
return fn(*args, **kwargs), 0
from torch.utils.flop_counter import FlopCounterMode
with FlopCounterMode(display=False) as fc:
result = fn(*args, **kwargs)
try:
Per-batch metrics — measured vs absent
tflops and mfu are only added to the metrics dict when the step
actually ran FlopCounterMode. No approximation is reported on
unmeasured batches; downstream filters can rely on
flops_measured == True to identify trustworthy points.
The unlabeled-eval path emits tokens_scored (one forward pass over
the prompt batch); the generate path emits new_tokens (one forward
per autoregressive token). Never both.
Eval-unlabeled scoring — perplexity from logits
Without --label-column, eval mode runs a single forward pass per
batch and scores next-token prediction at every position. Argmax
accuracy and per-token cross-entropy go into the metrics; perplexity
= exp(NLL/token) is reported per batch and overall.
This is much cheaper than autoregressive generation and matches the "language modeling perplexity" you'd get from eval-harness tools.
Modes⚓︎
# Throughput benchmark — synthetic random tokens, no dataset
ezpz launch python3 -m ezpz.examples.inference --mode benchmark
# Synthetic data generation — dataset prompts → completions JSONL
ezpz launch python3 -m ezpz.examples.inference --mode generate
# Evaluation — generate, then score against gold labels
ezpz launch python3 -m ezpz.examples.inference --mode eval \
--dataset gsm8k --dataset-config main --dataset-split test \
--text-column question --label-column answer
| Mode | Source | Saves predictions? | Reports accuracy? |
|---|---|---|---|
benchmark |
random tokens | no | no |
generate (default) |
dataset prompts | yes | no |
eval |
dataset prompts + labels | yes (with label) |
yes |
--mode benchmark⚓︎
Pure throughput measurement. Skips the tokenizer entirely and feeds
random input_ids of shape (batch_size, max_input_tokens) into
the model. Configurable via:
--benchmark-iters(default20) — number of forward passes--benchmark-warmup(default3) — warmup iters excluded from totals
Use this for hardware/configuration comparisons (different
--batch-size, --dtype, --world-size, --max-input-tokens).
The reported tokens/sec and MFU are uncontaminated by tokenizer
overhead or dataset variance.
--mode generate⚓︎
Reads prompts from a HuggingFace dataset, generates completions, and
writes them to predictions-rank<N>.jsonl. Useful for:
- Synthetic data generation for downstream training
- Distillation — generate teacher completions to train a student
- Spot-checking model behavior on a known corpus
--mode eval (with --label-column)⚓︎
Same as generate, but also extracts a gold label from
--label-column and compares the completion. The match rule is
"normalized exact-match OR substring": both completion and label
are lowercased, whitespace-collapsed, then checked. The label
counts as correct if it appears as a substring of the normalized
completion (handles "the answer is 42" vs gold "42").
Reports accuracy = correct / labeled per batch (in History)
and overall (in the final log line).
--mode eval (without --label-column)⚓︎
If you omit --label-column, eval falls back to next-token
prediction scoring — useful on any text dataset, no labels needed:
ezpz launch python3 -m ezpz.examples.inference --mode eval \
--dataset wikitext --dataset-config wikitext-2-raw-v1
This path:
- Runs one forward pass over the full prompt batch (not generation)
- Computes argmax accuracy at each position (token
i+1predicted from logits at positioni) - Computes per-token cross-entropy → perplexity =
exp(NLL/token) - Reports
next_token_accuracyandperplexityper batch + overall
Much faster than the labeled path (one forward pass vs autoregressive generation), and gives the standard "language modeling perplexity" metric you'd see in eval-harness tools — without needing a labeled dataset.
The per-batch log line shows tokens_scored= instead of new_tokens=
to make the difference obvious.
How sharding works⚓︎
Inputs are loaded once on every rank, then each rank processes a
contiguous block via shard_indices:
total = 10 samples, world_size = 4
rank 0: [0, 1, 2]
rank 1: [3, 4, 5]
rank 2: [6, 7, 8]
rank 3: [9]
This is data parallelism — each rank holds the full model. There is no model-parallel sharding, so the model must fit in a single device's memory. For larger models, use a model-parallel inference framework (vLLM, DeepSpeed-Inference) instead.
CLI options⚓︎
| Flag | Default | Description |
|---|---|---|
--mode |
generate |
benchmark / generate / eval |
--model |
meta-llama/Llama-3.2-1B |
HF model name or local path |
--dataset |
wikitext |
HF dataset (ignored in benchmark mode) |
--dataset-config |
wikitext-2-raw-v1 |
Dataset configuration / subset |
--dataset-split |
test |
Split (train/validation/test) |
--text-column |
text |
Column containing the prompt |
--label-column |
— | Gold label column (optional for --mode eval; without it, scores next-token prediction) |
--max-samples |
128 |
Total samples across all ranks |
--batch-size |
4 |
Per-rank batch size |
--max-input-tokens |
512 |
Truncate prompts to this many tokens |
--max-new-tokens |
64 |
Tokens to generate per sample |
--dtype |
bfloat16 |
Model dtype (float32/float16/bfloat16) |
--do-sample |
off | Sample instead of greedy decoding |
--temperature |
1.0 |
Sampling temperature (with --do-sample) |
--top-p |
1.0 |
Nucleus sampling cutoff (with --do-sample) |
--benchmark-iters |
20 |
Iterations (only --mode benchmark) |
--benchmark-warmup |
3 |
Warmup iters excluded from totals |
--seed |
0 |
Random seed for token sampling and shard generation |
--flops |
off | Measure real per-batch FLOPS via FlopCounterMode. Without this flag, tflops and mfu are not reported (rather than reporting approximated values). |
--flops-every-n-steps |
1 |
When --flops is set, measure every N steps to amortize the overhead |
--no-save-predictions |
save on | Skip writing per-sample JSONL |
Outputs⚓︎
outputs/ezpz.examples.inference/<timestamp>/
├── predictions-rank0.jsonl # not written in benchmark mode
├── predictions-rank1.jsonl
├── ...
├── inference.h5 # finalized History dataset
├── report-inference.md # markdown summary
└── plots/ # auto-generated metric plots
Each predictions-rank<N>.jsonl row:
// generate mode
{"rank": 0, "prompt": "...", "completion": "..."}
// eval mode
{"rank": 0, "prompt": "...", "completion": "...", "label": "42"}
MFU Tracking (opt-in)⚓︎
tflops and mfu are off by default because the only honest
measurement is via FlopCounterMode, which adds ~15-40% per-step
overhead. Approximated values (linear-scaled startup estimates,
n_tokens × forward_flops, etc.) tend to be misleading — they ignore
attention's O(seq²) cost and KV-cache savings, and have produced
MFU values >100% in practice.
To opt in:
# Measure real FLOPS on every batch
ezpz launch python3 -m ezpz.examples.inference --mode eval --flops
# Or amortize the overhead — measure every 10th batch
ezpz launch python3 -m ezpz.examples.inference \
--mode eval --flops --flops-every-n-steps 10
Behavior:
- Without
--flops:tflopsandmfukeys are absent frommetrics. Per-batch logs show timing/throughput/eval metrics only. - With
--flopsand--flops-every-n-steps 1(default): every batch is profiled. ~15-40% slower overall, exact MFU on every step. - With
--flopsand--flops-every-n-steps N(N > 1): every Nth batch is profiled, others run normally. Profiled batches getmetrics["flops_measured"] = Trueso post-analysis can filter to the trustworthy points.
See ezpz.flops for the
per-device MFU formula.
Help⚓︎
--help
$ ezpz launch python3 -m ezpz.examples.inference --help
usage: ezpz.examples.inference [-h] [--mode {benchmark,generate,eval}]
[--model MODEL] [--dataset DATASET]
[--dataset-config DATASET_CONFIG]
[--dataset-split DATASET_SPLIT]
[--text-column TEXT_COLUMN]
[--label-column LABEL_COLUMN]
[--benchmark-iters BENCHMARK_ITERS]
[--benchmark-warmup BENCHMARK_WARMUP]
[--max-samples MAX_SAMPLES]
[--batch-size BATCH_SIZE]
[--max-input-tokens MAX_INPUT_TOKENS]
[--max-new-tokens MAX_NEW_TOKENS]
[--dtype {float32,float16,bfloat16}] [--flops]
[--flops-every-n-steps FLOPS_EVERY_N_STEPS]
[--do-sample] [--temperature TEMPERATURE]
[--top-p TOP_P] [--seed SEED]
[--save-predictions | --no-save-predictions]
Distributed inference over a HuggingFace model + dataset. Three modes: --mode
benchmark (throughput), generate (synthetic data corpus), eval (accuracy on
labeled data). Each rank processes a disjoint shard of prompts.
options:
-h, --help show this help message and exit
--mode {benchmark,generate,eval}
Inference mode: 'benchmark' = synthetic random tokens,
no dataset, focus on tokens/sec/MFU. 'generate' =
dataset prompts → completions, save to JSONL
(synthetic data / distillation use case). 'eval' =
dataset prompts + gold labels, compare generated text
to label, report accuracy. (default: generate)
--model MODEL HuggingFace model name or local path (default: meta-
llama/Llama-3.2-1B)
--dataset DATASET HuggingFace dataset name or local path (ignored in
--mode benchmark) (default: wikitext)
--dataset-config DATASET_CONFIG
Dataset configuration (subset name) (default:
wikitext-2-raw-v1)
--dataset-split DATASET_SPLIT
Dataset split (train/validation/test) (default: test)
--text-column TEXT_COLUMN
Dataset column containing the prompt text (default:
text)
--label-column LABEL_COLUMN
Dataset column containing the gold label (required for
--mode eval) (default: None)
--benchmark-iters BENCHMARK_ITERS
Number of benchmark iterations (only with --mode
benchmark) (default: 20)
--benchmark-warmup BENCHMARK_WARMUP
Warmup iterations to skip when reporting (only with
--mode benchmark) (default: 3)
--max-samples MAX_SAMPLES
Maximum number of samples to process across all ranks
(default: 128)
--batch-size BATCH_SIZE
Per-rank batch size (default: 4)
--max-input-tokens MAX_INPUT_TOKENS
Truncate prompts to this many tokens (default: 512)
--max-new-tokens MAX_NEW_TOKENS
Maximum tokens to generate per sample (default: 64)
--dtype {float32,float16,bfloat16}
Model dtype (default: bfloat16)
--flops Measure exact per-batch FLOPS via FlopCounterMode and
report tflops + mfu in metrics. Off by default —
without this flag, MFU/TFLOPS columns are simply
omitted (rather than reporting approximated values).
Adds ~15-40% overhead per step. (default: False)
--flops-every-n-steps FLOPS_EVERY_N_STEPS
When --flops is set, measure FLOPS every N steps. Use
a higher value to amortize the overhead across
batches. (default: 1)
--do-sample Use sampling instead of greedy decoding (default:
False)
--temperature TEMPERATURE
Sampling temperature (only used with --do-sample)
(default: 1.0)
--top-p TOP_P Nucleus sampling threshold (only used with --do-
sample) (default: 1.0)
--seed SEED Random seed (default: 0)
--save-predictions, --no-save-predictions
Write per-sample predictions to JSONL (default: True)
See Also⚓︎
ezpz.examples.hf— fine-tuning loop (the training counterpart to this inference example)- Recipes › MFU Tracking