Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
from tensorrt_llm.sampling_params import SamplingParams

from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range
Expand All @@ -32,7 +33,7 @@
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
ModelConfig = tensorrt_llm.bindings.ModelConfig
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
DataType = tensorrt_llm.bindings.DataType
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
RequestList = list[LlmRequest]
Expand Down Expand Up @@ -160,7 +161,7 @@ def __init__(
spec_config: Optional["DecodingBaseConfig"] = None,
layer_mask: Optional[List[bool]] = None,
max_num_tokens: int = 8192,
model_config: Optional[ModelConfig] = None,
model_config: Optional[ModelConfigCpp] = None,
max_beam_width: int = 1,
is_draft: bool = False,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
Expand Down Expand Up @@ -371,7 +372,7 @@ def shutdown(self):

@classmethod
def from_model_config(cls,
model_config: ModelConfig,
model_config: ModelConfigCpp,
kv_cache_config: KvCacheConfigCpp,
mapping: Mapping,
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
Expand Down Expand Up @@ -772,7 +773,7 @@ def adjust_window_sizes_for_vswa(
window_size_to_layers: Dict[int, List[int]],
max_attention_window_vec: List[int],
kv_cache_config: KvCacheConfigCpp,
model_config: ModelConfig,
model_config: ModelConfigCpp,
pool_memory_bytes: int,
kv_factor: int,
dtype: DataType,
Expand Down Expand Up @@ -887,7 +888,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int:
def calculate_max_num_blocks_from_cpp(
self,
kv_cache_config: KvCacheConfigCpp,
model_config: ModelConfig,
model_config: ModelConfigCpp,
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
"""
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
Expand Down Expand Up @@ -1133,7 +1134,7 @@ class PeftCacheManager(BaseResourceManager):
def __init__(self,
peft_cache_config: PeftCacheConfig,
lora_config: LoraConfig,
model_config: ModelConfig,
model_config: ModelConfigCpp,
world_config: WorldConfig | None = None):
import tensorrt_llm.bindings as _tb

Expand Down Expand Up @@ -1169,7 +1170,20 @@ def __init__(self,
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
binding_to_str_dtype(model_config.data_type),
lora_config.swap_gate_up_proj_lora_b_weight)
self._lora_manager = LoraManager()
mapping = Mapping(
world_size=world_config.size,
rank=world_config.rank,
tp_size=world_config.tensor_parallelism,
pp_size=world_config.pipeline_parallelism,
gpus_per_node=world_config.gpus_per_node,
)
self._lora_manager = LoraManager(
mapping=mapping,
model_config=ModelConfigPython.from_model_config_cpp(model_config),
cpp_peft_cache_manager=self.impl)

def get_lora_manager(self) -> LoraManager:
return self._lora_manager

def add_request_peft(self, request: LlmRequest):
if request.lora_task_id is not None:
Expand All @@ -1183,7 +1197,6 @@ def add_request_peft(self, request: LlmRequest):
self._lora_manager.load_from_ckpt(
[request.py_lora_path],
model_config=self._lora_model_config,
runtime_mapping=None,
uids=[request.lora_task_id],
ckpt_source=self._lora_config.lora_ckpt_source)
request.lora_weights = self._lora_manager.cpp_lora_weights[
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import tensorrt as trt
# isort: on

from tensorrt_llm.bindings import DataType, GptJsonConfig
from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
from tensorrt_llm.logger import logger

Expand Down Expand Up @@ -198,6 +198,10 @@ def str_dtype_to_torch(dtype):
}


def binding_layer_type_to_str(layer_type: LayerType) -> str:
return layer_type.name.lower()


def binding_to_str_dtype(binding_dtype) -> str:
ret = _binding_to_str_dtype.get(binding_dtype)
assert ret is not None, f'Unsupported binding dtype: {binding_dtype}'
Expand Down
9 changes: 5 additions & 4 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ def _create_engine(executor_config):
# point in the TRT flow is currently not supported (it's at the CPP
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
# optimization is not available in TRT-python flow.
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
self._lora_manager = LoraManager(
mapping=engine_config.pretrained_config.mapping,
model_config=self._runtime_model_config,
cpp_peft_cache_manager=None)
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()

Expand All @@ -216,8 +219,7 @@ def _create_engine(executor_config):
ResourceManagerType
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
ResourceManagerType.PEFT_CACHE_MANAGER)
self._lora_manager = LoraManager(
cpp_peft_cache_manager=peft_cache_manager.impl)
self._lora_manager = peft_cache_manager.get_lora_manager()
lora_model_config = self.engine.model_engine.lora_model_config
assert lora_model_config is not None
self._lora_model_config = lora_model_config
Expand Down Expand Up @@ -302,7 +304,6 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
[lora_request.path],
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
runtime_mapping=None,
uids=[adapter_id],
ckpt_source=lora_request.ckpt_source)
return adapter_id in newly_loaded_uids
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/lora_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules():
"attn_q": "q_proj",
"attn_k": "k_proj",
"attn_v": "v_proj",
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_proj",
"mlp_4h_to_h": "down_proj",
Expand Down
125 changes: 76 additions & 49 deletions tensorrt_llm/lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import itertools
import json
import logging
import re
Expand Down Expand Up @@ -660,11 +661,17 @@ class LoraManager(object):
}

def __init__(
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
self,
*,
mapping: Mapping,
model_config: "ModelConfig",
cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None,
):
"""Constructor.

Args:
mapping (Mapping): Parallelism related information.
model_config (ModelConfig): model configuration python class instance.
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
the adapter is already loaded in the LoRA CPU cache.
Expand Down Expand Up @@ -704,6 +711,8 @@ def __init__(
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
self.lora_target_modules: List[str] = []
self._mapping = mapping
self._model_config = model_config
self._cpp_peft_cache_manager = cpp_peft_cache_manager

def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
Expand All @@ -730,7 +739,6 @@ def load_from_ckpt(
self,
model_dirs_or_files: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
ckpt_source: str = "hf",
) -> List[str]:
Expand All @@ -743,7 +751,6 @@ def load_from_ckpt(
return self.load_from_hf(
model_dirs=model_dirs_or_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids,
)
elif ckpt_source == "nemo":
Expand All @@ -754,7 +761,6 @@ def load_from_ckpt(
return self.load_from_nemo(
model_files=nemo_files,
model_config=model_config,
runtime_mapping=runtime_mapping,
uids=uids,
)
else:
Expand All @@ -764,19 +770,13 @@ def load_from_nemo(
self,
model_files: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
) -> List[str]:
"""Returns the adapter UIDs that were loaded by this call.

Note that when an adapter was already loaded before this call, it would not be
included in the returned list of UIDs.
"""
if runtime_mapping is None:
runtime_mapping = Mapping()
tp_size = runtime_mapping.tp_size
tp_rank = runtime_mapping.tp_rank

if uids is None:
uids = [self._generate_uid() for _ in range(len(model_files))]
assert len(uids) == len(model_files)
Expand Down Expand Up @@ -829,10 +829,6 @@ def load_from_model_file(uid, model_file):

t_in = all_lora_weights[layer_idx]["in"]
t_out = all_lora_weights[layer_idx]["out"]
assert t_out.shape[0] % tp_size == 0
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
tp_rank
].contiguous()
else:
t_in = None
t_out = None
Expand Down Expand Up @@ -882,7 +878,6 @@ def load_from_hf(
self,
model_dirs: List[str],
model_config: Union["ModelConfig", LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None,
component: Optional[str] = None,
) -> List[str]:
Expand Down Expand Up @@ -939,11 +934,6 @@ def load_from_hf(
...

"""
if runtime_mapping is None:
runtime_mapping = Mapping()
tp_size = runtime_mapping.tp_size
tp_rank = runtime_mapping.tp_rank

if uids is None:
uids = [self._generate_uid() for _ in range(len(model_dirs))]
assert len(uids) == len(model_dirs)
Expand Down Expand Up @@ -983,6 +973,70 @@ def preprocess_lora_weights(lora_model, model_config):
lora_model[key] = value
return lora_model

def interleave_fused_lora_weights_for_tp(
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
) -> List[torch.Tensor]:
"""Interleaves fused LoRA modules weights for TP.
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
where N=TP size.
""" # noqa: D205
assert weight.shape[rank_dim] == sum(part_sizes)

# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
weight_parts = [
weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i])
for i in range(len(part_sizes))
]
for i in range(len(part_sizes)):
assert weight_parts[i].shape[rank_dim] % tp_size == 0

# Split each part into tp_size chunks.
# e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
# where N is TP size, for attn_qkv.
weight_parts_tp_weights = [
torch.split(
weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim
)
for i in range(len(part_sizes))
]

# Interleave the parts across TP ranks and flatten the list of lists into a single list.
# e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
# -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights)))

def prepare_fused_lora_modules_for_tp(
lora_module: str, t_out: torch.Tensor, rank_dim: int
) -> torch.Tensor:
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.

See interleave_fused_lora_weights_for_tp for more details.
""" # noqa: D205
tp_size = self._mapping.tp_size
if tp_size == 1:
return t_out
part_sizes = []
if lora_module == "mlp_gate_up":
assert t_out.shape[rank_dim] % 2 == 0
half_size = t_out.shape[rank_dim] // 2
part_sizes = [half_size, half_size]
elif lora_module == "attn_qkv":
# The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
# divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
q_size = self._model_config.head_size * self._model_config.num_heads * tp_size
kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size
part_sizes = [q_size, kv_size, kv_size]

if part_sizes:
interleaved_parts = interleave_fused_lora_weights_for_tp(
t_out, rank_dim, tp_size, part_sizes
)
# Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
t_out = torch.cat(interleaved_parts, dim=rank_dim)
return t_out

def load_from_model_dir(uid, model_dir, hf_config):
if uid not in self._cpp_lora_weights:
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
Expand Down Expand Up @@ -1060,36 +1114,9 @@ def load_from_model_dir(uid, model_dir, hf_config):
t_mag = module_weights.get("magnitude", None)

is_dora = t_mag is not None

if lora_module in ["moe_router", "mlp_router"]:
pass
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
pass
elif lora_module in [
"attn_dense",
"cross_attn_dense",
"mlp_4h_to_h",
"moe_4h_to_h",
]:
# split by row
dim = 2 if has_expert_indices else 1
assert t_in.shape[dim] % tp_size == 0
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
tp_rank
].contiguous()
else:
# split by column
dim = 1 if has_expert_indices else 0
assert t_out.shape[dim] % tp_size == 0
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
tp_rank
].contiguous()
if dim == 0 and is_dora and t_mag is not None:
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
tp_rank
].contiguous()

rank_dim = 1 if has_expert_indices else 0
t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim)

effective_rank = t_in.shape[rank_dim]

t_in = t_in.cuda().contiguous()
Expand Down
Loading