-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
There has been alot of discussion around logging, trainer.predict, evaluation hooks and callbacks. I think I can boil this down to a reproducible example that will be useful for the community.
What has been discussed so far.
#10365
#16258 (where I started the example below)
#16822
#7333
From these links, there is no clear guidance between using trainer.predict_step() and trainer.predict in why one can use logging and the other cannot. This is flirting with being a bug, but appears to be intended behavior from the comment below.
We are not inside a predict hook, we are inside a evaluation_hook. We did use trainer.predict, with all of its great functionality, to generate a set of predictions.
Expected behavior
I understand from the above issues as stated by @carmocca (#7333 (comment)) that we cannot overwrite the trainer state. Why doesn't this work with a new trainer?
What version are you seeing the problem on?
v2.1
How to reproduce the bug
import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
class MyModel(BoringModel):
def on_validation_epoch_end(self):
if self.trainer.sanity_checking: # optional skip
return
print("Start predicting!")
for i, batch in enumerate(self.predict_dataloader()):
batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
out = self.predict_step(batch, i)
print(i, out)
self.log("metric", 1.0)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run_predict_step():
model = MyModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
accelerator="auto",
limit_train_batches=1,
limit_val_batches=1,
fast_dev_run=True,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model)
class MyModel2(BoringModel):
def on_validation_epoch_end(self):
if self.trainer.sanity_checking: # optional skip
return
print("Start predicting!")
dataloader = self.predict_dataloader()
new_trainer = Trainer(
default_root_dir=os.getcwd(),
accelerator="auto",
limit_train_batches=1,
limit_val_batches=1,
fast_dev_run=True,
max_epochs=1,
enable_model_summary=False)
new_trainer.predict(self, dataloaders=dataloader)
self.log("metric", 1.0)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run_trainer_predict():
model = MyModel2()
trainer = Trainer(
default_root_dir=os.getcwd(),
accelerator="auto",
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
fast_dev_run=True,
enable_model_summary=False,
)
trainer.fit(model)
if __name__ == "__main__":
# This works
run_predict_step()
# This does not work
run_trainer_predict()Error messages and logs
You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging
Environment
(DeepForest) benweinstein@Bens-MacBook-Pro Downloads % python collect_env_details.py
<details>
<summary>Current environment</summary>
* CUDA:
- GPU: None
- available: False
- version: None
* Lightning:
- lightning-lite: 1.8.0.post1
- lightning-utilities: 0.8.0
- pytorch-lightning: 2.1.2
- torch: 1.12.1
- torchmetrics: 1.2.0
- torchvision: 0.13.1
* Packages:
- absl-py: 0.13.0
- accessible-pygments: 0.0.4
- affine: 2.3.0
- aiohttp: 3.7.4.post0
- alabaster: 0.7.12
- albumentations: 1.1.0
- async-timeout: 3.0.1
- attrs: 21.2.0
- babel: 2.9.1
- beautifulsoup4: 4.12.2
- bleach: 4.0.0
- brotlipy: 0.7.0
- bumpversion: 0.5.3
- cached-property: 1.5.2
- cachetools: 4.2.2
- certifi: 2021.5.30
- cffi: 1.14.6
- chardet: 4.0.0
- click: 7.1.2
- click-plugins: 1.1.1
- cligj: 0.7.2
- cmarkgfm: 0.4.2
- colorama: 0.4.4
- commonmark: 0.9.1
- cryptography: 3.4.7
- cycler: 0.10.0
- docutils: 0.18.1
- execnet: 2.0.2
- fiona: 1.8.20
- fire: 0.4.0
- fonttools: 4.25.0
- fsspec: 2021.7.0
- furo: 2023.9.10
- future: 0.18.2
- gdal: 3.3.1
- geopandas: 0.9.0
- google-auth: 1.34.0
- google-auth-oauthlib: 0.4.5
- gprof2dot: 2022.7.29
- grpcio: 1.39.0
- h5py: 3.3.0
- idna: 2.10
- imagecodecs: 2021.7.30
- imageio: 2.9.0
- imagesize: 1.4.1
- importlib-metadata: 6.8.0
- iniconfig: 1.1.1
- jinja2: 3.0.1
- joblib: 1.0.1
- keyring: 23.0.1
- kiwisolver: 1.3.1
- lightning-lite: 1.8.0.post1
- lightning-utilities: 0.8.0
- mapclassify: 2.4.3
- markdown: 3.3.4
- markupsafe: 2.0.1
- matplotlib: 3.4.2
- more-itertools: 8.8.0
- multidict: 5.1.0
- munch: 2.5.0
- munkres: 1.1.4
- networkx: 2.6.2
- numpy: 1.21.1
- numpydoc: 1.1.0
- oauthlib: 3.1.1
- olefile: 0.46
- opencv-python: 4.6.0.66
- packaging: 21.0
- pandas: 1.3.1
- pillow: 9.2.0
- pip: 21.2.2
- pkginfo: 1.7.1
- platformdirs: 3.11.0
- pluggy: 0.13.1
- progressbar2: 4.2.0
- protobuf: 3.17.3
- psutil: 5.8.0
- py: 1.10.0
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pycocotools: 2.0.7
- pycparser: 2.20
- pydata-sphinx-theme: 0.14.1
- pydeprecate: 0.3.1
- pygments: 2.16.1
- pyopenssl: 20.0.1
- pyparsing: 2.4.7
- pyproj: 3.1.0
- pysocks: 1.7.1
- pytest: 6.2.4
- pytest-profiling: 1.7.0
- pytest-xdist: 3.3.1
- python-dateutil: 2.8.2
- python-utils: 3.4.5
- pytorch-lightning: 2.1.2
- pytz: 2021.1
- pywavelets: 1.1.1
- pyyaml: 5.4.1
- qudida: 0.0.4
- rasterio: 1.2.6
- readme-renderer: 24.0
- recommonmark: 0.7.1
- requests: 2.25.1
- requests-oauthlib: 1.3.0
- requests-toolbelt: 0.9.1
- rfc3986: 1.4.0
- rsa: 4.7.2
- rtree: 0.9.7
- scikit-image: 0.18.2
- scikit-learn: 0.24.2
- scipy: 1.7.0
- setuptools: 59.5.0
- shapely: 1.7.1
- six: 1.16.0
- slidingwindow: 0.0.14
- snakeviz: 2.1.1
- snowballstemmer: 2.1.0
- snuggs: 1.4.7
- soupsieve: 2.5
- sphinx: 7.2.6
- sphinx-basic-ng: 1.0.0b2
- sphinx-markdown-tables: 0.0.15
- sphinx-rtd-theme: 1.3.0
- sphinxcontrib-applehelp: 1.0.2
- sphinxcontrib-devhelp: 1.0.2
- sphinxcontrib-htmlhelp: 2.0.0
- sphinxcontrib-jquery: 4.1
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-qthelp: 1.0.3
- sphinxcontrib-serializinghtml: 1.1.9
- tensorboard: 2.10.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.0
- termcolor: 2.1.0
- threadpoolctl: 2.2.0
- tifffile: 2021.7.30
- toml: 0.10.2
- tomli: 2.0.1
- torch: 1.12.1
- torchmetrics: 1.2.0
- torchvision: 0.13.1
- tornado: 6.1
- tqdm: 4.62.0
- twine: 0.0.0
- typing-extensions: 4.3.0
- urllib3: 1.26.6
- webencodings: 0.5.1
- werkzeug: 2.0.1
- wheel: 0.36.2
- xmltodict: 0.12.0
- yapf: 0.40.2
- yarl: 1.6.3
- zipp: 3.5.0
* System:
- OS: Darwin
- architecture:
- 64bit
-
- processor: i386
- python: 3.9.6
- release: 23.1.0
- version: Darwin Kernel Version 23.1.0: Mon Oct 9 21:27:27 PDT 2023; root:xnu-10002.41.9~6/RELEASE_X86_64
</details>
More info
No response