Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 217 additions & 5 deletions rust/src/server/src/routes/inference/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,33 @@ mod types;
mod validate;

use std::collections::HashMap;
use std::convert::Infallible;
use std::result::Result;
use std::sync::Arc;

use asynk_strim_attr::{TryYielder, try_stream};
use axum::Json;
use axum::extract::State;
use axum::http::HeaderMap;
use axum::response::sse::{Event, Sse};
use axum::response::{IntoResponse, Response};
use futures::{Stream, StreamExt as _, pin_mut};
use thiserror_ext::AsReport as _;
use tracing::info;
use tracing::{error, info, trace};
use tracing_futures::Instrument as _;
use vllm_engine_core_client::protocol::logprobs::{Logprobs, PositionLogprobs};
use vllm_llm::{CollectedGenerateOutput, GenerateOutputStreamExt as _};
use vllm_llm::{
CollectedGenerateOutput, FinishReason, GenerateOutput, GenerateOutputStreamExt as _,
};

use self::convert::prepare_generate_request;
use self::types::{GenerateLogprob, GenerateRequest, GenerateResponse, GenerateResponseChoice};
use crate::error::{ApiError, server_error};
use self::types::{
GenerateLogprob, GenerateRequest, GenerateResponse, GenerateResponseChoice,
GenerateResponseStreamChoice, GenerateStreamResponse,
};
use crate::error::{ApiError, bail_server_error, server_error};
use crate::routes::openai::utils::logprobs::clamp_logprob;
use crate::routes::openai::utils::types::{ChatLogProbs, ChatLogProbsContent, TopLogProb};
use crate::routes::openai::utils::types::{ChatLogProbs, ChatLogProbsContent, TopLogProb, Usage};
use crate::routes::openai::utils::validated_json::ValidatedJson;
use crate::state::AppState;
use crate::utils::resolve_request_context;
Expand All @@ -46,6 +56,7 @@ pub async fn generate(
let log_request = state.enable_log_requests;
let include_logprobs = prepared.include_logprobs;
let include_prompt_logprobs = prepared.include_prompt_logprobs;
let stream = prepared.stream;

let raw_stream = match state
.chat
Expand All @@ -64,6 +75,20 @@ pub async fn generate(
}
};

if stream {
let chunk_stream = generate_chunk_stream(
raw_stream,
prepared.request_id,
log_request,
prepared.include_usage,
prepared.include_continuous_usage,
include_logprobs,
);
let sse_stream = generate_sse_stream(chunk_stream).instrument(request_span);

return Sse::new(sse_stream).into_response();
}

let collected = match raw_stream.collect_output().instrument(request_span.clone()).await {
Ok(collected) => collected,
Err(error) => {
Expand Down Expand Up @@ -98,6 +123,102 @@ pub async fn generate(
Json(response).into_response()
}

#[try_stream]
async fn generate_chunk_stream(
stream: impl Stream<Item = vllm_llm::Result<GenerateOutput>>,
request_id: String,
log_request: bool,
include_usage: bool,
include_continuous_usage: bool,
include_logprobs: bool,
mut y: TryYielder<GenerateStreamResponse, ApiError>,
) -> Result<(), ApiError> {
pin_mut!(stream);
let mut prompt_tokens: Option<u32> = None;
let mut output_tokens = 0_u32;

while let Some(next) = stream.next().await {
match next {
Ok(output) => {
if prompt_tokens.is_none() {
prompt_tokens =
output.prompt_info.as_ref().map(|info| info.prompt_token_ids.len() as u32);
}
let usage_prompt_tokens = prompt_tokens.unwrap_or_default();

let token_ids = output.token_ids;
output_tokens = output_tokens.saturating_add(token_ids.len() as u32);
let finish_reason = output.finish_reason;

if matches!(finish_reason.as_ref(), Some(FinishReason::Error)) {
bail_server_error!("Internal server error");
}

if let Some(finish_reason) = finish_reason.as_ref()
&& log_request
{
info!(
stream = true,
prompt_tokens = usage_prompt_tokens,
output_tokens,
finish_reason = finish_reason.as_str(),
"generate finished"
);
}

if token_ids.is_empty() && finish_reason.is_none() {
continue;
}

let logprobs = if include_logprobs && !token_ids.is_empty() {
let logprobs = output.logprobs.as_ref().ok_or_else(|| {
server_error!(
"raw generate stream requested logprobs but generation returned none"
)
})?;
Some(raw_logprobs_to_openai_chat(logprobs)?)
} else {
None
};

y.yield_ok(GenerateStreamResponse {
request_id: request_id.clone(),
choices: vec![GenerateResponseStreamChoice {
index: 0,
logprobs,
finish_reason: finish_reason.map(|reason| reason.as_str().to_string()),
Comment thread
Xunzhuo marked this conversation as resolved.
token_ids,
}],
usage: include_continuous_usage
.then(|| Usage::from_counts(usage_prompt_tokens, output_tokens)),
})
.await;
}
Err(error) => {
error!(
error = %error.as_report(),
"raw generate stream failed"
);
bail_server_error!("{}", error.to_report_string());
}
}
}

if include_usage {
y.yield_ok(GenerateStreamResponse {
request_id,
choices: Vec::new(),
usage: Some(Usage::from_counts(
prompt_tokens.unwrap_or_default(),
output_tokens,
)),
})
.await;
}

Ok(())
}

fn collect_generate(
collected: CollectedGenerateOutput,
request_id: String,
Expand Down Expand Up @@ -213,3 +334,94 @@ fn position_to_logprob_map(position: &PositionLogprobs) -> HashMap<u32, Generate
fn format_token_id(token_id: u32) -> String {
format!("token_id:{token_id}")
}

/// Convert one raw-generate chunk stream into SSE events.
#[try_stream]
async fn generate_sse_stream(
stream: impl Stream<Item = Result<GenerateStreamResponse, ApiError>>,
mut y: TryYielder<Event, Infallible>,
) -> Result<(), Infallible> {
pin_mut!(stream);

while let Some(next) = stream.next().await {
match next {
Ok(chunk) => y.yield_ok(to_sse_event(&chunk)).await,
Err(error) => {
y.yield_ok(to_error_sse_event(&error)).await;
break;
}
}
}

y.yield_ok(done_sse_event()).await;
Ok(())
}

fn to_sse_event(chunk: &GenerateStreamResponse) -> Event {
let payload = serde_json::to_string(chunk).expect("generate chunk must serialize to JSON");
trace!(payload, "generate emitting chunk");
Event::default().data(payload)
}

fn to_error_sse_event(error: &ApiError) -> Event {
let payload = serde_json::to_string(&error.to_error_response())
.expect("ErrorResponse must serialize to JSON");
trace!(payload, "generate emitting error");
Event::default().data(payload)
}

fn done_sse_event() -> Event {
trace!("generate emitting done");
Event::default().data("[DONE]")
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use futures::{TryStreamExt as _, stream};
use vllm_llm::GeneratePromptInfo;

use super::*;

#[tokio::test]
async fn generate_chunk_stream_captures_late_prompt_info() {
let stream = stream::iter(vec![
Ok(GenerateOutput {
request_id: String::new(),
prompt_info: None,
token_ids: Vec::new(),
logprobs: None,
finish_reason: None,
kv_transfer_params: None,
}),
Ok(GenerateOutput {
request_id: String::new(),
prompt_info: Some(GeneratePromptInfo {
prompt_token_ids: Arc::from([11_u32, 22_u32]),
prompt_logprobs: None,
}),
token_ids: vec![33],
logprobs: None,
finish_reason: Some(FinishReason::stop_eos()),
kv_transfer_params: None,
}),
]);

let chunks: Vec<_> =
generate_chunk_stream(stream, "raw-stream".to_string(), false, true, true, false)
.try_collect()
.await
.expect("collect chunks");

assert_eq!(chunks.len(), 2);
assert_eq!(
chunks[0].usage.as_ref().expect("chunk usage").prompt_tokens,
2
);
assert_eq!(
chunks[1].usage.as_ref().expect("final usage").prompt_tokens,
2
);
}
}
42 changes: 42 additions & 0 deletions rust/src/server/src/routes/inference/generate/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use crate::utils::{ResolvedRequestContext, merge_kv_transfer_params};
pub struct PreparedRequest {
pub request_id: String,
pub text_request: TextRequest,
pub stream: bool,
pub include_usage: bool,
pub include_continuous_usage: bool,
pub include_logprobs: bool,
pub include_prompt_logprobs: bool,
}
Expand All @@ -23,6 +26,18 @@ pub fn prepare_generate_request(
) -> Result<PreparedRequest, ApiError> {
validate::validate_request_compat(&request, served_model_names)?;

let stream = request.stream;
let include_usage = request
.stream_options
.as_ref()
.and_then(|options| options.include_usage)
.unwrap_or(false);
let include_continuous_usage = include_usage
&& request
.stream_options
.as_ref()
.and_then(|options| options.continuous_usage_stats)
.unwrap_or(false);
let include_logprobs = request.sampling_params.logprobs.is_some();
let include_prompt_logprobs = request.sampling_params.prompt_logprobs.is_some();
let mut sampling_params = request.sampling_params;
Expand All @@ -47,6 +62,9 @@ pub fn prepare_generate_request(
Ok(PreparedRequest {
request_id: ctx.request_id,
text_request,
stream,
include_usage,
include_continuous_usage,
include_logprobs,
include_prompt_logprobs,
})
Expand Down Expand Up @@ -109,4 +127,28 @@ mod tests {
Some(json!({"connector": "x"}))
);
}

#[test]
fn prepare_generate_request_gates_continuous_usage_on_include_usage() {
let request: GenerateRequest = serde_json::from_value(json!({
"model": "Qwen/Qwen1.5-0.5B-Chat",
"token_ids": [11, 22],
"stream": true,
"stream_options": {
"continuous_usage_stats": true
},
"sampling_params": {}
}))
.expect("parse request");

let prepared = prepare_generate_request(
request,
&["Qwen/Qwen1.5-0.5B-Chat".to_string()],
ResolvedRequestContext::default(),
)
.expect("prepare");

assert!(!prepared.include_usage);
assert!(!prepared.include_continuous_usage);
}
}
22 changes: 21 additions & 1 deletion rust/src/server/src/routes/inference/generate/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde_json::{Map, Value};
use validator::Validate;
use vllm_text::SamplingParams;

use crate::routes::openai::utils::types::{ChatLogProbs, Normalizable};
use crate::routes::openai::utils::types::{ChatLogProbs, Normalizable, StreamOptions, Usage};

/// vLLM-compatible request type for the token-in/token-out generate API.
#[serde_with::skip_serializing_none]
Expand All @@ -17,6 +17,7 @@ pub struct GenerateRequest {
pub sampling_params: SamplingParams,
#[serde(default)]
pub stream: bool,
pub stream_options: Option<StreamOptions>,
pub cache_salt: Option<String>,
#[serde(default)]
pub priority: i32,
Expand All @@ -37,6 +38,25 @@ pub(super) struct GenerateResponseChoice {
pub token_ids: Vec<u32>,
}

/// Mirrors the Python vLLM `GenerateResponseStreamChoice` class.
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize)]
pub(super) struct GenerateResponseStreamChoice {
pub index: u32,
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
pub token_ids: Vec<u32>,
}

/// Mirrors the Python vLLM `GenerateStreamResponse` class.
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize)]
pub(super) struct GenerateStreamResponse {
pub request_id: String,
pub choices: Vec<GenerateResponseStreamChoice>,
pub usage: Option<Usage>,
}

/// Mirrors the Python vLLM `GenerateResponse` class.
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize)]
Expand Down
Loading
Loading