Skip to content

ezpz.examples.vit⚓︎

Train a lightweight Vision Transformer on fake or MNIST data.

Launch with:

ezpz launch -m ezpz.examples.vit --dataset mnist --batch_size 256

Quick smoke test on a laptop:

python -m ezpz.examples.vit --dataset fake --max_iters 1         --batch_size 4 --img_size 64 --patch_size 8         --num_heads 2 --head_dim 16 --depth 2 --num_classes 10

Model presets:

--model debug|small|medium|med|large

Help output (python3 -m ezpz.examples.vit --help):

usage: ezpz.examples.vit [-h] [--img_size IMG_SIZE] [--batch_size BATCH_SIZE]
                         [--num_heads NUM_HEADS] [--head_dim HEAD_DIM]
                         [--hidden-dim HIDDEN_DIM] [--mlp-dim MLP_DIM]
                         [--dropout DROPOUT]
                         [--attention-dropout ATTENTION_DROPOUT]
                         [--num_classes NUM_CLASSES] [--dataset {fake,mnist}]
                         [--depth DEPTH] [--patch_size PATCH_SIZE]
                         [--dtype DTYPE] [--compile]
                         [--num_workers NUM_WORKERS] [--max_iters MAX_ITERS]
                         [--warmup WARMUP] [--attn_type {native,sdpa}]
                         [--cuda_sdpa_backend {flash_sdp,mem_efficient_sdp,math_sdp,cudnn_sdp,all}]
                         [--fsdp]

Train a simple ViT

options:
  -h, --help            show this help message and exit
  --img_size IMG_SIZE, --img-size IMG_SIZE
                        Image size
  --batch_size BATCH_SIZE, --batch-size BATCH_SIZE
                        Batch size
  --num_heads NUM_HEADS, --num-heads NUM_HEADS
                        Number of heads
  --head_dim HEAD_DIM, --head-dim HEAD_DIM
                        Hidden Dimension
  --hidden-dim HIDDEN_DIM, --hidden_dim HIDDEN_DIM
                        Hidden Dimension
  --mlp-dim MLP_DIM, --mlp_dim MLP_DIM
                        MLP Dimension
  --dropout DROPOUT     Dropout rate
  --attention-dropout ATTENTION_DROPOUT, --attention_dropout ATTENTION_DROPOUT
                        Attention Dropout rate
  --num_classes NUM_CLASSES, --num-classes NUM_CLASSES
                        Number of classes
  --dataset {fake,mnist}
                        Dataset to use
  --depth DEPTH         Depth
  --patch_size PATCH_SIZE, --patch-size PATCH_SIZE
                        Patch size
  --dtype DTYPE         Data type
  --compile             Compile model
  --num_workers NUM_WORKERS, --num-workers NUM_WORKERS
                        Number of workers
  --max_iters MAX_ITERS, --max-iters MAX_ITERS
                        Maximum iterations
  --warmup WARMUP       Warmup iterations (or fraction) before starting to collect metrics.
  --attn_type {native,sdpa}, --attn-type {native,sdpa}
                        Attention function to use.
  --cuda_sdpa_backend {flash_sdp,mem_efficient_sdp,math_sdp,cudnn_sdp,all}, --cuda-sdpa-backend {flash_sdp,mem_efficient_sdp,math_sdp,cudnn_sdp,all}
                        CUDA SDPA backend to use.
  --fsdp                Use FSDP

PatchEmbed ⚓︎

Bases: Module

Convert images into patch embeddings.

Source code in src/ezpz/examples/vit.py
class PatchEmbed(torch.nn.Module):
    """Convert images into patch embeddings."""

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_chans: int,
        embed_dim: int,
    ) -> None:
        super().__init__()
        if img_size % patch_size != 0:
            raise ValueError("img_size must be divisible by patch_size")
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = torch.nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

SimpleVisionTransformer ⚓︎

Bases: Module

Minimal Vision Transformer implementation without timm.

Source code in src/ezpz/examples/vit.py
class SimpleVisionTransformer(torch.nn.Module):
    """Minimal Vision Transformer implementation without timm."""

    def __init__(
        self,
        img_size: int,
        patch_size: int,
        in_chans: int,
        embed_dim: int,
        depth: int,
        num_heads: int,
        num_classes: int,
        block_fn: Any,
        class_token: bool = False,
        global_pool: str = "avg",
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches
        self.class_token = class_token
        self.global_pool = global_pool
        if class_token:
            self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))
            num_patches += 1
        else:
            self.cls_token = None
        self.pos_embed = torch.nn.Parameter(
            torch.zeros(1, num_patches, embed_dim)
        )
        self.pos_drop = torch.nn.Dropout(p=dropout)
        self.blocks = torch.nn.ModuleList(
            [
                block_fn(dim=embed_dim, num_heads=num_heads)
                for _ in range(depth)
            ]
        )
        self.norm = torch.nn.LayerNorm(embed_dim)
        self.head = (
            torch.nn.Linear(embed_dim, num_classes)
            if num_classes > 0
            else torch.nn.Identity()
        )
        self._init_weights()

    def _init_weights(self) -> None:
        torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.cls_token is not None:
            torch.nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        if self.cls_token is not None:
            cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        if self.global_pool == "avg":
            if self.cls_token is not None:
                x = x[:, 1:]
            x = x.mean(dim=1)
        elif self.cls_token is not None:
            x = x[:, 0]
        else:
            x = x.mean(dim=1)
        return self.head(x)

main(args) ⚓︎

CLI entrypoint to configure logging and launch ViT training.

Source code in src/ezpz/examples/vit.py
def main(args: argparse.Namespace):
    """CLI entrypoint to configure logging and launch ViT training."""
    rank = ezpz.dist.setup_torch()
    if rank == 0 and ezpz.verify_wandb():
        try:
            fp = Path(__file__).resolve()
            run = ezpz.setup_wandb(
                project_name=f"ezpz.{fp.parent.name}.{fp.stem}"
            )
            if wandb is not None and run is not None and run is wandb.run:
                # assert run is not None and run is wandb.run
                wandb.config.update(ezpz.get_dist_info())
                wandb.config.update({**vars(args)})
        except Exception:
            logger.warning("Failed to setup wandb, continuing without!")

    def attn_fn(
        q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
    ) -> torch.Tensor:
        """Scaled dot-product attention with configurable backend."""
        scale = args.head_dim ** (-0.5)
        q = q * scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        if args.attention_dropout > 0.0:
            attn = torch.nn.functional.dropout(
                attn,
                p=args.attention_dropout,
                training=torch.is_grad_enabled(),
            )
        x = attn @ v
        return x

    logger.info(f"Using {args.attn_type} for SDPA backend")
    if args.attn_type == "native":
        block_fn = functools.partial(AttentionBlock, attn_fn=attn_fn)
    # if args.sdpa_backend == 'by_hand':
    elif args.attn_type == "sdpa":
        if torch.cuda.is_available():
            torch.backends.cuda.enable_flash_sdp(False)
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_math_sdp(False)
            torch.backends.cuda.enable_cudnn_sdp(False)

            if args.cuda_sdpa_backend in ["flash_sdp", "all"]:
                torch.backends.cuda.enable_flash_sdp(True)
            if args.cuda_sdpa_backend in ["mem_efficient_sdp", "all"]:
                torch.backends.cuda.enable_mem_efficient_sdp(True)
            if args.cuda_sdpa_backend in ["math_sdp", "all"]:
                torch.backends.cuda.enable_math_sdp(True)
            if args.cuda_sdpa_backend in ["cudnn_sdp", "all"]:
                torch.backends.cuda.enable_cudnn_sdp(True)

        block_fn = functools.partial(
            AttentionBlock,
            attn_fn=lambda q, k, v: torch.nn.functional.scaled_dot_product_attention(
                q,
                k,
                v,
                dropout_p=(
                    args.attention_dropout if torch.is_grad_enabled() else 0.0
                ),
            ),
        )
    else:
        raise ValueError(f"Unknown attention type: {args.attn_type}")
    logger.info(f"Using AttentionBlock Attention with {args.compile=}")
    train_fn(block_fn, args=args, dataset=args.dataset)

parse_args(argv=None) ⚓︎

Parse CLI arguments for ViT training.

Source code in src/ezpz/examples/vit.py
def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
    """Parse CLI arguments for ViT training."""
    if argv is None:
        argv = sys.argv[1:]
    parser = argparse.ArgumentParser(
        prog="ezpz.examples.vit",
        description="Train a simple ViT",
    )
    parser.add_argument(
        "--img_size",
        "--img-size",
        type=int,
        default=224,
        help="Image size",
    )
    parser.add_argument(
        "--batch_size",
        "--batch-size",
        type=int,
        default=128,
        help="Batch size",
    )
    parser.add_argument(
        "--num_heads",
        "--num-heads",
        type=int,
        default=16,
        help="Number of heads",
    )
    parser.add_argument(
        "--head_dim",
        "--head-dim",
        type=int,
        default=64,
        help="Hidden Dimension",
    )
    parser.add_argument(
        "--hidden-dim",
        "--hidden_dim",
        type=int,
        default=1024,
        help="Hidden Dimension",
    )
    parser.add_argument(
        "--mlp-dim", "--mlp_dim", type=int, default=2048, help="MLP Dimension"
    )
    parser.add_argument(
        "--dropout", type=float, default=0.1, help="Dropout rate"
    )
    parser.add_argument(
        "--attention-dropout",
        "--attention_dropout",
        type=float,
        default=0.0,
        help="Attention Dropout rate",
    )
    parser.add_argument(
        "--num_classes",
        "--num-classes",
        type=int,
        default=1000,
        help="Number of classes",
    )
    parser.add_argument(
        "--dataset",
        default="mnist",
        choices=["fake", "mnist"],
        help="Dataset to use",
    )
    parser.add_argument(
        "--model",
        default=None,
        choices=sorted([*MODEL_PRESETS.keys(), *MODEL_ALIASES.keys()]),
        help="Model size preset (overrides defaults)",
    )
    parser.add_argument("--depth", type=int, default=24, help="Depth")
    parser.add_argument(
        "--patch_size",
        "--patch-size",
        type=int,
        default=16,
        help="Patch size",
    )
    parser.add_argument("--dtype", type=str, default="bf16", help="Data type")
    parser.add_argument("--compile", action="store_true", help="Compile model")
    parser.add_argument(
        "--num_workers",
        "--num-workers",
        type=int,
        default=0,
        help="Number of workers",
    )
    parser.add_argument(
        "--max_iters",
        "--max-iters",
        type=int,
        default=100,
        help="Maximum iterations",
    )
    parser.add_argument(
        "--warmup",
        type=float,
        default=0.1,
        help="Warmup iterations (or fraction) before starting to collect metrics.",
    )
    parser.add_argument(
        "--attn_type",
        "--attn-type",
        default="native",
        choices=["native", "sdpa"],
        help="Attention function to use.",
    )
    parser.add_argument(
        "--cuda_sdpa_backend",
        "--cuda-sdpa-backend",
        default="all",
        choices=[
            "flash_sdp",
            "mem_efficient_sdp",
            "math_sdp",
            "cudnn_sdp",
            "all",
        ],
        help="CUDA SDPA backend to use.",
    )
    parser.add_argument("--fsdp", action="store_true", help="Use FSDP")
    args = parser.parse_args(argv)
    apply_model_preset(args, argv)
    apply_dataset_overrides(args, argv)
    validate_dataset_args(args)
    return args

train_fn(block_fn, args, dataset='fake') ⚓︎

Train the Vision Transformer on fake or MNIST data.

Parameters:

Name Type Description Default
block_fn Any

Attention block constructor with attn_fn injected.

required
args Namespace

Training hyperparameters.

required
dataset Optional[str]

Dataset choice, either fake or mnist.

'fake'

Returns:

Type Description
History

History of training metrics.

Source code in src/ezpz/examples/vit.py
def train_fn(
    block_fn: Any,
    args: argparse.Namespace,
    dataset: Optional[str] = "fake",
) -> ezpz.History:
    """Train the Vision Transformer on fake or MNIST data.

    Args:
        block_fn: Attention block constructor with attn_fn injected.
        args: Training hyperparameters.
        dataset: Dataset choice, either ``fake`` or ``mnist``.

    Returns:
        History of training metrics.
    """
    # seed = int(os.environ.get('SEED', '0'))
    # rank = ezpz.setup(backend='DDP', seed=seed)
    world_size = ezpz.dist.get_world_size()

    local_rank = ezpz.dist.get_local_rank()
    # device_type = str(ezpz.get_torch_device(as_torch_device=False))
    device_type = ezpz.dist.get_torch_device_type()
    device = torch.device(f"{device_type}:{local_rank}")
    # torch.set_default_device(device)
    logger.info("train_args=%s", vars(args))

    if dataset == "fake":
        dataset_dict = get_fake_data(
            img_size=args.img_size,
            batch_size=args.batch_size,
        )
    elif dataset == "mnist":
        dataset_dict = get_mnist(
            train_batch_size=args.batch_size,
            test_batch_size=args.batch_size,
            download=(ezpz.dist.get_rank() == 0),
        )
    else:
        raise ValueError(
            f"Unknown dataset: {dataset}. Expected 'fake' or 'mnist'."
        )

    # data = get

    # train_set = FakeImageDataset(config.img_size)
    # logger.info(f'{len(train_set)=}')
    # train_loader = DataLoader(
    #     train_set,
    #     batch_size=config.batch_size,
    #     num_workers=args.num_workers,
    #     pin_memory=True,
    #     drop_last=True,
    # )

    in_chans = 1 if dataset == "mnist" else 3
    model = SimpleVisionTransformer(
        img_size=args.img_size,
        patch_size=args.patch_size,
        in_chans=in_chans,
        embed_dim=(args.num_heads * args.head_dim),
        depth=args.depth,
        num_heads=args.num_heads,
        num_classes=args.num_classes,
        class_token=False,
        global_pool="avg",
        block_fn=block_fn,
        dropout=args.dropout,
    )

    mstr = summarize_model(
        model,
        verbose=False,
        depth=1,
        input_size=(
            args.batch_size,
            in_chans,
            args.img_size,
            args.img_size,
        ),
    )
    model.to(device)
    num_params = sum(
        [
            sum(
                [
                    getattr(p, "ds_numel", 0)
                    if hasattr(p, "ds_id")
                    else p.nelement()
                    for p in model_module.parameters()
                ]
            )
            for model_module in model.modules()
        ]
    )
    model_size_in_billions = num_params / 1e9
    logger.info(f"\n{mstr}")
    logger.info(f"Model size: nparams={model_size_in_billions:.2f} B")
    if wandb is not None and ezpz.verify_wandb():
        if (run := getattr(wandb, "run")) is not None and run is wandb.run:
            try:
                if args.compile:
                    logger.info("Skipping wandb watch while compiling")
                else:
                    wandb.run.watch(model, log="all")  # type:ignore
            except Exception as e:
                logger.exception(e)
                logger.warning(
                    "Failed to watch model with wandb; continuing..."
                )

    # model = ezpz.dist.wrap_model(
    #     model=model,
    #     use_fsdp=args.fsdp,
    #     dtype=args.dtype,
    #     # device_id=int(ezpz.get_local_rank())
    # )
    if world_size > 1:
        model = ezpz.dist.wrap_model(
            model=model,
            use_fsdp=args.fsdp,
            dtype=args.dtype,
            device_id=ezpz.get_torch_device(as_torch_device=True),
        )
        # if args.fsdp:
        #     logger.info("Using FSDP for distributed training")
        #     if args.dtype in {"fp16", "bf16", "fp32"}:
        #         try:
        #             model = FSDP(
        #                 model,
        #                 mixed_precision=MixedPrecision(
        #                     param_dtype=TORCH_DTYPES_MAP[args.dtype],
        #                     reduce_dtype=torch.float32,
        #                     cast_forward_inputs=True,
        #                 ),
        #             )
        #         except Exception as exc:
        #             logger.warning(f"Encountered exception: {exc}")
        #             logger.warning(
        #                 "Unable to wrap model with FSDP. Falling back to DDP..."
        #             )
        #             model = ezpz.dist.wrap_model(model=model, f)
        #     else:
        #         try:
        #             model = FSDP(model)
        #         except Exception:
        #             model = ezpz.dist.wrap_model(args=args, model=model)
        # else:
        #     logger.info("Using DDP for distributed training")
        #     model = ezpz.dist.prepare_model_for_ddp(model)

    if args.compile:
        logger.info("Compiling model")
        model = torch.compile(model)

    torch_dtype = ezpz.dist.TORCH_DTYPES_MAP[args.dtype]
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters())  # type:ignore
    model.train()  # type:ignore

    history = ezpz.history.History()
    eval_history = ezpz.history.History()
    run_stamp = ezpz.get_timestamp()
    run_dir = Path.cwd().joinpath("outputs", WBRUN_NAME, run_stamp)
    logger.info(
        f"Training with {world_size} x {device_type} (s), using {torch_dtype=}"
    )
    warmup_iters = (
        int(args.warmup)
        if args.warmup >= 1.0
        else int(
            args.warmup
            * (
                args.max_iters
                if args.max_iters is not None
                else len(dataset_dict["train"]["loader"])
            )
        )
    )
    # data["train"].to(ezpz.dist.get_torch_device_type())
    last_step = -1
    for step, batch in enumerate(dataset_dict["train"]["loader"]):
        last_step = step
        if args.max_iters is not None and step > int(args.max_iters):
            break
        t0 = time.perf_counter()
        inputs = batch[0].to(device=device, non_blocking=True)
        label = batch[1].to(device=device, non_blocking=True)
        ezpz.dist.synchronize()
        with torch.autocast(device_type=device_type, dtype=torch_dtype):
            t1 = time.perf_counter()
            outputs = model(inputs)
            loss = criterion(outputs, label)
            acc = (outputs.argmax(dim=-1) == label).float().mean()
            t2 = time.perf_counter()
        ezpz.dist.synchronize()
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        ezpz.dist.synchronize()
        t3 = time.perf_counter()
        optimizer.step()
        ezpz.dist.synchronize()
        t4 = time.perf_counter()
        if step >= warmup_iters:
            loss_value = float(loss.detach().item())
            acc_value = float(acc.detach().item())
            if not math.isfinite(loss_value) or not math.isfinite(acc_value):
                logger.warning(
                    "Skipping non-finite train metrics at step=%s", step
                )
                continue
            train_msg = history.update(
                {
                    "train/iter": step,
                    "train/loss": loss_value,
                    "train/acc": acc_value,
                    "train/dt": t4 - t0,
                    "train/dtd": t1 - t0,
                    "train/dtf": t2 - t1,
                    "train/dto": t3 - t2,
                    "train/dtb": t4 - t3,
                }
            ).replace("train/", "")
            logger.info("[train] %s", train_msg)

    if "test" in dataset_dict:
        model.eval()  # type:ignore
        eval_loss = 0.0
        eval_acc = 0.0
        eval_count = 0
        eval_step = 0
        with torch.no_grad():
            for batch in dataset_dict["test"]["loader"]:
                inputs = batch[0].to(device=device, non_blocking=True)
                labels = batch[1].to(device=device, non_blocking=True)
                with torch.autocast(device_type=device_type, dtype=torch_dtype):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    correct = (outputs.argmax(dim=-1) == labels).sum()
                batch_size = labels.numel()
                eval_loss += loss.item() * batch_size
                eval_acc += correct.item()
                eval_count += batch_size
                batch_loss = float(loss.detach().item())
                batch_acc = float(correct.item() / batch_size)
                if math.isfinite(batch_loss) and math.isfinite(batch_acc):
                    eval_msg = eval_history.update(
                        {
                            "eval/iter": eval_step,
                            "eval/loss": batch_loss,
                            "eval/acc": batch_acc,
                        }
                    ).replace("eval/", "")
                    logger.info("[eval] %s", eval_msg)
                eval_step += 1
        if eval_count:
            total_loss = torch.tensor(eval_loss, device=device)
            total_correct = torch.tensor(eval_acc, device=device)
            total_count = torch.tensor(eval_count, device=device)
            if torch.distributed.is_available() and torch.distributed.is_initialized():
                torch.distributed.all_reduce(total_loss)
                torch.distributed.all_reduce(total_correct)
                torch.distributed.all_reduce(total_count)
            eval_loss_value = float(total_loss.item() / total_count.item())
            eval_acc_value = float(total_correct.item() / total_count.item())
            if not math.isfinite(eval_loss_value) or not math.isfinite(
                eval_acc_value
            ):
                logger.warning("Skipping non-finite eval metrics")
                model.train()  # type:ignore
                return history
            summary_rows = [
                ("loss", f"{eval_loss_value:.6f}"),
                ("acc", f"{eval_acc_value:.6f}"),
                ("samples", f"{int(total_count.item())}"),
            ]
            header = ("eval metric", "value")
            col1 = max(len(header[0]), *(len(row[0]) for row in summary_rows))
            col2 = max(len(header[1]), *(len(row[1]) for row in summary_rows))
            summary_table = [
                "Eval summary:",
                f"| {header[0]:<{col1}} | {header[1]:>{col2}} |",
                f"|:{'-' * (col1 - 1)} | {'-' * (col2 - 1)}:|",
            ]
            summary_table.extend(
                f"| {name:<{col1}} | {value:>{col2}} |"
                for name, value in summary_rows
            )
            logger.info("\n".join(f"[eval] {line}" for line in summary_table))
        model.train()  # type:ignore

    if ezpz.dist.get_rank() == 0:
        dataset = history.finalize(
            outdir=run_dir,
            run_name=WBRUN_NAME,
            dataset_fname="train",
            verbose=False,
        )
        logger.info(f"{dataset=}")
        if "test" in dataset_dict:
            eval_dataset = eval_history.finalize(
                outdir=run_dir,
                run_name=WBRUN_NAME,
                dataset_fname="eval",
                verbose=False,
            )
            logger.info(f"{eval_dataset=}")

    return history