diff --git a/tests/utils/test_memory_snapshot_profiler.py b/tests/utils/test_memory_snapshot_profiler.py index 936382e14d..3b5a112a64 100644 --- a/tests/utils/test_memory_snapshot_profiler.py +++ b/tests/utils/test_memory_snapshot_profiler.py @@ -35,9 +35,8 @@ def test_stop_step(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: memory_snapshot_profiler = MemorySnapshotProfiler( output_dir=temp_dir, - memory_snapshot_params=MemorySnapshotParams(stop_step=2), + memory_snapshot_params=MemorySnapshotParams(start_step=0, stop_step=2), ) - memory_snapshot_profiler.start() # initialize device & allocate memory for tensors device = get_device_from_env() @@ -64,3 +63,59 @@ def test_stop_step(self) -> None: self.assertTrue(os.path.exists(pickle_dump_path)) self.assertTrue(os.path.exists(trace_path)) self.assertTrue(os.path.exists(segment_plot_path)) + + @unittest.skipUnless( + condition=torch_version_geq_2_0, + reason="This test needs changes from PyTorch 2.0 to run.", + ) + def test_validation(self) -> None: + """Test parameter validation.""" + with tempfile.TemporaryDirectory() as temp_dir: + with self.assertRaisesRegex(ValueError, "start_step must be nonnegative."): + _ = MemorySnapshotProfiler( + output_dir=temp_dir, + memory_snapshot_params=MemorySnapshotParams( + start_step=-1, stop_step=0 + ), + ) + with self.assertRaisesRegex( + ValueError, "stop_step must be specified when start_step is set." + ): + _ = MemorySnapshotProfiler( + output_dir=temp_dir, + memory_snapshot_params=MemorySnapshotParams( + start_step=2, stop_step=None + ), + ) + with self.assertRaisesRegex(ValueError, "start_step must be < stop_step."): + _ = MemorySnapshotProfiler( + output_dir=temp_dir, + memory_snapshot_params=MemorySnapshotParams( + start_step=2, stop_step=0 + ), + ) + with self.assertRaisesRegex(ValueError, "stop_step must be positive."): + _ = MemorySnapshotProfiler( + output_dir=temp_dir, + memory_snapshot_params=MemorySnapshotParams(stop_step=0), + ) + with self.assertRaisesRegex( + ValueError, + "stop_step must be enabled with either start_step or enable_oom_observer.", + ): + _ = MemorySnapshotProfiler( + output_dir=temp_dir, + memory_snapshot_params=MemorySnapshotParams( + stop_step=2, enable_oom_observer=False + ), + ) + with self.assertRaisesRegex( + ValueError, + "At least one of start_step/stop_step or enable_oom_observer must be set.", + ): + _ = MemorySnapshotProfiler( + output_dir=temp_dir, + memory_snapshot_params=MemorySnapshotParams( + start_step=None, stop_step=None, enable_oom_observer=False + ), + ) diff --git a/torchtnt/framework/callbacks/memory_snapshot.py b/torchtnt/framework/callbacks/memory_snapshot.py index 2436523062..061750b34a 100644 --- a/torchtnt/framework/callbacks/memory_snapshot.py +++ b/torchtnt/framework/callbacks/memory_snapshot.py @@ -8,7 +8,7 @@ from typing import Optional from torchtnt.framework.callback import Callback -from torchtnt.framework.state import EntryPoint, State +from torchtnt.framework.state import State from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit from torchtnt.utils.memory_snapshot_profiler import ( MemorySnapshotParams, @@ -42,24 +42,12 @@ def __init__( self.memory_snapshot_profiler = MemorySnapshotProfiler( output_dir=output_dir, memory_snapshot_params=memory_snapshot_params ) - self.memory_snapshot_profiler.start() def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: self.memory_snapshot_profiler.step() - def on_train_end(self, state: State, unit: TTrainUnit) -> None: - self.memory_snapshot_profiler.stop() - def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: self.memory_snapshot_profiler.step() - def on_eval_end(self, state: State, unit: TEvalUnit) -> None: - # if in fit do nothing since the profiler will be stopped in on_train_end - if state.entry_point == EntryPoint.EVALUATE: - self.memory_snapshot_profiler.stop() - def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: self.memory_snapshot_profiler.step() - - def on_predict_end(self, state: State, unit: TPredictUnit) -> None: - self.memory_snapshot_profiler.stop() diff --git a/torchtnt/utils/memory_snapshot_profiler.py b/torchtnt/utils/memory_snapshot_profiler.py index 72a96e6e76..58f2b69b99 100644 --- a/torchtnt/utils/memory_snapshot_profiler.py +++ b/torchtnt/utils/memory_snapshot_profiler.py @@ -22,12 +22,18 @@ class MemorySnapshotParams: Memory snapshot parameters. Args: - stop_step: Number of steps after which to dump memory snapshot, and stop recording memory history. + start_step: Step from which to start recording memory history. + stop_step: Step after which to dump memory snapshot, and stop recording memory history. max_entries: Maximum number of events to keep in memory. enable_oom_observer: Whether to attach an observer to record OOM events. If stop_step is set, the OOM observer will only be active until stop_step is reached. + + Note: If you set enable_oom_observer to True, you don't necessarily have to set a start_step as attach_oom_observer + will start recording memory history. Note that if you don't set a stop_step, it will continue recording memory + history until the program exits, which may incur a slight performance cost. """ + start_step: Optional[int] = None stop_step: Optional[int] = 2 max_entries: int = 100000 enable_oom_observer: bool = True @@ -44,6 +50,20 @@ class MemorySnapshotProfiler: Args: output_dir: Directory where to save the memory snapshots. memory_snapshot_params: Instance of MemorySnapshotParams. + + Raises: + ValueError: If `start_step` is negative, or `stop_step` is less than or equal to zero. + ValueError: If `start_step` is greater than or equal to `stop_step`. + ValueError: If `start_step` is set and `stop_step` is not set. + ValueError: If `stop_step` is set and neither `start_step` nor `enable_oom_observer` are set. + ValueError: If `enable_oom_observer` is False and neither `start_step` nor `stop_step` is set + + Examples:: + memory_snapshot_params = MemorySnapshotParams(start_step=5, stop_step=10, enable_oom_observer=True) + memory_snapshot_profiler = MemorySnapshotProfiler(output_dir="/tmp", memory_snapshot_params=memory_snapshot_params) + for batch in dataloader: + ... + memory_snapshot_profiler.step() """ def __init__( @@ -55,6 +75,31 @@ def __init__( self.params: MemorySnapshotParams = ( memory_snapshot_params or MemorySnapshotParams() ) + start_step = self.params.start_step + stop_step = self.params.stop_step + if start_step is not None: + if start_step < 0: + raise ValueError("start_step must be nonnegative.") + elif stop_step is None: + raise ValueError("stop_step must be specified when start_step is set.") + elif start_step >= stop_step: + raise ValueError("start_step must be < stop_step.") + if stop_step is not None: + if stop_step <= 0: + raise ValueError("stop_step must be positive.") + elif start_step is None and not self.params.enable_oom_observer: + raise ValueError( + "stop_step must be enabled with either start_step or enable_oom_observer." + ) + if ( + start_step is None + and stop_step is None + and not self.params.enable_oom_observer + ): + raise ValueError( + "At least one of start_step/stop_step or enable_oom_observer must be set." + ) + self.step_num: int = 0 if not is_torch_version_geq_2_0(): @@ -63,6 +108,8 @@ def __init__( attach_oom_observer( output_dir=output_dir, trace_max_entries=self.params.max_entries ) + if start_step is not None and start_step == 0: + self.start() logger.info( f"Created MemorySnapshotProfiler with MemorySnapshotParams={self.params}." @@ -97,6 +144,11 @@ def stop(self) -> None: def step(self) -> None: self.step_num += 1 + if ( + self.params.start_step is not None + and self.step_num == self.params.start_step + ): + self.start() if self.params.stop_step is not None and self.step_num == self.params.stop_step: log_memory_snapshot(output_dir=self.output_dir) self.stop()