You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: components/frontend/src/dynamo/frontend/main.py
+7Lines changed: 7 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -112,6 +112,12 @@ def parse_args():
112
112
help=" KV Router. Disable KV events.",
113
113
)
114
114
parser.set_defaults(use_kv_events=True)
115
+
parser.add_argument(
116
+
"--router-replica-sync",
117
+
action="store_true",
118
+
default=False,
119
+
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.",
Copy file name to clipboardExpand all lines: docs/architecture/kv_cache_routing.md
+57-32Lines changed: 57 additions & 32 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,12 +17,13 @@ For performance testing, compare a typical workload with `--router-mode random|r
17
17
18
18
The KV-aware routing arguments:
19
19
20
-
-`--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks).
20
+
-`--kv-overlap-score-weight`: Sets the amount of weighting on overlaps with prefix caches, which directly contributes to the prefill cost. A large weight is expected to yield a better TTFT (at the expense of worse ITL). When set to 0, prefix caches are not considered at all (falling back to pure load balancing behavior on the active blocks). Defaults to 1.
21
21
22
-
-`--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 recovers the deterministic behavior where the min logit is picked.
22
+
-`--router-temperature`: Sets the temperature when randomly selecting workers to route to via softmax sampling on the router cost logits. Setting it to 0 (default) recovers the deterministic behavior where the min logit is picked.
23
23
24
-
-`--use-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true, then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
24
+
-`--use-kv-events`/`--no-kv-events`: Sets whether to listen to KV events for maintaining the global view of cached blocks. If true (default), then we use the `KvIndexer` to listen to the block creation and deletion events. If false, `ApproxKvIndexer`, which assumes the kv cache of historical prompts exists for fixed time durations (hard-coded to 120s), is used to predict the kv cache hit ratio in each engine. Set false if your backend engine does not emit KV events.
25
25
26
+
-`--router-replica-sync`: Enables state synchronization between multiple router replicas via NATS. Disabled by default, and can be enabled by passing the flag in. When enabled, router replicas share their view of KV cache distribution and active sequences, allowing all routers to make optimal routing decisions even when requests are distributed across multiple router instances. This improves fault tolerance and routing accuracy in multi-router deployments.
26
27
27
28
## Architecture
28
29
@@ -45,6 +46,22 @@ We can then use the default routing methods exposed by the client class to send
45
46
46
47
KV Cache routing uses direct routing with a special worker selection algorithm.
47
48
49
+
## Serving Two Router Replicas
50
+
51
+
For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance.
52
+
53
+
To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend:
When `--router-replica-sync` is enabled, the router replicas will communicate with each other via NATS to maintain consistent state across instances. This allows both routers to have a complete view of the KV cache distribution and make optimal routing decisions, even when requests are distributed across multiple router instances.
64
+
48
65
## Understanding KV Cache
49
66
The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching).
50
67
@@ -88,30 +105,46 @@ Further details can be found for: [TRT-LLM](https://developer.nvidia.com/blog/in
Load balancing in LLM serving becomes complex when enabling KV Cache reuse. While KV Cache reuse can save significant computation, if the routing strategy is not aware of the unique KV states of each worker we can:
100
-
- miss opportunities for KV Cache reuse if routing to the “wrong” node
118
+
- miss opportunities for KV Cache reuse if routing to the "wrong" node
101
119
- get into an imbalanced state where a few workers are processing many requests, lowering throughput of entire system
102
120
103
-
The best way to solve these issues is for the router to have a global view of KV Cache and load. With this view, the router can use a cost function to score the workers and make decisions to maximize cache hits while keeping the system balanced and throughput high.
121
+
The router uses a cost function that considers both the prefill cost (influenced by cached blocks) and the decode load to make optimal routing decisions:
122
+
123
+
### Cost Calculation
124
+
125
+
1.**Prefill blocks**: The number of tokens that need to be processed during prefill is predicted based on the request's input tokens and the cached blocks available on each worker. This is divided by the block size to get the effective "prefill blocks". This prediction is updated when the first output token is produced, signaling prefill completion.
104
126
105
-
In the above image, our cost function is (KV match - Load) so we select Worker 2 even though Worker 3 would offer the best KV match.
106
-
- Worker 1 = (0.15 - 0.30) = -0.15
107
-
-**Worker 2 = (0.50 - 0.50) = 0**
108
-
- Worker 3 = (0.75 - 0.80) = -0.05
127
+
2.**Decode blocks**: The number of blocks needed during the decode phase is predicted based on the request's input tokens and the current active sequences on each worker. This is updated when the request is freed (blocks are dereferenced or freed).
- The `overlap_score_weight` parameter controls the importance of cache hits vs. load balancing
132
+
- A higher weight prioritizes cache reuse (better TTFT) while a lower weight prioritizes load distribution (better ITL)
133
+
134
+
### Worker Selection
135
+
136
+
The router selects the worker with the lowest cost. When `router_temperature` is set to a non-zero value, the router uses softmax sampling on the normalized cost logits to introduce randomness in the selection, which can help with load distribution.
137
+
138
+
Example calculation with `overlap_score_weight = 1.0`:
In Dynamo, we want to support KV Cache Routing and load balancing for many backends that have different implementations of KV Cache and record different metrics. To that end, we built a KVPublisher that can be plugged into any framework to publish KV Events and a WorkerMetricsPublisher that can publish Metric Events.
145
+
In Dynamo, we support KV Cache Routing for many backends that have different implementations of KV Cache. To enable this, we built a KVPublisher that can be plugged into any framework to publish KV Events.
113
146
114
-
On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree and a KvMetricsAggregator which aggregates metric events by worker.
147
+
On the receiving side we have a KVIndexer which accepts events from the KVPublisher and puts them into a global prefix tree for tracking cached blocks across all workers.
115
148
116
149
```text
117
150
+----------------+ +-----------------+
@@ -121,13 +154,8 @@ On the receiving side we have a KVIndexer which accepts events from the KVPublis
@@ -144,18 +172,15 @@ The KVIndexer builds and maintains a global view of cached blocks in a prefix tr
144
172
145
173
The KVIndexer has a method `find_matches_for_request`, which takes in tokens and returns a dictionary with keys of worker id and values of the number of matched KV Blocks.
146
174
147
-
### WorkerMetricsPublisher
148
-
We added a KvMetrics Publisher which sends the following metrics to the KvMetricsAggregator:
149
-
- num_requests_waiting
150
-
- gpu_cache_usage_perc
151
-
- gpu_prefix_cache_hit_rate
152
-
- request_active_slots
153
-
- request_total_slots
154
-
- kv_active_blocks
155
-
- kv_total_blocks
175
+
### Inter-Router Communication
176
+
177
+
In multi-router deployments, each router only observes a subset of requests. To maintain a consistent global view of active sequences and KV cache states, routers broadcast their local actions to other replicas through three synchronization events:
178
+
179
+
1.**AddRequest**: Published when assigning a request to a worker, containing the request ID, worker ID, token sequence blocks, and overlap score. This updates other routers' tracking of which blocks are in use.
180
+
181
+
2.**MarkPrefillCompleted**: Published when a request transitions from prefill to decode phase, signaling that prefill tokens should no longer count toward the worker's active prefill load.
156
182
157
-
Currently, the WorkerMetricsPublisher exists as a Python binding.
183
+
3.**Free**: Published when a request completes and its resources are released, allowing other routers to update their block reference counts.
158
184
159
-
### KvMetricsAggregator
160
-
The KvMetricsAggregator receives these metrics and aggregates them. It has a method `get_metrics` which returns an object of `AggregatedMetrics`.
185
+
Each event includes a unique router ID to prevent processing of self-generated events. This asynchronous communication ensures all routers maintain synchronized KV cache state for optimal routing decisions despite handling different request streams.
0 commit comments