Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy weights and preserve device for 8da4w QAT linear #211

Merged
merged 7 commits into from
May 6, 2024

Conversation

andrewor14
Copy link
Contributor

Summary: This fixes two correctness bugs. First, we never copied over the weights from the existing linear, so we would start from random weights even when loading from checkpoints. Second, we never preserved the device of the original linear. This is important for settings like FSDP, where we expect non-zero ranks to have their parameters on the meta device in order to initialize these parameters correctly.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_load_state_dict_meta

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 2, 2024
@andrewor14 andrewor14 requested a review from jerryzh168 May 2, 2024 21:45
@@ -47,6 +47,7 @@ def prepare(
*args: Any,
**kwargs: Any
) -> torch.nn.Module:
state_dict = model.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh OK, right, this function have to do create_quantized_state_dict() and load_state_dict() in the gpt-fast API, but I feel we could also change _replace_linear_8da4w to instantiate from the existing floating point module, that might be clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, updated

)
break
if should_load_state_dict:
model.load_state_dict(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this work if we use tensor subclasses?

# on the meta device, in which case there is no need to
# load the state dict, and doing so will lead to an error
should_load_state_dict = True
for k, v in state_dict.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably be simplified a bit:

should_load_state_dict = all(not v.is_meta for v in state_dict.values())

Summary: This fixes two correctness bugs. First, we never copied
over the weights from the existing linear, so we would start from
random weights even when loading from checkpoints. Second, we
never preserved the device of the original linear. This is
important for settings like FSDP, where we expect non-zero ranks
to have their parameters on the meta device in order to
initialize these parameters correctly.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_meta_weights

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
@andrewor14 andrewor14 force-pushed the 8da4w_qat_state_dict branch from bac2301 to cd50d7a Compare May 2, 2024 22:10
@andrewor14 andrewor14 requested a review from jerryzh168 May 2, 2024 22:11
@facebook-github-bot
Copy link

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, we'll refine the api/impl more when we move to tensor subclass

@facebook-github-bot
Copy link

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@andrewor14 andrewor14 merged commit ce78e79 into main May 6, 2024
15 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* Copy weights and preserve device for 8da4w QAT linear

Summary: This fixes two correctness bugs. First, we never copied
over the weights from the existing linear, so we would start from
random weights even when loading from checkpoints. Second, we
never preserved the device of the original linear. This is
important for settings like FSDP, where we expect non-zero ranks
to have their parameters on the meta device in order to
initialize these parameters correctly.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_meta_weights

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar

* Update test_qat.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants