From ed673529633659576e7e94ff3a0e90867463e602 Mon Sep 17 00:00:00 2001 From: Danielle Pintz <daniellepintz@fb.com> Date: Mon, 30 Oct 2023 15:28:05 -0700 Subject: [PATCH] Add IProfiler protocol (#599) Summary: Add a protocol to define a Profiler. Reviewed By: crassirostris Differential Revision: D50765964 --- torchtnt/utils/__init__.py | 2 ++ torchtnt/utils/profiler.py | 40 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 torchtnt/utils/profiler.py diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index 83c71a9e0f..f234f6c2b4 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -53,6 +53,7 @@ prepare_ddp, prepare_fsdp, ) +from .profiler import IProfiler from .progress import Progress from .rank_zero_log import ( rank_zero_critical, @@ -121,6 +122,7 @@ "NOOPStrategy", "prepare_ddp", "prepare_fsdp", + "IProfiler", "Progress", "rank_zero_critical", "rank_zero_debug", diff --git a/torchtnt/utils/profiler.py b/torchtnt/utils/profiler.py new file mode 100644 index 0000000000..278f1dde46 --- /dev/null +++ b/torchtnt/utils/profiler.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from types import TracebackType +from typing import Optional, Protocol, Type + +logger: logging.Logger = logging.getLogger(__name__) + + +class IProfiler(Protocol): + """Protocol for profilers. Can be used as a context manager.""" + + def __enter__(self) -> None: + """Enters the context manager and starts the profiler.""" + pass + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: + """Exits the context manager and stops the profiler.""" + pass + + def start(self) -> None: + """Starts the profiler.""" + pass + + def stop(self) -> None: + """Stops the profiler.""" + pass + + def step(self) -> None: + """Signals to the profiler that the next profiling step has started.""" + pass