Skip to content

[AWQ][Smooth] mapping shouldn't use ignore#2152

Merged
HDCharles merged 9 commits intomainfrom
93_mapping_match
Jan 7, 2026
Merged

[AWQ][Smooth] mapping shouldn't use ignore#2152
HDCharles merged 9 commits intomainfrom
93_mapping_match

Conversation

@HDCharles
Copy link
Collaborator

@HDCharles HDCharles commented Dec 18, 2025

Summary:

after the change to use match_module_set i made it so AWQ would take into account the ignore list during module mapping, this was a mistake since the ignore list should be used for quantization.

A similar issue occured with smoothquant which is also fixed, new tests were added to verify this behavior

at the same time, added behavior that if all balance layers + smooth layer are ignored, then it skips that mapping.

these changes should fix #2151 however i have not been able to get it to run, while the matching now works correctly, something is happening with the tracing, i'm investigating that separately as i'm able to run it when i switch to the basic pipeline so this isn't an issue with this

Test Plan:
pytest /home/HDCharles/repos/llm-compressor/tests/llmcompressor/modifiers/smoothquant/test_base.py
pytest /home/HDCharles/repos/llm-compressor/tests/llmcompressor/modifiers/awq/test_base.py

@github-actions
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @HDCharles, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses and rectifies an issue in the AWQ (Activation-aware Weight Quantization) modifier where the ignore list was inadvertently influencing module mapping for smoothing operations. The change ensures a clear separation of concerns, with the ignore list now solely dedicated to specifying layers that should not be quantized, thereby resolving a reported bug and improving the precision of AWQ configurations.

Highlights

  • AWQ Module Mapping Logic Correction: The match_modules_set function within AWQModifier._set_resolved_mappings no longer uses the self.ignore list when identifying modules for smoothing. This ensures that the ignore list is exclusively applied to quantization targets, not to the smoothing process, correcting an earlier mistake.
  • Documentation Clarification: The examples/awq/README.md has been updated with a note clarifying that mappings define layers for smoothing, while targets and ignore lists are specifically for quantization. This helps users understand the distinct roles of these configurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@HDCharles HDCharles requested review from dsikka and fynnsu December 18, 2025 18:00
@HDCharles HDCharles added bug Something isn't working ready When a PR is ready for review awq For any issue / PR related to AWQ support labels Dec 18, 2025
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug where the ignore list was improperly used during module mapping for AWQ smoothing. The change correctly removes the ignore parameter from the match_modules_set call, ensuring it's only used for quantization as intended. The accompanying update to README.md and the new code comment clearly explain this behavior, improving both user documentation and code maintainability. I've added one comment regarding an outdated docstring that this change introduces, which would be good to address for consistency.

fynnsu
fynnsu previously approved these changes Dec 18, 2025
@HDCharles HDCharles dismissed stale reviews from fynnsu and brian-dellabetta via 7cc1bf0 December 18, 2025 18:38
kylesayrs
kylesayrs previously approved these changes Dec 18, 2025
fynnsu
fynnsu previously approved these changes Dec 18, 2025
@HDCharles HDCharles enabled auto-merge (squash) December 18, 2025 21:04
@HDCharles HDCharles disabled auto-merge December 18, 2025 21:05
@ZewenShen-Cohere
Copy link
Contributor

Hi Charles, I’m working on the same issue as well. I don’t think your change will work correctly for multimodal LLMs. In practice, we usually do not want to quantize the vision encoder. As vision encoder usually shares a similar transformer structure, the vision encoder would also be picked up by the AWQ mapping resolution logic when you remove the 'ignore', which is not desirable. In my draft PR, I added an additional field in the modifier to separate the logic.

@HDCharles
Copy link
Collaborator Author

Hi Charles, I’m working on the same issue as well. I don’t think your change will work correctly for multimodal LLMs. In practice, we usually do not want to quantize the vision encoder. As vision encoder usually shares a similar transformer structure, the vision encoder would also be picked up by the AWQ mapping resolution logic when you remove the 'ignore', which is not desirable. In my draft PR, I added an additional field in the modifier to separate the logic.

I don't think it is an issue, its just going to do smoothing which will be mathematically irrelevent, though likely slow. Also this was tested on python /home/HDCharles/repos/llm-compressor/examples/awq/qwen3-vl-30b-a3b-Instruct-example.py which is multimodal. I left comments on your PR, I think we can make further enhancements on top of this PR since they seem orthagonal and the changes in this PR are enabling behaviors we had previously and are expected.

@ZewenShen-Cohere
Copy link
Contributor

Hi Charles, thank you for the suggestion. I've made a PR to filter out these redundant mappings #2179

@HDCharles HDCharles dismissed stale reviews from brian-dellabetta and fynnsu via 1f3c8a6 January 5, 2026 16:22
@HDCharles
Copy link
Collaborator Author

@ZewenShen-Cohere sorry i'm back from the break now.

i've added that functionality to this PR and added tests as well as consolidated the logic for smoothquant and AWQ

@HDCharles HDCharles changed the title [AWQ] mapping shouldn't use ignore [AWQ][Smooth] mapping shouldn't use ignore Jan 5, 2026
Copy link
Collaborator

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updates and tests look good! One nit.

HDCharles and others added 5 commits January 6, 2026 12:36
Summary:

after the change to match_module_set i made it so AWQ would take into
account the ignore list during module mapping, this was a mistake since
the ignore list should be used for quantization.

this fixes #2151

Test Plan:
python /home/HDCharles/repos/llm-compressor/examples/awq/qwen3-vl-30b-a3b-Instruct-example.py

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Summary

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Copy link
Collaborator

@shanjiaz shanjiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@HDCharles HDCharles merged commit 5d15cc8 into main Jan 7, 2026
11 checks passed
@HDCharles HDCharles deleted the 93_mapping_match branch January 7, 2026 15:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awq For any issue / PR related to AWQ support bug Something isn't working ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: qwen3-vl-30b-a3b v0.9.0 awq no matches found for input_layernorm

7 participants

Comments