Skip to content

Commit 406ef5e

Browse files
committed
feat: logprob / top_logprobs=1 handling
Signed-off-by: Greg Clark <[email protected]>
1 parent 72ec5f5 commit 406ef5e

File tree

16 files changed

+228
-53
lines changed

16 files changed

+228
-53
lines changed

components/backends/sglang/src/dynamo/sglang/worker/main.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import signal
88
import socket
99
import sys
10-
from typing import Any, Dict, Optional, Union
10+
from operator import itemgetter
11+
from typing import Any, Optional
1112

1213
import sglang as sgl
1314
import uvloop
@@ -210,6 +211,7 @@ async def generate(self, request: dict):
210211
else request["batch_token_ids"],
211212
sampling_params=sampling_params,
212213
stream=True,
214+
return_logprob=True,
213215
bootstrap_host=bootstrap_host,
214216
bootstrap_port=bootstrap_port,
215217
bootstrap_room=bootstrap_room,
@@ -231,54 +233,49 @@ async def generate(self, request: dict):
231233
else request["batch_token_ids"],
232234
sampling_params=sampling_params,
233235
stream=True,
236+
return_logprob=True,
234237
)
235238

236239
async for out in self._process_stream(g, unpack=False, is_batch=is_batch):
237240
yield out
238241

239242
async def _process_stream(self, stream_source, unpack: bool, is_batch: bool):
240-
# Initialize based on batch mode
241-
num_output_tokens_so_far: Union[Dict[int, int], int]
242-
if is_batch:
243-
num_output_tokens_so_far = {}
244-
else:
245-
num_output_tokens_so_far = 0
243+
assert not is_batch, "Batch processing is not supported."
244+
num_output_tokens_so_far = 0
246245

247246
async for res in stream_source:
248247
data = res.data() if unpack else res
249248
finish_reason = data["meta_info"]["finish_reason"]
250249

251-
if is_batch:
252-
# Handle batch response
253-
assert isinstance(num_output_tokens_so_far, dict)
254-
index = data.get("index", 0)
255-
if index not in num_output_tokens_so_far:
256-
num_output_tokens_so_far[index] = 0
257-
258-
if finish_reason:
259-
out = {
260-
"token_ids": [],
261-
"finish_reason": finish_reason["type"],
262-
"index": index,
263-
}
264-
else:
265-
next_total_toks = len(data["output_ids"])
266-
new_tokens = data["output_ids"][num_output_tokens_so_far[index] :]
267-
out = {
268-
"token_ids": new_tokens,
269-
"index": index,
270-
}
271-
num_output_tokens_so_far[index] = next_total_toks
250+
# Handle single response
251+
assert isinstance(num_output_tokens_so_far, int)
252+
if finish_reason:
253+
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
272254
else:
273-
# Handle single response
274-
assert isinstance(num_output_tokens_so_far, int)
275-
if finish_reason:
276-
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
277-
else:
278-
next_total_toks = len(data["output_ids"])
279-
out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
280-
num_output_tokens_so_far = next_total_toks
255+
next_total_toks = len(res["meta_info"]["output_token_logprobs"])
256+
new_tokens = list(
257+
map(
258+
itemgetter(1),
259+
res["meta_info"]["output_token_logprobs"][
260+
num_output_tokens_so_far:
261+
],
262+
)
263+
)
264+
new_logprobs = list(
265+
map(
266+
itemgetter(0),
267+
res["meta_info"]["output_token_logprobs"][
268+
num_output_tokens_so_far:
269+
],
270+
)
271+
)
272+
out = {
273+
"token_ids": new_tokens,
274+
"log_probs": new_logprobs,
275+
}
276+
num_output_tokens_so_far = next_total_toks
281277

278+
logging.debug(f"Generated output: {out}")
282279
yield out
283280

284281
async def _prefill_generator(self, prefill):

lib/engines/llamacpp/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ fn run_request(
268268
//text: if output.text.is_empty() { None } else { Some(output.text) },
269269
cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64),
270270
log_probs: None, // TODO output.logprobs
271+
top_logprobs: None,
271272
finish_reason: None,
272273
index: None,
273274
};

lib/engines/mistralrs/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ impl
590590
None => None,
591591
};
592592
#[allow(deprecated)]
593-
let inner = response_generator.create_choice(0, Some(from_assistant), None);
593+
let inner = response_generator.create_choice(0, Some(from_assistant), None, None);
594594
let ann = Annotated{
595595
id: None,
596596
data: Some(inner),

lib/llm/src/backend.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,14 @@ impl
218218
//let mdcsum = self.mdcsum.clone();
219219
let stream = processed_stream.map(move |output| {
220220
output.map_data(|data| {
221+
log::info!("data: {:?}", data);
221222
Ok(BackendOutput {
222223
token_ids: data.token_ids,
223224
tokens: data.tokens.unwrap_or_default(),
224225
text: data.text,
225226
cum_log_probs: data.cum_log_probs,
226227
log_probs: data.log_probs,
228+
top_logprobs: data.top_logprobs,
227229
finish_reason: data.finish_reason,
228230
//mdcsum: mdcsum.clone(),
229231
index: data.index,

lib/llm/src/engines.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
102102
text: None,
103103
cum_log_probs: None,
104104
log_probs: None,
105+
top_logprobs: None,
105106
finish_reason: None,
106107
index: None,
107108
};
@@ -242,11 +243,11 @@ impl
242243
let mut id = 1;
243244
for c in chars_string.chars() {
244245
tokio::time::sleep(*TOKEN_ECHO_DELAY).await;
245-
let response = deltas.create_choice(0, Some(c.to_string()), None);
246+
let response = deltas.create_choice(0, Some(c.to_string()), None, None);
246247
yield Annotated{ id: Some(id.to_string()), data: Some(response), event: None, comment: None };
247248
id += 1;
248249
}
249-
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop));
250+
let response = deltas.create_choice(0, None, Some(async_openai::types::CompletionFinishReason::Stop), None);
250251
yield Annotated { id: Some(id.to_string()), data: Some(response), event: None, comment: None };
251252

252253
};

lib/llm/src/migration.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ impl RetryManager {
166166
#[cfg(test)]
167167
mod tests {
168168
use super::*;
169-
use crate::protocols::common::{SamplingOptions, StopConditions};
169+
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
170170
use dynamo_runtime::pipeline::context::Controller;
171171
use dynamo_runtime::pipeline::AsyncEngine;
172172
use std::sync::atomic::{AtomicU32, Ordering};
@@ -183,6 +183,7 @@ mod tests {
183183
..Default::default()
184184
},
185185
sampling_options: SamplingOptions::default(),
186+
output_options: OutputOptions::default(),
186187
eos_token_ids: vec![],
187188
mdc_sum: None,
188189
annotations: vec![],
@@ -198,6 +199,7 @@ mod tests {
198199
text: Some(format!("token_{}", token_id)),
199200
cum_log_probs: None,
200201
log_probs: None,
202+
top_logprobs: None,
201203
finish_reason: None,
202204
index: None,
203205
})

lib/llm/src/mocker/engine.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
405405
text: None,
406406
cum_log_probs: None,
407407
log_probs: None,
408+
top_logprobs: None,
408409
finish_reason: None,
409410
index: None,
410411
};
@@ -525,7 +526,7 @@ mod integration_tests {
525526
use super::*;
526527
use crate::kv_router::indexer::RouterEvent;
527528
use crate::kv_router::KV_EVENT_SUBJECT;
528-
use crate::protocols::common::{SamplingOptions, StopConditions};
529+
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
529530
use dynamo_runtime::{
530531
pipeline::Context,
531532
pipeline::{network::Ingress, PushRouter},
@@ -641,6 +642,7 @@ mod integration_tests {
641642
..Default::default()
642643
},
643644
sampling_options: SamplingOptions::default(),
645+
output_options: OutputOptions::default(),
644646
eos_token_ids: vec![],
645647
mdc_sum: None,
646648
annotations: vec![format!("dp_rank:{dp_rank}")],

lib/llm/src/preprocessor.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use dynamo_runtime::pipeline::{
3333
use dynamo_runtime::protocols::annotated::{Annotated, AnnotationsProvider};
3434

3535
use crate::protocols::{
36-
common::{SamplingOptionsProvider, StopConditionsProvider},
36+
common::{OutputOptionsProvider, SamplingOptionsProvider, StopConditionsProvider},
3737
openai::{
3838
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
3939
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
@@ -146,6 +146,7 @@ impl OpenAIPreprocessor {
146146
+ AnnotationsProvider
147147
+ SamplingOptionsProvider
148148
+ StopConditionsProvider
149+
+ OutputOptionsProvider
149150
+ NvExtProvider,
150151
>(
151152
&self,
@@ -249,6 +250,7 @@ impl OpenAIPreprocessor {
249250

250251
builder.stop_conditions(stop_conditions);
251252
builder.sampling_options(request.extract_sampling_options()?);
253+
builder.output_options(request.extract_output_options()?);
252254
builder.annotations(request.annotations().unwrap_or_default());
253255
builder.mdc_sum(Some(self.mdcsum.clone()));
254256
builder.estimated_prefix_hit_num_blocks(None);

lib/llm/src/protocols/common.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ pub trait StopConditionsProvider {
4545
fn extract_stop_conditions(&self) -> Result<StopConditions>;
4646
}
4747

48+
pub trait OutputOptionsProvider {
49+
fn extract_output_options(&self) -> Result<OutputOptions>;
50+
}
51+
4852
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
4953
pub enum FinishReason {
5054
#[serde(rename = "eos")]
@@ -179,6 +183,9 @@ pub struct CompletionRequest {
179183
/// are needed.
180184
pub sampling_options: SamplingOptions,
181185

186+
#[builder(default)]
187+
pub output_options: OutputOptions,
188+
182189
/// The computed checksum of the Model Deployment Card (MDC).
183190
#[builder(default)]
184191
pub mdc_sum: Option<String>,

lib/llm/src/protocols/common/llm_backend.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@ use dynamo_runtime::protocols::maybe_error::MaybeError;
2323
pub type TokenType = Option<String>;
2424
pub type LogProbs = Vec<f64>;
2525

26+
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
27+
pub struct TopLogprob {
28+
pub rank: u32,
29+
pub token_id: TokenIdType,
30+
pub token: TokenType,
31+
pub logprob: f64,
32+
}
33+
pub type TopLogprobs = Vec<Vec<TopLogprob>>; // num_tokens x top_logprobs
34+
2635
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
2736
pub struct BackendOutput {
2837
/// New token_ids generated from the LLM Engine
@@ -41,6 +50,8 @@ pub struct BackendOutput {
4150
/// Optional log probabilities
4251
pub log_probs: Option<LogProbs>,
4352

53+
pub top_logprobs: Option<TopLogprobs>,
54+
4455
// TODO: Enrich this with more information as can apply our first-level postprocessing
4556
// logic and return more detailed information
4657
pub finish_reason: Option<FinishReason>,
@@ -77,6 +88,8 @@ pub struct LLMEngineOutput {
7788
/// Optional log probabilities
7889
pub log_probs: Option<LogProbs>,
7990

91+
pub top_logprobs: Option<TopLogprobs>,
92+
8093
// TODO: Enrich this with more information as can apply our first-level postprocessing
8194
// logic and return more detailed information
8295
pub finish_reason: Option<FinishReason>,
@@ -93,6 +106,7 @@ impl LLMEngineOutput {
93106
text: None,
94107
cum_log_probs: None,
95108
log_probs: None,
109+
top_logprobs: None,
96110
finish_reason: Some(FinishReason::Cancelled),
97111
index: None,
98112
}
@@ -106,6 +120,7 @@ impl LLMEngineOutput {
106120
cum_log_probs: None,
107121
log_probs: None,
108122
finish_reason: Some(FinishReason::Stop),
123+
top_logprobs: None,
109124
index: None,
110125
}
111126
}
@@ -117,6 +132,7 @@ impl LLMEngineOutput {
117132
text: None,
118133
cum_log_probs: None,
119134
log_probs: None,
135+
top_logprobs: None,
120136
finish_reason: Some(FinishReason::Length),
121137
index: None,
122138
}
@@ -129,6 +145,7 @@ impl LLMEngineOutput {
129145
text: None,
130146
cum_log_probs: None,
131147
log_probs: None,
148+
top_logprobs: None,
132149
finish_reason: Some(FinishReason::Error(err_msg)),
133150
index: None,
134151
}

0 commit comments

Comments
 (0)