Skip to content

ezpz.examples.hf⚓︎

Fine-tune a causal LM with a hand-rolled training loop.

This mirrors the dataset/model setup used in ezpz.examples.hf_trainer while keeping an explicit training loop like the other examples.

main() ⚓︎

Entrypoint for standalone HF causal LM fine-tuning without Trainer.

Source code in src/ezpz/examples/hf.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def main() -> None:
    """Entrypoint for standalone HF causal LM fine-tuning without Trainer."""
    t0 = time.perf_counter()
    rank = ezpz.setup_torch()
    model_args, data_args, training_args = parse_args()

    output_dir = training_args.output_dir or os.getcwd()
    wandb = None
    report_to = training_args.report_to

    # Build FSDP plugin explicitly when --fsdp is requested, bypassing
    # the env-var machinery which can pick up stale/conflicting defaults.
    fsdp_plugin = None
    use_fsdp = os.environ.get("ACCELERATE_USE_FSDP", "").lower() == "true"
    if use_fsdp:
        from torch.distributed.fsdp import (
            BackwardPrefetch,
            MixedPrecision,
            ShardingStrategy,
        )

        mp_policy = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ) if training_args.bf16 else None
        fsdp_plugin = FullyShardedDataParallelPlugin(
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            mixed_precision_policy=mp_policy,
            use_orig_params=True,
            sync_module_states=False,
            cpu_ram_efficient_loading=False,
            limit_all_gathers=True,
        )
        logger.info("[rank %d] using explicit FSDP plugin: %s", rank, fsdp_plugin)

    # Don't let Accelerator manage wandb — we handle it via ezpz.setup_wandb()
    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        fsdp_plugin=fsdp_plugin,
    )
    t_setup = time.perf_counter()

    # Initialise wandb early so console capture covers the full run.
    if rank == 0:
        try:
            import wandb

            if (
                wandb is not None
                and report_to is not None
                and report_to != "none"
                and not os.environ.get("WANDB_DISABLED", False)
            ):
                wbproj_name = (
                    model_args.wandb_project_name
                    if model_args.wandb_project_name is not None
                    else model_args.model_name_or_path
                )
                if wbproj_name is None:
                    wbproj_name = "ezpz-hf-default-project"
                wbproj_name = f"ezpz-hf-{wbproj_name}".replace("/", "-")
                run = ezpz.setup_wandb(project_name=wbproj_name)
                if run is not None and run is wandb.run:
                    wandb.define_metric("num_input_tokens_seen")
                    run.config.update(
                        {
                            "model": model_args.__dict__,
                            "data": data_args.__dict__,
                            "training": training_args.to_dict(),
                            "ezpz.dist_info": ezpz.get_dist_info(),
                        }
                    )
        except Exception:
            logger.info("W&B setup skipped")
    ezpz.barrier()

    logger.warning(accelerator.state)
    if accelerator.is_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    if training_args.seed is not None:
        set_seed(training_args.seed)

    api = None
    repo_id = None
    if accelerator.is_main_process:
        if training_args.push_to_hub:
            repo_name = training_args.hub_model_id
            if repo_name is None:
                repo_name = Path(output_dir).absolute().name
            api = HfApi()
            repo_id = api.create_repo(
                repo_name, exist_ok=True, token=training_args.hub_token
            ).repo_id

            with open(os.path.join(output_dir, ".gitignore"), "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        else:
            os.makedirs(output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    last_checkpoint = None
    overwrite = getattr(training_args, "overwrite_output_dir", False)
    if (
        os.path.isdir(output_dir)
        and training_args.do_train
        and not overwrite
    ):
        last_checkpoint = get_last_checkpoint(output_dir)
        if (
            last_checkpoint is None
            and len(os.listdir(output_dir)) > 0
        ):
            raise ValueError(
                "Output directory already exists and is not empty."
            )
        if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                "Checkpoint detected, resuming training at %s. To avoid this behavior, change the output_dir.",
                last_checkpoint,
            )

    train_split_name = data_args.train_split_name or "train"
    validation_split_name = data_args.validation_split_name or "validation"
    if data_args.dataset_name is not None:
        raw_datasets = split_dataset(
            model_args,
            data_args,
            train_split_name=train_split_name,
            validation_split_name=validation_split_name,
        )
    else:
        data_files: dict[str, str] = {}
        dataset_args: dict[str, object] = {}
        if data_args.train_file is not None:
            data_files[train_split_name] = data_args.train_file
        if data_args.validation_file is not None:
            data_files[validation_split_name] = data_args.validation_file
        if data_args.train_file is not None:
            extension = data_args.train_file.split(".")[-1]
        elif data_args.validation_file is not None:
            extension = data_args.validation_file.split(".")[-1]
        else:
            raise ValueError("Expected a train or validation file.")
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
        raw_datasets = datasets.load_dataset(  # type: ignore[arg-type]
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            **dataset_args,
        )
        if validation_split_name not in raw_datasets.keys():
            raw_datasets[validation_split_name] = datasets.load_dataset(  # type:ignore
                extension,
                data_files=data_files,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )
            raw_datasets[train_split_name] = datasets.load_dataset(  # type:ignore
                extension,
                data_files=data_files,
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(
            model_args.config_name, **config_kwargs
        )
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path, **config_kwargs
        )
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info("Overriding config: %s", model_args.config_overrides)
            config.update_from_string(model_args.config_overrides)
            logger.info("New config: %s", config)

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.tokenizer_name, **tokenizer_kwargs
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
            model_args.model_name_or_path, **tokenizer_kwargs
        )
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None]
            else getattr(torch, model_args.torch_dtype)
        )
        model = AutoModelForCausalLM.from_pretrained(  # type:ignore
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        logger.info("Training new model from scratch")
        model = AutoModelForCausalLM.from_config(  # type:ignore
            config, trust_remote_code=model_args.trust_remote_code
        )

    if callable(getattr(model, "get_input_embeddings")):
        embedding_size = model.get_input_embeddings().weight.shape[0]
        if len(tokenizer) > embedding_size:
            model.resize_token_embeddings(len(tokenizer))

    logger.info("[rank %d] proceeding to data prep", rank)

    if training_args.do_train:
        column_names = list(raw_datasets[train_split_name].features)  # type:ignore
    else:
        column_names = list(raw_datasets[validation_split_name].features)  # type:ignore
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples: dict[str, object]) -> dict[str, object]:
        """Tokenize raw text using the configured tokenizer."""
        return tokenizer(examples[text_column_name])

    logger.info("[rank %d] entering tokenization", rank)
    with training_args.main_process_first(desc="dataset map tokenization"):
        if not data_args.streaming:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,  # type:ignore
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,  # type:ignore
            )
        else:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=column_names,
            )
    logger.info("[rank %d] tokenization done", rank)

    if hasattr(config, "max_position_embeddings"):
        max_pos_embeddings = config.max_position_embeddings
    else:
        max_pos_embeddings = 1024
        logger.warning(
            "Config %s does not have 'max_position_embeddings'; using %s.",
            config,
            max_pos_embeddings,
        )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > max_pos_embeddings:
            logger.warning(
                "The tokenizer picked seems to have a very large `model_max_length` (%s). "
                "Using block_size=%s instead. You can change that default value by passing --block_size xxx.",
                tokenizer.model_max_length,
                min(1024, max_pos_embeddings),
            )
            if max_pos_embeddings > 0:
                block_size = min(1024, max_pos_embeddings)
            else:
                block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                "The block_size passed (%s) is larger than the maximum length for the model (%s). Using block_size=%s.",
                data_args.block_size,
                tokenizer.model_max_length,
                tokenizer.model_max_length,
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    def group_texts(
        examples: dict[str, list[list[int]]]
    ) -> dict[str, list[list[int]]]:
        """Concatenate and chunk tokenized text into fixed-size blocks."""
        concatenated_examples = {
            k: [int(x) for x in chain(*examples[k])]
            for k in examples.keys()
        }
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        total_length = (total_length // block_size) * block_size
        result = {
            k: [
                t[i : i + block_size]
                for i in range(0, total_length, block_size)
            ]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    logger.info("[rank %d] entering group_texts", rank)
    with training_args.main_process_first(desc="grouping texts together"):
        if not data_args.streaming:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,  # type:ignore
                load_from_cache_file=not data_args.overwrite_cache,  # type:ignore
            )
        else:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
            )
    logger.info("[rank %d] group_texts done", rank)

    train_dataset = None
    if training_args.do_train:
        if train_split_name not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets[train_split_name]  # type:ignore
        if data_args.max_train_samples is not None:
            if isinstance(train_dataset, datasets.IterableDataset):
                train_dataset = train_dataset.take(data_args.max_train_samples)
            else:
                max_train_samples = min(
                    train_dataset.num_rows, data_args.max_train_samples
                )
                train_dataset = train_dataset.select(range(max_train_samples))

    eval_dataset = None
    if training_args.do_eval:
        if validation_split_name not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets[validation_split_name]  # type:ignore
        if data_args.max_eval_samples is not None:
            if isinstance(eval_dataset, datasets.IterableDataset):
                eval_dataset = eval_dataset.take(data_args.max_eval_samples)
            else:
                eval_dataset = eval_dataset.select(
                    range(data_args.max_eval_samples)
                )  # type:ignore

    assert train_dataset is not None
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=not data_args.streaming,
        collate_fn=default_data_collator,
        batch_size=training_args.per_device_train_batch_size,
        num_workers=training_args.dataloader_num_workers,
    )
    eval_dataloader = None
    if eval_dataset is not None:
        eval_dataloader = DataLoader(
            eval_dataset,
            collate_fn=default_data_collator,
            batch_size=training_args.per_device_eval_batch_size,
            num_workers=training_args.dataloader_num_workers,
        )

    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": training_args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters, lr=training_args.learning_rate
    )

    overrode_max_train_steps = False
    try:
        num_update_steps_per_epoch = math.ceil(
            len(train_dataloader) / training_args.gradient_accumulation_steps
        )
    except TypeError:
        # IterableDataset (streaming) has no len(); require max_steps
        num_update_steps_per_epoch = None
    if training_args.max_steps <= 0:
        if num_update_steps_per_epoch is None:
            raise ValueError(
                "max_steps must be set when using a streaming / "
                "IterableDataset (dataset length is unknown)."
            )
        training_args.max_steps = int(
            training_args.num_train_epochs * num_update_steps_per_epoch
        )
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        name=training_args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=training_args.warmup_steps
        * accelerator.num_processes,
        num_training_steps=training_args.max_steps
        if overrode_max_train_steps
        else training_args.max_steps * accelerator.num_processes,
    )

    logger.info("[rank %d] calling accelerator.prepare() ...", rank)
    (
        model,
        optimizer,
        train_dataloader,
        eval_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
    logger.info("[rank %d] accelerator.prepare() complete", rank)

    try:
        num_update_steps_per_epoch = math.ceil(
            len(train_dataloader) / training_args.gradient_accumulation_steps
        )
    except TypeError:
        num_update_steps_per_epoch = None
    if num_update_steps_per_epoch is not None:
        if overrode_max_train_steps:
            training_args.max_steps = int(
                training_args.num_train_epochs * num_update_steps_per_epoch
            )
        training_args.num_train_epochs = math.ceil(
            training_args.max_steps / num_update_steps_per_epoch
        )

    checkpointing_steps = training_args.save_steps

    train_start = time.perf_counter()
    total_batch_size = (
        training_args.per_device_train_batch_size
        * accelerator.num_processes
        * training_args.gradient_accumulation_steps
    )

    logger.info("***** Model *****")
    logger.info("  Model = %s", model)
    logger.info("***** Args *****")
    logger.info(
        json.dumps(
            {
                "model": model_args.__dict__,
                "data": data_args.__dict__,
                "training": training_args.to_dict(),
            },
            indent=4,
            sort_keys=True,
        )
    )
    logger.info("***** Running training *****")
    logger.info("  Num processes = %s", accelerator.num_processes)
    logger.info(
        "  Num examples = %s",
        len(train_dataset) if hasattr(train_dataset, "__len__") else "unknown (streaming)",
    )
    logger.info("  Num Epochs = %s", training_args.num_train_epochs)
    logger.info(
        "  Instantaneous batch size per device = %s",
        training_args.per_device_train_batch_size,
    )
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %s",
        total_batch_size,
    )
    logger.info(
        "  Gradient Accumulation steps = %s",
        training_args.gradient_accumulation_steps,
    )
    logger.info("  Total optimization steps = %s", training_args.max_steps)

    logging_steps = max(1, int(training_args.logging_steps))
    outdir = Path(training_args.output_dir) if training_args.output_dir else Path.cwd() / "outputs"
    outdir.mkdir(parents=True, exist_ok=True)
    history = ezpz.history.History(
        report_dir=outdir,
        report_enabled=True,
        jsonl_path=outdir / "metrics.jsonl",
        jsonl_overwrite=True,
    )
    completed_steps = 0
    starting_epoch = 0

    if training_args.resume_from_checkpoint:
        if (
            training_args.resume_from_checkpoint is not None
            or training_args.resume_from_checkpoint != ""
        ):
            checkpoint_path = training_args.resume_from_checkpoint
            path = os.path.basename(training_args.resume_from_checkpoint)
        else:
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            path = dirs[-1]
            checkpoint_path = path
            path = os.path.basename(checkpoint_path)

        accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
        accelerator.load_state(checkpoint_path)
        training_difference = os.path.splitext(path)[0]

        if "epoch" in training_difference:
            starting_epoch = int(training_difference.replace("epoch_", "")) + 1
            resume_step = None
            completed_steps = (
                starting_epoch * num_update_steps_per_epoch
                if num_update_steps_per_epoch is not None
                else 0
            )
        else:
            resume_step = (
                int(training_difference.replace("step_", ""))
                * training_args.gradient_accumulation_steps
            )
            completed_steps = resume_step // training_args.gradient_accumulation_steps
            if num_update_steps_per_epoch is not None:
                starting_epoch = resume_step // len(train_dataloader)
                resume_step -= starting_epoch * len(train_dataloader)
            else:
                starting_epoch = 0
    else:
        resume_step = None

    total_loss = 0
    perplexity: Optional[float] = None
    for epoch in range(starting_epoch, int(training_args.num_train_epochs)):
        model.train()
        total_loss = 0
        if (
            training_args.resume_from_checkpoint
            and epoch == starting_epoch
            and resume_step is not None
        ):
            active_dataloader = accelerator.skip_first_batches(
                train_dataloader, resume_step
            )
        else:
            active_dataloader = train_dataloader
        for _, batch in enumerate(active_dataloader):
            t0step = time.perf_counter()
            t1step = 0.0
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                t1step = time.perf_counter() - t0step
                completed_steps += 1
                total_loss += loss.detach().float().item()

                tokens_per_step = total_batch_size * block_size
                tps = tokens_per_step / t1step if t1step > 0 else 0.0
                step_loss = loss.detach().float().item()
                try:
                    step_ppl = math.exp(step_loss)
                except OverflowError:
                    step_ppl = float("inf")

                metrics = {
                    "step": completed_steps,
                    "epoch": epoch,
                    "loss": step_loss,
                    "perplexity": step_ppl,
                    "learning_rate": lr_scheduler.get_last_lr()[0],
                    "dts": t1step,
                    "tokens_per_sec": tps,
                }

                if completed_steps % logging_steps == 0:
                    summary = history.update(metrics)
                    logger.info(summary)

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0 and accelerator.sync_gradients:
                    output_dir = f"step_{completed_steps}"
                    if training_args.output_dir is not None:
                        output_dir = os.path.join(
                            training_args.output_dir, output_dir
                        )
                    accelerator.save_state(output_dir)

            if completed_steps >= training_args.max_steps:
                break

        if eval_dataloader is not None:
            model.eval()
            losses = []
            for _, batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**batch)

                loss = outputs.loss
                losses.append(
                    accelerator.gather_for_metrics(
                        loss.repeat(training_args.per_device_eval_batch_size)
                    )
                )

            losses = torch.cat(losses)
            eval_loss = torch.mean(losses)
            try:
                perplexity = math.exp(eval_loss)
            except OverflowError:
                perplexity = float("inf")

            avg_train_loss = float(total_loss) / max(completed_steps, 1)
            eval_metrics = {
                "step": completed_steps,
                "epoch": epoch,
                "eval_loss": float(eval_loss),
                "eval_perplexity": perplexity,
                "train_loss": avg_train_loss,
            }
            summary = history.update(eval_metrics)
            logger.info(summary)

        if training_args.push_to_hub and epoch < training_args.num_train_epochs - 1:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(
                output_dir,
                is_main_process=accelerator.is_main_process,
                save_function=accelerator.save,
            )
            if accelerator.is_main_process and api is not None and repo_id is not None:
                tokenizer.save_pretrained(output_dir)
                api.upload_folder(  # type: ignore[arg-type]
                    commit_message=f"Training in progress epoch {epoch}",
                    folder_path=output_dir,
                    repo_id=repo_id,
                    repo_type="model",
                    token=training_args.hub_token,
                )

        if training_args.save_strategy == "epoch":
            output_dir = f"epoch_{epoch}"
            output_dir = os.path.join(output_dir, output_dir)
            accelerator.save_state(output_dir)

    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        output_dir,
        is_main_process=accelerator.is_main_process,
        save_function=accelerator.save,
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        if training_args.push_to_hub and api is not None and repo_id is not None:
            api.upload_folder(  # type: ignore[arg-type]
                commit_message="End of training",
                folder_path=output_dir,
                repo_id=repo_id,
                repo_type="model",
                token=training_args.hub_token,
            )
        if perplexity is not None:
            with open(
                os.path.join(output_dir, "all_results.json"),
                "w",
            ) as f:
                json.dump({"perplexity": perplexity}, f)

    train_end = time.perf_counter()
    timings = {
        "main/setup_torch": t_setup - t0,
        "main/train": train_end - train_start,
        "main/total": train_end - t0,
        "timings/training_start": train_start - t0,
        "timings/train_duration": train_end - train_start,
        "timings/end-to-end": train_end - t0,
    }
    logger.info("Timings: %s", timings)

    if accelerator.is_main_process:
        history.finalize(
            run_name="ezpz.examples.hf",
            dataset_fname="train",
            warmup=0,
            save=True,
            plot=True,
            outdir=outdir,
            timings=timings,
        )

    if wandb is not None and getattr(wandb, "run", None) is not None:
        try:
            wandb.log(
                {
                    (f"timings/{k}" if not k.startswith("timings/") else k): v
                    for k, v in timings.items()
                }
            )
        except Exception:
            logger.warning("Failed to log timings to wandb")

parse_args() ⚓︎

Parse Hugging Face model, data, and training arguments.

Returns:

Type Description
tuple[HfModelArguments, HfDataTrainingArguments, TrainingArguments]

Mapping with model, data, and training argument objects.

Source code in src/ezpz/examples/hf.py
def parse_args(
    ) -> tuple[HfModelArguments, HfDataTrainingArguments, TrainingArguments]:
    """Parse Hugging Face model, data, and training arguments.

    Returns:
        Mapping with ``model``, ``data``, and ``training`` argument objects.
    """
    parser = HfArgumentParser(
        (HfModelArguments, HfDataTrainingArguments, TrainingArguments)  # type:ignore
    )
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=os.path.abspath(sys.argv[1])
        )
    else:
        model_args, data_args, training_args = (
            parser.parse_args_into_dataclasses()
        )

    if training_args.should_log:
        from transformers.utils import logging as hf_logging

        hf_logging.set_verbosity_info()

    rank = ezpz.get_rank()
    log_level_info = 20
    log_level_critical = 50
    log_level = log_level_info if rank == 0 else log_level_critical
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    if rank == 0:
        logger.info("Training/evaluation parameters %s", training_args)

    return model_args, data_args, training_args

split_dataset(model_args, data_args, train_split_name='train', validation_split_name=None) ⚓︎

Split a Hugging Face dataset into train/validation splits.

Parameters:

Name Type Description Default
model_args HfModelArguments

Model configuration arguments for cache/token settings.

required
data_args HfDataTrainingArguments

Data-related arguments for dataset selection.

required
train_split_name str

Name of the training split.

'train'
validation_split_name Optional[str]

Name of the validation split (if any).

None

Returns:

Type Description
IterableDatasetDict | DatasetDict

Dataset dictionary with requested splits.

Source code in src/ezpz/examples/hf.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def split_dataset(
    model_args: HfModelArguments,
    data_args: HfDataTrainingArguments,
    train_split_name: str = "train",
    validation_split_name: Optional[str] = None,
) -> datasets.IterableDatasetDict | datasets.DatasetDict:
    """Split a Hugging Face dataset into train/validation splits.

    Args:
        model_args: Model configuration arguments for cache/token settings.
        data_args: Data-related arguments for dataset selection.
        train_split_name: Name of the training split.
        validation_split_name: Name of the validation split (if any).

    Returns:
        Dataset dictionary with requested splits.
    """
    dataset_name = data_args.dataset_name
    assert dataset_name is not None, (
        "dataset_name must be provided to split the dataset."
    )
    dsets: dict[str, datasets.Dataset | datasets.IterableDataset] = {}
    if validation_split_name is not None:
        try:
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
            dsets[train_split_name] = datasets.load_dataset(  # type: ignore
                dataset_name,
                data_args.dataset_config_name,
                split=f"{train_split_name}[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
        except ValueError:
            dsets[validation_split_name] = datasets.load_dataset(  # type:ignore
                dataset_name,
                data_args.dataset_config_name,
                split=train_split_name,
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
                trust_remote_code=model_args.trust_remote_code,
            )
            try:
                dsets[train_split_name] = datasets.load_dataset(  # type:ignore
                    dataset_name,
                    data_args.dataset_config_name,
                    split=f"{train_split_name}[:{data_args.validation_split_percentage}%]",
                    cache_dir=model_args.cache_dir,
                    token=model_args.token,
                    streaming=data_args.streaming,
                    trust_remote_code=model_args.trust_remote_code,
                )
            except Exception:
                dsets[train_split_name] = datasets.load_dataset(  # type:ignore
                    dataset_name,
                    data_args.dataset_config_name,
                    split=train_split_name,
                    cache_dir=model_args.cache_dir,
                    token=model_args.token,
                    streaming=data_args.streaming,
                    trust_remote_code=model_args.trust_remote_code,
                )

    if data_args.streaming:
        return datasets.IterableDatasetDict(  # type: ignore
            cast(dict[str, datasets.IterableDataset], dsets)
        )
    return datasets.DatasetDict(  # type: ignore
        cast(dict[str, datasets.Dataset], dsets)
    )