diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index 524e122478bad1..bb9ef53e968b2b 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -436,7 +436,7 @@ def test_replication_factor(tmpdir): plugin.model = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.pre_dispatch() + trainer.training_type_plugin.pre_dispatch(trainer) trainer.state.stage = RunningStage.TRAINING assert trainer.training_type_plugin.replication_factor == 8 @@ -450,7 +450,7 @@ def test_replication_factor(tmpdir): ): trainer.state.fn = fn trainer.state.stage = stage - trainer.training_type_plugin.pre_dispatch() + trainer.training_type_plugin.pre_dispatch(trainer) assert trainer.training_type_plugin.replication_factor == 7 @@ -585,7 +585,7 @@ def test_poptorch_models_at_different_stages(tmpdir): trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING - trainer.training_type_plugin.pre_dispatch() + trainer.training_type_plugin.pre_dispatch(trainer) assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING] for fn, stage in ( @@ -595,7 +595,7 @@ def test_poptorch_models_at_different_stages(tmpdir): ): trainer.state.fn = fn trainer.state.stage = stage - trainer.training_type_plugin.pre_dispatch() + trainer.training_type_plugin.pre_dispatch(trainer) assert list(trainer.training_type_plugin.poptorch_models) == [stage]