ezpz.models.attention⚓︎
Attention primitives for transformer models.
Provides :class:FlexAttention (torch.nn.attention.flex_attention-based)
and :class:ScaledDotProductAttention (F.scaled_dot_product_attention-based)
modules, plus helpers for constructing and initialising block masks.
FlexAttention
⚓︎
Bases: Module
FlexAttention module that uses torch.nn.attention.flex_attention.
This module is a wrapper around torch.nn.attention.flex_attention. This module implements certain common attention types, such as causal and block_causal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
attn_mask_type
|
str
|
The type of attention mask. Currently, we support "causal" and "block_causal". "causal" means the lower triangle of the attention matrix is masked. "block_causal" means the attention matrix is divided into blocks, where block boundary is defined by EOS token, and the lower triangle of each block is masked. |
required |
fixed_block_size
|
int | None
|
The block size to be used to perform attention.
If specified, each sequence will be further divided to blocks, where each
block has the maximum size of |
None
|
Source code in src/ezpz/models/attention.py
class FlexAttention(torch.nn.Module):
"""FlexAttention module that uses torch.nn.attention.flex_attention.
This module is a wrapper around torch.nn.attention.flex_attention. This module
implements certain common attention types, such as causal and block_causal.
Args:
attn_mask_type (str): The type of attention mask. Currently, we support
"causal" and "block_causal". "causal" means the lower triangle of the
attention matrix is masked. "block_causal" means the attention matrix
is divided into blocks, where block boundary is defined by EOS token,
and the lower triangle of each block is masked.
fixed_block_size (int | None): The block size to be used to perform attention.
If specified, each sequence will be further divided to blocks, where each
block has the maximum size of ``fixed_block_size``. A query will only attend
to the keys within the same block.
"""
# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation.
flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set()
# Attention mask type to the created BlockMask.
# This allows us to keep track the created block masks for each
# new batch. We will use this to update the block mask when a
# new batch is created. This also allows user to create different
# block masks for different layers.
block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {}
# Instance variables.
attn_mask_type: str
def __init__(
self, attn_mask_type: str, fixed_block_size: int | None = None
) -> None:
super().__init__()
if attn_mask_type not in ["causal", "block_causal"]:
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
self.attn_mask_type = attn_mask_type
self.fixed_block_size = fixed_block_size
FlexAttention.used_attn_mask_types.add(self.mask_key)
@property
def mask_key(self) -> FLEX_ATTN_MASK_T:
return (self.attn_mask_type, self.fixed_block_size)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float | None = None,
) -> torch.Tensor:
block_mask = FlexAttention.block_masks[self.mask_key]
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale)
@staticmethod
def _get_causal_mask_mod() -> _mask_mod_signature:
def causal_mask(
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
return q_idx >= kv_idx
return causal_mask
@staticmethod
def _get_block_causal_mask_mod(
batch: torch.Tensor, eos_id: int
) -> _mask_mod_signature:
# batch is [b, s, h, d] shape
mask = batch == eos_id
mask[:, -1] = True
acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1)
seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32)
seq_idx[:, 1:] = acc_mask[:, :-1]
def block_causal_mask(
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
return block_causal_mask
@staticmethod
def _fixed_block_mask_mod(
mask_mod: _mask_mod_signature, fixed_block_size: int
) -> _mask_mod_signature:
"""
Given an arbitrary mask_mod, divide the input sequence to blocks
and only allow attention within the same block.
Args:
mask_mod: The mask mod to apply to the documents
fixed_block_size: The number of tokens in each block.
"""
# Credit to @drisspg.
def blocked_mask_mod(
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
# Get the block index of the query and key
q_block = q_idx // fixed_block_size
kv_block = kv_idx // fixed_block_size
# Only allow attention within the same block
same_block = q_block == kv_block
# Apply the original mask mod
inner_mask = mask_mod(
b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size
)
return same_block & inner_mask
blocked_mask_mod.__name__ = (
f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}"
)
return blocked_mask_mod
@staticmethod
@torch.no_grad()
def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None:
# batch is [b, s, h, d] shape
for mask_key in FlexAttention.used_attn_mask_types:
attn_mask_type, fixed_block_size = mask_key
match attn_mask_type:
case "causal":
if FlexAttention.block_masks.get(mask_key, None) is not None:
continue
# We don't care about batch dimension --
# all samples have the same lower triangle mask.
batch_dimension = 1
mask_mod = FlexAttention._get_causal_mask_mod()
case "block_causal":
if eos_id is None:
raise RuntimeError(
"eos_id must be provided for block_causal mask."
)
batch_dimension = batch.shape[0]
mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id)
case _:
raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")
if fixed_block_size is not None and fixed_block_size > 0:
mask_mod = FlexAttention._fixed_block_mask_mod(
mask_mod, fixed_block_size
)
seq_len = batch.shape[1]
block_mask = FlexAttention.compiled_create_block_mask(
mask_mod, batch_dimension, None, seq_len, seq_len
)
FlexAttention.block_masks[mask_key] = block_mask
ScaledDotProductAttention
⚓︎
Bases: Module
Attention using F.scaled_dot_product_attention with automatic backend selection.
Currently only supports causal masking. On first instantiation the class selects the best available SDPA backend (Flash, Efficient, Math).
Source code in src/ezpz/models/attention.py
class ScaledDotProductAttention(torch.nn.Module):
"""Attention using ``F.scaled_dot_product_attention`` with automatic backend selection.
Currently only supports causal masking. On first instantiation the class
selects the best available SDPA backend (Flash, Efficient, Math).
"""
backends: ClassVar[list[SDPBackend]] = []
def __init__(self, attn_mask_type: str) -> None:
super().__init__()
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)
ScaledDotProductAttention._init_backend()
@classmethod
def _init_backend(cls) -> None:
if cls.backends:
return
# Add CuDNN on B200 w/ highest priority
cls.backends = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
# if has_cuda_capability(10, 0):
# cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float | None = None,
) -> torch.Tensor:
assert self.backends, "SDPA Backends should not be empty."
with sdpa_kernel(self.backends, set_priority=True):
return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale)
build_attention(use_flex_attn, attn_mask_type, fixed_block_size=None)
⚓︎
Factory that returns the appropriate attention module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
use_flex_attn
|
bool
|
If |
required |
attn_mask_type
|
str
|
Mask type ( |
required |
fixed_block_size
|
int | None
|
Optional fixed block size for :class: |
None
|
Source code in src/ezpz/models/attention.py
def build_attention(
use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None
) -> FlexAttention | ScaledDotProductAttention:
"""Factory that returns the appropriate attention module.
Args:
use_flex_attn: If ``True``, return a :class:`FlexAttention` instance;
otherwise return :class:`ScaledDotProductAttention`.
attn_mask_type: Mask type (``"causal"`` or ``"block_causal"``).
fixed_block_size: Optional fixed block size for :class:`FlexAttention`.
"""
if use_flex_attn:
return FlexAttention(attn_mask_type, fixed_block_size)
else:
if fixed_block_size is not None:
raise ValueError(
"TorchTitan with SDPA currently does not support fixed_block_size."
)
if attn_mask_type != "causal":
raise ValueError(
"TorchTitan with SDPA currently only supports causal mask."
)
return ScaledDotProductAttention(attn_mask_type)
init_attention_mask(batch, eos_id, cp_mesh=None)
⚓︎
Initialise the :class:FlexAttention block masks for the current batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
Tensor
|
Input token tensor of shape |
required |
eos_id
|
int | None
|
End-of-sequence token id (required for |
required |
cp_mesh
|
DeviceMesh | None
|
Optional context-parallel device mesh. |
None
|
Source code in src/ezpz/models/attention.py
def init_attention_mask(
batch: torch.Tensor,
eos_id: int | None,
cp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
) -> None:
"""Initialise the :class:`FlexAttention` block masks for the current batch.
Args:
batch: Input token tensor of shape ``(B, S, H, D)``.
eos_id: End-of-sequence token id (required for ``block_causal``).
cp_mesh: Optional context-parallel device mesh.
"""
# This is not functional yet because we currently gate the use of Flex + CP
# while we continue debugging accuracy issues. However, we want to evaluate
# the user experience with CP enabled.
if cp_mesh is not None:
FlexAttention.compiled_create_block_mask = functools.partial(
create_cp_block_mask, device_mesh=cp_mesh
)
FlexAttention.init_attention_mask(batch, eos_id)