ezpz.utils⚓︎
- See ezpz/
utils
Utility functions for debugging, GPU memory monitoring, DeepSpeed configuration, and dataset I/O.
Distributed Debugging⚓︎
ezpz provides debugger classes that work correctly in multi-process
distributed environments.
DistributedPdb⚓︎
A pdb.Pdb subclass that synchronizes all ranks at breakpoints, allowing
interactive debugging of a single rank while others wait at a barrier.
ForkedPdb⚓︎
A pdb.Pdb subclass that works in forked (multiprocessing) contexts by
redirecting stdin.
breakpoint(rank=0)⚓︎
Set a breakpoint that only activates on the specified rank. All other ranks wait at a barrier until the debugger continues.
Debugging rank 0 in a distributed job
GPU Memory Monitoring⚓︎
from ezpz.utils import get_max_memory_allocated, get_max_memory_reserved
device = ezpz.get_torch_device()
peak_allocated = get_max_memory_allocated(device) # in bytes
peak_reserved = get_max_memory_reserved(device) # in bytes
print(f"Peak allocated: {peak_allocated / 1e9:.2f} GB")
print(f"Peak reserved: {peak_reserved / 1e9:.2f} GB")
Both functions work with CUDA and XPU devices, falling back gracefully on unsupported platforms.
Peak FLOPS Lookup⚓︎
Look up the theoretical peak BF16 FLOPS for known GPU types:
from ezpz.utils import get_peak_flops
flops = get_peak_flops("A100") # NVIDIA A100
flops = get_peak_flops("H100 SXM") # NVIDIA H100 SXM
flops = get_peak_flops("H100 NVL") # NVIDIA H100 NVL variant
flops = get_peak_flops("H200") # NVIDIA H200
flops = get_peak_flops("B200") # NVIDIA B200
flops = get_peak_flops("Max 1550") # Intel Data Center GPU Max (PVC)
Note: The lookup is case-sensitive and uses substring matching (e.g.
"A100" in device_name). On systems with lspci available, the function
will auto-detect the GPU name from PCI device listings.
DeepSpeed Config Generators⚓︎
Generate DeepSpeed configuration dictionaries for various ZeRO stages and precision modes:
ZeRO Stage ½ auto config
ZeRO Stage 3 auto config
Precision configs
Full DeepSpeed config
FLOPs profiler config
Dataset I/O⚓︎
Save and load xarray.Dataset objects to/from HDF5 files:
from ezpz.utils import save_dataset, dataset_to_h5pyfile, dataset_from_h5pyfile
# Save a dataset (uses HDF5 by default)
path = save_dataset(dataset, outdir="./data", fname="metrics.h5")
# Direct HDF5 operations
dataset_to_h5pyfile("metrics.h5", dataset)
loaded = dataset_from_h5pyfile("metrics.h5")
ezpz/utils/init.py
DistributedPdb
⚓︎
Bases: Pdb
Supports using PDB from inside a multiprocessing child process.
Usage: DistributedPdb().set_trace()
Source code in src/ezpz/utils/__init__.py
class DistributedPdb(pdb.Pdb):
"""
Supports using PDB from inside a multiprocessing child process.
Usage:
DistributedPdb().set_trace()
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
DummyTqdmFile
⚓︎
Dummy file-like wrapper that forwards writes to tqdm.
Source code in src/ezpz/utils/__init__.py
class DummyTqdmFile:
"""Dummy file-like wrapper that forwards writes to tqdm."""
file = None
def __init__(self, file):
self.file = file
def write(self, text):
if len(text.rstrip()) > 0:
tqdm.tqdm.write(text, file=self.file, end="\n")
def flush(self):
return getattr(self.file, "flush", lambda: None)()
ForkedPdb
⚓︎
Bases: Pdb
PDB subclass for debugging multi-processed code.
Source code in src/ezpz/utils/__init__.py
breakpoint(rank=0)
⚓︎
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rank
|
int
|
Which rank to break on. Default: |
0
|
Source code in src/ezpz/utils/__init__.py
def breakpoint(rank: int = 0):
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
done with the breakpoint before continuing.
Args:
rank (int): Which rank to break on. Default: ``0``
"""
if ezpz.get_rank() == rank:
pdb = DistributedPdb()
pdb.message(
"\n!!! ATTENTION !!!\n\n"
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
)
pdb.set_trace()
# torch.distributed.barrier()
ezpz.distributed.barrier()
check_for_tarball(env_prefix=None, overwrite=False)
⚓︎
Locate or create a .tar.gz of env_prefix; return its absolute path.
Search order (first hit wins, unless overwrite=True):
<env_prefix.parent>/<env_name>.tar.gz— alongside the env. This is where_suggest_tarball_if_present(inezpz.utils.yeet_env) looks first, so co-locating here means a subsequentezpz yeet(no args) will see and suggest it./tmp/<env_name>.tar.gz— node-local fallback.<cwd>/<env_name>.tar.gz— current-directory fallback.
If none exist (or overwrite=True), creates a new tarball at
location #1 (next to the env).
Source code in src/ezpz/utils/__init__.py
def check_for_tarball(
env_prefix: Optional[str | os.PathLike | Path] = None,
overwrite: Optional[bool] = False,
):
"""Locate or create a `.tar.gz` of *env_prefix*; return its absolute path.
Search order (first hit wins, unless ``overwrite=True``):
1. ``<env_prefix.parent>/<env_name>.tar.gz`` — alongside the env.
This is where ``_suggest_tarball_if_present`` (in
``ezpz.utils.yeet_env``) looks first, so co-locating here means
a subsequent ``ezpz yeet`` (no args) will see and suggest it.
2. ``/tmp/<env_name>.tar.gz`` — node-local fallback.
3. ``<cwd>/<env_name>.tar.gz`` — current-directory fallback.
If none exist (or ``overwrite=True``), creates a new tarball at
location #1 (next to the env).
"""
if env_prefix is None:
# NOTE:
# - `sys.executable` looks like:
# `/path/to/some/envs/env_name/bin/python`
fpl = sys.executable.split("/")
# `env_prefix` looks like `/path/to/some/envs/env_name`
env_prefix = "/".join(fpl[:-2])
env_path = Path(env_prefix).resolve()
env_name = env_path.name
tarball_name = f"{env_name}.tar.gz"
candidates = [
env_path.parent / tarball_name, # next to the venv (canonical)
Path("/tmp") / tarball_name,
Path.cwd() / tarball_name,
]
if overwrite:
for c in candidates:
if c.exists():
logger.info(f"Removing existing tarball at {c}")
c.unlink()
else:
for c in candidates:
if c.exists():
logger.info(f"Tarball {c} already exists, skipping creation")
return c
target = candidates[0]
logger.info(f"Creating tarball {target} from {env_prefix}")
make_tarfile(str(target), str(env_prefix))
return target
format_compact_summary(metrics, precision=6, keys_to_skip=None, min_widths=None, constant_keys=None)
⚓︎
Render metrics as a compact key=value(±std) summary line.
Designed to replace the noisy per-step summary that looks like:
loss=0.047 loss/mean=0.030 loss/max=0.120 loss/min=0.011 loss/std=0.022
accuracy=1.000 accuracy/mean=0.997 accuracy/max=1.000 ...
with the much tighter:
loss=0.047(±0.022) accuracy=1.000(±0.006) ...
Rules
- For each base metric
X, look upX/stdin metrics and append it inline asX=val(±std). Other aggregation suffixes (/mean,/min,/max,/avg) are dropped entirely from the console line — trackers still get them. - Memory keys (handled by :func:
format_memory_summary) are skipped here. The caller is expected to append the compact memory token separately at the end of the line. - Counter-like base names (
iter,step,epoch,batch,idx) suppress the(±std)suffix even if std is present — a counter's std is meaningless noise. - Hyperparameters that are replicated across ranks (
lr,momentum,weight_decay,beta1,beta2, …) also suppress(±std)because their per-step std is always 0. Extend the recognised set viaconstant_keys. - Counter tokens are right-padded so successive lines align at
the left edge:
iter=8 loss=...lines up underiter=180 loss=.... Override widths viamin_widths.
Aggregation values (X/mean, X/std, etc.) that have NO
corresponding base value in metrics are still emitted as
standalone keys, so we don't silently lose data.
Source code in src/ezpz/utils/__init__.py
def format_compact_summary(
metrics: dict[str, float],
precision: int = 6,
keys_to_skip: Iterable | None = None,
min_widths: "dict[str, int] | None" = None,
constant_keys: Iterable[str] | None = None,
) -> str:
"""Render *metrics* as a compact ``key=value(±std)`` summary line.
Designed to replace the noisy per-step summary that looks like:
loss=0.047 loss/mean=0.030 loss/max=0.120 loss/min=0.011 loss/std=0.022
accuracy=1.000 accuracy/mean=0.997 accuracy/max=1.000 ...
with the much tighter:
loss=0.047(±0.022) accuracy=1.000(±0.006) ...
Rules:
- For each base metric ``X``, look up ``X/std`` in *metrics* and
append it inline as ``X=val(±std)``. Other aggregation suffixes
(``/mean``, ``/min``, ``/max``, ``/avg``) are dropped entirely
from the console line — trackers still get them.
- Memory keys (handled by :func:`format_memory_summary`) are
skipped here. The caller is expected to append the compact
memory token separately at the end of the line.
- Counter-like base names (``iter``, ``step``, ``epoch``, ``batch``,
``idx``) suppress the ``(±std)`` suffix even if std is present
— a counter's std is meaningless noise.
- Hyperparameters that are replicated across ranks (``lr``,
``momentum``, ``weight_decay``, ``beta1``, ``beta2``, …) also
suppress ``(±std)`` because their per-step std is always 0.
Extend the recognised set via ``constant_keys``.
- Counter tokens are right-padded so successive lines align at
the left edge: ``iter=8 loss=...`` lines up under
``iter=180 loss=...``. Override widths via ``min_widths``.
Aggregation values (``X/mean``, ``X/std``, etc.) that have NO
corresponding base value in *metrics* are still emitted as
standalone keys, so we don't silently lose data.
"""
skip = set(keys_to_skip or ())
# Merge caller-supplied widths over the defaults.
widths: dict[str, int] = dict(_DEFAULT_MIN_WIDTHS)
if min_widths:
widths.update(min_widths)
def _pad(base_name: str, token: str) -> str:
"""Right-pad ``token`` so each line's counter aligns with prior
lines. ``base_name`` strips any namespace prefix (``train/iter``
→ ``iter``) before looking up the configured width."""
leaf = base_name.rsplit("/", 1)[-1]
target = widths.get(leaf)
if target is None or len(token) >= target:
return token
return token + " " * (target - len(token))
# Pre-build a lookup of std values keyed by base name so we can
# match them onto bases in a single pass.
std_lookup: dict[str, float] = {}
aggregation_suffixes = ("/mean", "/min", "/max", "/std", "/avg")
aggregation_keys: set[str] = set()
for k, v in metrics.items():
for suffix in aggregation_suffixes:
if k.endswith(suffix):
aggregation_keys.add(k)
if suffix == "/std":
std_lookup[k[: -len(suffix)]] = float(v)
break
# Counter-like bases for which (±std) is meaningless.
_counter_bases = ("iter", "step", "epoch", "batch", "idx")
def _is_counter(base: str) -> bool:
# Match exact name AND prefixed forms (e.g. "train/iter").
return base.rsplit("/", 1)[-1] in _counter_bases
# Hyperparameters that are replicated identically across ranks —
# their /std is always 0, so emitting `(±0)` or padding a 9-char
# gap to "reserve" space for a future non-zero std is pure noise.
# Treat them like counters: bare `key=value`, no parenthetical,
# no padding. Caller can extend via `constant_keys` kwarg.
_known_constant_bases = frozenset((
"lr", "learning_rate",
"momentum",
"beta1", "beta2",
"weight_decay", "wd",
"eps", "epsilon",
"clip_grad", "grad_clip", "clip_norm", "max_grad_norm",
"warmup_steps", "warmup_iters", "warmup",
))
extra_constants = (
frozenset(constant_keys) if constant_keys is not None else frozenset()
)
def _is_known_constant(base: str) -> bool:
leaf = base.rsplit("/", 1)[-1]
return leaf in _known_constant_bases or leaf in extra_constants
tokens: list[str] = []
seen_bases: set[str] = set()
for k, v in metrics.items():
if k in skip:
continue
if is_memory_metric_key(k):
continue
if k in aggregation_keys:
continue # handled inline (via std_lookup) or skipped
seen_bases.add(k)
base_token = format_pair(k, v, precision=precision)
std = std_lookup.get(k)
if std is not None and not _is_counter(k) and not _is_known_constant(k):
std_token = _format_std(std, precision=precision)
if std_token is None:
# std rounds to zero at the chosen precision (e.g.
# `lr/std=1e-12` with precision=2). `(±0)` adds no
# signal; drop it. Pad with spaces matching the width
# of a full `(±XXXXXX)` parenthetical so this token's
# successor still aligns with its neighbors above/below
# — important for metrics that swing between zero and
# non-zero std across rows (e.g. small/sporadic noise).
pad = " " * (_STD_TOKEN_MAX_WIDTH + 3)
tokens.append(f"{base_token}{pad}")
else:
# Right-align the std token so `(±0.070)`, `(±0.12)`,
# and `(±5.1e-4)` all occupy the same number of columns.
padded = std_token.rjust(_STD_TOKEN_MAX_WIDTH)
tokens.append(f"{base_token}(±{padded})")
else:
# Counters and known-constant hyperparameters: no padding,
# no parenthetical. Counters still get left-edge padding so
# the *next* field aligns across rows.
tokens.append(_pad(k, base_token))
# Emit aggregation keys whose base wasn't present in the dict — so
# we don't silently drop them. (Rare; happens when caller passes
# only an aggregated metric, e.g. ``loss/mean`` without ``loss``.)
for k in aggregation_keys:
base = next(
k[: -len(s)] for s in aggregation_suffixes if k.endswith(s)
)
if base in seen_bases or k in skip:
continue
# Memory-metric aggregations are handled by format_memory_summary;
# never emit them as standalone tokens here.
if is_memory_metric_key(k):
continue
# /std for a base we already showed inline → drop. Other
# aggregations without a base → keep (visible debug info).
if k.endswith("/std") and base in std_lookup and base in seen_bases:
continue
tokens.append(format_pair(k, metrics[k], precision=precision))
# rstrip trailing whitespace introduced by std-None padding on the
# last token — interior padding (between tokens) is preserved by
# the join, so column alignment across rows still works.
return " ".join(tokens).rstrip()
format_memory_summary(metrics, *, device=None, prefix=None)
⚓︎
Condense the four mem_* keys into a single console-friendly string.
Input: a dict that contains (some subset of) the keys produced by
:func:get_memory_metrics — {prefix}mem_alloc,
{prefix}mem_peak_alloc, {prefix}mem_reserved,
{prefix}mem_peak_reserved.
Output: "X.XX/Y.YYGiB (Z%)" where X is current alloc, Y is peak
alloc, and Z is current alloc as a percent of device total memory
(omitted when device total isn't available — e.g. unknown XPU, CPU
fallback).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
metrics
|
dict[str, float]
|
dict that may contain mem_* keys. |
required |
device
|
'torch.device | int | str | None'
|
optional device for total-memory lookup. |
None
|
prefix
|
'str | None'
|
explicit prefix (e.g. |
None
|
Returns an empty string if no mem_* keys are present (CPU/MPS) — so
callers can " ".join(filter(None, [...])) without checking.
Source code in src/ezpz/utils/__init__.py
def format_memory_summary(
metrics: dict[str, float],
*,
device: "torch.device | int | str | None" = None,
prefix: "str | None" = None,
) -> str:
"""Condense the four mem_* keys into a single console-friendly string.
Input: a dict that contains (some subset of) the keys produced by
:func:`get_memory_metrics` — ``{prefix}mem_alloc``,
``{prefix}mem_peak_alloc``, ``{prefix}mem_reserved``,
``{prefix}mem_peak_reserved``.
Output: ``"X.XX/Y.YYGiB (Z%)"`` where X is current alloc, Y is peak
alloc, and Z is current alloc as a percent of device total memory
(omitted when device total isn't available — e.g. unknown XPU, CPU
fallback).
Args:
metrics: dict that may contain mem_* keys.
device: optional device for total-memory lookup. ``int`` index
or ``torch.device('cuda:N')`` honored; ``None`` uses the
local rank's device.
prefix: explicit prefix (e.g. ``"train/"``). When ``None``
(default), the prefix is inferred by scanning ``metrics`` for
``*mem_alloc`` / ``*mem_peak_alloc`` keys — so callers that
don't know whether their metrics are namespaced don't have
to probe twice.
Returns an empty string if no mem_* keys are present (CPU/MPS) — so
callers can ``" ".join(filter(None, [...]))`` without checking.
"""
if prefix is None:
prefix = ""
for key in metrics:
if key.endswith("mem_alloc") or key.endswith("mem_peak_alloc"):
# Strip the suffix to recover whatever namespace the
# caller used (e.g. "train/", "eval/", or "").
if key.endswith("mem_peak_alloc"):
prefix = key[: -len("mem_peak_alloc")]
else:
prefix = key[: -len("mem_alloc")]
break
alloc = metrics.get(f"{prefix}mem_alloc")
peak = metrics.get(f"{prefix}mem_peak_alloc")
if alloc is None and peak is None:
return ""
# Total device VRAM for the percentage. Lazy-resolve to avoid the
# import cost when caller has no memory keys to format anyway.
# Normalize `device` → int index where possible, so callers passing
# `torch.device('cuda:1')` get the right device's total (not rank 0's
# device by way of get_local_rank()).
pct_str = ""
try:
import ezpz
idx: int | None
if isinstance(device, int):
idx = device
elif device is None:
idx = None
else:
# torch.device('cuda:1').index == 1; torch.device('cuda').index is None
try:
idx = torch.device(device).index
except (TypeError, RuntimeError):
idx = None
props = ezpz.distributed.get_device_properties(idx)
total_bytes = props.get("total_memory", -1)
if total_bytes and total_bytes > 0 and alloc is not None:
total_gib = total_bytes / (1024 ** 3)
pct = 100.0 * alloc / total_gib
pct_str = f" ({pct:.0f}%)"
except Exception:
# Any failure resolving total memory: omit the percentage rather
# than break logging. Raw numbers still print.
pct_str = ""
if alloc is not None and peak is not None:
return f"{alloc:.2f}/{peak:.2f}GiB{pct_str}"
if alloc is not None:
return f"{alloc:.2f}GiB{pct_str}"
# Only peak is present (rare — caller passed `mem_peak_alloc` without
# `mem_alloc`). Format matches the alloc-only branch for consistency.
return f"{peak:.2f}GiB{pct_str}"
format_pair(k, v, precision=6)
⚓︎
Format a key-value pair (supports nested dict/list/tuple/set).
Nested dicts become dotted keys: key.subkey=value Sequences become indexed keys: key[0]=value
Returns a newline-joined string if multiple leaf pairs are produced.
Source code in src/ezpz/utils/__init__.py
def format_pair(k: str, v: Any, precision: int = 6) -> str:
"""Format a key-value pair (supports nested dict/list/tuple/set).
Nested dicts become dotted keys: key.subkey=value
Sequences become indexed keys: key[0]=value
Returns a newline-joined string if multiple leaf pairs are produced.
"""
def _is_int_like(x: Any) -> bool:
return (
isinstance(x, (bool, int, np.integer))
and not isinstance(x, (bool,)) is False
) # keep bool distinct below
def _is_bool_like(x: Any) -> bool:
return isinstance(x, (bool, np.bool_))
def _is_float_like(x: Any) -> bool:
return isinstance(x, (float, np.floating))
def _scalar_str(key: str, val: Any) -> str:
# numpy scalar -> python scalar (helps consistent isinstance checks)
if isinstance(val, np.generic):
val = val.item()
if _is_bool_like(val):
return f"{key}={bool(val)}"
if isinstance(val, (int, np.integer)):
return f"{key}={int(val)}"
if isinstance(val, float):
# be explicit for non-finite floats (avoids ValueError with format spec)
if not math.isfinite(val):
return f"{key}={val}"
return f"{key}={val:.{precision}f}"
# fallback: strings, None, objects, etc.
return f"{key}={val}"
def _flatten(key: str, val: Any) -> list[str]:
# numpy scalar -> python scalar early
if isinstance(val, np.generic):
val = val.item()
if isinstance(val, dict):
out: list[str] = []
for kk, vv in val.items():
out.extend(_flatten(f"{key}.{kk}", vv))
return out
if isinstance(val, (list, tuple)):
out: list[str] = []
for i, vv in enumerate(val):
out.extend(_flatten(f"{key}[{i}]", vv))
return out
if isinstance(val, set):
# sets are unordered; make deterministic
out: list[str] = []
for i, vv in enumerate(sorted(val, key=lambda x: repr(x))):
out.extend(_flatten(f"{key}[{i}]", vv))
return out
return [_scalar_str(key, val)]
return "\n".join(_flatten(k, v))
get_bf16_config_json(enabled=True)
⚓︎
get_current_memory_allocated(device)
⚓︎
Currently allocated memory in bytes on device. 0.0 on CPU/MPS.
Source code in src/ezpz/utils/__init__.py
def get_current_memory_allocated(device: "torch.device | int | str") -> float:
"""Currently allocated memory in bytes on ``device``. 0.0 on CPU/MPS."""
dtype = _device_type(device)
if dtype == "cuda":
return torch.cuda.memory_allocated(device)
if dtype == "xpu":
try:
return torch.xpu.memory_allocated(device)
except (ImportError, AttributeError):
return 0.0
return 0.0
get_current_memory_reserved(device)
⚓︎
Currently reserved memory in bytes on device. 0.0 on CPU/MPS.
Source code in src/ezpz/utils/__init__.py
def get_current_memory_reserved(device: "torch.device | int | str") -> float:
"""Currently reserved memory in bytes on ``device``. 0.0 on CPU/MPS."""
dtype = _device_type(device)
if dtype == "cuda":
return torch.cuda.memory_reserved(device)
if dtype == "xpu":
try:
return torch.xpu.memory_reserved(device)
except (ImportError, AttributeError):
return 0.0
return 0.0
get_deepspeed_adamw_optimizer_config_json(auto_config=True)
⚓︎
Get the deepspeed adamw optimizer config json.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
auto_config
|
bool
|
Whether to use the auto config. Default: |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
Deepspeed adamw optimizer config. |
Source code in src/ezpz/utils/__init__.py
def get_deepspeed_adamw_optimizer_config_json(
auto_config: Optional[bool] = True,
) -> dict:
"""
Get the deepspeed adamw optimizer config json.
Args:
auto_config (bool): Whether to use the auto config. Default: ``True``.
Returns:
dict: Deepspeed adamw optimizer config.
"""
return (
{"type": "AdamW"}
if not auto_config
else {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": True,
"adam_w_mode": True,
},
}
)
get_deepspeed_config_json(auto_config=True, gradient_accumulation_steps=1, gradient_clipping='auto', steps_per_print=10, train_batch_size='auto', train_micro_batch_size_per_gpu='auto', wall_clock_breakdown=False, wandb=True, bf16=True, fp16=None, flops_profiler=None, optimizer=None, scheduler=None, zero_optimization=None, stage=0, allgather_partitions=None, allgather_bucket_size=int(500000000.0), overlap_comm=None, reduce_scatter=True, reduce_bucket_size=int(500000000.0), contiguous_gradients=None, offload_param=None, offload_optimizer=None, stage3_max_live_parameters=int(1000000000.0), stage3_max_reuse_distance=int(1000000000.0), stage3_prefetch_bucket_size=int(500000000.0), stage3_param_persistence_threshold=int(1000000.0), sub_group_size=None, elastic_checkpoint=None, stage3_gather_16bit_weights_on_model_save=None, ignore_unused_parameters=None, round_robin_gradients=None, zero_hpz_partition_size=None, zero_quantized_weights=None, zero_quantized_gradients=None, log_trace_cache_warnings=None, save_config=True, output_file=None, output_dir=None)
⚓︎
Write a deepspeed config to the output directory.
Source code in src/ezpz/utils/__init__.py
def get_deepspeed_config_json(
auto_config: Optional[bool] = True,
gradient_accumulation_steps: int = 1,
gradient_clipping: Optional[str | float] = "auto",
steps_per_print: Optional[int] = 10,
train_batch_size: str = "auto",
train_micro_batch_size_per_gpu: str = "auto",
wall_clock_breakdown: bool = False,
wandb: bool = True, # NOTE: Opinionated, W&B is enabled by default
bf16: bool = True, # NOTE: Opinionated, BF16 is enabled by default
fp16: Optional[bool] = None,
flops_profiler: Optional[dict] = None,
optimizer: Optional[dict] = None,
scheduler: Optional[dict] = None,
zero_optimization: Optional[dict] = None,
stage: Optional[int] = 0,
allgather_partitions: Optional[bool] = None,
allgather_bucket_size: Optional[int] = int(5e8),
overlap_comm: Optional[bool] = None,
reduce_scatter: Optional[bool] = True,
reduce_bucket_size: Optional[int] = int(5e8),
contiguous_gradients: Optional[bool] = None,
offload_param: Optional[dict] = None,
offload_optimizer: Optional[dict] = None,
stage3_max_live_parameters: Optional[int] = int(1e9),
stage3_max_reuse_distance: Optional[int] = int(1e9),
stage3_prefetch_bucket_size: Optional[int] = int(5e8),
stage3_param_persistence_threshold: Optional[int] = int(1e6),
sub_group_size: Optional[int] = None,
elastic_checkpoint: Optional[dict] = None,
stage3_gather_16bit_weights_on_model_save: Optional[bool] = None,
ignore_unused_parameters: Optional[bool] = None,
round_robin_gradients: Optional[bool] = None,
zero_hpz_partition_size: Optional[int] = None,
zero_quantized_weights: Optional[bool] = None,
zero_quantized_gradients: Optional[bool] = None,
log_trace_cache_warnings: Optional[bool] = None,
save_config: bool = True,
output_file: Optional[str] = None,
output_dir: Optional[PathLike] = None,
) -> dict[str, Any]:
"""
Write a deepspeed config to the output directory.
"""
import json
wandb_config = {"enabled": wandb}
bf16_config = {"enabled": bf16}
fp16_config = {"enabled": fp16}
flops_profiler_config = (
get_flops_profiler_config_json()
if flops_profiler is None
else flops_profiler
)
optimizer = (
get_deepspeed_adamw_optimizer_config_json()
if optimizer is None
else optimizer
)
scheduler = (
get_deepspeed_warmup_decay_scheduler_config_json()
if scheduler is None
else scheduler
)
if stage is not None and int(stage) > 0:
zero_optimization = (
get_deepspeed_zero_config_json(
stage=stage,
allgather_partitions=allgather_partitions,
allgather_bucket_size=allgather_bucket_size,
overlap_comm=overlap_comm,
reduce_scatter=reduce_scatter,
reduce_bucket_size=reduce_bucket_size,
contiguous_gradients=contiguous_gradients,
offload_param=offload_param,
offload_optimizer=offload_optimizer,
stage3_max_live_parameters=stage3_max_live_parameters,
stage3_max_reuse_distance=stage3_max_reuse_distance,
stage3_prefetch_bucket_size=stage3_prefetch_bucket_size,
stage3_param_persistence_threshold=stage3_param_persistence_threshold,
sub_group_size=sub_group_size,
elastic_checkpoint=elastic_checkpoint,
stage3_gather_16bit_weights_on_model_save=stage3_gather_16bit_weights_on_model_save,
ignore_unused_parameters=ignore_unused_parameters,
round_robin_gradients=round_robin_gradients,
zero_hpz_partition_size=zero_hpz_partition_size,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_gradients=zero_quantized_gradients,
log_trace_cache_warnings=log_trace_cache_warnings,
)
if zero_optimization is None
else zero_optimization
)
else:
zero_optimization = None
ds_config = {
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_clipping": gradient_clipping,
"steps_per_print": steps_per_print,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": train_micro_batch_size_per_gpu,
"wall_clock_breakdown": wall_clock_breakdown,
"wandb": wandb,
"bf16": bf16,
"fp16": fp16,
"flops_profiler": flops_profiler,
"optimizer": optimizer,
"scheduler": scheduler,
"zero_optimization": zero_optimization,
}
if save_config:
if output_file is None:
if output_dir is None:
output_dir = Path(os.getcwd()).joinpath("ds_configs")
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
outfile = output_dir.joinpath("deepspeed_config.json")
else:
outfile = Path(output_file)
logger.info(f"Saving DeepSpeed config to: {outfile.as_posix()}")
logger.info(json.dumps(ds_config, indent=4))
with outfile.open("w") as f:
json.dump(
ds_config,
fp=f,
indent=4,
)
return ds_config
get_deepspeed_warmup_decay_scheduler_config_json(auto_config=True)
⚓︎
Get the deepspeed warmup decay scheduler config json.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
auto_config
|
bool
|
Whether to use the auto config. Default: |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
Deepspeed warmup decay scheduler config. |
Source code in src/ezpz/utils/__init__.py
def get_deepspeed_warmup_decay_scheduler_config_json(
auto_config: Optional[bool] = True,
) -> dict:
"""
Get the deepspeed warmup decay scheduler config json.
Args:
auto_config (bool): Whether to use the auto config. Default: ``True``.
Returns:
dict: Deepspeed warmup decay scheduler config.
"""
return (
{"type": "WarmupDecayLR"}
if not auto_config
else {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
},
}
)
get_deepspeed_zero_config_json(zero_config)
⚓︎
get_flops_profiler_config_json(enabled=True, profile_step=1, module_depth=-1, top_modules=1, detailed=True)
⚓︎
Get the deepspeed flops profiler config json.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
enabled
|
bool
|
Whether to use the flops profiler. Default: |
True
|
profile_step
|
int
|
The step to profile. Default: |
1
|
module_depth
|
int
|
The depth of the module. Default: |
-1
|
top_modules
|
int
|
The number of top modules to show. Default: |
1
|
detailed
|
bool
|
Whether to show detailed profiling. Default: |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
Deepspeed flops profiler config. |
Source code in src/ezpz/utils/__init__.py
def get_flops_profiler_config_json(
enabled: bool = True,
profile_step: int = 1,
module_depth: int = -1,
top_modules: int = 1,
detailed: bool = True,
) -> dict:
"""
Get the deepspeed flops profiler config json.
Args:
enabled (bool): Whether to use the flops profiler. Default: ``True``.
profile_step (int): The step to profile. Default: ``1``.
module_depth (int): The depth of the module. Default: ``-1``.
top_modules (int): The number of top modules to show. Default: ``1``.
detailed (bool): Whether to show detailed profiling. Default: ``True``.
Returns:
dict: Deepspeed flops profiler config.
"""
return {
"enabled": enabled,
"profile_step": profile_step,
"module_depth": module_depth,
"top_modules": top_modules,
"detailed": detailed,
}
get_fp16_config_json(enabled=True)
⚓︎
get_max_memory_allocated(device)
⚓︎
Peak allocated memory in bytes on device. 0.0 on CPU/MPS.
Routes to the backend matching device's type, not whichever
accelerator happens to be globally available. So
get_max_memory_allocated("cpu") returns 0.0 even on a CUDA box.
Source code in src/ezpz/utils/__init__.py
def get_max_memory_allocated(device: "torch.device | int | str") -> float:
"""Peak allocated memory in bytes on ``device``. 0.0 on CPU/MPS.
Routes to the backend matching ``device``'s type, not whichever
accelerator happens to be globally available. So
``get_max_memory_allocated("cpu")`` returns 0.0 even on a CUDA box.
"""
dtype = _device_type(device)
if dtype == "cuda":
return torch.cuda.max_memory_allocated(device)
if dtype == "xpu":
try:
return torch.xpu.max_memory_allocated(device)
except (ImportError, AttributeError):
return 0.0
return 0.0
get_max_memory_reserved(device)
⚓︎
Peak reserved memory in bytes on device. 0.0 on CPU/MPS.
Source code in src/ezpz/utils/__init__.py
def get_max_memory_reserved(device: "torch.device | int | str") -> float:
"""Peak reserved memory in bytes on ``device``. 0.0 on CPU/MPS."""
dtype = _device_type(device)
if dtype == "cuda":
return torch.cuda.max_memory_reserved(device)
if dtype == "xpu":
try:
return torch.xpu.max_memory_reserved(device)
except (ImportError, AttributeError):
return 0.0
return 0.0
get_memory_metrics(device=None, *, reset_peak=True, prefix='')
⚓︎
Return device memory metrics in GiB.
Returns 4 keys when supported (CUDA, XPU):
{prefix}mem_alloc currently allocated
{prefix}mem_peak_alloc peak allocated since last reset
{prefix}mem_reserved currently reserved by the allocator
{prefix}mem_peak_reserved peak reserved since last reset
Returns {} on CPU / MPS (silent — caller's metrics dict simply
doesn't gain these keys), and unconditionally when the env var
EZPZ_TRACK_MEMORY=0 is set.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
'torch.device | int | str | None'
|
device to query. If None, uses |
None
|
reset_peak
|
bool
|
if True (default), reset peak counters AFTER reading.
Next call's |
True
|
prefix
|
str
|
optional string prepended to every key. Useful for the
examples that namespace their metrics (e.g. |
''
|
Source code in src/ezpz/utils/__init__.py
def get_memory_metrics(
device: "torch.device | int | str | None" = None,
*,
reset_peak: bool = True,
prefix: str = "",
) -> dict[str, float]:
"""Return device memory metrics in GiB.
Returns 4 keys when supported (CUDA, XPU):
{prefix}mem_alloc currently allocated
{prefix}mem_peak_alloc peak allocated since last reset
{prefix}mem_reserved currently reserved by the allocator
{prefix}mem_peak_reserved peak reserved since last reset
Returns ``{}`` on CPU / MPS (silent — caller's metrics dict simply
doesn't gain these keys), and unconditionally when the env var
``EZPZ_TRACK_MEMORY=0`` is set.
Args:
device: device to query. If None, uses ``ezpz.get_torch_device()``.
reset_peak: if True (default), reset peak counters AFTER reading.
Next call's ``mem_peak_*`` then reflect only what happened
between calls — the standard per-step pattern.
prefix: optional string prepended to every key. Useful for the
examples that namespace their metrics (e.g. ``"train/"``).
"""
if os.environ.get("EZPZ_TRACK_MEMORY", "1") == "0":
return {}
# Lazy default device resolution — only pay the cost when caller
# didn't pass one explicitly.
if device is None:
import ezpz
device = ezpz.get_torch_device()
assert device is not None # type narrowing for pyright
# On CPU/MPS, all four helpers return 0.0. Short-circuit to avoid
# emitting a row of zeros. Route via the canonical device type so
# the check works for torch.device('cpu'), 'cpu', 'cpu:0', etc.
if _device_type(device) not in ("cuda", "xpu"):
return {}
_GIB = 1024 ** 3
metrics = {
f"{prefix}mem_alloc": get_current_memory_allocated(device) / _GIB,
f"{prefix}mem_peak_alloc": get_max_memory_allocated(device) / _GIB,
f"{prefix}mem_reserved": get_current_memory_reserved(device) / _GIB,
f"{prefix}mem_peak_reserved": get_max_memory_reserved(device) / _GIB,
}
if reset_peak:
reset_peak_memory_stats(device)
return metrics
get_timestamp(fstr=None)
⚓︎
Get formatted timestamp.
Returns the current date and time as a formatted string. By default, returns a timestamp in the format 'YYYY-MM-DD-HHMMSS'. A custom format string can be provided to change the output format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fstr
|
str
|
Format string for strftime. If None, uses default format '%Y-%m-%d-%H%M%S'. Defaults to None. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
Formatted timestamp string. |
Examples:
>>> get_timestamp() # Returns something like '2023-12-01-143022'
>>> get_timestamp("%Y-%m-%d") # Returns something like '2023-12-01'
Source code in src/ezpz/utils/__init__.py
def get_timestamp(fstr: Optional[str] = None) -> str:
"""Get formatted timestamp.
Returns the current date and time as a formatted string. By default, returns
a timestamp in the format 'YYYY-MM-DD-HHMMSS'. A custom format string can
be provided to change the output format.
Args:
fstr (str, optional): Format string for strftime. If None, uses default
format '%Y-%m-%d-%H%M%S'. Defaults to None.
Returns:
str: Formatted timestamp string.
Examples:
>>> get_timestamp() # Returns something like '2023-12-01-143022'
>>> get_timestamp("%Y-%m-%d") # Returns something like '2023-12-01'
"""
import datetime
now = datetime.datetime.now()
return (
now.strftime("%Y-%m-%d-%H%M%S") if fstr is None else now.strftime(fstr)
)
grab_tensor(x, force=False)
⚓︎
Convert various tensor/array-like objects to numpy arrays.
This function converts different types of array-like objects (tensors, lists, etc.) to numpy arrays for consistent handling. Supports PyTorch tensors, numpy arrays, and nested lists.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any
|
The object to convert to a numpy array. Can be None, scalar values, lists, numpy arrays, or PyTorch tensors. |
required |
force
|
bool
|
Force conversion even if it requires copying data. Defaults to False. |
False
|
Returns:
| Type | Description |
|---|---|
Union[ndarray, ScalarLike, None]
|
Union[np.ndarray, ScalarLike, None]: Numpy array representation of the input, or the original scalar value, or None if input was None. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If unable to convert a list to array. |
Examples:
>>> import torch
>>> import numpy as np
>>> grab_tensor([1, 2, 3])
array([1, 2, 3])
>>> grab_tensor(torch.tensor([1, 2, 3]))
array([1, 2, 3])
>>> grab_tensor(np.array([1, 2, 3]))
array([1, 2, 3])
Source code in src/ezpz/utils/__init__.py
def grab_tensor(
x: Any, force: bool = False
) -> Union[np.ndarray, ScalarLike, None]:
"""Convert various tensor/array-like objects to numpy arrays.
This function converts different types of array-like objects (tensors, lists, etc.)
to numpy arrays for consistent handling. Supports PyTorch tensors, numpy arrays,
and nested lists.
Args:
x (Any): The object to convert to a numpy array. Can be None, scalar values,
lists, numpy arrays, or PyTorch tensors.
force (bool, optional): Force conversion even if it requires copying data.
Defaults to False.
Returns:
Union[np.ndarray, ScalarLike, None]: Numpy array representation of the input,
or the original scalar value, or None if input was None.
Raises:
ValueError: If unable to convert a list to array.
Examples:
>>> import torch
>>> import numpy as np
>>> grab_tensor([1, 2, 3])
array([1, 2, 3])
>>> grab_tensor(torch.tensor([1, 2, 3]))
array([1, 2, 3])
>>> grab_tensor(np.array([1, 2, 3]))
array([1, 2, 3])
"""
if x is None:
return None
if isinstance(x, (int, float, bool, np.floating)):
return x
if isinstance(x, tuple):
x = list(x)
if isinstance(x, list):
if len(x) == 0:
return np.array([])
if isinstance(x[0], torch.Tensor):
return grab_tensor(torch.stack(x))
if isinstance(x[0], np.ndarray):
return np.stack(x)
if isinstance(x[0], (int, float, bool, np.floating)):
return np.array(x)
if isinstance(x[0], (tuple, list)):
return np.array(x)
else:
raise ValueError(f"Unable to convert list: \n {x=}\n to array")
# else:
# try:
# import tensorflow as tf # type:ignore
# except (ImportError, ModuleNotFoundError) as exc:
# raise exc
# if isinstance(x[0], tf.Tensor):
# return grab_tensor(tf.stack(x))
elif isinstance(x, np.ndarray):
return x
elif isinstance(x, torch.Tensor):
return x.numpy(force=force)
# return x.detach().cpu().numpy()
elif callable(getattr(x, "numpy", None)):
assert callable(getattr(x, "numpy"))
return x.numpy(force=force)
is_memory_metric_key(key)
⚓︎
True if key is one of the 4 mem_* metrics (raw OR aggregated).
Matches both
- raw:
mem_alloc,train/mem_peak_reserved, etc. - aggregated:
mem_alloc/mean,train/mem_alloc/max, etc. (History._compute_distributed_metrics emits these per rank.)
Does NOT match unrelated keys that happen to contain mem_ —
e.g. mem_loss or memo_field — because we anchor on the
full base name + aggregation suffix.
Source code in src/ezpz/utils/__init__.py
def is_memory_metric_key(key: str) -> bool:
"""True if *key* is one of the 4 mem_* metrics (raw OR aggregated).
Matches both:
- raw: ``mem_alloc``, ``train/mem_peak_reserved``, etc.
- aggregated: ``mem_alloc/mean``, ``train/mem_alloc/max``, etc.
(History._compute_distributed_metrics emits these per rank.)
Does NOT match unrelated keys that happen to contain ``mem_`` —
e.g. ``mem_loss`` or ``memo_field`` — because we anchor on the
full base name + aggregation suffix.
"""
for base in _MEMORY_METRIC_BASES:
for suffix in _AGGREGATION_SUFFIXES:
target = f"{base}{suffix}"
# endswith() catches both 'mem_alloc' and 'train/mem_alloc';
# the explicit base+suffix list keeps 'mem_loss' from matching.
if key == target or key.endswith(f"/{target}"):
return True
return False
make_tarfile(output_filename, source_dir)
⚓︎
Create a gzipped tar archive of source_dir at output_filename.
Normalizes the output to end in .tar.gz, then runs
tar -czvf <out> -C <parent> <dirname>. Uses subprocess (not
os.system + f-string) so paths with spaces or shell-meta characters
don't break or get reinterpreted.
Source code in src/ezpz/utils/__init__.py
def make_tarfile(
output_filename: str,
source_dir: str | os.PathLike | Path,
) -> str:
"""Create a gzipped tar archive of *source_dir* at *output_filename*.
Normalizes the output to end in `.tar.gz`, then runs
``tar -czvf <out> -C <parent> <dirname>``. Uses subprocess (not
os.system + f-string) so paths with spaces or shell-meta characters
don't break or get reinterpreted.
"""
output_filename = (
output_filename.replace(".tar", "").replace(".gz", "") + ".tar.gz"
)
srcfp = Path(source_dir).absolute().resolve()
dirname = srcfp.name
cmd = [
"tar", "-czvf", output_filename,
"--directory", str(srcfp.parent), dirname,
]
logger.info(f"Creating tarball at {output_filename} from {source_dir}")
logger.info("Executing: %s", " ".join(cmd))
subprocess.run(cmd, check=True)
return output_filename
model_summary(model, verbose=False, depth=1, input_size=None)
⚓︎
Print a summary of the model using torchinfo.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Any
|
The model to summarize. |
required |
verbose
|
bool
|
Whether to print the summary. Default: |
False
|
depth
|
int
|
The depth of the summary. Default: |
1
|
input_size
|
Optional[Sequence[int]]
|
The input size for the model. Default: |
None
|
Returns:
| Type | Description |
|---|---|
ModelStatistics | None
|
ModelStatistics | None: The model summary if torchinfo is available, otherwise None. |
Source code in src/ezpz/utils/__init__.py
def model_summary(
model: Any,
verbose: bool = False,
depth: int = 1,
input_size: Optional[Sequence[int]] = None,
) -> ModelStatistics | None:
"""
Print a summary of the model using torchinfo.
Args:
model: The model to summarize.
verbose (bool): Whether to print the summary. Default: ``False``.
depth (int): The depth of the summary. Default: ``1``.
input_size (Optional[Sequence[int]]): The input size for the model. Default: ``None``.
Returns:
ModelStatistics | None: The model summary if torchinfo is available, otherwise None.
"""
try:
from torchinfo import summary
return summary(
model,
input_size=input_size,
depth=depth,
verbose=verbose,
)
# logger.info(f'\n{summary_str}')
except (ImportError, ModuleNotFoundError):
logger.warning(
"torchinfo not installed, unable to print model summary!"
)
reset_peak_memory_stats(device)
⚓︎
Reset peak-memory counters on device. No-op on CPU/MPS.
Source code in src/ezpz/utils/__init__.py
def reset_peak_memory_stats(device: "torch.device | int | str") -> None:
"""Reset peak-memory counters on ``device``. No-op on CPU/MPS."""
dtype = _device_type(device)
if dtype == "cuda":
torch.cuda.reset_peak_memory_stats(device)
return
if dtype == "xpu":
try:
torch.xpu.reset_peak_memory_stats(device)
except (ImportError, AttributeError):
pass
summarize_dict(d, precision=6, keys_to_skip=None)
⚓︎
Summarize a dictionary into a string with formatted key-value pairs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d
|
dict
|
The dictionary to summarize. |
required |
precision
|
int
|
The precision for floating point values. Default: |
6
|
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
A string representation of the dictionary with formatted key-value pairs. |
Source code in src/ezpz/utils/__init__.py
def summarize_dict(
d: dict,
precision: int = 6,
keys_to_skip: Iterable | None = None,
) -> str:
"""
Summarize a dictionary into a string with formatted key-value pairs.
Args:
d (dict): The dictionary to summarize.
precision (int): The precision for floating point values. Default: ``6``.
Returns:
str: A string representation of the dictionary with formatted key-value pairs.
"""
keys_to_skip = [] if keys_to_skip is None else keys_to_skip
return " ".join(
[
format_pair(k, v, precision=precision)
for k, v in d.items()
if k not in keys_to_skip
]
)
write_deepspeed_zero12_auto_config(zero_stage=1, output_dir=None)
⚓︎
Write a deepspeed zero1 auto config to the output directory.
Source code in src/ezpz/utils/__init__.py
def write_deepspeed_zero12_auto_config(
zero_stage: int = 1, output_dir: Optional[PathLike] = None
) -> dict:
"""
Write a deepspeed zero1 auto config to the output directory.
"""
import json
ds_config = {
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": True,
"wandb": {"enabled": True},
"bf16": {"enabled": True},
"flops_profiler": {
"enabled": True,
"profile_step": 1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": True,
"adam_w_mode": True,
},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
},
},
"zero_optimization": {
"stage": zero_stage,
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
},
}
if output_dir is None:
output_dir = Path(os.getcwd()).joinpath("ds_configs")
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
outfile = output_dir.joinpath(
f"deepspeed_zero{zero_stage}_auto_config.json"
)
logger.info(
f"Saving DeepSpeed ZeRO Stage {zero_stage} "
f"auto config to: {outfile.as_posix()}"
)
with outfile.open("w") as f:
json.dump(
ds_config,
fp=f,
indent=4,
)
return ds_config
write_deepspeed_zero3_auto_config(zero_stage=3, output_dir=None)
⚓︎
Write a deepspeed zero1 auto config to the output directory.
Source code in src/ezpz/utils/__init__.py
def write_deepspeed_zero3_auto_config(
zero_stage: int = 3, output_dir: Optional[PathLike] = None
) -> dict:
"""
Write a deepspeed zero1 auto config to the output directory.
"""
import json
ds_config = {
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 1,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": True,
"wandb": {"enabled": True},
"bf16": {"enabled": True},
"flops_profiler": {
"enabled": True,
"profile_step": 1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"torch_adam": True,
"adam_w_mode": True,
},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto",
},
},
"zero_optimization": {
"stage": zero_stage,
"allgather_partitions": True,
"allgather_bucket_size": 2e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": "auto",
"contiguous_gradients": True,
},
}
if output_dir is None:
output_dir = Path(os.getcwd()).joinpath("ds_configs")
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
outfile = output_dir.joinpath(
f"deepspeed_zero{zero_stage}_auto_config.json"
)
logger.info(
f"Saving DeepSpeed ZeRO Stage {zero_stage} "
f"auto config to: {outfile.as_posix()}"
)
with outfile.open("w") as f:
json.dump(
ds_config,
fp=f,
indent=4,
)
return ds_config
write_generic_deepspeed_config(gradient_accumulation_steps=1, gradient_clipping='auto', steps_per_print=10, train_batch_size='auto', train_micro_batch_size_per_gpu='auto', wall_clock_breakdown=False, wandb=None, bf16=None, fp16=None, flops_profiler=None, optimizer=None, scheduler=None, zero_optimization=None)
⚓︎
Write a generic deepspeed config to the output directory.
Source code in src/ezpz/utils/__init__.py
def write_generic_deepspeed_config(
gradient_accumulation_steps: int = 1,
gradient_clipping: str | float = "auto",
steps_per_print: int = 10,
train_batch_size: str = "auto",
train_micro_batch_size_per_gpu: str = "auto",
wall_clock_breakdown: bool = False,
wandb: Optional[dict] = None,
bf16: Optional[dict] = None,
fp16: Optional[dict] = None,
flops_profiler: Optional[dict] = None,
optimizer: Optional[dict] = None,
scheduler: Optional[dict] = None,
zero_optimization: Optional[dict] = None,
):
"""
Write a generic deepspeed config to the output directory.
"""
ds_config = {
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_clipping": gradient_clipping,
"steps_per_print": steps_per_print,
"train_batch_size": train_batch_size,
"train_micro_batch_size_per_gpu": train_micro_batch_size_per_gpu,
"wall_clock_breakdown": wall_clock_breakdown,
"wandb": wandb,
"bf16": bf16,
"fp16": fp16,
"flops_profiler": flops_profiler,
"optimizer": optimizer,
"scheduler": scheduler,
"zero_optimization": zero_optimization,
}
return ds_config