|
| 1 | +from abc import ABC, abstractmethod |
| 2 | +from typing import Callable, List, Union |
| 3 | + |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from tensorrt_llm._torch.model_config import ModelConfig, TConfig |
| 7 | +from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM |
| 8 | + |
| 9 | + |
| 10 | +class BaseWeightMapper(ABC): |
| 11 | + |
| 12 | + def __init__(self): |
| 13 | + self._callbacks: list[Callable] = [] |
| 14 | + self._mapping: dict = {} |
| 15 | + self._skip_modules = [] |
| 16 | + self._model: Union[nn.Module, DecoderModelForCausalLM] | None = None |
| 17 | + self._config: TConfig | None = None |
| 18 | + |
| 19 | + def init_model_and_config(self, model: Union[nn.Module, |
| 20 | + DecoderModelForCausalLM], |
| 21 | + config: TConfig): |
| 22 | + self._model = model |
| 23 | + self._config = config |
| 24 | + |
| 25 | + if not hasattr(model, 'model_config') or not isinstance( |
| 26 | + model.model_config, ModelConfig): |
| 27 | + raise ValueError("model must have a model_config attribute") |
| 28 | + if not hasattr(model, 'config'): |
| 29 | + raise ValueError("model must have a config attribute") |
| 30 | + |
| 31 | + self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size |
| 32 | + self._num_kv_heads = model.config.num_key_value_heads if hasattr( |
| 33 | + model.config, 'num_key_value_heads' |
| 34 | + ) and model.config.num_key_value_heads is not None else model.config.num_attention_heads |
| 35 | + |
| 36 | + self.map_weights() |
| 37 | + |
| 38 | + def cleanup(self) -> None: |
| 39 | + self._model = None |
| 40 | + self._config = None |
| 41 | + |
| 42 | + @abstractmethod |
| 43 | + def map_weights(self) -> None: |
| 44 | + """ |
| 45 | + Maps weights from TRT-LLM to a source state dictionary (e.g., Hugging Face) |
| 46 | + """ |
| 47 | + |
| 48 | + @abstractmethod |
| 49 | + def apply_callbacks(self, module: nn.Module, module_name: str, |
| 50 | + module_names_breakdown: list[str], |
| 51 | + weights: dict) -> list[dict]: |
| 52 | + """ |
| 53 | + Applies a series of transformation functions to an internal representation |
| 54 | + of weights or to guide the mapping process. The exact behavior might depend |
| 55 | + on the implementation (e.g., storing callbacks to be applied later). |
| 56 | +
|
| 57 | + Args: |
| 58 | + module: The module to apply the callbacks to |
| 59 | + module_name: The specific module name (e.g., 'qkv_proj', 'gate_up_proj') |
| 60 | + module_names_breakdown: List of module path components for building full paths |
| 61 | + weights: The weights dictionary to process |
| 62 | + """ |
| 63 | + |
| 64 | + def rename_by_params_map(self, params_map: dict[str, str], |
| 65 | + weights: dict) -> dict: |
| 66 | + """ |
| 67 | + Rename weight keys according to regex pattern matching. |
| 68 | +
|
| 69 | + Args: |
| 70 | + pattern_mapping: A dictionary mapping regex patterns to replacement strings. The key is HF name pattern, and the value is corresponding TRT-LLM name pattern. |
| 71 | + The patterns will be used to match keys in the weights dict and replace |
| 72 | + them according to the replacement string, which can use regex backreferences. |
| 73 | + Example: |
| 74 | + HF name: vision_model.encoder.layers.1.self_attn.out_proj.{weight,bias} |
| 75 | + TRT-LLM name: vision_model.encoder.layers.1.self_attn.o_proj.{weight,bias} |
| 76 | + Then the pattern_mapping could be: |
| 77 | + pattern_mapping = { |
| 78 | + r'(.*?)out_proj(.*)': r'\1o_proj\2' |
| 79 | + } |
| 80 | + weights: A dictionary of weights |
| 81 | +
|
| 82 | + Returns: |
| 83 | + A dictionary of weights with renamed keys |
| 84 | + """ |
| 85 | + import re |
| 86 | + |
| 87 | + # Create a new dictionary to store the renamed weights |
| 88 | + renamed_weights = {} |
| 89 | + |
| 90 | + # Keep track of keys that have been matched by a pattern |
| 91 | + matched_keys = set() |
| 92 | + |
| 93 | + # Process each key in the weights dictionary |
| 94 | + for key in list(weights.keys()): |
| 95 | + # Check each pattern for a match |
| 96 | + for pattern, replacement in params_map.items(): |
| 97 | + if re.match(pattern, key): |
| 98 | + # Create the new key by applying the regex replacement |
| 99 | + new_key = re.sub(pattern, replacement, key) |
| 100 | + # Store the weight with the new key |
| 101 | + renamed_weights[new_key] = weights[key] |
| 102 | + matched_keys.add(key) |
| 103 | + break |
| 104 | + |
| 105 | + # If the key wasn't matched by any pattern, keep it as is |
| 106 | + if key not in matched_keys: |
| 107 | + renamed_weights[key] = weights[key] |
| 108 | + |
| 109 | + return renamed_weights |
| 110 | + |
| 111 | + def preprocess_weights(self, weights: dict) -> dict: |
| 112 | + """ |
| 113 | + Preprocess weights before starting the loading process. |
| 114 | + """ |
| 115 | + ... |
| 116 | + |
| 117 | + def handle_manual_copy(self, module_name: str, module_weights: dict, n: str, |
| 118 | + p: nn.Parameter) -> None: |
| 119 | + p.data.copy_(module_weights[n][:]) |
| 120 | + |
| 121 | + def does_require_special_handling(self, module_name: str) -> bool: |
| 122 | + return module_name in self.mapping |
| 123 | + |
| 124 | + def is_special_instance_module(self, module: nn.Module) -> bool: |
| 125 | + return False |
| 126 | + |
| 127 | + def handle_special_instance_module(self, module: nn.Module, |
| 128 | + module_name: str, |
| 129 | + module_weights: dict) -> None: |
| 130 | + raise NotImplementedError() |
| 131 | + |
| 132 | + @property |
| 133 | + def skip_modules(self) -> List[str]: |
| 134 | + return self._skip_modules |
| 135 | + |
| 136 | + def add_skip_modules(self, value: List[str]) -> None: |
| 137 | + self._skip_modules.extend(value) |
| 138 | + |
| 139 | + def should_skip_module(self, module_name: str) -> bool: |
| 140 | + return any(skip_module in module_name |
| 141 | + for skip_module in self._skip_modules) |
| 142 | + |
| 143 | + def filter_weights(self, prefix: str, weights: dict) -> dict: |
| 144 | + result = {} |
| 145 | + for k, v in weights.items(): |
| 146 | + if k.startswith(prefix): |
| 147 | + new_k = k[len(prefix) + 1:] |
| 148 | + result[new_k] = v |
| 149 | + return result |
| 150 | + |
| 151 | + @property |
| 152 | + def mapping(self) -> dict: |
| 153 | + return self._mapping |
| 154 | + |
| 155 | + @property |
| 156 | + def config(self) -> TConfig: |
| 157 | + if self._config is None: |
| 158 | + raise RuntimeError("Weight mapper is not initialized") |
| 159 | + return self._config |
| 160 | + |
| 161 | + @property |
| 162 | + def model(self) -> Union[nn.Module, DecoderModelForCausalLM]: |
| 163 | + if self._model is None: |
| 164 | + raise RuntimeError("Weight mapper is not initialized") |
| 165 | + return self._model |
0 commit comments