ezpz.flops⚓︎
- See ezpz/
flops.py
FLOPS estimation and Model FLOPS Utilization (MFU) calculation.
Provides utilities for measuring how efficiently a model uses the available hardware compute:
- :func:
get_peak_flops— peak BF16 FLOPS for common accelerators - :func:
estimate_model_flops— count FLOPS for one forward+backward pass - :func:
compute_mfu— calculate MFU% from model FLOPS and step timing
Usage::
import ezpz
from ezpz.flops import estimate_model_flops, compute_mfu
rank = ezpz.setup_torch()
model = MyModel().to(ezpz.get_torch_device())
# Count FLOPS once before training
model_flops = estimate_model_flops(model, input_shape=(batch_size, seq_len))
for step in range(num_steps):
t0 = time.perf_counter()
loss = train_step(model, batch)
ezpz.synchronize()
dt = time.perf_counter() - t0
mfu = compute_mfu(model_flops, dt)
history.update({"loss": loss, "mfu": mfu}, step=step)
compute_mfu(model_flops, step_duration, *, device_name=None, peak_flops=None)
⚓︎
Calculate per-device Model FLOPS Utilization (MFU) as a percentage.
MFU measures what fraction of a single device's theoretical peak compute is used by the model's actual operations::
MFU = model_flops / (peak_flops_per_device × step_duration)
Both model_flops and peak_flops are per-device quantities,
so world_size is not needed. In data-parallel training
(DDP/FSDP), every device performs the same forward+backward work.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_flops
|
int | float
|
FLOPS per forward+backward pass (from
:func: |
required |
step_duration
|
float
|
Wall-clock time for one training step (seconds). |
required |
device_name
|
str | None
|
Device name for peak FLOPS lookup. Auto-detected
if |
None
|
peak_flops
|
float | None
|
Override the peak FLOPS value directly (skips device lookup). |
None
|
Returns:
| Type | Description |
|---|---|
float
|
MFU as a percentage (0–100). Returns 0.0 if inputs are invalid. |
Source code in src/ezpz/flops.py
def compute_mfu(
model_flops: int | float,
step_duration: float,
*,
device_name: str | None = None,
peak_flops: float | None = None,
) -> float:
"""Calculate per-device Model FLOPS Utilization (MFU) as a percentage.
MFU measures what fraction of a single device's theoretical peak
compute is used by the model's actual operations::
MFU = model_flops / (peak_flops_per_device × step_duration)
Both ``model_flops`` and ``peak_flops`` are per-device quantities,
so ``world_size`` is not needed. In data-parallel training
(DDP/FSDP), every device performs the same forward+backward work.
Args:
model_flops: FLOPS per forward+backward pass (from
:func:`estimate_model_flops`). This is the per-device
workload — the same value on every rank.
step_duration: Wall-clock time for one training step (seconds).
device_name: Device name for peak FLOPS lookup. Auto-detected
if ``None``.
peak_flops: Override the peak FLOPS value directly (skips
device lookup).
Returns:
MFU as a percentage (0–100). Returns 0.0 if inputs are invalid.
"""
if step_duration <= 0 or model_flops <= 0:
return 0.0
if peak_flops is None:
peak_flops = get_peak_flops(device_name)
if peak_flops is None:
return 0.0
if peak_flops <= 0:
return 0.0
return (model_flops / step_duration / peak_flops) * 100.0
estimate_model_flops(model, input_shape, *, device=None, backward=True)
⚓︎
Count FLOPS for one forward (+ optional backward) pass.
Uses PyTorch's built-in FlopCounterMode to count actual
floating-point operations, not parameter-based estimates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
The model to profile. |
required |
input_shape
|
tuple[int, ...] | list[int]
|
Shape of the input tensor (e.g. |
required |
device
|
device | str | None
|
Device for the dummy input. Auto-detected if |
None
|
backward
|
bool
|
If |
True
|
Returns:
| Type | Description |
|---|---|
int
|
Total FLOPS (forward + backward if requested). |
Source code in src/ezpz/flops.py
def estimate_model_flops(
model: torch.nn.Module,
input_shape: tuple[int, ...] | list[int],
*,
device: torch.device | str | None = None,
backward: bool = True,
) -> int:
"""Count FLOPS for one forward (+ optional backward) pass.
Uses PyTorch's built-in ``FlopCounterMode`` to count actual
floating-point operations, not parameter-based estimates.
Args:
model: The model to profile.
input_shape: Shape of the input tensor (e.g. ``(batch, seq_len)``
or ``(batch, channels, height, width)``).
device: Device for the dummy input. Auto-detected if ``None``.
backward: If ``True``, also count the backward pass FLOPS
(typically ~2× forward for most architectures).
Returns:
Total FLOPS (forward + backward if requested).
"""
if device is None:
try:
device = next(model.parameters()).device
except StopIteration:
device = "cpu"
# Detect if model expects integer inputs (has an embedding layer)
# vs float inputs (CNN, MLP, etc.)
has_embedding = any(
isinstance(m, (torch.nn.Embedding, torch.nn.EmbeddingBag))
for m in model.modules()
)
if has_embedding:
# Language model: expects Long token IDs
# Find vocab size from the embedding layer
vocab_size = 32000 # fallback
for m in model.modules():
if isinstance(m, torch.nn.Embedding):
vocab_size = m.num_embeddings
break
dummy = torch.randint(0, vocab_size, input_shape, device=device)
else:
# Vision/MLP model: expects float tensors matching model dtype
try:
dtype = next(model.parameters()).dtype
except StopIteration:
dtype = torch.float32
dummy = torch.randn(*input_shape, device=device, dtype=dtype)
# Stay in training mode so dropout/BN paths are counted realistically.
# FlopCounterMode is a measurement, not a forward pass we want to
# ship — flipping to eval() would under-count BN-heavy models.
flops = 0
backward_ran = False
try:
with FlopCounterMode(display=False) as counter:
output = model(dummy)
if backward:
loss = _extract_loss(output)
if loss is not None:
loss.backward()
backward_ran = True
flops = counter.get_total_flops()
except Exception as exc:
logger.debug("FlopCounterMode failed: %s", exc)
flops = 0
finally:
# Only clear gradients if backward actually ran — otherwise
# we'd clobber caller-set grads even though we never touched
# them. This matters for the parameter-fallback path, which
# never executes loss.backward() and must leave grads alone.
if backward_ran:
model.zero_grad(set_to_none=True)
if flops > 0:
return flops
# FlopCounterMode returned 0 (common on XPU / non-CUDA devices).
# Fall back to parameter-based estimate:
# forward ≈ 2 * params * tokens, backward ≈ 4 * params * tokens
# total ≈ 6 * params * tokens (Kaplan et al.)
num_params = sum(p.numel() for p in model.parameters())
# For embedding models, last dim is sequence length (tokens per sample)
# For vision/MLP, use product of spatial dims as "elements per sample"
batch_size = input_shape[0]
tokens = (
input_shape[-1] if has_embedding
else math.prod(input_shape[1:])
)
multiplier = 6 if backward else 2
fallback = multiplier * num_params * batch_size * tokens
logger.info(
"FlopCounterMode returned 0 — using parameter-based estimate: %.2e",
fallback,
)
return fallback
get_device_name()
⚓︎
Return a human-readable name for the current accelerator.
Tries torch.cuda.get_device_name() for NVIDIA/AMD, falls back
to torch.xpu.get_device_properties() for Intel, and finally
"cpu" if nothing is available.
Source code in src/ezpz/flops.py
def get_device_name() -> str:
"""Return a human-readable name for the current accelerator.
Tries ``torch.cuda.get_device_name()`` for NVIDIA/AMD, falls back
to ``torch.xpu.get_device_properties()`` for Intel, and finally
``"cpu"`` if nothing is available.
"""
if torch.cuda.is_available():
return torch.cuda.get_device_name(0)
if hasattr(torch, "xpu") and torch.xpu.is_available():
props = torch.xpu.get_device_properties(0)
return getattr(props, "name", str(props))
return "cpu"
get_peak_flops(device_name=None)
⚓︎
Return peak BF16 FLOPS for the given device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device_name
|
str | None
|
GPU name string (e.g. |
None
|
Returns:
| Type | Description |
|---|---|
float | None
|
Peak FLOPS as a float, or |
float | None
|
recognized (e.g. CPU). |
Source code in src/ezpz/flops.py
def get_peak_flops(
device_name: str | None = None,
) -> float | None:
"""Return peak BF16 FLOPS for the given device.
Args:
device_name: GPU name string (e.g. ``"NVIDIA A100-SXM4-80GB"``).
If ``None``, auto-detected from the current device.
Returns:
Peak FLOPS as a float, or ``None`` if the device is not
recognized (e.g. CPU).
"""
if device_name is None:
device_name = get_device_name()
name = device_name.upper()
# CPU — no meaningful peak FLOPS
if name == "CPU":
return None
# Check known devices (order matters — specific matches first)
for key, flops in _PEAK_FLOPS:
if key.upper() in name:
if flops is not None:
return flops
# Dynamic computation for Intel PVC
return _compute_pvc_peak_flops()
# Unknown accelerator — warn once per device and return None
if device_name not in _WARNED_UNKNOWN_DEVICES:
_WARNED_UNKNOWN_DEVICES.add(device_name)
warnings.warn(
f"Peak FLOPS unknown for {device_name!r} — MFU tracking disabled",
stacklevel=2,
)
return None
try_estimate(model, input_shape, *, device=None, backward=True)
⚓︎
Estimate model FLOPS, returning 0 on failure.
Convenience wrapper around :func:estimate_model_flops that catches
exceptions and logs a warning instead of propagating. Intended to
replace the repeated try/except/log boilerplate in example scripts.
Source code in src/ezpz/flops.py
def try_estimate(
model: torch.nn.Module,
input_shape: tuple[int, ...] | list[int],
*,
device: torch.device | str | None = None,
backward: bool = True,
) -> int:
"""Estimate model FLOPS, returning 0 on failure.
Convenience wrapper around :func:`estimate_model_flops` that catches
exceptions and logs a warning instead of propagating. Intended to
replace the repeated try/except/log boilerplate in example scripts.
"""
try:
flops = estimate_model_flops(
model, input_shape, device=device, backward=backward,
)
if flops > 0:
try:
from ezpz.distributed import get_rank
rank = get_rank()
except Exception:
rank = 0
if rank == 0:
logger.info("Model FLOPS (fwd+bwd): %.2e", flops)
return flops
except Exception as exc:
logger.warning("FLOPS estimation failed: %s", exc)
return 0