ezpz.configs⚓︎
- See ezpz/
configs.py
ezpz/configs.py
HfDataTrainingArguments
dataclass
⚓︎
Arguments pertaining to what data we are going to input our model for training and eval.
Source code in src/ezpz/configs.py
@dataclass
class HfDataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_path: Optional[str] = field(
default=None,
metadata={"help": "Path to the training data."},
)
dataset_name: Optional[str] = field(
default=None,
metadata={
"help": "The name of the dataset to use (via the datasets library)."
},
)
dataset_config_name: Optional[str] = field(
default=None,
metadata={
"help": "The configuration name of the dataset to use (via the datasets library)."
},
)
train_split_str: Optional[str] = field(
default=None,
metadata={
"help": "The split string to use for the train split (via the datasets library)."
},
)
train_split_name: Optional[str] = field(
default="train",
metadata={
"help": "The name of the train split to use (via the datasets library)."
},
)
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation split to use (via the datasets library)."
},
)
validation_split_str: Optional[str] = field(
default=None,
metadata={
"help": "The split string to use for the validation split (via the datasets library)."
},
)
test_split_name: Optional[str] = field(
default="test",
metadata={
"help": "The name of the test split to use (via the datasets library)."
},
)
test_split_str: Optional[str] = field(
default=None,
metadata={
"help": "The split string to use for the test split (via the datasets library)."
},
)
train_file: Optional[str] = field(
default=None,
metadata={"help": "The input training data file (a text file)."},
)
validation_file: Optional[str] = field(
default=None,
metadata={
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
streaming: bool = field(
default=False, metadata={"help": "Enable streaming mode"}
)
block_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={
"help": "The number of processes to use for the preprocessing."
},
)
keep_linebreaks: bool = field(
default=True,
metadata={
"help": "Whether to keep line breaks when using TXT files or not."
},
)
def __post_init__(self):
"""Validate dataset arguments and ensure required files are present."""
from transformers.utils.versions import require_version
if self.streaming:
require_version(
"datasets>=2.0.0",
"The streaming feature requires `datasets>=2.0.0`",
)
if (
self.dataset_name is None
and self.data_path is None
and self.train_file is None
and self.validation_file is None
):
raise ValueError(
"You must specify at least one of the following: "
"a dataset name, a data path, a training file, or a validation file."
)
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`validation_file` should be a csv, a json or a txt file."
__post_init__()
⚓︎
Validate dataset arguments and ensure required files are present.
Source code in src/ezpz/configs.py
def __post_init__(self):
"""Validate dataset arguments and ensure required files are present."""
from transformers.utils.versions import require_version
if self.streaming:
require_version(
"datasets>=2.0.0",
"The streaming feature requires `datasets>=2.0.0`",
)
if (
self.dataset_name is None
and self.data_path is None
and self.train_file is None
and self.validation_file is None
):
raise ValueError(
"You must specify at least one of the following: "
"a dataset name, a data path, a training file, or a validation file."
)
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in [
"csv",
"json",
"txt",
], "`validation_file` should be a csv, a json or a txt file."
HfModelArguments
dataclass
⚓︎
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
Source code in src/ezpz/configs.py
@dataclass
class HfModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
wandb_project_name: Optional[str] = field( # type:ignore
default=None,
metadata={
"help": (
"The name of the wandb project to use. If not specified, will use the model name."
)
},
)
model_name_or_path: Optional[str] = field( # type:ignore
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str | None] = field(
default=None,
metadata={
"help": "If training from scratch, pass a model type from the list: "
+ ", ".join("https://huggingface.co/docs/transformers/en/models")
},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": (
"Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
)
},
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as model_name"
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded from huggingface.co"
},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
},
)
model_revision: str = field(
default="main",
metadata={
"help": "The specific model version to use (can be a branch name, tag name or commit id)."
},
)
token: Optional[str] = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={
"help": (
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
"set True will benefit LLM loading time and RAM consumption."
)
},
)
def __post_init__(self):
"""Validate mutually exclusive Hugging Face model configuration options."""
if self.config_overrides is not None and (
self.config_name is not None or self.model_name_or_path is not None
):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
__post_init__()
⚓︎
Validate mutually exclusive Hugging Face model configuration options.
Source code in src/ezpz/configs.py
def __post_init__(self):
"""Validate mutually exclusive Hugging Face model configuration options."""
if self.config_overrides is not None and (
self.config_name is not None or self.model_name_or_path is not None
):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
TrainConfig
dataclass
⚓︎
Bases: BaseConfig
High-level training options shared by ezpz scripts.
Source code in src/ezpz/configs.py
@dataclass(init=False)
@rich.repr.auto
class TrainConfig(BaseConfig):
"""High-level training options shared by ezpz scripts."""
gas: int = 1
# ---- [NOTE]+ Framework + Backend ----------------
# `framework`: `{'backend'}`
# • `tensorflow`: `{'horovod'}`
# • `pytorch`: `{'DDP', 'deepspeed', 'horovod'}`
# -------------------------------------------------
framework: str = "pytorch"
backend: str = "DDP"
use_wandb: bool = False
seed: Optional[int] = None
port: Optional[str] = None
dtype: Optional[Any] = None
load_from: Optional[str] = None
save_to: Optional[str] = None
ds_config_path: Optional[str] = None
wandb_project_name: Optional[str] = None
ngpus: Optional[int] = None
extras: dict[str, Any] = field(
default_factory=dict, init=False, repr=False
)
def __init__(self, **kwargs: Any) -> None:
"""Populate known fields while capturing any extras in ``self.extras``."""
extras: dict[str, Any] = {}
for name, field_def in self.__dataclass_fields__.items():
if name == "extras":
continue
if name in kwargs:
value = kwargs.pop(name)
elif field_def.default is not MISSING:
value = field_def.default
elif field_def.default_factory is not MISSING: # type: ignore[attr-defined]
value = field_def.default_factory() # type: ignore[misc]
else:
value = None
setattr(self, name, value)
for key, value in kwargs.items():
setattr(self, key, value)
extras[key] = value
self.extras = extras
def to_str(self) -> str:
"""Return a compact identifier combining framework and backend."""
return "_".join(
[
f"fw-{self.framework}",
f"be-{self.backend}",
]
)
def __post_init__(self):
"""Validate framework/backend compatibility after initialisation."""
# assert self.framework.lower() in FRAMEWORKS.values()
# if self.seed is None:
# self.seed = np.random.randint(0, 2**32 - 1)
assert self.framework in [
"t",
"tf",
"tflow",
"tensorflow",
"p",
"pt",
"ptorch",
"torch",
"pytorch",
]
if self.framework in ["t", "tf", "tensorflow"]:
assert self.backend.lower() in BACKENDS["tensorflow"]
else:
assert self.backend.lower() in BACKENDS["pytorch"]
if self.use_wandb and self.wandb_project_name is None:
self.wandb_project_name = os.environ.get(
"WANDB_PROJECT", os.environ.get("WB_PROJECT", "ezpz")
)
if self.framework in ["p", "pt", "ptorch", "torch", "pytorch"]:
if self.backend.lower() in ["ds", "deepspeed", "dspeed"]:
self.ds_config = load_ds_config(
DS_CONFIG_PATH
if self.ds_config_path is None
else self.ds_config_path
)
__getitem__(key)
⚓︎
__init__(**kwargs)
⚓︎
Populate known fields while capturing any extras in self.extras.
Source code in src/ezpz/configs.py
def __init__(self, **kwargs: Any) -> None:
"""Populate known fields while capturing any extras in ``self.extras``."""
extras: dict[str, Any] = {}
for name, field_def in self.__dataclass_fields__.items():
if name == "extras":
continue
if name in kwargs:
value = kwargs.pop(name)
elif field_def.default is not MISSING:
value = field_def.default
elif field_def.default_factory is not MISSING: # type: ignore[attr-defined]
value = field_def.default_factory() # type: ignore[misc]
else:
value = None
setattr(self, name, value)
for key, value in kwargs.items():
setattr(self, key, value)
extras[key] = value
self.extras = extras
__post_init__()
⚓︎
Validate framework/backend compatibility after initialisation.
Source code in src/ezpz/configs.py
def __post_init__(self):
"""Validate framework/backend compatibility after initialisation."""
# assert self.framework.lower() in FRAMEWORKS.values()
# if self.seed is None:
# self.seed = np.random.randint(0, 2**32 - 1)
assert self.framework in [
"t",
"tf",
"tflow",
"tensorflow",
"p",
"pt",
"ptorch",
"torch",
"pytorch",
]
if self.framework in ["t", "tf", "tensorflow"]:
assert self.backend.lower() in BACKENDS["tensorflow"]
else:
assert self.backend.lower() in BACKENDS["pytorch"]
if self.use_wandb and self.wandb_project_name is None:
self.wandb_project_name = os.environ.get(
"WANDB_PROJECT", os.environ.get("WB_PROJECT", "ezpz")
)
if self.framework in ["p", "pt", "ptorch", "torch", "pytorch"]:
if self.backend.lower() in ["ds", "deepspeed", "dspeed"]:
self.ds_config = load_ds_config(
DS_CONFIG_PATH
if self.ds_config_path is None
else self.ds_config_path
)
from_file(fpath)
⚓︎
get_config()
⚓︎
to_dict()
⚓︎
to_file(fpath)
⚓︎
to_json()
⚓︎
Return a JSON string representation of the configuration.
to_str()
⚓︎
ZeroConfig
dataclass
⚓︎
Subset of DeepSpeed ZeRO options exposed via the ezpz CLI.
Source code in src/ezpz/configs.py
@dataclass
class ZeroConfig:
"""Subset of DeepSpeed ZeRO options exposed via the ezpz CLI."""
stage: int = 0
allgather_partitions: Optional[bool] = None
allgather_bucket_size: int = int(5e8)
overlap_comm: Optional[bool] = None
reduce_scatter: bool = True
reduce_bucket_size: int = int(5e8)
contiguous_gradients: Optional[bool] = None
offload_param: Optional[dict] = None
offload_optimizer: Optional[dict] = None
stage3_max_live_parameters: int = int(1e9)
stage3_max_reuse_distance: int = int(1e9)
stage3_prefetch_bucket_size: int = int(5e8)
stage3_param_persistence_threshold: int = int(1e6)
sub_group_size: Optional[int] = None
elastic_checkpoint: Optional[dict] = None
stage3_gather_16bit_weights_on_model_save: Optional[bool] = None
ignore_unused_parameters: Optional[bool] = None
round_robin_gradients: Optional[bool] = None
zero_hpz_partition_size: Optional[int] = None
zero_quantized_weights: Optional[bool] = None
zero_quantized_gradients: Optional[bool] = None
log_trace_cache_warnings: Optional[bool] = None
cmd_exists(cmd)
⚓︎
Check whether command exists.
Examples:
command_exists(cmd)
⚓︎
get_logging_config()
⚓︎
Return the logging configuration dictionary used by logging.config.
Source code in src/ezpz/configs.py
def get_logging_config() -> dict:
"""Return the logging configuration dictionary used by ``logging.config``."""
# import logging.config
import yaml
cfp = CONF_DIR.joinpath("hydra", "job_logging", "custom.yaml")
with cfp.open("r") as stream:
config = yaml.load(stream, Loader=yaml.FullLoader)
config.setdefault("loggers", {})
return config
get_scheduler(_scheduler=None)
⚓︎
Infer the active scheduler from environment variables or hostname.
Source code in src/ezpz/configs.py
def get_scheduler(_scheduler: Optional[str] = None) -> str:
"""Infer the active scheduler from environment variables or hostname."""
from ezpz import get_hostname, get_machine
if _scheduler is not None:
log.info(f"Using user-specified scheduler: {_scheduler}")
return _scheduler.upper()
if os.environ.get("PBS_JOBID"):
return "PBS"
if os.environ.get("SLURM_JOB_ID") or os.environ.get("SLURM_JOBID"):
return "SLURM"
machine = get_machine(get_hostname())
if machine.lower() in [
"thetagpu",
"sunspot",
"polaris",
"aurora",
"sophia",
]:
return SCHEDULERS["ALCF"]
if machine.lower() in ["frontier"]:
return SCHEDULERS["OLCF"]
if machine.lower() in ["nersc", "perlmutter"]:
return SCHEDULERS["NERSC"]
return "UNKNOWN"
getjobenv_dep()
⚓︎
git_ds_info()
⚓︎
Log the output of DeepSpeed's environment report plus Git metadata.
Source code in src/ezpz/configs.py
def git_ds_info():
"""Log the output of DeepSpeed's environment report plus Git metadata."""
from deepspeed.env_report import main as ds_report
ds_report()
# Write out version/git info
git_hash_cmd = "git rev-parse --short HEAD"
git_branch_cmd = "git rev-parse --abbrev-ref HEAD"
if command_exists("git"):
try:
result = subprocess.check_output(git_hash_cmd, shell=True)
git_hash = result.decode("utf-8").strip()
result = subprocess.check_output(git_branch_cmd, shell=True)
git_branch = result.decode("utf-8").strip()
except subprocess.CalledProcessError:
git_hash = "unknown"
git_branch = "unknown"
else:
git_hash = "unknown"
git_branch = "unknown"
log.info(
f"**** Git info for DeepSpeed:"
f" git_hash={git_hash} git_branch={git_branch} ****"
)
load_ds_config(fpath=None)
⚓︎
Load a DeepSpeed configuration file (JSON or YAML).
Source code in src/ezpz/configs.py
def load_ds_config(
fpath: Optional[Union[str, os.PathLike, Path]] = None, # type:ignore[reportDeprecated]
) -> dict[str, Any]:
"""Load a DeepSpeed configuration file (JSON or YAML)."""
fpath = Path(DS_CONFIG_PATH) if fpath is None else f"{fpath}"
cfgpath = Path(fpath)
if cfgpath.suffix == ".json":
with cfgpath.open("r") as f:
ds_config: dict[str, Any] = json.load(f)
return ds_config
if cfgpath.suffix == ".yaml":
with cfgpath.open("r") as stream:
dsconfig: dict[str, Any] = dict(yaml.safe_load(stream))
return dsconfig
raise TypeError("Unexpected FileType")
print_config(cfg)
⚓︎
Render cfg to the active rich console.
Source code in src/ezpz/configs.py
def print_config(cfg: Union[dict, str]) -> None:
"""Render ``cfg`` to the active rich console."""
# try:
# from hydra.utils import instantiate
# config = instantiate(cfg)
# except Exception:
# config = OmegaConf.to_container(cfg, resolve=True)
# config = OmegaConf.to_container(cfg, resolve=True)
# if isinstance(cfg, dict):
# jstr = json.dumps(cfg, indent=4)
# else:
# jstr = cfg
from rich.logging import RichHandler
from ezpz.log.handler import RichHandler as EnrichHandler
console = None
for handler in log.handlers:
if isinstance(handler, (RichHandler, EnrichHandler)):
console = handler.console
if console is None:
from ezpz.log.console import get_console
console = get_console()
# console.print_json(data=cfg, indent=4, highlight=True)
print_json(data=cfg, console=console, indent=4, highlight=True)
print_config_tree(cfg, resolve=True, save_to_file=True, verbose=True, style='tree', print_order=None, highlight=True, outfile=None)
⚓︎
Prints the contents of a DictConfig as a tree structure using the Rich library.
- cfg: A DictConfig composed by Hydra.
- print_order: Determines in what order config components are printed.
- resolve: Whether to resolve reference fields of DictConfig.
- save_to_file: Whether to export config to the hydra output folder.
Source code in src/ezpz/configs.py
def print_config_tree(
cfg: DictConfig,
resolve: bool = True,
save_to_file: bool = True,
verbose: bool = True,
style: str = "tree",
print_order: Optional[Sequence[str]] = None,
highlight: bool = True,
outfile: Optional[Union[str, os.PathLike, Path]] = None,
) -> Tree:
"""Prints the contents of a DictConfig as a tree structure using the Rich
library.
- cfg: A DictConfig composed by Hydra.
- print_order: Determines in what order config components are printed.
- resolve: Whether to resolve reference fields of DictConfig.
- save_to_file: Whether to export config to the hydra output folder.
"""
from rich.console import Console
from rich.theme import Theme
from ezpz.log.config import STYLES
name = cfg.get("_target_", "cfg")
console = Console(record=True, theme=Theme(STYLES))
tree = Tree(label=name, highlight=highlight)
queue = []
# add fields from `print_order` to queue
if print_order is not None:
for field in print_order:
(
queue.append(field)
if field in cfg
else log.warning(
f"Field '{field}' not found in config. "
f"Skipping '{field}' config printing..."
)
)
# add all the other fields to queue (not specified in `print_order`)
for field in cfg:
if field not in queue:
queue.append(field)
# generate config tree from queue
for field in queue:
branch = tree.add(field, highlight=highlight) # , guide_style=style)
config_group = cfg[field]
if isinstance(config_group, DictConfig):
branch_content = str(
OmegaConf.to_yaml(config_group, resolve=resolve)
)
branch.add(Text(branch_content, style="red"))
else:
branch_content = str(config_group)
branch.add(Text(branch_content, style="blue"))
if verbose or save_to_file:
console.print(tree)
if save_to_file:
outfpath = (
Path(os.getcwd()).joinpath("config_tree.log")
if outfile is None
else Path(outfile)
)
console.save_text(outfpath.as_posix())
return tree
print_json(json_str=None, console=None, *, data=None, indent=2, highlight=True, skip_keys=False, ensure_ascii=False, check_circular=True, allow_nan=True, default=None, sort_keys=False)
⚓︎
Pretty prints JSON. Output will be valid JSON.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
json_str
|
Optional[str]
|
A string containing JSON. |
None
|
data
|
Any
|
If json is not supplied, then encode this data. |
None
|
indent
|
Union[None, int, str]
|
Number of spaces to indent. Defaults to 2. |
2
|
highlight
|
bool
|
Enable highlighting of output: Defaults to True. |
True
|
skip_keys
|
bool
|
Skip keys not of a basic type. Defaults to False. |
False
|
ensure_ascii
|
bool
|
Escape all non-ascii characters. Defaults to False. |
False
|
check_circular
|
bool
|
Check for circular references. Defaults to True. |
True
|
allow_nan
|
bool
|
Allow NaN and Infinity values. Defaults to True. |
True
|
default
|
Callable
|
A callable that converts values that can not be encoded in to something that can be JSON encoded. Defaults to None. |
None
|
sort_keys
|
bool
|
Sort dictionary keys. Defaults to False. |
False
|
Source code in src/ezpz/configs.py
def print_json(
json_str: Optional[str] = None,
console: Optional[Console] = None,
*,
data: Any = None,
indent: Union[None, int, str] = 2,
highlight: bool = True,
skip_keys: bool = False,
ensure_ascii: bool = False,
check_circular: bool = True,
allow_nan: bool = True,
default: Optional[Callable[[Any], Any]] = None,
sort_keys: bool = False,
) -> None:
"""Pretty prints JSON. Output will be valid JSON.
Args:
json_str (Optional[str]): A string containing JSON.
data (Any): If json is not supplied, then encode this data.
indent (Union[None, int, str], optional): Number of spaces to indent.
Defaults to 2.
highlight (bool, optional): Enable highlighting of output:
Defaults to True.
skip_keys (bool, optional): Skip keys not of a basic type.
Defaults to False.
ensure_ascii (bool, optional): Escape all non-ascii characters.
Defaults to False.
check_circular (bool, optional): Check for circular references.
Defaults to True.
allow_nan (bool, optional): Allow NaN and Infinity values.
Defaults to True.
default (Callable, optional): A callable that converts values
that can not be encoded in to something that can be JSON
encoded.
Defaults to None.
sort_keys (bool, optional): Sort dictionary keys. Defaults to False.
"""
if json_str is None and data is None:
raise ValueError(
"Either `json_str` or `data` must be provided. "
"Did you mean print_json(data={data!r}) ?"
)
if json_str is not None and data is not None:
raise ValueError(
" ".join(
[
"Only one of `json_str` or `data` should be provided.",
"Did you mean print_json(json_str={json_str!r}) ?",
"Or print_json(data={data!r}) ?",
"Received both:",
f"json_str={json_str!r}",
f"data={data!r}",
]
)
)
from rich.json import JSON
from ezpz.log.console import get_console
console = get_console() if console is None else console
if json_str is None:
json_renderable = JSON.from_data(
data,
indent=indent,
highlight=highlight,
skip_keys=skip_keys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
default=default,
sort_keys=sort_keys,
)
else:
json_renderable = JSON(
json_str,
indent=indent,
highlight=highlight,
skip_keys=skip_keys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
default=default,
sort_keys=sort_keys,
)
assert console is not None and isinstance(console, Console)
log.info(Text(str(json_renderable)).render(console=console))