Skip to content

Cannot call self.log in evaluation_hooks after using trainer.predict, even if using a new trainer object.  #19101

@bw4sz

Description

@bw4sz

Bug description

There has been alot of discussion around logging, trainer.predict, evaluation hooks and callbacks. I think I can boil this down to a reproducible example that will be useful for the community.

What has been discussed so far.

#10365
#16258 (where I started the example below)
#16822
#7333

From these links, there is no clear guidance between using trainer.predict_step() and trainer.predict in why one can use logging and the other cannot. This is flirting with being a bug, but appears to be intended behavior from the comment below.

We are not inside a predict hook, we are inside a evaluation_hook. We did use trainer.predict, with all of its great functionality, to generate a set of predictions.

Expected behavior

I understand from the above issues as stated by @carmocca (#7333 (comment)) that we cannot overwrite the trainer state. Why doesn't this work with a new trainer?

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel

class MyModel(BoringModel):
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
            return
        print("Start predicting!")
        for i, batch in enumerate(self.predict_dataloader()):
            batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
            out = self.predict_step(batch, i)
            print(i, out)
        
        self.log("metric", 1.0)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run_predict_step():
    model = MyModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        accelerator="auto",
        limit_train_batches=1,
        limit_val_batches=1,
        fast_dev_run=True,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model)


class MyModel2(BoringModel):
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
            return
        print("Start predicting!")
        dataloader = self.predict_dataloader()

        new_trainer = Trainer(
        default_root_dir=os.getcwd(),
        accelerator="auto",
        limit_train_batches=1,
        limit_val_batches=1,
        fast_dev_run=True,
        max_epochs=1,
        enable_model_summary=False)

        new_trainer.predict(self, dataloaders=dataloader)

        self.log("metric", 1.0)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run_trainer_predict():
    model = MyModel2()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        accelerator="auto",
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        fast_dev_run=True,
        enable_model_summary=False,
    )
    trainer.fit(model)

if __name__ == "__main__":
    # This works
    run_predict_step()

    # This does not work
    run_trainer_predict()

Error messages and logs

You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging

Environment

(DeepForest) benweinstein@Bens-MacBook-Pro Downloads % python collect_env_details.py          
<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning-lite:    1.8.0.post1
        - lightning-utilities: 0.8.0
        - pytorch-lightning: 2.1.2
        - torch:             1.12.1
        - torchmetrics:      1.2.0
        - torchvision:       0.13.1
* Packages:
        - absl-py:           0.13.0
        - accessible-pygments: 0.0.4
        - affine:            2.3.0
        - aiohttp:           3.7.4.post0
        - alabaster:         0.7.12
        - albumentations:    1.1.0
        - async-timeout:     3.0.1
        - attrs:             21.2.0
        - babel:             2.9.1
        - beautifulsoup4:    4.12.2
        - bleach:            4.0.0
        - brotlipy:          0.7.0
        - bumpversion:       0.5.3
        - cached-property:   1.5.2
        - cachetools:        4.2.2
        - certifi:           2021.5.30
        - cffi:              1.14.6
        - chardet:           4.0.0
        - click:             7.1.2
        - click-plugins:     1.1.1
        - cligj:             0.7.2
        - cmarkgfm:          0.4.2
        - colorama:          0.4.4
        - commonmark:        0.9.1
        - cryptography:      3.4.7
        - cycler:            0.10.0
        - docutils:          0.18.1
        - execnet:           2.0.2
        - fiona:             1.8.20
        - fire:              0.4.0
        - fonttools:         4.25.0
        - fsspec:            2021.7.0
        - furo:              2023.9.10
        - future:            0.18.2
        - gdal:              3.3.1
        - geopandas:         0.9.0
        - google-auth:       1.34.0
        - google-auth-oauthlib: 0.4.5
        - gprof2dot:         2022.7.29
        - grpcio:            1.39.0
        - h5py:              3.3.0
        - idna:              2.10
        - imagecodecs:       2021.7.30
        - imageio:           2.9.0
        - imagesize:         1.4.1
        - importlib-metadata: 6.8.0
        - iniconfig:         1.1.1
        - jinja2:            3.0.1
        - joblib:            1.0.1
        - keyring:           23.0.1
        - kiwisolver:        1.3.1
        - lightning-lite:    1.8.0.post1
        - lightning-utilities: 0.8.0
        - mapclassify:       2.4.3
        - markdown:          3.3.4
        - markupsafe:        2.0.1
        - matplotlib:        3.4.2
        - more-itertools:    8.8.0
        - multidict:         5.1.0
        - munch:             2.5.0
        - munkres:           1.1.4
        - networkx:          2.6.2
        - numpy:             1.21.1
        - numpydoc:          1.1.0
        - oauthlib:          3.1.1
        - olefile:           0.46
        - opencv-python:     4.6.0.66
        - packaging:         21.0
        - pandas:            1.3.1
        - pillow:            9.2.0
        - pip:               21.2.2
        - pkginfo:           1.7.1
        - platformdirs:      3.11.0
        - pluggy:            0.13.1
        - progressbar2:      4.2.0
        - protobuf:          3.17.3
        - psutil:            5.8.0
        - py:                1.10.0
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycocotools:       2.0.7
        - pycparser:         2.20
        - pydata-sphinx-theme: 0.14.1
        - pydeprecate:       0.3.1
        - pygments:          2.16.1
        - pyopenssl:         20.0.1
        - pyparsing:         2.4.7
        - pyproj:            3.1.0
        - pysocks:           1.7.1
        - pytest:            6.2.4
        - pytest-profiling:  1.7.0
        - pytest-xdist:      3.3.1
        - python-dateutil:   2.8.2
        - python-utils:      3.4.5
        - pytorch-lightning: 2.1.2
        - pytz:              2021.1
        - pywavelets:        1.1.1
        - pyyaml:            5.4.1
        - qudida:            0.0.4
        - rasterio:          1.2.6
        - readme-renderer:   24.0
        - recommonmark:      0.7.1
        - requests:          2.25.1
        - requests-oauthlib: 1.3.0
        - requests-toolbelt: 0.9.1
        - rfc3986:           1.4.0
        - rsa:               4.7.2
        - rtree:             0.9.7
        - scikit-image:      0.18.2
        - scikit-learn:      0.24.2
        - scipy:             1.7.0
        - setuptools:        59.5.0
        - shapely:           1.7.1
        - six:               1.16.0
        - slidingwindow:     0.0.14
        - snakeviz:          2.1.1
        - snowballstemmer:   2.1.0
        - snuggs:            1.4.7
        - soupsieve:         2.5
        - sphinx:            7.2.6
        - sphinx-basic-ng:   1.0.0b2
        - sphinx-markdown-tables: 0.0.15
        - sphinx-rtd-theme:  1.3.0
        - sphinxcontrib-applehelp: 1.0.2
        - sphinxcontrib-devhelp: 1.0.2
        - sphinxcontrib-htmlhelp: 2.0.0
        - sphinxcontrib-jquery: 4.1
        - sphinxcontrib-jsmath: 1.0.1
        - sphinxcontrib-qthelp: 1.0.3
        - sphinxcontrib-serializinghtml: 1.1.9
        - tensorboard:       2.10.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.0
        - termcolor:         2.1.0
        - threadpoolctl:     2.2.0
        - tifffile:          2021.7.30
        - toml:              0.10.2
        - tomli:             2.0.1
        - torch:             1.12.1
        - torchmetrics:      1.2.0
        - torchvision:       0.13.1
        - tornado:           6.1
        - tqdm:              4.62.0
        - twine:             0.0.0
        - typing-extensions: 4.3.0
        - urllib3:           1.26.6
        - webencodings:      0.5.1
        - werkzeug:          2.0.1
        - wheel:             0.36.2
        - xmltodict:         0.12.0
        - yapf:              0.40.2
        - yarl:              1.6.3
        - zipp:              3.5.0
* System:
        - OS:                Darwin
        - architecture:
                - 64bit
                - 
        - processor:         i386
        - python:            3.9.6
        - release:           23.1.0
        - version:           Darwin Kernel Version 23.1.0: Mon Oct  9 21:27:27 PDT 2023; root:xnu-10002.41.9~6/RELEASE_X86_64

</details>

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions