diff --git a/torchtnt/utils/memory_snapshot_profiler.py b/torchtnt/utils/memory_snapshot_profiler.py index 58f2b69b99..07c31f0c77 100644 --- a/torchtnt/utils/memory_snapshot_profiler.py +++ b/torchtnt/utils/memory_snapshot_profiler.py @@ -6,8 +6,7 @@ import logging from dataclasses import dataclass -from types import TracebackType -from typing import Optional, Type +from typing import Optional import torch from torchtnt.utils.oom import attach_oom_observer, log_memory_snapshot @@ -115,17 +114,6 @@ def __init__( f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}." ) - def __enter__(self) -> None: - self.start() - - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - tb: Optional[TracebackType], - ) -> Optional[bool]: - self.stop() - def start(self) -> None: if not torch.cuda.is_available(): logger.warn("CUDA unavailable. Not recording memory history.")