Skip to content

Fine-Tune a Causal LM (HuggingFace + Accelerate)⚓︎

Use this example when you want to fine-tune a pretrained HuggingFace language model with full control over the training loop. Unlike the hf_trainer example which uses the HF Trainer abstraction, this writes the loop explicitly — giving you control over gradient accumulation, custom evaluation, and learning rate scheduling while still using Accelerate for distributed coordination and ezpz for setup and metrics.

Key API Functions

See:

ezpz launch python3 -m ezpz.examples.hf \
    --model_name_or_path gpt2 \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size 2 \
    --output_dir ./output-hf

Source⚓︎

src/ezpz/examples/hf.py
src/ezpz/examples/hf.py
#!/usr/bin/env python
"""
Fine-tune a causal LM with a hand-rolled training loop.

This mirrors the dataset/model setup used in ``ezpz.examples.hf_trainer`` while
keeping an explicit training loop like the other examples.
"""

from __future__ import annotations

# pyright: reportArgumentType=false
# pyright: reportGeneralTypeIssues=false

import json
import math
import os
import sys
import time
from itertools import chain
from pathlib import Path
from typing import Optional, cast

import datasets
import torch
import transformers
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from transformers import (
    CONFIG_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    get_scheduler,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version

import ezpz
from ezpz.configs import HfDataTrainingArguments, HfModelArguments
from ezpz.flops import compute_mfu, try_estimate

logger = ezpz.get_logger(__name__)

try:
    from accelerate import Accelerator, FullyShardedDataParallelPlugin  # noqa: E402 type:ignore
    from accelerate.utils import set_seed
except ImportError as exc:
    logger.error(
        "Please install accelerate to run this script: `pip install accelerate`"
    )
    raise exc

require_version(
    "datasets>=2.14.0",
    "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
)


def _safetensors_save_errors() -> tuple[type[Exception], ...]:
    """Exception types we accept as triggers for the safetensors retry.

    Always includes (OSError, RuntimeError, ValueError) — those cover
    the common parallel-filesystem failures (E2BIG "Argument list too
    long" on Lustre, RuntimeError from torch save shims).  When the
    optional ``safetensors`` library is installed we also include its
    native ``SafetensorError`` because the rust core raises that
    directly for header/metadata/shared-tensor failures, not OSError.

    Return type is narrowed to ``Exception`` (not ``BaseException``) so
    a future addition can't accidentally include ``SystemExit`` or
    ``KeyboardInterrupt`` and turn the retry into an interrupt swallow.
    """
    errors: tuple[type[Exception], ...] = (OSError, RuntimeError, ValueError)
    try:
        from safetensors import SafetensorError
    except ImportError:
        return errors
    return errors + (SafetensorError,)


_SAFETENSORS_SAVE_ERRORS = _safetensors_save_errors()


def _save_pretrained_with_fallback(
    model: object,
    output_dir: str,
    *,
    is_main_process: bool,
    save_function: object,
) -> None:
    """``model.save_pretrained`` with a safetensors fallback.

    First tries the default safetensors serializer; on a parallel-FS
    failure (OSError "Argument list too long", RuntimeError from torch
    save shims, or safetensors.SafetensorError from the rust core)
    retries with ``safe_serialization=False``.  Genuine bugs (TypeError,
    attribute errors, OOM) are not caught.

    Used at both the mid-training epoch save and the end-of-training
    save so a Lustre/safetensors failure mid-run doesn't crash the
    whole job.
    """
    try:
        model.save_pretrained(  # type: ignore[attr-defined]
            output_dir,
            is_main_process=is_main_process,
            save_function=save_function,
        )
    except _SAFETENSORS_SAVE_ERRORS as e:
        logger.warning(
            "save_pretrained with safetensors failed (%s: %s); "
            "retrying with safe_serialization=False",
            type(e).__name__, e,
        )
        model.save_pretrained(  # type: ignore[attr-defined]
            output_dir,
            is_main_process=is_main_process,
            save_function=save_function,
            safe_serialization=False,
        )


def _strip_metric_prefix(summary: str, prefix: str) -> str:
    """Drop *prefix* from metric tokens in a History summary string.

    History.update() returns a space-separated string like
    ``"train/loss=0.5 train/dt=0.1"``.  The previous implementation
    used ``str.replace(prefix, "")`` which would mangle a metric
    whose name contains the prefix as a substring (e.g.
    ``cosine_train/x``).  This helper splits on whitespace and
    only strips the prefix when it actually anchors a token.
    """
    return " ".join(
        token.removeprefix(prefix) for token in summary.split()
    )


def parse_args(
    ) -> tuple[HfModelArguments, HfDataTrainingArguments, TrainingArguments]:
    """Parse Hugging Face model, data, and training arguments.

    Returns:
        Mapping with ``model``, ``data``, and ``training`` argument objects.
    """
    parser = HfArgumentParser(
        (HfModelArguments, HfDataTrainingArguments, TrainingArguments)  # type:ignore
    )
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1])
        )
    else:
        model_args, data_args, training_args = (
            parser.parse_args_into_dataclasses()
        )

    if training_args.should_log:
        from transformers.utils import logging as hf_logging

        hf_logging.set_verbosity_info()

    rank = ezpz.get_rank()
    log_level_info = 20
    log_level_critical = 50
    log_level = log_level_info if rank == 0 else log_level_critical
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    if rank == 0:
        logger.info("Training/evaluation parameters %s", training_args)

    return model_args, data_args, training_args


@ezpz.timeitlogit(rank=ezpz.get_rank())
def split_dataset(
    model_args: HfModelArguments,
    data_args: HfDataTrainingArguments,
    train_split_name: str = "train",
    validation_split_name: Optional[str] = None,
) -> datasets.IterableDatasetDict | datasets.DatasetDict:
    """Split a Hugging Face dataset into train/validation splits.

    Args:
        model_args: Model configuration arguments for cache/token settings.
        data_args: Data-related arguments for dataset selection.
        train_split_name: Name of the training split.
        validation_split_name: Name of the validation split (if any).

    Returns:
        Dataset dictionary with requested splits.
    """
    dataset_name = data_args.dataset_name
    assert dataset_name is not None, (
        "dataset_name must be provided to split the dataset."
    )
    dsets: dict[str, datasets.Dataset | datasets.IterableDataset] = {}
    if validation_split_name is not None:
        try:
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
            dsets[train_split_name] = datasets.load_dataset(  # type: ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
        except ValueError:
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=train_split_name,
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
            try:
                dsets[train_split_name] = datasets.load_dataset(  # type:ignore
                    dataset_name,
                    data_args.dataset_config_name,
                    split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                    cache_dir=model_args.cache_dir,
                    token=model_args.token,
                    streaming=data_args.streaming,
                    trust_remote_code=model_args.trust_remote_code,
                )
            except Exception:
                dsets[train_split_name] = datasets.load_dataset(  # type:ignore
                    dataset_name,
                    data_args.dataset_config_name,
                    split=train_split_name,
                    cache_dir=model_args.cache_dir,
                    token=model_args.token,
                    streaming=data_args.streaming,
                    trust_remote_code=model_args.trust_remote_code,
                )

    if data_args.streaming:
        return datasets.IterableDatasetDict(  # type: ignore
            cast(dict[str, datasets.IterableDataset], dsets)
        )
    return datasets.DatasetDict(  # type: ignore
        cast(dict[str, datasets.Dataset], dsets)
    )


@ezpz.timeitlogit(rank=ezpz.get_rank())
def main() -> None:
    """Entrypoint for standalone HF causal LM fine-tuning without Trainer."""
    import logging as _logging
    for _noisy in ("httpx", "huggingface_hub", "filelock"):
        _logging.getLogger(_noisy).setLevel(_logging.WARNING)
    t0 = time.perf_counter()
    rank = ezpz.setup_torch()
    model_args, data_args, training_args = parse_args()

    output_dir = training_args.output_dir or os.getcwd()
    wandb = None
    report_to = training_args.report_to

    # Build FSDP plugin explicitly when --fsdp is requested, bypassing
    # the env-var machinery which can pick up stale/conflicting defaults.
    fsdp_plugin = None
    use_fsdp = os.environ.get("ACCELERATE_USE_FSDP", "").lower() == "true"
    if use_fsdp:
        from torch.distributed.fsdp import (
            BackwardPrefetch,
            MixedPrecision,
            ShardingStrategy,
        )

        mp_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ) if training_args.bf16 else None
        fsdp_plugin = FullyShardedDataParallelPlugin(
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            mixed_precision_policy=mp_policy,
            use_orig_params=True,
            sync_module_states=False,
            cpu_ram_efficient_loading=False,
            limit_all_gathers=True,
        )
        logger.info("[rank %d] using explicit FSDP plugin: %s", rank, fsdp_plugin)

    # Don't let Accelerator manage wandb — we handle it via History's tracker
    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        fsdp_plugin=fsdp_plugin,
    )
    t_setup = time.perf_counter()

    # Resolve project name for tracker backends
    wbproj_name = (
        model_args.wandb_project_name
        if getattr(model_args, "wandb_project_name", None) is not None
        else (model_args.model_name_or_path or "ezpz-hf-default-project")
    )
    wbproj_name = f"ezpz-hf-{wbproj_name}".replace("/", "-")
    ezpz.barrier()

    logger.warning(accelerator.state)
    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    if training_args.seed is not None:
        set_seed(training_args.seed)

    api = None
    repo_id = None
    if accelerator.is_main_process:
        if training_args.push_to_hub:
            repo_name = training_args.hub_model_id
            if repo_name is None:
                repo_name = Path(output_dir).absolute().name
            api = HfApi()
            repo_id = api.create_repo(
                repo_name, exist_ok=True, token=training_args.hub_token
            ).repo_id

            with open(os.path.join(output_dir, ".gitignore"), "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        else:
            os.makedirs(output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    last_checkpoint = None
    overwrite = getattr(training_args, "overwrite_output_dir", False)
    if (
        os.path.isdir(output_dir)
        and training_args.do_train
        and not overwrite
    ):
        last_checkpoint = get_last_checkpoint(output_dir)
        if (
            last_checkpoint is None
            and len(os.listdir(output_dir)) > 0
        ):
            raise ValueError(
                "Output directory already exists and is not empty."
            )
        if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                "Checkpoint detected, resuming training at %s. To avoid this behavior, change the output_dir.",
                last_checkpoint,
            )

    train_split_name = data_args.train_split_name or "train"
    validation_split_name = data_args.validation_split_name or "validation"
    if data_args.dataset_name is not None:
        raw_datasets = split_dataset(
            model_args,
            data_args,
            train_split_name=train_split_name,
            validation_split_name=validation_split_name,
        )
    else:
        data_files: dict[str, str] = {}
        dataset_args: dict[str, object] = {}
        if data_args.train_file is not None:
            data_files[train_split_name] = data_args.train_file
        if data_args.validation_file is not None:
            data_files[validation_split_name] = data_args.validation_file
        if data_args.train_file is not None:
            extension = data_args.train_file.split(".")[-1]
        elif data_args.validation_file is not None:
            extension = data_args.validation_file.split(".")[-1]
        else:
            raise ValueError("Expected a train or validation file.")
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
        raw_datasets = datasets.load_dataset(  # type: ignore[arg-type]
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            **dataset_args,
        )
        if validation_split_name not in raw_datasets.keys():
            raw_datasets[validation_split_name] = datasets.load_dataset(  # type:ignore
                extension,
                data_files=data_files,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )
            raw_datasets[train_split_name] = datasets.load_dataset(  # type:ignore
                extension,
                data_files=data_files,
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name, **config_kwargs
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path, **config_kwargs
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info("Overriding config: %s", model_args.config_overrides)
            config.update_from_string(model_args.config_overrides)
            logger.info("New config: %s", config)

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, **tokenizer_kwargs
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, **tokenizer_kwargs
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None]
            else getattr(torch, model_args.torch_dtype)
        )
        model = AutoModelForCausalLM.from_pretrained(  # type:ignore
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(  # type:ignore
            config, trust_remote_code=model_args.trust_remote_code
        )

    if callable(getattr(model, "get_input_embeddings")):
        embedding_size = model.get_input_embeddings().weight.shape[0]
        if len(tokenizer) > embedding_size:
            model.resize_token_embeddings(len(tokenizer))

    logger.info("[rank %d] proceeding to data prep", rank)

    if training_args.do_train:
        column_names = list(raw_datasets[train_split_name].features)  # type:ignore
    else:
        column_names = list(raw_datasets[validation_split_name].features)  # type:ignore
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples: dict[str, object]) -> dict[str, object]:
        """Tokenize raw text using the configured tokenizer."""
        return tokenizer(examples[text_column_name])

    logger.info("[rank %d] entering tokenization", rank)
    with training_args.main_process_first(desc="dataset map tokenization"):
        if not data_args.streaming:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,  # type:ignore
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,  # type:ignore
            )
        else:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=column_names,
            )
    logger.info("[rank %d] tokenization done", rank)

    if hasattr(config, "max_position_embeddings"):
        max_pos_embeddings = config.max_position_embeddings
    else:
        max_pos_embeddings = 1024
        logger.warning(
            "Config %s does not have 'max_position_embeddings'; using %s.",
            config,
            max_pos_embeddings,
        )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > max_pos_embeddings:
            logger.warning(
                "The tokenizer picked seems to have a very large `model_max_length` (%s). "
                "Using block_size=%s instead. You can change that default value by passing --block_size xxx.",
                tokenizer.model_max_length,
                min(1024, max_pos_embeddings),
            )
            if max_pos_embeddings > 0:
                block_size = min(1024, max_pos_embeddings)
            else:
                block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                "The block_size passed (%s) is larger than the maximum length for the model (%s). Using block_size=%s.",
                data_args.block_size,
                tokenizer.model_max_length,
                tokenizer.model_max_length,
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    def group_texts(
        examples: dict[str, list[list[int]]]
    ) -> dict[str, list[list[int]]]:
        """Concatenate and chunk tokenized text into fixed-size blocks."""
        concatenated_examples = {
            k: [int(x) for x in chain(*examples[k])]
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {
            k: [
                t[i : i + block_size]
                for i in range(0, total_length, block_size)
            ]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    logger.info("[rank %d] entering group_texts", rank)
    with training_args.main_process_first(desc="grouping texts together"):
        if not data_args.streaming:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,  # type:ignore
                load_from_cache_file=not data_args.overwrite_cache,  # type:ignore
            )
        else:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
            )
    logger.info("[rank %d] group_texts done", rank)

    train_dataset = None
    if training_args.do_train:
        if train_split_name not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets[train_split_name]  # type:ignore
        if data_args.max_train_samples is not None:
            if isinstance(train_dataset, datasets.IterableDataset):
                train_dataset = train_dataset.take(data_args.max_train_samples)
            else:
                max_train_samples = min(
                    train_dataset.num_rows, data_args.max_train_samples
                )
                train_dataset = train_dataset.select(range(max_train_samples))

    eval_dataset = None
    if training_args.do_eval:
        if validation_split_name not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets[validation_split_name]  # type:ignore
        if data_args.max_eval_samples is not None:
            if isinstance(eval_dataset, datasets.IterableDataset):
                eval_dataset = eval_dataset.take(data_args.max_eval_samples)
            else:
                eval_dataset = eval_dataset.select(
                    range(data_args.max_eval_samples)
                )  # type:ignore

    assert train_dataset is not None
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=not data_args.streaming,
        collate_fn=default_data_collator,
        batch_size=training_args.per_device_train_batch_size,
        num_workers=training_args.dataloader_num_workers,
    )
    eval_dataloader = None
    if eval_dataset is not None:
        eval_dataloader = DataLoader(
            eval_dataset,
            collate_fn=default_data_collator,
            batch_size=training_args.per_device_eval_batch_size,
            num_workers=training_args.dataloader_num_workers,
        )

    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters, lr=training_args.learning_rate
    )

    overrode_max_train_steps = False
    try:
        num_update_steps_per_epoch = math.ceil(
            len(train_dataloader) / training_args.gradient_accumulation_steps
        )
    except TypeError:
        # IterableDataset (streaming) has no len(); require max_steps
        num_update_steps_per_epoch = None
    if training_args.max_steps <= 0:
        if num_update_steps_per_epoch is None:
            raise ValueError(
                "max_steps must be set when using a streaming / "
                "IterableDataset (dataset length is unknown)."
            )
        training_args.max_steps = int(
            training_args.num_train_epochs * num_update_steps_per_epoch
        )
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        name=training_args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=training_args.warmup_steps
        * accelerator.num_processes,
        num_training_steps=training_args.max_steps
        if overrode_max_train_steps
        else training_args.max_steps * accelerator.num_processes,
    )

    _model_flops = try_estimate(
        model, (training_args.per_device_train_batch_size, block_size),
    )

    logger.info("[rank %d] calling accelerator.prepare() ...", rank)
    (
        model,
        optimizer,
        train_dataloader,
        eval_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
    logger.info("[rank %d] accelerator.prepare() complete", rank)

    try:
        num_update_steps_per_epoch = math.ceil(
            len(train_dataloader) / training_args.gradient_accumulation_steps
        )
    except TypeError:
        num_update_steps_per_epoch = None
    if num_update_steps_per_epoch is not None:
        if overrode_max_train_steps:
            training_args.max_steps = int(
                training_args.num_train_epochs * num_update_steps_per_epoch
            )
        training_args.num_train_epochs = math.ceil(
            training_args.max_steps / num_update_steps_per_epoch
        )

    checkpointing_steps = training_args.save_steps

    train_start = time.perf_counter()
    total_batch_size = (
        training_args.per_device_train_batch_size
        * accelerator.num_processes
        * training_args.gradient_accumulation_steps
    )

    logger.info("***** Model *****")
    logger.info("  Model = %s", model)
    logger.info("***** Args *****")
    logger.info(
        json.dumps(
            {
                "model": model_args.__dict__,
                "data": data_args.__dict__,
                "training": training_args.to_dict(),
            },
            indent=4,
            sort_keys=True,
        )
    )
    logger.info("***** Running training *****")
    logger.info("  Num processes = %s", accelerator.num_processes)
    logger.info(
        "  Num examples = %s",
        len(train_dataset) if hasattr(train_dataset, "__len__") else "unknown (streaming)",
    )
    logger.info("  Num Epochs = %s", training_args.num_train_epochs)
    logger.info(
        "  Instantaneous batch size per device = %s",
        training_args.per_device_train_batch_size,
    )
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %s",
        total_batch_size,
    )
    logger.info(
        "  Gradient Accumulation steps = %s",
        training_args.gradient_accumulation_steps,
    )
    logger.info("  Total optimization steps = %s", training_args.max_steps)

    logging_steps = max(1, int(training_args.logging_steps))
    outdir = Path(training_args.output_dir) if training_args.output_dir else Path.cwd() / "outputs"
    outdir.mkdir(parents=True, exist_ok=True)
    logger.info("Outputs will be saved to %s", outdir)
    history = ezpz.history.History(
        project_name=wbproj_name,
        config={
            "model": model_args.__dict__,
            "data": data_args.__dict__,
            "training": training_args.to_dict(),
        },
        outdir=outdir,
        report_dir=outdir,
        report_enabled=True,
        jsonl_path=outdir / "metrics.jsonl",
        jsonl_overwrite=True,
    )
    completed_steps = 0
    starting_epoch = 0

    if training_args.resume_from_checkpoint:
        if (
            training_args.resume_from_checkpoint is not None
            or training_args.resume_from_checkpoint != ""
        ):
            checkpoint_path = training_args.resume_from_checkpoint
            path = os.path.basename(training_args.resume_from_checkpoint)
        else:
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            path = dirs[-1]
            checkpoint_path = path
            path = os.path.basename(checkpoint_path)

        accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
        accelerator.load_state(checkpoint_path)
        training_difference = os.path.splitext(path)[0]

        if "epoch" in training_difference:
            starting_epoch = int(training_difference.replace("epoch_", "")) + 1
            resume_step = None
            completed_steps = (
                starting_epoch * num_update_steps_per_epoch
                if num_update_steps_per_epoch is not None
                else 0
            )
        else:
            resume_step = (
                int(training_difference.replace("step_", ""))
                * training_args.gradient_accumulation_steps
            )
            completed_steps = resume_step // training_args.gradient_accumulation_steps
            if num_update_steps_per_epoch is not None:
                starting_epoch = resume_step // len(train_dataloader)
                resume_step -= starting_epoch * len(train_dataloader)
            else:
                starting_epoch = 0
    else:
        resume_step = None

    total_loss = 0
    perplexity: Optional[float] = None
    for epoch in range(starting_epoch, int(training_args.num_train_epochs)):
        model.train()
        total_loss = 0
        if (
            training_args.resume_from_checkpoint
            and epoch == starting_epoch
            and resume_step is not None
        ):
            active_dataloader = accelerator.skip_first_batches(
                train_dataloader, resume_step
            )
        else:
            active_dataloader = train_dataloader
        for _, batch in enumerate(active_dataloader):
            t0step = time.perf_counter()
            t1step = 0.0
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                t1step = time.perf_counter() - t0step
                completed_steps += 1
                total_loss += loss.detach().float().item()

                tokens_per_step = total_batch_size * block_size
                tps = tokens_per_step / t1step if t1step > 0 else 0.0
                step_loss = loss.detach().float().item()
                try:
                    step_ppl = math.exp(step_loss)
                except OverflowError:
                    step_ppl = float("inf")

                metrics = {
                    "train/step": completed_steps,
                    "train/epoch": epoch,
                    "train/loss": step_loss,
                    "train/perplexity": step_ppl,
                    "train/lr": lr_scheduler.get_last_lr()[0],
                    "train/dt": t1step,
                    "train/tokens_per_sec": tps,
                }
                if _model_flops > 0 and t1step > 0:
                    # t1step covers the full optimizer step (data load
                    # + forward + backward + optimizer.step + sync).
                    metrics["train/tflops"] = _model_flops / t1step / 1e12
                    metrics["train/mfu"] = compute_mfu(_model_flops, t1step)

                if completed_steps % logging_steps == 0:
                    summary = history.update(metrics)
                    logger.info(
                        "[train] %s",
                        _strip_metric_prefix(summary, "train/"),
                    )

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0 and accelerator.sync_gradients:
                    output_dir = f"step_{completed_steps}"
                    if training_args.output_dir is not None:
                        output_dir = os.path.join(
                            training_args.output_dir, output_dir
                        )
                    accelerator.save_state(output_dir)

            if completed_steps >= training_args.max_steps:
                break

        if eval_dataloader is not None:
            model.eval()
            losses = []
            for _, batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**batch)

                loss = outputs.loss
                losses.append(
                    accelerator.gather_for_metrics(
                        loss.repeat(training_args.per_device_eval_batch_size)
                    )
                )

            losses = torch.cat(losses)
            eval_loss = torch.mean(losses)
            try:
                perplexity = math.exp(eval_loss)
            except OverflowError:
                perplexity = float("inf")

            avg_train_loss = float(total_loss) / max(completed_steps, 1)
            eval_metrics = {
                "eval/step": completed_steps,
                "eval/epoch": epoch,
                "eval/loss": float(eval_loss),
                "eval/perplexity": perplexity,
                "eval/train_loss": avg_train_loss,
            }
            summary = history.update(eval_metrics)
            logger.info("[eval] %s", _strip_metric_prefix(summary, "eval/"))

        if training_args.push_to_hub and epoch < training_args.num_train_epochs - 1:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            _save_pretrained_with_fallback(
                unwrapped_model,
                output_dir,
                is_main_process=accelerator.is_main_process,
                save_function=accelerator.save,
            )
            if accelerator.is_main_process and api is not None and repo_id is not None:
                tokenizer.save_pretrained(output_dir)
                api.upload_folder(  # type: ignore[arg-type]
                    commit_message=f"Training in progress epoch {epoch}",
                    folder_path=output_dir,
                    repo_id=repo_id,
                    repo_type="model",
                    token=training_args.hub_token,
                )

        if training_args.save_strategy == "epoch":
            output_dir = f"epoch_{epoch}"
            output_dir = os.path.join(output_dir, output_dir)
            accelerator.save_state(output_dir)

        if completed_steps >= training_args.max_steps:
            break

    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    _save_pretrained_with_fallback(
        unwrapped_model,
        output_dir,
        is_main_process=accelerator.is_main_process,
        save_function=accelerator.save,
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        if training_args.push_to_hub and api is not None and repo_id is not None:
            api.upload_folder(  # type: ignore[arg-type]
                commit_message="End of training",
                folder_path=output_dir,
                repo_id=repo_id,
                repo_type="model",
                token=training_args.hub_token,
            )
        if perplexity is not None:
            with open(
                os.path.join(output_dir, "all_results.json"),
                "w",
            ) as f:
                json.dump({"perplexity": perplexity}, f)

    train_end = time.perf_counter()
    timings = {
        "main/setup_torch": t_setup - t0,
        "main/train": train_end - train_start,
        "main/total": train_end - t0,
        "timings/training_start": train_start - t0,
        "timings/train_duration": train_end - train_start,
        "timings/end-to-end": train_end - t0,
    }
    logger.info("Timings: %s", timings)

    history.tracker.log(
        {
            (f"timings/{k}" if not k.startswith("timings/") else k): v
            for k, v in timings.items()
        }
    )

    if accelerator.is_main_process:
        history.finalize(
            run_name="ezpz.examples.hf",
            dataset_fname="train",
            warmup=0,
            save=True,
            plot=True,
            outdir=outdir,
            timings=timings,
        )


if __name__ == "__main__":
    main()

Code Walkthrough⚓︎

Imports

Standard library, HuggingFace, and ezpz imports. The Accelerator and FullyShardedDataParallelPlugin from accelerate are loaded in a try/except so the error message is clear if the package is missing.

src/ezpz/examples/hf.py:1:59
#!/usr/bin/env python
"""
Fine-tune a causal LM with a hand-rolled training loop.

This mirrors the dataset/model setup used in ``ezpz.examples.hf_trainer`` while
keeping an explicit training loop like the other examples.
"""

from __future__ import annotations

# pyright: reportArgumentType=false
# pyright: reportGeneralTypeIssues=false

import json
import math
import os
import sys
import time
from itertools import chain
from pathlib import Path
from typing import Optional, cast

import datasets
import torch
import transformers
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from transformers import (
    CONFIG_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    get_scheduler,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version

import ezpz
from ezpz.configs import HfDataTrainingArguments, HfModelArguments
from ezpz.flops import compute_mfu, try_estimate

logger = ezpz.get_logger(__name__)

try:
    from accelerate import Accelerator, FullyShardedDataParallelPlugin  # noqa: E402 type:ignore
    from accelerate.utils import set_seed
except ImportError as exc:
    logger.error(
        "Please install accelerate to run this script: `pip install accelerate`"
    )
    raise exc

require_version(
    "datasets>=2.14.0",
    "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt",
)
parse_args

Uses HuggingFace's HfArgumentParser to parse three dataclass groups in one pass. Supports loading arguments from a JSON file when the sole CLI argument is a .json path. Logging verbosity is restricted to rank 0.

src/ezpz/examples/hf.py:62:98
def _safetensors_save_errors() -> tuple[type[Exception], ...]:
    """Exception types we accept as triggers for the safetensors retry.

    Always includes (OSError, RuntimeError, ValueError) — those cover
    the common parallel-filesystem failures (E2BIG "Argument list too
    long" on Lustre, RuntimeError from torch save shims).  When the
    optional ``safetensors`` library is installed we also include its
    native ``SafetensorError`` because the rust core raises that
    directly for header/metadata/shared-tensor failures, not OSError.

    Return type is narrowed to ``Exception`` (not ``BaseException``) so
    a future addition can't accidentally include ``SystemExit`` or
    ``KeyboardInterrupt`` and turn the retry into an interrupt swallow.
    """
    errors: tuple[type[Exception], ...] = (OSError, RuntimeError, ValueError)
    try:
        from safetensors import SafetensorError
    except ImportError:
        return errors
    return errors + (SafetensorError,)


_SAFETENSORS_SAVE_ERRORS = _safetensors_save_errors()


def _save_pretrained_with_fallback(
    model: object,
    output_dir: str,
    *,
    is_main_process: bool,
    save_function: object,
) -> None:
    """``model.save_pretrained`` with a safetensors fallback.

    First tries the default safetensors serializer; on a parallel-FS
    failure (OSError "Argument list too long", RuntimeError from torch
    save shims, or safetensors.SafetensorError from the rust core)
split_dataset

Loads a named dataset from the HuggingFace Hub, splitting it into train/validation by percentage. Falls back gracefully when the requested split syntax is not supported by the dataset.

src/ezpz/examples/hf.py:101:181
    Used at both the mid-training epoch save and the end-of-training
    save so a Lustre/safetensors failure mid-run doesn't crash the
    whole job.
    """
    try:
        model.save_pretrained(  # type: ignore[attr-defined]
            output_dir,
            is_main_process=is_main_process,
            save_function=save_function,
        )
    except _SAFETENSORS_SAVE_ERRORS as e:
        logger.warning(
            "save_pretrained with safetensors failed (%s: %s); "
            "retrying with safe_serialization=False",
            type(e).__name__, e,
        )
        model.save_pretrained(  # type: ignore[attr-defined]
            output_dir,
            is_main_process=is_main_process,
            save_function=save_function,
            safe_serialization=False,
        )


def _strip_metric_prefix(summary: str, prefix: str) -> str:
    """Drop *prefix* from metric tokens in a History summary string.

    History.update() returns a space-separated string like
    ``"train/loss=0.5 train/dt=0.1"``.  The previous implementation
    used ``str.replace(prefix, "")`` which would mangle a metric
    whose name contains the prefix as a substring (e.g.
    ``cosine_train/x``).  This helper splits on whitespace and
    only strips the prefix when it actually anchors a token.
    """
    return " ".join(
        token.removeprefix(prefix) for token in summary.split()
    )


def parse_args(
    ) -> tuple[HfModelArguments, HfDataTrainingArguments, TrainingArguments]:
    """Parse Hugging Face model, data, and training arguments.

    Returns:
        Mapping with ``model``, ``data``, and ``training`` argument objects.
    """
    parser = HfArgumentParser(
        (HfModelArguments, HfDataTrainingArguments, TrainingArguments)  # type:ignore
    )
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1])
        )
    else:
        model_args, data_args, training_args = (
            parser.parse_args_into_dataclasses()
        )

    if training_args.should_log:
        from transformers.utils import logging as hf_logging

        hf_logging.set_verbosity_info()

    rank = ezpz.get_rank()
    log_level_info = 20
    log_level_critical = 50
    log_level = log_level_info if rank == 0 else log_level_critical
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    if rank == 0:
        logger.info("Training/evaluation parameters %s", training_args)

    return model_args, data_args, training_args


@ezpz.timeitlogit(rank=ezpz.get_rank())
def split_dataset(
main -- Distributed Setup

setup_torch() initializes the process group and returns the local rank. Arguments are parsed, and the output directory is established.

src/ezpz/examples/hf.py:184:196
    train_split_name: str = "train",
    validation_split_name: Optional[str] = None,
) -> datasets.IterableDatasetDict | datasets.DatasetDict:
    """Split a Hugging Face dataset into train/validation splits.

    Args:
        model_args: Model configuration arguments for cache/token settings.
        data_args: Data-related arguments for dataset selection.
        train_split_name: Name of the training split.
        validation_split_name: Name of the validation split (if any).

    Returns:
        Dataset dictionary with requested splits.
main -- FSDP Plugin

When ACCELERATE_USE_FSDP=true, an explicit FullyShardedDataParallelPlugin is constructed with a bf16 mixed-precision policy, bypassing Accelerate's env-var machinery which can pick up stale defaults.

src/ezpz/examples/hf.py:200:223
        "dataset_name must be provided to split the dataset."
    )
    dsets: dict[str, datasets.Dataset | datasets.IterableDataset] = {}
    if validation_split_name is not None:
        try:
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
            dsets[train_split_name] = datasets.load_dataset(  # type: ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
        except ValueError:
main -- Accelerator

The Accelerator is created with gradient accumulation and the optional FSDP plugin. W&B logging is handled separately via ezpz.setup_wandb(), so it is not passed to Accelerate.

src/ezpz/examples/hf.py
    # Don't let Accelerator manage wandb — we handle it via ezpz.setup_wandb()
    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        fsdp_plugin=fsdp_plugin,
    )
    t_setup = time.perf_counter()
main -- Dataset Loading

Datasets are loaded from the Hub via split_dataset (or from local files), then the text column is identified for tokenization.

src/ezpz/examples/hf.py:294:342
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            mixed_precision_policy=mp_policy,
            use_orig_params=True,
            sync_module_states=False,
            cpu_ram_efficient_loading=False,
            limit_all_gathers=True,
        )
        logger.info("[rank %d] using explicit FSDP plugin: %s", rank, fsdp_plugin)

    # Don't let Accelerator manage wandb — we handle it via History's tracker
    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        fsdp_plugin=fsdp_plugin,
    )
    t_setup = time.perf_counter()

    # Resolve project name for tracker backends
    wbproj_name = (
        model_args.wandb_project_name
        if getattr(model_args, "wandb_project_name", None) is not None
        else (model_args.model_name_or_path or "ezpz-hf-default-project")
    )
    wbproj_name = f"ezpz-hf-{wbproj_name}".replace("/", "-")
    ezpz.barrier()

    logger.warning(accelerator.state)
    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    if training_args.seed is not None:
        set_seed(training_args.seed)

    api = None
    repo_id = None
    if accelerator.is_main_process:
        if training_args.push_to_hub:
            repo_name = training_args.hub_model_id
            if repo_name is None:
                repo_name = Path(output_dir).absolute().name
            api = HfApi()
            repo_id = api.create_repo(
                repo_name, exist_ok=True, token=training_args.hub_token
            ).repo_id
main -- Model & Tokenizer

Config, tokenizer, and model are resolved from AutoConfig, AutoTokenizer, and AutoModelForCausalLM. The embedding layer is resized if the tokenizer vocabulary is larger than the model's embedding matrix.

src/ezpz/examples/hf.py:344:413
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        else:
            os.makedirs(output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    last_checkpoint = None
    overwrite = getattr(training_args, "overwrite_output_dir", False)
    if (
        os.path.isdir(output_dir)
        and training_args.do_train
        and not overwrite
    ):
        last_checkpoint = get_last_checkpoint(output_dir)
        if (
            last_checkpoint is None
            and len(os.listdir(output_dir)) > 0
        ):
            raise ValueError(
                "Output directory already exists and is not empty."
            )
        if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                "Checkpoint detected, resuming training at %s. To avoid this behavior, change the output_dir.",
                last_checkpoint,
            )

    train_split_name = data_args.train_split_name or "train"
    validation_split_name = data_args.validation_split_name or "validation"
    if data_args.dataset_name is not None:
        raw_datasets = split_dataset(
            model_args,
            data_args,
            train_split_name=train_split_name,
            validation_split_name=validation_split_name,
        )
    else:
        data_files: dict[str, str] = {}
        dataset_args: dict[str, object] = {}
        if data_args.train_file is not None:
            data_files[train_split_name] = data_args.train_file
        if data_args.validation_file is not None:
            data_files[validation_split_name] = data_args.validation_file
        if data_args.train_file is not None:
            extension = data_args.train_file.split(".")[-1]
        elif data_args.validation_file is not None:
            extension = data_args.validation_file.split(".")[-1]
        else:
            raise ValueError("Expected a train or validation file.")
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
        raw_datasets = datasets.load_dataset(  # type: ignore[arg-type]
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            **dataset_args,
        )
        if validation_split_name not in raw_datasets.keys():
            raw_datasets[validation_split_name] = datasets.load_dataset(  # type:ignore
                extension,
                data_files=data_files,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )
main -- W&B Setup

Weights & Biases is initialized on rank 0 only via ezpz.setup_wandb(). A custom metric axis (num_input_tokens_seen) is defined and the full config is uploaded to the run.

src/ezpz/examples/hf.py
    if rank == 0:
        try:
            import wandb

            if (
                wandb is not None
                and report_to is not None
                and report_to != "none"
                and not os.environ.get("WANDB_DISABLED", False)
            ):
                wbproj_name = (
                    model_args.wandb_project_name
                    if model_args.wandb_project_name is not None
                    else model_args.model_name_or_path
                )
                if wbproj_name is None:
                    wbproj_name = "ezpz-hf-default-project"
                wbproj_name = f"ezpz-hf-{wbproj_name}".replace("/", "-")
                run = ezpz.setup_wandb(project_name=wbproj_name)
                if run is not None and run is wandb.run:
                    wandb.define_metric("num_input_tokens_seen")
                    run.config.update(
                        {
                            "model": model_args.__dict__,
                            "data": data_args.__dict__,
                            "training": training_args.to_dict(),
                            "ezpz.dist_info": ezpz.get_dist_info(),
                        }
                    )
        except Exception:
            logger.info("W&B setup skipped")

    ezpz.barrier()  # sync all ranks after rank-0 wandb setup
main -- Tokenization

Raw text is tokenized with dataset.map(). The main process runs the map first (via main_process_first) so other ranks can reuse the cache.

src/ezpz/examples/hf.py:417:443
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name, **config_kwargs
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path, **config_kwargs
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info("Overriding config: %s", model_args.config_overrides)
            config.update_from_string(model_args.config_overrides)
            logger.info("New config: %s", config)
main -- group_texts

Tokenized sequences are concatenated end-to-end, then sliced into fixed-length block_size chunks. This maximizes token utilization by avoiding padding. Labels are a copy of input_ids (standard causal LM objective).

src/ezpz/examples/hf.py:478:512
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(  # type:ignore
            config, trust_remote_code=model_args.trust_remote_code
        )

    if callable(getattr(model, "get_input_embeddings")):
        embedding_size = model.get_input_embeddings().weight.shape[0]
        if len(tokenizer) > embedding_size:
            model.resize_token_embeddings(len(tokenizer))

    logger.info("[rank %d] proceeding to data prep", rank)

    if training_args.do_train:
        column_names = list(raw_datasets[train_split_name].features)  # type:ignore
    else:
        column_names = list(raw_datasets[validation_split_name].features)  # type:ignore
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples: dict[str, object]) -> dict[str, object]:
        """Tokenize raw text using the configured tokenizer."""
        return tokenizer(examples[text_column_name])

    logger.info("[rank %d] entering tokenization", rank)
    with training_args.main_process_first(desc="dataset map tokenization"):
        if not data_args.streaming:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,  # type:ignore
main -- DataLoaders

Train and eval DataLoaders are built from the processed datasets. The train loader shuffles when not streaming.

src/ezpz/examples/hf.py:541:556
                min(1024, max_pos_embeddings),
            )
            if max_pos_embeddings > 0:
                block_size = min(1024, max_pos_embeddings)
            else:
                block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                "The block_size passed (%s) is larger than the maximum length for the model (%s). Using block_size=%s.",
                data_args.block_size,
                tokenizer.model_max_length,
                tokenizer.model_max_length,
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)
main -- Optimizer & LR Scheduler

AdamW is used with separate weight-decay groups (bias and LayerNorm weights are excluded). The LR scheduler is created from the HuggingFace get_scheduler helper.

src/ezpz/examples/hf.py:558:579
        examples: dict[str, list[list[int]]]
    ) -> dict[str, list[list[int]]]:
        """Concatenate and chunk tokenized text into fixed-size blocks."""
        concatenated_examples = {
            k: [int(x) for x in chain(*examples[k])]
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {
            k: [
                t[i : i + block_size]
                for i in range(0, total_length, block_size)
            ]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    logger.info("[rank %d] entering group_texts", rank)
    with training_args.main_process_first(desc="grouping texts together"):
        if not data_args.streaming:
src/ezpz/examples/hf.py:600:608
                train_dataset = train_dataset.take(data_args.max_train_samples)
            else:
                max_train_samples = min(
                    train_dataset.num_rows, data_args.max_train_samples
                )
                train_dataset = train_dataset.select(range(max_train_samples))

    eval_dataset = None
    if training_args.do_eval:
main -- accelerator.prepare

All training objects are wrapped by Accelerate in one call. This handles DDP/FSDP wrapping, optimizer state sharding, and dataloader distribution.

src/ezpz/examples/hf.py:614:624
                eval_dataset = eval_dataset.take(data_args.max_eval_samples)
            else:
                eval_dataset = eval_dataset.select(
                    range(data_args.max_eval_samples)
                )  # type:ignore

    assert train_dataset is not None
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=not data_args.streaming,
        collate_fn=default_data_collator,
main -- History & Checkpointing Setup

An ezpz.history.History object is created for metric tracking. Checkpoint resumption is handled by detecting step_* or epoch_* directories and calling accelerator.load_state.

src/ezpz/examples/hf.py:685:703
        if overrode_max_train_steps
        else training_args.max_steps * accelerator.num_processes,
    )

    _model_flops = try_estimate(
        model, (training_args.per_device_train_batch_size, block_size),
    )

    logger.info("[rank %d] calling accelerator.prepare() ...", rank)
    (
        model,
        optimizer,
        train_dataloader,
        eval_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
    logger.info("[rank %d] accelerator.prepare() complete", rank)
main -- Training Loop

Each epoch iterates over the dataloader inside accelerator.accumulate, which handles gradient accumulation transparently. After each optimizer step, per-step metrics (loss, perplexity, tokens/sec) are recorded via history.update.

src/ezpz/examples/hf.py:745:816
    logger.info(
        "  Num examples = %s",
        len(train_dataset) if hasattr(train_dataset, "__len__") else "unknown (streaming)",
    )
    logger.info("  Num Epochs = %s", training_args.num_train_epochs)
    logger.info(
        "  Instantaneous batch size per device = %s",
        training_args.per_device_train_batch_size,
    )
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %s",
        total_batch_size,
    )
    logger.info(
        "  Gradient Accumulation steps = %s",
        training_args.gradient_accumulation_steps,
    )
    logger.info("  Total optimization steps = %s", training_args.max_steps)

    logging_steps = max(1, int(training_args.logging_steps))
    outdir = Path(training_args.output_dir) if training_args.output_dir else Path.cwd() / "outputs"
    outdir.mkdir(parents=True, exist_ok=True)
    logger.info("Outputs will be saved to %s", outdir)
    history = ezpz.history.History(
        project_name=wbproj_name,
        config={
            "model": model_args.__dict__,
            "data": data_args.__dict__,
            "training": training_args.to_dict(),
        },
        outdir=outdir,
        report_dir=outdir,
        report_enabled=True,
        jsonl_path=outdir / "metrics.jsonl",
        jsonl_overwrite=True,
    )
    completed_steps = 0
    starting_epoch = 0

    if training_args.resume_from_checkpoint:
        if (
            training_args.resume_from_checkpoint is not None
            or training_args.resume_from_checkpoint != ""
        ):
            checkpoint_path = training_args.resume_from_checkpoint
            path = os.path.basename(training_args.resume_from_checkpoint)
        else:
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            path = dirs[-1]
            checkpoint_path = path
            path = os.path.basename(checkpoint_path)

        accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
        accelerator.load_state(checkpoint_path)
        training_difference = os.path.splitext(path)[0]

        if "epoch" in training_difference:
            starting_epoch = int(training_difference.replace("epoch_", "")) + 1
            resume_step = None
            completed_steps = (
                starting_epoch * num_update_steps_per_epoch
                if num_update_steps_per_epoch is not None
                else 0
            )
        else:
            resume_step = (
                int(training_difference.replace("step_", ""))
                * training_args.gradient_accumulation_steps
            )
            completed_steps = resume_step // training_args.gradient_accumulation_steps
            if num_update_steps_per_epoch is not None:
main -- Evaluation Loop

At the end of each epoch, an eval pass gathers losses across all ranks and computes perplexity. Results are logged through history.update.

src/ezpz/examples/hf.py
        if eval_dataloader is not None:
            model.eval()
            losses = []
            for _, batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**batch)

                loss = outputs.loss
                losses.append(
                    accelerator.gather_for_metrics(
                        loss.repeat(training_args.per_device_eval_batch_size)
                    )
                )

            losses = torch.cat(losses)
            eval_loss = torch.mean(losses)
            try:
                perplexity = math.exp(eval_loss)
            except OverflowError:
                perplexity = float("inf")

            avg_train_loss = float(total_loss) / max(completed_steps, 1)
            eval_metrics = {
                "step": completed_steps,
                "epoch": epoch,
                "eval_loss": float(eval_loss),
                "eval_perplexity": perplexity,
                "train_loss": avg_train_loss,
            }
            summary = history.update(eval_metrics)
            logger.info(summary)
main -- Epoch Checkpointing & Hub Upload

If push_to_hub is enabled, the model and tokenizer are saved and uploaded after each epoch. Per-epoch state can also be saved locally when save_strategy="epoch".

src/ezpz/examples/hf.py:850:871
            if accelerator.sync_gradients:
                t1step = time.perf_counter() - t0step
                completed_steps += 1
                total_loss += loss.detach().float().item()

                tokens_per_step = total_batch_size * block_size
                tps = tokens_per_step / t1step if t1step > 0 else 0.0
                step_loss = loss.detach().float().item()
                try:
                    step_ppl = math.exp(step_loss)
                except OverflowError:
                    step_ppl = float("inf")

                metrics = {
                    "train/step": completed_steps,
                    "train/epoch": epoch,
                    "train/loss": step_loss,
                    "train/perplexity": step_ppl,
                    "train/lr": lr_scheduler.get_last_lr()[0],
                    "train/dt": t1step,
                    "train/tokens_per_sec": tps,
                }
main -- Finalization & Save

After all epochs complete, the unwrapped model is saved. Timing information is collected and history.finalize() generates summary plots and writes final metrics. W&B timings are logged as a last step.

src/ezpz/examples/hf.py
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        output_dir,
        is_main_process=accelerator.is_main_process,
        save_function=accelerator.save,
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        if training_args.push_to_hub and api is not None and repo_id is not None:
            api.upload_folder(  # type: ignore[arg-type]
                commit_message="End of training",
                folder_path=output_dir,
                repo_id=repo_id,
                repo_type="model",
                token=training_args.hub_token,
            )
        if perplexity is not None:
            with open(
                os.path.join(output_dir, "all_results.json"),
                "w",
            ) as f:
                json.dump({"perplexity": perplexity}, f)

    train_end = time.perf_counter()
    timings = {
        "main/setup_torch": t_setup - t0,
        "main/train": train_end - train_start,
        "main/total": train_end - t0,
        "timings/training_start": train_start - t0,
        "timings/train_duration": train_end - train_start,
        "timings/end-to-end": train_end - t0,
    }
    logger.info("Timings: %s", timings)

    if accelerator.is_main_process:
        history.finalize(
            run_name="ezpz.examples.hf",
            dataset_fname="train",
            warmup=0,
            save=True,
            plot=True,
            outdir=outdir,
            timings=timings,
        )

    if wandb is not None and getattr(wandb, "run", None) is not None:
        try:
            wandb.log(
                {
                    (f"timings/{k}" if not k.startswith("timings/") else k): v
                    for k, v in timings.items()
                }
            )
        except Exception:
            logger.warning("Failed to log timings to wandb")
__main__ guard
src/ezpz/examples/hf.py:941:942
                    commit_message=f"Training in progress epoch {epoch}",
                    folder_path=output_dir,

MFU Tracking⚓︎

hf.py estimates model FLOPS via try_estimate before accelerator.prepare() (FlopCounterMode can't run through DDP/FSDP wrappers). Per-step TFLOPS and MFU are reported as train/tflops and train/mfu.

_model_flops = try_estimate(
    model, (training_args.per_device_train_batch_size, block_size),
)
# ... per step:
metrics["train/tflops"] = _model_flops / t1step / 1e12
metrics["train/mfu"] = compute_mfu(_model_flops, t1step)

For HF causal LMs, estimate_model_flops extracts output.logits.sum() as the backward target since output.loss is None without labels. See ezpz.flops for details.

Metric Keys, Logging, and Output⚓︎

  • Prefixed keys — Train metrics use train/ prefix (train/loss, train/perplexity, train/tflops, train/mfu); eval metrics use eval/ prefix. This makes History.finalize() produce separate train.h5 / eval.h5 datasets and grouped plots, instead of flattening everything into one column-shared table.
  • Log line cleanup — Each log line is tagged [train] or [eval], so the prefix is stripped from the per-line summary to reduce noise:
    logger.info("[train] %s", summary.replace("train/", ""))
    
  • HTTP log suppressionhttpx, huggingface_hub, and filelock are silenced to WARNING at startup. Without this, every Hub HEAD/GET produced an INFO log line per rank — hundreds of lines on multi-rank jobs.

Robustness⚓︎

  • Safetensors fallback — On some parallel filesystems (Lustre), safetensors raises Argument list too long (OS error 7) during save_pretrained. The example catches this and retries with safe_serialization=False to write .bin instead — the model still gets saved.
  • max_steps exits early — The HF training loop's completed_steps is global across epochs. Without an outer-loop break, epoch 1 would run a single wasted step before hitting max_steps. The example breaks out of the epoch loop too.

Comparison with hf_trainer.py⚓︎

This example (hf.py) uses an explicit training loop — you control the forward/backward/optimizer step directly, giving maximum flexibility for custom metrics, gradient manipulation, or unconventional schedules.

The companion hf_trainer.py uses the HuggingFace Trainer abstraction, which handles the loop internally. Use the Trainer version when you want standard training with minimal code; use this version when you need full control.

Help⚓︎

--help
$ python3 -m ezpz.examples.hf --help
# Accepts all HuggingFace TrainingArguments plus:
#   --model_name_or_path    Pretrained model name or path
#   --dataset_name          HuggingFace dataset name
#   --dataset_config_name   Dataset configuration
#   --do_train / --do_eval  Enable training / evaluation
#   --block_size            Token block size for grouping
#   --wandb_project_name    Custom W&B project name
# See HfModelArguments and HfDataTrainingArguments for full list.