ezpz.profile⚓︎
- See ezpz/
profile.py
profile.py
Sam Foreman [2024-06-21]
Contains implementation of:
get_context_managerPyInstrumentProfiler
which can be used as a context manager to profile a block of code, e.g.
# test.py
def main():
print("Hello!")
from ezpz.profile import get_context_manager
# NOTE:
# 1. if `rank` is passed to `get_context_manager`:
# - it will ONLY be instantiated if rank == 0,
# otherwise, it will return a contextlib.nullcontext() instance.
# 2. if `strict=True`:
# - only run if "PYINSTRUMENT_PROFILER=1" in environment
cm = get_context_manager(rank=RANK, strict=False)
with cm:
main()
if __name__ == "__main__":
main()
get_context_manager(rank=None, outdir=None, strict=True, *, profiler_type='pyinstrument', rank_zero_only=True, **profile_kwargs)
⚓︎
Returns a context manager for profiling code blocks using PyInstrument.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rank
|
Optional[int]
|
The rank of the process (default: None). If provided, the profiler will only run if rank is 0. |
None
|
outdir
|
Optional[str]
|
The output directory for saving profiles.
Defaults to |
None
|
strict
|
Optional[bool]
|
If True, the profiler will only run if "PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True. |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
AbstractContextManager |
AbstractContextManager
|
A context manager that starts and stops the PyInstrument profiler. |
Source code in src/ezpz/profile.py
def get_context_manager(
rank: Optional[int] = None,
outdir: Optional[str] = None,
strict: Optional[bool] = True,
*,
profiler_type: str = "pyinstrument",
rank_zero_only: bool = True,
**profile_kwargs: Any,
) -> AbstractContextManager:
"""
Returns a context manager for profiling code blocks using PyInstrument.
Args:
rank (Optional[int]): The rank of the process (default: None).
If provided, the profiler will only run if rank is 0.
outdir (Optional[str]): The output directory for saving profiles.
Defaults to `ezpz.OUTPUTS_DIR`.
strict (Optional[bool]): If True, the profiler will only run if
"PYINSTRUMENT_PROFILER" is set in the environment.
Defaults to True.
Returns:
AbstractContextManager: A context manager that starts and stops
the PyInstrument profiler.
"""
if profiler_type != "pyinstrument":
return get_profiling_context(
profiler_type=profiler_type,
wait=profile_kwargs.get("wait", 0),
warmup=profile_kwargs.get("warmup", 0),
active=profile_kwargs.get("active", 1),
repeat=profile_kwargs.get("repeat", 1),
rank_zero_only=rank_zero_only,
record_shapes=profile_kwargs.get("record_shapes", True),
with_stack=profile_kwargs.get("with_stack", True),
with_flops=profile_kwargs.get("with_flops", True),
with_modules=profile_kwargs.get("with_modules", True),
acc_events=profile_kwargs.get("acc_events", False),
profile_memory=profile_kwargs.get("profile_memory", False),
outdir=outdir,
strict=strict,
)
if rank_zero_only and rank not in (None, 0):
return nullcontext()
d = ezpz.OUTPUTS_DIR if outdir is None else outdir
fp = Path(d)
fp = fp.joinpath("ezpz", "pyinstrument_profiles")
if strict and os.environ.get("PYINSTRUMENT_PROFILER", None) is None:
return nullcontext()
return PyInstrumentProfiler(
rank=rank,
outdir=fp.as_posix(),
rank_zero_only=rank_zero_only,
)
get_profiling_context(profiler_type, wait, warmup, active, repeat, rank_zero_only, record_shapes=True, with_stack=True, with_flops=True, with_modules=True, acc_events=False, profile_memory=False, outdir=None, strict=True)
⚓︎
Returns a context manager for profiling code blocks using either PyTorch Profiler or PyInstrument.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
profiler_type
|
str
|
The type of profiler to use. Must be one of ['torch', 'pyinstrument']. |
required |
wait
|
int
|
The number of steps to wait before starting profiling. |
required |
warmup
|
int
|
The number of warmup steps before profiling starts. |
required |
active
|
int
|
The number of active profiling steps. |
required |
repeat
|
int
|
The number of times to repeat the profiling schedule. |
required |
rank_zero_only
|
bool
|
If True, the profiler will only run on rank 0. Defaults to True. |
required |
record_shapes
|
bool
|
If True, shapes of tensors are recorded. Defaults to True. |
True
|
with_stack
|
bool
|
If True, stack traces are recorded. Defaults to True. |
True
|
with_flops
|
bool
|
If True, FLOPs are recorded. Defaults to True. |
True
|
with_modules
|
bool
|
If True, module information is recorded. Defaults to True. |
True
|
acc_events
|
bool
|
If True, accumulated events are recorded. Defaults to False. |
False
|
profile_memory
|
bool
|
If True, memory profiling is enabled. Defaults to False. |
False
|
outdir
|
Optional[str | Path | PathLike]
|
The output directory
for saving profiles. Defaults to |
None
|
strict
|
Optional[bool]
|
If True, the profiler will only run if "PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True. |
True
|
Returns: AbstractContextManager: A context manager that starts and stops the profiler.
Source code in src/ezpz/profile.py
def get_profiling_context(
profiler_type: str,
wait: int,
warmup: int,
active: int,
repeat: int,
rank_zero_only: bool,
record_shapes: bool = True,
with_stack: bool = True,
with_flops: bool = True,
with_modules: bool = True,
acc_events: bool = False,
profile_memory: bool = False,
outdir: Optional[str | Path | os.PathLike] = None,
strict: Optional[bool] = True,
) -> AbstractContextManager:
"""
Returns a context manager for profiling code blocks using either
PyTorch Profiler or PyInstrument.
Args:
profiler_type (str): The type of profiler to use.
Must be one of ['torch', 'pyinstrument'].
wait (int): The number of steps to wait before starting profiling.
warmup (int): The number of warmup steps before profiling starts.
active (int): The number of active profiling steps.
repeat (int): The number of times to repeat the profiling schedule.
rank_zero_only (bool): If True, the profiler will only run on rank 0.
Defaults to True.
record_shapes (bool): If True, shapes of tensors are recorded.
Defaults to True.
with_stack (bool): If True, stack traces are recorded.
Defaults to True.
with_flops (bool): If True, FLOPs are recorded.
Defaults to True.
with_modules (bool): If True, module information is recorded.
Defaults to True.
acc_events (bool): If True, accumulated events are recorded.
Defaults to False.
profile_memory (bool): If True, memory profiling is enabled.
Defaults to False.
outdir (Optional[str | Path | os.PathLike]): The output directory
for saving profiles. Defaults to `ezpz.OUTPUTS_DIR`.
strict (Optional[bool]): If True, the profiler will only run if
"PYINSTRUMENT_PROFILER" is set in the environment. Defaults to True.
Returns:
AbstractContextManager: A context manager that starts and stops
the profiler.
"""
if profiler_type not in {"pt", "pytorch", "torch", "pyinstrument"}:
raise ValueError(
f"Invalid profiling type: {profiler_type}. "
"Must be one of ['torch', 'pyinstrument']"
)
outdir_fallback = Path(os.getcwd()).joinpath("ezpz", "torch_profiles")
outdir = outdir_fallback if outdir is None else outdir
_ = Path(outdir).mkdir(parents=True, exist_ok=True)
if profiler_type in {"torch", "pytorch", "pt"}:
def trace_handler(p: torch.profiler.profile):
"""
Callback function to handle the trace when it is ready.
"""
logger.info(
"\n"
+ p.key_averages().table(
sort_by=(f"self_{ezpz.get_torch_device_type()}_time_total"),
row_limit=-1,
)
)
fname: str = "-".join(
[
"torch-profiler",
f"rank{ezpz.get_rank()}",
f"step{p.step_num}",
f"{ezpz.get_timestamp()}",
]
)
trace_output = Path(outdir).joinpath(f"{fname}.json")
logger.info(f"Saving torch profiler trace to: {trace_output.as_posix()}")
p.export_chrome_trace(trace_output.as_posix())
schedule = torch.profiler.schedule(
wait=wait,
warmup=warmup,
active=active,
repeat=repeat,
)
return get_torch_profiler(
rank=ezpz.get_rank(),
schedule=schedule,
on_trace_ready=trace_handler,
rank_zero_only=rank_zero_only,
profile_memory=profile_memory,
record_shapes=record_shapes,
with_stack=with_stack,
with_flops=with_flops,
with_modules=with_modules,
acc_events=acc_events,
)
if profiler_type == "pyinstrument":
return get_context_manager(rank=ezpz.get_rank(), strict=strict)
raise ValueError(
f"Invalid profiling type: {profiler_type}. "
"Must be one of ['torch', 'pyinstrument']"
)
get_torch_profiler(rank=None, schedule=None, on_trace_ready=None, rank_zero_only=True, profile_memory=False, record_shapes=True, with_stack=True, with_flops=True, with_modules=True, acc_events=False)
⚓︎
A thin wrapper around torch.profiler.profile that:
- Supports automatic device detection {CPU, CUDA, XPU}
- Runs on rank 0 only (by default)
- To run from all ranks, set
rank_zero_only=False
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rank
|
Optional[int]
|
The rank of the process (default: None). If provided, the profiler will only run if rank is 0. |
None
|
schedule
|
Optional[Callable[[int], ProfilerAction]]
|
A callable
that returns a |
None
|
on_trace_ready
|
Optional[Callable]
|
A callback function that is called when the trace is ready. |
None
|
rank_zero_only
|
bool
|
If True, the profiler will only run on rank 0. Defaults to True. |
True
|
profile_memory
|
bool
|
If True, memory profiling is enabled. Defaults to False. |
False
|
record_shapes
|
bool
|
If True, shapes of tensors are recorded. Defaults to True. |
True
|
with_stack
|
bool
|
If True, stack traces are recorded. Defaults to True. |
True
|
with_flops
|
bool
|
If True, FLOPs are recorded. Defaults to True. |
True
|
with_modules
|
bool
|
If True, module information is recorded. Defaults to True. |
True
|
acc_events
|
bool
|
If True, accumulated events are recorded. Defaults to False. |
False
|
Returns: torch.profiler.profile: A profiler context manager that can be used to profile code blocks.
Source code in src/ezpz/profile.py
def get_torch_profiler(
rank: Optional[int] = None,
schedule: Optional[Callable[[int], ProfilerAction]] = None,
on_trace_ready: Optional[Callable] = None,
rank_zero_only: bool = True,
profile_memory: bool = False,
record_shapes: bool = True,
with_stack: bool = True,
with_flops: bool = True,
with_modules: bool = True,
acc_events: bool = False,
):
"""
A thin wrapper around `torch.profiler.profile` that:
1. Supports automatic device detection {CPU, CUDA, XPU}
2. Runs on rank 0 only (by default)
- To run from _all_ ranks, set `rank_zero_only=False`
Args:
rank (Optional[int]): The rank of the process (default: None).
If provided, the profiler will only run if rank is 0.
schedule (Optional[Callable[[int], ProfilerAction]]): A callable
that returns a `ProfilerAction` for the profiler schedule.
on_trace_ready (Optional[Callable]): A callback function that is
called when the trace is ready.
rank_zero_only (bool): If True, the profiler will only run on rank 0.
Defaults to True.
profile_memory (bool): If True, memory profiling is enabled.
Defaults to False.
record_shapes (bool): If True, shapes of tensors are recorded.
Defaults to True.
with_stack (bool): If True, stack traces are recorded.
Defaults to True.
with_flops (bool): If True, FLOPs are recorded.
Defaults to True.
with_modules (bool): If True, module information is recorded.
Defaults to True.
acc_events (bool): If True, accumulated events are recorded.
Defaults to False.
Returns:
torch.profiler.profile: A profiler context manager that can be used
to profile code blocks.
"""
if rank_zero_only and (rank is None or rank != 0):
return nullcontext()
activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(ProfilerActivity.CUDA)
if hasattr(torch, "xpu") and torch.xpu.is_available():
activities.append(ProfilerActivity.XPU)
return profile(
activities=activities,
schedule=schedule,
on_trace_ready=on_trace_ready,
record_shapes=record_shapes,
profile_memory=profile_memory,
with_stack=with_stack,
with_flops=with_flops,
with_modules=with_modules,
# use_cuda=(torch.cuda.is_available()),
# acc_events=acc_events,
)