From ed673529633659576e7e94ff3a0e90867463e602 Mon Sep 17 00:00:00 2001
From: Danielle Pintz <daniellepintz@fb.com>
Date: Mon, 30 Oct 2023 15:28:05 -0700
Subject: [PATCH] Add IProfiler protocol (#599)

Summary:

Add a protocol to define a Profiler.

Reviewed By: crassirostris

Differential Revision: D50765964
---
 torchtnt/utils/__init__.py |  2 ++
 torchtnt/utils/profiler.py | 40 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 42 insertions(+)
 create mode 100644 torchtnt/utils/profiler.py

diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py
index 83c71a9e0f..f234f6c2b4 100644
--- a/torchtnt/utils/__init__.py
+++ b/torchtnt/utils/__init__.py
@@ -53,6 +53,7 @@
     prepare_ddp,
     prepare_fsdp,
 )
+from .profiler import IProfiler
 from .progress import Progress
 from .rank_zero_log import (
     rank_zero_critical,
@@ -121,6 +122,7 @@
     "NOOPStrategy",
     "prepare_ddp",
     "prepare_fsdp",
+    "IProfiler",
     "Progress",
     "rank_zero_critical",
     "rank_zero_debug",
diff --git a/torchtnt/utils/profiler.py b/torchtnt/utils/profiler.py
new file mode 100644
index 0000000000..278f1dde46
--- /dev/null
+++ b/torchtnt/utils/profiler.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from types import TracebackType
+from typing import Optional, Protocol, Type
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+
+class IProfiler(Protocol):
+    """Protocol for profilers. Can be used as a context manager."""
+
+    def __enter__(self) -> None:
+        """Enters the context manager and starts the profiler."""
+        pass
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_value: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> Optional[bool]:
+        """Exits the context manager and stops the profiler."""
+        pass
+
+    def start(self) -> None:
+        """Starts the profiler."""
+        pass
+
+    def stop(self) -> None:
+        """Stops the profiler."""
+        pass
+
+    def step(self) -> None:
+        """Signals to the profiler that the next profiling step has started."""
+        pass