diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index 83c71a9e0f..f234f6c2b4 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -53,6 +53,7 @@ prepare_ddp, prepare_fsdp, ) +from .profiler import IProfiler from .progress import Progress from .rank_zero_log import ( rank_zero_critical, @@ -121,6 +122,7 @@ "NOOPStrategy", "prepare_ddp", "prepare_fsdp", + "IProfiler", "Progress", "rank_zero_critical", "rank_zero_debug", diff --git a/torchtnt/utils/profiler.py b/torchtnt/utils/profiler.py new file mode 100644 index 0000000000..278f1dde46 --- /dev/null +++ b/torchtnt/utils/profiler.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from types import TracebackType +from typing import Optional, Protocol, Type + +logger: logging.Logger = logging.getLogger(__name__) + + +class IProfiler(Protocol): + """Protocol for profilers. Can be used as a context manager.""" + + def __enter__(self) -> None: + """Enters the context manager and starts the profiler.""" + pass + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + """Exits the context manager and stops the profiler.""" + pass + + def start(self) -> None: + """Starts the profiler.""" + pass + + def stop(self) -> None: + """Stops the profiler.""" + pass + + def step(self) -> None: + """Signals to the profiler that the next profiling step has started.""" + pass