ezpz.configs⚓︎
- See ezpz/
configs.py
Configuration dataclasses and utility functions for training, DeepSpeed, and HuggingFace integration.
TrainConfig⚓︎
High-level training configuration:
from ezpz.configs import TrainConfig
config = TrainConfig(
seed=42,
dtype="bf16",
use_wandb=True,
wandb_project_name="my-project",
ds_config_path="./ds_config.json",
)
| Field | Type | Default | Description |
|---|---|---|---|
gas |
int |
1 |
Gradient accumulation steps |
use_wandb |
bool |
False |
Enable Weights & Biases logging |
seed |
int \| None |
None |
Random seed |
port |
str \| None |
None |
Rendezvous port |
dtype |
Any \| None |
None |
Data type for training |
load_from |
str \| None |
None |
Path to load checkpoint from |
save_to |
str \| None |
None |
Path to save checkpoint to |
ds_config_path |
str \| None |
None |
Path to DeepSpeed config |
wandb_project_name |
str \| None |
None |
W&B project name |
ngpus |
int \| None |
None |
Number of GPUs to use |
Unknown keyword arguments passed to the constructor are captured in
self.extras rather than raising an error.
ZeroConfig⚓︎
DeepSpeed ZeRO optimizer configuration with all ZeRO stage options:
from ezpz.configs import ZeroConfig
zero = ZeroConfig(stage=2, overlap_comm=True, contiguous_gradients=True)
HuggingFace Configs⚓︎
HfModelArguments⚓︎
Configuration for HuggingFace model loading:
from ezpz.configs import HfModelArguments
model_args = HfModelArguments(
model_name_or_path="gpt2",
torch_dtype="bfloat16",
use_fast_tokenizer=True,
)
HfDataTrainingArguments⚓︎
Configuration for HuggingFace dataset loading and preprocessing:
from ezpz.configs import HfDataTrainingArguments
data_args = HfDataTrainingArguments(
dataset_name="wikitext",
dataset_config_name="wikitext-2-raw-v1",
block_size=1024,
)
Vision Transformer Configs⚓︎
ViTConfig⚓︎
Standard Vision Transformer configuration:
from ezpz.configs import ViTConfig
vit = ViTConfig(
img_size=224,
patch_size=16,
depth=12,
num_heads=12,
hidden_dim=768,
num_classes=10,
)
timmViTConfig⚓︎
Timm-compatible Vision Transformer configuration with additional training parameters:
TrainArgs⚓︎
Training hyperparameters dataclass used by the example scripts:
from ezpz.configs import TrainArgs
args = TrainArgs(
batch_size=32,
max_iters=1000,
fsdp=True,
dtype="bf16",
compile=True,
)
Utility Functions⚓︎
get_scheduler()⚓︎
Detect the active job scheduler from environment variables:
from ezpz.configs import get_scheduler
scheduler = get_scheduler() # Returns "PBS", "SLURM", or falls back to hostname-based detection
print_config_tree()⚓︎
Display an OmegaConf DictConfig as a rich tree in the terminal:
Configuration dataclasses, logging setup, and path constants for ezpz.
Provides :class:TrainConfig, :class:ZeroConfig, :class:ViTConfig, and
related dataclasses used to configure distributed training runs, DeepSpeed
ZeRO optimisation, and vision model architectures. Also exposes directory
constants (LOGS_DIR, OUTPUTS_DIR) and helpers for scheduler
detection and DeepSpeed config loading.
HfDataTrainingArguments
dataclass
⚓︎
Arguments pertaining to what data we are going to input our model for training and eval.
Source code in src/ezpz/configs.py
@dataclass
class HfDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the training data."},
)
dataset_name: Optional[str] = field(
default=None,
metadata={
"help": "The name of the dataset to use (via the datasets library)."
},
)
dataset_config_name: Optional[str] = field(
default=None,
metadata={
"help": "The configuration name of the dataset to use (via the datasets library)."
},
)
train_split_str: Optional[str] = field(
default=None,
metadata={
"help": "The split string to use for the train split (via the datasets library)."
},
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the train split to use (via the datasets library)."
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation split to use (via the datasets library)."
},
)
validation_split_str: Optional[str] = field(
default=None,
metadata={
"help": "The split string to use for the validation split (via the datasets library)."
},
)
test_split_name: Optional[str] = field(
default="test",
metadata={
"help": "The name of the test split to use (via the datasets library)."
},
)
test_split_str: Optional[str] = field(
default=None,
metadata={
"help": "The split string to use for the test split (via the datasets library)."
},
)
train_file: Optional[str] = field(
default=None,
metadata={"help": "The input training data file (a text file)."},
)
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
streaming: bool = field(
default=False, metadata={"help": "Enable streaming mode"}
)
block_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={
"help": "The number of processes to use for the preprocessing."
},
)
keep_linebreaks: bool = field(
default=True,
metadata={
"help": "Whether to keep line breaks when using TXT files or not."
},
)
def __post_init__(self):
"""Validate dataset arguments and ensure required files are present."""
from transformers.utils.versions import require_version
if self.streaming:
require_version(
"datasets>=2.0.0",
"The streaming feature requires `datasets>=2.0.0`",
)
if (
self.dataset_name is None
and self.data_path is None
and self.train_file is None
and self.validation_file is None
):
raise ValueError(
"You must specify at least one of the following: "
"a dataset name, a data path, a training file, or a validation file."
)
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`validation_file` should be a csv, a json or a txt file."
__post_init__()
⚓︎
Validate dataset arguments and ensure required files are present.
Source code in src/ezpz/configs.py
def __post_init__(self):
"""Validate dataset arguments and ensure required files are present."""
from transformers.utils.versions import require_version
if self.streaming:
require_version(
"datasets>=2.0.0",
"The streaming feature requires `datasets>=2.0.0`",
)
if (
self.dataset_name is None
and self.data_path is None
and self.train_file is None
and self.validation_file is None
):
raise ValueError(
"You must specify at least one of the following: "
"a dataset name, a data path, a training file, or a validation file."
)
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`validation_file` should be a csv, a json or a txt file."
HfModelArguments
dataclass
⚓︎
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
Source code in src/ezpz/configs.py
@dataclass
class HfModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
wandb_project_name: Optional[str] = field( # type:ignore
default=None,
metadata={
"help": (
"The name of the wandb project to use. If not specified, will use the model name."
)
},
)
model_name_or_path: Optional[str] = field( # type:ignore
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str | None] = field(
default=None,
metadata={
"help": "If training from scratch, pass a model type from the list: "
+ ", ".join("https://huggingface.co/docs/transformers/en/models")
},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": (
"Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
)
},
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as model_name"
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"
},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
},
)
token: Optional[str] = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={
"help": (
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
"set True will benefit LLM loading time and RAM consumption."
)
},
)
def __post_init__(self):
"""Validate mutually exclusive Hugging Face model configuration options."""
if self.config_overrides is not None and (
self.config_name is not None or self.model_name_or_path is not None
):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
__post_init__()
⚓︎
Validate mutually exclusive Hugging Face model configuration options.
Source code in src/ezpz/configs.py
def __post_init__(self):
"""Validate mutually exclusive Hugging Face model configuration options."""
if self.config_overrides is not None and (
self.config_name is not None or self.model_name_or_path is not None
):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
TrainArgs
dataclass
⚓︎
Arguments for vision model training runs.
Source code in src/ezpz/configs.py
@dataclass
class TrainArgs:
"""Arguments for vision model training runs."""
img_size: int
"""Input image resolution."""
batch_size: int
"""Training batch size per device."""
num_heads: int
"""Number of self-attention heads."""
head_dim: int
"""Dimension per attention head."""
depth: int
"""Number of transformer encoder layers."""
patch_size: int
"""Size of each image patch."""
dtype: str
"""Data type string (e.g. ``"bf16"``, ``"fp32"``)."""
compile: bool
"""Enable ``torch.compile`` for the model."""
attn_type: str
"""Attention backend (``"native"``, ``"sdpa"``, ``"flash"``, etc.)."""
warmup: int | float
"""Learning rate warmup steps or fraction of total steps."""
num_workers: int
"""Number of data loader workers."""
max_iters: int
"""Maximum training iterations."""
fsdp: Optional[bool] = False
"""Enable Fully Sharded Data Parallel wrapping."""
format: Optional[str] = field(default_factory=str)
"""Output format string for logging."""
cuda_sdpa_backend: Optional[str] = field(default_factory=str)
"""Override the CUDA SDPA backend selection."""
attn_type
instance-attribute
⚓︎
Attention backend ("native", "sdpa", "flash", etc.).
batch_size
instance-attribute
⚓︎
Training batch size per device.
compile
instance-attribute
⚓︎
Enable torch.compile for the model.
cuda_sdpa_backend = field(default_factory=str)
class-attribute
instance-attribute
⚓︎
Override the CUDA SDPA backend selection.
depth
instance-attribute
⚓︎
Number of transformer encoder layers.
dtype
instance-attribute
⚓︎
Data type string (e.g. "bf16", "fp32").
format = field(default_factory=str)
class-attribute
instance-attribute
⚓︎
Output format string for logging.
fsdp = False
class-attribute
instance-attribute
⚓︎
Enable Fully Sharded Data Parallel wrapping.
head_dim
instance-attribute
⚓︎
Dimension per attention head.
img_size
instance-attribute
⚓︎
Input image resolution.
max_iters
instance-attribute
⚓︎
Maximum training iterations.
num_heads
instance-attribute
⚓︎
Number of self-attention heads.
num_workers
instance-attribute
⚓︎
Number of data loader workers.
patch_size
instance-attribute
⚓︎
Size of each image patch.
warmup
instance-attribute
⚓︎
Learning rate warmup steps or fraction of total steps.
TrainConfig
dataclass
⚓︎
Bases: BaseConfig
High-level training options shared by ezpz scripts.
Source code in src/ezpz/configs.py
@dataclass(init=False)
@rich.repr.auto
class TrainConfig(BaseConfig):
"""High-level training options shared by ezpz scripts."""
gas: int = 1
"""Gradient accumulation steps."""
use_wandb: bool = False
"""Enable Weights & Biases logging."""
seed: Optional[int] = None
"""Random seed for reproducibility. ``None`` means no explicit seeding."""
port: Optional[str] = None
"""Port for the distributed rendezvous endpoint."""
dtype: Optional[Any] = None
"""Default data type for model parameters (e.g. ``"bf16"``)."""
load_from: Optional[str] = None
"""Path to a checkpoint to resume training from."""
save_to: Optional[str] = None
"""Directory where checkpoints will be saved."""
ds_config_path: Optional[str] = None
"""Path to a DeepSpeed JSON/YAML config file."""
wandb_project_name: Optional[str] = None
"""W&B project name. Auto-set from ``WANDB_PROJECT`` when ``use_wandb`` is True."""
ngpus: Optional[int] = None
"""Number of GPUs to use. ``None`` means use all available."""
extras: dict[str, Any] = field(
default_factory=dict, init=False, repr=False
)
"""Extra keyword arguments not matching any declared field."""
def __init__(self, **kwargs: Any) -> None:
"""Populate known fields while capturing any extras in ``self.extras``."""
extras: dict[str, Any] = {}
for name, field_def in self.__dataclass_fields__.items():
if name == "extras":
continue
if name in kwargs:
value = kwargs.pop(name)
elif field_def.default is not MISSING:
value = field_def.default
elif field_def.default_factory is not MISSING: # type: ignore[attr-defined]
value = field_def.default_factory() # type: ignore[misc]
else:
value = None
setattr(self, name, value)
for key, value in kwargs.items():
setattr(self, key, value)
extras[key] = value
self.extras = extras
self.__post_init__()
def to_str(self) -> str:
"""Return a compact identifier for this config."""
parts = [f"gas-{self.gas}"]
if self.seed is not None:
parts.append(f"seed-{self.seed}")
return "_".join(parts)
def __post_init__(self):
"""Validate configuration after initialisation."""
if self.use_wandb and self.wandb_project_name is None:
self.wandb_project_name = os.environ.get(
"WANDB_PROJECT", os.environ.get("WB_PROJECT", "ezpz")
)
ds_config_path = None
class-attribute
instance-attribute
⚓︎
Path to a DeepSpeed JSON/YAML config file.
dtype = None
class-attribute
instance-attribute
⚓︎
Default data type for model parameters (e.g. "bf16").
extras = extras
class-attribute
instance-attribute
⚓︎
Extra keyword arguments not matching any declared field.
gas = 1
class-attribute
instance-attribute
⚓︎
Gradient accumulation steps.
load_from = None
class-attribute
instance-attribute
⚓︎
Path to a checkpoint to resume training from.
ngpus = None
class-attribute
instance-attribute
⚓︎
Number of GPUs to use. None means use all available.
port = None
class-attribute
instance-attribute
⚓︎
Port for the distributed rendezvous endpoint.
save_to = None
class-attribute
instance-attribute
⚓︎
Directory where checkpoints will be saved.
seed = None
class-attribute
instance-attribute
⚓︎
Random seed for reproducibility. None means no explicit seeding.
use_wandb = False
class-attribute
instance-attribute
⚓︎
Enable Weights & Biases logging.
wandb_project_name = None
class-attribute
instance-attribute
⚓︎
W&B project name. Auto-set from WANDB_PROJECT when use_wandb is True.
__getitem__(key)
⚓︎
__init__(**kwargs)
⚓︎
Populate known fields while capturing any extras in self.extras.
Source code in src/ezpz/configs.py
def __init__(self, **kwargs: Any) -> None:
"""Populate known fields while capturing any extras in ``self.extras``."""
extras: dict[str, Any] = {}
for name, field_def in self.__dataclass_fields__.items():
if name == "extras":
continue
if name in kwargs:
value = kwargs.pop(name)
elif field_def.default is not MISSING:
value = field_def.default
elif field_def.default_factory is not MISSING: # type: ignore[attr-defined]
value = field_def.default_factory() # type: ignore[misc]
else:
value = None
setattr(self, name, value)
for key, value in kwargs.items():
setattr(self, key, value)
extras[key] = value
self.extras = extras
self.__post_init__()
__post_init__()
⚓︎
Validate configuration after initialisation.
from_file(fpath)
⚓︎
get_config()
⚓︎
to_dict()
⚓︎
to_file(fpath)
⚓︎
to_json()
⚓︎
Return a JSON string representation of the configuration.
to_str()
⚓︎
ViTConfig
dataclass
⚓︎
Configuration for a Vision Transformer model.
Source code in src/ezpz/configs.py
@dataclass
class ViTConfig:
"""Configuration for a Vision Transformer model."""
img_size: int = 224
"""Input image resolution (height and width)."""
patch_size: int = 16
"""Size of each image patch."""
depth: int = 12
"""Number of transformer encoder layers."""
num_heads: int = 12
"""Number of self-attention heads."""
hidden_dim: int = 768
"""Embedding / hidden dimension."""
mlp_dim: int = 3072
"""Feed-forward network inner dimension."""
dropout: float = 0.0
"""Dropout rate for embeddings and feed-forward layers."""
attention_dropout: float = 0.0
"""Dropout rate applied to attention weights."""
num_classes: int = 10
"""Number of output classification classes."""
def __post_init__(self):
self.seq_len = (self.img_size // self.patch_size) ** 2 # 196, default
attention_dropout = 0.0
class-attribute
instance-attribute
⚓︎
Dropout rate applied to attention weights.
depth = 12
class-attribute
instance-attribute
⚓︎
Number of transformer encoder layers.
dropout = 0.0
class-attribute
instance-attribute
⚓︎
Dropout rate for embeddings and feed-forward layers.
hidden_dim = 768
class-attribute
instance-attribute
⚓︎
Embedding / hidden dimension.
img_size = 224
class-attribute
instance-attribute
⚓︎
Input image resolution (height and width).
mlp_dim = 3072
class-attribute
instance-attribute
⚓︎
Feed-forward network inner dimension.
num_classes = 10
class-attribute
instance-attribute
⚓︎
Number of output classification classes.
num_heads = 12
class-attribute
instance-attribute
⚓︎
Number of self-attention heads.
patch_size = 16
class-attribute
instance-attribute
⚓︎
Size of each image patch.
ZeroConfig
dataclass
⚓︎
Subset of DeepSpeed ZeRO options exposed via the ezpz CLI.
Source code in src/ezpz/configs.py
@dataclass
class ZeroConfig:
"""Subset of DeepSpeed ZeRO options exposed via the ezpz CLI."""
stage: int = 0
"""ZeRO optimisation stage (0–3)."""
allgather_partitions: Optional[bool] = None
"""Use allgather for partition reconstruction."""
allgather_bucket_size: int = int(5e8)
"""Number of elements allgathered at a time."""
overlap_comm: Optional[bool] = None
"""Overlap gradient reduction with backward pass."""
reduce_scatter: bool = True
"""Use reduce-scatter instead of allreduce for gradients."""
reduce_bucket_size: int = int(5e8)
"""Number of elements reduced at a time."""
contiguous_gradients: Optional[bool] = None
"""Copy gradients to a contiguous buffer as they are produced."""
offload_param: Optional[dict] = None
"""Parameter offloading config (e.g. ``{"device": "cpu"}``)."""
offload_optimizer: Optional[dict] = None
"""Optimizer state offloading config (e.g. ``{"device": "cpu"}``)."""
stage3_max_live_parameters: int = int(1e9)
"""Max number of parameters resident per GPU before releasing (stage 3)."""
stage3_max_reuse_distance: int = int(1e9)
"""Max reuse distance for parameters before re-fetching (stage 3)."""
stage3_prefetch_bucket_size: int = int(5e8)
"""Prefetch bucket size for stage 3 parameter fetching."""
stage3_param_persistence_threshold: int = int(1e6)
"""Parameters smaller than this stay on the GPU (stage 3)."""
sub_group_size: Optional[int] = None
"""Number of parameters within a sub-group for ZeRO-3 partitioning."""
elastic_checkpoint: Optional[dict] = None
"""Configuration for elastic checkpointing."""
stage3_gather_16bit_weights_on_model_save: Optional[bool] = None
"""Gather full fp16 weights on save for portability (stage 3)."""
ignore_unused_parameters: Optional[bool] = None
"""Silence errors about parameters not used in the forward pass."""
round_robin_gradients: Optional[bool] = None
"""Distribute gradient partitions round-robin across ranks."""
zero_hpz_partition_size: Optional[int] = None
"""Hierarchical partitioning group size (ZeRO++)."""
zero_quantized_weights: Optional[bool] = None
"""Enable weight quantisation for communication (ZeRO++)."""
zero_quantized_gradients: Optional[bool] = None
"""Enable gradient quantisation for communication (ZeRO++)."""
log_trace_cache_warnings: Optional[bool] = None
"""Log warnings about trace cache misses."""
allgather_bucket_size = int(500000000.0)
class-attribute
instance-attribute
⚓︎
Number of elements allgathered at a time.
allgather_partitions = None
class-attribute
instance-attribute
⚓︎
Use allgather for partition reconstruction.
contiguous_gradients = None
class-attribute
instance-attribute
⚓︎
Copy gradients to a contiguous buffer as they are produced.
elastic_checkpoint = None
class-attribute
instance-attribute
⚓︎
Configuration for elastic checkpointing.
ignore_unused_parameters = None
class-attribute
instance-attribute
⚓︎
Silence errors about parameters not used in the forward pass.
log_trace_cache_warnings = None
class-attribute
instance-attribute
⚓︎
Log warnings about trace cache misses.
offload_optimizer = None
class-attribute
instance-attribute
⚓︎
Optimizer state offloading config (e.g. {"device": "cpu"}).
offload_param = None
class-attribute
instance-attribute
⚓︎
Parameter offloading config (e.g. {"device": "cpu"}).
overlap_comm = None
class-attribute
instance-attribute
⚓︎
Overlap gradient reduction with backward pass.
reduce_bucket_size = int(500000000.0)
class-attribute
instance-attribute
⚓︎
Number of elements reduced at a time.
reduce_scatter = True
class-attribute
instance-attribute
⚓︎
Use reduce-scatter instead of allreduce for gradients.
round_robin_gradients = None
class-attribute
instance-attribute
⚓︎
Distribute gradient partitions round-robin across ranks.
stage = 0
class-attribute
instance-attribute
⚓︎
ZeRO optimisation stage (0–3).
stage3_gather_16bit_weights_on_model_save = None
class-attribute
instance-attribute
⚓︎
Gather full fp16 weights on save for portability (stage 3).
stage3_max_live_parameters = int(1000000000.0)
class-attribute
instance-attribute
⚓︎
Max number of parameters resident per GPU before releasing (stage 3).
stage3_max_reuse_distance = int(1000000000.0)
class-attribute
instance-attribute
⚓︎
Max reuse distance for parameters before re-fetching (stage 3).
stage3_param_persistence_threshold = int(1000000.0)
class-attribute
instance-attribute
⚓︎
Parameters smaller than this stay on the GPU (stage 3).
stage3_prefetch_bucket_size = int(500000000.0)
class-attribute
instance-attribute
⚓︎
Prefetch bucket size for stage 3 parameter fetching.
sub_group_size = None
class-attribute
instance-attribute
⚓︎
Number of parameters within a sub-group for ZeRO-3 partitioning.
zero_hpz_partition_size = None
class-attribute
instance-attribute
⚓︎
Hierarchical partitioning group size (ZeRO++).
zero_quantized_gradients = None
class-attribute
instance-attribute
⚓︎
Enable gradient quantisation for communication (ZeRO++).
zero_quantized_weights = None
class-attribute
instance-attribute
⚓︎
Enable weight quantisation for communication (ZeRO++).
timmViTConfig
dataclass
⚓︎
Configuration for a timm-style Vision Transformer (larger defaults).
Source code in src/ezpz/configs.py
@dataclass
class timmViTConfig:
"""Configuration for a timm-style Vision Transformer (larger defaults)."""
img_size: int = 224
"""Input image resolution."""
batch_size: int = 128
"""Training batch size per device."""
num_heads: int = 16
"""Number of self-attention heads."""
head_dim: int = 64
"""Dimension per attention head."""
depth: int = 24
"""Number of transformer encoder layers."""
patch_size: int = 16
"""Size of each image patch."""
hidden_dim: int = 1024
"""Embedding / hidden dimension."""
mlp_dim: int = 4096
"""Feed-forward network inner dimension."""
dropout: float = 0.0
"""Dropout rate for embeddings and feed-forward layers."""
attention_dropout: float = 0.0
"""Dropout rate applied to attention weights."""
num_classes: int = 1000
"""Number of output classification classes."""
def __post_init__(self):
self.seq_len = (self.img_size // self.patch_size) ** 2 # 196, default
attention_dropout = 0.0
class-attribute
instance-attribute
⚓︎
Dropout rate applied to attention weights.
batch_size = 128
class-attribute
instance-attribute
⚓︎
Training batch size per device.
depth = 24
class-attribute
instance-attribute
⚓︎
Number of transformer encoder layers.
dropout = 0.0
class-attribute
instance-attribute
⚓︎
Dropout rate for embeddings and feed-forward layers.
head_dim = 64
class-attribute
instance-attribute
⚓︎
Dimension per attention head.
hidden_dim = 1024
class-attribute
instance-attribute
⚓︎
Embedding / hidden dimension.
img_size = 224
class-attribute
instance-attribute
⚓︎
Input image resolution.
mlp_dim = 4096
class-attribute
instance-attribute
⚓︎
Feed-forward network inner dimension.
num_classes = 1000
class-attribute
instance-attribute
⚓︎
Number of output classification classes.
num_heads = 16
class-attribute
instance-attribute
⚓︎
Number of self-attention heads.
patch_size = 16
class-attribute
instance-attribute
⚓︎
Size of each image patch.
_derive_default_job_name(argv)
⚓︎
Derive a log-friendly job name from command-line arguments.
Looks for python -m <module> patterns and converts the module name
(e.g., ezpz.examples.vit) to a filename-safe form (ezpz-examples-vit).
Falls back to the executable base name plus the first non-option arg.
Source code in src/ezpz/configs.py
def _derive_default_job_name(argv: Sequence[str]) -> str:
"""Derive a log-friendly job name from command-line arguments.
Looks for ``python -m <module>`` patterns and converts the module name
(e.g., ``ezpz.examples.vit``) to a filename-safe form (``ezpz-examples-vit``).
Falls back to the executable base name plus the first non-option arg.
"""
if not argv:
return "ezpz"
# Detect `python -m <module>` invocations
for i, arg in enumerate(argv):
if arg == "-m" and i + 1 < len(argv):
module_name = argv[i + 1]
return module_name
# Fallback: use executable basename + optional first subcommand
base = Path(argv[0]).name
if base.endswith(".py"):
base = base[:-3]
parts = [base]
if len(argv) > 1 and argv[1] and not argv[1].startswith("-"):
parts.append(argv[1])
return "-".join(parts)
cmd_exists(cmd)
⚓︎
Check whether command exists.
Examples:
command_exists(cmd)
⚓︎
get_json_log_file()
⚓︎
get_logging_config(rank=None)
⚓︎
Return the logging configuration dictionary used by logging.config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rank
|
Optional[int]
|
The distributed rank. If None, will be determined automatically. Used to create per-rank log files when EZPZ_LOG_FROM_ALL_RANKS is set. |
None
|
Source code in src/ezpz/configs.py
def get_logging_config(rank: Optional[int] = None) -> dict:
"""Return the logging configuration dictionary used by ``logging.config``.
Args:
rank: The distributed rank. If None, will be determined automatically.
Used to create per-rank log files when EZPZ_LOG_FROM_ALL_RANKS is set.
"""
log_from_all_ranks = _get_log_from_all_ranks()
resolved_rank = _determine_rank(rank)
cache_key = (resolved_rank, log_from_all_ranks)
cached_config = _LOGGING_CONFIG_CACHE.get(cache_key)
if cached_config is not None:
return cached_config
config_dict = _build_logging_config(
rank=resolved_rank,
log_from_all_ranks=log_from_all_ranks,
)
_LOGGING_CONFIG_CACHE[cache_key] = config_dict
return config_dict
get_scheduler(_scheduler=None)
⚓︎
Infer the active scheduler from environment variables or hostname.
Source code in src/ezpz/configs.py
def get_scheduler(_scheduler: Optional[str] = None) -> str:
"""Infer the active scheduler from environment variables or hostname."""
from ezpz import get_hostname, get_machine
if _scheduler is not None:
log.info(f"Using user-specified scheduler: {_scheduler}")
return _scheduler.upper()
if os.environ.get("PBS_JOBID"):
return "PBS"
if os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID"):
return "SLURM"
machine = get_machine(get_hostname())
if machine.lower() in [
"thetagpu",
"sunspot",
"polaris",
"aurora",
"sophia",
]:
return SCHEDULERS["ALCF"]
if machine.lower() in ["frontier"]:
return SCHEDULERS["OLCF"]
if machine.lower() in ["nersc", "perlmutter"]:
return SCHEDULERS["NERSC"]
return "UNKNOWN"
getjobenv_dep()
⚓︎
git_ds_info()
⚓︎
Log the output of DeepSpeed's environment report plus Git metadata.
Source code in src/ezpz/configs.py
def git_ds_info():
"""Log the output of DeepSpeed's environment report plus Git metadata."""
from deepspeed.env_report import main as ds_report # type: ignore[import-not-found]
ds_report()
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists("git"):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode("utf-8").strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode("utf-8").strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
log.info(
f"**** Git info for DeepSpeed:"
f" git_hash={git_hash} git_branch={git_branch} ****"
)
load_ds_config(fpath=None)
⚓︎
Load a DeepSpeed configuration file (JSON or YAML).
Source code in src/ezpz/configs.py
def load_ds_config(
fpath: Optional[Union[str, os.PathLike, Path]] = None, # type:ignore[reportDeprecated]
) -> dict[str, Any]:
"""Load a DeepSpeed configuration file (JSON or YAML)."""
fpath = Path(DS_CONFIG_PATH) if fpath is None else f"{fpath}"
cfgpath = Path(fpath)
if cfgpath.suffix == ".json":
with cfgpath.open("r") as f:
ds_config: dict[str, Any] = json.load(f)
return ds_config
if cfgpath.suffix == ".yaml":
with cfgpath.open("r") as stream:
dsconfig: dict[str, Any] = dict(yaml.safe_load(stream))
return dsconfig
raise TypeError("Unexpected FileType")
print_config(cfg)
⚓︎
Render cfg to the active rich console.
Source code in src/ezpz/configs.py
def print_config(cfg: Union[dict, str]) -> None:
"""Render ``cfg`` to the active rich console."""
# try:
# from hydra.utils import instantiate
# config = instantiate(cfg)
# except Exception:
# config = OmegaConf.to_container(cfg, resolve=True)
# config = OmegaConf.to_container(cfg, resolve=True)
# if isinstance(cfg, dict):
# jstr = json.dumps(cfg, indent=4)
# else:
# jstr = cfg
from rich.logging import RichHandler
from ezpz.log.handler import RichHandler as EnrichHandler
console = None
for handler in log.handlers:
if isinstance(handler, (RichHandler, EnrichHandler)):
console = handler.console
if console is None:
from ezpz.log.console import get_console
console = get_console()
# console.print_json(data=cfg, indent=4, highlight=True)
print_json(data=cfg, console=console, indent=4, highlight=True)
print_config_tree(cfg, resolve=True, save_to_file=True, verbose=True, style='tree', print_order=None, highlight=True, outfile=None)
⚓︎
Prints the contents of a DictConfig as a tree structure using the Rich library.
- cfg: A DictConfig composed by Hydra.
- print_order: Determines in what order config components are printed.
- resolve: Whether to resolve reference fields of DictConfig.
- save_to_file: Whether to export config to the hydra output folder.
Source code in src/ezpz/configs.py
def print_config_tree(
cfg: DictConfig,
resolve: bool = True,
save_to_file: bool = True,
verbose: bool = True,
style: str = "tree",
print_order: Optional[Sequence[str]] = None,
highlight: bool = True,
outfile: Optional[Union[str, os.PathLike, Path]] = None,
) -> Tree:
"""Prints the contents of a DictConfig as a tree structure using the Rich
library.
- cfg: A DictConfig composed by Hydra.
- print_order: Determines in what order config components are printed.
- resolve: Whether to resolve reference fields of DictConfig.
- save_to_file: Whether to export config to the hydra output folder.
"""
from rich.console import Console
from rich.theme import Theme
from ezpz.log.config import STYLES
name = cfg.get("_target_", "cfg")
console = Console(record=True, theme=Theme(STYLES))
tree = Tree(label=name, highlight=highlight)
queue = []
# add fields from `print_order` to queue
if print_order is not None:
for field in print_order:
(
queue.append(field)
if field in cfg
else log.warning(
f"Field '{field}' not found in config. "
f"Skipping '{field}' config printing..."
)
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, highlight=highlight) # , guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = str(
OmegaConf.to_yaml(config_group, resolve=resolve)
)
branch.add(Text(branch_content, style="red"))
else:
branch_content = str(config_group)
branch.add(Text(branch_content, style="blue"))
if verbose or save_to_file:
console.print(tree)
if save_to_file:
outfpath = (
Path(os.getcwd()).joinpath("config_tree.log")
if outfile is None
else Path(outfile)
)
console.save_text(outfpath.as_posix())
return tree
print_json(json_str=None, console=None, *, data=None, indent=2, highlight=True, skip_keys=False, ensure_ascii=False, check_circular=True, allow_nan=True, default=None, sort_keys=False)
⚓︎
Pretty prints JSON. Output will be valid JSON.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
json_str
|
Optional[str]
|
A string containing JSON. |
None
|
data
|
Any
|
If json is not supplied, then encode this data. |
None
|
indent
|
Union[None, int, str]
|
Number of spaces to indent. Defaults to 2. |
2
|
highlight
|
bool
|
Enable highlighting of output: Defaults to True. |
True
|
skip_keys
|
bool
|
Skip keys not of a basic type. Defaults to False. |
False
|
ensure_ascii
|
bool
|
Escape all non-ascii characters. Defaults to False. |
False
|
check_circular
|
bool
|
Check for circular references. Defaults to True. |
True
|
allow_nan
|
bool
|
Allow NaN and Infinity values. Defaults to True. |
True
|
default
|
Callable
|
A callable that converts values that can not be encoded in to something that can be JSON encoded. Defaults to None. |
None
|
sort_keys
|
bool
|
Sort dictionary keys. Defaults to False. |
False
|
Source code in src/ezpz/configs.py
def print_json(
json_str: Optional[str] = None,
console: Optional[Console] = None,
*,
data: Any = None,
indent: Union[None, int, str] = 2,
highlight: bool = True,
skip_keys: bool = False,
ensure_ascii: bool = False,
check_circular: bool = True,
allow_nan: bool = True,
default: Optional[Callable[[Any], Any]] = None,
sort_keys: bool = False,
) -> None:
"""Pretty prints JSON. Output will be valid JSON.
Args:
json_str (Optional[str]): A string containing JSON.
data (Any): If json is not supplied, then encode this data.
indent (Union[None, int, str], optional): Number of spaces to indent.
Defaults to 2.
highlight (bool, optional): Enable highlighting of output:
Defaults to True.
skip_keys (bool, optional): Skip keys not of a basic type.
Defaults to False.
ensure_ascii (bool, optional): Escape all non-ascii characters.
Defaults to False.
check_circular (bool, optional): Check for circular references.
Defaults to True.
allow_nan (bool, optional): Allow NaN and Infinity values.
Defaults to True.
default (Callable, optional): A callable that converts values
that can not be encoded in to something that can be JSON
encoded.
Defaults to None.
sort_keys (bool, optional): Sort dictionary keys. Defaults to False.
"""
if json_str is None and data is None:
raise ValueError(
"Either `json_str` or `data` must be provided. "
"Did you mean print_json(data={data!r}) ?"
)
if json_str is not None and data is not None:
raise ValueError(
" ".join(
[
"Only one of `json_str` or `data` should be provided.",
"Did you mean print_json(json_str={json_str!r}) ?",
"Or print_json(data={data!r}) ?",
"Received both:",
f"json_str={json_str!r}",
f"data={data!r}",
]
)
)
from rich.json import JSON
from ezpz.log.console import get_console
console = get_console() if console is None else console
if json_str is None:
json_renderable = JSON.from_data(
data,
indent=indent,
highlight=highlight,
skip_keys=skip_keys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
default=default,
sort_keys=sort_keys,
)
else:
json_renderable = JSON(
json_str,
indent=indent,
highlight=highlight,
skip_keys=skip_keys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
default=default,
sort_keys=sort_keys,
)
assert console is not None and isinstance(console, Console)
log.info(Text(str(json_renderable)).render(console=console))