ezpz.tp⚓︎
- See ezpz/tp
ezpz/tp/init.py
modified from: https://github.com/facebookresearch/fairscale/blob/5f484b3545f27eddb19d970fbe1d361b9c5f2b07/fairscale/nn/tensor_parallel/initialize.py
destroy_tensor_parallel()
⚓︎
Set the groups to none.
Source code in src/ezpz/tp/__init__.py
def destroy_tensor_parallel() -> None:
"""Set the groups to none."""
global _TENSOR_PARALLEL_GROUP
_TENSOR_PARALLEL_GROUP = None
global _TENSOR_PARALLEL_RANKS
_TENSOR_PARALLEL_RANKS = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _DATA_PARALLEL_RANKS
_DATA_PARALLEL_RANKS = None
global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None
global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None
global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GROUP_RANKS
_CONTEXT_PARALLEL_GROUP_RANKS = None
divide_and_check_no_remainder(numerator, denominator)
⚓︎
Divide the numerator by the denominator and check that there is no remainder.
ensure_divisibility(numerator, denominator)
⚓︎
Ensure that numerator is divisible by the denominator.
get_context_parallel_group()
⚓︎
Get the context parallel group the caller rank belongs to.
get_context_parallel_rank()
⚓︎
get_context_parallel_ranks()
⚓︎
Return context parallel ranks for the context parallel group.
get_context_parallel_world_size()
⚓︎
get_data_parallel_group()
⚓︎
Get the data parallel group the caller rank belongs to.
get_data_parallel_rank()
⚓︎
get_data_parallel_ranks()
⚓︎
Get the data parallel group the caller rank belongs to.
get_data_parallel_world_size()
⚓︎
get_pipeline_parallel_group()
⚓︎
Get the pipeline parallel group the caller rank belongs to.
get_pipeline_parallel_rank()
⚓︎
get_pipeline_parallel_ranks()
⚓︎
Get the pipeline parallel group the caller rank belongs to.
get_pipeline_parallel_world_size()
⚓︎
get_tensor_parallel_group()
⚓︎
Get the tensor parallel group the caller rank belongs to.
get_tensor_parallel_rank()
⚓︎
get_tensor_parallel_ranks()
⚓︎
Get the tensor parallel group the caller rank belongs to.
get_tensor_parallel_src_rank()
⚓︎
Calculate the global rank corresponding to local rank 0 in the TP group.
Source code in src/ezpz/tp/__init__.py
get_tensor_parallel_world_size()
⚓︎
initialize_tensor_parallel(tensor_parallel_size=1, pipeline_parallel_size=1, context_parallel_size=1, tensor_parallel_backend=None, pipeline_parallel_backend=None, context_parallel_backend=None, data_parallel_backend=None, timeout=None)
⚓︎
Initialize tensor data parallel groups.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor_parallel_size
|
int
|
number of GPUs used to parallelize model. |
1
|
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will create 4 tensor parallel groups and 2 data parallel groups as:
- 4 tensor parallel groups:
-
2 data parallel groups:
Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box.
process groups initialized in the order of TP, CP, PP, DP.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the tensor tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize the tensor pipeline. The present function will create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups and 8 data-parallel groups as: when alternate_pp_config = False,
- 8 data_parallel groups: [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
- 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
- 8 context-parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
- 8 pipeline model-parallel groups: [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
Source code in src/ezpz/tp/__init__.py
def initialize_tensor_parallel(
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
context_parallel_size: int = 1,
tensor_parallel_backend: Optional[str] = None,
pipeline_parallel_backend: Optional[str] = None,
context_parallel_backend: Optional[str] = None,
data_parallel_backend: Optional[str] = None,
timeout: Optional[timedelta] = None,
) -> None:
"""
Initialize tensor data parallel groups.
Arguments:
tensor_parallel_size: number of GPUs used to parallelize model.
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 tensor parallel groups and 2 data parallel groups as:
- 4 tensor parallel groups:
```
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
```
- 2 data parallel groups:
```
[g0, g2, g4, g6], [g1, g3, g5, g7]
```
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
process groups initialized in the order of TP, CP, PP, DP.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the tensor tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
the tensor pipeline. The present function will
create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
and 8 data-parallel groups as:
when alternate_pp_config = False,
- 8 data_parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
- 8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
- 8 context-parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
- 8 pipeline model-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
"""
# Get world size and rank. Ensure some consistencies.
assert tdist.is_initialized()
world_size = tdist.get_world_size()
tensor_parallel_size = int(min(tensor_parallel_size, world_size))
ensure_divisibility(world_size, tensor_parallel_size)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(
world_size,
tensor_parallel_size * pipeline_parallel_size * context_parallel_size,
)
rank = tdist.get_rank()
dpsize = int(
world_size
/ (tensor_parallel_size * pipeline_parallel_size * context_parallel_size)
)
if tdist.get_rank() == 0:
pstr = ", ".join(
[
f"TP: {tensor_parallel_size}",
f"PP: {pipeline_parallel_size}",
f"CP: {context_parallel_size}",
f"DP: {dpsize}",
]
)
logger.info(pstr)
# pstr = f'TP: {tensor_parallel_size}, PP: {pipeline_parallel_size}, CP: {context_parallel_size}, DP: {dpsize}'
# logger.info(
# '> initializing tensor parallel with size {}'.format(
# tensor_parallel_size
# )
# )
# logger.info(
# '> initializing context parallel with size {}'.format(
# context_parallel_size
# )
# )
# logger.info(
# '> initializing pipeline with size {}'.format(
# pipeline_parallel_size
# )
# )
groups = torch.LongTensor(range(world_size)).reshape(
dpsize,
pipeline_parallel_size,
context_parallel_size,
tensor_parallel_size,
)
found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found)
found = [x[0] for x in found]
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_RANKS
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
assert _DATA_PARALLEL_RANKS is None, "data parallel ranks are already initialized"
for i in range(pipeline_parallel_size):
for j in range(context_parallel_size):
for k in range(tensor_parallel_size):
ranks = groups[:, i, j, k].tolist()
group = tdist.new_group(
groups[:, i, j, k].tolist(),
backend=data_parallel_backend,
timeout=timeout,
)
if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_RANKS = ranks
# Build the tensor parallel groups.
global _TENSOR_PARALLEL_GROUP
global _TENSOR_PARALLEL_RANKS
assert (
_TENSOR_PARALLEL_GROUP is None
), "tensor parallel group is already initialized"
assert (
_TENSOR_PARALLEL_RANKS is None
), "tensor parallel ranks are already initialized"
for i in range(dpsize):
for j in range(pipeline_parallel_size):
for k in range(context_parallel_size):
ranks = groups[i, j, k, :].tolist()
group = tdist.new_group(
groups[i, j, k, :].tolist(),
backend=tensor_parallel_backend,
timeout=timeout,
)
if i == found[0] and j == found[1] and k == found[2]:
_TENSOR_PARALLEL_GROUP = group
_TENSOR_PARALLEL_RANKS = ranks
# Build the pipeline parallel groups.
global _PIPELINE_PARALLEL_GROUP
global _PIPELINE_PARALLEL_RANKS
assert (
_PIPELINE_PARALLEL_GROUP is None
), "Pipeline parallel group is already initialized"
for i in range(dpsize):
for j in range(context_parallel_size):
for k in range(tensor_parallel_size):
ranks = groups[i, :, j, k].tolist()
group = tdist.new_group(
ranks, backend=pipeline_parallel_backend, timeout=timeout
)
if i == found[0] and j == found[2] and k == found[3]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
# Build the context parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GROUP_RANKS
assert (
_CONTEXT_PARALLEL_GROUP is None
), "Context parallelism is already initialized."
for i in range(dpsize):
for j in range(pipeline_parallel_size):
for k in range(tensor_parallel_size):
ranks = groups[i, j, :, k].tolist()
group = tdist.new_group(
ranks, backend=context_parallel_backend, timeout=timeout
)
if i == found[0] and j == found[1] and k == found[3]:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False)
⚓︎
Split a tensor along its last dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
The tensor to split. |
required |
num_partitions
|
int
|
The number of partitions to split the tensor into. |
required |
contiguous_split_chunks
|
bool
|
Whether to return contiguous split chunks. |
False
|
Source code in src/ezpz/tp/utils.py
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: The tensor to split.
num_partitions: The number of partitions to split the tensor into.
contiguous_split_chunks: Whether to return contiguous split chunks.
"""
last_dim = tensor.dim() - 1
last_dim_size = divide_and_check_no_remainder(
tensor.size()[last_dim], num_partitions
)
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
tensor_parallel_is_initialized()
⚓︎
Check if tensor and data parallel groups are initialized.