Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions tests/unittest/llmapi/test_llm_kv_cache_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,79 +160,80 @@ def check_events(llm,
scheduling_params=scheduling_params)
time.sleep(1)
events = llm.get_kv_cache_events(5)

# Created or stored event
total_stored_blocks = 0
if attention_dp_rank is None:
event = events.pop(0) # created event
assert event["event_id"] == 0
assert event["data"]["type"] == "created"
while events:
event = events.pop(0)
if event:
assert event["event_id"] == 1
assert event["data"]["type"] == "stored"
assert len(event["data"]["blocks"]) == 5
assert event["event_id"] > 0
total_stored_blocks += len(event["data"]["blocks"])
else:
while events:
event = events.pop(0)
assert "attention_dp_rank" in event
if event and event["attention_dp_rank"] == attention_dp_rank:
assert event["event_id"] in [0, 1]
assert event["data"]["type"] in ["created", "stored"]
if event["data"]["type"] == "created":
assert event["event_id"] == 0
if event["data"]["type"] == "stored":
assert event["event_id"] == 1
assert len(event["data"]["blocks"]) == 5
assert event["event_id"] > 0
total_stored_blocks += len(event["data"]["blocks"])

assert total_stored_blocks == 5 # Should have 5 blocks in total

_ = llm.generate(requests[1],
sampling_params=sampling_params,
scheduling_params=scheduling_params)
time.sleep(1)
events2 = llm.get_kv_cache_events(5)

total_stored_blocks = 0
has_removed_event = False
while events2:
event = events2.pop(0)
if event and (attention_dp_rank is None
or event.get("attention_dp_rank") == attention_dp_rank):
if event["event_id"] == 2:
# 2 removed events needed
# should be a removed event to make space for context block
assert event["data"]["type"] == "removed"
assert event["data"]["block_hashes"]
elif event["event_id"] == 3:
assert event["data"]["type"] == "removed"
if event["data"]["type"] == "removed":
has_removed_event = True
assert event["data"]["block_hashes"]
# stored event for 2nd request
elif event["event_id"] == 4:
assert event["data"]["type"] == "stored"
assert len(event["data"]["blocks"]) == 5
# stored events
elif event["data"]["type"] == "stored":
total_stored_blocks += len(event["data"]["blocks"])

assert total_stored_blocks == 5 # Should have 5 blocks in total
assert has_removed_event

_ = llm.generate(requests[2],
sampling_params=sampling_params,
scheduling_params=scheduling_params)
time.sleep(1)
events3 = llm.get_kv_cache_events(5)

total_stored_blocks = 0
has_removed_event = False
while events3:
event = events3.pop(0)
if event and (attention_dp_rank is None
or event.get("attention_dp_rank") == attention_dp_rank):
if event["event_id"] == 5:
assert event["data"]["type"] == "removed"
assert event["data"]["block_hashes"]
elif event["event_id"] == 6:
assert event["data"]["type"] == "removed"

if event["data"]["type"] == "removed":
has_removed_event = True
assert event["data"]["block_hashes"]
elif event["event_id"] == 7:
assert event["data"]["type"] == "stored"
assert len(event["data"]["blocks"]) == 5
elif event["data"]["type"] == "stored":
total_stored_blocks += len(event["data"]["blocks"])

assert total_stored_blocks == 5 # Should have 5 blocks in total
assert has_removed_event

# no more events after request is finished
assert not llm.get_kv_cache_events(5)


@pytest.mark.skip(reason="https://nvbugs/5445001")
def test_llm_kv_events_api():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6,
Expand All @@ -247,7 +248,6 @@ def test_llm_kv_events_api():
check_events(llm, requests, sampling_params)


@pytest.mark.skip(reason="https://nvbugs/5451407")
@skip_single_gpu
@pytest.mark.threadleak(enabled=False)
def test_llm_api_attention_dp_kv_events():
Expand Down
Loading