Skip to content

Commit 7faae02

Browse files
committed
A whole bunch of unit tests
1 parent 63fa1ba commit 7faae02

File tree

7 files changed

+159
-61
lines changed

7 files changed

+159
-61
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
using SizeType32 = tensorrt_llm::runtime::SizeType32;
2727
using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType;
2828

29-
using namespace tensorrt_llm::batch_manager;
30-
3129
namespace tensorrt_llm::batch_manager::kv_connector
3230
{
3331

cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline
3131
public:
3232
using KvCacheConnectorManager::KvCacheConnectorManager;
3333

34-
SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) override
34+
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
3535
{
3636
PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens",
3737
getNumNewMatchedTokens, request, numComputedTokens);

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
9898

9999
void addSequence(tb::LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
100100
tensorrt_llm::common::OptionalRef<tb::LlmRequest> llmRequest = std::nullopt,
101-
tensorrt_llm::common::OptionalRef<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager
101+
tensorrt_llm::common::OptionalRef<tb::kv_connector::KvCacheConnectorManager> kvCacheConnectorManager
102102
= std::nullopt) override
103103
{
104104
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, addSequence, requestId, inputLength, beamWidth,
@@ -238,10 +238,10 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
238238
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, flushIterationEvents);
239239
}
240240

241-
kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override
241+
[[nodiscard]] tb::kv_connector::KvCacheConnectorPoolsData getKvCacheConnectorPoolsData() const override
242242
{
243243
PYBIND11_OVERLOAD_PURE(
244-
kv_connector::KvCacheConnectorPoolsData, tbk::BaseKVCacheManager, getKvCacheConnectorPoolsData);
244+
tb::kv_connector::KvCacheConnectorPoolsData, tbk::BaseKVCacheManager, getKvCacheConnectorPoolsData);
245245
}
246246
};
247247

tensorrt_llm/_torch/pyexecutor/connector.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def get_num_new_matched_tokens(self, request: LlmRequest,
275275

276276
# TODO(jthomson04): This part is a bit ugly.
277277
# When the connector indicates that a request will be loaded asynchronously, we need to suspend it's execution.
278-
# This is problematic, since at this point when this function is called, the request has already been scheduled!
278+
# This is problematic, since at the point when this function is called, the request has already been scheduled!
279279
# Because of this, we need to remove it from our list of scheduled requests (see `take_scheduled_requests_pending_load`).
280280
if load_kv_async:
281281
self.new_async_requests.loading[request.request_id] = request
@@ -308,8 +308,6 @@ def take_scheduled_requests_pending_load(
308308
# Update the list of scheduled requests.
309309
scheduled_requests.context_requests = allowed_context_requests
310310

311-
return scheduled_requests
312-
313311
def build_connector_meta(self) -> object:
314312
metadata = self._run_on_leader(
315313
lambda: self.scheduler.build_connector_meta(self._scheduler_output))

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def _executor_loop(self):
924924
# We have to run this after we've run the KV cache manager (via the resource manager).
925925
# This takes requests that are pending an async load, and removes them from the scheduled context batch.
926926
if self.kv_connector_manager:
927-
scheduled_batch = self.kv_connector_manager.take_scheduled_requests_pending_load(
927+
self.kv_connector_manager.take_scheduled_requests_pending_load(
928928
scheduled_batch)
929929

930930
if scheduled_batch.batch_size > 0 or (
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import pickle
2+
import sys
3+
from unittest.mock import MagicMock
4+
5+
import cloudpickle
6+
import mpi4py
7+
import pytest
8+
9+
from tensorrt_llm import mpi_rank
10+
from tensorrt_llm._torch.pyexecutor.connector import KvCacheConnectorManager
11+
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
12+
13+
cloudpickle.register_pickle_by_value(sys.modules[__name__])
14+
mpi4py.MPI.pickle.__init__(
15+
cloudpickle.dumps,
16+
cloudpickle.loads,
17+
pickle.HIGHEST_PROTOCOL,
18+
)
19+
20+
21+
def run_across_mpi(executor, fun, num_ranks):
22+
return list(executor.starmap(fun, [() for i in range(num_ranks)]))
23+
24+
25+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
26+
def test_connector_manager_get_finished_allgather(mpi_pool_executor):
27+
28+
def test():
29+
worker = MagicMock()
30+
31+
if mpi_rank() == 0:
32+
scheduler = MagicMock()
33+
34+
scheduler.request_finished.return_value = True
35+
else:
36+
scheduler = None
37+
38+
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
39+
40+
req = MagicMock()
41+
42+
req.request_id = 42
43+
44+
manager.request_finished(req)
45+
46+
# To start, make both workers return nothing.
47+
worker.get_finished.return_value = ([], [])
48+
49+
assert manager.get_finished() == []
50+
51+
assert worker.get_finished.call_count == 1
52+
assert worker.get_finished.call_args[0] == ([42], [])
53+
54+
worker.get_finished.reset_mock()
55+
56+
# Now, only return the request id on one worker.
57+
if mpi_rank() == 0:
58+
worker.get_finished.return_value = ([42], [])
59+
else:
60+
worker.get_finished.return_value = ([], [])
61+
62+
# It should still return nothing, since rank 1 is still saving.
63+
assert manager.get_finished() == []
64+
65+
assert worker.get_finished.call_count == 1
66+
assert worker.get_finished.call_args[0] == ([], [])
67+
68+
# Now, also return it on worker 1.
69+
if mpi_rank() == 0:
70+
worker.get_finished.return_value = ([], [])
71+
else:
72+
worker.get_finished.return_value = ([42], [])
73+
74+
assert manager.get_finished() == [req]
75+
76+
run_across_mpi(mpi_pool_executor, test, 2)
77+
78+
79+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
80+
def test_connector_manager_num_matched_tokens(mpi_pool_executor):
81+
82+
def test():
83+
worker = MagicMock()
84+
85+
if mpi_rank() == 0:
86+
scheduler = MagicMock()
87+
scheduler.get_num_new_matched_tokens.return_value = (16, True)
88+
else:
89+
scheduler = None
90+
91+
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
92+
93+
req = MagicMock()
94+
95+
req.request_id = 42
96+
97+
assert manager.get_num_new_matched_tokens(req, 32) == 16
98+
assert req.is_kv_cache_connector_async_onboard
99+
100+
if mpi_rank() == 0:
101+
assert scheduler.get_num_new_matched_tokens.call_count == 1
102+
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req,
103+
32)
104+
105+
run_across_mpi(mpi_pool_executor, test, 2)
106+
107+
108+
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
109+
def test_connector_manager_take_scheduled_requests(mpi_pool_executor):
110+
111+
def test():
112+
worker = MagicMock()
113+
114+
if mpi_rank() == 0:
115+
scheduler = MagicMock()
116+
else:
117+
scheduler = None
118+
119+
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
120+
121+
scheduled_requests = ScheduledRequests()
122+
123+
req0 = MagicMock()
124+
req0.request_id = 0
125+
126+
req1 = MagicMock()
127+
req1.request_id = 1
128+
129+
if mpi_rank() == 0:
130+
scheduler.get_num_new_matched_tokens.return_value = (16, True)
131+
132+
assert manager.get_num_new_matched_tokens(req0, 0) == 16
133+
if mpi_rank() == 0:
134+
assert scheduler.get_num_new_matched_tokens.call_count == 1
135+
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0,
136+
0)
137+
138+
scheduler.get_num_new_matched_tokens.reset_mock()
139+
scheduler.get_num_new_matched_tokens.return_value = (32, False)
140+
141+
assert manager.get_num_new_matched_tokens(req1, 0) == 32
142+
if mpi_rank() == 0:
143+
assert scheduler.get_num_new_matched_tokens.call_count == 1
144+
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1,
145+
0)
146+
147+
scheduled_requests.context_requests = [req0, req1]
148+
149+
manager.take_scheduled_requests_pending_load(scheduled_requests)
150+
151+
assert scheduled_requests.context_requests == [req1]
152+
153+
run_across_mpi(mpi_pool_executor, test, 2)

tests/unittest/bindings/test_connector_bindings.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)