Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: avoid automatic device detection via serialized tensors when deserializing PyTorchPredictor. #2576

Merged
merged 1 commit into from
Jan 20, 2023

Conversation

shubhamkapoor
Copy link
Contributor

@shubhamkapoor shubhamkapoor commented Jan 20, 2023

Issue #, if available:

Description of changes:
When map_location=None is passed to torch.load, it automatically loads the tensors to the device they were serialized on. It results in runtime errors when the model is trained using a GPU and serialized, but inference is done on CPU when predictor is constructed by deserialize method of PytorchPredictor with device=None.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@shubhamkapoor shubhamkapoor added the bug fix (one of pr required labels) label Jan 20, 2023
@lostella lostella added models This item concerns models implementations torch This concerns the PyTorch side of GluonTS labels Jan 20, 2023
@lostella lostella changed the title update Device to avoid automatic detection via serialized tensors. Fix: avoid automatic device detection via serialized tensors when deserializing PyTorchPredictor. Jan 20, 2023
Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We can then consolidate this with the same logic happening in other places, maybe having a default_device function to do this would be good

@lostella lostella added the pending v0.11.x backport This contains a fix to be backported to the v0.11.x branch label Jan 20, 2023
@shubhamkapoor shubhamkapoor merged commit d5bfd79 into awslabs:dev Jan 20, 2023
@shubhamkapoor shubhamkapoor deleted the fix_predictor branch January 20, 2023 16:02
lostella pushed a commit to lostella/gluonts that referenced this pull request Jan 30, 2023
@lostella lostella mentioned this pull request Jan 30, 2023
lostella added a commit that referenced this pull request Jan 30, 2023
* Fix: avoid automatic device detection via serialized tensors when deserializing. (#2576)

* Make itertools Map/Filter dataclasses. (#2579)

* serde: Fix encoding of dtypes. (#2586)

* Add assertion to split function ensuring valid windows (#2587)

* Ensure dtype on feat_time in torch DeepAR. (#2596)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* fix requirements further

* fix style

* remove undesired change

---------

Co-authored-by: Shubham Kapoor <[email protected]>
Co-authored-by: Jasper <[email protected]>
Co-authored-by: MarcelK1102 <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
@lostella lostella removed the pending v0.11.x backport This contains a fix to be backported to the v0.11.x branch label Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix (one of pr required labels) models This item concerns models implementations torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants