-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context manager to properly convert the precision #10079
Conversation
# this hook is not wrapped. # TODO: should it be? | ||
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure whether this is working as expected or a bug. The precision context manager is only active during the forward context, and this hook is not part of it.
Should we instead enter the context manager on setup and exit on teardown?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, but I would say yes. real + img in float32 -> complex64, and real + img in float64 -> complex128. Makes sense to me at least.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's as expected.
The problem here is that we only wrap the precision for the forward hooks.
So, other hooks like setup
and on_fit_start
are not wrapped and as tested here, they do not use the correct precision.
Maybe we could change this to wrap everything from setup
to teardown
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note after discussion with Thomas: It's likely we would need to disable it for backward and optimizer.step.
This will also need to be considered for Lite
return self.autodtype() | ||
|
||
|
||
class DoublePrecisionPlugin(DtypePrecisionPlugin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a dataclass ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, but I don't think we want to. It's still a PrecisionPlugin
(not a dataclass)
# this hook is not wrapped. # TODO: should it be? | ||
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure, but I would say yes. real + img in float32 -> complex64, and real + img in float64 -> complex128. Makes sense to me at least.
This is no longer necessary after the addition of |
What does this PR do?
The fix in #8208 isn't working as expected. The check for
complex64
should beassert self.complex_buffer.dtype == torch.complex128
.This is because (from https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html):
when the default floating point type is float64 the default complex type is complex128.
However, I didn't find a way to fix it, because moving the model
model.double()
ormodel.to(torch.double)
does not convert it tocomplex128
after it has been created.The only option I see is that the model instantiation is done after
torch.set_default_dtype(torch.double)
but that's outside of Lightning's scope.To make this easier, this PR adds the following context manager:
This is similar to how #9920 works.
Some other changes included:
set_default_dtype
instead ofset_default_tensor_type
.set_default_dtpye
only supports floating types currently.Does your PR introduce any breaking changes? If yes, please list them.
The model is no longer moved.
Before submitting
PR review