Skip to content

ezpz.configs⚓︎

Configuration dataclasses and utility functions for training, DeepSpeed, and HuggingFace integration.

TrainConfig⚓︎

High-level training configuration:

from ezpz.configs import TrainConfig

config = TrainConfig(
    seed=42,
    dtype="bf16",
    use_wandb=True,
    wandb_project_name="my-project",
    ds_config_path="./ds_config.json",
)
Field Type Default Description
gas int 1 Gradient accumulation steps
use_wandb bool False Enable Weights & Biases logging
seed int \| None None Random seed
port str \| None None Rendezvous port
dtype Any \| None None Data type for training
load_from str \| None None Path to load checkpoint from
save_to str \| None None Path to save checkpoint to
ds_config_path str \| None None Path to DeepSpeed config
wandb_project_name str \| None None W&B project name
ngpus int \| None None Number of GPUs to use

Unknown keyword arguments passed to the constructor are captured in self.extras rather than raising an error.

ZeroConfig⚓︎

DeepSpeed ZeRO optimizer configuration with all ZeRO stage options:

from ezpz.configs import ZeroConfig

zero = ZeroConfig(stage=2, overlap_comm=True, contiguous_gradients=True)

HuggingFace Configs⚓︎

HfModelArguments⚓︎

Configuration for HuggingFace model loading:

from ezpz.configs import HfModelArguments

model_args = HfModelArguments(
    model_name_or_path="gpt2",
    torch_dtype="bfloat16",
    use_fast_tokenizer=True,
)

HfDataTrainingArguments⚓︎

Configuration for HuggingFace dataset loading and preprocessing:

from ezpz.configs import HfDataTrainingArguments

data_args = HfDataTrainingArguments(
    dataset_name="wikitext",
    dataset_config_name="wikitext-2-raw-v1",
    block_size=1024,
)

Vision Transformer Configs⚓︎

ViTConfig⚓︎

Standard Vision Transformer configuration:

from ezpz.configs import ViTConfig

vit = ViTConfig(
    img_size=224,
    patch_size=16,
    depth=12,
    num_heads=12,
    hidden_dim=768,
    num_classes=10,
)

timmViTConfig⚓︎

Timm-compatible Vision Transformer configuration with additional training parameters:

from ezpz.configs import timmViTConfig

vit = timmViTConfig(batch_size=128, head_dim=64)

TrainArgs⚓︎

Training hyperparameters dataclass used by the example scripts:

from ezpz.configs import TrainArgs

args = TrainArgs(
    batch_size=32,
    max_iters=1000,
    fsdp=True,
    dtype="bf16",
    compile=True,
)

Utility Functions⚓︎

get_scheduler()⚓︎

Detect the active job scheduler from environment variables:

from ezpz.configs import get_scheduler

scheduler = get_scheduler()  # Returns "PBS", "SLURM", or falls back to hostname-based detection

Display an OmegaConf DictConfig as a rich tree in the terminal:

from ezpz.configs import print_config_tree

print_config_tree(cfg, resolve=True, style="tree")

Configuration dataclasses, logging setup, and path constants for ezpz.

Provides :class:TrainConfig, :class:ZeroConfig, :class:ViTConfig, and related dataclasses used to configure distributed training runs, DeepSpeed ZeRO optimisation, and vision model architectures. Also exposes directory constants (LOGS_DIR, OUTPUTS_DIR) and helpers for scheduler detection and DeepSpeed config loading.

HfDataTrainingArguments dataclass ⚓︎

Arguments pertaining to what data we are going to input our model for training and eval.

Source code in src/ezpz/configs.py
@dataclass
class HfDataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    data_path: Optional[str] = field(
        default=None,
        metadata={"help": "Path to the training data."},
    )
    dataset_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the dataset to use (via the datasets library)."
        },
    )
    dataset_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The configuration name of the dataset to use (via the datasets library)."
        },
    )
    train_split_str: Optional[str] = field(
        default=None,
        metadata={
            "help": "The split string to use for the train split (via the datasets library)."
        },
    )
    train_split_name: Optional[str] = field(
        default="train",
        metadata={
            "help": "The name of the train split to use (via the datasets library)."
        },
    )
    validation_split_name: Optional[str] = field(
        default="validation",
        metadata={
            "help": "The name of the validation split to use (via the datasets library)."
        },
    )
    validation_split_str: Optional[str] = field(
        default=None,
        metadata={
            "help": "The split string to use for the validation split (via the datasets library)."
        },
    )
    test_split_name: Optional[str] = field(
        default="test",
        metadata={
            "help": "The name of the test split to use (via the datasets library)."
        },
    )
    test_split_str: Optional[str] = field(
        default=None,
        metadata={
            "help": "The split string to use for the test split (via the datasets library)."
        },
    )
    train_file: Optional[str] = field(
        default=None,
        metadata={"help": "The input training data file (a text file)."},
    )
    validation_file: Optional[str] = field(
        default=None,
        metadata={
            "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    streaming: bool = field(
        default=False, metadata={"help": "Enable streaming mode"}
    )
    block_size: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False,
        metadata={"help": "Overwrite the cached training and evaluation sets"},
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={
            "help": "The number of processes to use for the preprocessing."
        },
    )
    keep_linebreaks: bool = field(
        default=True,
        metadata={
            "help": "Whether to keep line breaks when using TXT files or not."
        },
    )

    def __post_init__(self):
        """Validate dataset arguments and ensure required files are present."""
        from transformers.utils.versions import require_version

        if self.streaming:
            require_version(
                "datasets>=2.0.0",
                "The streaming feature requires `datasets>=2.0.0`",
            )

        if (
            self.dataset_name is None
            and self.data_path is None
            and self.train_file is None
            and self.validation_file is None
        ):
            raise ValueError(
                "You must specify at least one of the following: "
                "a dataset name, a data path, a training file, or a validation file."
            )

        if self.train_file is not None:
            extension = self.train_file.split(".")[-1]
            assert extension in [
                "csv",
                "json",
                "txt",
            ], "`train_file` should be a csv, a json or a txt file."
        if self.validation_file is not None:
            extension = self.validation_file.split(".")[-1]
            assert extension in [
                "csv",
                "json",
                "txt",
            ], "`validation_file` should be a csv, a json or a txt file."

__post_init__() ⚓︎

Validate dataset arguments and ensure required files are present.

Source code in src/ezpz/configs.py
def __post_init__(self):
    """Validate dataset arguments and ensure required files are present."""
    from transformers.utils.versions import require_version

    if self.streaming:
        require_version(
            "datasets>=2.0.0",
            "The streaming feature requires `datasets>=2.0.0`",
        )

    if (
        self.dataset_name is None
        and self.data_path is None
        and self.train_file is None
        and self.validation_file is None
    ):
        raise ValueError(
            "You must specify at least one of the following: "
            "a dataset name, a data path, a training file, or a validation file."
        )

    if self.train_file is not None:
        extension = self.train_file.split(".")[-1]
        assert extension in [
            "csv",
            "json",
            "txt",
        ], "`train_file` should be a csv, a json or a txt file."
    if self.validation_file is not None:
        extension = self.validation_file.split(".")[-1]
        assert extension in [
            "csv",
            "json",
            "txt",
        ], "`validation_file` should be a csv, a json or a txt file."

HfModelArguments dataclass ⚓︎

Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.

Source code in src/ezpz/configs.py
@dataclass
class HfModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    wandb_project_name: Optional[str] = field(  # type:ignore
        default=None,
        metadata={
            "help": (
                "The name of the wandb project to use. If not specified, will use the model name."
            )
        },
    )

    model_name_or_path: Optional[str] = field(  # type:ignore
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str | None] = field(
        default=None,
        metadata={
            "help": "If training from scratch, pass a model type from the list: "
            + ", ".join("https://huggingface.co/docs/transformers/en/models")
        },
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if not the same as model_name"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={
            "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    token: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether to trust the execution of code from datasets/models defined on the Hub."
                " This option should only be set to `True` for repositories you trust and in which you have read the"
                " code, as it will execute code present on the Hub on your local machine."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )

    def __post_init__(self):
        """Validate mutually exclusive Hugging Face model configuration options."""
        if self.config_overrides is not None and (
            self.config_name is not None or self.model_name_or_path is not None
        ):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )

__post_init__() ⚓︎

Validate mutually exclusive Hugging Face model configuration options.

Source code in src/ezpz/configs.py
def __post_init__(self):
    """Validate mutually exclusive Hugging Face model configuration options."""
    if self.config_overrides is not None and (
        self.config_name is not None or self.model_name_or_path is not None
    ):
        raise ValueError(
            "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
        )

TrainArgs dataclass ⚓︎

Arguments for vision model training runs.

Source code in src/ezpz/configs.py
@dataclass
class TrainArgs:
    """Arguments for vision model training runs."""

    img_size: int
    """Input image resolution."""

    batch_size: int
    """Training batch size per device."""

    num_heads: int
    """Number of self-attention heads."""

    head_dim: int
    """Dimension per attention head."""

    depth: int
    """Number of transformer encoder layers."""

    patch_size: int
    """Size of each image patch."""

    dtype: str
    """Data type string (e.g. ``"bf16"``, ``"fp32"``)."""

    compile: bool
    """Enable ``torch.compile`` for the model."""

    attn_type: str
    """Attention backend (``"native"``, ``"sdpa"``, ``"flash"``, etc.)."""

    warmup: int | float
    """Learning rate warmup steps or fraction of total steps."""

    num_workers: int
    """Number of data loader workers."""

    max_iters: int
    """Maximum training iterations."""

    fsdp: Optional[bool] = False
    """Enable Fully Sharded Data Parallel wrapping."""

    format: Optional[str] = field(default_factory=str)
    """Output format string for logging."""

    cuda_sdpa_backend: Optional[str] = field(default_factory=str)
    """Override the CUDA SDPA backend selection."""

attn_type instance-attribute ⚓︎

Attention backend ("native", "sdpa", "flash", etc.).

batch_size instance-attribute ⚓︎

Training batch size per device.

compile instance-attribute ⚓︎

Enable torch.compile for the model.

cuda_sdpa_backend = field(default_factory=str) class-attribute instance-attribute ⚓︎

Override the CUDA SDPA backend selection.

depth instance-attribute ⚓︎

Number of transformer encoder layers.

dtype instance-attribute ⚓︎

Data type string (e.g. "bf16", "fp32").

format = field(default_factory=str) class-attribute instance-attribute ⚓︎

Output format string for logging.

fsdp = False class-attribute instance-attribute ⚓︎

Enable Fully Sharded Data Parallel wrapping.

head_dim instance-attribute ⚓︎

Dimension per attention head.

img_size instance-attribute ⚓︎

Input image resolution.

max_iters instance-attribute ⚓︎

Maximum training iterations.

num_heads instance-attribute ⚓︎

Number of self-attention heads.

num_workers instance-attribute ⚓︎

Number of data loader workers.

patch_size instance-attribute ⚓︎

Size of each image patch.

warmup instance-attribute ⚓︎

Learning rate warmup steps or fraction of total steps.

TrainConfig dataclass ⚓︎

Bases: BaseConfig

High-level training options shared by ezpz scripts.

Source code in src/ezpz/configs.py
@dataclass(init=False)
@rich.repr.auto
class TrainConfig(BaseConfig):
    """High-level training options shared by ezpz scripts."""

    gas: int = 1
    """Gradient accumulation steps."""

    use_wandb: bool = False
    """Enable Weights & Biases logging."""

    seed: Optional[int] = None
    """Random seed for reproducibility. ``None`` means no explicit seeding."""

    port: Optional[str] = None
    """Port for the distributed rendezvous endpoint."""

    dtype: Optional[Any] = None
    """Default data type for model parameters (e.g. ``"bf16"``)."""

    load_from: Optional[str] = None
    """Path to a checkpoint to resume training from."""

    save_to: Optional[str] = None
    """Directory where checkpoints will be saved."""

    ds_config_path: Optional[str] = None
    """Path to a DeepSpeed JSON/YAML config file."""

    wandb_project_name: Optional[str] = None
    """W&B project name. Auto-set from ``WANDB_PROJECT`` when ``use_wandb`` is True."""

    ngpus: Optional[int] = None
    """Number of GPUs to use. ``None`` means use all available."""

    extras: dict[str, Any] = field(
        default_factory=dict, init=False, repr=False
    )
    """Extra keyword arguments not matching any declared field."""

    def __init__(self, **kwargs: Any) -> None:
        """Populate known fields while capturing any extras in ``self.extras``."""
        extras: dict[str, Any] = {}
        for name, field_def in self.__dataclass_fields__.items():
            if name == "extras":
                continue
            if name in kwargs:
                value = kwargs.pop(name)
            elif field_def.default is not MISSING:
                value = field_def.default
            elif field_def.default_factory is not MISSING:  # type: ignore[attr-defined]
                value = field_def.default_factory()  # type: ignore[misc]
            else:
                value = None
            setattr(self, name, value)

        for key, value in kwargs.items():
            setattr(self, key, value)
            extras[key] = value
        self.extras = extras
        self.__post_init__()

    def to_str(self) -> str:
        """Return a compact identifier for this config."""
        parts = [f"gas-{self.gas}"]
        if self.seed is not None:
            parts.append(f"seed-{self.seed}")
        return "_".join(parts)

    def __post_init__(self):
        """Validate configuration after initialisation."""
        if self.use_wandb and self.wandb_project_name is None:
            self.wandb_project_name = os.environ.get(
                "WANDB_PROJECT", os.environ.get("WB_PROJECT", "ezpz")
            )

ds_config_path = None class-attribute instance-attribute ⚓︎

Path to a DeepSpeed JSON/YAML config file.

dtype = None class-attribute instance-attribute ⚓︎

Default data type for model parameters (e.g. "bf16").

extras = extras class-attribute instance-attribute ⚓︎

Extra keyword arguments not matching any declared field.

gas = 1 class-attribute instance-attribute ⚓︎

Gradient accumulation steps.

load_from = None class-attribute instance-attribute ⚓︎

Path to a checkpoint to resume training from.

ngpus = None class-attribute instance-attribute ⚓︎

Number of GPUs to use. None means use all available.

port = None class-attribute instance-attribute ⚓︎

Port for the distributed rendezvous endpoint.

save_to = None class-attribute instance-attribute ⚓︎

Directory where checkpoints will be saved.

seed = None class-attribute instance-attribute ⚓︎

Random seed for reproducibility. None means no explicit seeding.

use_wandb = False class-attribute instance-attribute ⚓︎

Enable Weights & Biases logging.

wandb_project_name = None class-attribute instance-attribute ⚓︎

W&B project name. Auto-set from WANDB_PROJECT when use_wandb is True.

__getitem__(key) ⚓︎

Provide dictionary-style indexing for convenience.

Source code in src/ezpz/configs.py
def __getitem__(self, key):
    """Provide dictionary-style indexing for convenience."""
    return super().__getattribute__(key)

__init__(**kwargs) ⚓︎

Populate known fields while capturing any extras in self.extras.

Source code in src/ezpz/configs.py
def __init__(self, **kwargs: Any) -> None:
    """Populate known fields while capturing any extras in ``self.extras``."""
    extras: dict[str, Any] = {}
    for name, field_def in self.__dataclass_fields__.items():
        if name == "extras":
            continue
        if name in kwargs:
            value = kwargs.pop(name)
        elif field_def.default is not MISSING:
            value = field_def.default
        elif field_def.default_factory is not MISSING:  # type: ignore[attr-defined]
            value = field_def.default_factory()  # type: ignore[misc]
        else:
            value = None
        setattr(self, name, value)

    for key, value in kwargs.items():
        setattr(self, key, value)
        extras[key] = value
    self.extras = extras
    self.__post_init__()

__post_init__() ⚓︎

Validate configuration after initialisation.

Source code in src/ezpz/configs.py
def __post_init__(self):
    """Validate configuration after initialisation."""
    if self.use_wandb and self.wandb_project_name is None:
        self.wandb_project_name = os.environ.get(
            "WANDB_PROJECT", os.environ.get("WB_PROJECT", "ezpz")
        )

from_file(fpath) ⚓︎

Populate the configuration from a JSON file.

Source code in src/ezpz/configs.py
def from_file(self, fpath: os.PathLike) -> None:
    """Populate the configuration from a JSON file."""
    with Path(fpath).open("r") as f:
        config = json.load(f)
    self.__init__(**config)

get_config() ⚓︎

Return the configuration as a standard dictionary via asdict.

Source code in src/ezpz/configs.py
def get_config(self) -> dict:
    """Return the configuration as a standard dictionary via ``asdict``."""
    return cast(dict[str, Any], asdict(self))

to_dict() ⚓︎

Return a deep copy of the dataclass __dict__.

Source code in src/ezpz/configs.py
def to_dict(self) -> dict[str, Any]:
    """Return a deep copy of the dataclass ``__dict__``."""
    return cast(dict[str, Any], deepcopy(self.__dict__))

to_file(fpath) ⚓︎

Write the configuration to fpath in JSON format.

Source code in src/ezpz/configs.py
def to_file(self, fpath: os.PathLike) -> None:
    """Write the configuration to ``fpath`` in JSON format."""
    with Path(fpath).open("w") as f:
        json.dump(self.to_json(), f, indent=4)

to_json() ⚓︎

Return a JSON string representation of the configuration.

Source code in src/ezpz/configs.py
def to_json(self) -> str:
    """Return a JSON string representation of the configuration."""
    # name = (
    #     f'{name=}' if name is not None
    #     else f'{self.__class__.__name__}'
    # )
    return json.dumps(deepcopy(self.__dict__), indent=4)

to_str() ⚓︎

Return a compact identifier for this config.

Source code in src/ezpz/configs.py
def to_str(self) -> str:
    """Return a compact identifier for this config."""
    parts = [f"gas-{self.gas}"]
    if self.seed is not None:
        parts.append(f"seed-{self.seed}")
    return "_".join(parts)

ViTConfig dataclass ⚓︎

Configuration for a Vision Transformer model.

Source code in src/ezpz/configs.py
@dataclass
class ViTConfig:
    """Configuration for a Vision Transformer model."""

    img_size: int = 224
    """Input image resolution (height and width)."""

    patch_size: int = 16
    """Size of each image patch."""

    depth: int = 12
    """Number of transformer encoder layers."""

    num_heads: int = 12
    """Number of self-attention heads."""

    hidden_dim: int = 768
    """Embedding / hidden dimension."""

    mlp_dim: int = 3072
    """Feed-forward network inner dimension."""

    dropout: float = 0.0
    """Dropout rate for embeddings and feed-forward layers."""

    attention_dropout: float = 0.0
    """Dropout rate applied to attention weights."""

    num_classes: int = 10
    """Number of output classification classes."""

    def __post_init__(self):
        self.seq_len = (self.img_size // self.patch_size) ** 2  # 196, default

attention_dropout = 0.0 class-attribute instance-attribute ⚓︎

Dropout rate applied to attention weights.

depth = 12 class-attribute instance-attribute ⚓︎

Number of transformer encoder layers.

dropout = 0.0 class-attribute instance-attribute ⚓︎

Dropout rate for embeddings and feed-forward layers.

hidden_dim = 768 class-attribute instance-attribute ⚓︎

Embedding / hidden dimension.

img_size = 224 class-attribute instance-attribute ⚓︎

Input image resolution (height and width).

mlp_dim = 3072 class-attribute instance-attribute ⚓︎

Feed-forward network inner dimension.

num_classes = 10 class-attribute instance-attribute ⚓︎

Number of output classification classes.

num_heads = 12 class-attribute instance-attribute ⚓︎

Number of self-attention heads.

patch_size = 16 class-attribute instance-attribute ⚓︎

Size of each image patch.

ZeroConfig dataclass ⚓︎

Subset of DeepSpeed ZeRO options exposed via the ezpz CLI.

Source code in src/ezpz/configs.py
@dataclass
class ZeroConfig:
    """Subset of DeepSpeed ZeRO options exposed via the ezpz CLI."""

    stage: int = 0
    """ZeRO optimisation stage (0–3)."""

    allgather_partitions: Optional[bool] = None
    """Use allgather for partition reconstruction."""

    allgather_bucket_size: int = int(5e8)
    """Number of elements allgathered at a time."""

    overlap_comm: Optional[bool] = None
    """Overlap gradient reduction with backward pass."""

    reduce_scatter: bool = True
    """Use reduce-scatter instead of allreduce for gradients."""

    reduce_bucket_size: int = int(5e8)
    """Number of elements reduced at a time."""

    contiguous_gradients: Optional[bool] = None
    """Copy gradients to a contiguous buffer as they are produced."""

    offload_param: Optional[dict] = None
    """Parameter offloading config (e.g. ``{"device": "cpu"}``)."""

    offload_optimizer: Optional[dict] = None
    """Optimizer state offloading config (e.g. ``{"device": "cpu"}``)."""

    stage3_max_live_parameters: int = int(1e9)
    """Max number of parameters resident per GPU before releasing (stage 3)."""

    stage3_max_reuse_distance: int = int(1e9)
    """Max reuse distance for parameters before re-fetching (stage 3)."""

    stage3_prefetch_bucket_size: int = int(5e8)
    """Prefetch bucket size for stage 3 parameter fetching."""

    stage3_param_persistence_threshold: int = int(1e6)
    """Parameters smaller than this stay on the GPU (stage 3)."""

    sub_group_size: Optional[int] = None
    """Number of parameters within a sub-group for ZeRO-3 partitioning."""

    elastic_checkpoint: Optional[dict] = None
    """Configuration for elastic checkpointing."""

    stage3_gather_16bit_weights_on_model_save: Optional[bool] = None
    """Gather full fp16 weights on save for portability (stage 3)."""

    ignore_unused_parameters: Optional[bool] = None
    """Silence errors about parameters not used in the forward pass."""

    round_robin_gradients: Optional[bool] = None
    """Distribute gradient partitions round-robin across ranks."""

    zero_hpz_partition_size: Optional[int] = None
    """Hierarchical partitioning group size (ZeRO++)."""

    zero_quantized_weights: Optional[bool] = None
    """Enable weight quantisation for communication (ZeRO++)."""

    zero_quantized_gradients: Optional[bool] = None
    """Enable gradient quantisation for communication (ZeRO++)."""

    log_trace_cache_warnings: Optional[bool] = None
    """Log warnings about trace cache misses."""

allgather_bucket_size = int(500000000.0) class-attribute instance-attribute ⚓︎

Number of elements allgathered at a time.

allgather_partitions = None class-attribute instance-attribute ⚓︎

Use allgather for partition reconstruction.

contiguous_gradients = None class-attribute instance-attribute ⚓︎

Copy gradients to a contiguous buffer as they are produced.

elastic_checkpoint = None class-attribute instance-attribute ⚓︎

Configuration for elastic checkpointing.

ignore_unused_parameters = None class-attribute instance-attribute ⚓︎

Silence errors about parameters not used in the forward pass.

log_trace_cache_warnings = None class-attribute instance-attribute ⚓︎

Log warnings about trace cache misses.

offload_optimizer = None class-attribute instance-attribute ⚓︎

Optimizer state offloading config (e.g. {"device": "cpu"}).

offload_param = None class-attribute instance-attribute ⚓︎

Parameter offloading config (e.g. {"device": "cpu"}).

overlap_comm = None class-attribute instance-attribute ⚓︎

Overlap gradient reduction with backward pass.

reduce_bucket_size = int(500000000.0) class-attribute instance-attribute ⚓︎

Number of elements reduced at a time.

reduce_scatter = True class-attribute instance-attribute ⚓︎

Use reduce-scatter instead of allreduce for gradients.

round_robin_gradients = None class-attribute instance-attribute ⚓︎

Distribute gradient partitions round-robin across ranks.

stage = 0 class-attribute instance-attribute ⚓︎

ZeRO optimisation stage (0–3).

stage3_gather_16bit_weights_on_model_save = None class-attribute instance-attribute ⚓︎

Gather full fp16 weights on save for portability (stage 3).

stage3_max_live_parameters = int(1000000000.0) class-attribute instance-attribute ⚓︎

Max number of parameters resident per GPU before releasing (stage 3).

stage3_max_reuse_distance = int(1000000000.0) class-attribute instance-attribute ⚓︎

Max reuse distance for parameters before re-fetching (stage 3).

stage3_param_persistence_threshold = int(1000000.0) class-attribute instance-attribute ⚓︎

Parameters smaller than this stay on the GPU (stage 3).

stage3_prefetch_bucket_size = int(500000000.0) class-attribute instance-attribute ⚓︎

Prefetch bucket size for stage 3 parameter fetching.

sub_group_size = None class-attribute instance-attribute ⚓︎

Number of parameters within a sub-group for ZeRO-3 partitioning.

zero_hpz_partition_size = None class-attribute instance-attribute ⚓︎

Hierarchical partitioning group size (ZeRO++).

zero_quantized_gradients = None class-attribute instance-attribute ⚓︎

Enable gradient quantisation for communication (ZeRO++).

zero_quantized_weights = None class-attribute instance-attribute ⚓︎

Enable weight quantisation for communication (ZeRO++).

timmViTConfig dataclass ⚓︎

Configuration for a timm-style Vision Transformer (larger defaults).

Source code in src/ezpz/configs.py
@dataclass
class timmViTConfig:
    """Configuration for a timm-style Vision Transformer (larger defaults)."""

    img_size: int = 224
    """Input image resolution."""

    batch_size: int = 128
    """Training batch size per device."""

    num_heads: int = 16
    """Number of self-attention heads."""

    head_dim: int = 64
    """Dimension per attention head."""

    depth: int = 24
    """Number of transformer encoder layers."""

    patch_size: int = 16
    """Size of each image patch."""

    hidden_dim: int = 1024
    """Embedding / hidden dimension."""

    mlp_dim: int = 4096
    """Feed-forward network inner dimension."""

    dropout: float = 0.0
    """Dropout rate for embeddings and feed-forward layers."""

    attention_dropout: float = 0.0
    """Dropout rate applied to attention weights."""

    num_classes: int = 1000
    """Number of output classification classes."""

    def __post_init__(self):
        self.seq_len = (self.img_size // self.patch_size) ** 2  # 196, default

attention_dropout = 0.0 class-attribute instance-attribute ⚓︎

Dropout rate applied to attention weights.

batch_size = 128 class-attribute instance-attribute ⚓︎

Training batch size per device.

depth = 24 class-attribute instance-attribute ⚓︎

Number of transformer encoder layers.

dropout = 0.0 class-attribute instance-attribute ⚓︎

Dropout rate for embeddings and feed-forward layers.

head_dim = 64 class-attribute instance-attribute ⚓︎

Dimension per attention head.

hidden_dim = 1024 class-attribute instance-attribute ⚓︎

Embedding / hidden dimension.

img_size = 224 class-attribute instance-attribute ⚓︎

Input image resolution.

mlp_dim = 4096 class-attribute instance-attribute ⚓︎

Feed-forward network inner dimension.

num_classes = 1000 class-attribute instance-attribute ⚓︎

Number of output classification classes.

num_heads = 16 class-attribute instance-attribute ⚓︎

Number of self-attention heads.

patch_size = 16 class-attribute instance-attribute ⚓︎

Size of each image patch.

_derive_default_job_name(argv) ⚓︎

Derive a log-friendly job name from command-line arguments.

Looks for python -m <module> patterns and converts the module name (e.g., ezpz.examples.vit) to a filename-safe form (ezpz-examples-vit). Falls back to the executable base name plus the first non-option arg.

Source code in src/ezpz/configs.py
def _derive_default_job_name(argv: Sequence[str]) -> str:
    """Derive a log-friendly job name from command-line arguments.

    Looks for ``python -m <module>`` patterns and converts the module name
    (e.g., ``ezpz.examples.vit``) to a filename-safe form (``ezpz-examples-vit``).
    Falls back to the executable base name plus the first non-option arg.
    """
    if not argv:
        return "ezpz"

    # Detect `python -m <module>` invocations
    for i, arg in enumerate(argv):
        if arg == "-m" and i + 1 < len(argv):
            module_name = argv[i + 1]
            return module_name

    # Fallback: use executable basename + optional first subcommand
    base = Path(argv[0]).name
    if base.endswith(".py"):
        base = base[:-3]
    parts = [base]
    if len(argv) > 1 and argv[1] and not argv[1].startswith("-"):
        parts.append(argv[1])
    return "-".join(parts)

cmd_exists(cmd) ⚓︎

Check whether command exists.

Examples:

>>> cmd_exists("ls")
True
>>> cmd_exists("hostname")
True
Source code in src/ezpz/configs.py
def cmd_exists(cmd: str) -> bool:
    """Check whether command exists.

    Examples:
        >>> cmd_exists("ls")
        True
        >>> cmd_exists("hostname")
        True
    """
    return shutil.which(cmd) is not None

command_exists(cmd) ⚓︎

Return True if cmd is available on PATH.

Source code in src/ezpz/configs.py
def command_exists(cmd: str) -> bool:
    """Return ``True`` if ``cmd`` is available on ``PATH``."""
    return shutil.which(cmd) is not None

get_json_log_file() ⚓︎

Return the path to the current JSON log file, if configured.

Source code in src/ezpz/configs.py
def get_json_log_file() -> Optional[Path]:
    """Return the path to the current JSON log file, if configured."""
    return _CURRENT_JSON_LOG_FILE

get_logging_config(rank=None) ⚓︎

Return the logging configuration dictionary used by logging.config.

Parameters:

Name Type Description Default
rank Optional[int]

The distributed rank. If None, will be determined automatically. Used to create per-rank log files when EZPZ_LOG_FROM_ALL_RANKS is set.

None
Source code in src/ezpz/configs.py
def get_logging_config(rank: Optional[int] = None) -> dict:
    """Return the logging configuration dictionary used by ``logging.config``.

    Args:
        rank: The distributed rank. If None, will be determined automatically.
              Used to create per-rank log files when EZPZ_LOG_FROM_ALL_RANKS is set.
    """
    log_from_all_ranks = _get_log_from_all_ranks()
    resolved_rank = _determine_rank(rank)
    cache_key = (resolved_rank, log_from_all_ranks)
    cached_config = _LOGGING_CONFIG_CACHE.get(cache_key)
    if cached_config is not None:
        return cached_config

    config_dict = _build_logging_config(
        rank=resolved_rank,
        log_from_all_ranks=log_from_all_ranks,
    )
    _LOGGING_CONFIG_CACHE[cache_key] = config_dict
    return config_dict

get_scheduler(_scheduler=None) ⚓︎

Infer the active scheduler from environment variables or hostname.

Source code in src/ezpz/configs.py
def get_scheduler(_scheduler: Optional[str] = None) -> str:
    """Infer the active scheduler from environment variables or hostname."""
    from ezpz import get_hostname, get_machine

    if _scheduler is not None:
        log.info(f"Using user-specified scheduler: {_scheduler}")
        return _scheduler.upper()

    if os.environ.get("PBS_JOBID"):
        return "PBS"
    if os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID"):
        return "SLURM"

    machine = get_machine(get_hostname())
    if machine.lower() in [
        "thetagpu",
        "sunspot",
        "polaris",
        "aurora",
        "sophia",
    ]:
        return SCHEDULERS["ALCF"]
    if machine.lower() in ["frontier"]:
        return SCHEDULERS["OLCF"]
    if machine.lower() in ["nersc", "perlmutter"]:
        return SCHEDULERS["NERSC"]
    return "UNKNOWN"

getjobenv_dep() ⚓︎

Return the getjobenv helper path (for debugging).

Source code in src/ezpz/configs.py
def getjobenv_dep():
    """Return the ``getjobenv`` helper path (for debugging)."""
    print(GETJOBENV)
    return GETJOBENV

git_ds_info() ⚓︎

Log the output of DeepSpeed's environment report plus Git metadata.

Source code in src/ezpz/configs.py
def git_ds_info():
    """Log the output of DeepSpeed's environment report plus Git metadata."""
    from deepspeed.env_report import main as ds_report  # type: ignore[import-not-found]

    ds_report()

    # Write out version/git info
    git_hash_cmd = "git rev-parse --short HEAD"
    git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
    if command_exists("git"):
        try:
            result = subprocess.check_output(git_hash_cmd, shell=True)
            git_hash = result.decode("utf-8").strip()
            result = subprocess.check_output(git_branch_cmd, shell=True)
            git_branch = result.decode("utf-8").strip()
        except subprocess.CalledProcessError:
            git_hash = "unknown"
            git_branch = "unknown"
    else:
        git_hash = "unknown"
        git_branch = "unknown"
    log.info(
        f"**** Git info for DeepSpeed:"
        f" git_hash={git_hash} git_branch={git_branch} ****"
    )

load_ds_config(fpath=None) ⚓︎

Load a DeepSpeed configuration file (JSON or YAML).

Source code in src/ezpz/configs.py
def load_ds_config(
    fpath: Optional[Union[str, os.PathLike, Path]] = None,  # type:ignore[reportDeprecated]
) -> dict[str, Any]:
    """Load a DeepSpeed configuration file (JSON or YAML)."""
    fpath = Path(DS_CONFIG_PATH) if fpath is None else f"{fpath}"
    cfgpath = Path(fpath)
    if cfgpath.suffix == ".json":
        with cfgpath.open("r") as f:
            ds_config: dict[str, Any] = json.load(f)
        return ds_config
    if cfgpath.suffix == ".yaml":
        with cfgpath.open("r") as stream:
            dsconfig: dict[str, Any] = dict(yaml.safe_load(stream))
        return dsconfig
    raise TypeError("Unexpected FileType")

print_config(cfg) ⚓︎

Render cfg to the active rich console.

Source code in src/ezpz/configs.py
def print_config(cfg: Union[dict, str]) -> None:
    """Render ``cfg`` to the active rich console."""
    # try:
    #     from hydra.utils import instantiate
    #     config = instantiate(cfg)
    # except Exception:
    #     config = OmegaConf.to_container(cfg, resolve=True)
    #     config = OmegaConf.to_container(cfg, resolve=True)
    # if isinstance(cfg, dict):
    #     jstr = json.dumps(cfg, indent=4)
    # else:
    #     jstr = cfg
    from rich.logging import RichHandler

    from ezpz.log.handler import RichHandler as EnrichHandler

    console = None
    for handler in log.handlers:
        if isinstance(handler, (RichHandler, EnrichHandler)):
            console = handler.console
    if console is None:
        from ezpz.log.console import get_console

        console = get_console()
    # console.print_json(data=cfg, indent=4, highlight=True)
    print_json(data=cfg, console=console, indent=4, highlight=True)

print_config_tree(cfg, resolve=True, save_to_file=True, verbose=True, style='tree', print_order=None, highlight=True, outfile=None) ⚓︎

Prints the contents of a DictConfig as a tree structure using the Rich library.

  • cfg: A DictConfig composed by Hydra.
  • print_order: Determines in what order config components are printed.
  • resolve: Whether to resolve reference fields of DictConfig.
  • save_to_file: Whether to export config to the hydra output folder.
Source code in src/ezpz/configs.py
def print_config_tree(
    cfg: DictConfig,
    resolve: bool = True,
    save_to_file: bool = True,
    verbose: bool = True,
    style: str = "tree",
    print_order: Optional[Sequence[str]] = None,
    highlight: bool = True,
    outfile: Optional[Union[str, os.PathLike, Path]] = None,
) -> Tree:
    """Prints the contents of a DictConfig as a tree structure using the Rich
    library.

    - cfg: A DictConfig composed by Hydra.
    - print_order: Determines in what order config components are printed.
    - resolve: Whether to resolve reference fields of DictConfig.
    - save_to_file: Whether to export config to the hydra output folder.
    """
    from rich.console import Console
    from rich.theme import Theme

    from ezpz.log.config import STYLES

    name = cfg.get("_target_", "cfg")
    console = Console(record=True, theme=Theme(STYLES))
    tree = Tree(label=name, highlight=highlight)
    queue = []
    # add fields from `print_order` to queue
    if print_order is not None:
        for field in print_order:
            (
                queue.append(field)
                if field in cfg
                else log.warning(
                    f"Field '{field}' not found in config. "
                    f"Skipping '{field}' config printing..."
                )
            )
    # add all the other fields to queue (not specified in `print_order`)
    for field in cfg:
        if field not in queue:
            queue.append(field)
    # generate config tree from queue
    for field in queue:
        branch = tree.add(field, highlight=highlight)  # , guide_style=style)
        config_group = cfg[field]
        if isinstance(config_group, DictConfig):
            branch_content = str(
                OmegaConf.to_yaml(config_group, resolve=resolve)
            )
            branch.add(Text(branch_content, style="red"))
        else:
            branch_content = str(config_group)
            branch.add(Text(branch_content, style="blue"))
    if verbose or save_to_file:
        console.print(tree)
        if save_to_file:
            outfpath = (
                Path(os.getcwd()).joinpath("config_tree.log")
                if outfile is None
                else Path(outfile)
            )
            console.save_text(outfpath.as_posix())
    return tree

print_json(json_str=None, console=None, *, data=None, indent=2, highlight=True, skip_keys=False, ensure_ascii=False, check_circular=True, allow_nan=True, default=None, sort_keys=False) ⚓︎

Pretty prints JSON. Output will be valid JSON.

Parameters:

Name Type Description Default
json_str Optional[str]

A string containing JSON.

None
data Any

If json is not supplied, then encode this data.

None
indent Union[None, int, str]

Number of spaces to indent. Defaults to 2.

2
highlight bool

Enable highlighting of output: Defaults to True.

True
skip_keys bool

Skip keys not of a basic type. Defaults to False.

False
ensure_ascii bool

Escape all non-ascii characters. Defaults to False.

False
check_circular bool

Check for circular references. Defaults to True.

True
allow_nan bool

Allow NaN and Infinity values. Defaults to True.

True
default Callable

A callable that converts values that can not be encoded in to something that can be JSON encoded. Defaults to None.

None
sort_keys bool

Sort dictionary keys. Defaults to False.

False
Source code in src/ezpz/configs.py
def print_json(
    json_str: Optional[str] = None,
    console: Optional[Console] = None,
    *,
    data: Any = None,
    indent: Union[None, int, str] = 2,
    highlight: bool = True,
    skip_keys: bool = False,
    ensure_ascii: bool = False,
    check_circular: bool = True,
    allow_nan: bool = True,
    default: Optional[Callable[[Any], Any]] = None,
    sort_keys: bool = False,
) -> None:
    """Pretty prints JSON. Output will be valid JSON.

    Args:
        json_str (Optional[str]): A string containing JSON.
        data (Any): If json is not supplied, then encode this data.
        indent (Union[None, int, str], optional): Number of spaces to indent.
            Defaults to 2.
        highlight (bool, optional): Enable highlighting of output:
            Defaults to True.
        skip_keys (bool, optional): Skip keys not of a basic type.
            Defaults to False.
        ensure_ascii (bool, optional): Escape all non-ascii characters.
            Defaults to False.
        check_circular (bool, optional): Check for circular references.
            Defaults to True.
        allow_nan (bool, optional): Allow NaN and Infinity values.
            Defaults to True.
        default (Callable, optional): A callable that converts values
            that can not be encoded in to something that can be JSON
            encoded.
            Defaults to None.
        sort_keys (bool, optional): Sort dictionary keys. Defaults to False.
    """
    if json_str is None and data is None:
        raise ValueError(
            "Either `json_str` or `data` must be provided. "
            "Did you mean print_json(data={data!r}) ?"
        )
    if json_str is not None and data is not None:
        raise ValueError(
            " ".join(
                [
                    "Only one of `json_str` or `data` should be provided.",
                    "Did you mean print_json(json_str={json_str!r}) ?",
                    "Or print_json(data={data!r}) ?",
                    "Received both:",
                    f"json_str={json_str!r}",
                    f"data={data!r}",
                ]
            )
        )
    from rich.json import JSON

    from ezpz.log.console import get_console

    console = get_console() if console is None else console
    if json_str is None:
        json_renderable = JSON.from_data(
            data,
            indent=indent,
            highlight=highlight,
            skip_keys=skip_keys,
            ensure_ascii=ensure_ascii,
            check_circular=check_circular,
            allow_nan=allow_nan,
            default=default,
            sort_keys=sort_keys,
        )
    else:
        json_renderable = JSON(
            json_str,
            indent=indent,
            highlight=highlight,
            skip_keys=skip_keys,
            ensure_ascii=ensure_ascii,
            check_circular=check_circular,
            allow_nan=allow_nan,
            default=default,
            sort_keys=sort_keys,
        )
    assert console is not None and isinstance(console, Console)
    log.info(Text(str(json_renderable)).render(console=console))

savejobenv_dep() ⚓︎

Return the savejobenv helper path (for debugging).

Source code in src/ezpz/configs.py
def savejobenv_dep():
    """Return the ``savejobenv`` helper path (for debugging)."""
    print(SAVEJOBENV)
    return SAVEJOBENV