diff --git a/rust/src/server/src/routes/inference/generate.rs b/rust/src/server/src/routes/inference/generate.rs index ff7ea3c6302c..b256b7721f9c 100644 --- a/rust/src/server/src/routes/inference/generate.rs +++ b/rust/src/server/src/routes/inference/generate.rs @@ -3,23 +3,33 @@ mod types; mod validate; use std::collections::HashMap; +use std::convert::Infallible; +use std::result::Result; use std::sync::Arc; +use asynk_strim_attr::{TryYielder, try_stream}; use axum::Json; use axum::extract::State; use axum::http::HeaderMap; +use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; +use futures::{Stream, StreamExt as _, pin_mut}; use thiserror_ext::AsReport as _; -use tracing::info; +use tracing::{error, info, trace}; use tracing_futures::Instrument as _; use vllm_engine_core_client::protocol::logprobs::{Logprobs, PositionLogprobs}; -use vllm_llm::{CollectedGenerateOutput, GenerateOutputStreamExt as _}; +use vllm_llm::{ + CollectedGenerateOutput, FinishReason, GenerateOutput, GenerateOutputStreamExt as _, +}; use self::convert::prepare_generate_request; -use self::types::{GenerateLogprob, GenerateRequest, GenerateResponse, GenerateResponseChoice}; -use crate::error::{ApiError, server_error}; +use self::types::{ + GenerateLogprob, GenerateRequest, GenerateResponse, GenerateResponseChoice, + GenerateResponseStreamChoice, GenerateStreamResponse, +}; +use crate::error::{ApiError, bail_server_error, server_error}; use crate::routes::openai::utils::logprobs::clamp_logprob; -use crate::routes::openai::utils::types::{ChatLogProbs, ChatLogProbsContent, TopLogProb}; +use crate::routes::openai::utils::types::{ChatLogProbs, ChatLogProbsContent, TopLogProb, Usage}; use crate::routes::openai::utils::validated_json::ValidatedJson; use crate::state::AppState; use crate::utils::resolve_request_context; @@ -46,6 +56,7 @@ pub async fn generate( let log_request = state.enable_log_requests; let include_logprobs = prepared.include_logprobs; let include_prompt_logprobs = prepared.include_prompt_logprobs; + let stream = prepared.stream; let raw_stream = match state .chat @@ -64,6 +75,20 @@ pub async fn generate( } }; + if stream { + let chunk_stream = generate_chunk_stream( + raw_stream, + prepared.request_id, + log_request, + prepared.include_usage, + prepared.include_continuous_usage, + include_logprobs, + ); + let sse_stream = generate_sse_stream(chunk_stream).instrument(request_span); + + return Sse::new(sse_stream).into_response(); + } + let collected = match raw_stream.collect_output().instrument(request_span.clone()).await { Ok(collected) => collected, Err(error) => { @@ -98,6 +123,102 @@ pub async fn generate( Json(response).into_response() } +#[try_stream] +async fn generate_chunk_stream( + stream: impl Stream>, + request_id: String, + log_request: bool, + include_usage: bool, + include_continuous_usage: bool, + include_logprobs: bool, + mut y: TryYielder, +) -> Result<(), ApiError> { + pin_mut!(stream); + let mut prompt_tokens: Option = None; + let mut output_tokens = 0_u32; + + while let Some(next) = stream.next().await { + match next { + Ok(output) => { + if prompt_tokens.is_none() { + prompt_tokens = + output.prompt_info.as_ref().map(|info| info.prompt_token_ids.len() as u32); + } + let usage_prompt_tokens = prompt_tokens.unwrap_or_default(); + + let token_ids = output.token_ids; + output_tokens = output_tokens.saturating_add(token_ids.len() as u32); + let finish_reason = output.finish_reason; + + if matches!(finish_reason.as_ref(), Some(FinishReason::Error)) { + bail_server_error!("Internal server error"); + } + + if let Some(finish_reason) = finish_reason.as_ref() + && log_request + { + info!( + stream = true, + prompt_tokens = usage_prompt_tokens, + output_tokens, + finish_reason = finish_reason.as_str(), + "generate finished" + ); + } + + if token_ids.is_empty() && finish_reason.is_none() { + continue; + } + + let logprobs = if include_logprobs && !token_ids.is_empty() { + let logprobs = output.logprobs.as_ref().ok_or_else(|| { + server_error!( + "raw generate stream requested logprobs but generation returned none" + ) + })?; + Some(raw_logprobs_to_openai_chat(logprobs)?) + } else { + None + }; + + y.yield_ok(GenerateStreamResponse { + request_id: request_id.clone(), + choices: vec![GenerateResponseStreamChoice { + index: 0, + logprobs, + finish_reason: finish_reason.map(|reason| reason.as_str().to_string()), + token_ids, + }], + usage: include_continuous_usage + .then(|| Usage::from_counts(usage_prompt_tokens, output_tokens)), + }) + .await; + } + Err(error) => { + error!( + error = %error.as_report(), + "raw generate stream failed" + ); + bail_server_error!("{}", error.to_report_string()); + } + } + } + + if include_usage { + y.yield_ok(GenerateStreamResponse { + request_id, + choices: Vec::new(), + usage: Some(Usage::from_counts( + prompt_tokens.unwrap_or_default(), + output_tokens, + )), + }) + .await; + } + + Ok(()) +} + fn collect_generate( collected: CollectedGenerateOutput, request_id: String, @@ -213,3 +334,94 @@ fn position_to_logprob_map(position: &PositionLogprobs) -> HashMap String { format!("token_id:{token_id}") } + +/// Convert one raw-generate chunk stream into SSE events. +#[try_stream] +async fn generate_sse_stream( + stream: impl Stream>, + mut y: TryYielder, +) -> Result<(), Infallible> { + pin_mut!(stream); + + while let Some(next) = stream.next().await { + match next { + Ok(chunk) => y.yield_ok(to_sse_event(&chunk)).await, + Err(error) => { + y.yield_ok(to_error_sse_event(&error)).await; + break; + } + } + } + + y.yield_ok(done_sse_event()).await; + Ok(()) +} + +fn to_sse_event(chunk: &GenerateStreamResponse) -> Event { + let payload = serde_json::to_string(chunk).expect("generate chunk must serialize to JSON"); + trace!(payload, "generate emitting chunk"); + Event::default().data(payload) +} + +fn to_error_sse_event(error: &ApiError) -> Event { + let payload = serde_json::to_string(&error.to_error_response()) + .expect("ErrorResponse must serialize to JSON"); + trace!(payload, "generate emitting error"); + Event::default().data(payload) +} + +fn done_sse_event() -> Event { + trace!("generate emitting done"); + Event::default().data("[DONE]") +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use futures::{TryStreamExt as _, stream}; + use vllm_llm::GeneratePromptInfo; + + use super::*; + + #[tokio::test] + async fn generate_chunk_stream_captures_late_prompt_info() { + let stream = stream::iter(vec![ + Ok(GenerateOutput { + request_id: String::new(), + prompt_info: None, + token_ids: Vec::new(), + logprobs: None, + finish_reason: None, + kv_transfer_params: None, + }), + Ok(GenerateOutput { + request_id: String::new(), + prompt_info: Some(GeneratePromptInfo { + prompt_token_ids: Arc::from([11_u32, 22_u32]), + prompt_logprobs: None, + }), + token_ids: vec![33], + logprobs: None, + finish_reason: Some(FinishReason::stop_eos()), + kv_transfer_params: None, + }), + ]); + + let chunks: Vec<_> = + generate_chunk_stream(stream, "raw-stream".to_string(), false, true, true, false) + .try_collect() + .await + .expect("collect chunks"); + + assert_eq!(chunks.len(), 2); + assert_eq!( + chunks[0].usage.as_ref().expect("chunk usage").prompt_tokens, + 2 + ); + assert_eq!( + chunks[1].usage.as_ref().expect("final usage").prompt_tokens, + 2 + ); + } +} diff --git a/rust/src/server/src/routes/inference/generate/convert.rs b/rust/src/server/src/routes/inference/generate/convert.rs index 70374c2f1eb8..844a606bcfdf 100644 --- a/rust/src/server/src/routes/inference/generate/convert.rs +++ b/rust/src/server/src/routes/inference/generate/convert.rs @@ -10,6 +10,9 @@ use crate::utils::{ResolvedRequestContext, merge_kv_transfer_params}; pub struct PreparedRequest { pub request_id: String, pub text_request: TextRequest, + pub stream: bool, + pub include_usage: bool, + pub include_continuous_usage: bool, pub include_logprobs: bool, pub include_prompt_logprobs: bool, } @@ -23,6 +26,18 @@ pub fn prepare_generate_request( ) -> Result { validate::validate_request_compat(&request, served_model_names)?; + let stream = request.stream; + let include_usage = request + .stream_options + .as_ref() + .and_then(|options| options.include_usage) + .unwrap_or(false); + let include_continuous_usage = include_usage + && request + .stream_options + .as_ref() + .and_then(|options| options.continuous_usage_stats) + .unwrap_or(false); let include_logprobs = request.sampling_params.logprobs.is_some(); let include_prompt_logprobs = request.sampling_params.prompt_logprobs.is_some(); let mut sampling_params = request.sampling_params; @@ -47,6 +62,9 @@ pub fn prepare_generate_request( Ok(PreparedRequest { request_id: ctx.request_id, text_request, + stream, + include_usage, + include_continuous_usage, include_logprobs, include_prompt_logprobs, }) @@ -109,4 +127,28 @@ mod tests { Some(json!({"connector": "x"})) ); } + + #[test] + fn prepare_generate_request_gates_continuous_usage_on_include_usage() { + let request: GenerateRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [11, 22], + "stream": true, + "stream_options": { + "continuous_usage_stats": true + }, + "sampling_params": {} + })) + .expect("parse request"); + + let prepared = prepare_generate_request( + request, + &["Qwen/Qwen1.5-0.5B-Chat".to_string()], + ResolvedRequestContext::default(), + ) + .expect("prepare"); + + assert!(!prepared.include_usage); + assert!(!prepared.include_continuous_usage); + } } diff --git a/rust/src/server/src/routes/inference/generate/types.rs b/rust/src/server/src/routes/inference/generate/types.rs index de7a196c3c6b..d4567c44aa6d 100644 --- a/rust/src/server/src/routes/inference/generate/types.rs +++ b/rust/src/server/src/routes/inference/generate/types.rs @@ -5,7 +5,7 @@ use serde_json::{Map, Value}; use validator::Validate; use vllm_text::SamplingParams; -use crate::routes::openai::utils::types::{ChatLogProbs, Normalizable}; +use crate::routes::openai::utils::types::{ChatLogProbs, Normalizable, StreamOptions, Usage}; /// vLLM-compatible request type for the token-in/token-out generate API. #[serde_with::skip_serializing_none] @@ -17,6 +17,7 @@ pub struct GenerateRequest { pub sampling_params: SamplingParams, #[serde(default)] pub stream: bool, + pub stream_options: Option, pub cache_salt: Option, #[serde(default)] pub priority: i32, @@ -37,6 +38,25 @@ pub(super) struct GenerateResponseChoice { pub token_ids: Vec, } +/// Mirrors the Python vLLM `GenerateResponseStreamChoice` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct GenerateResponseStreamChoice { + pub index: u32, + pub logprobs: Option, + pub finish_reason: Option, + pub token_ids: Vec, +} + +/// Mirrors the Python vLLM `GenerateStreamResponse` class. +#[serde_with::skip_serializing_none] +#[derive(Debug, Clone, Serialize)] +pub(super) struct GenerateStreamResponse { + pub request_id: String, + pub choices: Vec, + pub usage: Option, +} + /// Mirrors the Python vLLM `GenerateResponse` class. #[serde_with::skip_serializing_none] #[derive(Debug, Clone, Serialize)] diff --git a/rust/src/server/src/routes/inference/generate/validate.rs b/rust/src/server/src/routes/inference/generate/validate.rs index 74a5bbb690ae..43347c60b574 100644 --- a/rust/src/server/src/routes/inference/generate/validate.rs +++ b/rust/src/server/src/routes/inference/generate/validate.rs @@ -13,8 +13,11 @@ pub(super) fn validate_request_compat( return Err(ApiError::model_not_found(model.clone())); } - if request.stream { - bail_invalid_request!(param = "stream", "stream=true is not supported."); + if request.stream_options.is_some() && !request.stream { + bail_invalid_request!( + param = "stream_options", + "stream_options are only supported when stream=true." + ); } if request.token_ids.is_empty() { @@ -65,11 +68,24 @@ mod tests { } #[test] - fn validate_request_compat_rejects_streaming() { + fn validate_request_compat_accepts_streaming() { let request = GenerateRequest { stream: true, ..base_request() }; + assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_ok()); + } + + #[test] + fn validate_request_compat_rejects_stream_options_without_streaming() { + let request: GenerateRequest = serde_json::from_value(json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "token_ids": [11, 22], + "stream": false, + "stream_options": {"include_usage": true}, + "sampling_params": {} + })) + .expect("parse request"); assert!(validate_request_compat(&request, &served(&["Qwen/Qwen1.5-0.5B-Chat"])).is_err()); } diff --git a/rust/src/server/src/routes/tests.rs b/rust/src/server/src/routes/tests.rs index d166b8004470..b1a4f0705fd4 100644 --- a/rust/src/server/src/routes/tests.rs +++ b/rust/src/server/src/routes/tests.rs @@ -2425,8 +2425,72 @@ async fn non_stream_raw_generate_returns_token_output_envelope() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] -async fn raw_generate_rejects_streaming() { - let mut app = test_app().await; +async fn stream_raw_generate_returns_sse_chunks_and_usage() { + let ipc = IpcNamespace::new().expect("create ipc namespace"); + let handshake_address = ipc.handshake_endpoint(); + let engine_id = b"engine-raw-generate-stream".to_vec(); + + let engine_task = MockEngineTask::new(spawn_mock_engine_task( + handshake_address.clone(), + engine_id.clone(), + |dealer, push| { + boxed_test_future(async move { + let add = recv_engine_message(dealer).await; + let request: EngineCoreRequest = + rmp_serde::from_slice(&add[1]).expect("decode request"); + assert_eq!(request.prompt_token_ids.as_deref(), Some(&[11, 22][..])); + assert_eq!(request.external_req_id.as_deref(), Some("raw-stream")); + + send_outputs( + push, + EngineCoreOutputs { + engine_index: 0, + outputs: vec![ + request_output_with_logprobs( + &request.request_id, + vec![33], + None, + None, + Some(sample_logprobs_for_token(33, 34)), + None, + ), + request_output_with_logprobs( + &request.request_id, + vec![44], + Some(EngineCoreFinishReason::Stop), + None, + Some(sample_logprobs_for_token(44, 45)), + None, + ), + ], + scheduler_stats: None, + timestamp: 0.0, + utility_output: None, + finished_requests: None, + wave_complete: None, + start_wave: None, + }, + ) + .await; + }) + }, + )); + + let client = EngineCoreClient::connect( + EngineCoreClientConfig::new_single(handshake_address) + .with_model_name("test-model") + .with_local_input_output_addresses( + Some(ipc.input_endpoint()), + Some(ipc.output_endpoint()), + ), + ) + .await + .expect("connect client"); + let chat = ChatLlm::from_shared_backend(Llm::new(client), Arc::new(FakeChatBackend::new())); + let mut app = build_router(Arc::new(AppState::new( + vec!["Qwen/Qwen1.5-0.5B-Chat".to_string()], + chat, + ))); let response = app .call( @@ -2437,9 +2501,17 @@ async fn raw_generate_rejects_streaming() { .body(Body::from( json!({ "model": "Qwen/Qwen1.5-0.5B-Chat", + "request_id": "raw-stream", "token_ids": [11, 22], "stream": true, - "sampling_params": {} + "stream_options": { + "include_usage": true, + "continuous_usage_stats": true + }, + "sampling_params": { + "max_tokens": 2, + "logprobs": 1 + } }) .to_string(), )) @@ -2448,10 +2520,196 @@ async fn raw_generate_rejects_streaming() { .await .expect("call app"); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").and_then(|value| value.to_str().ok()), + Some("text/event-stream") + ); + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); - let json: serde_json::Value = serde_json::from_slice(&body).expect("decode json"); - assert_eq!(json["error"]["param"], "stream"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + assert_eq!(payloads.len(), 4, "{text}"); + + let first: serde_json::Value = serde_json::from_str(payloads[0]).expect("first chunk json"); + assert_eq!(first["request_id"], "raw-stream"); + assert_eq!(first["choices"][0]["index"], 0); + assert_eq!(first["choices"][0]["token_ids"], json!([33])); + assert_eq!( + first["choices"][0]["logprobs"]["content"][0]["token"], + "token_id:33" + ); + assert_eq!(first["usage"]["prompt_tokens"], 2); + assert_eq!(first["usage"]["completion_tokens"], 1); + + let second: serde_json::Value = serde_json::from_str(payloads[1]).expect("second chunk json"); + assert_eq!(second["choices"][0]["token_ids"], json!([44])); + assert_eq!(second["choices"][0]["finish_reason"], "stop"); + assert_eq!(second["usage"]["completion_tokens"], 2); + + let usage: serde_json::Value = serde_json::from_str(payloads[2]).expect("usage chunk json"); + assert_eq!(usage["choices"], json!([])); + assert_eq!(usage["usage"]["prompt_tokens"], 2); + assert_eq!(usage["usage"]["completion_tokens"], 2); + assert_eq!(usage["usage"]["total_tokens"], 4); + assert_eq!(payloads[3], "[DONE]"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_raw_generate_emits_final_usage_without_continuous_usage() { + let (mut app, engine_task) = test_app_with_stream_output_specs(vec![ + (vec![33], None), + (vec![44], Some(EngineCoreFinishReason::Stop)), + ]) + .await; + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "request_id": "raw-stream-final-usage", + "token_ids": [11, 22], + "stream": true, + "stream_options": { + "include_usage": true + }, + "sampling_params": { + "max_tokens": 2 + } + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + assert_eq!(payloads.len(), 4, "{text}"); + + let first: serde_json::Value = serde_json::from_str(payloads[0]).expect("first chunk json"); + assert_eq!(first["choices"][0]["token_ids"], json!([33])); + assert!(first.get("usage").is_none()); + + let second: serde_json::Value = serde_json::from_str(payloads[1]).expect("second chunk json"); + assert_eq!(second["choices"][0]["token_ids"], json!([44])); + assert_eq!(second["choices"][0]["finish_reason"], "stop"); + assert!(second.get("usage").is_none()); + + let usage: serde_json::Value = serde_json::from_str(payloads[2]).expect("usage chunk json"); + assert_eq!(usage["choices"], json!([])); + assert_eq!(usage["usage"]["prompt_tokens"], 2); + assert_eq!(usage["usage"]["completion_tokens"], 2); + assert_eq!(usage["usage"]["total_tokens"], 4); + assert_eq!(payloads[3], "[DONE]"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_raw_generate_emits_empty_finish_chunk() { + let (mut app, engine_task) = test_app_with_stream_output_specs(vec![ + (vec![33], None), + (vec![], Some(EngineCoreFinishReason::Stop)), + ]) + .await; + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "request_id": "raw-stream-empty-finish", + "token_ids": [11, 22], + "stream": true, + "sampling_params": { + "max_tokens": 2 + } + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + let payloads = sse_data_payloads(&text); + assert_eq!(payloads.len(), 3, "{text}"); + + let first: serde_json::Value = serde_json::from_str(payloads[0]).expect("first chunk json"); + assert_eq!(first["choices"][0]["token_ids"], json!([33])); + assert!(first["choices"][0].get("finish_reason").is_none()); + + let second: serde_json::Value = serde_json::from_str(payloads[1]).expect("second chunk json"); + assert_eq!(second["choices"][0]["token_ids"], json!([])); + assert_eq!(second["choices"][0]["finish_reason"], "stop"); + assert_eq!(payloads[2], "[DONE]"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[serial] +async fn stream_raw_generate_error_finish_returns_sse_error() { + let (mut app, engine_task) = + test_app_with_stream_output_specs(vec![(vec![], Some(EngineCoreFinishReason::Error))]) + .await; + + let response = app + .call( + Request::builder() + .method("POST") + .uri("/inference/v1/generate") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "Qwen/Qwen1.5-0.5B-Chat", + "request_id": "raw-stream-error", + "token_ids": [11, 22], + "stream": true, + "stream_options": { + "include_usage": true + }, + "sampling_params": { + "max_tokens": 2 + } + }) + .to_string(), + )) + .expect("build request"), + ) + .await + .expect("call app"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = to_bytes(response.into_body(), usize::MAX).await.expect("read body"); + engine_task.await.expect("mock engine task"); + let text = String::from_utf8(body.to_vec()).expect("utf8 body"); + + assert!(text.contains("\"type\":\"server_error\""), "{text}"); + assert!(text.contains("Internal server error"), "{text}"); + assert!(!text.contains("\"finish_reason\":\"error\""), "{text}"); + assert!(!text.contains("\"usage\":"), "{text}"); + assert!(text.trim_end().ends_with("data: [DONE]"), "{text}"); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)]