Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove __enter__ and __exit__ from MemorySnapshotProfiler #611

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions tests/utils/test_memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def test_stop_step(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
memory_snapshot_profiler = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(stop_step=2),
memory_snapshot_params=MemorySnapshotParams(start_step=0, stop_step=2),
)
memory_snapshot_profiler.start()

# initialize device & allocate memory for tensors
device = get_device_from_env()
Expand All @@ -64,3 +63,59 @@ def test_stop_step(self) -> None:
self.assertTrue(os.path.exists(pickle_dump_path))
self.assertTrue(os.path.exists(trace_path))
self.assertTrue(os.path.exists(segment_plot_path))

@unittest.skipUnless(
condition=torch_version_geq_2_0,
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_validation(self) -> None:
"""Test parameter validation."""
with tempfile.TemporaryDirectory() as temp_dir:
with self.assertRaisesRegex(ValueError, "start_step must be nonnegative."):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=-1, stop_step=0
),
)
with self.assertRaisesRegex(
ValueError, "stop_step must be specified when start_step is set."
):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=2, stop_step=None
),
)
with self.assertRaisesRegex(ValueError, "start_step must be < stop_step."):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=2, stop_step=0
),
)
with self.assertRaisesRegex(ValueError, "stop_step must be positive."):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(stop_step=0),
)
with self.assertRaisesRegex(
ValueError,
"stop_step must be enabled with either start_step or enable_oom_observer.",
):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
stop_step=2, enable_oom_observer=False
),
)
with self.assertRaisesRegex(
ValueError,
"At least one of start_step/stop_step or enable_oom_observer must be set.",
):
_ = MemorySnapshotProfiler(
output_dir=temp_dir,
memory_snapshot_params=MemorySnapshotParams(
start_step=None, stop_step=None, enable_oom_observer=False
),
)
14 changes: 1 addition & 13 deletions torchtnt/framework/callbacks/memory_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Optional

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.state import State
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
from torchtnt.utils.memory_snapshot_profiler import (
MemorySnapshotParams,
Expand Down Expand Up @@ -42,24 +42,12 @@ def __init__(
self.memory_snapshot_profiler = MemorySnapshotProfiler(
output_dir=output_dir, memory_snapshot_params=memory_snapshot_params
)
self.memory_snapshot_profiler.start()

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
self.memory_snapshot_profiler.step()

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self.memory_snapshot_profiler.stop()

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
self.memory_snapshot_profiler.step()

def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
# if in fit do nothing since the profiler will be stopped in on_train_end
if state.entry_point == EntryPoint.EVALUATE:
self.memory_snapshot_profiler.stop()

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
self.memory_snapshot_profiler.step()

def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
self.memory_snapshot_profiler.stop()
68 changes: 54 additions & 14 deletions torchtnt/utils/memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import logging
from dataclasses import dataclass
from types import TracebackType
from typing import Optional, Type
from typing import Optional

import torch
from torchtnt.utils.oom import attach_oom_observer, log_memory_snapshot
Expand All @@ -22,12 +21,18 @@ class MemorySnapshotParams:
Memory snapshot parameters.

Args:
stop_step: Number of steps after which to dump memory snapshot, and stop recording memory history.
start_step: Step from which to start recording memory history.
stop_step: Step after which to dump memory snapshot, and stop recording memory history.
max_entries: Maximum number of events to keep in memory.
enable_oom_observer: Whether to attach an observer to record OOM events. If stop_step is set, the
OOM observer will only be active until stop_step is reached.

Note: If you set enable_oom_observer to True, you don't necessarily have to set a start_step as attach_oom_observer
will start recording memory history. Note that if you don't set a stop_step, it will continue recording memory
history until the program exits, which may incur a slight performance cost.
"""

start_step: Optional[int] = None
stop_step: Optional[int] = 2
max_entries: int = 100000
enable_oom_observer: bool = True
Expand All @@ -44,6 +49,20 @@ class MemorySnapshotProfiler:
Args:
output_dir: Directory where to save the memory snapshots.
memory_snapshot_params: Instance of MemorySnapshotParams.

Raises:
ValueError: If `start_step` is negative, or `stop_step` is less than or equal to zero.
ValueError: If `start_step` is greater than or equal to `stop_step`.
ValueError: If `start_step` is set and `stop_step` is not set.
ValueError: If `stop_step` is set and neither `start_step` nor `enable_oom_observer` are set.
ValueError: If `enable_oom_observer` is False and neither `start_step` nor `stop_step` is set

Examples::
memory_snapshot_params = MemorySnapshotParams(start_step=5, stop_step=10, enable_oom_observer=True)
memory_snapshot_profiler = MemorySnapshotProfiler(output_dir="/tmp", memory_snapshot_params=memory_snapshot_params)
for batch in dataloader:
...
memory_snapshot_profiler.step()
"""

def __init__(
Expand All @@ -55,6 +74,31 @@ def __init__(
self.params: MemorySnapshotParams = (
memory_snapshot_params or MemorySnapshotParams()
)
start_step = self.params.start_step
stop_step = self.params.stop_step
if start_step is not None:
if start_step < 0:
raise ValueError("start_step must be nonnegative.")
elif stop_step is None:
raise ValueError("stop_step must be specified when start_step is set.")
elif start_step >= stop_step:
raise ValueError("start_step must be < stop_step.")
if stop_step is not None:
if stop_step <= 0:
raise ValueError("stop_step must be positive.")
elif start_step is None and not self.params.enable_oom_observer:
raise ValueError(
"stop_step must be enabled with either start_step or enable_oom_observer."
)
if (
start_step is None
and stop_step is None
and not self.params.enable_oom_observer
):
raise ValueError(
"At least one of start_step/stop_step or enable_oom_observer must be set."
)

self.step_num: int = 0

if not is_torch_version_geq_2_0():
Expand All @@ -63,22 +107,13 @@ def __init__(
attach_oom_observer(
output_dir=output_dir, trace_max_entries=self.params.max_entries
)
if start_step is not None and start_step == 0:
self.start()

logger.info(
f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}."
)

def __enter__(self) -> None:
self.start()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
tb: Optional[TracebackType],
) -> Optional[bool]:
self.stop()

def start(self) -> None:
if not torch.cuda.is_available():
logger.warn("CUDA unavailable. Not recording memory history.")
Expand All @@ -97,6 +132,11 @@ def stop(self) -> None:

def step(self) -> None:
self.step_num += 1
if (
self.params.start_step is not None
and self.step_num == self.params.start_step
):
self.start()
if self.params.stop_step is not None and self.step_num == self.params.stop_step:
log_memory_snapshot(output_dir=self.output_dir)
self.stop()
Loading