Skip to content

Commit d7a24c8

Browse files
committed
refactoring to avoid circular import when importing torch models
Signed-off-by: Rakib Hasan <[email protected]>
1 parent 3b2dd40 commit d7a24c8

File tree

16 files changed

+130
-87
lines changed

16 files changed

+130
-87
lines changed

tensorrt_llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _add_trt_llm_dll_directory():
3333
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
3434
import xgrammar # noqa
3535

36+
import tensorrt_llm._torch.models as torch_models
3637
import tensorrt_llm.functional as functional
3738
import tensorrt_llm.math_utils as math_utils
3839
import tensorrt_llm.models as models
@@ -82,6 +83,7 @@ def _add_trt_llm_dll_directory():
8283
'default_trtnet',
8384
'precision',
8485
'net_guard',
86+
'torch_models',
8587
'Network',
8688
'Mapping',
8789
'MnnvlMemory',

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
1414
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
1515
from tensorrt_llm.logger import logger
16-
from tensorrt_llm.lora_manager import (LoraConfig,
17-
get_default_trtllm_modules_to_hf_modules,
18-
load_torch_lora)
16+
from tensorrt_llm.lora_helper import (LoraConfig,
17+
get_default_trtllm_modules_to_hf_modules)
18+
from tensorrt_llm.lora_manager import load_torch_lora
1919
from tensorrt_llm.mapping import Mapping
2020

2121
from ..model_config import ModelConfig

tensorrt_llm/llmapi/build_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import filelock
1313

1414
import tensorrt_llm
15-
from tensorrt_llm import BuildConfig
15+
from tensorrt_llm.builder import BuildConfig
1616
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
1717
from tensorrt_llm.logger import logger
1818

tensorrt_llm/llmapi/llm_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from strenum import StrEnum
2020
from transformers import PreTrainedTokenizerBase
2121

22-
from tensorrt_llm.lora_manager import (LoraConfig,
23-
get_default_trtllm_modules_to_hf_modules)
22+
from tensorrt_llm.lora_helper import (LoraConfig,
23+
get_default_trtllm_modules_to_hf_modules)
2424

2525
from .._utils import mpi_rank
2626
from ..auto_parallel import AutoParallelConfig, infer_cluster_config

tensorrt_llm/lora_helper.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from dataclasses import dataclass, field
17+
from typing import Dict, List, Optional
18+
19+
from ._utils import DictConversion
20+
21+
22+
def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
23+
"""Get missing QKV modules from LoRA target modules.
24+
25+
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
26+
all disabled at the same time. However, some lora checkpoint (e.g. BART) only contain two of them,
27+
so we use zero tensor to fill the missing ones.
28+
"""
29+
missing_qkv_modules = []
30+
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
31+
for lora_module in ["attn_q", "attn_k", "attn_v"]:
32+
if lora_module not in lora_target_modules:
33+
missing_qkv_modules.append(lora_module)
34+
if any(x in lora_target_modules
35+
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
36+
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
37+
if lora_module not in lora_target_modules:
38+
missing_qkv_modules.append(lora_module)
39+
return missing_qkv_modules
40+
41+
42+
def get_default_trtllm_modules_to_hf_modules():
43+
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
44+
return {
45+
"attn_q": "q_proj",
46+
"attn_k": "k_proj",
47+
"attn_v": "v_proj",
48+
"attn_dense": "o_proj",
49+
"mlp_h_to_4h": "gate_proj",
50+
"mlp_4h_to_h": "down_proj",
51+
"mlp_gate": "up_proj",
52+
"mlp_gate_up": "gate_up_proj",
53+
"moe_h_to_4h": "w1",
54+
"moe_4h_to_h": "w2",
55+
"moe_gate": "w3",
56+
"moe_router": "gate",
57+
}
58+
59+
60+
def use_lora(
61+
model,
62+
lora_config: "LoraConfig",
63+
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
64+
):
65+
"""Use LoRA with the given model and configuration.
66+
67+
This function is a wrapper that delegates to the appropriate loading function
68+
based on the LoRA checkpoint source.
69+
"""
70+
if lora_config.lora_ckpt_source == "nemo":
71+
from .lora_manager import load_nemo_lora
72+
load_nemo_lora(model, lora_config)
73+
elif lora_config.lora_ckpt_source == "hf":
74+
from .lora_manager import load_hf_lora
75+
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
76+
else:
77+
raise ValueError(
78+
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
79+
80+
81+
@dataclass
82+
class LoraConfig(DictConversion):
83+
lora_dir: List[str] = field(default_factory=list)
84+
lora_ckpt_source: str = "hf"
85+
max_lora_rank: int = 64
86+
lora_target_modules: List[str] = field(default_factory=list)
87+
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
88+
max_loras: int | None = None
89+
max_cpu_loras: int | None = None
90+
91+
def __post_init__(self):
92+
assert self.lora_ckpt_source in [
93+
"hf", "nemo"
94+
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
95+
)
96+
97+
@property
98+
def missing_qkv_modules(self) -> List[str]:
99+
return get_missing_qkv_modules(self.lora_target_modules)

tensorrt_llm/lora_manager.py

Lines changed: 8 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tarfile
66
import warnings
77
from collections import defaultdict
8-
from dataclasses import dataclass, field
8+
from dataclasses import dataclass
99
from functools import lru_cache
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
@@ -16,8 +16,13 @@
1616

1717
from tensorrt_llm.bindings import internal as tb_internal
1818

19-
from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
19+
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
2020
from .layers.linear import ColumnLinear
21+
from .lora_helper import (
22+
LoraConfig,
23+
get_default_trtllm_modules_to_hf_modules,
24+
get_missing_qkv_modules,
25+
)
2126
from .mapping import Mapping
2227
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp
2328

@@ -232,26 +237,6 @@ def norm_dora_magnitude(
232237
return norm_m
233238

234239

235-
@dataclass
236-
class LoraConfig(DictConversion):
237-
lora_dir: List[str] = field(default_factory=list)
238-
lora_ckpt_source: str = "hf"
239-
max_lora_rank: int = 64
240-
lora_target_modules: List[str] = field(default_factory=list)
241-
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
242-
max_loras: int | None = None
243-
max_cpu_loras: int | None = None
244-
245-
def __post_init__(self):
246-
assert self.lora_ckpt_source in ["hf", "nemo"], (
247-
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
248-
)
249-
250-
@property
251-
def missing_qkv_modules(self) -> List[str]:
252-
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)
253-
254-
255240
@dataclass
256241
class LoraModelConfig:
257242
lora_target_modules: list[str]
@@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
430415
lora_config.lora_target_modules = lora_loader.lora_target_modules
431416

432417

433-
def get_default_trtllm_modules_to_hf_modules():
434-
return {
435-
"attn_q": "q_proj",
436-
"attn_k": "k_proj",
437-
"attn_v": "v_proj",
438-
"attn_dense": "o_proj",
439-
"mlp_h_to_4h": "gate_proj",
440-
"mlp_4h_to_h": "down_proj",
441-
"mlp_gate": "up_proj",
442-
"mlp_gate_up": "gate_up_proj",
443-
"moe_h_to_4h": "w1",
444-
"moe_4h_to_h": "w2",
445-
"moe_gate": "w3",
446-
"moe_router": "gate",
447-
}
448-
449-
450418
def load_torch_hf_lora(lora_config: LoraConfig):
451419
"""This is a shortned version of load_hf_lora that is used for torch models.
452420
@@ -628,19 +596,6 @@ def load_hf_lora(
628596
).to(torch_dtype)
629597

630598

631-
def use_lora(
632-
model,
633-
lora_config: LoraConfig,
634-
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
635-
):
636-
if lora_config.lora_ckpt_source == "nemo":
637-
load_nemo_lora(model, lora_config)
638-
elif lora_config.lora_ckpt_source == "hf":
639-
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
640-
else:
641-
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")
642-
643-
644599
def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
645600
"""Unpack model config and weights from a NeMo .nemo archive file.
646601
@@ -763,20 +718,7 @@ def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
763718

764719
@staticmethod
765720
def get_missing_qkv_modules(lora_target_modules):
766-
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
767-
# all disabled at the same time.
768-
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
769-
# to fill the missing ones.
770-
missing_qkv_modules = []
771-
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
772-
for lora_module in ["attn_q", "attn_k", "attn_v"]:
773-
if lora_module not in lora_target_modules:
774-
missing_qkv_modules.append(lora_module)
775-
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
776-
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
777-
if lora_module not in lora_target_modules:
778-
missing_qkv_modules.append(lora_module)
779-
return missing_qkv_modules
721+
return get_missing_qkv_modules(lora_target_modules)
780722

781723
@property
782724
def missing_qkv_modules(self) -> List[str]:

tensorrt_llm/models/enc_dec/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
LanguageAdapterConfig, LayerNorm, LoraParams,
3636
PromptTuningEmbedding, RmsNorm)
3737
# yapf: enable
38-
from tensorrt_llm.lora_manager import (LoraConfig,
39-
get_default_trtllm_modules_to_hf_modules,
40-
use_lora)
38+
from tensorrt_llm.lora_helper import (LoraConfig,
39+
get_default_trtllm_modules_to_hf_modules,
40+
use_lora)
4141
from tensorrt_llm.mapping import Mapping
4242
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
4343
from tensorrt_llm.module import Module, ModuleList

tensorrt_llm/models/gemma/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...layers import (Attention, AttentionMaskType, AttentionParams,
2929
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
3030
LoraParams, PositionEmbeddingType, RmsNorm)
31-
from ...lora_manager import LoraConfig, use_lora
31+
from ...lora_helper import LoraConfig, use_lora
3232
from ...mapping import Mapping
3333
from ...module import Module
3434
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

tensorrt_llm/models/gpt/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
2222
Embedding, GatedMLP, LayerNorm, MoeConfig,
2323
PositionEmbeddingType)
24-
from ...lora_manager import LoraConfig, use_lora
24+
from ...lora_helper import LoraConfig, use_lora
2525
from ...mapping import Mapping
2626
from ...module import Module
2727
from ...quantization import QuantMode

tensorrt_llm/models/grok/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ...functional import Tensor, recv, send
1919
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
2020
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
21-
from ...lora_manager import LoraConfig, use_lora
21+
from ...lora_helper import LoraConfig, use_lora
2222
from ...mapping import Mapping
2323
from ...module import Module
2424
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,

0 commit comments

Comments
 (0)