-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Copy weights and preserve device for 8da4w QAT linear #211
Conversation
@@ -47,6 +47,7 @@ def prepare( | |||
*args: Any, | |||
**kwargs: Any | |||
) -> torch.nn.Module: | |||
state_dict = model.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh OK, right, this function have to do create_quantized_state_dict() and load_state_dict() in the gpt-fast API, but I feel we could also change _replace_linear_8da4w
to instantiate from the existing floating point module, that might be clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree, updated
) | ||
break | ||
if should_load_state_dict: | ||
model.load_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this work if we use tensor subclasses?
# on the meta device, in which case there is no need to | ||
# load the state dict, and doing so will lead to an error | ||
should_load_state_dict = True | ||
for k, v in state_dict.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can probably be simplified a bit:
should_load_state_dict = all(not v.is_meta for v in state_dict.values())
Summary: This fixes two correctness bugs. First, we never copied over the weights from the existing linear, so we would start from random weights even when loading from checkpoints. Second, we never preserved the device of the original linear. This is important for settings like FSDP, where we expect non-zero ranks to have their parameters on the meta device in order to initialize these parameters correctly. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_meta_weights Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
bac2301
to
cd50d7a
Compare
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sg, we'll refine the api/impl more when we move to tensor subclass
@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
* Copy weights and preserve device for 8da4w QAT linear Summary: This fixes two correctness bugs. First, we never copied over the weights from the existing linear, so we would start from random weights even when loading from checkpoints. Second, we never preserved the device of the original linear. This is important for settings like FSDP, where we expect non-zero ranks to have their parameters on the meta device in order to initialize these parameters correctly. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_meta_weights Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar * Update test_qat.py
Summary: This fixes two correctness bugs. First, we never copied over the weights from the existing linear, so we would start from random weights even when loading from checkpoints. Second, we never preserved the device of the original linear. This is important for settings like FSDP, where we expect non-zero ranks to have their parameters on the meta device in order to initialize these parameters correctly.
Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_load_state_dict_meta
Reviewers: jerryzh168, cpuhrsch
Subscribers: jerryzh168, cpuhrsch, supriyar