ezpz.tp.utils⚓︎
- See ezpz/tp/
utils.py
VocabUtility
⚓︎
Split the vocabulary into world_size chunks and return the first and last
index of the vocabulary belonging to the rank partition.
Note that indices in [first, last]
Source code in src/ezpz/tp/utils.py
class VocabUtility:
"""
Split the vocabulary into `world_size` chunks and return the first and last
index of the vocabulary belonging to the `rank` partition.
Note that indices in [first, last]
"""
@staticmethod
def get_vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank: int, _: int
) -> tuple[int, int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(
global_vocab_size: int, rank: int, world_size: int
) -> tuple[int, int]:
per_partition_vocab_size = divide_and_check_no_remainder(
global_vocab_size, world_size
)
return VocabUtility.get_vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
divide_and_check_no_remainder(numerator, denominator)
⚓︎
Divide the numerator by the denominator and check that there is no remainder.
ensure_divisibility(numerator, denominator)
⚓︎
Ensure that numerator is divisible by the denominator.
split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False)
⚓︎
Split a tensor along its last dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
The tensor to split. |
required |
num_partitions
|
int
|
The number of partitions to split the tensor into. |
required |
contiguous_split_chunks
|
bool
|
Whether to return contiguous split chunks. |
False
|
Source code in src/ezpz/tp/utils.py
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: The tensor to split.
num_partitions: The number of partitions to split the tensor into.
contiguous_split_chunks: Whether to return contiguous split chunks.
"""
last_dim = tensor.dim() - 1
last_dim_size = divide_and_check_no_remainder(
tensor.size()[last_dim], num_partitions
)
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list