Skip to content

Commit

Permalink
make evaluate private (Lightning-AI#1260)
Browse files Browse the repository at this point in the history
* make evaluate private

* changelog
  • Loading branch information
Borda authored and akarnachev committed Apr 3, 2020
1 parent f9a2b75 commit 097f8fe
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))

### Deprecated

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def reset_test_dataloader(self, *args):
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_mode: bool = False):
"""Run evaluation code.
Args:
Expand Down Expand Up @@ -365,7 +365,7 @@ def run_evaluation(self, test_mode: bool = False):
setattr(self, f'{"test" if test_mode else "val"}_progress_bar', pbar)

# run evaluation
eval_results = self.evaluate(self.model, dataloaders, max_batches, test_mode)
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,10 +894,10 @@ def run_pretrain_routine(self, model: LightningModule):
# dummy validation progress bar
self.val_progress_bar = tqdm(disable=True)

eval_results = self.evaluate(model,
self.val_dataloaders,
self.num_sanity_val_steps,
False)
eval_results = self._evaluate(model,
self.val_dataloaders,
self.num_sanity_val_steps,
False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)

# close progress bars
Expand Down
4 changes: 2 additions & 2 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks():

trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': 0.6}

model = ModelVer0_7(hparams)
Expand All @@ -106,5 +106,5 @@ def test_tbd_remove_in_v1_0_0_model_hooks():

trainer = Trainer(logger=False)
# TODO: why `dataloder` is required if it is not used
result = trainer.evaluate(model, dataloaders=[[None]], max_batches=1)
result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1)
assert result == {'val_loss': 0.7}

0 comments on commit 097f8fe

Please sign in to comment.