Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 164 additions & 35 deletions src/compressed_tensors/utils/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import logging
import os
import re
from collections import defaultdict
from collections.abc import Generator
from typing import Iterable, List, Mapping, Optional, Tuple, Union

Expand All @@ -29,6 +31,7 @@
"match_named_parameters",
"match_targets",
"match_modules_set",
"get_lowest_common_ancestor_name",
"is_match",
"is_narrow_match",
]
Expand Down Expand Up @@ -157,68 +160,194 @@ def match_targets(
return matched_targets


def get_lowest_common_ancestor_name(names: list[str | None]) -> str:
"""
Given a list of names, returns the lowest-scope common name ignoring Nones.

Implementation is a small alteration of os.path.commonprefix
https://docs.python.org/3/library/os.path.html#os.path.commonprefix

([s1, s2]->prefix->result)
# case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc
# case 1: single module: [abc.] -> .abc. -> abc
# case 2: substring modules: [abc., ab.] -> .ab -> ""
# case 3: parent & child: [ab., ab.a.] -> .ab. -> ab
"""
names = [name for name in names if name is not None]
if len(names) == 0:
return ""

# 1) find longest shared prefix
s1 = "." + min(names) + "."
s2 = "." + max(names) + "."
common_prefix = os.path.commonprefix([s1, s2])
# 2) throw away right most dot and name fragment, throw away leftmost char
# ".keep.thro" -> "keep", "." -> ""
return common_prefix[1 : common_prefix.rfind(".")]


def match_modules_set(
model: torch.nn.Module,
targets: Optional[Iterable[str]],
ignore: Optional[Iterable[str]] = None,
) -> Generator[Iterable[torch.nn.Module]]:
error_on_module_rematch: bool = True,
) -> Generator[List[List[torch.nn.Module]]]:
"""
Yields modules grouped with the same order and size as `targets`.
Values are returned in order of `model.named_modules()`
Yields modules grouped by parent context.

We group by parent context so that we can return ALL matches of a
specific target that can be paired with another target. This is most
relevant in the case of MoE modules with multiple modules for each
expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj,
mlp.expert.N.up_proj for all N. The parent context will differ from
one layer to another while being the same for one expert to another.

E.g. the following targets would yield module belonging to the following layers:
Each returned group is a list (of lists) with the same size
and order as `targets` while all matches for each target and
the overall order of the groups are ordered in the same way
as `model.named_modules`


E.g. the following targets would yield modules belonging to the following layers:
```python3
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
(
`model.layers.0.self_attn.q_proj`,
`model.layers.0.self_attn.k_proj`,
`model.layers.0.self_attn.v_proj`,
),
(
`model.layers.1.self_attn.q_proj`,
`model.layers.1.self_attn.k_proj`,
`model.layers.1.self_attn.v_proj`,
),
[
[`layers.0.self_attn.q_proj`],
[`layers.0.self_attn.k_proj`],
[`layers.0.self_attn.v_proj`],
],
[
[`layers.1.self_attn.q_proj`],
[`layers.1.self_attn.k_proj`],
[`layers.1.self_attn.v_proj`],
],
...
(
`model.layers.32.self_attn.q_proj`,
`model.layers.32.self_attn.k_proj`,
`model.layers.32.self_attn.v_proj`,
),
)
```

This can be used to match layers to their corresponding downstream counterparts.
For example, matching layer norms to their subsequent linear layers
```python3
for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
fuse_norm_linears(norm, [q, k, v])
fuse_norm_linears(*norm, [*q, *k, *v])
Comment thread
HDCharles marked this conversation as resolved.
```

Alternatively for MoE you would get multiple matches
per target per group, E.g.

```python3

targets = [
"post_attention_layernorm",
"up_proj",
"down_proj"
]
match_modules_set(model, targets) == (
[
[layers.0.post_attention_layernorm],
[
`layers.0.mlp.experts.0.up_proj`,
`layers.0.mlp.experts.1.up_proj`,
...
],
[
`layers.0.mlp.experts.0.down_proj`,
`layers.0.mlp.experts.1.down_proj`,
...

]
], # <- first yield
[
[layers.1.post_attention_layernorm],
[
`layers.1.mlp.experts.0.up_proj`,
`layers.1.mlp.experts.1.up_proj`,
...
],
[
`layers.1.mlp.experts.0.down_proj`,
`layers.1.mlp.experts.1.down_proj`,
...
]
],
...
)
```

:param model: model containing modules to match against
:param targets: target strings, potentially containing "re:" prefixes
:param ignore: targets to ignore, potentially containing "re:" prefixes
:param error_on_module_rematch: if True, errors when a module gets
matched to multiple targets, if False, no error. (Defaults to True)
"""
targets = targets or []
ignore = ignore or []

matches = dict.fromkeys(targets, None)
# as we iterate through modules and try to match them with targets,
# the algorithm can be in 2 possible states:
# 0) unmatched_targets > 0, i.e. some of the targets haven't been matched.
# Keep matching until all targets have at least one match
# 1) unmatched_targets == 0 i.e. we have at least one match for each target.
# At this point we are unsure if we have a full set or if we need to add
# more matches.
# There are 3 things that can happen once were in state 1:
# A) found a new match with same parent_context,
# (add it to matches and keep going)
# B) found a new match with different parent_context, i.e. we found a match
# that requires a deeper parent context, this indicates that this match
# should be part of a new set.
# (yield current set [not including newest match] and go back to state 0)
# C) ran out of modules, we will always yield the final remaining set when
# we we've iterated through all the modules in the model.
# (yield final set then exit.)
# Note: its possible to iterate through all the modules in the model while
# not having a full matched set if the user specified a bad matching, in
# that case something has gone wrong and we error
matches = defaultdict(list)
parent_context = None
unmatched_targets = set(targets)

for name, module in model.named_modules():
# match until we get a full set
matched_targets_for_cur_module = set()
for target in targets:
if is_match(name, module, target, ignore):
if matches[target] is not None:
raise ValueError(f"Matched a {target} twice before completing set")
matches[target] = module

# once we have a full set, yield and reset
if targets and all((matches[target] is not None for target in targets)):
yield [matches[target] for target in targets] # ensure correct ordering
matches = dict.fromkeys(targets, None)

# check that none are left over
unmatched_keys = [match for match, value in matches.items() if value is not None]
if len(unmatched_keys):
raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
new_parent_context = get_lowest_common_ancestor_name(
[name, parent_context]
)

# code for (B)
if not unmatched_targets and new_parent_context != parent_context:
yield [matches[target] for target in targets]
matches = defaultdict(list)
new_parent_context = name
unmatched_targets = set(targets)

matches[target].append(module)
parent_context = new_parent_context
unmatched_targets -= {target}
matched_targets_for_cur_module |= {target}

if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch:
raise ValueError(
f"module: {name} was matched with multiple targets: "
f"{matched_targets_for_cur_module} which is unexpected "
"disable this check by setting `error_on_module_rematch = False`"
)

# never found anything
if unmatched_targets == set(targets):
return

# code for (C)
if not unmatched_targets: # have a full matching
yield [matches[target] for target in targets]
return

raise ValueError(
f"Found a final incomplete set with matches found for keys: "
f"{set(targets) - unmatched_targets} "
f"but no matches found for keys: {unmatched_targets}"
)
Comment thread
kylesayrs marked this conversation as resolved.


def is_match(
Expand Down
Loading