|
3 | 3 | import subprocess |
4 | 4 | import sys |
5 | 5 | import unittest |
| 6 | +from typing import NamedTuple, Tuple |
| 7 | +from unittest.mock import patch |
6 | 8 |
|
7 | 9 | import numpy as np |
8 | 10 | import torch |
|
13 | 15 | from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager, |
14 | 16 | PeftCacheConfig, |
15 | 17 | PeftCacheManager) |
| 18 | +from tensorrt_llm.bindings import LayerType |
16 | 19 | from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp |
17 | 20 | from tensorrt_llm.bindings import executor as tllm |
18 | 21 | from tensorrt_llm.bindings.internal.batch_manager import \ |
19 | 22 | PeftTaskNotCachedException |
20 | 23 | from tensorrt_llm.lora_helper import LoraConfig |
| 24 | +from tensorrt_llm.mapping import Mapping |
21 | 25 |
|
22 | 26 | DataType = tensorrt_llm.bindings.DataType |
23 | 27 | LoraModule = tensorrt_llm.bindings.LoraModule |
@@ -544,6 +548,148 @@ def test_adjust_window_sizes_for_vswa(self): |
544 | 548 | f"Actual: {adjusted_max_attention_window_vec}\n" |
545 | 549 | f"Expected: {expected_max_attention_window_vec}") |
546 | 550 |
|
| 551 | + @staticmethod |
| 552 | + def _create_model_config_for_kv_cache_manager() -> ModelConfigCpp: |
| 553 | + """ |
| 554 | + Create a simple model config for KVCacheManager test. |
| 555 | + """ |
| 556 | + |
| 557 | + model_config_params = { |
| 558 | + "vocab_size": 0, |
| 559 | + "num_layers": 4, |
| 560 | + "num_attention_layers": 4, |
| 561 | + "num_rnn_layers": 0, |
| 562 | + "num_heads": 64, |
| 563 | + "hidden_size": 64, |
| 564 | + "data_type": DataType.HALF |
| 565 | + } |
| 566 | + num_kv_heads = 8 |
| 567 | + |
| 568 | + model_config = ModelConfigCpp(**model_config_params) |
| 569 | + model_config.layer_types = [LayerType.ATTENTION |
| 570 | + ] * model_config.num_attention_layers() |
| 571 | + model_config.set_num_kv_heads(num_kv_heads) |
| 572 | + |
| 573 | + return model_config |
| 574 | + |
| 575 | + @staticmethod |
| 576 | + def _create_kv_cache_config_for_kv_cache_manager( |
| 577 | + params: dict) -> tllm.KvCacheConfig: |
| 578 | + """ |
| 579 | + Create a KV cache config for KVCacheManager test. |
| 580 | + """ |
| 581 | + return tllm.KvCacheConfig(**params) |
| 582 | + |
| 583 | + def test_calculate_max_num_blocks_from_cpp(self): |
| 584 | + # Construct a minimal mapping (single-rank, no TP/PP) |
| 585 | + mapping = Mapping(world_size=1, tp_size=1, pp_size=1) |
| 586 | + |
| 587 | + # Construct model config |
| 588 | + model_config = TestResourceManager._create_model_config_for_kv_cache_manager( |
| 589 | + ) |
| 590 | + |
| 591 | + # Construct KV cache config |
| 592 | + free_gpu_memory_fraction = 0.1 |
| 593 | + max_attention_window = [64, 128] |
| 594 | + max_gpu_total_bytes = 32 * 1024 * 1024 # 32MB |
| 595 | + enable_block_reuse = False |
| 596 | + host_cache_size = 32 * 1024 * 1024 # 32MB |
| 597 | + |
| 598 | + # mock values for torch.cuda.mem_get_info to return a fixed value |
| 599 | + fixed_free_mem = 128 * 1024 * 1024 # 128MB |
| 600 | + fixed_total_mem = 256 * 1024 * 1024 # 256MB |
| 601 | + |
| 602 | + class MemTestCase(NamedTuple): |
| 603 | + case_name: str |
| 604 | + kv_cache_config_params: dict |
| 605 | + expected_memory_bytes: Tuple[ |
| 606 | + int, |
| 607 | + int] # (primary_pool_memory_bytes, secondary_pool_memory_bytes) |
| 608 | + |
| 609 | + test_cases = [ |
| 610 | + # Case 1: |
| 611 | + # max_gpu_total_bytes is set, even if free_gpu_memory_fraction is set, we will use max_gpu_total_bytes |
| 612 | + # host_cache_size is set, we will use host_cache_size |
| 613 | + MemTestCase( |
| 614 | + case_name="max_gpu_total_bytes is set, host_cache_size is set", |
| 615 | + kv_cache_config_params={ |
| 616 | + "max_attention_window": max_attention_window, |
| 617 | + "free_gpu_memory_fraction": free_gpu_memory_fraction, |
| 618 | + "max_gpu_total_bytes": max_gpu_total_bytes, |
| 619 | + "enable_block_reuse": enable_block_reuse, |
| 620 | + "host_cache_size": host_cache_size, |
| 621 | + }, |
| 622 | + expected_memory_bytes=(max_gpu_total_bytes, host_cache_size), |
| 623 | + ), |
| 624 | + |
| 625 | + # Case 2: |
| 626 | + # max_gpu_total_bytes is not set, we will use free_gpu_memory_fraction |
| 627 | + # host_cache_size is not set, we will use 0 |
| 628 | + MemTestCase( |
| 629 | + case_name= |
| 630 | + "max_gpu_total_bytes is not set, host_cache_size is not set", |
| 631 | + kv_cache_config_params={ |
| 632 | + "max_attention_window": max_attention_window, |
| 633 | + "free_gpu_memory_fraction": free_gpu_memory_fraction, |
| 634 | + "enable_block_reuse": enable_block_reuse, |
| 635 | + }, |
| 636 | + # NOTE: use np.float32 to avoid float precision issue between python(double in most cases) and cpp binding(float) |
| 637 | + expected_memory_bytes=(int( |
| 638 | + fixed_free_mem * np.float32(free_gpu_memory_fraction)), 0), |
| 639 | + ), |
| 640 | + ] |
| 641 | + |
| 642 | + tokens_per_block = 32 |
| 643 | + model_config.tokens_per_block = tokens_per_block |
| 644 | + max_seq_len = max(max_attention_window) |
| 645 | + max_batch_size = 1 |
| 646 | + max_beam_width = 1 |
| 647 | + |
| 648 | + for case_name, kv_cache_config_params, expected_memory_bytes in test_cases: |
| 649 | + with self.subTest(case=case_name): |
| 650 | + kv_cache_config = TestResourceManager._create_kv_cache_config_for_kv_cache_manager( |
| 651 | + kv_cache_config_params) |
| 652 | + with patch('torch.cuda.mem_get_info', |
| 653 | + return_value=(fixed_free_mem, fixed_total_mem)): |
| 654 | + # Create a real KVCacheManager, it will run calculate_max_num_blocks_from_cpp in __init__ |
| 655 | + manager = KVCacheManager( |
| 656 | + kv_cache_config=kv_cache_config, |
| 657 | + kv_cache_type=tensorrt_llm.bindings.internal. |
| 658 | + batch_manager.CacheType.SELF, |
| 659 | + num_layers=model_config.num_attention_layers(), |
| 660 | + num_kv_heads=model_config.num_kv_heads( |
| 661 | + 0 |
| 662 | + ), # NOTE: assume same number of kv heads for all layers |
| 663 | + head_dim=model_config.head_size, |
| 664 | + tokens_per_block=tokens_per_block, |
| 665 | + max_seq_len=max_seq_len, |
| 666 | + max_batch_size=max_batch_size, |
| 667 | + mapping=mapping, |
| 668 | + dtype=model_config.data_type, |
| 669 | + model_config=model_config, |
| 670 | + max_beam_width=max_beam_width, |
| 671 | + ) |
| 672 | + try: |
| 673 | + expected_primary, expected_secondary = expected_memory_bytes |
| 674 | + self.assertEqual( |
| 675 | + manager._primary_pool_memory_bytes, |
| 676 | + expected_primary, |
| 677 | + f"Test case '{case_name}' failed.\n" |
| 678 | + f"Expected primary pool memory bytes: {expected_primary}\n" |
| 679 | + f"Actual primary pool memory bytes: {manager._primary_pool_memory_bytes}" |
| 680 | + ) |
| 681 | + self.assertEqual( |
| 682 | + manager._secondary_pool_memory_bytes, |
| 683 | + expected_secondary, |
| 684 | + f"Test case '{case_name}' failed.\n" |
| 685 | + f"Expected secondary pool memory bytes: {expected_secondary}\n" |
| 686 | + f"Actual secondary pool memory bytes: {manager._secondary_pool_memory_bytes}" |
| 687 | + ) |
| 688 | + except Exception as e: |
| 689 | + self.fail(f"Test case '{case_name}' failed: {e}") |
| 690 | + finally: |
| 691 | + manager.shutdown() |
| 692 | + |
547 | 693 |
|
548 | 694 | if __name__ == "__main__": |
549 | 695 | unittest.main() |
0 commit comments