Skip to content

Commit 33cc07d

Browse files
committed
test: Add test to protect the memory allocation behavior.
- Changed primary and secondary pool memory allocation to use instance variables instead of local variables for better clarity and maintainability. - Updated logging to reflect the new instance variable usage. - Added unit tests to validate memory allocation behavior in KVCacheManager. Signed-off-by: qixiang-99 <[email protected]>
1 parent 6837e60 commit 33cc07d

File tree

3 files changed

+153
-16
lines changed

3 files changed

+153
-16
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
264264
assert isinstance(
265265
kv_cache_config, KvCacheConfigCpp
266266
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp"
267-
268267
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
269268
kv_cache_config=kv_cache_config,
270269
model_config=model_config,
@@ -828,12 +827,12 @@ def calculate_max_num_blocks_from_cpp(
828827
free_mem, total_mem = torch.cuda.mem_get_info()
829828
# Respect max_gpu_total_bytes if provided
830829
free_gpu_memory_fraction = kv_cache_config.free_gpu_memory_fraction if kv_cache_config.free_gpu_memory_fraction else 0.9
831-
primary_pool_memory_bytes = kv_cache_config.max_gpu_total_bytes if kv_cache_config.max_gpu_total_bytes > 0 else int(
830+
self._primary_pool_memory_bytes = kv_cache_config.max_gpu_total_bytes if kv_cache_config.max_gpu_total_bytes > 0 else int(
832831
free_mem * free_gpu_memory_fraction)
833-
secondary_pool_memory_bytes = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0
832+
self._secondary_pool_memory_bytes = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0
834833
logger.debug(
835-
f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \n"
836-
f"secondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB"
834+
f"primary_pool_memory_bytes is set to {self._primary_pool_memory_bytes/1024**3}GB, \n"
835+
f"secondary_pool_memory_bytes is set to {self._secondary_pool_memory_bytes/1024**3}GB"
837836
)
838837

839838
# Adjust the window sizes to fit the memory if even a single sequence
@@ -843,7 +842,7 @@ def calculate_max_num_blocks_from_cpp(
843842
max_attention_window_vec=self.max_attention_window_vec,
844843
model_config=model_config,
845844
kv_cache_config=kv_cache_config,
846-
pool_memory_bytes=primary_pool_memory_bytes,
845+
pool_memory_bytes=self._primary_pool_memory_bytes,
847846
kv_factor=self.kv_factor,
848847
dtype=self.dtype,
849848
is_cross_attention=is_cross_attention,
@@ -858,8 +857,8 @@ def calculate_max_num_blocks_from_cpp(
858857
model_config=model_config,
859858
world_config=world_config_cpp,
860859
window_size_to_layers=window_size_to_layers,
861-
allotted_primary_mem_bytes=primary_pool_memory_bytes,
862-
allotted_secondary_mem_bytes=secondary_pool_memory_bytes,
860+
allotted_primary_mem_bytes=self._primary_pool_memory_bytes,
861+
allotted_secondary_mem_bytes=self._secondary_pool_memory_bytes,
863862
extra_cost_memory=extra_cost_memory,
864863
kv_factor=self.kv_factor,
865864
)

tensorrt_llm/llmapi/llm_args.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,14 +1057,6 @@ def validate_max_attention_window(cls, v: Optional[List[int]]):
10571057
raise ValueError(
10581058
"kv_cache_config.max_attention_window values must be positive"
10591059
)
1060-
1061-
# Must not be a redundant repetition of a shorter pattern
1062-
n = len(v)
1063-
for k in range(1, n):
1064-
if n % k == 0 and v == v[:k] * (n // k):
1065-
raise ValueError(
1066-
f"kv_cache_config.max_attention_window should contain only the minimal repeating pattern; use {v[:k]} instead of {v}"
1067-
)
10681060
return v
10691061

10701062

tests/unittest/_torch/executor/test_resource_manager.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import subprocess
44
import sys
55
import unittest
6+
from typing import NamedTuple, Tuple
7+
from unittest.mock import patch
68

79
import numpy as np
810
import torch
@@ -13,11 +15,13 @@
1315
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
1416
PeftCacheConfig,
1517
PeftCacheManager)
18+
from tensorrt_llm.bindings import LayerType
1619
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
1720
from tensorrt_llm.bindings import executor as tllm
1821
from tensorrt_llm.bindings.internal.batch_manager import \
1922
PeftTaskNotCachedException
2023
from tensorrt_llm.lora_helper import LoraConfig
24+
from tensorrt_llm.mapping import Mapping
2125

2226
DataType = tensorrt_llm.bindings.DataType
2327
LoraModule = tensorrt_llm.bindings.LoraModule
@@ -544,6 +548,148 @@ def test_adjust_window_sizes_for_vswa(self):
544548
f"Actual: {adjusted_max_attention_window_vec}\n"
545549
f"Expected: {expected_max_attention_window_vec}")
546550

551+
@staticmethod
552+
def _create_model_config_for_kv_cache_manager() -> ModelConfigCpp:
553+
"""
554+
Create a simple model config for KVCacheManager test.
555+
"""
556+
557+
model_config_params = {
558+
"vocab_size": 0,
559+
"num_layers": 4,
560+
"num_attention_layers": 4,
561+
"num_rnn_layers": 0,
562+
"num_heads": 64,
563+
"hidden_size": 64,
564+
"data_type": DataType.HALF
565+
}
566+
num_kv_heads = 8
567+
568+
model_config = ModelConfigCpp(**model_config_params)
569+
model_config.layer_types = [LayerType.ATTENTION
570+
] * model_config.num_attention_layers()
571+
model_config.set_num_kv_heads(num_kv_heads)
572+
573+
return model_config
574+
575+
@staticmethod
576+
def _create_kv_cache_config_for_kv_cache_manager(
577+
params: dict) -> tllm.KvCacheConfig:
578+
"""
579+
Create a KV cache config for KVCacheManager test.
580+
"""
581+
return tllm.KvCacheConfig(**params)
582+
583+
def test_calculate_max_num_blocks_from_cpp(self):
584+
# Construct a minimal mapping (single-rank, no TP/PP)
585+
mapping = Mapping(world_size=1, tp_size=1, pp_size=1)
586+
587+
# Construct model config
588+
model_config = TestResourceManager._create_model_config_for_kv_cache_manager(
589+
)
590+
591+
# Construct KV cache config
592+
free_gpu_memory_fraction = 0.1
593+
max_attention_window = [64, 128]
594+
max_gpu_total_bytes = 32 * 1024 * 1024 # 32MB
595+
enable_block_reuse = False
596+
host_cache_size = 32 * 1024 * 1024 # 32MB
597+
598+
# mock values for torch.cuda.mem_get_info to return a fixed value
599+
fixed_free_mem = 128 * 1024 * 1024 # 128MB
600+
fixed_total_mem = 256 * 1024 * 1024 # 256MB
601+
602+
class MemTestCase(NamedTuple):
603+
case_name: str
604+
kv_cache_config_params: dict
605+
expected_memory_bytes: Tuple[
606+
int,
607+
int] # (primary_pool_memory_bytes, secondary_pool_memory_bytes)
608+
609+
test_cases = [
610+
# Case 1:
611+
# max_gpu_total_bytes is set, even if free_gpu_memory_fraction is set, we will use max_gpu_total_bytes
612+
# host_cache_size is set, we will use host_cache_size
613+
MemTestCase(
614+
case_name="max_gpu_total_bytes is set, host_cache_size is set",
615+
kv_cache_config_params={
616+
"max_attention_window": max_attention_window,
617+
"free_gpu_memory_fraction": free_gpu_memory_fraction,
618+
"max_gpu_total_bytes": max_gpu_total_bytes,
619+
"enable_block_reuse": enable_block_reuse,
620+
"host_cache_size": host_cache_size,
621+
},
622+
expected_memory_bytes=(max_gpu_total_bytes, host_cache_size),
623+
),
624+
625+
# Case 2:
626+
# max_gpu_total_bytes is not set, we will use free_gpu_memory_fraction
627+
# host_cache_size is not set, we will use 0
628+
MemTestCase(
629+
case_name=
630+
"max_gpu_total_bytes is not set, host_cache_size is not set",
631+
kv_cache_config_params={
632+
"max_attention_window": max_attention_window,
633+
"free_gpu_memory_fraction": free_gpu_memory_fraction,
634+
"enable_block_reuse": enable_block_reuse,
635+
},
636+
# NOTE: use np.float32 to avoid float precision issue between python(double in most cases) and cpp binding(float)
637+
expected_memory_bytes=(int(
638+
fixed_free_mem * np.float32(free_gpu_memory_fraction)), 0),
639+
),
640+
]
641+
642+
tokens_per_block = 32
643+
model_config.tokens_per_block = tokens_per_block
644+
max_seq_len = max(max_attention_window)
645+
max_batch_size = 1
646+
max_beam_width = 1
647+
648+
for case_name, kv_cache_config_params, expected_memory_bytes in test_cases:
649+
with self.subTest(case=case_name):
650+
kv_cache_config = TestResourceManager._create_kv_cache_config_for_kv_cache_manager(
651+
kv_cache_config_params)
652+
with patch('torch.cuda.mem_get_info',
653+
return_value=(fixed_free_mem, fixed_total_mem)):
654+
# Create a real KVCacheManager, it will run calculate_max_num_blocks_from_cpp in __init__
655+
manager = KVCacheManager(
656+
kv_cache_config=kv_cache_config,
657+
kv_cache_type=tensorrt_llm.bindings.internal.
658+
batch_manager.CacheType.SELF,
659+
num_layers=model_config.num_attention_layers(),
660+
num_kv_heads=model_config.num_kv_heads(
661+
0
662+
), # NOTE: assume same number of kv heads for all layers
663+
head_dim=model_config.head_size,
664+
tokens_per_block=tokens_per_block,
665+
max_seq_len=max_seq_len,
666+
max_batch_size=max_batch_size,
667+
mapping=mapping,
668+
dtype=model_config.data_type,
669+
model_config=model_config,
670+
max_beam_width=max_beam_width,
671+
)
672+
try:
673+
expected_primary, expected_secondary = expected_memory_bytes
674+
self.assertEqual(
675+
manager._primary_pool_memory_bytes,
676+
expected_primary,
677+
f"Test case '{case_name}' failed.\n"
678+
f"Expected primary pool memory bytes: {expected_primary}\n"
679+
f"Actual primary pool memory bytes: {manager._primary_pool_memory_bytes}"
680+
)
681+
self.assertEqual(
682+
manager._secondary_pool_memory_bytes,
683+
expected_secondary,
684+
f"Test case '{case_name}' failed.\n"
685+
f"Expected secondary pool memory bytes: {expected_secondary}\n"
686+
f"Actual secondary pool memory bytes: {manager._secondary_pool_memory_bytes}"
687+
)
688+
except Exception as e:
689+
self.fail(f"Test case '{case_name}' failed: {e}")
690+
finally:
691+
manager.shutdown()
692+
547693

548694
if __name__ == "__main__":
549695
unittest.main()

0 commit comments

Comments
 (0)