diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e78d1c66c..db0dc3f7a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version ([#493](https://github.com/PyTorchLightning/lightning-flash/pull/493)) + + ## [0.4.0] - 2021-06-22 ### Added diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 70e06af707..44faef2810 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -87,8 +87,15 @@ def __init__(self, *args, serve_sanity_check: bool = False, **kwargs): self.serve_sanity_check = serve_sanity_check + def _run_sanity_check(self, ref_model): + if hasattr(super(), "_run_sanity_check"): + super()._run_sanity_check(ref_model) + + self.run_sanity_check(ref_model) + def run_sanity_check(self, ref_model): - super().run_sanity_check(ref_model) + if hasattr(super(), "run_sanity_check"): + super().run_sanity_check(ref_model) if self.serve_sanity_check and ref_model.is_servable and _SERVE_AVAILABLE: ref_model.run_serve_sanity_check()