From 79e65341bd6154c6f726efb51b0b4b08b8cacf3c Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:42 +0000 Subject: [PATCH 01/16] Interleave mlp_gate_up LoRA weights for TP, move the mapping argument from LoraManager load methods to its constructor, remove other TP references in load_from_model_dir as until now it always received default mapping with TP=1 as CPP expects to work on full non-split weights Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 10 ++- tensorrt_llm/lora_manager.py | 83 ++++++++----------- tensorrt_llm/runtime/enc_dec_model_runner.py | 8 +- tensorrt_llm/runtime/model_runner.py | 6 +- tensorrt_llm/runtime/model_runner_cpp.py | 4 +- 5 files changed, 51 insertions(+), 60 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 6e4f4a98497..8a93bc4df62 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1169,7 +1169,14 @@ def __init__(self, lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size, binding_to_str_dtype(model_config.data_type), lora_config.swap_gate_up_proj_lora_b_weight) - self._lora_manager = LoraManager() + mapping = Mapping( + world_size=world_config.size, + rank=world_config.rank, + tp_size=world_config.tensor_parallelism, + pp_size=world_config.pipeline_parallelism, + gpus_per_node=world_config.gpus_per_node, + ) + self._lora_manager = LoraManager(mapping=mapping) def add_request_peft(self, request: LlmRequest): if request.lora_task_id is not None: @@ -1183,7 +1190,6 @@ def add_request_peft(self, request: LlmRequest): self._lora_manager.load_from_ckpt( [request.py_lora_path], model_config=self._lora_model_config, - runtime_mapping=None, uids=[request.lora_task_id], ckpt_source=self._lora_config.lora_ckpt_source) request.lora_weights = self._lora_manager.cpp_lora_weights[ diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index c7dc6f28bc9..3e554ab766a 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1,4 +1,5 @@ import io +import itertools import json import logging import re @@ -660,7 +661,10 @@ class LoraManager(object): } def __init__( - self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None + self, + *, + mapping: Mapping, + cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None, ): """Constructor. @@ -704,6 +708,7 @@ def __init__( self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu self.lora_target_modules: List[str] = [] + self._mapping = mapping self._cpp_peft_cache_manager = cpp_peft_cache_manager def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool: @@ -730,7 +735,6 @@ def load_from_ckpt( self, model_dirs_or_files: List[str], model_config: Union["ModelConfig", LoraModelConfig], - runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, ckpt_source: str = "hf", ) -> List[str]: @@ -743,7 +747,6 @@ def load_from_ckpt( return self.load_from_hf( model_dirs=model_dirs_or_files, model_config=model_config, - runtime_mapping=runtime_mapping, uids=uids, ) elif ckpt_source == "nemo": @@ -754,7 +757,6 @@ def load_from_ckpt( return self.load_from_nemo( model_files=nemo_files, model_config=model_config, - runtime_mapping=runtime_mapping, uids=uids, ) else: @@ -764,7 +766,6 @@ def load_from_nemo( self, model_files: List[str], model_config: Union["ModelConfig", LoraModelConfig], - runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, ) -> List[str]: """Returns the adapter UIDs that were loaded by this call. @@ -772,11 +773,6 @@ def load_from_nemo( Note that when an adapter was already loaded before this call, it would not be included in the returned list of UIDs. """ - if runtime_mapping is None: - runtime_mapping = Mapping() - tp_size = runtime_mapping.tp_size - tp_rank = runtime_mapping.tp_rank - if uids is None: uids = [self._generate_uid() for _ in range(len(model_files))] assert len(uids) == len(model_files) @@ -829,10 +825,6 @@ def load_from_model_file(uid, model_file): t_in = all_lora_weights[layer_idx]["in"] t_out = all_lora_weights[layer_idx]["out"] - assert t_out.shape[0] % tp_size == 0 - t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[ - tp_rank - ].contiguous() else: t_in = None t_out = None @@ -882,7 +874,6 @@ def load_from_hf( self, model_dirs: List[str], model_config: Union["ModelConfig", LoraModelConfig], - runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, component: Optional[str] = None, ) -> List[str]: @@ -939,11 +930,6 @@ def load_from_hf( ... """ - if runtime_mapping is None: - runtime_mapping = Mapping() - tp_size = runtime_mapping.tp_size - tp_rank = runtime_mapping.tp_rank - if uids is None: uids = [self._generate_uid() for _ in range(len(model_dirs))] assert len(uids) == len(model_dirs) @@ -1060,36 +1046,37 @@ def load_from_model_dir(uid, model_dir, hf_config): t_mag = module_weights.get("magnitude", None) is_dora = t_mag is not None + rank_dim = 1 if has_expert_indices else 0 - if lora_module in ["moe_router", "mlp_router"]: - pass - elif "moe" in lora_module and runtime_mapping.has_moe_ep(): - pass - elif lora_module in [ - "attn_dense", - "cross_attn_dense", - "mlp_4h_to_h", - "moe_4h_to_h", - ]: - # split by row - dim = 2 if has_expert_indices else 1 - assert t_in.shape[dim] % tp_size == 0 - t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[ - tp_rank - ].contiguous() - else: - # split by column - dim = 1 if has_expert_indices else 0 - assert t_out.shape[dim] % tp_size == 0 - t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[ - tp_rank - ].contiguous() - if dim == 0 and is_dora and t_mag is not None: - t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[ - tp_rank - ].contiguous() + if lora_module in ["mlp_gate_up"]: + # Special handling for fused module mlp_gate_up: + # HF stores each part's weights sequentially, whereas we need to interleave them for TP. + # We have to concatenate them all back after interleaving, as the CPP expects the full + # non-split weights. + assert t_out.shape[rank_dim] % 2 == 0 + half_size = t_out.shape[rank_dim] // 2 + tp_size = self._mapping.tp_size + assert half_size % tp_size == 0 + + first_half = t_out.narrow(rank_dim, 0, half_size) + second_half = t_out.narrow(rank_dim, half_size, half_size) + tp_parts_a = [ + torch.split( + first_half, first_half.shape[rank_dim] // tp_size, dim=rank_dim + )[r] + for r in range(tp_size) + ] + tp_parts_b = [ + torch.split( + second_half, second_half.shape[rank_dim] // tp_size, dim=rank_dim + )[r] + for r in range(tp_size) + ] + interleaved_parts = list( + itertools.chain.from_iterable(zip(tp_parts_a, tp_parts_b)) + ) + t_out = torch.cat(interleaved_parts, dim=rank_dim) - rank_dim = 1 if has_expert_indices else 0 effective_rank = t_in.shape[rank_dim] t_in = t_in.cuda().contiguous() diff --git a/tensorrt_llm/runtime/enc_dec_model_runner.py b/tensorrt_llm/runtime/enc_dec_model_runner.py index f2f482a2505..3e1aa586678 100644 --- a/tensorrt_llm/runtime/enc_dec_model_runner.py +++ b/tensorrt_llm/runtime/enc_dec_model_runner.py @@ -174,12 +174,12 @@ def engine_setup(component): # encoder lora manager setup if self.encoder_model_config.lora_plugin: - self.encoder_lora_manager = LoraManager() + self.encoder_lora_manager = LoraManager( + mapping=self.encoder_runtime_mapping) # TODO: this is only for bart self.encoder_lora_manager.load_from_hf( model_dirs=lora_dir, model_config=self.encoder_model_config, - runtime_mapping=self.encoder_runtime_mapping, component='encoder', ) else: @@ -197,12 +197,12 @@ def engine_setup(component): # decoder lora manager setup if self.decoder_model_config.lora_plugin: - self.decoder_lora_manager = LoraManager() + self.decoder_lora_manager = LoraManager( + mapping=self.decoder_runtime_mapping) # TODO: this is only for bart self.decoder_lora_manager.load_from_hf( model_dirs=lora_dir, model_config=self.decoder_model_config, - runtime_mapping=self.decoder_runtime_mapping, component='decoder', ) else: diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index ee35da3ef0e..c7045854b84 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -611,11 +611,10 @@ def from_engine( session.runtime._set_weight_streaming(gpu_weights_percent) if session.use_lora_plugin: - lora_manager = LoraManager() + lora_manager = LoraManager(mapping=runtime_mapping) if lora_dir is not None: lora_manager.load_from_ckpt(lora_dir, model_config=model_config, - runtime_mapping=runtime_mapping, ckpt_source=lora_ckpt_source) else: lora_manager = None @@ -720,11 +719,10 @@ def from_dir( debug_mode=debug_mode, stream=stream) if session.use_lora_plugin: - lora_manager = LoraManager() + lora_manager = LoraManager(mapping=runtime_mapping) if lora_dir is not None: lora_manager.load_from_ckpt(lora_dir, model_config=model_config, - runtime_mapping=runtime_mapping, ckpt_source=lora_ckpt_source) else: lora_manager = None diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index b701f245f6f..d3cd6dff965 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -277,7 +277,8 @@ def from_dir( engine_config = EngineConfig.from_json_file(f"{engine_dir}/config.json") if model_config.use_lora_plugin and rank == 0: - lora_manager = LoraManager() + lora_manager = LoraManager( + mapping=_world_config_to_mapping(world_config)) if lora_dir is None: config_lora_dir = engine_config.build_config.lora_config.lora_dir if len(config_lora_dir) > 0: @@ -292,7 +293,6 @@ def from_dir( # For Executor, only rank 0 can enqueue requests, and should hold all lora weights lora_manager.load_from_ckpt(lora_dir, model_config=runtime_model_config, - runtime_mapping=None, ckpt_source=lora_ckpt_source) else: raise RuntimeError( From 56583122f6dbaa9949f901ef1916ac3c00bd8b64 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:43 +0000 Subject: [PATCH 02/16] Refactor fused LoRA module weight interleaving logic into its own function, add support for LorA fused attn QKV, pass ModelConfig to LoraManager ctor for fused QKV LoRA adapter support Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 24 +++++- tensorrt_llm/executor/base_worker.py | 11 ++- tensorrt_llm/executor/worker.py | 1 + tensorrt_llm/lora_helper.py | 1 + tensorrt_llm/lora_manager.py | 77 +++++++++++++------ 5 files changed, 86 insertions(+), 28 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 8a93bc4df62..50eea487fb6 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -13,6 +13,7 @@ from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig +from tensorrt_llm.runtime import ModelConfig as ModelConfigRuntime from tensorrt_llm.sampling_params import SamplingParams from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range @@ -1176,7 +1177,28 @@ def __init__(self, pp_size=world_config.pipeline_parallelism, gpus_per_node=world_config.gpus_per_node, ) - self._lora_manager = LoraManager(mapping=mapping) + self._lora_manager = LoraManager( + mapping=mapping, + model_config=self._model_config_binding_to_model_config_runtime( + model_config)) + + @staticmethod + def _model_config_binding_to_model_config_runtime( + model_config_binding: ModelConfig) -> ModelConfigRuntime: + # TODO ZUKER: Init the rest of the fields as well? + return ModelConfigRuntime( + max_batch_size=model_config_binding.max_batch_size, + max_beam_width=model_config_binding.max_beam_width, + vocab_size=model_config_binding.vocab_size, + num_layers=model_config_binding.num_layers( + ), # TODO ZUKER: Should num_layers get the PP args from mapping? + num_heads=model_config_binding.num_heads, + num_kv_heads=model_config_binding.num_kv_heads(0), + hidden_size=model_config_binding.hidden_size, + head_size=model_config_binding.head_size, + gpt_attention_plugin=model_config_binding.use_gpt_attention_plugin, + dtype=binding_to_str_dtype(model_config_binding.data_type), + ) def add_request_peft(self, request: LlmRequest): if request.lora_task_id is not None: diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 498f976ecce..ecf2473a84c 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -10,6 +10,7 @@ import torch from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping from .._torch.pyexecutor.llm_request import LlmResponse from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, @@ -205,7 +206,12 @@ def _create_engine(executor_config): # point in the TRT flow is currently not supported (it's at the CPP # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA # optimization is not available in TRT-python flow. - self._lora_manager = LoraManager(cpp_peft_cache_manager=None) + self._lora_manager = LoraManager( + # TODO ZUKER: Somehow create Mapping with real info when self.llm_args is None + mapping=Mapping() if self.llm_args is None else + self.llm_args.parallel_config.to_mapping(), + model_config=self._runtime_model_config, + cpp_peft_cache_manager=None) if engine_config.build_config.max_prompt_embedding_table_size > 0: self._prompt_adapter_manager = PromptAdapterManager() @@ -216,8 +222,7 @@ def _create_engine(executor_config): ResourceManagerType peft_cache_manager = self.engine.resource_manager.resource_managers.get( ResourceManagerType.PEFT_CACHE_MANAGER) - self._lora_manager = LoraManager( - cpp_peft_cache_manager=peft_cache_manager.impl) + self._lora_manager = peft_cache_manager._lora_manager lora_model_config = self.engine.model_engine.lora_model_config assert lora_model_config is not None self._lora_model_config = lora_model_config diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 59e3fca19fc..623066ed120 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -10,6 +10,7 @@ import zmq from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping from .._utils import mpi_comm, mpi_rank from ..bindings import executor as tllm diff --git a/tensorrt_llm/lora_helper.py b/tensorrt_llm/lora_helper.py index 719df510794..b9c14232724 100644 --- a/tensorrt_llm/lora_helper.py +++ b/tensorrt_llm/lora_helper.py @@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules(): "attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", + "attn_qkv": "qkv_proj", "attn_dense": "o_proj", "mlp_h_to_4h": "gate_proj", "mlp_4h_to_h": "down_proj", diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 3e554ab766a..6730a196a9e 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -664,6 +664,7 @@ def __init__( self, *, mapping: Mapping, + model_config: "ModelConfig", cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None, ): """Constructor. @@ -709,6 +710,7 @@ def __init__( self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu self.lora_target_modules: List[str] = [] self._mapping = mapping + self._model_config = model_config self._cpp_peft_cache_manager = cpp_peft_cache_manager def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool: @@ -969,6 +971,34 @@ def preprocess_lora_weights(lora_model, model_config): lora_model[key] = value return lora_model + def interleave_fused_lora_weights_for_tp( + weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: list[int] + ) -> list[torch.Tensor]: + assert weight.shape[rank_dim] == sum(part_sizes) + + # Split the weights into their respective parts. e.g. weight -> [q, k, v] for attn_qkv. + weight_parts = [ + weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i]) + for i in range(len(part_sizes)) + ] + for i in range(len(part_sizes)): + assert weight_parts[i].shape[rank_dim] % tp_size == 0 + + # Split each part into tp_size chunks. + # e.g. [q, k, v] -> [[q_rank0, ..., q_rankN], [k_rank0, ..., k_rankN], [v_rank0, ..., v_rankN]] + # where N is TP size, for attn_qkv. + weight_parts_tp_weights = [ + torch.split( + weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim + ) + for i in range(len(part_sizes)) + ] + + # Interleave the parts across TP ranks and flatten the list of lists into a single list. + # e.g. [[q_rank0, ..., q_rankN], [k_rank0, ..., k_rankN], [v_rank0, ..., v_rankN]] + # -> [q_rank0, k_rank0, v_rank0, ..., q_rankN, k_rankN, v_rankN] where N is TP size, for attn_qkv. + return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights))) + def load_from_model_dir(uid, model_dir, hf_config): if uid not in self._cpp_lora_weights: self._cpp_lora_weights[uid] = [] # Will be converted to tensor later @@ -1048,33 +1078,32 @@ def load_from_model_dir(uid, model_dir, hf_config): is_dora = t_mag is not None rank_dim = 1 if has_expert_indices else 0 - if lora_module in ["mlp_gate_up"]: - # Special handling for fused module mlp_gate_up: - # HF stores each part's weights sequentially, whereas we need to interleave them for TP. - # We have to concatenate them all back after interleaving, as the CPP expects the full - # non-split weights. + # Prepare fused modules weights for TP + # For fused modules, HF stores the parts weights sequentially, whereas with TP>1 we need them to be + # interleaved. + # e.g. Convert [q, k, v] to [q_rank0, k_rank0, v_rank0, ..., q_rankN, k_rankN, v_rankN] where + # N=TP size + tp_size = self._mapping.tp_size + interleaved_parts = [] + if lora_module == "mlp_gate_up": assert t_out.shape[rank_dim] % 2 == 0 half_size = t_out.shape[rank_dim] // 2 - tp_size = self._mapping.tp_size - assert half_size % tp_size == 0 - - first_half = t_out.narrow(rank_dim, 0, half_size) - second_half = t_out.narrow(rank_dim, half_size, half_size) - tp_parts_a = [ - torch.split( - first_half, first_half.shape[rank_dim] // tp_size, dim=rank_dim - )[r] - for r in range(tp_size) - ] - tp_parts_b = [ - torch.split( - second_half, second_half.shape[rank_dim] // tp_size, dim=rank_dim - )[r] - for r in range(tp_size) - ] - interleaved_parts = list( - itertools.chain.from_iterable(zip(tp_parts_a, tp_parts_b)) + interleaved_parts = interleave_fused_lora_weights_for_tp( + t_out, rank_dim, tp_size, [half_size, half_size] + ) + elif lora_module == "attn_qkv": + q_size = ( + self._model_config.head_size * self._model_config.num_heads * tp_size + ) + kv_size = ( + self._model_config.head_size * self._model_config.num_kv_heads * tp_size + ) + interleaved_parts = interleave_fused_lora_weights_for_tp( + t_out, rank_dim, tp_size, [q_size, kv_size, kv_size] ) + # We have to concatenate them all back after interleaving, as the CPP expects the full non-split + # weights. + if interleaved_parts: t_out = torch.cat(interleaved_parts, dim=rank_dim) effective_rank = t_in.shape[rank_dim] From 9c9030d7aeb29d22abb6e0142dc89c3a7f0d405e Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:44 +0000 Subject: [PATCH 03/16] Remove small code duplication Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 6730a196a9e..8a28cf7b12e 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1088,9 +1088,7 @@ def load_from_model_dir(uid, model_dir, hf_config): if lora_module == "mlp_gate_up": assert t_out.shape[rank_dim] % 2 == 0 half_size = t_out.shape[rank_dim] // 2 - interleaved_parts = interleave_fused_lora_weights_for_tp( - t_out, rank_dim, tp_size, [half_size, half_size] - ) + part_sizes = [half_size, half_size] elif lora_module == "attn_qkv": q_size = ( self._model_config.head_size * self._model_config.num_heads * tp_size @@ -1098,12 +1096,14 @@ def load_from_model_dir(uid, model_dir, hf_config): kv_size = ( self._model_config.head_size * self._model_config.num_kv_heads * tp_size ) + part_sizes = [q_size, kv_size, kv_size] + + if part_sizes: interleaved_parts = interleave_fused_lora_weights_for_tp( - t_out, rank_dim, tp_size, [q_size, kv_size, kv_size] + t_out, rank_dim, tp_size, part_sizes ) - # We have to concatenate them all back after interleaving, as the CPP expects the full non-split - # weights. - if interleaved_parts: + # We have to concatenate them all back after interleaving, as the CPP expects the full non-split + # weights. t_out = torch.cat(interleaved_parts, dim=rank_dim) effective_rank = t_in.shape[rank_dim] From 43174b96b7a806cf22d40546ad99639df57f05a3 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:45 +0000 Subject: [PATCH 04/16] Minor fix Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 8a28cf7b12e..8a2f69a89be 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1084,7 +1084,7 @@ def load_from_model_dir(uid, model_dir, hf_config): # e.g. Convert [q, k, v] to [q_rank0, k_rank0, v_rank0, ..., q_rankN, k_rankN, v_rankN] where # N=TP size tp_size = self._mapping.tp_size - interleaved_parts = [] + part_sizes = [] if lora_module == "mlp_gate_up": assert t_out.shape[rank_dim] % 2 == 0 half_size = t_out.shape[rank_dim] // 2 From f7000d176d86a937b28b9f9b99d98e2a79e514e3 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:49 +0000 Subject: [PATCH 05/16] Added test_lora_fused_modules_output_on_tp2_identical_to_tp1 to pytorch flow, minor fixes & improvements Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 37 ++++--------- tensorrt_llm/executor/base_worker.py | 12 +++-- tensorrt_llm/lora_manager.py | 14 ++--- tensorrt_llm/runtime/enc_dec_model_runner.py | 8 ++- tensorrt_llm/runtime/generation.py | 35 ++++++++++++- tensorrt_llm/runtime/model_runner.py | 6 ++- tensorrt_llm/runtime/model_runner_cpp.py | 10 ++-- tests/unittest/llmapi/lora_test_utils.py | 52 +++++++++++++++++++ .../llmapi/test_llm_multi_gpu_pytorch.py | 10 +++- 9 files changed, 138 insertions(+), 46 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 50eea487fb6..ff3dfcf09c6 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -13,7 +13,7 @@ from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig -from tensorrt_llm.runtime import ModelConfig as ModelConfigRuntime +from tensorrt_llm.runtime import ModelConfig as ModelConfigPython from tensorrt_llm.sampling_params import SamplingParams from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range @@ -33,7 +33,7 @@ KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType -ModelConfig = tensorrt_llm.bindings.ModelConfig +ModelConfigCpp = tensorrt_llm.bindings.ModelConfig DataType = tensorrt_llm.bindings.DataType KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager RequestList = list[LlmRequest] @@ -161,7 +161,7 @@ def __init__( spec_config: Optional["DecodingBaseConfig"] = None, layer_mask: Optional[List[bool]] = None, max_num_tokens: int = 8192, - model_config: Optional[ModelConfig] = None, + model_config: Optional[ModelConfigCpp] = None, max_beam_width: int = 1, is_draft: bool = False, kv_connector_manager: Optional[KvCacheConnectorManager] = None, @@ -372,7 +372,7 @@ def shutdown(self): @classmethod def from_model_config(cls, - model_config: ModelConfig, + model_config: ModelConfigCpp, kv_cache_config: KvCacheConfigCpp, mapping: Mapping, kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF, @@ -773,7 +773,7 @@ def adjust_window_sizes_for_vswa( window_size_to_layers: Dict[int, List[int]], max_attention_window_vec: List[int], kv_cache_config: KvCacheConfigCpp, - model_config: ModelConfig, + model_config: ModelConfigCpp, pool_memory_bytes: int, kv_factor: int, dtype: DataType, @@ -888,7 +888,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int: def calculate_max_num_blocks_from_cpp( self, kv_cache_config: KvCacheConfigCpp, - model_config: ModelConfig, + model_config: ModelConfigCpp, extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: """ This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks. @@ -1134,7 +1134,7 @@ class PeftCacheManager(BaseResourceManager): def __init__(self, peft_cache_config: PeftCacheConfig, lora_config: LoraConfig, - model_config: ModelConfig, + model_config: ModelConfigCpp, world_config: WorldConfig | None = None): import tensorrt_llm.bindings as _tb @@ -1179,26 +1179,11 @@ def __init__(self, ) self._lora_manager = LoraManager( mapping=mapping, - model_config=self._model_config_binding_to_model_config_runtime( - model_config)) + model_config=ModelConfigPython.from_model_config_cpp( + model_config, mapping)) - @staticmethod - def _model_config_binding_to_model_config_runtime( - model_config_binding: ModelConfig) -> ModelConfigRuntime: - # TODO ZUKER: Init the rest of the fields as well? - return ModelConfigRuntime( - max_batch_size=model_config_binding.max_batch_size, - max_beam_width=model_config_binding.max_beam_width, - vocab_size=model_config_binding.vocab_size, - num_layers=model_config_binding.num_layers( - ), # TODO ZUKER: Should num_layers get the PP args from mapping? - num_heads=model_config_binding.num_heads, - num_kv_heads=model_config_binding.num_kv_heads(0), - hidden_size=model_config_binding.hidden_size, - head_size=model_config_binding.head_size, - gpt_attention_plugin=model_config_binding.use_gpt_attention_plugin, - dtype=binding_to_str_dtype(model_config_binding.data_type), - ) + def get_lora_manager(self) -> LoraManager: + return self._lora_manager def add_request_peft(self, request: LlmRequest): if request.lora_task_id is not None: diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index ecf2473a84c..2d63cadbe99 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -206,10 +206,14 @@ def _create_engine(executor_config): # point in the TRT flow is currently not supported (it's at the CPP # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA # optimization is not available in TRT-python flow. + + # NOTE: In TRT-python flow, llm_args is not set, so we create Mapping() with default values here. + # this would cause incorrect weight split on TP>1 for fused LoRA modules. + mapping = Mapping( + ) if self.llm_args is None else self.llm_args.parallel_config.to_mapping( + ) self._lora_manager = LoraManager( - # TODO ZUKER: Somehow create Mapping with real info when self.llm_args is None - mapping=Mapping() if self.llm_args is None else - self.llm_args.parallel_config.to_mapping(), + mapping=mapping, model_config=self._runtime_model_config, cpp_peft_cache_manager=None) if engine_config.build_config.max_prompt_embedding_table_size > 0: @@ -222,7 +226,7 @@ def _create_engine(executor_config): ResourceManagerType peft_cache_manager = self.engine.resource_manager.resource_managers.get( ResourceManagerType.PEFT_CACHE_MANAGER) - self._lora_manager = peft_cache_manager._lora_manager + self._lora_manager = peft_cache_manager.get_lora_manager() lora_model_config = self.engine.model_engine.lora_model_config assert lora_model_config is not None self._lora_model_config = lora_model_config diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 8a2f69a89be..c872a98f3dd 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -670,6 +670,8 @@ def __init__( """Constructor. Args: + mapping (Mapping): Parallelism related information. + model_config (ModelConfig): model configuration python class (not CPP binding). cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when the adapter is already loaded in the LoRA CPU cache. @@ -976,7 +978,7 @@ def interleave_fused_lora_weights_for_tp( ) -> list[torch.Tensor]: assert weight.shape[rank_dim] == sum(part_sizes) - # Split the weights into their respective parts. e.g. weight -> [q, k, v] for attn_qkv. + # Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv. weight_parts = [ weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i]) for i in range(len(part_sizes)) @@ -985,7 +987,7 @@ def interleave_fused_lora_weights_for_tp( assert weight_parts[i].shape[rank_dim] % tp_size == 0 # Split each part into tp_size chunks. - # e.g. [q, k, v] -> [[q_rank0, ..., q_rankN], [k_rank0, ..., k_rankN], [v_rank0, ..., v_rankN]] + # e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]] # where N is TP size, for attn_qkv. weight_parts_tp_weights = [ torch.split( @@ -995,8 +997,8 @@ def interleave_fused_lora_weights_for_tp( ] # Interleave the parts across TP ranks and flatten the list of lists into a single list. - # e.g. [[q_rank0, ..., q_rankN], [k_rank0, ..., k_rankN], [v_rank0, ..., v_rankN]] - # -> [q_rank0, k_rank0, v_rank0, ..., q_rankN, k_rankN, v_rankN] where N is TP size, for attn_qkv. + # e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]] + # -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv. return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights))) def load_from_model_dir(uid, model_dir, hf_config): @@ -1081,8 +1083,8 @@ def load_from_model_dir(uid, model_dir, hf_config): # Prepare fused modules weights for TP # For fused modules, HF stores the parts weights sequentially, whereas with TP>1 we need them to be # interleaved. - # e.g. Convert [q, k, v] to [q_rank0, k_rank0, v_rank0, ..., q_rankN, k_rankN, v_rankN] where - # N=TP size + # e.g. Convert [Wq, Wk, Wv] to [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] + # where N=TP size, for attn_qkv. tp_size = self._mapping.tp_size part_sizes = [] if lora_module == "mlp_gate_up": diff --git a/tensorrt_llm/runtime/enc_dec_model_runner.py b/tensorrt_llm/runtime/enc_dec_model_runner.py index 3e1aa586678..57ed27ae091 100644 --- a/tensorrt_llm/runtime/enc_dec_model_runner.py +++ b/tensorrt_llm/runtime/enc_dec_model_runner.py @@ -175,7 +175,9 @@ def engine_setup(component): # encoder lora manager setup if self.encoder_model_config.lora_plugin: self.encoder_lora_manager = LoraManager( - mapping=self.encoder_runtime_mapping) + mapping=self.encoder_runtime_mapping, + model_config=self.encoder_model_config, + ) # TODO: this is only for bart self.encoder_lora_manager.load_from_hf( model_dirs=lora_dir, @@ -198,7 +200,9 @@ def engine_setup(component): # decoder lora manager setup if self.decoder_model_config.lora_plugin: self.decoder_lora_manager = LoraManager( - mapping=self.decoder_runtime_mapping) + mapping=self.decoder_runtime_mapping, + model_config=self.decoder_model_config, + ) # TODO: this is only for bart self.decoder_lora_manager.load_from_hf( model_dirs=lora_dir, diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index bf6e228f769..2c6a0e54089 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -40,8 +40,8 @@ PoolsKVCacheManager from tensorrt_llm.runtime.redrafter_utils import * -from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy, - trt_dtype_to_torch) +from .._utils import (binding_to_str_dtype, pad_vocab_size, str_dtype_to_torch, + torch_to_numpy, trt_dtype_to_torch) from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free from ..layers import LanguageAdapterConfig from ..logger import logger @@ -653,6 +653,37 @@ class ModelConfig: # language adapter language_adapter_config: Optional[LanguageAdapterConfig] = None + @classmethod + def from_model_config_cpp(cls, model_config_cpp, + mapping: Mapping) -> 'ModelConfig': + """Create a partially initialized ModelConfigPython from a given ModelConfigCpp. + + Note that each of these classes have fields that don't exist in the other, so the created ModelConfigPython + won't have all of its fields initialized. + """ + return cls( + max_batch_size=model_config_cpp.max_batch_size, + max_beam_width=model_config_cpp.max_beam_width, + vocab_size=model_config_cpp.vocab_size, + num_layers=model_config_cpp.num_layers( + pipeline_parallelism=mapping.pp_size, + pipeline_parallelism_rank=mapping.pp_rank, + ), + num_heads=model_config_cpp.num_heads, + num_kv_heads=model_config_cpp.num_kv_heads(0), + hidden_size=model_config_cpp.hidden_size, + kv_cache_type=model_config_cpp.kv_cache_type, + cross_attention=model_config_cpp.use_cross_attention, + head_size=model_config_cpp.head_size, + max_prompt_embedding_table_size=model_config_cpp. + max_prompt_embedding_table_size, + gpt_attention_plugin=model_config_cpp.use_gpt_attention_plugin, + dtype=binding_to_str_dtype(model_config_cpp.data_type), + num_kv_heads_per_layer=model_config_cpp.num_kv_heads_per_layer, + tokens_per_block=model_config_cpp.tokens_per_block, + lora_plugin=model_config_cpp.use_lora_plugin, + ) + @dataclass class SamplingConfig: diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index c7045854b84..94965e66d2f 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -611,7 +611,8 @@ def from_engine( session.runtime._set_weight_streaming(gpu_weights_percent) if session.use_lora_plugin: - lora_manager = LoraManager(mapping=runtime_mapping) + lora_manager = LoraManager(mapping=runtime_mapping, + model_config=model_config) if lora_dir is not None: lora_manager.load_from_ckpt(lora_dir, model_config=model_config, @@ -719,7 +720,8 @@ def from_dir( debug_mode=debug_mode, stream=stream) if session.use_lora_plugin: - lora_manager = LoraManager(mapping=runtime_mapping) + lora_manager = LoraManager(mapping=runtime_mapping, + model_config=model_config) if lora_dir is not None: lora_manager.load_from_ckpt(lora_dir, model_config=model_config, diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index d3cd6dff965..0e253e68986 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -32,8 +32,9 @@ from ..layers import MropeParams from ..logger import logger from ..mapping import Mapping -from .generation import (LogitsProcessor, LoraManager, SamplingConfig, - StoppingCriteria) +from .generation import LogitsProcessor, LoraManager +from .generation import ModelConfig as ModelConfigPython +from .generation import SamplingConfig, StoppingCriteria from .model_runner import ModelRunnerMixin, _engine_config_to_model_config _bindings_dtype_to_torch_dtype_dict = { @@ -277,8 +278,11 @@ def from_dir( engine_config = EngineConfig.from_json_file(f"{engine_dir}/config.json") if model_config.use_lora_plugin and rank == 0: + mapping = _world_config_to_mapping(world_config) lora_manager = LoraManager( - mapping=_world_config_to_mapping(world_config)) + mapping=mapping, + model_config=ModelConfigPython.from_model_config_cpp( + model_config, mapping)) if lora_dir is None: config_lora_dir = engine_config.build_config.lora_config.lora_dir if len(config_lora_dir) > 0: diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 58673aa0699..0cf5e708e45 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -11,6 +11,58 @@ from tensorrt_llm import SamplingParams from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi.llm import BaseLLM +from tensorrt_llm.lora_helper import LoraConfig + +_RU_LORA_ADAPTER_PROMPTS = [ + "Назови главную площадь в центре Москвы.", + "Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".", + "Что означает выражение \"водить за нос\"? Объясни в двух словах.", +] + + +def _generate_llm_response_lora_fused_modules(llm_class: Type[BaseLLM], + prompts: list[str], + **extra_llm_kwargs) -> list[str]: + """Generates responses with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter. + The used LoRA adapter has fused attention QKV and fused MLP gate up proj modules. + Returns the generated texts. + """ # noqa: D205 + hf_model_dir = f"{llm_models_root()}/Phi-3/Phi-3-mini-4k-instruct" + hf_lora_dir = f"{llm_models_root()}/lora/phi/Phi-3-mini-4k-instruct-ru-lora" + + lora_req = LoRARequest("ru-lora", 0, hf_lora_dir) + sampling_params = SamplingParams(max_tokens=20) + + lora_config = LoraConfig(lora_dir=[hf_lora_dir], + max_lora_rank=16, + max_loras=2, + max_cpu_loras=2) + + lora_requests = [lora_req] * len(prompts) + with llm_class(hf_model_dir, lora_config=lora_config, + **extra_llm_kwargs) as llm: + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests) + + return [output.outputs[0].text for output in outputs] + + +def check_lora_fused_modules_output_tp2_identical_to_tp1( + llm_class: Type[BaseLLM], **extra_llm_kwargs) -> None: + """Tests the output with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter with TP=2 is identical to + the output with TP=1. + That LoRA adapter has fused attention QKV and fused MLP gate up proj modules. + """ # noqa: D205 + extra_llm_kwargs["tensor_parallel_size"] = 1 + outputs_tp1 = _generate_llm_response_lora_fused_modules( + llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs) + + extra_llm_kwargs["tensor_parallel_size"] = 2 + outputs_tp2 = _generate_llm_response_lora_fused_modules( + llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs) + + assert outputs_tp1 == outputs_tp2 def check_llama_7b_multi_unique_lora_adapters_from_request( diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index b145122d176..1e26c869356 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -5,7 +5,7 @@ from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.lora_helper import LoraConfig -from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness +from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_lora_fused_modules_output_tp2_identical_to_tp1 from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm import _test_llm_capture_request_error from utils.util import skip_ray @@ -62,6 +62,14 @@ def test_llama_7b_multi_lora_tp2(): cuda_graph_config=None) +def test_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None: + check_lora_fused_modules_output_tp2_identical_to_tp1( + LLM, + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) + + @pytest.mark.skip(reason="https://nvbugs/5560921") @skip_ray @pytest.mark.gpu2 From 7b1dff0ff75825263d751a0a30434fb098399e99 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:50 +0000 Subject: [PATCH 06/16] Add inline comment Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index c872a98f3dd..77b65590d79 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1092,6 +1092,8 @@ def load_from_model_dir(uid, model_dir, hf_config): half_size = t_out.shape[rank_dim] // 2 part_sizes = [half_size, half_size] elif lora_module == "attn_qkv": + # The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already + # divided by tp_size q_size = ( self._model_config.head_size * self._model_config.num_heads * tp_size ) From 8e379572488005f38ee4fd1f5c837c93dd303ddb Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:51 +0000 Subject: [PATCH 07/16] Using Python 3.8 type hints Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 4 ++-- tests/unittest/llmapi/lora_test_utils.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 77b65590d79..71ec437d9c6 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -974,8 +974,8 @@ def preprocess_lora_weights(lora_model, model_config): return lora_model def interleave_fused_lora_weights_for_tp( - weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: list[int] - ) -> list[torch.Tensor]: + weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int] + ) -> List[torch.Tensor]: assert weight.shape[rank_dim] == sum(part_sizes) # Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv. diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 0cf5e708e45..5396c3d8fb4 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -2,7 +2,7 @@ import tarfile import tempfile from pathlib import Path -from typing import OrderedDict, Type +from typing import List, OrderedDict, Type import torch from utils.llm_data import llm_models_root @@ -21,8 +21,8 @@ def _generate_llm_response_lora_fused_modules(llm_class: Type[BaseLLM], - prompts: list[str], - **extra_llm_kwargs) -> list[str]: + prompts: List[str], + **extra_llm_kwargs) -> List[str]: """Generates responses with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter. The used LoRA adapter has fused attention QKV and fused MLP gate up proj modules. Returns the generated texts. @@ -66,7 +66,7 @@ def check_lora_fused_modules_output_tp2_identical_to_tp1( def check_llama_7b_multi_unique_lora_adapters_from_request( - lora_adapter_count_per_call: list[int], repeat_calls: int, + lora_adapter_count_per_call: List[int], repeat_calls: int, repeats_per_call: int, llm_class: Type[BaseLLM], **llm_kwargs): """Calls llm.generate s.t. for each C in lora_adapter_count_per_call, llm.generate is called with C requests repeated 'repeats_per_call' times, where each request is configured with a unique LoRA adapter ID. From 5e0c767715d789a39aee201a11a3aab500a70346 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:52 +0000 Subject: [PATCH 08/16] Remove unnecessary import Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/executor/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 623066ed120..59e3fca19fc 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -10,7 +10,6 @@ import zmq from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping from .._utils import mpi_comm, mpi_rank from ..bindings import executor as tllm From 8b7dda84295f2558efe97185da657c636cc16403 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:53 +0000 Subject: [PATCH 09/16] Remove passing of removed runtime_mapping argument in call to lora_manager.load_from_ckpt Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/executor/base_worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 2d63cadbe99..bd16c1a582d 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -311,7 +311,6 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: [lora_request.path], model_config=self._runtime_model_config if self._runtime_model_config is not None else self._lora_model_config, - runtime_mapping=None, uids=[adapter_id], ckpt_source=lora_request.ckpt_source) return adapter_id in newly_loaded_uids From 00da932c7ad9056fbf5b00055dcdee5c1b919c46 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:54 +0000 Subject: [PATCH 10/16] Pass mapping correctly in TRT-python flow as well Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/executor/base_worker.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index bd16c1a582d..572d8a531d6 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -10,7 +10,6 @@ import torch from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping from .._torch.pyexecutor.llm_request import LlmResponse from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, @@ -206,11 +205,7 @@ def _create_engine(executor_config): # point in the TRT flow is currently not supported (it's at the CPP # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA # optimization is not available in TRT-python flow. - - # NOTE: In TRT-python flow, llm_args is not set, so we create Mapping() with default values here. - # this would cause incorrect weight split on TP>1 for fused LoRA modules. - mapping = Mapping( - ) if self.llm_args is None else self.llm_args.parallel_config.to_mapping( + mapping = engine_config.pretrained_config.mapping if self.llm_args is None else self.llm_args.parallel_config.to_mapping( ) self._lora_manager = LoraManager( mapping=mapping, From 2fcb0409b7c1b46d2ff4395f52ffa0da5c5a6f96 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:54 +0000 Subject: [PATCH 11/16] In TRT-python flow use mapping only from engine_config.pretrained_config.mapping Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/executor/base_worker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 572d8a531d6..f2655cafb4e 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -205,10 +205,8 @@ def _create_engine(executor_config): # point in the TRT flow is currently not supported (it's at the CPP # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA # optimization is not available in TRT-python flow. - mapping = engine_config.pretrained_config.mapping if self.llm_args is None else self.llm_args.parallel_config.to_mapping( - ) self._lora_manager = LoraManager( - mapping=mapping, + mapping=engine_config.pretrained_config.mapping, model_config=self._runtime_model_config, cpp_peft_cache_manager=None) if engine_config.build_config.max_prompt_embedding_table_size > 0: From 446716ab0ff113f3e83228864414c95a0870f2e7 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:57 +0000 Subject: [PATCH 12/16] Refactor 'prepare fused lora modules for TP' logic into its own inner function, pass cpp_peft_cache_manager to LoraManager creation, add 'phi3' to relevant test names Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 3 +- tensorrt_llm/lora_manager.py | 65 ++++++++++--------- tests/unittest/llmapi/lora_test_utils.py | 12 ++-- .../llmapi/test_llm_multi_gpu_pytorch.py | 6 +- 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index ff3dfcf09c6..cd3a5e18ad0 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1180,7 +1180,8 @@ def __init__(self, self._lora_manager = LoraManager( mapping=mapping, model_config=ModelConfigPython.from_model_config_cpp( - model_config, mapping)) + model_config, mapping), + cpp_peft_cache_manager=self.impl) def get_lora_manager(self) -> LoraManager: return self._lora_manager diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 71ec437d9c6..dcbe6f85c97 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -671,7 +671,7 @@ def __init__( Args: mapping (Mapping): Parallelism related information. - model_config (ModelConfig): model configuration python class (not CPP binding). + model_config (ModelConfig): model configuration python class instance. cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when the adapter is already loaded in the LoRA CPU cache. @@ -1001,6 +1001,38 @@ def interleave_fused_lora_weights_for_tp( # -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv. return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights))) + def prepare_fused_lora_modules_for_tp( + lora_module: str, t_out: torch.Tensor, rank_dim: int + ) -> torch.Tensor: + """Interleaves fused LoRA modules weights for TP. This is required since HF stores the parts weights + sequentially, whereas with TP>1 we need them to be interleaved. + e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to + torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN]) + where N=TP size. + """ # noqa: D205 + tp_size = self._mapping.tp_size + if tp_size == 1: + return t_out + part_sizes = [] + if lora_module == "mlp_gate_up": + assert t_out.shape[rank_dim] % 2 == 0 + half_size = t_out.shape[rank_dim] // 2 + part_sizes = [half_size, half_size] + elif lora_module == "attn_qkv": + # The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already + # divided by tp_size + q_size = self._model_config.head_size * self._model_config.num_heads * tp_size + kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size + part_sizes = [q_size, kv_size, kv_size] + + if part_sizes: + interleaved_parts = interleave_fused_lora_weights_for_tp( + t_out, rank_dim, tp_size, part_sizes + ) + # Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights. + t_out = torch.cat(interleaved_parts, dim=rank_dim) + return t_out + def load_from_model_dir(uid, model_dir, hf_config): if uid not in self._cpp_lora_weights: self._cpp_lora_weights[uid] = [] # Will be converted to tensor later @@ -1079,36 +1111,7 @@ def load_from_model_dir(uid, model_dir, hf_config): is_dora = t_mag is not None rank_dim = 1 if has_expert_indices else 0 - - # Prepare fused modules weights for TP - # For fused modules, HF stores the parts weights sequentially, whereas with TP>1 we need them to be - # interleaved. - # e.g. Convert [Wq, Wk, Wv] to [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] - # where N=TP size, for attn_qkv. - tp_size = self._mapping.tp_size - part_sizes = [] - if lora_module == "mlp_gate_up": - assert t_out.shape[rank_dim] % 2 == 0 - half_size = t_out.shape[rank_dim] // 2 - part_sizes = [half_size, half_size] - elif lora_module == "attn_qkv": - # The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already - # divided by tp_size - q_size = ( - self._model_config.head_size * self._model_config.num_heads * tp_size - ) - kv_size = ( - self._model_config.head_size * self._model_config.num_kv_heads * tp_size - ) - part_sizes = [q_size, kv_size, kv_size] - - if part_sizes: - interleaved_parts = interleave_fused_lora_weights_for_tp( - t_out, rank_dim, tp_size, part_sizes - ) - # We have to concatenate them all back after interleaving, as the CPP expects the full non-split - # weights. - t_out = torch.cat(interleaved_parts, dim=rank_dim) + t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim) effective_rank = t_in.shape[rank_dim] diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 5396c3d8fb4..a123df495b9 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -20,9 +20,9 @@ ] -def _generate_llm_response_lora_fused_modules(llm_class: Type[BaseLLM], - prompts: List[str], - **extra_llm_kwargs) -> List[str]: +def _generate_phi3_response_lora_fused_modules(llm_class: Type[BaseLLM], + prompts: List[str], + **extra_llm_kwargs) -> List[str]: """Generates responses with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter. The used LoRA adapter has fused attention QKV and fused MLP gate up proj modules. Returns the generated texts. @@ -48,18 +48,18 @@ def _generate_llm_response_lora_fused_modules(llm_class: Type[BaseLLM], return [output.outputs[0].text for output in outputs] -def check_lora_fused_modules_output_tp2_identical_to_tp1( +def check_phi3_lora_fused_modules_output_tp2_identical_to_tp1( llm_class: Type[BaseLLM], **extra_llm_kwargs) -> None: """Tests the output with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter with TP=2 is identical to the output with TP=1. That LoRA adapter has fused attention QKV and fused MLP gate up proj modules. """ # noqa: D205 extra_llm_kwargs["tensor_parallel_size"] = 1 - outputs_tp1 = _generate_llm_response_lora_fused_modules( + outputs_tp1 = _generate_phi3_response_lora_fused_modules( llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs) extra_llm_kwargs["tensor_parallel_size"] = 2 - outputs_tp2 = _generate_llm_response_lora_fused_modules( + outputs_tp2 = _generate_phi3_response_lora_fused_modules( llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs) assert outputs_tp1 == outputs_tp2 diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index 1e26c869356..d5a59e95efb 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -5,7 +5,7 @@ from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.lora_helper import LoraConfig -from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_lora_fused_modules_output_tp2_identical_to_tp1 +from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_phi3_lora_fused_modules_output_tp2_identical_to_tp1 from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm import _test_llm_capture_request_error from utils.util import skip_ray @@ -62,8 +62,8 @@ def test_llama_7b_multi_lora_tp2(): cuda_graph_config=None) -def test_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None: - check_lora_fused_modules_output_tp2_identical_to_tp1( +def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None: + check_phi3_lora_fused_modules_output_tp2_identical_to_tp1( LLM, # Disable CUDA graph # TODO: remove this once we have a proper fix for CUDA graph in LoRA From d532a7994721c3a59de29f8367f97ae57f51e471 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:57 +0000 Subject: [PATCH 13/16] Convert more fields of ModelConfig class from binding Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/_utils.py | 6 +++++- tensorrt_llm/runtime/generation.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index baef7e79a6c..4c696511dc3 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -42,7 +42,7 @@ import tensorrt as trt # isort: on -from tensorrt_llm.bindings import DataType, GptJsonConfig +from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.logger import logger @@ -198,6 +198,10 @@ def str_dtype_to_torch(dtype): } +def binding_layer_type_to_str(layer_type: LayerType) -> str: + return layer_type.name.lower() + + def binding_to_str_dtype(binding_dtype) -> str: ret = _binding_to_str_dtype.get(binding_dtype) assert ret is not None, f'Unsupported binding dtype: {binding_dtype}' diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 2c6a0e54089..0e17bb5891d 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -40,8 +40,9 @@ PoolsKVCacheManager from tensorrt_llm.runtime.redrafter_utils import * -from .._utils import (binding_to_str_dtype, pad_vocab_size, str_dtype_to_torch, - torch_to_numpy, trt_dtype_to_torch) +from .._utils import (binding_layer_type_to_str, binding_to_str_dtype, + pad_vocab_size, str_dtype_to_torch, torch_to_numpy, + trt_dtype_to_torch) from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free from ..layers import LanguageAdapterConfig from ..logger import logger @@ -672,16 +673,24 @@ def from_model_config_cpp(cls, model_config_cpp, num_heads=model_config_cpp.num_heads, num_kv_heads=model_config_cpp.num_kv_heads(0), hidden_size=model_config_cpp.hidden_size, + remove_input_padding=model_config_cpp.use_packed_input, kv_cache_type=model_config_cpp.kv_cache_type, cross_attention=model_config_cpp.use_cross_attention, head_size=model_config_cpp.head_size, max_prompt_embedding_table_size=model_config_cpp. max_prompt_embedding_table_size, + quant_mode=QuantMode(model_config_cpp.quant_mode.value), + gather_context_logits=model_config_cpp.compute_context_logits, + gather_generation_logits=model_config_cpp.compute_generation_logits, gpt_attention_plugin=model_config_cpp.use_gpt_attention_plugin, dtype=binding_to_str_dtype(model_config_cpp.data_type), num_kv_heads_per_layer=model_config_cpp.num_kv_heads_per_layer, tokens_per_block=model_config_cpp.tokens_per_block, lora_plugin=model_config_cpp.use_lora_plugin, + layer_types=[ + binding_layer_type_to_str(lt) + for lt in model_config_cpp.layer_types + ], ) From 98bd0567d9983a52d859d1faad58507a13ff46af Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:58 +0000 Subject: [PATCH 14/16] Improve comments Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 2 +- tensorrt_llm/runtime/generation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index dcbe6f85c97..d141c6843b6 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1020,7 +1020,7 @@ def prepare_fused_lora_modules_for_tp( part_sizes = [half_size, half_size] elif lora_module == "attn_qkv": # The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already - # divided by tp_size + # divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config q_size = self._model_config.head_size * self._model_config.num_heads * tp_size kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size part_sizes = [q_size, kv_size, kv_size] diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 0e17bb5891d..67cdbbe3876 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -657,7 +657,7 @@ class ModelConfig: @classmethod def from_model_config_cpp(cls, model_config_cpp, mapping: Mapping) -> 'ModelConfig': - """Create a partially initialized ModelConfigPython from a given ModelConfigCpp. + """Create a partially initialized ModelConfig instance from a given ModelConfig CPP binding instance. Note that each of these classes have fields that don't exist in the other, so the created ModelConfigPython won't have all of its fields initialized. From ac03c28cdfb0d091cd70cb3c3d22d72d8cef9343 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:15:59 +0000 Subject: [PATCH 15/16] Convert model_config.num_layers from binding class to python class without dividing by PP size Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 3 +-- tensorrt_llm/runtime/generation.py | 8 ++------ tensorrt_llm/runtime/model_runner_cpp.py | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index cd3a5e18ad0..bc2804584af 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -1179,8 +1179,7 @@ def __init__(self, ) self._lora_manager = LoraManager( mapping=mapping, - model_config=ModelConfigPython.from_model_config_cpp( - model_config, mapping), + model_config=ModelConfigPython.from_model_config_cpp(model_config), cpp_peft_cache_manager=self.impl) def get_lora_manager(self) -> LoraManager: diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index 67cdbbe3876..36cdbf0aca5 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -655,8 +655,7 @@ class ModelConfig: language_adapter_config: Optional[LanguageAdapterConfig] = None @classmethod - def from_model_config_cpp(cls, model_config_cpp, - mapping: Mapping) -> 'ModelConfig': + def from_model_config_cpp(cls, model_config_cpp) -> 'ModelConfig': """Create a partially initialized ModelConfig instance from a given ModelConfig CPP binding instance. Note that each of these classes have fields that don't exist in the other, so the created ModelConfigPython @@ -666,10 +665,7 @@ def from_model_config_cpp(cls, model_config_cpp, max_batch_size=model_config_cpp.max_batch_size, max_beam_width=model_config_cpp.max_beam_width, vocab_size=model_config_cpp.vocab_size, - num_layers=model_config_cpp.num_layers( - pipeline_parallelism=mapping.pp_size, - pipeline_parallelism_rank=mapping.pp_rank, - ), + num_layers=model_config_cpp.num_layers(), num_heads=model_config_cpp.num_heads, num_kv_heads=model_config_cpp.num_kv_heads(0), hidden_size=model_config_cpp.hidden_size, diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index 0e253e68986..96895268074 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -282,7 +282,7 @@ def from_dir( lora_manager = LoraManager( mapping=mapping, model_config=ModelConfigPython.from_model_config_cpp( - model_config, mapping)) + model_config)) if lora_dir is None: config_lora_dir = engine_config.build_config.lora_config.lora_dir if len(config_lora_dir) > 0: From 30e6fb55eb2a1dc42ae875213e83f00008422330 Mon Sep 17 00:00:00 2001 From: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:16:00 +0000 Subject: [PATCH 16/16] Mark test to run only when there are 2 GPUs, improve documentation Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- tensorrt_llm/lora_manager.py | 14 +++++++++----- .../unittest/llmapi/test_llm_multi_gpu_pytorch.py | 1 + 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index d141c6843b6..4fe0d0b44cb 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -976,6 +976,11 @@ def preprocess_lora_weights(lora_model, model_config): def interleave_fused_lora_weights_for_tp( weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int] ) -> List[torch.Tensor]: + """Interleaves fused LoRA modules weights for TP. + e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to + torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN]) + where N=TP size. + """ # noqa: D205 assert weight.shape[rank_dim] == sum(part_sizes) # Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv. @@ -1004,11 +1009,10 @@ def interleave_fused_lora_weights_for_tp( def prepare_fused_lora_modules_for_tp( lora_module: str, t_out: torch.Tensor, rank_dim: int ) -> torch.Tensor: - """Interleaves fused LoRA modules weights for TP. This is required since HF stores the parts weights - sequentially, whereas with TP>1 we need them to be interleaved. - e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to - torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN]) - where N=TP size. + """Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights + sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly. + + See interleave_fused_lora_weights_for_tp for more details. """ # noqa: D205 tp_size = self._mapping.tp_size if tp_size == 1: diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index d5a59e95efb..f4fa75e7da0 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -62,6 +62,7 @@ def test_llama_7b_multi_lora_tp2(): cuda_graph_config=None) +@pytest.mark.gpu2 def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None: check_phi3_lora_fused_modules_output_tp2_identical_to_tp1( LLM,