Skip to content

Commit 38e1bac

Browse files
committed
dump first patched cuda graph version
Signed-off-by: Shahar Mor <[email protected]>
1 parent 85af621 commit 38e1bac

File tree

7 files changed

+170
-13
lines changed

7 files changed

+170
-13
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
attn_metadata: AttentionMetadata,
3535
spec_metadata: Optional[SpecMetadata] = None,
3636
use_mrope: bool = False,
37+
lora_params: Optional[dict] = None,
3738
) -> None:
3839
"""
3940
Stores a CUDA graph and its associated input buffers.
@@ -68,6 +69,7 @@ def __init__(
6869

6970
self.attn_metadata = attn_metadata
7071
self.spec_metadata = spec_metadata
72+
self.lora_params = lora_params
7173
self._output = None
7274
self._graph = None
7375
self.optional_extra_model_inputs = ["mrope_position_deltas"]
@@ -90,6 +92,9 @@ def capture(
9092
"mrope_position_deltas": self.mrope_position_deltas,
9193
}
9294

95+
if self.lora_params is not None:
96+
inputs["lora_params"] = self.lora_params
97+
9398
# We have to do warm up runs to initialize PyTorch's
9499
# internal states according to the docs:
95100
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
2727
MultimodalRuntimeData)
2828
from tensorrt_llm.logger import logger
29-
from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig
29+
from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig
3030
from tensorrt_llm.mapping import Mapping
3131
from tensorrt_llm.models.modeling_utils import QuantAlgo
3232
from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2
@@ -287,6 +287,16 @@ def __init__(
287287
)
288288

289289
attn_backend = pytorch_backend_config.attn_backend
290+
291+
self.lora_manager: Optional[LoraManager] = None
292+
if lora_config is not None:
293+
self.lora_manager = LoraManager()
294+
295+
self.lora_prefetch_requests_list = None # TODO smor - fix "LoRARequest" import
296+
if lora_config is not None and lora_config.lora_request is not None:
297+
self.lora_prefetch_requests_list = lora_config.lora_request
298+
self.has_lora_prefetched = False
299+
290300
self.model = self._load_model(
291301
model_path,
292302
mapping=self.mapping,
@@ -445,6 +455,27 @@ def set_lora_model_config(self, lora_target_modules: list[str],
445455
hidden_size=self.model.config.hidden_size,
446456
dtype=torch_dtype_to_str(self.model.config.torch_dtype))
447457

458+
def set_lora_manager_cpp_peft_cache_manager(
459+
self, resource_manager: ResourceManager):
460+
cpp_peft_cache_manager = resource_manager.get_resource_manager(
461+
ResourceManagerType.PEFT_CACHE_MANAGER)
462+
if cpp_peft_cache_manager is not None and self.lora_manager is not None:
463+
self.lora_manager.set_cpp_peft_cache_manager(
464+
cpp_peft_cache_manager.impl)
465+
466+
def prefetch_lora_dirs(self):
467+
if self.lora_prefetch_requests_list is None:
468+
return
469+
470+
for request in self.lora_prefetch_requests_list:
471+
self.lora_manager.load_from_ckpt(
472+
[request.path],
473+
model_config=self.lora_model_config,
474+
runtime_mapping=None,
475+
uids=[request.adapter_id])
476+
477+
self.has_lora_prefetched = True
478+
448479
@property
449480
def use_mrope(self):
450481
use_mrope = False
@@ -503,6 +534,16 @@ def warmup(self, resource_manager: ResourceManager) -> None:
503534
self.cuda_graph_dummy_request = None
504535

505536
def get_cuda_graph_warmup_request(batch_size, draft_len):
537+
lora_config = None
538+
if self.has_lora_prefetched:
539+
# TODO smor currently I assume a single adapter with uid 0, change this
540+
uid = 0
541+
from tensorrt_llm.bindings import executor as tllm
542+
lora_config = tllm.LoraConfig(
543+
task_id=uid,
544+
weights=self.lora_manager.cpp_lora_weights[uid],
545+
config=self.lora_manager.cpp_lora_config[uid])
546+
506547
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
507548
available_blocks = kv_cache_manager.get_num_free_blocks(
508549
) // self.max_beam_width
@@ -516,7 +557,10 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
516557
is_gen=True,
517558
max_num_draft_tokens=draft_len,
518559
use_mrope=use_mrope,
519-
max_beam_width=self.max_beam_width)
560+
max_beam_width=self.max_beam_width,
561+
lora_request=
562+
lora_config, # TODO smor- tests assume BS1 then this will be ignored for now, need to resolve
563+
)
520564
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
521565
available_tokens = kv_cache_manager.get_num_available_tokens(
522566
draft_len)
@@ -530,7 +574,8 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
530574
is_gen=True,
531575
max_num_draft_tokens=draft_len,
532576
use_mrope=use_mrope,
533-
max_beam_width=self.max_beam_width)[0]
577+
max_beam_width=self.max_beam_width,
578+
lora_request=lora_config)[0]
534579
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
535580
# This batch contains both the longest request and the shortest requests,
536581
# it also contains the maximum number of requests and the maximum token number,
@@ -926,6 +971,7 @@ def _round_up_batch_size(self, batch_size: int) -> int:
926971
def _maybe_get_cuda_graph(
927972
self,
928973
batch: ScheduledRequests,
974+
resource_manager: Optional[ResourceManager] = None
929975
) -> Optional[DecodingCUDAGraphRunner]:
930976
"""
931977
Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled
@@ -972,13 +1018,60 @@ def _maybe_get_cuda_graph(
9721018
else:
9731019
spec_metadata = None
9741020

1021+
lora_params = None
1022+
if self.has_lora_prefetched:
1023+
peft_cache_manager = resource_manager.get_resource_manager(
1024+
ResourceManagerType.PEFT_CACHE_MANAGER)
1025+
1026+
context_requests = batch.context_requests
1027+
generation_requests = batch.generation_requests
1028+
1029+
if len(context_requests) > 0 and len(generation_requests) > 0:
1030+
raise ValueError(
1031+
"SMOR, non empty context and generation requests isn't tested yet"
1032+
)
1033+
1034+
if len(context_requests) > 0:
1035+
raise ValueError("SMOR, context requests isn't tested yet")
1036+
1037+
if len(generation_requests) > 1:
1038+
raise ValueError("SMOR, generation requests isn't tested yet")
1039+
1040+
generation_request = generation_requests[0]
1041+
# TODO smor I have no idea why this is happening
1042+
generation_request.lora_weights = generation_request.lora_weights.reshape(
1043+
[1] + list(generation_request.lora_weights.shape))
1044+
generation_request.lora_config = generation_request.lora_config.reshape(
1045+
[1] + list(generation_request.lora_config.shape))
1046+
peft_cache_manager.impl.add_request_peft(generation_request, True)
1047+
1048+
py_lora_task_layer_module_configs = peft_cache_manager.impl.ensure_batch(
1049+
context_requests, generation_requests, False)
1050+
for req in context_requests:
1051+
req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
1052+
req.
1053+
py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
1054+
for req in generation_requests:
1055+
req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
1056+
req.
1057+
py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
1058+
1059+
# TODO smor - look at get lora params from requests
1060+
# You need something that isn't scheduled requests
1061+
# It also appears that you should make sure resource manager is called, because prefetch
1062+
# has to be added to peftCacheManager as well. So it still shouldn't work
1063+
1064+
lora_params = self._get_lora_params_from_requests(
1065+
batch, attn_metadata)
1066+
print(f"SMOR, not failed on lora_params in maybe_get_cuda_graph")
1067+
9751068
# Initialize nested dictionary if needed
9761069
if batch_size not in self._cuda_graphs:
9771070
self._cuda_graphs[batch_size] = {}
9781071

9791072
self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner(
9801073
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
981-
self.use_mrope)
1074+
self.use_mrope, lora_params)
9821075
return self._cuda_graphs[batch_size][draft_len]
9831076

9841077
def __del__(self) -> None:
@@ -2134,7 +2227,8 @@ def forward(
21342227
gather_context_logits)
21352228
with self._maybe_pad_batch(scheduled_requests, kv_cache_manager,
21362229
spec_resource_manager) as scheduled_requests:
2137-
maybe_graph = self._maybe_get_cuda_graph(scheduled_requests)
2230+
maybe_graph = self._maybe_get_cuda_graph(
2231+
scheduled_requests, resource_manager=resource_manager)
21382232
if maybe_graph is not None:
21392233
attn_metadata = maybe_graph.attn_metadata
21402234
spec_metadata = maybe_graph.spec_metadata

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ def __init__(self,
211211
self.micro_batches: List[BatchStatePP
212212
| None] = [None] * self.num_micro_batches
213213
self.send_handles = [None] * self.num_micro_batches
214+
self.model_engine.set_lora_manager_cpp_peft_cache_manager(
215+
self.resource_manager)
216+
self.model_engine.prefetch_lora_dirs()
214217

215218
self.inflight_req_ids = ReqIdsSet()
216219

@@ -274,6 +277,9 @@ def _event_loop_wrapper(self):
274277
finally:
275278
self._executor_loop_cleanup()
276279

280+
def get_lora_manager(self):
281+
return self.model_engine.lora_manager
282+
277283
def start_worker(self):
278284
self.worker_lock.acquire()
279285
try:

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def add_dummy_requests(
400400
max_num_draft_tokens: int = 0,
401401
use_mrope: bool = False,
402402
max_beam_width: int = 1,
403+
lora_request=None,
403404
):
404405
beam_width = max_beam_width
405406
requests = []
@@ -419,14 +420,28 @@ def add_dummy_requests(
419420
# Using 1 instead of 0 prevents NaN during warmup in e.g. Deepseek
420421
mrope_position_deltas = torch.zeros(
421422
1, device="cuda", dtype=torch.int32) if use_mrope else None
423+
424+
lora_task_id = None
425+
lora_weights = None
426+
lora_config = None
427+
428+
if lora_request is not None:
429+
# TODO smor currently work with single adapter only, not sure how this should work with request ids
430+
lora_task_id = lora_request.task_id
431+
lora_weights = lora_request.weights
432+
lora_config = lora_request.config
433+
422434
req = LlmRequest(request_id=req_id,
423435
max_new_tokens=1,
424436
input_tokens=[1] * token_num,
425437
sampling_config=SamplingConfig(
426438
sampling_params._get_sampling_config()),
427439
is_streaming=False,
428440
mrope_position_deltas=mrope_position_deltas,
429-
encoder_input_tokens=encoder_input_tokens)
441+
encoder_input_tokens=encoder_input_tokens,
442+
lora_task_id=lora_task_id,
443+
lora_weights=lora_weights,
444+
lora_config=lora_config)
430445
req.is_dummy_request = True
431446
req.paged_kv_block_ids = []
432447
if prepare_resource:

tensorrt_llm/executor/worker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,13 @@ def _create_engine():
161161

162162
if getattr(executor_config, "backend",
163163
"") == "pytorch" and lora_config is not None:
164-
from tensorrt_llm._torch.pyexecutor.resource_manager import \
165-
ResourceManagerType
166-
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
167-
ResourceManagerType.PEFT_CACHE_MANAGER)
168-
self._lora_manager = LoraManager(
169-
cpp_peft_cache_manager=peft_cache_manager.impl)
164+
# from tensorrt_llm._torch.pyexecutor.resource_manager import \
165+
# ResourceManagerType
166+
# peft_cache_manager = self.engine.resource_manager.resource_managers.get(
167+
# ResourceManagerType.PEFT_CACHE_MANAGER)
168+
# self._lora_manager = LoraManager(
169+
# cpp_peft_cache_manager=peft_cache_manager.impl)
170+
self._lora_manager = self.engine.get_lora_manager()
170171
lora_model_config = self.engine.model_engine.lora_model_config
171172
assert lora_model_config is not None
172173
self._lora_model_config = lora_model_config

tensorrt_llm/lora_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dataclasses import dataclass, field
99
from functools import lru_cache
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
11+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
1212

1313
import numpy as np
1414
import torch
@@ -241,6 +241,7 @@ class LoraConfig(DictConversion):
241241
trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict)
242242
max_loras: int | None = None
243243
max_cpu_loras: int | None = None
244+
lora_request: Optional[List[Any]] = None # TODO smor fix
244245

245246
def __post_init__(self):
246247
assert self.lora_ckpt_source in ["hf", "nemo"], (
@@ -747,6 +748,11 @@ def __init__(
747748
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
748749
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
749750
self.lora_target_modules: List[str] = []
751+
self._cpp_peft_cache_manager: Optional[tb_internal.batch_manager.PeftCacheManager] = None
752+
753+
def set_cpp_peft_cache_manager(
754+
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager
755+
):
750756
self._cpp_peft_cache_manager = cpp_peft_cache_manager
751757

752758
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,3 +783,33 @@ def test_gqa_nemo_lora(tmp_path):
783783
f"got: {base_outputs[0].outputs[0].text}"
784784
finally:
785785
llm.shutdown()
786+
787+
788+
def test_lora_dir_with_graph():
789+
lora_req = LoRARequest(
790+
"task-0", 0, f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1")
791+
792+
lora_config = LoraConfig(
793+
lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"],
794+
max_lora_rank=8,
795+
lora_request=[lora_req])
796+
797+
llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
798+
lora_config=lora_config,
799+
cuda_graph_config=None)
800+
# cuda_graph_config=CudaGraphConfig(max_batch_size=1))
801+
802+
prompts = [
803+
"美国的首都在哪里? \n答案:",
804+
]
805+
references = [
806+
"美国的首都是华盛顿。\n\n美国的",
807+
]
808+
sampling_params = SamplingParams(max_tokens=20)
809+
lora_request = [lora_req]
810+
811+
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
812+
813+
assert similar(outputs[0].outputs[0].text, references[0])
814+
print(f"lora output: {outputs[0].outputs[0].text}")
815+
print(f"ref output: {references[0]}")

0 commit comments

Comments
 (0)