ezpz.examples.diffusion⚓︎
Tiny diffusion example for short text generation.
This script trains a tiny denoising diffusion model on a handful of toy sentences and then samples new sentences by running the reverse process. The goal is to keep the code minimal while showcasing the full flow:
1. Build a small vocabulary from a list of prompts.
2. Train a denoising network to predict noise on token embeddings.
3. Sample text by iterating the reverse diffusion process.
Typical usage (customize with args as needed):
ezpz-launch -m ezpz.examples.diffusion --timesteps 64 --train-steps 500 --batch-size 16
# with FSDP and a HF dataset slice:
WORLD_SIZE=2 ezpz-launch -m ezpz.examples.diffusion --hf-dataset ag_news --fsdp
Launch with:
ezpz launch -m ezpz.examples.diffusion --timesteps 64 --train-steps 500
Help output (python3 -m ezpz.examples.diffusion --help):
usage: diffusion.py [-h] [--batch-size BATCH_SIZE] [--dtype DTYPE]
[--extra-text [EXTRA_TEXT ...]] [--fsdp]
[--fsdp-mixed-precision] [--hidden HIDDEN]
[--hf-dataset HF_DATASET] [--hf-split HF_SPLIT]
[--hf-text-column HF_TEXT_COLUMN] [--hf-limit HF_LIMIT]
[--log_freq LOG_FREQ] [--outdir OUTDIR]
[--samples SAMPLES] [--seed SEED] [--seq-len SEQ_LEN]
[--timesteps TIMESTEPS] [--train-steps TRAIN_STEPS]
[--lr LR]
Tiny diffusion example for text generation.
options:
-h, --help show this help message and exit
--batch-size BATCH_SIZE
--dtype DTYPE
--extra-text [EXTRA_TEXT ...]
Additional sentences to add to the tiny corpus.
--fsdp Enable FSDP wrapping (requires WORLD_SIZE>1 and torch.distributed init).
--fsdp-mixed-precision
Use bfloat16 parameters with FSDP for speed (defaults to float32).
--hidden HIDDEN
--hf-dataset HF_DATASET
Optional Hugging Face dataset name (e.g., 'ag_news'). When set, replaces the toy corpus.
--hf-split HF_SPLIT Dataset split to load.
--hf-text-column HF_TEXT_COLUMN
Column containing raw text in the dataset.
--hf-limit HF_LIMIT Number of rows to sample from the HF dataset for quick experiments.
--log_freq LOG_FREQ
--outdir OUTDIR
--samples SAMPLES
--seed SEED
--seq-len SEQ_LEN
--timesteps TIMESTEPS
--train-steps TRAIN_STEPS
--lr LR
DiffusionSchedule
dataclass
⚓︎
Precompute alpha/beta schedule values for DDPM style updates.
Source code in src/ezpz/examples/diffusion.py
@dataclass
class DiffusionSchedule:
"""Precompute alpha/beta schedule values for DDPM style updates."""
timesteps: int = 64
beta_start: float = 1e-4
beta_end: float = 0.02
def __post_init__(self) -> None:
"""Precompute alpha and alpha_bar schedules for diffusion steps."""
self.betas = torch.linspace(
self.beta_start, self.beta_end, self.timesteps
)
self.alphas = 1.0 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
__post_init__()
⚓︎
Precompute alpha and alpha_bar schedules for diffusion steps.
Source code in src/ezpz/examples/diffusion.py
DiffusionTextModel
⚓︎
Bases: Module
Simple transformer that predicts noise on token embeddings.
Source code in src/ezpz/examples/diffusion.py
class DiffusionTextModel(nn.Module):
"""Simple transformer that predicts noise on token embeddings."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
max_seq_len: int,
timesteps: int,
n_layers: int = 2,
n_heads: int = 4,
) -> None:
"""Initialize embeddings and transformer encoder.
Args:
vocab_size: Size of the token vocabulary.
hidden_size: Dimensionality of embeddings and model width.
max_seq_len: Maximum sequence length.
timesteps: Number of diffusion steps.
n_layers: Number of transformer encoder layers.
n_heads: Attention heads per layer.
"""
super().__init__()
self.hidden_size = hidden_size # type:ignore
self.token_emb = nn.Embedding(vocab_size, hidden_size)
self.pos_emb = nn.Embedding(max_seq_len, hidden_size)
self.time_emb = nn.Embedding(timesteps, hidden_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=n_heads,
dim_feedforward=4 * hidden_size,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_layers
)
self.proj = nn.Linear(hidden_size, hidden_size)
def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
"""Embed token ids and scale them for transformer input."""
# Clone avoids autograd complaints about views when using sharded params.
return self.token_emb(tokens).clone() * math.sqrt(self.hidden_size)
def forward(
self, noisy_embs: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
"""Predict noise residuals given noisy embeddings and timestep."""
_, seq_len, _ = noisy_embs.shape
pos = self.pos_emb(torch.arange(seq_len, device=noisy_embs.device))
temb = self.time_emb(t).unsqueeze(1)
h = noisy_embs + pos.unsqueeze(0) + temb
h = self.encoder(h)
return self.proj(h)
def decode_tokens(self, embs: torch.Tensor) -> torch.Tensor:
"""Project embeddings back to token ids via tied embeddings."""
weights = self.token_emb.weight # (vocab, hidden)
logits = torch.einsum("bld,vd->blv", embs, weights)
return logits.argmax(dim=-1)
__init__(vocab_size, hidden_size, max_seq_len, timesteps, n_layers=2, n_heads=4)
⚓︎
Initialize embeddings and transformer encoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vocab_size
|
int
|
Size of the token vocabulary. |
required |
hidden_size
|
int
|
Dimensionality of embeddings and model width. |
required |
max_seq_len
|
int
|
Maximum sequence length. |
required |
timesteps
|
int
|
Number of diffusion steps. |
required |
n_layers
|
int
|
Number of transformer encoder layers. |
2
|
n_heads
|
int
|
Attention heads per layer. |
4
|
Source code in src/ezpz/examples/diffusion.py
def __init__(
self,
vocab_size: int,
hidden_size: int,
max_seq_len: int,
timesteps: int,
n_layers: int = 2,
n_heads: int = 4,
) -> None:
"""Initialize embeddings and transformer encoder.
Args:
vocab_size: Size of the token vocabulary.
hidden_size: Dimensionality of embeddings and model width.
max_seq_len: Maximum sequence length.
timesteps: Number of diffusion steps.
n_layers: Number of transformer encoder layers.
n_heads: Attention heads per layer.
"""
super().__init__()
self.hidden_size = hidden_size # type:ignore
self.token_emb = nn.Embedding(vocab_size, hidden_size)
self.pos_emb = nn.Embedding(max_seq_len, hidden_size)
self.time_emb = nn.Embedding(timesteps, hidden_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=n_heads,
dim_feedforward=4 * hidden_size,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_layers
)
self.proj = nn.Linear(hidden_size, hidden_size)
decode_tokens(embs)
⚓︎
Project embeddings back to token ids via tied embeddings.
Source code in src/ezpz/examples/diffusion.py
embed_tokens(tokens)
⚓︎
Embed token ids and scale them for transformer input.
Source code in src/ezpz/examples/diffusion.py
forward(noisy_embs, t)
⚓︎
Predict noise residuals given noisy embeddings and timestep.
Source code in src/ezpz/examples/diffusion.py
def forward(
self, noisy_embs: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
"""Predict noise residuals given noisy embeddings and timestep."""
_, seq_len, _ = noisy_embs.shape
pos = self.pos_emb(torch.arange(seq_len, device=noisy_embs.device))
temb = self.time_emb(t).unsqueeze(1)
h = noisy_embs + pos.unsqueeze(0) + temb
h = self.encoder(h)
return self.proj(h)
ToyTextDataset
⚓︎
Bases: Dataset
Pads or truncates sentences to a fixed length.
Source code in src/ezpz/examples/diffusion.py
class ToyTextDataset(Dataset):
"""Pads or truncates sentences to a fixed length."""
def __init__(
self, texts: List[str], vocab: Dict[str, int], seq_len: int = 12
):
"""Store corpus and vocabulary for encoding.
Args:
texts: Raw sentences.
vocab: Token-to-id mapping.
seq_len: Target sequence length for padding/truncation.
"""
self.texts = texts
self.vocab = vocab
self.seq_len = seq_len
self.pad_id = vocab["<pad>"]
self.unk_id = vocab["<unk>"]
def __len__(self) -> int:
"""Return number of sentences in the corpus."""
return len(self.texts)
def _encode(self, text: str) -> torch.Tensor:
"""Convert a sentence to a fixed-length tensor of token ids."""
tokens = [
self.vocab.get(tok, self.unk_id) for tok in text.lower().split()
]
tokens = tokens[: self.seq_len]
tokens += [self.pad_id] * (self.seq_len - len(tokens))
return torch.tensor(tokens, dtype=torch.long)
def __getitem__(self, idx: int) -> torch.Tensor: # type:ignore
"""Return encoded tokens for the indexed sentence."""
return self._encode(self.texts[idx])
__getitem__(idx)
⚓︎
__init__(texts, vocab, seq_len=12)
⚓︎
Store corpus and vocabulary for encoding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
texts
|
List[str]
|
Raw sentences. |
required |
vocab
|
Dict[str, int]
|
Token-to-id mapping. |
required |
seq_len
|
int
|
Target sequence length for padding/truncation. |
12
|
Source code in src/ezpz/examples/diffusion.py
def __init__(
self, texts: List[str], vocab: Dict[str, int], seq_len: int = 12
):
"""Store corpus and vocabulary for encoding.
Args:
texts: Raw sentences.
vocab: Token-to-id mapping.
seq_len: Target sequence length for padding/truncation.
"""
self.texts = texts
self.vocab = vocab
self.seq_len = seq_len
self.pad_id = vocab["<pad>"]
self.unk_id = vocab["<unk>"]
__len__()
⚓︎
add_noise(x0, t, schedule)
⚓︎
Apply forward diffusion noise to clean embeddings.
Source code in src/ezpz/examples/diffusion.py
def add_noise(
x0: torch.Tensor, t: torch.Tensor, schedule: DiffusionSchedule
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply forward diffusion noise to clean embeddings."""
noise = torch.randn_like(x0)
alpha_bar = schedule.alpha_bars.to(x0.device)[t].view(-1, 1, 1)
noisy = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1 - alpha_bar) * noise
return noisy, noise
build_vocab(texts)
⚓︎
Create a tiny vocabulary from a list of strings.
Source code in src/ezpz/examples/diffusion.py
def build_vocab(texts: Iterable[str]) -> Tuple[Dict[str, int], Dict[int, str]]:
"""Create a tiny vocabulary from a list of strings."""
specials = ["<pad>", "<unk>"]
words = sorted({word for text in texts for word in text.lower().split()})
vocab = {tok: idx for idx, tok in enumerate(specials + words)}
inv_vocab = {idx: tok for tok, idx in vocab.items()}
return vocab, inv_vocab
generate_text(model, schedule, inv_vocab, seq_len, num_samples, skip_tokens=('<pad>', '<unk>'))
⚓︎
Sample sequences from the trained diffusion model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
DiffusionTextModel
|
Trained diffusion network (possibly FSDP wrapped). |
required |
schedule
|
DiffusionSchedule
|
Noise schedule with precomputed alphas. |
required |
inv_vocab
|
Dict[int, str]
|
Mapping from token ids back to string tokens. |
required |
seq_len
|
int
|
Maximum sequence length. |
required |
num_samples
|
int
|
Number of sentences to generate. |
required |
skip_tokens
|
Tuple[str, ...]
|
Tokens to drop from decoded outputs. |
('<pad>', '<unk>')
|
Returns:
| Type | Description |
|---|---|
List[str]
|
List of generated text strings. |
Source code in src/ezpz/examples/diffusion.py
def generate_text(
model: DiffusionTextModel,
schedule: DiffusionSchedule,
inv_vocab: Dict[int, str],
seq_len: int,
num_samples: int,
skip_tokens: Tuple[str, ...] = ("<pad>", "<unk>"),
) -> List[str]:
"""Sample sequences from the trained diffusion model.
Args:
model: Trained diffusion network (possibly FSDP wrapped).
schedule: Noise schedule with precomputed alphas.
inv_vocab: Mapping from token ids back to string tokens.
seq_len: Maximum sequence length.
num_samples: Number of sentences to generate.
skip_tokens: Tokens to drop from decoded outputs.
Returns:
List of generated text strings.
"""
model.eval()
samples: List[str] = []
do_sample = ezpz.get_rank() == 0
is_fsdp = isinstance(model, FSDP)
base_model = model.module if hasattr(model, "module") else model
full_param_ctx = (
FSDP.summon_full_params(model) # , recursive=True)
if is_fsdp
else nullcontext()
)
with torch.no_grad():
with full_param_ctx:
if not do_sample:
return samples
token_emb_weight = base_model.token_emb.weight # type:ignore
for _ in range(num_samples):
xt = torch.randn(
(1, seq_len, base_model.hidden_size),
device=token_emb_weight.device,
)
for t in reversed(range(schedule.timesteps)):
xt = p_sample(base_model, xt, t, schedule)
logits = torch.einsum("bld,vd->blv", xt, token_emb_weight)
token_ids = logits.argmax(dim=-1)[0].tolist()
words = [
inv_vocab.get(tok_id, "<unk>") for tok_id in token_ids
]
words = [w for w in words if w not in skip_tokens]
samples.append(" ".join(words))
return samples
get_default_texts()
⚓︎
Return a small corpus of seed sentences for toy training.
Source code in src/ezpz/examples/diffusion.py
def get_default_texts() -> List[str]:
"""Return a small corpus of seed sentences for toy training."""
return [
"the product team ships updates every week",
"customers ask for faster onboarding",
"the service autoscaling keeps latency steady",
"data pipelines need reliable monitoring",
"large language models assist with code reviews",
"cloud costs drop when workloads are right sized",
"edge devices sync logs during quiet hours",
"dashboards show live metrics for incidents",
]
load_hf_texts(dataset_name, split, text_column, limit)
⚓︎
Pull a small slice of text from a Hugging Face dataset for quick experiments.
This uses only a limited number of rows (limit) to keep the example light.
Source code in src/ezpz/examples/diffusion.py
def load_hf_texts(
dataset_name: str,
split: str,
text_column: str,
limit: int,
) -> List[str]:
"""
Pull a small slice of text from a Hugging Face dataset for quick experiments.
This uses only a limited number of rows (`limit`) to keep the example light.
"""
try:
from datasets import load_dataset # type: ignore
except Exception as exc: # pragma: no cover - best-effort import
raise RuntimeError(
"datasets package is required for --hf-dataset usage"
) from exc
logger.info(
"Loading HF dataset %s split=%s column=%s limit=%s",
dataset_name,
split,
text_column,
limit,
)
dataset = load_dataset(dataset_name, split=split)
if text_column not in list(dataset.column_names):
raise ValueError(
f"text_column '{text_column}' not in dataset columns {dataset.column_names}"
)
texts = [str(row[text_column]) for row in dataset.select(range(limit))]
if not texts:
raise ValueError("No text rows found from HF dataset.")
return texts
main(args)
⚓︎
Set up distributed training, fit the model, and log samples.
Source code in src/ezpz/examples/diffusion.py
def main(args: argparse.Namespace) -> None:
"""Set up distributed training, fit the model, and log samples."""
rank = ezpz.setup_torch(seed=args.seed)
if rank == 0:
outdir = args.outdir if args.outdir is not None else OUTDIR
else:
outdir = None
outdir = ezpz.dist.broadcast(outdir, root=0)
logger.info(f"Using {outdir=}")
# self._created_at = ezpz.dist.broadcast(self._created_at, root=0)
if ezpz.get_rank() == 0:
run = ezpz.dist.setup_wandb(
project_name=WBPROJ_NAME,
# outdir=outdir,
)
assert run is not None and run is wandb.run
# wandb.config.update(ezpz.dist.get_dist_info())
wandb.config.update({"outdir": outdir, "args": {**vars(args)}})
# wandb.config.update({"args": {**vars(args)}})
base_texts: List[str]
if args.hf_dataset:
base_texts = load_hf_texts(
dataset_name=args.hf_dataset,
split=args.hf_split,
text_column=args.hf_text_column,
limit=args.hf_limit,
)
else:
base_texts = get_default_texts()
if args.extra_text:
base_texts = base_texts + args.extra_text
vocab, inv_vocab = build_vocab(base_texts)
dataset = ToyTextDataset(base_texts, vocab, seq_len=args.seq_len)
sampler = (
DistributedSampler(dataset) if ezpz.get_world_size() > 1 else None
)
loader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
shuffle=(sampler is None),
drop_last=False,
)
schedule = DiffusionSchedule(timesteps=args.timesteps)
model = DiffusionTextModel(
vocab_size=len(vocab),
hidden_size=args.hidden,
max_seq_len=args.seq_len,
timesteps=args.timesteps,
)
device = ezpz.get_torch_device(as_torch_device=True)
model.to(device)
history, wrapped_model = train(
model=model,
loader=loader,
schedule=schedule,
args=args,
steps=args.train_steps,
lr=args.lr,
outdir=outdir,
)
if ezpz.get_rank() == 0:
dataset = history.finalize(
run_name=WBRUN_NAME,
dataset_fname="train",
warmup=0.1,
)
samples = generate_text(
wrapped_model,
schedule,
inv_vocab,
seq_len=args.seq_len,
num_samples=args.samples,
)
if ezpz.get_rank() == 0:
for idx, text in enumerate(samples):
logger.info("sample %s: %s", idx, text)
p_sample(model, xt, t, schedule)
⚓︎
Predict one reverse-diffusion step at timestep t.
Source code in src/ezpz/examples/diffusion.py
def p_sample(
model: DiffusionTextModel,
xt: torch.Tensor,
t: int,
schedule: DiffusionSchedule,
) -> torch.Tensor:
"""Predict one reverse-diffusion step at timestep ``t``."""
t_batch = torch.full((xt.size(0),), t, device=xt.device, dtype=torch.long)
beta = schedule.betas.to(xt.device)[t]
alpha = schedule.alphas.to(xt.device)[t]
alpha_bar = schedule.alpha_bars.to(xt.device)[t]
eps = model(xt, t_batch)
mean = (xt - (beta / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
if t == 0:
return mean
noise = torch.randn_like(xt)
return mean + torch.sqrt(beta) * noise
parse_args()
⚓︎
Parse CLI arguments for the diffusion text example.
Source code in src/ezpz/examples/diffusion.py
def parse_args() -> argparse.Namespace:
"""Parse CLI arguments for the diffusion text example."""
parser = argparse.ArgumentParser(
description="Tiny diffusion example for text generation."
)
parser.add_argument(
"--batch-size", type=int, default=int(os.environ.get("BATCH_SIZE", 8))
)
parser.add_argument(
"--dtype", type=str, default=os.environ.get("DTYPE", "float32")
)
parser.add_argument(
"--extra-text",
type=str,
nargs="*",
default=None,
help="Additional sentences to add to the tiny corpus.",
)
parser.add_argument(
"--fsdp",
action="store_true",
help="Enable FSDP wrapping (requires WORLD_SIZE>1 and torch.distributed init).",
)
parser.add_argument(
"--fsdp-mixed-precision",
action="store_true",
help="Use bfloat16 parameters with FSDP for speed (defaults to float32).",
)
parser.add_argument(
"--hidden", type=int, default=int(os.environ.get("HIDDEN", 128))
)
parser.add_argument(
"--hf-dataset",
type=str,
default=None,
help="Optional Hugging Face dataset name (e.g., 'ag_news'). When set, replaces the toy corpus.",
)
parser.add_argument(
"--hf-split",
type=str,
default="train",
help="Dataset split to load.",
)
parser.add_argument(
"--hf-text-column",
type=str,
default="text",
help="Column containing raw text in the dataset.",
)
parser.add_argument(
"--hf-limit",
type=int,
default=512,
help="Number of rows to sample from the HF dataset for quick experiments.",
)
parser.add_argument(
"--log_freq", type=int, default=int(os.environ.get("LOG_FREQ", 1))
)
parser.add_argument("--outdir", type=str, default=None)
parser.add_argument(
"--samples", type=int, default=int(os.environ.get("SAMPLES", 3))
)
parser.add_argument(
"--seed", type=int, default=int(os.environ.get("SEED", 0))
)
parser.add_argument(
"--seq-len", type=int, default=int(os.environ.get("SEQ_LEN", 12))
)
parser.add_argument(
"--timesteps", type=int, default=int(os.environ.get("TIMESTEPS", 64))
)
parser.add_argument(
"--train-steps",
type=int,
default=int(os.environ.get("TRAIN_STEPS", 400)),
)
parser.add_argument(
"--lr", type=float, default=float(os.environ.get("LR", 3e-3))
)
# parser.add_argument(
# "--ddp",
# action="store_true",
# help="Enable DDP wrapping (requires WORLD_SIZE>1 and torch.distributed init).",
# )
return parser.parse_args()
sample_timesteps(batch_size, schedule, device)
⚓︎
Uniformly sample diffusion steps for a batch.
test(model, test_loader)
⚓︎
Evaluate the classifier outputs on a held-out loader.
Source code in src/ezpz/examples/diffusion.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def test(model, test_loader):
"""Evaluate the classifier outputs on a held-out loader."""
DEVICE = ezpz.get_torch_device()
DEVICE_ID = f"{DEVICE}:{ezpz.get_local_rank()}"
model.eval()
# correct = 0
ddp_loss = torch.zeros(3).to(DEVICE_ID)
with torch.no_grad():
for batch, target in test_loader:
batch, target = batch.to(DEVICE_ID), target.to(DEVICE_ID)
output = model(batch)
ddp_loss[0] += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
ddp_loss[2] += len(batch)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM) # type:ignore
test_loss = ddp_loss[0] / ddp_loss[2]
return {
"test_loss": test_loss,
"test_acc": 100.0 * ddp_loss[1] / ddp_loss[2],
}
train(model, loader, schedule, args, steps, outdir, lr=0.001)
⚓︎
Train the diffusion text model for a fixed number of steps.
Source code in src/ezpz/examples/diffusion.py
@ezpz.timeitlogit(rank=ezpz.get_rank())
def train(
model: DiffusionTextModel,
loader: DataLoader,
schedule: DiffusionSchedule,
args: argparse.Namespace,
steps: int,
outdir: Path | os.PathLike | str,
lr: float = 1e-3,
) -> tuple[ezpz.History, torch.nn.Module]:
"""Train the diffusion text model for a fixed number of steps."""
device = ezpz.get_torch_device(as_torch_device=True)
# if not isinstance(model, (DistributeFSDP):
model.to(device)
model.train()
wrapped_model = ezpz.dist.wrap_model(
model, use_fsdp=args.fsdp, dtype=args.dtype
)
optim = torch.optim.AdamW(wrapped_model.parameters(), lr=lr)
mstr = ezpz.models.summarize_model(
wrapped_model,
verbose=False,
depth=2,
# input_size=(
# torch.tensor((int(args.batch_size), int(args.seq_length))).to(
# torch.long
# )
# ).shape,
)
logger.info("Model summary:\n%s", mstr)
# outdir = Path(os.getcwd()) if outdir is None else outdir
# outdir_parent = Path(outdir).joinpath(ezpz.utils.get_timestamp())
# outdir = Path(outdir).as_posix()
metrics_path = Path(outdir).joinpath(f"metrics-{ezpz.get_rank()}.jsonl")
history = ezpz.history.History(
report_dir=outdir,
report_enabled=True,
jsonl_path=metrics_path,
jsonl_overwrite=True,
distributed_history=(
1 < ezpz.get_world_size() <= 384 # and not config.pytorch_profiler
),
)
# log_freq = max(1, steps // 100)
assert isinstance(
wrapped_model, (nn.Module, FSDP, DistributedDataParallel)
), "Model should be wrapped for training."
base_model = (
wrapped_model.module
if hasattr(wrapped_model, "module")
else wrapped_model
)
assert callable(getattr(base_model, "embed_tokens", None)), (
"Model should have embed_tokens method."
)
is_fsdp = isinstance(wrapped_model, FSDP)
loader_iter = iter(loader)
for step in range(steps):
t0 = time.perf_counter()
try:
tokens = next(loader_iter)
except StopIteration:
loader_iter = iter(loader)
tokens = next(loader_iter)
tokens = tokens.to(device)
t1 = time.perf_counter()
ezpz.dist.synchronize()
full_param_ctx = (
FSDP.summon_full_params(wrapped_model)
if is_fsdp
else nullcontext()
)
with full_param_ctx:
x0 = base_model.embed_tokens(tokens)
t = sample_timesteps(tokens.size(0), schedule, device=device)
xt, noise = add_noise(x0, t, schedule)
pred_noise = wrapped_model(xt, t)
loss = torch.mean((pred_noise - noise) ** 2)
t2 = time.perf_counter()
ezpz.dist.synchronize()
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)
t3 = time.perf_counter()
ezpz.dist.synchronize()
if step % args.log_freq == 0 or step == steps - 1:
logger.info(
history.update(
{
"train/step": step,
"train/loss": loss.item(),
"train/dt": t3 - t0,
"train/dtd": t1 - t0,
"train/dtf": t2 - t1,
"train/dtb": t3 - t2,
}
).replace("train/", "")
)
# loader_iter = iter(loader)
# for step in range(steps):
# try:
# tokens = next(loader_iter)
# except StopIteration:
# loader_iter = iter(loader)
# tokens = next(loader_iter)
# tokens = tokens.to(device)
# x0 = model.embed_tokens(tokens)
# t = sample_timesteps(tokens.size(0), schedule, device=device)
# xt, noise = add_noise(x0, t, schedule)
# pred_noise = model(xt, t)
# loss = torch.mean((pred_noise - noise) ** 2)
#
# loss.backward()
# optim.step()
# optim.zero_grad(set_to_none=True)
#
# if step % log_freq == 0 or step == steps - 1:
# summary = history.update({"step": step, "loss": loss.item()})
# logger.info(summary)
return history, wrapped_model