diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index f26400b0b..72f61945e 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -13,7 +13,9 @@ # limitations under the License. import logging +import os import re +from collections import defaultdict from collections.abc import Generator from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -29,6 +31,7 @@ "match_named_parameters", "match_targets", "match_modules_set", + "get_lowest_common_ancestor_name", "is_match", "is_narrow_match", ] @@ -157,34 +160,68 @@ def match_targets( return matched_targets +def get_lowest_common_ancestor_name(names: list[str | None]) -> str: + """ + Given a list of names, returns the lowest-scope common name ignoring Nones. + + Implementation is a small alteration of os.path.commonprefix + https://docs.python.org/3/library/os.path.html#os.path.commonprefix + + ([s1, s2]->prefix->result) + # case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc + # case 1: single module: [abc.] -> .abc. -> abc + # case 2: substring modules: [abc., ab.] -> .ab -> "" + # case 3: parent & child: [ab., ab.a.] -> .ab. -> ab + """ + names = [name for name in names if name is not None] + if len(names) == 0: + return "" + + # 1) find longest shared prefix + s1 = "." + min(names) + "." + s2 = "." + max(names) + "." + common_prefix = os.path.commonprefix([s1, s2]) + # 2) throw away right most dot and name fragment, throw away leftmost char + # ".keep.thro" -> "keep", "." -> "" + return common_prefix[1 : common_prefix.rfind(".")] + + def match_modules_set( model: torch.nn.Module, targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, -) -> Generator[Iterable[torch.nn.Module]]: + error_on_module_rematch: bool = True, +) -> Generator[List[List[torch.nn.Module]]]: """ - Yields modules grouped with the same order and size as `targets`. - Values are returned in order of `model.named_modules()` + Yields modules grouped by parent context. + + We group by parent context so that we can return ALL matches of a + specific target that can be paired with another target. This is most + relevant in the case of MoE modules with multiple modules for each + expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj, + mlp.expert.N.up_proj for all N. The parent context will differ from + one layer to another while being the same for one expert to another. - E.g. the following targets would yield module belonging to the following layers: + Each returned group is a list (of lists) with the same size + and order as `targets` while all matches for each target and + the overall order of the groups are ordered in the same way + as `model.named_modules` + + + E.g. the following targets would yield modules belonging to the following layers: ```python3 match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == ( - ( - `model.layers.0.self_attn.q_proj`, - `model.layers.0.self_attn.k_proj`, - `model.layers.0.self_attn.v_proj`, - ), - ( - `model.layers.1.self_attn.q_proj`, - `model.layers.1.self_attn.k_proj`, - `model.layers.1.self_attn.v_proj`, - ), + [ + [`layers.0.self_attn.q_proj`], + [`layers.0.self_attn.k_proj`], + [`layers.0.self_attn.v_proj`], + ], + [ + [`layers.1.self_attn.q_proj`], + [`layers.1.self_attn.k_proj`], + [`layers.1.self_attn.v_proj`], + ], ... - ( - `model.layers.32.self_attn.q_proj`, - `model.layers.32.self_attn.k_proj`, - `model.layers.32.self_attn.v_proj`, - ), ) ``` @@ -192,33 +229,125 @@ def match_modules_set( For example, matching layer norms to their subsequent linear layers ```python3 for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)): - fuse_norm_linears(norm, [q, k, v]) + fuse_norm_linears(*norm, [*q, *k, *v]) + ``` + + Alternatively for MoE you would get multiple matches + per target per group, E.g. + + ```python3 + + targets = [ + "post_attention_layernorm", + "up_proj", + "down_proj" + ] + match_modules_set(model, targets) == ( + [ + [layers.0.post_attention_layernorm], + [ + `layers.0.mlp.experts.0.up_proj`, + `layers.0.mlp.experts.1.up_proj`, + ... + ], + [ + `layers.0.mlp.experts.0.down_proj`, + `layers.0.mlp.experts.1.down_proj`, + ... + + ] + ], # <- first yield + [ + [layers.1.post_attention_layernorm], + [ + `layers.1.mlp.experts.0.up_proj`, + `layers.1.mlp.experts.1.up_proj`, + ... + ], + [ + `layers.1.mlp.experts.0.down_proj`, + `layers.1.mlp.experts.1.down_proj`, + ... + ] + ], + ... + ) + ``` :param model: model containing modules to match against :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes + :param error_on_module_rematch: if True, errors when a module gets + matched to multiple targets, if False, no error. (Defaults to True) """ targets = targets or [] ignore = ignore or [] - matches = dict.fromkeys(targets, None) + # as we iterate through modules and try to match them with targets, + # the algorithm can be in 2 possible states: + # 0) unmatched_targets > 0, i.e. some of the targets haven't been matched. + # Keep matching until all targets have at least one match + # 1) unmatched_targets == 0 i.e. we have at least one match for each target. + # At this point we are unsure if we have a full set or if we need to add + # more matches. + # There are 3 things that can happen once were in state 1: + # A) found a new match with same parent_context, + # (add it to matches and keep going) + # B) found a new match with different parent_context, i.e. we found a match + # that requires a deeper parent context, this indicates that this match + # should be part of a new set. + # (yield current set [not including newest match] and go back to state 0) + # C) ran out of modules, we will always yield the final remaining set when + # we we've iterated through all the modules in the model. + # (yield final set then exit.) + # Note: its possible to iterate through all the modules in the model while + # not having a full matched set if the user specified a bad matching, in + # that case something has gone wrong and we error + matches = defaultdict(list) + parent_context = None + unmatched_targets = set(targets) + for name, module in model.named_modules(): - # match until we get a full set + matched_targets_for_cur_module = set() for target in targets: if is_match(name, module, target, ignore): - if matches[target] is not None: - raise ValueError(f"Matched a {target} twice before completing set") - matches[target] = module - - # once we have a full set, yield and reset - if targets and all((matches[target] is not None for target in targets)): - yield [matches[target] for target in targets] # ensure correct ordering - matches = dict.fromkeys(targets, None) - - # check that none are left over - unmatched_keys = [match for match, value in matches.items() if value is not None] - if len(unmatched_keys): - raise ValueError(f"Unable to match targets into set: {unmatched_keys}") + new_parent_context = get_lowest_common_ancestor_name( + [name, parent_context] + ) + + # code for (B) + if not unmatched_targets and new_parent_context != parent_context: + yield [matches[target] for target in targets] + matches = defaultdict(list) + new_parent_context = name + unmatched_targets = set(targets) + + matches[target].append(module) + parent_context = new_parent_context + unmatched_targets -= {target} + matched_targets_for_cur_module |= {target} + + if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch: + raise ValueError( + f"module: {name} was matched with multiple targets: " + f"{matched_targets_for_cur_module} which is unexpected " + "disable this check by setting `error_on_module_rematch = False`" + ) + + # never found anything + if unmatched_targets == set(targets): + return + + # code for (C) + if not unmatched_targets: # have a full matching + yield [matches[target] for target in targets] + return + + raise ValueError( + f"Found a final incomplete set with matches found for keys: " + f"{set(targets) - unmatched_targets} " + f"but no matches found for keys: {unmatched_targets}" + ) def is_match( diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 1129120c6..86bf639be 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -20,6 +20,7 @@ # Assuming the module is named "module_matching" - adjust import as needed from compressed_tensors.utils import ( InternalModule, + get_lowest_common_ancestor_name, is_match, is_narrow_match, match_modules_set, @@ -79,6 +80,39 @@ def __init__(self): ) +class DummyMoEModel(nn.Module): + """Test MoE model for unit tests. Weights are initialized on meta device""" + + def __init__(self, num_layers=2, num_experts=4): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.ModuleDict( + { + "post_attention_layernorm": nn.LayerNorm(3), + "mlp": nn.ModuleDict( + { + "experts": nn.ModuleList( + [ + nn.ModuleDict( + { + "gate_proj": nn.Linear(3, 6), + "up_proj": nn.Linear(3, 6), + "down_proj": nn.Linear(6, 3), + } + ) + for _ in range(num_experts) + ] + ), + } + ), + } + ) + for _ in range(num_layers) + ] + ) + + class TestMatchName: """Test cases for _match_name function""" @@ -412,6 +446,58 @@ class InternalLinear(InternalModule, nn.Linear): assert len(matches) == 0 +class TestGetLowestCommonModuleName: + """Test cases for get_lowest_common_ancestor_name function""" + + def test_multiple_modules(self): + assert "abc" == get_lowest_common_ancestor_name( + [ + "abc.a", + "abc.b", + "abc.c", + ] + ) + + def test_single_module(self): + assert "abc.abc" == get_lowest_common_ancestor_name( + [ + "abc.abc", + ] + ) + + def test_substring_modules(self): + assert "abc" == get_lowest_common_ancestor_name( + [ + "abc.abc", + "abc.ab", + ] + ) + + def test_parent_and_child_modules(self): + assert "abc.abc" == get_lowest_common_ancestor_name( + [ + "abc.abc.ab", + "abc.abc", + ] + ) + + def test_root(self): + assert "" == get_lowest_common_ancestor_name( + [ + "abc.abc", + "b.abc", + ] + ) + + def test_ignore_none(self): + assert "abc.abc" == get_lowest_common_ancestor_name( + [ + "abc.abc", + None, + ] + ) + + class TestMatchModulesSet: """Test cases for match_modules_set function""" @@ -432,7 +518,71 @@ def test_simple_module_set(self): # Each set should have 3 modules for module_set in matches: assert len(module_set) == 3 - assert all(isinstance(m, nn.Linear) for m in module_set) + assert all(isinstance(*m, nn.Linear) for m in module_set) + + def test_moe_module_match(self): + """Test matching MoE modules with multiple experts per layer""" + model = DummyMoEModel(num_layers=2, num_experts=4) + + # Test matching expert projections - each expert becomes its own set + # because the parent context differs between experts + targets = [ + "re:.*gate_proj$", + "re:.*up_proj$", + ] + + matches = list(match_modules_set(model, targets)) + + # Should have 8 sets (2 layers * 4 experts) + assert len(matches) == 8 + + # Each set should have 2 target lists (gate_proj, up_proj) + for expert_group in matches: + assert len(expert_group) == 2 + gate_modules, up_modules = expert_group + + # Each target should have matched 1 module (single expert) + assert len(gate_modules) == 1 + assert len(up_modules) == 1 + + # All modules should be Linear layers + assert isinstance(gate_modules[0], nn.Linear) + assert isinstance(up_modules[0], nn.Linear) + + def test_moe_with_layernorm_match(self): + """ + Test matching MoE modules with their corresponding layer norms. + Including a layer-level module (layernorm) groups all experts in + that layer together. + """ + model = DummyMoEModel(num_layers=2, num_experts=3) + + # Match layer norm with expert projections - the layernorm is at layer level, + # so it establishes a common parent context for all experts in that layer + targets = [ + "re:.*post_attention_layernorm$", + "re:.*gate_proj$", + "re:.*up_proj$", + ] + + matches = list(match_modules_set(model, targets)) + + # Should have 2 layer groups (one per layer) + assert len(matches) == 2 + + for layer_group in matches: + assert len(layer_group) == 3 + norm_modules, gate_modules, up_modules = layer_group + + # LayerNorm should have 1 module (single per layer) + assert len(norm_modules) == 1 + assert isinstance(norm_modules[0], nn.LayerNorm) + + # Each projection should have 3 experts (all experts in the layer) + assert len(gate_modules) == 3 + assert len(up_modules) == 3 + assert all(isinstance(m, nn.Linear) for m in gate_modules) + assert all(isinstance(m, nn.Linear) for m in up_modules) def test_module_set_ordering(self): """Test that module sets maintain target ordering""" @@ -448,6 +598,7 @@ def test_module_set_ordering(self): for module_set in matches: # Check that modules are returned in target order (v, q, k) v_proj, q_proj, k_proj = module_set + v_proj, q_proj, k_proj = *v_proj, *q_proj, *k_proj # We can't easily check the exact modules, but can check they're all Linear assert all(isinstance(m, nn.Linear) for m in [v_proj, q_proj, k_proj]) @@ -456,18 +607,8 @@ def test_incomplete_set_error(self): model = DummyModel() targets = ["layer1", "nonexistent_module"] - with pytest.raises(ValueError, match="Unable to match targets into set"): - list(match_modules_set(model, targets)) - - def test_duplicate_match_error(self): - """Test error when same target matches multiple times before set completion""" - model = DummyModel() - # This should cause the same target to match multiple times - # before we can complete a set - targets = ["Linear", "Linear"] # Two identical targets - with pytest.raises( - ValueError, match="Matched a .* twice before completing set" + ValueError, match="Found a final incomplete set with matches found for keys" ): list(match_modules_set(model, targets))