Skip to content

Commit e0836f9

Browse files
authored
[TRTLLM-5493] Add core infrastructure to enable loading of custom checkpoint formats (#5372)
Signed-off-by: Shahar Mor <[email protected]>
1 parent 9354114 commit e0836f9

39 files changed

+1202
-441
lines changed

tensorrt_llm/_torch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .llm import LLM
22
from .model_config import MoeLoadBalancerConfig
3+
from .models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader
34

4-
__all__ = ["LLM", "MoeLoadBalancerConfig"]
5+
__all__ = ["LLM", "MoeLoadBalancerConfig", "BaseCheckpointLoader"]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .base_checkpoint_loader import BaseCheckpointLoader
2+
from .hf.checkpoint_loader import HfCheckpointLoader
3+
from .hf.config_loader import HfConfigLoader
4+
from .hf.gemma3_weight_mapper import Gemma3HfWeightMapper
5+
from .hf.llama4_weight_mapper import Llama4HfWeightMapper
6+
from .hf.mixtral_weight_mapper import MixtralHfWeightMapper
7+
from .hf.nemotron_h_weight_mapper import NemotronHHfWeightMapper
8+
from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper
9+
from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
10+
from .hf.weight_loader import HfWeightLoader
11+
from .hf.weight_mapper import HfWeightMapper
12+
13+
__all__ = [
14+
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper",
15+
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
16+
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
17+
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper"
18+
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Optional
2+
3+
from tensorrt_llm._torch.models.modeling_utils import MODEL_CLASS_MAPPER_MAPPING
4+
5+
6+
class AutoCheckpointMapper():
7+
8+
@staticmethod
9+
def get(format: str, name: Optional[str] = None) -> "BaseWeightMapper":
10+
if name is not None:
11+
try:
12+
return MODEL_CLASS_MAPPER_MAPPING[f'{name}_{format}']()
13+
except KeyError: # no mapper for this model architecture, resort to default
14+
# TODO smor- a potential bug here, if the class isn't added to __init__, it will return the default mapper
15+
return MODEL_CLASS_MAPPER_MAPPING[format]()
16+
else:
17+
return MODEL_CLASS_MAPPER_MAPPING[format]()
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
from torch import nn
5+
6+
from tensorrt_llm._torch.model_config import ModelConfig
7+
from tensorrt_llm._torch.models.checkpoints.auto_mapper import \
8+
AutoCheckpointMapper
9+
from tensorrt_llm._torch.models.checkpoints.base_config_loader import \
10+
BaseConfigLoader
11+
from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
12+
BaseWeightLoader
13+
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
14+
BaseWeightMapper
15+
from tensorrt_llm._torch.models.modeling_utils import \
16+
CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING
17+
18+
19+
class BaseCheckpointLoader(ABC):
20+
21+
@abstractmethod
22+
def get_default_weight_loader(self) -> BaseWeightLoader:
23+
raise NotImplementedError
24+
25+
@abstractmethod
26+
def get_default_config_loader(self) -> BaseConfigLoader:
27+
raise NotImplementedError
28+
29+
@abstractmethod
30+
def cleanup(self) -> None:
31+
raise NotImplementedError
32+
33+
@property
34+
@abstractmethod
35+
def weight_loader(self) -> BaseWeightLoader:
36+
...
37+
38+
@property
39+
@abstractmethod
40+
def weight_mapper(self) -> BaseWeightMapper:
41+
...
42+
43+
@property
44+
@abstractmethod
45+
def config_loader(self) -> BaseConfigLoader:
46+
...
47+
48+
@property
49+
@abstractmethod
50+
def checkpoint_format(self) -> str:
51+
...
52+
53+
def load_config(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
54+
return self.config_loader.load(checkpoint_dir, **kwargs)
55+
56+
def load_weights(self, checkpoint_dir: str, **kwargs) -> dict[str, Any]:
57+
return self.weight_loader.load_weights(checkpoint_dir, **kwargs)
58+
59+
@classmethod
60+
def get(cls, checkpoint_format: str, **kwargs) -> "BaseCheckpointLoader":
61+
try:
62+
return CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING[checkpoint_format](
63+
**kwargs)
64+
except KeyError:
65+
raise ValueError(
66+
f"Checkpoint loader for format {checkpoint_format} not found, "
67+
f"available formats are: {CHECKPOINT_LOADER_FORMAT_DEFAULT_MAPPING.keys()}"
68+
)
69+
70+
def get_initilized_weight_mapper(self, model: nn.Module,
71+
config: ModelConfig) -> BaseWeightMapper:
72+
weight_mapper = None
73+
if self.weight_mapper is not None:
74+
self.weight_mapper.init_model_and_config(model, config)
75+
return self.weight_mapper
76+
else:
77+
# The name of the registered mapper should be the model architecture
78+
if config.pretrained_config and config.pretrained_config.architectures:
79+
model_arch = config.pretrained_config.architectures[0]
80+
else:
81+
raise ValueError(
82+
"Cannot determine model architecture from config")
83+
weight_mapper = AutoCheckpointMapper.get(self.checkpoint_format,
84+
model_arch)
85+
weight_mapper.init_model_and_config(model, config)
86+
self.weight_mapper = weight_mapper
87+
return weight_mapper
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from abc import ABC, abstractmethod
2+
3+
from tensorrt_llm._torch.model_config import ModelConfig
4+
5+
6+
class BaseConfigLoader(ABC):
7+
8+
@abstractmethod
9+
def load(self, checkpoint_dir: str, **kwargs) -> ModelConfig:
10+
pass
11+
12+
def cleanup(self) -> None:
13+
pass
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
5+
class BaseWeightLoader(ABC):
6+
7+
@abstractmethod
8+
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
9+
"""
10+
Loads weights from a checkpoint directory.
11+
12+
Args:
13+
checkpoint_dir: A path to the checkpoint directory.
14+
15+
Returns:
16+
A dictionary where keys are tensor names and values are the tensors.
17+
"""
18+
19+
def cleanup(self) -> None:
20+
pass
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Callable, List, Union
3+
4+
from torch import nn
5+
6+
from tensorrt_llm._torch.model_config import ModelConfig, TConfig
7+
from tensorrt_llm._torch.models.modeling_utils import DecoderModelForCausalLM
8+
9+
10+
class BaseWeightMapper(ABC):
11+
12+
def __init__(self):
13+
self._callbacks: list[Callable] = []
14+
self._mapping: dict = {}
15+
self._skip_modules = []
16+
self._model: Union[nn.Module, DecoderModelForCausalLM] | None = None
17+
self._config: TConfig | None = None
18+
19+
def init_model_and_config(self, model: Union[nn.Module,
20+
DecoderModelForCausalLM],
21+
config: TConfig):
22+
self._model = model
23+
self._config = config
24+
25+
if not hasattr(model, 'model_config') or not isinstance(
26+
model.model_config, ModelConfig):
27+
raise ValueError("model must have a model_config attribute")
28+
if not hasattr(model, 'config'):
29+
raise ValueError("model must have a config attribute")
30+
31+
self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
32+
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
33+
model.config, 'num_key_value_heads'
34+
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads
35+
36+
self.map_weights()
37+
38+
def cleanup(self) -> None:
39+
self._model = None
40+
self._config = None
41+
42+
@abstractmethod
43+
def map_weights(self) -> None:
44+
"""
45+
Maps weights from TRT-LLM to a source state dictionary (e.g., Hugging Face)
46+
"""
47+
48+
@abstractmethod
49+
def apply_callbacks(self, module: nn.Module, module_name: str,
50+
module_names_breakdown: list[str],
51+
weights: dict) -> list[dict]:
52+
"""
53+
Applies a series of transformation functions to an internal representation
54+
of weights or to guide the mapping process. The exact behavior might depend
55+
on the implementation (e.g., storing callbacks to be applied later).
56+
57+
Args:
58+
module: The module to apply the callbacks to
59+
module_name: The specific module name (e.g., 'qkv_proj', 'gate_up_proj')
60+
module_names_breakdown: List of module path components for building full paths
61+
weights: The weights dictionary to process
62+
"""
63+
64+
def rename_by_params_map(self, params_map: dict[str, str],
65+
weights: dict) -> dict:
66+
"""
67+
Rename weight keys according to regex pattern matching.
68+
69+
Args:
70+
pattern_mapping: A dictionary mapping regex patterns to replacement strings. The key is HF name pattern, and the value is corresponding TRT-LLM name pattern.
71+
The patterns will be used to match keys in the weights dict and replace
72+
them according to the replacement string, which can use regex backreferences.
73+
Example:
74+
HF name: vision_model.encoder.layers.1.self_attn.out_proj.{weight,bias}
75+
TRT-LLM name: vision_model.encoder.layers.1.self_attn.o_proj.{weight,bias}
76+
Then the pattern_mapping could be:
77+
pattern_mapping = {
78+
r'(.*?)out_proj(.*)': r'\1o_proj\2'
79+
}
80+
weights: A dictionary of weights
81+
82+
Returns:
83+
A dictionary of weights with renamed keys
84+
"""
85+
import re
86+
87+
# Create a new dictionary to store the renamed weights
88+
renamed_weights = {}
89+
90+
# Keep track of keys that have been matched by a pattern
91+
matched_keys = set()
92+
93+
# Process each key in the weights dictionary
94+
for key in list(weights.keys()):
95+
# Check each pattern for a match
96+
for pattern, replacement in params_map.items():
97+
if re.match(pattern, key):
98+
# Create the new key by applying the regex replacement
99+
new_key = re.sub(pattern, replacement, key)
100+
# Store the weight with the new key
101+
renamed_weights[new_key] = weights[key]
102+
matched_keys.add(key)
103+
break
104+
105+
# If the key wasn't matched by any pattern, keep it as is
106+
if key not in matched_keys:
107+
renamed_weights[key] = weights[key]
108+
109+
return renamed_weights
110+
111+
def preprocess_weights(self, weights: dict) -> dict:
112+
"""
113+
Preprocess weights before starting the loading process.
114+
"""
115+
...
116+
117+
def handle_manual_copy(self, module_name: str, module_weights: dict, n: str,
118+
p: nn.Parameter) -> None:
119+
p.data.copy_(module_weights[n][:])
120+
121+
def does_require_special_handling(self, module_name: str) -> bool:
122+
return module_name in self.mapping
123+
124+
def is_special_instance_module(self, module: nn.Module) -> bool:
125+
return False
126+
127+
def handle_special_instance_module(self, module: nn.Module,
128+
module_name: str,
129+
module_weights: dict) -> None:
130+
raise NotImplementedError()
131+
132+
@property
133+
def skip_modules(self) -> List[str]:
134+
return self._skip_modules
135+
136+
def add_skip_modules(self, value: List[str]) -> None:
137+
self._skip_modules.extend(value)
138+
139+
def should_skip_module(self, module_name: str) -> bool:
140+
return any(skip_module in module_name
141+
for skip_module in self._skip_modules)
142+
143+
def filter_weights(self, prefix: str, weights: dict) -> dict:
144+
result = {}
145+
for k, v in weights.items():
146+
if k.startswith(prefix):
147+
new_k = k[len(prefix) + 1:]
148+
result[new_k] = v
149+
return result
150+
151+
@property
152+
def mapping(self) -> dict:
153+
return self._mapping
154+
155+
@property
156+
def config(self) -> TConfig:
157+
if self._config is None:
158+
raise RuntimeError("Weight mapper is not initialized")
159+
return self._config
160+
161+
@property
162+
def model(self) -> Union[nn.Module, DecoderModelForCausalLM]:
163+
if self._model is None:
164+
raise RuntimeError("Weight mapper is not initialized")
165+
return self._model

tensorrt_llm/_torch/models/checkpoints/hf/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)