diff --git a/ignite/engine/__init__.py b/ignite/engine/__init__.py index 42bf114448a..f20458cd8f2 100644 --- a/ignite/engine/__init__.py +++ b/ignite/engine/__init__.py @@ -95,6 +95,8 @@ def supervised_training_step( ) def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: + if (engine.state.iteration - 1) % gradient_accumulation_steps == 0: + optimizer.zero_grad() model.train() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) @@ -104,7 +106,6 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to loss.backward() if engine.state.iteration % gradient_accumulation_steps == 0: optimizer.step() - optimizer.zero_grad() return output_transform(x, y, y_pred, loss) return update @@ -173,6 +174,8 @@ def supervised_training_step_amp( ) def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: + if (engine.state.iteration - 1) % gradient_accumulation_steps == 0: + optimizer.zero_grad() model.train() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) with autocast(enabled=True): @@ -185,12 +188,10 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to if engine.state.iteration % gradient_accumulation_steps == 0: scaler.step(optimizer) scaler.update() - optimizer.zero_grad() else: loss.backward() if engine.state.iteration % gradient_accumulation_steps == 0: optimizer.step() - optimizer.zero_grad() return output_transform(x, y, y_pred, loss) return update @@ -256,6 +257,8 @@ def supervised_training_step_apex( ) def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: + if (engine.state.iteration - 1) % gradient_accumulation_steps == 0: + optimizer.zero_grad() model.train() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) @@ -266,7 +269,6 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to scaled_loss.backward() if engine.state.iteration % gradient_accumulation_steps == 0: optimizer.step() - optimizer.zero_grad() return output_transform(x, y, y_pred, loss) return update @@ -331,6 +333,8 @@ def supervised_training_step_tpu( ) def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]: + if (engine.state.iteration - 1) % gradient_accumulation_steps == 0: + optimizer.zero_grad() model.train() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) @@ -340,7 +344,6 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to loss.backward() if engine.state.iteration % gradient_accumulation_steps == 0: xm.optimizer_step(optimizer, barrier=True) - optimizer.zero_grad() return output_transform(x, y, y_pred, loss) return update diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 13cb4a8f364..c214f41527c 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -124,6 +124,35 @@ def _(): trainer.run(data) +def _test_create_supervised_trainer_have_grad_after_iteration( + model_device: Optional[str] = None, + trainer_device: Optional[str] = None, + trace: bool = False, + amp_mode: str = None, + scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False, + gradient_accumulation_steps: int = 1, +): + + trainer, model = _default_create_supervised_trainer( + gradient_accumulation_steps=gradient_accumulation_steps, + model_device=model_device, + trainer_device=trainer_device, + trace=trace, + amp_mode=amp_mode, + scaler=scaler, + ) + + x = torch.tensor([[1.0], [1.0], [1.0], [1.0], [1.0]]) + y = torch.tensor([[2.0], [3.0], [4.0], [5.0], [6.0]]) + data = [(_x, _y) for _x, _y in zip(x, y)] + + @trainer.on(Events.ITERATION_COMPLETED) + def _(): + assert model.weight.grad != 0 + + trainer.run(data) + + @pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0") def test_create_supervised_training_scalar_assignment(): @@ -340,25 +369,119 @@ def _test_create_evaluation_step( assert output_transform_mock.called -def test_create_supervised_trainer(): - _test_create_supervised_trainer_wrong_accumulation() - _test_create_supervised_trainer(gradient_accumulation_steps=1) - _test_create_supervised_trainer(gradient_accumulation_steps=3) - _test_create_mocked_supervised_trainer() - - -def test_create_supervised_trainer_with_cpu(): - _test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu") - _test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu") - _test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu") - _test_create_mocked_supervised_trainer(trainer_device="cpu") - - -def test_create_supervised_trainer_traced_with_cpu(): - _test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu") - _test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cpu", trace=True) - _test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cpu", trace=True) - _test_create_mocked_supervised_trainer(trainer_device="cpu", trace=True) +@pytest.mark.parametrize( + ("trainer_device", "model_device", "trace", "amp_mode", "scaler"), + [ + pytest.param(None, None, False, None, False, id="default"), + pytest.param("cpu", None, False, None, False, id="cpu"), + pytest.param("cpu", None, True, None, False, id="traced_with_cpu"), + pytest.param( + "cuda", + "cuda", + False, + None, + False, + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"), + id="cuda", + ), + pytest.param( + "cuda", + "cuda", + False, + "amp", + False, + marks=[ + pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"), + pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0"), + ], + id="cuda_amp", + ), + pytest.param( + "cuda", + "cuda", + False, + "amp", + True, + marks=[ + pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"), + pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0"), + ], + id="cuda_amp_scaler", + ), + pytest.param( + "cuda", + "cuda", + False, + "amp", + torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available()), + marks=[ + pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"), + pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0"), + ], + id="cuda_amp_gradscaler", + ), + pytest.param( + "cuda", + "cuda", + False, + "apex", + False, + marks=[ + pytest.mark.skip(reason="Temporarily disabled, as it fails because of an issue from apex side"), + # pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU"), + # pytest.mark.skipif(not find_spec("apex"), reason="Skip if no APEX") + ], + id="cuda_apex", + ), + pytest.param( + "xla", + "xla", + False, + None, + False, + marks=[ + pytest.mark.tpu, + pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars"), + pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package"), + ], + id="tpu", + ), + pytest.param( + "cuda", + None, + False, + None, + False, + marks=[pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")], + id="cuda_with_model_on_cpu", + ), + ], +) +def test_create_supervised_trainer(trainer_device, model_device, trace, amp_mode, scaler): + _test_create_supervised_trainer_wrong_accumulation(model_device, trainer_device, amp_mode) + _test_create_supervised_trainer( + gradient_accumulation_steps=1, + model_device=model_device, + trainer_device=trainer_device, + trace=trace, + amp_mode=amp_mode, + scaler=scaler, + ) + _test_create_supervised_trainer( + gradient_accumulation_steps=3, + model_device=model_device, + trainer_device=trainer_device, + trace=trace, + amp_mode=amp_mode, + scaler=scaler, + ) + _test_create_mocked_supervised_trainer(model_device, trainer_device, trace, amp_mode, scaler) + _test_create_supervised_trainer_have_grad_after_iteration( + model_device, trainer_device, trace, amp_mode, scaler, gradient_accumulation_steps=1 + ) + _test_create_supervised_trainer_have_grad_after_iteration( + model_device, trainer_device, trace, amp_mode, scaler, gradient_accumulation_steps=3 + ) @pytest.mark.skipif(find_spec("apex"), reason="Skip if APEX") @@ -405,96 +528,6 @@ def test_create_supervised_trainer_scaler_not_amp(): _test_create_supervised_trainer(amp_mode="apex", scaler=scaler) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") -def test_create_supervised_trainer_on_cuda(): - model_device = trainer_device = "cuda" - _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) - _test_create_supervised_trainer( - gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device - ) - _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) - - -@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") -def test_create_supervised_trainer_on_cuda_amp(): - model_device = trainer_device = "cuda" - _test_create_supervised_trainer_wrong_accumulation( - model_device=model_device, trainer_device=trainer_device, amp_mode="amp" - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="amp" - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="amp" - ) - _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp") - - -@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") -def test_create_supervised_trainer_on_cuda_amp_scaler(): - model_device = trainer_device = "cuda" - _test_create_supervised_trainer_wrong_accumulation( - model_device=model_device, trainer_device=trainer_device, amp_mode="amp" - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=1, - model_device=model_device, - trainer_device=trainer_device, - amp_mode="amp", - scaler=True, - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=3, - model_device=model_device, - trainer_device=trainer_device, - amp_mode="amp", - scaler=True, - ) - _test_create_mocked_supervised_trainer( - model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=True - ) - scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available()) - _test_create_supervised_trainer( - gradient_accumulation_steps=1, - model_device=model_device, - trainer_device=trainer_device, - amp_mode="amp", - scaler=scaler, - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=3, - model_device=model_device, - trainer_device=trainer_device, - amp_mode="amp", - scaler=scaler, - ) - _test_create_mocked_supervised_trainer( - model_device=model_device, trainer_device=trainer_device, amp_mode="amp", scaler=scaler - ) - - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") -# @pytest.mark.skipif(not find_spec("apex"), reason="Skip if no APEX") -@pytest.mark.skip(reason="Temporarily disabled, as it fails because of an issue from apex side") -def test_create_supervised_trainer_on_cuda_apex(): - model_device = trainer_device = "cuda" - _test_create_supervised_trainer_wrong_accumulation( - model_device=model_device, trainer_device=trainer_device, amp_mode="apex" - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device, amp_mode="apex" - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device, amp_mode="apex" - ) - _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="apex") - - @pytest.mark.skipif(idist.has_xla_support, reason="Skip if has PyTorch XLA package") def test_supervised_training_step_tpu_no_xla(): with pytest.raises(ModuleNotFoundError, match="torch_xla cannot be imported, please install PyTorch XLA."): @@ -509,21 +542,6 @@ def test_create_supervised_trainer_on_tpu_no_xla(): _test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device) -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_create_supervised_trainer_on_tpu(): - model_device = trainer_device = "xla" - _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) - _test_create_supervised_trainer( - gradient_accumulation_steps=1, model_device=model_device, trainer_device=trainer_device - ) - _test_create_supervised_trainer( - gradient_accumulation_steps=3, model_device=model_device, trainer_device=trainer_device - ) - _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) - - @pytest.mark.tpu @pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") def test_create_supervised_trainer_on_tpu_amp(): @@ -532,14 +550,6 @@ def test_create_supervised_trainer_on_tpu_amp(): _test_create_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU") -def test_create_supervised_trainer_on_cuda_with_model_on_cpu(): - _test_create_supervised_trainer_wrong_accumulation(trainer_device="cuda") - _test_create_supervised_trainer(gradient_accumulation_steps=1, trainer_device="cuda") - _test_create_supervised_trainer(gradient_accumulation_steps=3, trainer_device="cuda") - _test_create_mocked_supervised_trainer(trainer_device="cuda") - - def test_create_supervised_evaluator(): _test_create_supervised_evaluator() _test_mocked_supervised_evaluator()