Skip to content

Commit b538dab

Browse files
ShunkangShunkang
authored andcommitted
Add Test
Signed-off-by: Shunkang <[email protected]>
1 parent a9f8c5a commit b538dab

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/unittest/others/test_kv_cache_transceiver.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import time
2+
13
import pytest
24
import torch
35

@@ -123,3 +125,46 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
123125
assert torch.equal(
124126
kv_cache_manager_gen.get_buffers(0),
125127
kv_cache_manager_ctx.get_buffers(0)), "different kv-cache values"
128+
129+
130+
@pytest.mark.timeout(120)
131+
@pytest.mark.parametrize("attention_type",
132+
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
133+
ids=["mha", "mla"])
134+
def test_cancel_request_in_transmission_ctx(attention_type):
135+
# Init kv_cache manager and cache transceiver
136+
mapping = Mapping(world_size=1, rank=0)
137+
ctx_kv_cache_dtype = DataType.HALF
138+
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
139+
140+
cache_transceiver_config = trtllm.CacheTransceiverConfig(
141+
backend=trtllm.CacheTransceiverBackendType.DEFAULT,
142+
max_tokens_in_buffer=512)
143+
144+
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
145+
mapping, kv_cache_manager_ctx, attention_type, cache_transceiver_config)
146+
147+
fill_kv_cache_buffer(kv_cache_manager_ctx)
148+
149+
# init ctx request
150+
sampling_params = SamplingParams()
151+
ctx_request = LlmRequest(
152+
request_id=0,
153+
max_new_tokens=1,
154+
input_tokens=list(range(256)),
155+
sampling_config=tensorrt_llm.bindings.SamplingConfig(
156+
sampling_params._get_sampling_config()),
157+
is_streaming=False,
158+
llm_request_type=LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY)
159+
160+
kv_cache_manager_ctx.impl.add_sequence(ctx_request.py_request_id,
161+
ctx_request.prompt_len, 1,
162+
ctx_request)
163+
# send ctx request
164+
kv_cache_transceiver_ctx.respond_and_send_async(ctx_request)
165+
166+
time.sleep(10)
167+
168+
# cancel ctx request
169+
is_cancelled = kv_cache_transceiver_ctx.cancel_request(ctx_request)
170+
assert is_cancelled

0 commit comments

Comments
 (0)