Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[RFC] Deprecate model.predict in favour of Trainer.predict #1004

Closed
ethanwharris opened this issue Nov 29, 2021 · 2 comments · Fixed by #1030
Closed

[RFC] Deprecate model.predict in favour of Trainer.predict #1004

ethanwharris opened this issue Nov 29, 2021 · 2 comments · Fixed by #1030

Comments

@ethanwharris
Copy link
Collaborator

ethanwharris commented Nov 29, 2021

model.predict has a lot of usability issues:

  • not clear how to load from different inputs
  • no way to pass additional data loading arguments
  • no docstring specific to the particular task
  • lots of magic in the background to resolve e.g. transforms and stuff like that
  • not able to handle much load (see text classify prediction - memory overflow #983)
  • two ways to do prediction

We should deprecate in favour of using a datamodule and Trainer. E.g. image classification prediction is currently:

predictions = model.predict(
    [
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
        "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
    ]
)

Would become:

predict_data_module = ImageClassificationData.from_files(predict_files=[
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
predictions = trainer.predict(model, datamodule=predict_data_module)
@tchaton
Copy link
Contributor

tchaton commented Nov 29, 2021

I believe the model.predict is confusing as users don't have fine control over the data being processed.

I would prefer relying on trainer.predict which is the default recommended way with PyTorch Lightning.

@Borda
Copy link
Member

Borda commented Nov 29, 2021

I would tent to string input to Dataset (maybe DataLoader) only and then it is up to the user to load data within...

Positive:

  • user can easier manage what the input data does
  • we provide basic tooling so it shall be at most one more line of code
  • simple and clean adding extra arguments

Negative:

  • one more line of code and it loses some kind of magic (also the magic is the source of confusion)
  • we need to well document and see dataset compatibility and list all available dataset

Would become:

predict_data_module = ImageClassificationData.from_files(predict_files=[
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
predictions = trainer.predict(model, datamodule=predict_data_module)

that is also a good way...

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants