From e3bbb6dc2965365c72f4b2ada61cb36469f60c22 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 17:01:37 +0200 Subject: [PATCH] Fix test --- tests/callbacks/test_prediction_writer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 286803f7fe44f..45ed0edae7dc1 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -49,21 +49,18 @@ def write_on_epoch_end(self, *args, **kwargs): cb = CustomPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) assert cb.write_on_batch_end_called assert cb.write_on_epoch_end_called - assert results == 1 cb = CustomPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) assert cb.write_on_batch_end_called assert not cb.write_on_epoch_end_called - assert results == 1 cb = CustomPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) + trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) assert not cb.write_on_batch_end_called assert cb.write_on_epoch_end_called - assert results == 1