diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 1014e70c45..7c9ccbd5a0 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -658,7 +658,7 @@ def test_is_last_batch(self) -> None: my_unit = DummyAutoUnit(module=my_module) train(my_unit, dataloader, max_epochs=1, max_steps_per_epoch=4) - self.assertFalse(my_unit._is_last_train_batch) + self.assertFalse(my_unit._is_last_batch) def test_auto_unit_timing_train(self) -> None: """ @@ -844,12 +844,12 @@ def test_get_next_batch_with_single_phase(self) -> None: batch = auto_unit._get_next_batch(state, second_data_iter) self.assertEqual(batch, 3) self._assert_next_batch_dicts(auto_unit, train_prefetched=True) - self.assertTrue(auto_unit._is_last_train_batch) + self.assertTrue(auto_unit._is_last_batch) with move_data_to_device_mock, self.assertRaises(StopIteration): auto_unit._get_next_batch(state, second_data_iter) self._assert_next_batch_dicts(auto_unit) - self.assertFalse(auto_unit._is_last_train_batch) + self.assertFalse(auto_unit._is_last_batch) def test_get_next_batch_with_multiple_phases(self) -> None: auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2)) @@ -990,7 +990,7 @@ def compute_loss( ) -> Tuple[torch.Tensor, torch.Tensor]: tc = unittest.TestCase() tc.assertEqual( - self._is_last_train_batch, + self._is_last_batch, self.train_progress.num_steps_completed_in_epoch + 1 == self.expected_steps_per_epoch, ) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 37fdff7c2b..895543c634 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -9,7 +9,17 @@ from abc import ABCMeta, abstractmethod from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Generic, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch from pyre_extensions import none_throws @@ -115,47 +125,24 @@ def __call__(self, *args, **kwargs): return x -class AutoPredictUnit(PredictUnit[TPredictData]): +class _AutoUnitMixin(Generic[TData]): + """ + A mixin to share initialization of shared attributes and introduce prefetching. + """ + def __init__( self, *, module: torch.nn.Module, device: Optional[torch.device] = None, - strategy: Optional[Union[Strategy, str]] = None, precision: Optional[Union[str, torch.dtype]] = None, - torch_compile_params: Optional[TorchCompileParams] = None, detect_anomaly: Optional[bool] = None, + torch_compile_params: Optional[TorchCompileParams] = None, ) -> None: - """ - AutoPredictUnit is a convenience for users who are running inference and would like to have certain features handled for them, such as: - - Moving data to the correct device. - - Running inference under a mixed precision context. - - Handling data parallel replication, especially if the module cannot fit on a single device using FullyShardedDataParallel. - - Profiling the data transfer to device and forward pass. - - Interleaving moving the next batch to the device with running the module's forward pass on the current batch. - - Additionally, the AutoPredictUnit offers an optional hook ``on_predict_step_end`` to further post-process module outputs if desired. - - Then use with the :py:func:`~torchtnt.framework.predict` entry point. - - For more advanced customization, directly use the :class:`~torchtnt.framework.unit.PredictUnit` interface. - - Args: - module: module to be used during prediction. - device: the device to be used. - precision: the precision to use in training, as either a string or a torch.dtype. - strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp`` or ``fsdp``. - torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html - detect_anomaly: whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection - - Note: - Torch compile support is only available in PyTorch 2.0 or higher. - """ + super().__init__() if torch_compile_params: _validate_torch_compile_available() - super().__init__() - self.device: torch.device = device or init_from_env() self.precision: Optional[torch.dtype] = ( convert_precision_str_to_dtype(precision) @@ -163,25 +150,21 @@ def __init__( else precision ) + self.detect_anomaly = detect_anomaly + # create autocast context based on precision and device type self.maybe_autocast_precision = torch.autocast( device_type=self.device.type, dtype=self.precision, enabled=self.precision is not None, ) - self.module: torch.nn.Module = prepare_module( - module, - self.device, - strategy=strategy, - torch_compile_params=torch_compile_params, - ) # cuda stream to use for moving data to device self._prefetch_stream: Optional[torch.cuda.streams.Stream] = ( torch.cuda.Stream() if self.device.type == "cuda" else None ) # dict mapping phase to whether the next batch which has been prefetched for that phase and is ready to be used - self._phase_to_next_batch: dict[ActivePhase, Optional[TPredictData]] = { + self._phase_to_next_batch: dict[ActivePhase, Optional[TData]] = { ActivePhase.TRAIN: None, ActivePhase.EVALUATE: None, ActivePhase.PREDICT: None, @@ -193,62 +176,21 @@ def __init__( ActivePhase.EVALUATE: False, ActivePhase.PREDICT: False, } - - self.detect_anomaly = detect_anomaly - - # pyre-fixme[3]: Return annotation cannot be `Any`. - def predict_step(self, state: State, data: Iterator[TPredictData]) -> Any: - with none_throws(state.predict_state).iteration_timer.time("data_wait_time"): - batch = self._get_next_batch(state, data) - - # if detect_anomaly is true, run forward pass under detect_anomaly context - detect_anomaly = self.detect_anomaly - maybe_detect_anomaly = ( - torch.autograd.set_detect_anomaly(detect_anomaly) - if detect_anomaly is not None - else contextlib.nullcontext() - ) - - with self.maybe_autocast_precision, maybe_detect_anomaly: - with get_timing_context(state, f"{self.__class__.__name__}.forward"): - outputs = self.module(batch) - - step = self.predict_progress.num_steps_completed - self.on_predict_step_end(state, batch, step, outputs) - return outputs - - def on_predict_step_end( - self, - state: State, - data: TPredictData, - step: int, - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - outputs: Any, - ) -> None: - """ - This will be called at the end of every ``predict_step`` before returning. The user can implement this method with code to update and log their metrics, - or do anything else. - - Args: - state: a State object which is passed from the ``predict_step`` - data: a batch of data which is passed from the ``predict_step`` - step: how many ``predict_step`` s have been completed - outputs: the outputs of the model forward pass - """ - pass + # whether the current batch is the last train batch + self._is_last_batch: bool = False def move_data_to_device( - self, state: State, data: TPredictData, non_blocking: bool - ) -> TPredictData: + self, state: State, data: TData, non_blocking: bool + ) -> TData: """ - The user can override this method with custom code to copy data to device. This will be called at the start of every ``predict_step``. + The user can override this method with custom code to copy data to device. This will be called at the start of every ``train_step``/``eval_step``/``predict_step``. By default this uses the utility function :py:func:`~torchtnt.utils.copy_data_to_device`. If on GPU, this method will be called on a separate CUDA stream. Args: - state: a State object which is passed from the ``predict_step`` - data: a batch of data which is passed from the ``predict_step`` + state: a State object which is passed from the ``train_step``/``eval_step``/``predict_step`` + data: a batch of data which is passed from the ``train_step``/``eval_step``/``predict_step`` non_blocking: parameter to pass to ``torch.tensor.to`` Returns: @@ -256,9 +198,33 @@ def move_data_to_device( """ return copy_data_to_device(data, self.device, non_blocking=non_blocking) - def _get_next_batch( - self, state: State, data: Iterator[TPredictData] - ) -> TPredictData: + def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None: + """Prefetch the next batch on a separate CUDA stream.""" + active_phase = state.active_phase + phase = state.active_phase.name.lower() + try: + with get_timing_context( + state, f"{self.__class__.__name__}.{phase}.next(data_iter)" + ): + next_batch = next(data_iter) + except StopIteration: + self._phase_to_next_batch[active_phase] = None + self._is_last_batch = True + return + + non_blocking = bool( + self.device.type == "cuda" and self._phase_to_prefetched[active_phase] + ) + + # if on cpu, self._prefetch_stream is None so the torch.cuda.stream call is a no-op + with torch.cuda.stream(self._prefetch_stream), get_timing_context( + state, f"{self.__class__.__name__}.{phase}.move_data_to_device" + ): + self._phase_to_next_batch[active_phase] = self.move_data_to_device( + state, next_batch, non_blocking=non_blocking + ) + + def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData: active_phase = state.active_phase if not self._phase_to_prefetched[active_phase]: self._prefetch_next_batch(state, data) @@ -273,6 +239,7 @@ def _get_next_batch( batch = self._phase_to_next_batch[active_phase] if batch is None: self._phase_to_prefetched[active_phase] = False + self._is_last_batch = False raise StopIteration if self._prefetch_stream: @@ -282,38 +249,106 @@ def _get_next_batch( # record the batch in the current stream record_data_in_stream(batch, torch.cuda.current_stream()) - # kick off prefetching the next batch + # prefetch the next batch self._prefetch_next_batch(state, data) + return batch - def _prefetch_next_batch( - self, state: State, data_iter: Iterator[TPredictData] + +class AutoPredictUnit(_AutoUnitMixin[TPredictData], PredictUnit[TPredictData]): + def __init__( + self, + *, + module: torch.nn.Module, + device: Optional[torch.device] = None, + strategy: Optional[Union[Strategy, str]] = None, + precision: Optional[Union[str, torch.dtype]] = None, + torch_compile_params: Optional[TorchCompileParams] = None, + detect_anomaly: Optional[bool] = None, ) -> None: - """Prefetch the next batch on a separate CUDA stream.""" - active_phase = state.active_phase - try: - with get_timing_context( - state, f"{self.__class__.__name__}.next(data_iter)" - ): - next_batch = next(data_iter) - except StopIteration: - self._phase_to_next_batch[active_phase] = None - return + """ + AutoPredictUnit is a convenience for users who are running inference and would like to have certain features handled for them, such as: + - Moving data to the correct device. + - Running inference under a mixed precision context. + - Handling data parallel replication, especially if the module cannot fit on a single device using FullyShardedDataParallel. + - Profiling the data transfer to device and forward pass. + - Interleaving moving the next batch to the device with running the module's forward pass on the current batch. - non_blocking = bool( - self.device.type == "cuda" and self._phase_to_prefetched[active_phase] + Additionally, the AutoPredictUnit offers an optional hook ``on_predict_step_end`` to further post-process module outputs if desired. + + Then use with the :py:func:`~torchtnt.framework.predict` entry point. + + For more advanced customization, directly use the :class:`~torchtnt.framework.unit.PredictUnit` interface. + + Args: + module: module to be used during prediction. + device: the device to be used. + precision: the precision to use in training, as either a string or a torch.dtype. + strategy: the data parallelization strategy to be used. if a string, must be one of ``ddp`` or ``fsdp``. + torch_compile_params: params for Torch compile https://pytorch.org/docs/stable/generated/torch.compile.html + detect_anomaly: whether to enable anomaly detection for the autograd engine https://pytorch.org/docs/stable/autograd.html#anomaly-detection + + Note: + Torch compile support is only available in PyTorch 2.0 or higher. + """ + super().__init__( + module=module, + device=device, + precision=precision, + torch_compile_params=torch_compile_params, + detect_anomaly=detect_anomaly, + ) + self.module: torch.nn.Module = prepare_module( + module, + self.device, + strategy=strategy, + torch_compile_params=torch_compile_params, ) - # if on cpu, self._prefetch_stream is None so the torch.cuda.stream call is a no-op - with torch.cuda.stream(self._prefetch_stream), get_timing_context( - state, f"{self.__class__.__name__}.move_data_to_device" - ): - self._phase_to_next_batch[active_phase] = self.move_data_to_device( - state, next_batch, non_blocking=non_blocking - ) + # pyre-fixme[3]: Return annotation cannot be `Any`. + def predict_step(self, state: State, data: Iterator[TPredictData]) -> Any: + with none_throws(state.predict_state).iteration_timer.time("data_wait_time"): + batch = self._get_next_batch(state, data) + + # if detect_anomaly is true, run forward pass under detect_anomaly context + detect_anomaly = self.detect_anomaly + maybe_detect_anomaly = ( + torch.autograd.set_detect_anomaly(detect_anomaly) + if detect_anomaly is not None + else contextlib.nullcontext() + ) + + with self.maybe_autocast_precision, maybe_detect_anomaly: + with get_timing_context(state, f"{self.__class__.__name__}.forward"): + outputs = self.module(batch) + + step = self.predict_progress.num_steps_completed + self.on_predict_step_end(state, batch, step, outputs) + return outputs + + def on_predict_step_end( + self, + state: State, + data: TPredictData, + step: int, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + outputs: Any, + ) -> None: + """ + This will be called at the end of every ``predict_step`` before returning. The user can implement this method with code to update and log their metrics, + or do anything else. + + Args: + state: a State object which is passed from the ``predict_step`` + data: a batch of data which is passed from the ``predict_step`` + step: how many ``predict_step`` s have been completed + outputs: the outputs of the model forward pass + """ + pass class AutoUnit( + _AutoUnitMixin[TData], TrainUnit[TData], EvalUnit[TData], PredictUnit[TData], @@ -388,21 +423,18 @@ def __init__( activation_checkpoint_params: Optional[ActivationCheckpointParams] = None, training: bool = True, ) -> None: - super().__init__() + super().__init__( + module=module, + device=device, + precision=precision, + detect_anomaly=detect_anomaly, + torch_compile_params=torch_compile_params, + ) if not gradient_accumulation_steps > 0: raise ValueError( f"gradient_accumulation_steps must be > 0. Got {gradient_accumulation_steps}" ) - if torch_compile_params: - _validate_torch_compile_available() - - self.device: torch.device = device or init_from_env() - self.precision: Optional[torch.dtype] = ( - convert_precision_str_to_dtype(precision) - if isinstance(precision, str) - else precision - ) self.swa_params: Optional[SWAParams] = swa_params self.swa_model: Optional[AveragedModel] = None @@ -444,39 +476,13 @@ def __init__( self.gradient_accumulation_steps = gradient_accumulation_steps - self.detect_anomaly = detect_anomaly self.clip_grad_norm = clip_grad_norm self.clip_grad_value = clip_grad_value # create autocast context based on precision and device type - self.maybe_autocast_precision = torch.autocast( - device_type=self.device.type, - dtype=self.precision, - enabled=self.precision is not None, - ) self.training = training - # cuda stream to use for moving data to device - self._prefetch_stream: Optional[torch.cuda.streams.Stream] = ( - torch.cuda.Stream() if self.device.type == "cuda" else None - ) - # dict mapping phase to whether the next batch which has been prefetched for that phase and is ready to be used - self._phase_to_next_batch: dict[ActivePhase, Optional[TData]] = { - ActivePhase.TRAIN: None, - ActivePhase.EVALUATE: None, - ActivePhase.PREDICT: None, - } - - # dict mapping phase to whether the next batch for that phase has been prefetched and is ready to be used - self._phase_to_prefetched: dict[ActivePhase, bool] = { - ActivePhase.TRAIN: False, - ActivePhase.EVALUATE: False, - ActivePhase.PREDICT: False, - } - # whether the current batch is the last train batch - self._is_last_train_batch: bool = False - self.optimizer: Optional[torch.optim.Optimizer] = None self.lr_scheduler: Optional[TLRScheduler] = None self.swa_scheduler: Optional[SWALR] = None @@ -515,81 +521,6 @@ def compute_loss(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: """ ... - def move_data_to_device( - self, state: State, data: TData, non_blocking: bool - ) -> TData: - """ - The user can override this method with custom code to copy data to device. This will be called at the start of every ``train_step``/``eval_step``. - By default this uses the utility function :py:func:`~torchtnt.utils.copy_data_to_device`. - - If on GPU, this method will be called on a separate CUDA stream. - - Args: - state: a State object which is passed from the ``train_step``/``eval_step`` - data: a batch of data which is passed from the ``train_step``/``eval_step`` - non_blocking: parameter to pass to ``torch.tensor.to`` - - Returns: - A batch of data which is on the device - """ - return copy_data_to_device(data, self.device, non_blocking=non_blocking) - - def _prefetch_next_batch(self, state: State, data_iter: Iterator[TData]) -> None: - """Prefetch the next batch on a separate CUDA stream.""" - active_phase = state.active_phase - phase = state.active_phase.name.lower() - try: - with get_timing_context( - state, f"{self.__class__.__name__}.{phase}.next(data_iter)" - ): - next_batch = next(data_iter) - except StopIteration: - self._phase_to_next_batch[active_phase] = None - self._is_last_train_batch = True - return - - non_blocking = bool( - self.device.type == "cuda" and self._phase_to_prefetched[active_phase] - ) - - # if on cpu, self._prefetch_stream is None so the torch.cuda.stream call is a no-op - with torch.cuda.stream(self._prefetch_stream), get_timing_context( - state, f"{self.__class__.__name__}.{phase}.move_data_to_device" - ): - self._phase_to_next_batch[active_phase] = self.move_data_to_device( - state, next_batch, non_blocking=non_blocking - ) - - def _get_next_batch(self, state: State, data: Iterator[TData]) -> TData: - active_phase = state.active_phase - if not self._phase_to_prefetched[active_phase]: - self._prefetch_next_batch(state, data) - self._phase_to_prefetched[active_phase] = True - - if self._prefetch_stream: - with get_timing_context(state, f"{self.__class__.__name__}.wait_stream"): - # wait on the CUDA stream to complete the host to device copy - torch.cuda.current_stream().wait_stream(self._prefetch_stream) - - # get the next batch which was stored by _prefetch_next_batch - batch = self._phase_to_next_batch[active_phase] - if batch is None: - self._phase_to_prefetched[active_phase] = False - self._is_last_train_batch = False - raise StopIteration - - if self._prefetch_stream: - with get_timing_context( - state, f"{self.__class__.__name__}.record_data_in_stream" - ): - # record the batch in the current stream - record_data_in_stream(batch, torch.cuda.current_stream()) - - # prefetch the next batch - self._prefetch_next_batch(state, data) - - return batch - def _should_update_swa(self) -> bool: if not self.swa_params: return False @@ -624,7 +555,7 @@ def train_step( should_update_weights = ( self.train_progress.num_steps_completed_in_epoch + 1 - ) % self.gradient_accumulation_steps == 0 or self._is_last_train_batch + ) % self.gradient_accumulation_steps == 0 or self._is_last_batch # for pyre, assign to local variable module = self.module @@ -798,7 +729,7 @@ def on_train_epoch_end(self, state: State) -> None: ): lr_scheduler.step() - self._is_last_train_batch = False + self._is_last_batch = False # pyre-fixme[3]: Return annotation cannot contain `Any`. def eval_step(