Skip to content

Commit f3d784f

Browse files
authored
feat: query instance_id based on routing strategy (#1787)
1 parent 13560ab commit f3d784f

File tree

2 files changed

+62
-50
lines changed

2 files changed

+62
-50
lines changed

lib/llm/src/kv_router.rs

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -313,69 +313,81 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
313313
InstanceSource::Dynamic(_) => {
314314
// Extract context ID for request tracking
315315
let context_id = request.context().id().to_string();
316-
317316
let (instance_id, overlap_amount) = self
318317
.chooser
319318
.find_best_match(&context_id, &request.token_ids)
320319
.await?;
320+
let query_instance_id = request.has_annotation("query_instance_id");
321+
// Extract context information before moving the request
322+
let stream_context = request.context().clone();
321323
// Update the request with the estimated prefix hit blocks
322324
let (mut backend_input, context) = request.into_parts();
323325
let isl = backend_input.token_ids.len();
324326
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
325327
let updated_request = context.map(|_| backend_input);
328+
// if request has the annotation "query_instance_id", for example
329+
// curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
330+
// request will not be routed to worker immediately
331+
if query_instance_id {
332+
let instance_id_str = instance_id.to_string();
333+
let response =
334+
Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
335+
let stream = stream::iter(vec![response]);
336+
Ok(ResponseStream::new(Box::pin(stream), stream_context))
337+
} else {
338+
// Get the response stream from the worker
339+
let mut response_stream =
340+
self.inner.direct(updated_request, instance_id).await?;
341+
342+
// Wrap the stream to track tokens
343+
let stream_context = response_stream.context();
344+
let chooser = self.chooser.clone();
345+
let request_id = context_id.clone();
346+
let block_size = chooser.block_size() as usize;
347+
348+
let wrapped_stream = Box::pin(async_stream::stream! {
349+
let mut accumulated_tokens = Vec::new();
350+
let mut total_output_length = 0usize;
351+
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
352+
let mut first_push_done = false;
353+
354+
while let Some(item) = response_stream.next().await {
355+
// Track tokens if they exist in the response
356+
let Some(ref output) = item.data else {
357+
yield item;
358+
continue;
359+
};
360+
if output.token_ids.is_empty() {
361+
yield item;
362+
continue;
363+
}
326364

327-
// Get the response stream from the worker
328-
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
329-
330-
// Wrap the stream to track tokens
331-
let stream_context = response_stream.context();
332-
let chooser = self.chooser.clone();
333-
let request_id = context_id.clone();
334-
let block_size = chooser.block_size() as usize;
335-
336-
let wrapped_stream = Box::pin(async_stream::stream! {
337-
let mut accumulated_tokens = Vec::new();
338-
let mut total_output_length = 0usize;
339-
let mut last_block_index = (isl.saturating_sub(1)) / block_size;
340-
let mut first_push_done = false;
365+
// Add tokens to accumulator
366+
accumulated_tokens.extend_from_slice(&output.token_ids);
367+
total_output_length += output.token_ids.len();
368+
369+
// Always push for the first generated token (to mark prefill done)
370+
// or when we've moved to a new block
371+
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
372+
let should_push = (!first_push_done && total_output_length >= 1) ||
373+
(first_push_done && current_block_index > last_block_index);
374+
375+
if should_push {
376+
chooser.push(&request_id, &accumulated_tokens).await;
377+
accumulated_tokens.clear();
378+
last_block_index = current_block_index;
379+
if !first_push_done {
380+
first_push_done = true;
381+
}
382+
}
341383

342-
while let Some(item) = response_stream.next().await {
343-
// Track tokens if they exist in the response
344-
let Some(ref output) = item.data else {
345-
yield item;
346-
continue;
347-
};
348-
if output.token_ids.is_empty() {
349384
yield item;
350-
continue;
351385
}
352386

353-
// Add tokens to accumulator
354-
accumulated_tokens.extend_from_slice(&output.token_ids);
355-
total_output_length += output.token_ids.len();
356-
357-
// Always push for the first generated token (to mark prefill done)
358-
// or when we've moved to a new block
359-
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
360-
let should_push = (!first_push_done && total_output_length >= 1) ||
361-
(first_push_done && current_block_index > last_block_index);
362-
363-
if should_push {
364-
chooser.push(&request_id, &accumulated_tokens).await;
365-
accumulated_tokens.clear();
366-
last_block_index = current_block_index;
367-
if !first_push_done {
368-
first_push_done = true;
369-
}
370-
}
371-
372-
yield item;
373-
}
374-
375-
chooser.free(&request_id).await;
376-
});
377-
378-
Ok(ResponseStream::new(wrapped_stream, stream_context))
387+
chooser.free(&request_id).await;
388+
});
389+
Ok(ResponseStream::new(wrapped_stream, stream_context))
390+
}
379391
}
380392
}
381393
}

lib/llm/src/preprocessor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ impl OpenAIPreprocessor {
397397
// Only set event if not already set to avoid overriding existing events (like errors)
398398
if response.event.is_none() {
399399
response.event = metrics_annotated.event;
400+
response.comment = metrics_annotated.comment;
400401
}
401-
response.comment = metrics_annotated.comment;
402402
}
403403

404404
tracing::trace!(

0 commit comments

Comments
 (0)