diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 286803f7fe44f..45ed0edae7dc1 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -49,21 +49,18 @@ def write_on_epoch_end(self, *args, **kwargs): cb = CustomPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) assert cb.write_on_batch_end_called assert cb.write_on_epoch_end_called - assert results == 1 cb = CustomPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) assert cb.write_on_batch_end_called assert not cb.write_on_epoch_end_called - assert results == 1 cb = CustomPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) assert not cb.write_on_batch_end_called assert cb.write_on_epoch_end_called - assert results == 1