|
| 1 | +import time |
| 2 | + |
1 | 3 | import pytest |
2 | 4 | import torch |
3 | 5 |
|
@@ -123,3 +125,46 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype, |
123 | 125 | assert torch.equal( |
124 | 126 | kv_cache_manager_gen.get_buffers(0), |
125 | 127 | kv_cache_manager_ctx.get_buffers(0)), "different kv-cache values" |
| 128 | + |
| 129 | + |
| 130 | +@pytest.mark.timeout(120) |
| 131 | +@pytest.mark.parametrize("attention_type", |
| 132 | + [AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA], |
| 133 | + ids=["mha", "mla"]) |
| 134 | +def test_cancel_request_in_transmission_ctx(attention_type): |
| 135 | + # Init kv_cache manager and cache transceiver |
| 136 | + mapping = Mapping(world_size=1, rank=0) |
| 137 | + ctx_kv_cache_dtype = DataType.HALF |
| 138 | + kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype) |
| 139 | + |
| 140 | + cache_transceiver_config = trtllm.CacheTransceiverConfig( |
| 141 | + backend=trtllm.CacheTransceiverBackendType.DEFAULT, |
| 142 | + max_tokens_in_buffer=512) |
| 143 | + |
| 144 | + kv_cache_transceiver_ctx = create_kv_cache_transceiver( |
| 145 | + mapping, kv_cache_manager_ctx, attention_type, cache_transceiver_config) |
| 146 | + |
| 147 | + fill_kv_cache_buffer(kv_cache_manager_ctx) |
| 148 | + |
| 149 | + # init ctx request |
| 150 | + sampling_params = SamplingParams() |
| 151 | + ctx_request = LlmRequest( |
| 152 | + request_id=0, |
| 153 | + max_new_tokens=1, |
| 154 | + input_tokens=list(range(256)), |
| 155 | + sampling_config=tensorrt_llm.bindings.SamplingConfig( |
| 156 | + sampling_params._get_sampling_config()), |
| 157 | + is_streaming=False, |
| 158 | + llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY) |
| 159 | + |
| 160 | + kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id, |
| 161 | + ctx_request.prompt_len, 1, |
| 162 | + ctx_request) |
| 163 | + # send ctx request |
| 164 | + kv_cache_transceiver_ctx.respond_and_send_async(ctx_request) |
| 165 | + |
| 166 | + time.sleep(10) |
| 167 | + |
| 168 | + # cancel ctx request |
| 169 | + is_cancelled = kv_cache_transceiver_ctx.cancel_request(ctx_request) |
| 170 | + assert is_cancelled |
0 commit comments