|  | 
| 5 | 5 | import tarfile | 
| 6 | 6 | import warnings | 
| 7 | 7 | from collections import defaultdict | 
| 8 |  | -from dataclasses import dataclass, field | 
|  | 8 | +from dataclasses import dataclass | 
| 9 | 9 | from functools import lru_cache | 
| 10 | 10 | from pathlib import Path | 
| 11 | 11 | from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union | 
|  | 
| 16 | 16 | 
 | 
| 17 | 17 | from tensorrt_llm.bindings import internal as tb_internal | 
| 18 | 18 | 
 | 
| 19 |  | -from ._utils import DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy | 
|  | 19 | +from ._utils import pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy | 
| 20 | 20 | from .layers.linear import ColumnLinear | 
|  | 21 | +from .lora_helper import ( | 
|  | 22 | +    LoraConfig, | 
|  | 23 | +    get_default_trtllm_modules_to_hf_modules, | 
|  | 24 | +    get_missing_qkv_modules, | 
|  | 25 | +) | 
| 21 | 26 | from .mapping import Mapping | 
| 22 | 27 | from .models.convert_utils import get_model_path, load_state_dict, split_matrix_tp | 
| 23 | 28 | 
 | 
| @@ -232,26 +237,6 @@ def norm_dora_magnitude( | 
| 232 | 237 |     return norm_m | 
| 233 | 238 | 
 | 
| 234 | 239 | 
 | 
| 235 |  | -@dataclass | 
| 236 |  | -class LoraConfig(DictConversion): | 
| 237 |  | -    lora_dir: List[str] = field(default_factory=list) | 
| 238 |  | -    lora_ckpt_source: str = "hf" | 
| 239 |  | -    max_lora_rank: int = 64 | 
| 240 |  | -    lora_target_modules: List[str] = field(default_factory=list) | 
| 241 |  | -    trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) | 
| 242 |  | -    max_loras: int | None = None | 
| 243 |  | -    max_cpu_loras: int | None = None | 
| 244 |  | - | 
| 245 |  | -    def __post_init__(self): | 
| 246 |  | -        assert self.lora_ckpt_source in ["hf", "nemo"], ( | 
| 247 |  | -            f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}" | 
| 248 |  | -        ) | 
| 249 |  | - | 
| 250 |  | -    @property | 
| 251 |  | -    def missing_qkv_modules(self) -> List[str]: | 
| 252 |  | -        return LoraManager.get_missing_qkv_modules(self.lora_target_modules) | 
| 253 |  | - | 
| 254 |  | - | 
| 255 | 240 | @dataclass | 
| 256 | 241 | class LoraModelConfig: | 
| 257 | 242 |     lora_target_modules: list[str] | 
| @@ -430,23 +415,6 @@ def load_nemo_lora(model, lora_config: LoraConfig): | 
| 430 | 415 |         lora_config.lora_target_modules = lora_loader.lora_target_modules | 
| 431 | 416 | 
 | 
| 432 | 417 | 
 | 
| 433 |  | -def get_default_trtllm_modules_to_hf_modules(): | 
| 434 |  | -    return { | 
| 435 |  | -        "attn_q": "q_proj", | 
| 436 |  | -        "attn_k": "k_proj", | 
| 437 |  | -        "attn_v": "v_proj", | 
| 438 |  | -        "attn_dense": "o_proj", | 
| 439 |  | -        "mlp_h_to_4h": "gate_proj", | 
| 440 |  | -        "mlp_4h_to_h": "down_proj", | 
| 441 |  | -        "mlp_gate": "up_proj", | 
| 442 |  | -        "mlp_gate_up": "gate_up_proj", | 
| 443 |  | -        "moe_h_to_4h": "w1", | 
| 444 |  | -        "moe_4h_to_h": "w2", | 
| 445 |  | -        "moe_gate": "w3", | 
| 446 |  | -        "moe_router": "gate", | 
| 447 |  | -    } | 
| 448 |  | - | 
| 449 |  | - | 
| 450 | 418 | def load_torch_hf_lora(lora_config: LoraConfig): | 
| 451 | 419 |     """This is a shortned version of load_hf_lora that is used for torch models. | 
| 452 | 420 | 
 | 
| @@ -628,19 +596,6 @@ def load_hf_lora( | 
| 628 | 596 |             ).to(torch_dtype) | 
| 629 | 597 | 
 | 
| 630 | 598 | 
 | 
| 631 |  | -def use_lora( | 
| 632 |  | -    model, | 
| 633 |  | -    lora_config: LoraConfig, | 
| 634 |  | -    trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None, | 
| 635 |  | -): | 
| 636 |  | -    if lora_config.lora_ckpt_source == "nemo": | 
| 637 |  | -        load_nemo_lora(model, lora_config) | 
| 638 |  | -    elif lora_config.lora_ckpt_source == "hf": | 
| 639 |  | -        load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules) | 
| 640 |  | -    else: | 
| 641 |  | -        raise ValueError(f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") | 
| 642 |  | - | 
| 643 |  | - | 
| 644 | 599 | def unpack_nemo_weights(nemo_archive_path: str) -> Tuple[Dict, Dict[str, torch.Tensor]]: | 
| 645 | 600 |     """Unpack model config and weights from a NeMo .nemo archive file. | 
| 646 | 601 | 
 | 
| @@ -763,20 +718,7 @@ def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool: | 
| 763 | 718 | 
 | 
| 764 | 719 |     @staticmethod | 
| 765 | 720 |     def get_missing_qkv_modules(lora_target_modules): | 
| 766 |  | -        # In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or | 
| 767 |  | -        # all disabled at the same time. | 
| 768 |  | -        # However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor | 
| 769 |  | -        # to fill the missing ones. | 
| 770 |  | -        missing_qkv_modules = [] | 
| 771 |  | -        if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]): | 
| 772 |  | -            for lora_module in ["attn_q", "attn_k", "attn_v"]: | 
| 773 |  | -                if lora_module not in lora_target_modules: | 
| 774 |  | -                    missing_qkv_modules.append(lora_module) | 
| 775 |  | -        if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]): | 
| 776 |  | -            for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]: | 
| 777 |  | -                if lora_module not in lora_target_modules: | 
| 778 |  | -                    missing_qkv_modules.append(lora_module) | 
| 779 |  | -        return missing_qkv_modules | 
|  | 721 | +        return get_missing_qkv_modules(lora_target_modules) | 
| 780 | 722 | 
 | 
| 781 | 723 |     @property | 
| 782 | 724 |     def missing_qkv_modules(self) -> List[str]: | 
|  | 
0 commit comments