Skip to content

Commit

Permalink
fix ipu tests
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Dec 1, 2021
1 parent 558e4f5 commit 7192d12
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_replication_factor(tmpdir):
plugin.model = model
model.trainer = trainer
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.pre_dispatch()
trainer.training_type_plugin.pre_dispatch(trainer)

trainer.state.stage = RunningStage.TRAINING
assert trainer.training_type_plugin.replication_factor == 8
Expand All @@ -450,7 +450,7 @@ def test_replication_factor(tmpdir):
):
trainer.state.fn = fn
trainer.state.stage = stage
trainer.training_type_plugin.pre_dispatch()
trainer.training_type_plugin.pre_dispatch(trainer)
assert trainer.training_type_plugin.replication_factor == 7


Expand Down Expand Up @@ -585,7 +585,7 @@ def test_poptorch_models_at_different_stages(tmpdir):

trainer.optimizers = model.configure_optimizers()[0]
trainer.state.fn = TrainerFn.FITTING
trainer.training_type_plugin.pre_dispatch()
trainer.training_type_plugin.pre_dispatch(trainer)
assert list(trainer.training_type_plugin.poptorch_models) == [RunningStage.TRAINING, RunningStage.VALIDATING]

for fn, stage in (
Expand All @@ -595,7 +595,7 @@ def test_poptorch_models_at_different_stages(tmpdir):
):
trainer.state.fn = fn
trainer.state.stage = stage
trainer.training_type_plugin.pre_dispatch()
trainer.training_type_plugin.pre_dispatch(trainer)
assert list(trainer.training_type_plugin.poptorch_models) == [stage]


Expand Down

0 comments on commit 7192d12

Please sign in to comment.