Skip to content

Comments

[Fix] SmoothQuant MoE Support: Smooth All Experts, Not Just expert.0#2084

Merged
fynnsu merged 3 commits intomainfrom
smoothquant-issue
Dec 3, 2025
Merged

[Fix] SmoothQuant MoE Support: Smooth All Experts, Not Just expert.0#2084
fynnsu merged 3 commits intomainfrom
smoothquant-issue

Conversation

@rahul-tuli
Copy link
Collaborator

Problem

SmoothQuant only smoothed the first expert (expert.0) in Mixture of Experts (MoE) models, leaving all other experts unsmoothed. This caused severe accuracy degradation for MoE models

Root Cause

The _resolve_mappings() method in SmoothQuantModifier used get_matching_layer(), which returns only the first regex match instead of iterating over all matches. For MoE models with regex patterns like "re:.*experts.*w1", this meant only expert.0.w1 was smoothed while experts 1-N were ignored.

# Before (BUGGY)
for balance_suffix in to_balance:
    _, balance_layer = get_matching_layer(balance_suffix, layer_name, model)
    # ❌ Returns only expert.0, ignores experts 1-15

Solution

Replace get_matching_layer() with match_named_modules() from compressed_tensors.utils to iterate over ALL matched layers. This follows the same proven pattern used in AWQModifier.

# After (FIXED)
for balance_regex in to_balance:
    for _, balance_layer in match_named_modules(smooth_parent, [balance_regex], self.ignore):
        balance_layers.append(balance_layer)
    # ✅ Returns ALL experts (expert.0, expert.1, ..., expert.15)

Key Changes

  1. Updated imports: Use match_named_modules from compressed_tensors.utils
  2. Rewrote _resolve_mappings(): Iterate over all matched layers instead of just the first

Tests Added

Added unit tests to encompass the issue to verify MoE support, these tests fail on main but pass with current diff:

1. test_moe_all_experts_smoothed

Verifies all 8 experts in a single MoE layer are included in balance_layers:

num_experts = 8
# ... create MoE model with 8 experts ...
resolved_mappings = sq._resolve_mappings(model)
assert len(mapping.balance_layers) == num_experts  # All 8 experts

2. test_moe_multiple_layers_all_experts_smoothed

Verifies correct behavior across multiple transformer layers:

num_layers = 2
num_experts = 4
# ... create model with 2 layers, 4 experts each ...
assert len(resolved_mappings) == num_layers
for mapping in resolved_mappings:
    assert len(mapping.balance_layers) == num_experts  # All 4 experts per layer

Test Results

All tests pass successfully:

$ python -m pytest tests/llmcompressor/modifiers/smoothquant/test_base.py -v

test_smooth_quant_is_registered                          ✅ PASSED
test_smooth_quant_defaults                               ✅ PASSED
test_override_defaults                                   ✅ PASSED
test_moe_all_experts_smoothed                            ✅ PASSED
test_moe_multiple_layers_all_experts_smoothed            ✅ PASSED

========================= 5 passed in 0.41s =========================

Before Fix (Tests Failed)

AssertionError: Expected 8 balance layers, got 1
# Only expert.0 was smoothed ❌

After Fix (Tests Pass)

All 8 experts smoothed ✅
All tests passing ✅

Related Issues

Fixes the SmoothQuant MoE bug reported in the community discussion about MoE quantization support.

@github-actions
Copy link

github-actions bot commented Dec 2, 2025

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

Add comprehensive unit tests to verify that SmoothQuant correctly
handles Mixture of Experts (MoE) models by smoothing all experts,
not just the first one.

Tests added:
- test_moe_all_experts_smoothed: Verifies all 8 experts in a single
  MoE layer are included in balance_layers
- test_moe_multiple_layers_all_experts_smoothed: Verifies correct
  behavior across multiple transformer layers with 4 experts each

These tests currently fail with the existing implementation, which
only matches the first expert due to get_matching_layer() returning
a single match instead of iterating over all matches.

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
@rahul-tuli rahul-tuli marked this pull request as ready for review December 2, 2025 15:50
@rahul-tuli rahul-tuli requested review from HDCharles, dsikka, fynnsu, kylesayrs and shanjiaz and removed request for fynnsu December 2, 2025 15:50
@rahul-tuli rahul-tuli self-assigned this Dec 2, 2025
@rahul-tuli rahul-tuli added bug Something isn't working ready When a PR is ready for review labels Dec 2, 2025
fynnsu
fynnsu previously approved these changes Dec 2, 2025
Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

HDCharles
HDCharles previously approved these changes Dec 2, 2025
Copy link
Collaborator

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems good to me

@rahul-tuli rahul-tuli dismissed stale reviews from HDCharles and fynnsu via 5028ebf December 3, 2025 11:53
Replace get_matching_layer() with match_named_modules() to iterate
over ALL matched layers instead of returning only the first match.
This fixes a critical bug where only expert.0 was smoothed in MoE
models, leaving all other experts unsmoothed and causing severe
accuracy degradation.

Changes:
- Use match_named_modules from compressed_tensors.utils to iterate
  over all matching modules
- Search for balance layers within the parent module scope for
  better locality
- Follow the same pattern already proven to work in AWQModifier

This fix ensures all experts in MoE models (Mixtral, Qwen3, Phi,
DeepSeek) are properly smoothed during quantization.

Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
@rahul-tuli
Copy link
Collaborator Author

@fynnsu Great suggestions! I've applied both:

  1. Simplified the nested loop: You're absolutely right that we can pass to_balance directly to match_named_modules since it already iterates over the targets list.
    Changed to a list comprehension for cleaner code.
  2. Updated get_layer_by_name: Added special handling for empty string to return the module itself. This cleaned up the ternary logic nicely. Also added a test case to
    cover this behavior.

Both changes are now committed. Thanks for the thoughtful review!

@fynnsu fynnsu merged commit 3cd0300 into main Dec 3, 2025
11 checks passed
@fynnsu fynnsu deleted the smoothquant-issue branch December 3, 2025 15:17
fynnsu added a commit that referenced this pull request Dec 10, 2025
Depends on vllm-project/compressed-tensors#524

Summary:
- modified AWQ _set_resolved_mappings
- get smoothing and balance layers at same time using match_modules_set
- (bugfix) correct logic so that if any balance layers are incompatible,
that matching is skipped
  -  added warnings
  -  get rid of tqdm and skip counting @kylesayrs 
  -  added helper for module_to_name
- remove hardcoded handling for single balance layer by updating
get_lowest_common_module to handle that
- modified SmoothQuant _resolve_mappings
  - brought into alignment with AWQ
- this is largely a horizontal move though there is handling for
situations that would have been missed before like
      - multiple smooth layer matches in a single set 
      - parent contexts further than 1 layer away.
- updated mapping definitions to always be tuple(list[str],str) which is
always the case but wasn't required unlike in AWQ
- removed get_lowest_common_parent
- now we can use CT's get_lowest_common_ancestor_name so only need to
check for module_list (it has a lot of bugfixes compared to the
get_lowest_common_parent implementation in LLMC)
- updated test_base for AWQ and smoothquant
- added test case for _set_resolved_mappings to check that partially
skipped matches are handled correctly
  - added tests for MoE matching being handled correctly
  - added test cases for get_lowest_non_module_list_ancestor
  - imported Linear and used that instead of torch.nn.Linear
- reverted test_pytorch.py for logarithmic_equalizations and smoothquant
- The test was updated in
#2084 by @rahul-tuli
to ignore some modules but in general because of the way the new logic
works, you need to ignore the whole set.
- if you only ignore one element the matching logic would need to
determine whether there's a full set or not *somehow* which it doesn't
do. In the previous logic, this was possible because it was assumed the
whole set had to be siblings of the smooth_layer, but the new util is
trying to be more flexible and so relaxes this assumption which prevents
the same approach from working. If this is a common need, perhaps we can
add a util that checks for a context parent context of size N or
something.

TEST PLAN:
pytest
/home/HDCharles/repos/llm-compressor/tests/llmcompressor/modifiers/awq/test_base.py
pytest
/home/HDCharles/repos/llm-compressor/tests/llmcompressor/modifiers/smoothquant/test_base.py

---------

Signed-off-by: HDCharles <charlesdavidhernandez@gmail.com>
Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: Fynn Schmitt-Ulms <fynnsu@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready When a PR is ready for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants