Skip to content
Merged
2 changes: 2 additions & 0 deletions tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _add_trt_llm_dll_directory():
# otherwise `MemoryError: std::bad_alloc` pattern error will be raised.
import xgrammar # noqa

import tensorrt_llm._torch.models as torch_models
import tensorrt_llm.functional as functional
import tensorrt_llm.math_utils as math_utils
import tensorrt_llm.models as models
Expand Down Expand Up @@ -82,6 +83,7 @@ def _add_trt_llm_dll_directory():
'default_trtnet',
'precision',
'net_guard',
'torch_models',
'Network',
'Mapping',
'MnnvlMemory',
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
load_torch_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_manager import load_torch_lora
from tensorrt_llm.mapping import Mapping

from ..model_config import ModelConfig
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/llmapi/build_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import filelock

import tensorrt_llm
from tensorrt_llm import BuildConfig
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.llmapi.utils import enable_llm_debug, print_colored
from tensorrt_llm.logger import logger

Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase

from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)

from .._utils import mpi_rank
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
Expand Down
99 changes: 99 additions & 0 deletions tensorrt_llm/lora_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Dict, List, Optional

from ._utils import DictConversion


def get_missing_qkv_modules(lora_target_modules: List[str]) -> List[str]:
"""Get missing QKV modules from LoRA target modules.
In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
all disabled at the same time. However, some lora checkpoint (e.g. BART) only contain two of them,
so we use zero tensor to fill the missing ones.
"""
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules
for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules


def get_default_trtllm_modules_to_hf_modules():
"""Get default mapping from TensorRT-LLM module names to HuggingFace module names."""
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}


def use_lora(
model,
lora_config: "LoraConfig",
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
"""Use LoRA with the given model and configuration.
This function is a wrapper that delegates to the appropriate loading function
based on the LoRA checkpoint source.
"""
if lora_config.lora_ckpt_source == "nemo":
from .lora_manager import load_nemo_lora
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
from .lora_manager import load_hf_lora
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(
f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")


@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int | None = None
max_cpu_loras: int | None = None

def __post_init__(self):
assert self.lora_ckpt_source in [
"hf", "nemo"
], (f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)

@property
def missing_qkv_modules(self) -> List[str]:
return get_missing_qkv_modules(self.lora_target_modules)
74 changes: 8 additions & 66 deletions tensorrt_llm/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tarfile
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
Expand All @@ -16,8 +16,13 @@

from tensorrt_llm.bindings import internal as tb_internal

from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy
from .layers.linear import ColumnLinear
from .lora_helper import (
LoraConfig,
get_default_trtllm_modules_to_hf_modules,
get_missing_qkv_modules,
)
from .mapping import Mapping
from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp

Expand Down Expand Up @@ -232,26 +237,6 @@ def norm_dora_magnitude(
return norm_m


@dataclass
class LoraConfig(DictConversion):
lora_dir: List[str] = field(default_factory=list)
lora_ckpt_source: str = "hf"
max_lora_rank: int = 64
lora_target_modules: List[str] = field(default_factory=list)
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
max_loras: int | None = None
max_cpu_loras: int | None = None

def __post_init__(self):
assert self.lora_ckpt_source in ["hf", "nemo"], (
f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}"
)

@property
def missing_qkv_modules(self) -> List[str]:
return LoraManager.get_missing_qkv_modules(self.lora_target_modules)


@dataclass
class LoraModelConfig:
lora_target_modules: list[str]
Expand Down Expand Up @@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig):
lora_config.lora_target_modules = lora_loader.lora_target_modules


def get_default_trtllm_modules_to_hf_modules():
return {
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_gate_up": "gate_up_proj",
"moe_h_to_4h": "w1",
"moe_4h_to_h": "w2",
"moe_gate": "w3",
"moe_router": "gate",
}


def load_torch_hf_lora(lora_config: LoraConfig):
"""This is a shortned version of load_hf_lora that is used for torch models.

Expand Down Expand Up @@ -628,19 +596,6 @@ def load_hf_lora(
).to(torch_dtype)


def use_lora(
model,
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
if lora_config.lora_ckpt_source == "nemo":
load_nemo_lora(model, lora_config)
elif lora_config.lora_ckpt_source == "hf":
load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules)
else:
raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}")


def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]:
"""Unpack model config and weights from a NeMo .nemo archive file.

Expand Down Expand Up @@ -763,20 +718,7 @@ def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:

@staticmethod
def get_missing_qkv_modules(lora_target_modules):
# In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or
# all disabled at the same time.
# However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor
# to fill the missing ones.
missing_qkv_modules = []
if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]):
for lora_module in ["attn_q", "attn_k", "attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]):
for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]:
if lora_module not in lora_target_modules:
missing_qkv_modules.append(lora_module)
return missing_qkv_modules
return get_missing_qkv_modules(lora_target_modules)

@property
def missing_qkv_modules(self) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/models/enc_dec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
LanguageAdapterConfig, LayerNorm, LoraParams,
PromptTuningEmbedding, RmsNorm)
# yapf: enable
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import PretrainedConfig, PretrainedModel
from tensorrt_llm.module import Module, ModuleList
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/models/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ...layers import (Attention, AttentionMaskType, AttentionParams,
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
LoraParams, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...layers import (MLP, MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, GatedMLP, LayerNorm, MoeConfig,
PositionEmbeddingType)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ...quantization import QuantMode
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/models/grok/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...functional import Tensor, recv, send
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, MoeConfig, PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, FusedGatedMLP, GatedMLP,
PositionEmbeddingType, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ...quantization.functional import fused_layernorm
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/models/mllama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
ColumnLinear, Embedding, FusedGatedMLP,
GatedMLP, GroupNorm, KeyValueCacheParams,
LayerNorm, LoraParams, RmsNorm)
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
use_lora)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.model_weights_loader import ModelWeightsLoader
from tensorrt_llm.models.modeling_utils import PretrainedModel, QuantConfig
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/models/phi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ...functional import Tensor
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
Embedding, LayerNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/models/phi3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...layers import (MLP, MOE, Attention, AttentionMaskType,
BlockSparseAttnParams, ColumnLinear, Embedding,
LayerNorm, MoeConfig, RmsNorm)
from ...lora_manager import LoraConfig, use_lora
from ...lora_helper import LoraConfig, use_lora
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import (DecoderLayerList, DecoderModelForCausalLM,
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/models/qwen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
Embedding, GatedMLP, RmsNorm, SharedMoE)
from ...layers.moe import MOEWeightWrapper
from ...logger import logger
from ...lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, use_lora)
from ...lora_helper import (LoraConfig,
get_default_trtllm_modules_to_hf_modules, use_lora)
from ...mapping import Mapping
from ...module import Module
from ...quantization import QuantAlgo
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/top_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from typing import Optional

from .lora_manager import LoraConfig
from .lora_helper import LoraConfig
from .mapping import Mapping
from .plugin.plugin import PluginConfig

Expand Down
Loading