-
Notifications
You must be signed in to change notification settings - Fork 5k
Fix TorchAO quant in VLM #13508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TorchAO quant in VLM #13508
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -36,6 +36,17 @@ def proj_filter( | |||||
| return "proj" in fqn | ||||||
|
|
||||||
|
|
||||||
| # TODO: implement a more general filter function | ||||||
| def proj_filter_conv3d( | ||||||
| module: torch.nn.Module, | ||||||
| fqn: str, | ||||||
| ): | ||||||
| if isinstance(module, torch.nn.Conv3d): | ||||||
| logger.warning(f"Quantize: skipping {fqn} because it's a Conv3d") | ||||||
| return False | ||||||
| return "proj" in fqn | ||||||
|
|
||||||
|
|
||||||
| def apply_torchao_config_to_model( | ||||||
| model: torch.nn.Module, | ||||||
| torchao_config: str, | ||||||
|
|
@@ -63,7 +74,7 @@ def apply_torchao_config_to_model( | |||||
| if torchao_config == "" or torchao_config is None: | ||||||
| return model | ||||||
| elif "int8wo" in torchao_config: | ||||||
| quantize_(model, int8_weight_only(), filter_fn=filter_fn) | ||||||
| quantize_(model, int8_weight_only(), filter_fn=proj_filter_conv3d) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of hardcoding You could change the function signature on line 49 to: filter_fn: Optional[Callable] = proj_filter_conv3d,And then change this line and others to use the default
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the default
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why delete
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Due to the need to determine whether to use the original filter or the new filter based on different quantization configurations, passing this decision from an upper layer would create some coupling, as the upper layer would need to be aware of the different quantization settings. However, we still need to consider how to handle modules like Deconv3D. The current |
||||||
| elif "int8dq" in torchao_config: | ||||||
| quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn) | ||||||
| elif "int4wo" in torchao_config: | ||||||
|
|
@@ -101,7 +112,7 @@ def apply_torchao_config_to_model( | |||||
| elif "fp8wo" in torchao_config: | ||||||
| # this requires newer hardware | ||||||
| # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 | ||||||
| quantize_(model, float8_weight_only(), filter_fn=filter_fn) | ||||||
| quantize_(model, float8_weight_only(), filter_fn=proj_filter_conv3d) | ||||||
| elif "fp8dq" in torchao_config: | ||||||
| granularity = torchao_config.split("-")[-1] | ||||||
| GRANULARITY_MAP = { | ||||||
|
|
@@ -116,7 +127,7 @@ def apply_torchao_config_to_model( | |||||
| float8_dynamic_activation_float8_weight( | ||||||
| granularity=GRANULARITY_MAP[granularity] | ||||||
| ), | ||||||
| filter_fn=filter_fn, | ||||||
| filter_fn=proj_filter_conv3d, | ||||||
| ) | ||||||
| else: | ||||||
| raise ValueError(f"Unexpected config: {torchao_config}") | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.