From 38c67752ee55ca4a73e8eee4347c697c4011b4f0 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 2 Dec 2025 22:01:02 +0000 Subject: [PATCH 01/14] fix match_modules_set to work with MoE Summary: match_modules_set isn't currently as useful as it could be because it lacks the ability to match multiple results for each set like in the case of a moe model where you have 128 experts. ``` [`layers.32.mlp.experts.0.gate_up_proj`, ..., `layers.32.mlp.experts.127.gate_up_proj`] ``` In order to make is so this can still work for matching simple cases and moe cases we use the following approach. 1) match modules until we have at least 1 match per target 2) when we have 1 match per target, our set is 'full' and we calculate the common parent context 3) continue matching and for each match, check if parent context would change given the new match 4) if we find a match that changes the parent context, this is the first element of the next set. yield the existing matched set and then reset, using the current match as the first element of the new set. To facilitate this algorithm i also added get_lowest_common_module_name which basically copies a similar function in llm-compressor though significantly simpler. Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 141 ++++++++++++++++++++------ tests/test_utils/test_match.py | 70 ++++++++++--- 2 files changed, 168 insertions(+), 43 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index f26400b0b..c68067e4c 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import os import re from collections.abc import Generator from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -29,6 +30,7 @@ "match_named_parameters", "match_targets", "match_modules_set", + "get_lowest_common_module_name", "is_match", "is_narrow_match", ] @@ -157,33 +159,76 @@ def match_targets( return matched_targets +def get_lowest_common_module_name(names: list[str | None]) -> str: + """ + Given a list of names, returns the lowest-scope common name ignoring None's. + + Implementation is a small alteration of os.path.commonprefix + https://docs.python.org/3/library/os.path.html#os.path.commonprefix + + ([s1, s2]->prefix->result) + # case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc + # case 1: single module: [abc.] -> .abc. -> abc + # case 2: substring modules: [abc., ab.] -> .ab -> "" + # case 3: parent & child: [ab., ab.a.] -> .ab. -> ab + """ + names = [name for name in names if name is not None] + if len(names) == 0: + return "" + + # 1) find longest shared prefix + s1 = "." + min(names) + "." + s2 = "." + max(names) + "." + common_prefix = os.path.commonprefix([s1, s2]) + # 2) throw away right most dot and name fragment, throw away leftmost char + # ".keep.thro" -> "keep", "." -> "" + return common_prefix[1 : common_prefix.rfind(".")] + + def match_modules_set( model: torch.nn.Module, targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, -) -> Generator[Iterable[torch.nn.Module]]: +) -> Generator[Iterable[Iterable[torch.nn.Module]]]: """ - Yields modules grouped with the same order and size as `targets`. - Values are returned in order of `model.named_modules()` + Yields modules grouped by parent context. - E.g. the following targets would yield module belonging to the following layers: + We group by parent context so that we can return ALL matches of a + specific target that can be paired with another target. This is most + relevant in the case of MoE modules with multiple modules for each + expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj, + mlp.expert.N.up_proj for all N. The parent context will differ from + one layer to another while being the same for one expert to another. + + Values are returned in order of `model.named_modules()` where possible + + E.g. the following targets would yield modules belonging to the following layers: ```python3 - match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == ( + match_modules_set(model, ["re:*q_proj", "re:*k_proj", "re:*v_proj"]) == ( ( - `model.layers.0.self_attn.q_proj`, - `model.layers.0.self_attn.k_proj`, - `model.layers.0.self_attn.v_proj`, + `layers.0.self_attn.q_proj`, + `layers.0.self_attn.k_proj`, + `layers.0.self_attn.v_proj`, ), + ... ( - `model.layers.1.self_attn.q_proj`, - `model.layers.1.self_attn.k_proj`, - `model.layers.1.self_attn.v_proj`, + `layers.32.self_attn.q_proj`, + `layers.32.self_attn.k_proj`, + `layers.32.self_attn.v_proj`, + ), + match_modules_set(model, ["re:*gate_up_proj", "down_proj"]) == ( + ( + [`layers.0.mlp.experts.0.gate_up_proj`, ..., + `layers.0.mlp.experts.127.gate_up_proj`] + [`layers.0.mlp.experts.0.down_proj`, ..., + `layers.0.mlp.experts.127.down_proj`] ), ... ( - `model.layers.32.self_attn.q_proj`, - `model.layers.32.self_attn.k_proj`, - `model.layers.32.self_attn.v_proj`, + [`layers.32.mlp.experts.0.gate_up_proj`, ..., + `layers.32.mlp.experts.127.gate_up_proj`] + [`layers.32.mlp.experts.0.down_proj`, ..., + `layers.32.mlp.experts.127.down_proj`] ), ) ``` @@ -192,7 +237,7 @@ def match_modules_set( For example, matching layer norms to their subsequent linear layers ```python3 for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)): - fuse_norm_linears(norm, [q, k, v]) + fuse_norm_linears(*norm, [*q, *k, *v]) :param model: model containing modules to match against :param targets: target strings, potentially containing "re:" prefixes @@ -201,24 +246,60 @@ def match_modules_set( targets = targets or [] ignore = ignore or [] - matches = dict.fromkeys(targets, None) + # as we iterate through modules and try to match them with targets, + # the set of matches can be in 2 possible states: + # 0) unmatched_targets > 0, i.e. some of the targets haven't been + # matched. Keep matching until all targets have at least one match + # 1) unmatched_targets == 0 i.e. we have at least one match for each + # target. At this point we are unsure if we have a full set or if + # we need to add more matches. + # There are 3 things that can happen once were in state 1: + # A) found a new match with same parent_context, (add it to matches + # and keep going) + # B) found a new match with different parent_context, i.e. we found a + # match that requires a deeper parent context, this indicates that + # this match should be part of a new set. + # (yield current set [not including newest match] and go back to + # state 0) + # C) ran out of modules, we will always yield the final remaining set + # when we we've iterated through all the modules in the model. + # (yield final set then exit.) + # Note: its possible to iterate through all the modules in the model + # while not having a full matched set if the user specified a + # bad matching, in that case something has gone wrong and we + # error + matches = dict.fromkeys(targets, []) + parent_context = None + unmatched_targets = len(targets) + for name, module in model.named_modules(): - # match until we get a full set for target in targets: if is_match(name, module, target, ignore): - if matches[target] is not None: - raise ValueError(f"Matched a {target} twice before completing set") - matches[target] = module - - # once we have a full set, yield and reset - if targets and all((matches[target] is not None for target in targets)): - yield [matches[target] for target in targets] # ensure correct ordering - matches = dict.fromkeys(targets, None) - - # check that none are left over - unmatched_keys = [match for match, value in matches.items() if value is not None] - if len(unmatched_keys): - raise ValueError(f"Unable to match targets into set: {unmatched_keys}") + new_parent_context = get_lowest_common_module_name(name, parent_context) + + # code for (B) + if unmatched_targets == 0 and new_parent_context != parent_context: + yield [matches[target] for target in targets] + matches = dict.fromkeys(targets, []) + parent_context = None + unmatched_targets = len(targets) + + # add match to mathes dict and do bookkeeping + unmatched_targets -= len(matches[target]) == 0 + matches[target].append(module) + parent_context = new_parent_context + + # code for (C) + if unmatched_targets == 0: + yield [matches[target] for target in targets] + return + + raise ValueError( + f"Found a final incomplete set with matches found for keys: " + f"{[t for t, m in matches if len(m)>0]} " + f"but no matches found for keys: " + f"{[t for t, m in matches if len(m)==0]}" + ) def is_match( diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 1129120c6..71a4a4c3e 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -20,6 +20,7 @@ # Assuming the module is named "module_matching" - adjust import as needed from compressed_tensors.utils import ( InternalModule, + get_lowest_common_module_name, is_match, is_narrow_match, match_modules_set, @@ -412,6 +413,58 @@ class InternalLinear(InternalModule, nn.Linear): assert len(matches) == 0 +class TestGetLowestCommonModuleName: + """Test cases for get_lowest_common_module_name function""" + + def test_multiple_modules(self): + assert "abc" == get_lowest_common_module_name( + [ + "abc.a", + "abc.b", + "abc.c", + ] + ) + + def test_single_module(self): + assert "abc.abc" == get_lowest_common_module_name( + [ + "abc.abc", + ] + ) + + def test_substring_modules(self): + assert "abc" == get_lowest_common_module_name( + [ + "abc.abc", + "abc.ab", + ] + ) + + def test_parent_and_child_modules(self): + assert "abc.abc" == get_lowest_common_module_name( + [ + "abc.abc.ab", + "abc.abc", + ] + ) + + def test_root(self): + assert "" == get_lowest_common_module_name( + [ + "abc.abc", + "b.abc", + ] + ) + + def test_ignore_none(self): + assert "abc.abc" == get_lowest_common_module_name( + [ + "abc.abc", + None, + ] + ) + + class TestMatchModulesSet: """Test cases for match_modules_set function""" @@ -432,7 +485,7 @@ def test_simple_module_set(self): # Each set should have 3 modules for module_set in matches: assert len(module_set) == 3 - assert all(isinstance(m, nn.Linear) for m in module_set) + assert all(isinstance(*m, nn.Linear) for m in module_set) def test_module_set_ordering(self): """Test that module sets maintain target ordering""" @@ -448,6 +501,7 @@ def test_module_set_ordering(self): for module_set in matches: # Check that modules are returned in target order (v, q, k) v_proj, q_proj, k_proj = module_set + v_proj, q_proj, k_proj = *v_proj, *q_proj, *k_proj # We can't easily check the exact modules, but can check they're all Linear assert all(isinstance(m, nn.Linear) for m in [v_proj, q_proj, k_proj]) @@ -456,18 +510,8 @@ def test_incomplete_set_error(self): model = DummyModel() targets = ["layer1", "nonexistent_module"] - with pytest.raises(ValueError, match="Unable to match targets into set"): - list(match_modules_set(model, targets)) - - def test_duplicate_match_error(self): - """Test error when same target matches multiple times before set completion""" - model = DummyModel() - # This should cause the same target to match multiple times - # before we can complete a set - targets = ["Linear", "Linear"] # Two identical targets - with pytest.raises( - ValueError, match="Matched a .* twice before completing set" + ValueError, match="Found a final incomplete set with matches found for keys" ): list(match_modules_set(model, targets)) @@ -476,7 +520,7 @@ def test_empty_targets_set(self): model = DummyModel() matches = list(match_modules_set(model, [])) # Should yield one empty set for each module traversed? - # Actually, with empty targets, we expect no matches + # with empty targets, we expect no matches assert len(matches) == 0 def test_module_set_with_ignore(self): From dd3bcfc478bf03d7bd8d6b0631110386a2059eb0 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:08:39 -0500 Subject: [PATCH 02/14] Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com> --- src/compressed_tensors/utils/match.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index c68067e4c..d3cd8b0e5 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -161,7 +161,7 @@ def match_targets( def get_lowest_common_module_name(names: list[str | None]) -> str: """ - Given a list of names, returns the lowest-scope common name ignoring None's. + Given a list of names, returns the lowest-scope common name, ignoring Nones. Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix From ae9dc047929812a24a67aba7b8dd9daa24975506 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:09:46 -0500 Subject: [PATCH 03/14] Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com> --- src/compressed_tensors/utils/match.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index d3cd8b0e5..71f47ffb1 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -159,7 +159,7 @@ def match_targets( return matched_targets -def get_lowest_common_module_name(names: list[str | None]) -> str: +def get_lowest_common_module_name(names: Iterable[str | None]) -> str: """ Given a list of names, returns the lowest-scope common name, ignoring Nones. From 37a22467ba41f4ba20a06b3690ec411c56ac32e1 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:10:05 -0500 Subject: [PATCH 04/14] Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com> --- src/compressed_tensors/utils/match.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 71f47ffb1..62ac8c82d 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -264,7 +264,7 @@ def match_modules_set( # C) ran out of modules, we will always yield the final remaining set # when we we've iterated through all the modules in the model. # (yield final set then exit.) - # Note: its possible to iterate through all the modules in the model + # Note: it's possible to iterate through all the modules in the model # while not having a full matched set if the user specified a # bad matching, in that case something has gone wrong and we # error From 412c666606da8d334fe0d711f88802316a0b977f Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:10:44 -0500 Subject: [PATCH 05/14] Update src/compressed_tensors/utils/match.py Co-authored-by: Kyle Sayers Signed-off-by: HDCharles <39544797+HDCharles@users.noreply.github.com> --- src/compressed_tensors/utils/match.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 62ac8c82d..3fb8e5496 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -284,7 +284,7 @@ def match_modules_set( parent_context = None unmatched_targets = len(targets) - # add match to mathes dict and do bookkeeping + # add match to matches dict and do bookkeeping unmatched_targets -= len(matches[target]) == 0 matches[target].append(module) parent_context = new_parent_context From 9c6d205b10199f0f64bbd1aaf088066900e2f7fc Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 04:23:08 +0000 Subject: [PATCH 06/14] format fixes and bug fixes Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 154 ++++++++++++++++---------- tests/test_utils/test_match.py | 103 +++++++++++++++++ 2 files changed, 200 insertions(+), 57 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 3fb8e5496..888c46bb6 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -15,6 +15,7 @@ import logging import os import re +from collections import defaultdict from collections.abc import Generator from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -159,9 +160,9 @@ def match_targets( return matched_targets -def get_lowest_common_module_name(names: Iterable[str | None]) -> str: +def get_lowest_common_module_name(names: list[str | None]) -> str: """ - Given a list of names, returns the lowest-scope common name, ignoring Nones. + Given a list of names, returns the lowest-scope common name ignoring None's. Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix @@ -200,36 +201,26 @@ def match_modules_set( mlp.expert.N.up_proj for all N. The parent context will differ from one layer to another while being the same for one expert to another. - Values are returned in order of `model.named_modules()` where possible + Each returned group is a list (of lists) with the same size + and order as `targets` while all matches for each target and + the overall order of the groups are ordered in the same way + as `model.named_modules` + E.g. the following targets would yield modules belonging to the following layers: ```python3 match_modules_set(model, ["re:*q_proj", "re:*k_proj", "re:*v_proj"]) == ( - ( - `layers.0.self_attn.q_proj`, - `layers.0.self_attn.k_proj`, - `layers.0.self_attn.v_proj`, - ), + [ + [`layers.0.self_attn.q_proj`], + [`layers.0.self_attn.k_proj`], + [`layers.0.self_attn.v_proj`], + ], ... - ( + [ `layers.32.self_attn.q_proj`, `layers.32.self_attn.k_proj`, `layers.32.self_attn.v_proj`, - ), - match_modules_set(model, ["re:*gate_up_proj", "down_proj"]) == ( - ( - [`layers.0.mlp.experts.0.gate_up_proj`, ..., - `layers.0.mlp.experts.127.gate_up_proj`] - [`layers.0.mlp.experts.0.down_proj`, ..., - `layers.0.mlp.experts.127.down_proj`] - ), - ... - ( - [`layers.32.mlp.experts.0.gate_up_proj`, ..., - `layers.32.mlp.experts.127.gate_up_proj`] - [`layers.32.mlp.experts.0.down_proj`, ..., - `layers.32.mlp.experts.127.down_proj`] - ), + ], ) ``` @@ -238,7 +229,48 @@ def match_modules_set( ```python3 for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)): fuse_norm_linears(*norm, [*q, *k, *v]) + ``` + + Alternatively for MoE you would get multiple matches + per target per group, E.g. + + ```python3 + match_modules_set(model, ["re:*up_proj", "down_proj"]) == ( + [ + [ + `layers.0.mlp.experts.0.up`, + ... + `layers.0.mlp.experts.127.up_proj` + ] + ... + [ + `layers.0.mlp.experts.0.down_proj`, + ... + `layers.0.mlp.experts.127.down_proj` + ] + ], # <- first yield + ... + [ + [ + `layers.32.mlp.experts.0.up_proj`, + ... + `layers.32.mlp.experts.127.up_proj` + ] + [ + `layers.32.mlp.experts.0.down_proj`, + ... + `layers.32.mlp.experts.127.down_proj` + ] + ], + ) + ``` + Note: if you only have one target i.e. match_modules_set(model, ["re:*up_proj") + it will yield one expert at a time rather than grouping experts by layer + This occurs because each single match fills all targets and the next expert + will change the parent context. Thus for single targets include another layer + to stabilize the parent context i.e. match_modules_set(model, ["re:*up_proj", "re:*experts") + :param model: model containing modules to match against :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes @@ -246,59 +278,67 @@ def match_modules_set( targets = targets or [] ignore = ignore or [] + # Early return for empty targets + if not targets: + return + # as we iterate through modules and try to match them with targets, - # the set of matches can be in 2 possible states: - # 0) unmatched_targets > 0, i.e. some of the targets haven't been - # matched. Keep matching until all targets have at least one match - # 1) unmatched_targets == 0 i.e. we have at least one match for each - # target. At this point we are unsure if we have a full set or if - # we need to add more matches. + # the algorithm can be in 2 possible states: + # 0) unmatched_targets > 0, i.e. some of the targets haven't been matched. + # Keep matching until all targets have at least one match + # 1) unmatched_targets == 0 i.e. we have at least one match for each target. + # At this point we are unsure if we have a full set or if we need to add + # more matches. # There are 3 things that can happen once were in state 1: - # A) found a new match with same parent_context, (add it to matches - # and keep going) - # B) found a new match with different parent_context, i.e. we found a - # match that requires a deeper parent context, this indicates that - # this match should be part of a new set. - # (yield current set [not including newest match] and go back to - # state 0) - # C) ran out of modules, we will always yield the final remaining set - # when we we've iterated through all the modules in the model. - # (yield final set then exit.) - # Note: it's possible to iterate through all the modules in the model - # while not having a full matched set if the user specified a - # bad matching, in that case something has gone wrong and we - # error - matches = dict.fromkeys(targets, []) + # A) found a new match with same parent_context, + # (add it to matches and keep going) + # B) found a new match with different parent_context, i.e. we found a match + # that requires a deeper parent context, this indicates that this match + # should be part of a new set. + # (yield current set [not including newest match] and go back to state 0) + # C) ran out of modules, we will always yield the final remaining set when + # we we've iterated through all the modules in the model. + # (yield final set then exit.) + # Note: its possible to iterate through all the modules in the model while + # not having a full matched set if the user specified a bad matching, in + # that case something has gone wrong and we error + matches = defaultdict(list) parent_context = None - unmatched_targets = len(targets) + unmatched_targets = set(targets) for name, module in model.named_modules(): for target in targets: if is_match(name, module, target, ignore): - new_parent_context = get_lowest_common_module_name(name, parent_context) + new_parent_context = get_lowest_common_module_name( + [name, parent_context] + ) # code for (B) - if unmatched_targets == 0 and new_parent_context != parent_context: + if not unmatched_targets and new_parent_context != parent_context: yield [matches[target] for target in targets] - matches = dict.fromkeys(targets, []) - parent_context = None - unmatched_targets = len(targets) + matches = defaultdict(list) + new_parent_context = name + unmatched_targets = set(targets) - # add match to matches dict and do bookkeeping - unmatched_targets -= len(matches[target]) == 0 matches[target].append(module) parent_context = new_parent_context + unmatched_targets -= {target} + # target has now been matched (this does no-op if not in set) # code for (C) - if unmatched_targets == 0: + if not unmatched_targets: yield [matches[target] for target in targets] return + # If no matches were found at all (e.g., all modules are internal), + # just return without yielding or raising an error + if unmatched_targets == set(targets): + return + raise ValueError( f"Found a final incomplete set with matches found for keys: " - f"{[t for t, m in matches if len(m)>0]} " - f"but no matches found for keys: " - f"{[t for t, m in matches if len(m)==0]}" + f"{set(targets)-unmatched_targets}" + f"but no matches found for keys: {unmatched_targets}" ) diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 71a4a4c3e..150434855 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -80,6 +80,45 @@ def __init__(self): ) +class DummyMoEModel(nn.Module): + """Test MoE model for unit tests. Weights are initialized on meta device""" + + def __init__(self, num_layers=2, num_experts=4): + try: + from accelerate import init_empty_weights + except ImportError: + pytest.skip("Skipping weight init requires accelerate") + + super().__init__() + with init_empty_weights(): + self.layers = nn.ModuleList( + [ + nn.ModuleDict( + { + "post_attention_layernorm": nn.LayerNorm(30), + "mlp": nn.ModuleDict( + { + "experts": nn.ModuleList( + [ + nn.ModuleDict( + { + "gate_proj": nn.Linear(30, 60), + "up_proj": nn.Linear(30, 60), + "down_proj": nn.Linear(60, 30), + } + ) + for _ in range(num_experts) + ] + ), + } + ), + } + ) + for _ in range(num_layers) + ] + ) + + class TestMatchName: """Test cases for _match_name function""" @@ -487,6 +526,70 @@ def test_simple_module_set(self): assert len(module_set) == 3 assert all(isinstance(*m, nn.Linear) for m in module_set) + def test_moe_module_match(self): + """Test matching MoE modules with multiple experts per layer""" + model = DummyMoEModel(num_layers=2, num_experts=4) + + # Test matching expert projections - each expert becomes its own set + # because the parent context differs between experts + targets = [ + "re:.*gate_proj$", + "re:.*up_proj$", + ] + + matches = list(match_modules_set(model, targets)) + + # Should have 8 sets (2 layers * 4 experts) + assert len(matches) == 8 + + # Each set should have 2 target lists (gate_proj, up_proj) + for expert_group in matches: + assert len(expert_group) == 2 + gate_modules, up_modules = expert_group + + # Each target should have matched 1 module (single expert) + assert len(gate_modules) == 1 + assert len(up_modules) == 1 + + # All modules should be Linear layers + assert isinstance(gate_modules[0], nn.Linear) + assert isinstance(up_modules[0], nn.Linear) + + def test_moe_with_layernorm_match(self): + """ + Test matching MoE modules with their corresponding layer norms. + Including a layer-level module (layernorm) groups all experts in that layer together. + """ + model = DummyMoEModel(num_layers=2, num_experts=3) + + # Match layer norm with expert projections - the layernorm is at layer level, + # so it establishes a common parent context for all experts in that layer + targets = [ + "re:.*post_attention_layernorm$", + "re:.*gate_proj$", + "re:.*up_proj$", + ] + + matches = list(match_modules_set(model, targets)) + + # Should have 2 layer groups (one per layer) + assert len(matches) == 2 + + for layer_group in matches: + assert len(layer_group) == 3 + norm_modules, gate_modules, up_modules = layer_group + + # LayerNorm should have 1 module (single per layer) + assert len(norm_modules) == 1 + assert isinstance(norm_modules[0], nn.LayerNorm) + + # Each projection should have 3 experts (all experts in the layer) + assert len(gate_modules) == 3 + assert len(up_modules) == 3 + assert all(isinstance(m, nn.Linear) for m in gate_modules) + assert all(isinstance(m, nn.Linear) for m in up_modules) + + def test_module_set_ordering(self): """Test that module sets maintain target ordering""" model = DummyModel() From a09a4f733927af8b559d9d167b5e672a789fa26c Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 05:05:08 +0000 Subject: [PATCH 07/14] formatting and fixes Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 46 +++++++++---------- tests/test_utils/test_match.py | 64 ++++++++++++--------------- 2 files changed, 50 insertions(+), 60 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 888c46bb6..b811f47b7 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -209,7 +209,7 @@ def match_modules_set( E.g. the following targets would yield modules belonging to the following layers: ```python3 - match_modules_set(model, ["re:*q_proj", "re:*k_proj", "re:*v_proj"]) == ( + match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == ( [ [`layers.0.self_attn.q_proj`], [`layers.0.self_attn.k_proj`], @@ -235,29 +235,36 @@ def match_modules_set( per target per group, E.g. ```python3 - match_modules_set(model, ["re:*up_proj", "down_proj"]) == ( + + targets = [ + "post_attention_layernorm", + "up_proj", + "down_proj" + ] + match_modules_set(model, targets) == ( [ + [layers.0.post_attention_layernorm], [ - `layers.0.mlp.experts.0.up`, + `layers.0.mlp.experts.0.up`, ... `layers.0.mlp.experts.127.up_proj` - ] - ... + ], [ - `layers.0.mlp.experts.0.down_proj`, + `layers.0.mlp.experts.0.down_proj`, ... `layers.0.mlp.experts.127.down_proj` ] ], # <- first yield ... [ + [layers.0.post_attention_layernorm], [ - `layers.32.mlp.experts.0.up_proj`, + `layers.32.mlp.experts.0.up_proj`, ... `layers.32.mlp.experts.127.up_proj` - ] + ], [ - `layers.32.mlp.experts.0.down_proj`, + `layers.32.mlp.experts.0.down_proj`, ... `layers.32.mlp.experts.127.down_proj` ] @@ -265,12 +272,6 @@ def match_modules_set( ) ``` - Note: if you only have one target i.e. match_modules_set(model, ["re:*up_proj") - it will yield one expert at a time rather than grouping experts by layer - This occurs because each single match fills all targets and the next expert - will change the parent context. Thus for single targets include another layer - to stabilize the parent context i.e. match_modules_set(model, ["re:*up_proj", "re:*experts") - :param model: model containing modules to match against :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes @@ -278,10 +279,6 @@ def match_modules_set( targets = targets or [] ignore = ignore or [] - # Early return for empty targets - if not targets: - return - # as we iterate through modules and try to match them with targets, # the algorithm can be in 2 possible states: # 0) unmatched_targets > 0, i.e. some of the targets haven't been matched. @@ -325,14 +322,13 @@ def match_modules_set( unmatched_targets -= {target} # target has now been matched (this does no-op if not in set) - # code for (C) - if not unmatched_targets: - yield [matches[target] for target in targets] + # never found anything + if unmatched_targets == set(targets): return - # If no matches were found at all (e.g., all modules are internal), - # just return without yielding or raising an error - if unmatched_targets == set(targets): + # code for (C) + if not unmatched_targets: # have a full matching + yield [matches[target] for target in targets] return raise ValueError( diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 150434855..0c4af3dac 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -84,39 +84,33 @@ class DummyMoEModel(nn.Module): """Test MoE model for unit tests. Weights are initialized on meta device""" def __init__(self, num_layers=2, num_experts=4): - try: - from accelerate import init_empty_weights - except ImportError: - pytest.skip("Skipping weight init requires accelerate") - super().__init__() - with init_empty_weights(): - self.layers = nn.ModuleList( - [ - nn.ModuleDict( - { - "post_attention_layernorm": nn.LayerNorm(30), - "mlp": nn.ModuleDict( - { - "experts": nn.ModuleList( - [ - nn.ModuleDict( - { - "gate_proj": nn.Linear(30, 60), - "up_proj": nn.Linear(30, 60), - "down_proj": nn.Linear(60, 30), - } - ) - for _ in range(num_experts) - ] - ), - } - ), - } - ) - for _ in range(num_layers) - ] - ) + self.layers = nn.ModuleList( + [ + nn.ModuleDict( + { + "post_attention_layernorm": nn.LayerNorm(3), + "mlp": nn.ModuleDict( + { + "experts": nn.ModuleList( + [ + nn.ModuleDict( + { + "gate_proj": nn.Linear(3, 6), + "up_proj": nn.Linear(3, 6), + "down_proj": nn.Linear(6, 3), + } + ) + for _ in range(num_experts) + ] + ), + } + ), + } + ) + for _ in range(num_layers) + ] + ) class TestMatchName: @@ -558,7 +552,8 @@ def test_moe_module_match(self): def test_moe_with_layernorm_match(self): """ Test matching MoE modules with their corresponding layer norms. - Including a layer-level module (layernorm) groups all experts in that layer together. + Including a layer-level module (layernorm) groups all experts in + that layer together. """ model = DummyMoEModel(num_layers=2, num_experts=3) @@ -589,7 +584,6 @@ def test_moe_with_layernorm_match(self): assert all(isinstance(m, nn.Linear) for m in gate_modules) assert all(isinstance(m, nn.Linear) for m in up_modules) - def test_module_set_ordering(self): """Test that module sets maintain target ordering""" model = DummyModel() @@ -623,7 +617,7 @@ def test_empty_targets_set(self): model = DummyModel() matches = list(match_modules_set(model, [])) # Should yield one empty set for each module traversed? - # with empty targets, we expect no matches + # Actually, with empty targets, we expect no matches assert len(matches) == 0 def test_module_set_with_ignore(self): From 345f353f2e3c57a991d93e670a216ab4a2cc4d94 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 05:11:48 +0000 Subject: [PATCH 08/14] formatting Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index b811f47b7..ad800a144 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -215,12 +215,12 @@ def match_modules_set( [`layers.0.self_attn.k_proj`], [`layers.0.self_attn.v_proj`], ], - ... [ - `layers.32.self_attn.q_proj`, - `layers.32.self_attn.k_proj`, - `layers.32.self_attn.v_proj`, + `layers.1.self_attn.q_proj`, + `layers.1.self_attn.k_proj`, + `layers.1.self_attn.v_proj`, ], + ... ) ``` @@ -245,30 +245,31 @@ def match_modules_set( [ [layers.0.post_attention_layernorm], [ - `layers.0.mlp.experts.0.up`, + `layers.0.mlp.experts.0.up_proj`, + `layers.0.mlp.experts.1.up_proj`, ... - `layers.0.mlp.experts.127.up_proj` ], [ `layers.0.mlp.experts.0.down_proj`, + `layers.0.mlp.experts.1.down_proj`, ... - `layers.0.mlp.experts.127.down_proj` + ] ], # <- first yield - ... [ - [layers.0.post_attention_layernorm], + [layers.1.post_attention_layernorm], [ - `layers.32.mlp.experts.0.up_proj`, + `layers.1.mlp.experts.0.up_proj`, + `layers.1.mlp.experts.1.up_proj`, ... - `layers.32.mlp.experts.127.up_proj` ], [ - `layers.32.mlp.experts.0.down_proj`, + `layers.1.mlp.experts.0.down_proj`, + `layers.1.mlp.experts.1.down_proj`, ... - `layers.32.mlp.experts.127.down_proj` ] ], + ... ) ``` From 5a878f4b44c175d8366944f27e52110168566c19 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 14:52:56 +0000 Subject: [PATCH 09/14] formatting the formatting of format Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index ad800a144..888518c65 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -237,8 +237,8 @@ def match_modules_set( ```python3 targets = [ - "post_attention_layernorm", - "up_proj", + "post_attention_layernorm", + "up_proj", "down_proj" ] match_modules_set(model, targets) == ( @@ -324,11 +324,11 @@ def match_modules_set( # target has now been matched (this does no-op if not in set) # never found anything - if unmatched_targets == set(targets): + if unmatched_targets == set(targets): return # code for (C) - if not unmatched_targets: # have a full matching + if not unmatched_targets: # have a full matching yield [matches[target] for target in targets] return From ec187be072d2efbc77f0e23d5169aea84288cfb7 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 14:58:16 +0000 Subject: [PATCH 10/14] making it look nice Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 888518c65..af0759dc5 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -162,7 +162,7 @@ def match_targets( def get_lowest_common_module_name(names: list[str | None]) -> str: """ - Given a list of names, returns the lowest-scope common name ignoring None's. + Given a list of names, returns the lowest-scope common name ignoring Nones. Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix @@ -334,7 +334,7 @@ def match_modules_set( raise ValueError( f"Found a final incomplete set with matches found for keys: " - f"{set(targets)-unmatched_targets}" + f"{set(targets) - unmatched_targets} " f"but no matches found for keys: {unmatched_targets}" ) From b2e977d153c08f8640adab479ba2399be43eb894 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 15:49:06 +0000 Subject: [PATCH 11/14] improve name to lowest_common_ancestor Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 6 +++--- tests/test_utils/test_match.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index af0759dc5..a869d9e6d 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -31,7 +31,7 @@ "match_named_parameters", "match_targets", "match_modules_set", - "get_lowest_common_module_name", + "get_lowest_common_ancestor_name", "is_match", "is_narrow_match", ] @@ -160,7 +160,7 @@ def match_targets( return matched_targets -def get_lowest_common_module_name(names: list[str | None]) -> str: +def get_lowest_common_ancestor_name(names: list[str | None]) -> str: """ Given a list of names, returns the lowest-scope common name ignoring Nones. @@ -307,7 +307,7 @@ def match_modules_set( for name, module in model.named_modules(): for target in targets: if is_match(name, module, target, ignore): - new_parent_context = get_lowest_common_module_name( + new_parent_context = get_lowest_common_ancestor_name( [name, parent_context] ) diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 0c4af3dac..86bf639be 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -20,7 +20,7 @@ # Assuming the module is named "module_matching" - adjust import as needed from compressed_tensors.utils import ( InternalModule, - get_lowest_common_module_name, + get_lowest_common_ancestor_name, is_match, is_narrow_match, match_modules_set, @@ -447,10 +447,10 @@ class InternalLinear(InternalModule, nn.Linear): class TestGetLowestCommonModuleName: - """Test cases for get_lowest_common_module_name function""" + """Test cases for get_lowest_common_ancestor_name function""" def test_multiple_modules(self): - assert "abc" == get_lowest_common_module_name( + assert "abc" == get_lowest_common_ancestor_name( [ "abc.a", "abc.b", @@ -459,14 +459,14 @@ def test_multiple_modules(self): ) def test_single_module(self): - assert "abc.abc" == get_lowest_common_module_name( + assert "abc.abc" == get_lowest_common_ancestor_name( [ "abc.abc", ] ) def test_substring_modules(self): - assert "abc" == get_lowest_common_module_name( + assert "abc" == get_lowest_common_ancestor_name( [ "abc.abc", "abc.ab", @@ -474,7 +474,7 @@ def test_substring_modules(self): ) def test_parent_and_child_modules(self): - assert "abc.abc" == get_lowest_common_module_name( + assert "abc.abc" == get_lowest_common_ancestor_name( [ "abc.abc.ab", "abc.abc", @@ -482,7 +482,7 @@ def test_parent_and_child_modules(self): ) def test_root(self): - assert "" == get_lowest_common_module_name( + assert "" == get_lowest_common_ancestor_name( [ "abc.abc", "b.abc", @@ -490,7 +490,7 @@ def test_root(self): ) def test_ignore_none(self): - assert "abc.abc" == get_lowest_common_module_name( + assert "abc.abc" == get_lowest_common_ancestor_name( [ "abc.abc", None, From 5f8481be9042bdfe40270f0a82ef2f536e916281 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 16:55:14 +0000 Subject: [PATCH 12/14] check for multiple matches, formatting, List typehint Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index a869d9e6d..e259d34cf 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -190,7 +190,7 @@ def match_modules_set( model: torch.nn.Module, targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, -) -> Generator[Iterable[Iterable[torch.nn.Module]]]: +) -> Generator[List[List[torch.nn.Module]]]: """ Yields modules grouped by parent context. @@ -216,9 +216,9 @@ def match_modules_set( [`layers.0.self_attn.v_proj`], ], [ - `layers.1.self_attn.q_proj`, - `layers.1.self_attn.k_proj`, - `layers.1.self_attn.v_proj`, + [`layers.1.self_attn.q_proj`], + [`layers.1.self_attn.k_proj`], + [`layers.1.self_attn.v_proj`], ], ... ) @@ -305,6 +305,7 @@ def match_modules_set( unmatched_targets = set(targets) for name, module in model.named_modules(): + matched_targets_for_cur_module = {} for target in targets: if is_match(name, module, target, ignore): new_parent_context = get_lowest_common_ancestor_name( @@ -321,7 +322,14 @@ def match_modules_set( matches[target].append(module) parent_context = new_parent_context unmatched_targets -= {target} - # target has now been matched (this does no-op if not in set) + matched_targets_for_cur_module += {target} + + if len(matched_targets_for_cur_module) > 1: + _LOGGER.warning( + f"found multiple matching targets for module: {name} which matched to " + f"targets: {matched_targets_for_cur_module}. " + " this can result in unexpected behavior if not intended" + ) # never found anything if unmatched_targets == set(targets): From b8ed7453936e3155bb5e38d63bb1458a33322102 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 18:13:24 +0000 Subject: [PATCH 13/14] error instead of warn Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index e259d34cf..c8b25ff10 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -190,6 +190,7 @@ def match_modules_set( model: torch.nn.Module, targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, + error_on_module_rematch: bool = True, ) -> Generator[List[List[torch.nn.Module]]]: """ Yields modules grouped by parent context. @@ -276,6 +277,8 @@ def match_modules_set( :param model: model containing modules to match against :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes + :param error_on_module_rematch: if True, errors when a module gets + matched to multiple targets, if False, no error. (Defaults to True) """ targets = targets or [] ignore = ignore or [] @@ -324,11 +327,11 @@ def match_modules_set( unmatched_targets -= {target} matched_targets_for_cur_module += {target} - if len(matched_targets_for_cur_module) > 1: - _LOGGER.warning( - f"found multiple matching targets for module: {name} which matched to " - f"targets: {matched_targets_for_cur_module}. " - " this can result in unexpected behavior if not intended" + if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch: + raise ValueError( + f"module: {name} was matched with multiple targets: " + f"{matched_targets_for_cur_module} which is unexpected " + "disable this check by setting `error_on_module_rematch = False`" ) # never found anything From 308585a4fbcceb1fba4395341946e48513039cd1 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 18:23:36 +0000 Subject: [PATCH 14/14] fix Summary Signed-off-by: HDCharles --- src/compressed_tensors/utils/match.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index c8b25ff10..72f61945e 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -308,7 +308,7 @@ def match_modules_set( unmatched_targets = set(targets) for name, module in model.named_modules(): - matched_targets_for_cur_module = {} + matched_targets_for_cur_module = set() for target in targets: if is_match(name, module, target, ignore): new_parent_context = get_lowest_common_ancestor_name( @@ -325,7 +325,7 @@ def match_modules_set( matches[target].append(module) parent_context = new_parent_context unmatched_targets -= {target} - matched_targets_for_cur_module += {target} + matched_targets_for_cur_module |= {target} if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch: raise ValueError(