Skip to content

Commit 5a0d5de

Browse files
committed
Enhance KVCacheManager to maintain adjusted max attention window sizes. Introduced an adjusted dictionary to track window size mappings and updated the logic to reflect these changes in the max attention window vector. Updated unit tests to validate the new behavior and ensure expected outputs for various memory configurations.
Signed-off-by: qixiang-99 <[email protected]>
1 parent d05e9e6 commit 5a0d5de

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int:
670670

671671
accum_max_tokens = 0
672672
prev_window_size = 0
673+
adjusted_dict = {}
673674
adjusted_max_attention_window_vec = max_attention_window_vec.copy()
674675

675676
for window_size in sorted(window_size_to_layers):
@@ -712,12 +713,14 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int:
712713

713714
if accum_max_tokens not in adjusted_window_size_to_layers:
714715
adjusted_window_size_to_layers[accum_max_tokens] = layers.copy()
715-
# also update adjusted_max_attention_window_vec
716-
for i, v in enumerate(adjusted_max_attention_window_vec):
717-
if v == window_size:
718-
adjusted_max_attention_window_vec[i] = accum_max_tokens
719716
else:
720717
adjusted_window_size_to_layers[accum_max_tokens].extend(layers)
718+
adjusted_dict[window_size] = accum_max_tokens
719+
# also update adjusted_max_attention_window_vec
720+
adjusted_max_attention_window_vec = [
721+
adjusted_dict.get(v, v)
722+
for v in adjusted_max_attention_window_vec
723+
]
721724

722725
remaining_layers -= set(layers)
723726
prev_window_size = window_size

tests/unittest/_torch/test_resource_manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def test_adjust_window_sizes_for_vswa(self):
433433
200: [4, 5, 6],
434434
7000: [7, 8],
435435
}
436+
max_attention_window_vec = [100] * 4 + [200] * 3 + [7000] * 2
436437

437438
model_config = self.MockModelConfig()
438439
model_config.num_attention_heads = 2
@@ -460,6 +461,7 @@ def test_adjust_window_sizes_for_vswa(self):
460461
100: [0, 1, 2, 3],
461462
130: [4, 5, 6, 7, 8],
462463
},
464+
[100] * 4 + [130] * 5,
463465
None,
464466
"limited_memory_clamped_windows"),
465467
(
@@ -471,6 +473,7 @@ def test_adjust_window_sizes_for_vswa(self):
471473
200: [4, 5, 6],
472474
1017: [7, 8],
473475
},
476+
[100] * 4 + [200] * 3 + [1017] * 2,
474477
None,
475478
"less_limited_memory_clamped_windows"),
476479
(
@@ -482,6 +485,7 @@ def test_adjust_window_sizes_for_vswa(self):
482485
200: [4, 5, 6],
483486
7000: [7, 8],
484487
},
488+
[100] * 4 + [200] * 3 + [7000] * 2,
485489
None,
486490
"sufficient_memory_no_clamping"),
487491
(
@@ -490,6 +494,7 @@ def test_adjust_window_sizes_for_vswa(self):
490494
{
491495
51: [0, 1, 2, 3, 4, 5, 6, 7, 8],
492496
},
497+
[51] * 9,
493498
None,
494499
"very_limited_memory_all_clamped"),
495500
(
@@ -501,15 +506,17 @@ def test_adjust_window_sizes_for_vswa(self):
501506
100: [0, 1, 2, 3],
502507
134: [4, 5, 6, 7, 8],
503508
},
509+
[100] * 4 + [134] * 5,
504510
134,
505511
"less_limited_memory_but_clamped_by_max_tokens"),
506512
]
507513

508-
for memory_bytes, expected_window_sizes, max_tokens, description in test_cases:
514+
for memory_bytes, expected_window_sizes, expected_max_attention_window_vec, max_tokens, description in test_cases:
509515
with self.subTest(case=description, memory_bytes=memory_bytes):
510516
kv_cache_config = tllm.KvCacheConfig(max_tokens=max_tokens)
511-
adjusted = KVCacheManager.adjust_window_sizes_for_vswa(
517+
adjusted, adjusted_max_attention_window_vec = KVCacheManager.adjust_window_sizes_for_vswa(
512518
window_size_to_layers=window_size_to_layers,
519+
max_attention_window_vec=max_attention_window_vec,
513520
model_config=model_config,
514521
kv_cache_config=kv_cache_config,
515522
pool_memory_bytes=memory_bytes,
@@ -524,6 +531,13 @@ def test_adjust_window_sizes_for_vswa(self):
524531
f"Memory bytes: {memory_bytes}\n"
525532
f"Actual: {adjusted}\n"
526533
f"Expected: {expected_window_sizes}")
534+
self.assertEqual(
535+
adjusted_max_attention_window_vec,
536+
expected_max_attention_window_vec,
537+
f"Test case '{description}' failed.\n"
538+
f"Memory bytes: {memory_bytes}\n"
539+
f"Actual: {adjusted_max_attention_window_vec}\n"
540+
f"Expected: {expected_max_attention_window_vec}")
527541

528542

529543
if __name__ == "__main__":

0 commit comments

Comments
 (0)