From 6e67c55f835dd02faa818aa9deae47c844d802f2 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:05:36 +0300 Subject: [PATCH] [TRTLLM-6683][feat] Support LoRA reload CPU cache evicted adapter (#6510) Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/llmRequest.h | 3 + cpp/tensorrt_llm/batch_manager/llmRequest.cpp | 6 + .../batch_manager/peftCacheManager.cpp | 13 +- cpp/tensorrt_llm/executor/loraConfig.cpp | 27 ++-- .../nanobind/batch_manager/bindings.cpp | 3 +- .../pybind/batch_manager/bindings.cpp | 3 +- .../unit_tests/executor/loraConfigTest.cpp | 9 +- tensorrt_llm/_torch/pyexecutor/_util.py | 1 + tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 + .../_torch/pyexecutor/resource_manager.py | 39 +++++- tensorrt_llm/_utils.py | 7 + tensorrt_llm/executor/worker.py | 6 +- .../unittest/_torch/test_resource_manager.py | 17 ++- tests/unittest/llmapi/test_llm.py | 52 ++++++-- tests/unittest/llmapi/test_llm_pytorch.py | 122 ++++++++---------- tests/unittest/utils/util.py | 93 +------------ 16 files changed, 185 insertions(+), 218 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 0d087d96c0f..aedac8c2ac7 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2347,6 +2347,9 @@ class LlmRequest : public GenericLlmRequest void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager); void moveLoraWeightsToGpu(runtime::BufferManager const& manager); + + // Remove LoRA weights and LoRA config tensors + void removeLoraTensors(); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index a9a4aec5dfc..dcebc9c3ac6 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager) mLoraWeights = gpuLoraWeights; } +void LlmRequest::removeLoraTensors() +{ + mLoraWeights.reset(); + mLoraConfig.reset(); +} + } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp index f513f2a3a10..cc62bd3eb04 100644 --- a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @@ -591,10 +591,9 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRe TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (llmRequest->getLoraTaskId().has_value()) { - auto taskId = llmRequest->getLoraTaskId().value(); try { - return mHostLoraCache->determineNumPages(taskId); + return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value()); } catch (std::runtime_error& e) { @@ -602,16 +601,6 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRe { return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value()); } - if (!llmRequest->getLoraWeights().has_value()) - { - auto const reqId = llmRequest->mRequestId; - std::string errMsg - = "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task " - + std::to_string(taskId) + " that's not found in LoRA CPU cache." - " Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization," - " so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported."; - throw PeftTaskNotCachedException(errMsg); - } throw; } } diff --git a/cpp/tensorrt_llm/executor/loraConfig.cpp b/cpp/tensorrt_llm/executor/loraConfig.cpp index 058b1a86710..c8499f36d4d 100644 --- a/cpp/tensorrt_llm/executor/loraConfig.cpp +++ b/cpp/tensorrt_llm/executor/loraConfig.cpp @@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional weights, std::option , mWeights(std::move(weights)) , mConfig(std::move(config)) { - if (mWeights.has_value() || mConfig.has_value()) + if (mConfig.has_value()) { - TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(), - "Request for LoRA inference must have both lora weights and lora config"); - - SizeType32 constexpr expectedWeightsDims = 2; SizeType32 constexpr expectedConfigDims = 2; - - TLLM_CHECK_WITH_INFO( - mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); TLLM_CHECK_WITH_INFO( mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions"); - TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU - && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU && mConfig.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); + "Expected lora config to be in CPU memory"); TLLM_CHECK_WITH_INFO( mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32"); + } + if (mWeights.has_value()) + { + SizeType32 constexpr expectedWeightsDims = 2; + TLLM_CHECK_WITH_INFO( + mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config"); + + TLLM_CHECK_WITH_INFO( + mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); + + TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU + && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, + "Expected lora weights to be in CPU memory"); TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0], "Expected dim 0 of lora weights and lora config to have the same size"); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 56fdbf14e9b..2ac069616e0 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -375,7 +375,8 @@ void initBindings(nb::module_& m) .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, nb::arg("manager")) .def("finish_by_reason", &tb::LlmRequest::finishByReason, nb::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) - .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")); + .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, nb::arg("iter_counter")) + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); nb::class_(m, "SequenceSlotManager") .def(nb::init(), nb::arg("max_num_slots"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index f0d74f4f99e..04faa90c2ff 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -381,7 +381,8 @@ void initBindings(pybind11::module_& m) .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager")) .def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason")) .def("set_first_scheduled_time", &tb::LlmRequest::setFirstScheduledTime) - .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")); + .def("update_perf_metrics", &tb::LlmRequest::updatePerfMetrics, py::arg("iter_counter")) + .def("remove_lora_tensors", &tb::LlmRequest::removeLoraTensors); py::classh(m, "SequenceSlotManager") .def(py::init(), py::arg("max_num_slots"), diff --git a/cpp/tests/unit_tests/executor/loraConfigTest.cpp b/cpp/tests/unit_tests/executor/loraConfigTest.cpp index 2859739f6e4..6ce56cccbdf 100644 --- a/cpp/tests/unit_tests/executor/loraConfigTest.cpp +++ b/cpp/tests/unit_tests/executor/loraConfigTest.cpp @@ -53,13 +53,12 @@ TEST(LoraConfigTest, invalidInputs) // This should work auto loraConfig = LoraConfig(1, weights, config); + // Having config only without weights is allowed + loraConfig = LoraConfig(1, std::nullopt, config); { - // Only one specified - testInvalid(1, std::nullopt, config, "must have both"); - - // Only one specified - testInvalid(1, weights, std::nullopt, "must have both"); + // Only weights specified without config - not allowed + testInvalid(1, weights, std::nullopt, "lora weights must also have lora config"); } { diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b7204afeb43..21fa9f91c1d 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -502,6 +502,7 @@ def create_py_executor_instance( ) peft_cache_manager = PeftCacheManager( peft_cache_config=executor_config.peft_cache_config, + lora_config=lora_config, model_config=model_binding_config, world_config=world_config, ) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 0fb1f06e964..8aa263bb039 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -285,6 +285,7 @@ def __init__( self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) + self.py_lora_path: str | None = kwargs.pop("py_lora_path", None) # Multimodal data self.py_multimodal_data = kwargs.pop("py_multimodal_data", None) if llm_request is not None: @@ -490,6 +491,7 @@ def executor_request_to_llm_request( if executor_request.lora_config is not None else None, lora_config=executor_request.lora_config.config if executor_request.lora_config is not None else None, + py_lora_path=getattr(executor_request, "py_lora_path", None), mrope_rotary_cos_sin=mrope_rotary_cos_sin, mrope_position_deltas=mrope_position_deltas, lookahead_config=None, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 9f44649b494..eb33f8aa5b9 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -10,9 +10,10 @@ import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE +from tensorrt_llm.lora_manager import LoraConfig, LoraManager, LoraModelConfig from tensorrt_llm.sampling_params import SamplingParams -from ..._utils import binding_dtype_size, nvtx_range +from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range from ...logger import logger from ...mapping import Mapping from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig, @@ -1170,6 +1171,7 @@ class PeftCacheManager(BaseResourceManager): def __init__(self, peft_cache_config: PeftCacheConfig, + lora_config: LoraConfig, model_config: ModelConfig, world_config: WorldConfig | None = None): import tensorrt_llm.bindings as _tb @@ -1200,8 +1202,36 @@ def __init__(self, model_config=model_config, world_config=world_config, buffer_manager=buffer_manager) + self._lora_config = lora_config + self._lora_model_config = LoraModelConfig( + lora_config.lora_target_modules, + lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size, + binding_to_str_dtype(model_config.data_type)) + self._lora_manager = LoraManager() def add_request_peft(self, request: LlmRequest): + if request.lora_task_id is not None: + is_task_cached = self.impl.is_task_cached(request.lora_task_id) + if is_task_cached: + # PeftCacheManager::addRequestPeft in CPP doesn't allow having only one of [config tensor, weights + # tensor] without the other. Since there's no need for any of them when the LoRA adapter is already + # cached, we can safely remove both from the request. + request.remove_lora_tensors() + elif request.lora_weights is None and request.py_lora_path: + self._lora_manager.load_from_ckpt( + [request.py_lora_path], + model_config=self._lora_model_config, + runtime_mapping=None, + uids=[request.lora_task_id], + ckpt_source=self._lora_config.lora_ckpt_source) + request.lora_weights = self._lora_manager.cpp_lora_weights[ + request.lora_task_id] + + # PeftCacheManager CPP implementation expects an extra dim at index 0 + if request.lora_weights is not None: + request.lora_weights = request.lora_weights.unsqueeze(0) + if request.lora_config is not None: + request.lora_config = request.lora_config.unsqueeze(0) self.impl.add_request_peft(request, True) def ensure_batch(self, @@ -1221,12 +1251,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests for req in context_batch: - if req.lora_weights is not None and req.lora_config is not None: - req.lora_weights = req.lora_weights.reshape( - [1] + list(req.lora_weights.shape)) - req.lora_config = req.lora_config.reshape( - [1] + list(req.lora_config.shape)) - self.impl.add_request_peft(req, True) + self.add_request_peft(req) py_lora_task_layer_module_configs = self.impl.ensure_batch( context_batch, generation_batch, False) diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 75be2727918..d6cce437761 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -180,6 +180,7 @@ def str_dtype_to_torch(dtype): bool=DataType.BOOL, fp8=DataType.FP8, ) +_binding_to_str_dtype = {v: k for k, v in _str_to_binding_dtype_dict.items()} _binding_dtype_size = { DataType.INT64: 8, @@ -194,6 +195,12 @@ def str_dtype_to_torch(dtype): } +def binding_to_str_dtype(binding_dtype) -> str: + ret = _binding_to_str_dtype.get(binding_dtype) + assert ret is not None, f'Unsupported binding dtype: {binding_dtype}' + return ret + + def binding_dtype_size(dtype: DataType): return _binding_dtype_size[dtype] diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 33ed146c9c6..db8d84fcc89 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -372,6 +372,7 @@ def _load_prompt_adapter(self, def _enqueue_request(self, request: GenerationRequest) -> int: assert request.id is not None + py_lora_path = None if self._lora_manager is not None and request.lora_request is not None: adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( request.lora_request.adapter_id) @@ -381,8 +382,8 @@ def _enqueue_request(self, request: GenerationRequest) -> int: task_id=request.lora_request.adapter_id, weights=self._lora_manager.cpp_lora_weights[uid] if not adapter_in_cache else None, - config=self._lora_manager.cpp_lora_config[uid] - if not adapter_in_cache else None) + config=self._lora_manager.cpp_lora_config[uid]) + py_lora_path = request.lora_request.lora_path else: lora_config = None @@ -497,6 +498,7 @@ def _deduce_max_tokens(request: GenerationRequest, kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, type=request_type) + executor_request.py_lora_path = py_lora_path if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: diff --git a/tests/unittest/_torch/test_resource_manager.py b/tests/unittest/_torch/test_resource_manager.py index da1dae84ba1..21edd013da1 100644 --- a/tests/unittest/_torch/test_resource_manager.py +++ b/tests/unittest/_torch/test_resource_manager.py @@ -5,11 +5,11 @@ import unittest import numpy as np -import pytest import torch import tensorrt_llm import tensorrt_llm.bindings +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, PeftCacheConfig, PeftCacheManager) @@ -17,6 +17,7 @@ from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.bindings.internal.batch_manager import \ PeftTaskNotCachedException +from tensorrt_llm.lora_manager import LoraConfig DataType = tensorrt_llm.bindings.DataType LoraModule = tensorrt_llm.bindings.LoraModule @@ -247,7 +248,7 @@ def _create_request(self, lora_config = torch.from_numpy(lora_config) input_tokens = [i + 1 for i in range(max_new_tokens)] - request = tensorrt_llm.bindings.internal.batch_manager.LlmRequest( + request = LlmRequest( request_id=request_id, max_new_tokens=max_new_tokens, input_tokens=input_tokens, @@ -261,15 +262,13 @@ def _create_request(self, return request def get_lora_data(self): - """Create mock LoRA weights and config that match the C++ validation expectations. + """Create mock LoRA weights and config. Returns: - tuple: (weights tensor, config tensor) formatted correctly for the C++ implementation. + tuple: (weights tensor, config tensor). """ lora_weights = np.load(self.TP1_WEIGHTS_PATH).astype(np.float16) - lora_weights = np.expand_dims(lora_weights, axis=0) lora_config = np.load(self.TP1_CONFIG_PATH) - lora_config = np.expand_dims(lora_config, axis=0) return lora_weights, lora_config def test_successful_mocked_peft_cache_manager_initialization(self): @@ -277,6 +276,7 @@ def test_successful_mocked_peft_cache_manager_initialization(self): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -290,6 +290,7 @@ def test_add_request_peft_empty_weights_config(self): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -307,6 +308,7 @@ def test_add_request_peft_empty_batch(self): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -322,6 +324,7 @@ def test_add_request_peft(self): peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) @@ -349,13 +352,13 @@ def test_add_request_peft(self): self.assertEqual(len(peft_table), self.num_lora_modules) - @pytest.mark.skip(reason="https://nvbugs/5324252") def test_put_get(self): """Test adding a request with properly configured LoRA weights and config.""" peft_cache_config = self.create_peft_cache_config() peft_cache_manager = PeftCacheManager( peft_cache_config=peft_cache_config, + lora_config=LoraConfig(), model_config=self.model_config, ) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 2b7c606bf41..5e82d10b43c 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1459,20 +1459,7 @@ def llama_v2_13b_lora_from_dir_test_harness(**llm_kwargs): assert similar(output.outputs[0].text, ref) -@pytest.mark.parametrize( - "lora_adapter_count_per_call, max_loras, max_cpu_loras, repeat_calls, repeats_per_call", - [ - # Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single - # llm.generate call, that's repeated twice. - ([ - 2, - ], 1, 2, 2, 3), - # Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU - # cache size < LoRA CPU cache size - ([2, 2, 2], 1, 3, 1, 1), - ]) -@skip_gpu_memory_less_than_40gb -def test_llama_7b_multi_lora_evict_load_new_adapters( +def _check_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call: list[int], max_loras: int, max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): # For LoRA checkpoints without finetuned embedding and lm_head, we can either: @@ -1493,6 +1480,43 @@ def test_llama_7b_multi_lora_evict_load_new_adapters( fast_build=True) +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(): + """Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single + llm.generate call, that's repeated twice. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2], + max_loras=1, + max_cpu_loras=2, + repeat_calls=2, + repeats_per_call=3) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(): + """Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU + cache size < LoRA CPU cache size. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2, 2, 2], + max_loras=1, + max_cpu_loras=3, + repeat_calls=1, + repeats_per_call=1) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_read_from_cache_after_insert(): + """Test that loading and then using the same adapters loaded in cache works.""" + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[3], + max_loras=3, + max_cpu_loras=3, + repeat_calls=2, + repeats_per_call=1) + + def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 13708aae3c1..518772d6f60 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -20,9 +20,7 @@ run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler, tinyllama_logits_processor_test_harness) -from utils.util import (EnvVarsContextManager, force_ampere, - run_function_in_sub_process, similar, - skip_gpu_memory_less_than_40gb, +from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_gpu_memory_less_than_80gb, skip_gpu_memory_less_than_138gb) from utils.llm_data import llm_models_root @@ -313,20 +311,7 @@ def test_llama_7b_lora_default_modules() -> None: llm.shutdown() -@pytest.mark.parametrize( - "lora_adapter_count_per_call, max_loras, max_cpu_loras, repeat_calls, repeats_per_call", - [ - # Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single - # llm.generate call, that's repeated twice. - ([ - 2, - ], 1, 2, 2, 3), - # Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU - # cache size < LoRA CPU cache size - ([2, 2, 2], 1, 3, 1, 1), - ]) -@skip_gpu_memory_less_than_40gb -def test_llama_7b_multi_lora_evict_load_new_adapters( +def _check_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call: list[int], max_loras: int, max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): # For LoRA checkpoints without finetuned embedding and lm_head, we can either: @@ -347,60 +332,66 @@ def test_llama_7b_multi_lora_evict_load_new_adapters( cuda_graph_config=None) -@pytest.mark.parametrize( - "lora_adapter_count_per_call, max_loras, max_cpu_loras, repeat_calls, repeats_per_call", - [ - # Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU - # cache over multiple llm.generate call repeated twice (two calls with the same requests): - # At the end of the 1st llm.generate call: - # The LoRA caches should contain adapters 1, 2 and shouldn't contain adapter 0 (it should have been evicted). - # So in the 2nd call, the worker should: - # - Send req0 with adapter 0 weights (because it was previously evicted) - # - Send the other two requests without their adapter weights as they're already in LoRA CPU cache - # Then, handling of req0 that has weights but not in the cache should evict one of the other two adapters from - # the cache, causing that evicted adapter's request to fail because its weights aren't with the request and - # aren't in LoRA cache. - ([ - 3, - ], 2, 2, 2, 1), - ]) @skip_gpu_memory_less_than_40gb -def test_llama_7b_multi_lora_load_previously_cpu_cache_evicted_adapter_fails( - lora_adapter_count_per_call: list[int], max_loras: int, - max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): - """Tests that trying to load a LoRA adapter after it was evicted from CPU cache fails with the expected - message, as this feature is currently not supported in favor of the performance improvement of not - sending the LoRA weights with every request after the first time. - NOTE: This test assumes the requests are handled in the order they're sent, if that's not true, then this test - may not get any error at all, which would cause it to fail. +def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(): + """Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single + llm.generate call, that's repeated twice. """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2], + max_loras=1, + max_cpu_loras=2, + repeat_calls=2, + repeats_per_call=3) - def _check_contains_expected_message(stdout: str, stderr: str): - note_in_message = "Note that currently a request with LoRA task that was already loaded is sent" \ - " without its LoRA weights to save its serialization, copy and deserialization, so if this" \ - " LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported." - return note_in_message in stderr - lora_config = LoraConfig(lora_target_modules=['attn_q', 'attn_k', 'attn_v'], - max_lora_rank=8, - max_loras=max_loras, - max_cpu_loras=max_cpu_loras) - with EnvVarsContextManager({"TLLM_WORKER_USE_SINGLE_PROCESS": "1"}): - child_stdout, child_stderr = run_function_in_sub_process( - target=check_llama_7b_multi_unique_lora_adapters_from_request, - args=(lora_adapter_count_per_call, repeat_calls, repeats_per_call, - LLM), - kwargs={ - "lora_config": lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - "cuda_graph_config": None - }, - stop_waiting_criteria=_check_contains_expected_message) - - assert _check_contains_expected_message(child_stdout, child_stderr) +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(): + """Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU + cache size < LoRA CPU cache size. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[2, 2, 2], + max_loras=1, + max_cpu_loras=3, + repeat_calls=1, + repeats_per_call=1) + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_read_from_cache_after_insert(): + """Test that loading and then using the same adapters loaded in cache works.""" + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[3], + max_loras=3, + max_cpu_loras=3, + repeat_calls=2, + repeats_per_call=1) + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_cache( +): + """Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU + cache over multiple llm.generate call repeated twice (two calls with the same requests): + At the end of the 1st llm.generate call: + The LoRA caches should contain adapters 1, 2 and shouldn't contain adapter 0 (it should have been evicted). + So in the 2nd call, the worker should: + - Send req0 with adapter 0 weights (because it was previously evicted) + - Send the other two requests without their adapter weights as they're already in LoRA CPU cache + Then, handling of req0 that has weights but not in the cache should evict one of the other two adapters from + the cache, causing that evicted adapter's request to again load its weights from the file system, as they + aren't with the request and aren't in LoRA cache. + """ # noqa: D205 + _check_llama_7b_multi_lora_evict_load_new_adapters( + lora_adapter_count_per_call=[3], + max_loras=2, + max_cpu_loras=2, + repeat_calls=2, + repeats_per_call=1) + +@skip_gpu_memory_less_than_40gb def test_llama_7b_peft_cache_config_affects_peft_cache_size(): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. @@ -436,6 +427,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): cuda_graph_config=None) +@skip_gpu_memory_less_than_40gb def test_llama_7b_lora_config_overrides_peft_cache_config(): """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg. diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index 7d5c90833a1..cbb483b6087 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -1,13 +1,9 @@ -import multiprocessing import os -import sys -import time import unittest from contextlib import contextmanager from difflib import SequenceMatcher -from multiprocessing.connection import Connection from pathlib import Path -from typing import Any, Callable, Generator, Mapping, Tuple +from typing import Any, Generator import pynvml import pytest @@ -425,90 +421,3 @@ def duplicate_list_to_length(list: list[Any], target_length: int) -> list[Any]: if remain != 0: duplicated_list += list[:remain] return duplicated_list - - -def _target_wrapper(target: Callable, stdout_pipe: Connection, - stderr_pipe: Connection, *args, **kwargs) -> None: - - class PipeWriter: - - def __init__(self, conn: Connection): - self.conn = conn - - def write(self, s: str): - self.conn.send_bytes(s.encode("UTF8")) - - def flush(self): - pass - - sys.stdout = PipeWriter(stdout_pipe) - sys.stderr = PipeWriter(stderr_pipe) - target(*args, **kwargs) - - -def run_function_in_sub_process(target: Callable, - args: tuple, - kwargs: Mapping[str, Any], - stop_waiting_criteria: Callable, - poll_interval_seconds: int = 5, - timeout_seconds: int = 240) -> Tuple[str, str]: - multiprocessing.set_start_method("spawn", force=True) - parent_stdout_pipe, child_stdout_pipe = multiprocessing.Pipe() - parent_stderr_pipe, child_stderr_pipe = multiprocessing.Pipe() - child_process = multiprocessing.Process( - target=_target_wrapper, - args=[target, child_stdout_pipe, child_stderr_pipe] + list(args), - kwargs=kwargs) - child_process.start() - child_stdout_pipe.close() - child_stderr_pipe.close() - - def _read_from_pipe(pipe: Connection): - out = "" - while pipe.poll(timeout=0.1): - try: - out += pipe.recv_bytes().decode("UTF8") - except Exception: - break - return out - - child_stdout = "" - child_stderr = "" - try: - total_waiting_seconds = 0 - while child_process.is_alive( - ) and total_waiting_seconds < timeout_seconds: - child_stdout += _read_from_pipe(parent_stdout_pipe) - child_stderr += _read_from_pipe(parent_stderr_pipe) - if stop_waiting_criteria(child_stdout, child_stderr): - break - time.sleep(poll_interval_seconds) - total_waiting_seconds += poll_interval_seconds - finally: - parent_stdout_pipe.close() - parent_stderr_pipe.close() - if child_process.is_alive(): - child_process.terminate() - - assert total_waiting_seconds < timeout_seconds, "Reached timeout while waiting for target" - return child_stdout, child_stderr - - -class EnvVarsContextManager: - - def __init__(self, new_env_vars: dict[str, str]): - self._env_vars = new_env_vars - self._original_value = None - - def __enter__(self): - self._original_vars = { - var_name: os.environ[var_name] - for var_name in self._env_vars.keys() if var_name in os.environ - } - os.environ.update(self._env_vars) - - def __exit__(self, type, value, traceback): - os.environ.update(self._original_vars) - for var_name in self._env_vars.keys(): - if var_name not in self._original_vars: - os.environ.pop(var_name)