Skip to content

Commit b19deaf

Browse files
messiaentedzhouhk
authored andcommitted
fix: aggregate logprobs (#2928)
Signed-off-by: Greg Clark <[email protected]> Signed-off-by: hongkuanz <[email protected]>
1 parent 1a412eb commit b19deaf

File tree

2 files changed

+102
-9
lines changed

2 files changed

+102
-9
lines changed

lib/llm/src/protocols/openai/chat_completions/aggregator.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ impl DeltaAggregator {
129129
text: "".to_string(),
130130
role: choice.delta.role,
131131
finish_reason: None,
132-
logprobs: choice.logprobs,
132+
logprobs: None,
133133
tool_calls: None,
134134
reasoning_content: None,
135135
});
@@ -150,6 +150,28 @@ impl DeltaAggregator {
150150
if let Some(finish_reason) = choice.finish_reason {
151151
state_choice.finish_reason = Some(finish_reason);
152152
}
153+
154+
// Update logprobs
155+
if let Some(logprobs) = &choice.logprobs {
156+
let state_lps = state_choice.logprobs.get_or_insert(
157+
dynamo_async_openai::types::ChatChoiceLogprobs {
158+
content: None,
159+
refusal: None,
160+
},
161+
);
162+
if let Some(content_lps) = &logprobs.content {
163+
state_lps
164+
.content
165+
.get_or_insert(Vec::new())
166+
.extend(content_lps.clone());
167+
}
168+
if let Some(refusal_lps) = &logprobs.refusal {
169+
state_lps
170+
.refusal
171+
.get_or_insert(Vec::new())
172+
.extend(refusal_lps.clone());
173+
}
174+
}
153175
}
154176
}
155177
aggregator
@@ -305,6 +327,7 @@ mod tests {
305327
text: &str,
306328
role: Option<dynamo_async_openai::types::Role>,
307329
finish_reason: Option<dynamo_async_openai::types::FinishReason>,
330+
logprob: Option<f32>,
308331
) -> Annotated<NvCreateChatCompletionStreamResponse> {
309332
// ALLOW: function_call is deprecated
310333
let delta = dynamo_async_openai::types::ChatCompletionStreamResponseDelta {
@@ -315,11 +338,22 @@ mod tests {
315338
refusal: None,
316339
reasoning_content: None,
317340
};
341+
let logprobs = logprob.map(|lp| dynamo_async_openai::types::ChatChoiceLogprobs {
342+
content: Some(vec![
343+
dynamo_async_openai::types::ChatCompletionTokenLogprob {
344+
token: text.to_string(),
345+
logprob: lp,
346+
bytes: None,
347+
top_logprobs: vec![],
348+
},
349+
]),
350+
refusal: None,
351+
});
318352
let choice = dynamo_async_openai::types::ChatChoiceStream {
319353
index,
320354
delta,
321355
finish_reason,
322-
logprobs: None,
356+
logprobs,
323357
};
324358

325359
let data = NvCreateChatCompletionStreamResponse {
@@ -372,6 +406,7 @@ mod tests {
372406
"Hello,",
373407
Some(dynamo_async_openai::types::Role::User),
374408
None,
409+
None,
375410
);
376411

377412
// Create a stream
@@ -409,12 +444,14 @@ mod tests {
409444
"Hello,",
410445
Some(dynamo_async_openai::types::Role::User),
411446
None,
447+
Some(-0.1),
412448
);
413449
let annotated_delta2 = create_test_delta(
414450
0,
415451
" world!",
416452
None,
417453
Some(dynamo_async_openai::types::FinishReason::Stop),
454+
Some(-0.2),
418455
);
419456

420457
// Create a stream
@@ -438,6 +475,25 @@ mod tests {
438475
Some(dynamo_async_openai::types::FinishReason::Stop)
439476
);
440477
assert_eq!(choice.message.role, dynamo_async_openai::types::Role::User);
478+
assert_eq!(
479+
choice
480+
.logprobs
481+
.as_ref()
482+
.unwrap()
483+
.content
484+
.as_ref()
485+
.unwrap()
486+
.len(),
487+
2
488+
);
489+
assert_eq!(
490+
choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[0].logprob,
491+
-0.1
492+
);
493+
assert_eq!(
494+
choice.logprobs.as_ref().unwrap().content.as_ref().unwrap()[1].logprob,
495+
-0.2
496+
);
441497
}
442498

443499
#[allow(deprecated)]
@@ -538,6 +594,7 @@ mod tests {
538594
tool_call_json,
539595
Some(dynamo_async_openai::types::Role::Assistant),
540596
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
597+
None,
541598
);
542599
let data = annotated_delta.data.unwrap();
543600

@@ -598,6 +655,7 @@ mod tests {
598655
tool_call_json,
599656
Some(dynamo_async_openai::types::Role::Assistant),
600657
Some(dynamo_async_openai::types::FinishReason::ToolCalls),
658+
None,
601659
);
602660
let data = annotated_delta.data.unwrap();
603661

lib/llm/src/protocols/openai/completions/aggregator.rs

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl DeltaAggregator {
9595
index: choice.index,
9696
text: "".to_string(),
9797
finish_reason: None,
98-
logprobs: choice.logprobs,
98+
logprobs: None,
9999
});
100100

101101
state_choice.text.push_str(&choice.text);
@@ -115,6 +115,24 @@ impl DeltaAggregator {
115115
) => Some(FinishReason::ContentFilter),
116116
None => None,
117117
};
118+
119+
// Update logprobs
120+
if let Some(logprobs) = &choice.logprobs {
121+
let state_lps = state_choice.logprobs.get_or_insert(
122+
dynamo_async_openai::types::Logprobs {
123+
tokens: Vec::new(),
124+
token_logprobs: Vec::new(),
125+
top_logprobs: Vec::new(),
126+
text_offset: Vec::new(),
127+
},
128+
);
129+
state_lps.tokens.extend(logprobs.tokens.clone());
130+
state_lps
131+
.token_logprobs
132+
.extend(logprobs.token_logprobs.clone());
133+
state_lps.top_logprobs.extend(logprobs.top_logprobs.clone());
134+
state_lps.text_offset.extend(logprobs.text_offset.clone());
135+
}
118136
}
119137
}
120138
aggregator
@@ -196,6 +214,7 @@ mod tests {
196214
index: u32,
197215
text: &str,
198216
finish_reason: Option<String>,
217+
logprob: Option<f32>,
199218
) -> Annotated<NvCreateCompletionResponse> {
200219
// This will silently discard invalid_finish reason values and fall back
201220
// to None - totally fine since this is test code
@@ -204,6 +223,20 @@ mod tests {
204223
.and_then(|s| FinishReason::from_str(s).ok())
205224
.map(Into::into);
206225

226+
let logprobs = logprob.map(|lp| dynamo_async_openai::types::Logprobs {
227+
tokens: vec![text.to_string()],
228+
token_logprobs: vec![Some(lp)],
229+
top_logprobs: vec![
230+
serde_json::to_value(dynamo_async_openai::types::TopLogprobs {
231+
token: text.to_string(),
232+
logprob: lp,
233+
bytes: None,
234+
})
235+
.unwrap(),
236+
],
237+
text_offset: vec![0],
238+
});
239+
207240
let inner = dynamo_async_openai::types::CreateCompletionResponse {
208241
id: "test_id".to_string(),
209242
model: "meta/llama-3.1-8b".to_string(),
@@ -214,7 +247,7 @@ mod tests {
214247
index,
215248
text: text.to_string(),
216249
finish_reason,
217-
logprobs: None,
250+
logprobs,
218251
}],
219252
object: "text_completion".to_string(),
220253
};
@@ -253,7 +286,7 @@ mod tests {
253286
#[tokio::test]
254287
async fn test_single_delta() {
255288
// Create a sample delta
256-
let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()));
289+
let annotated_delta = create_test_delta(0, "Hello,", Some("length".to_string()), None);
257290

258291
// Create a stream
259292
let stream = Box::pin(stream::iter(vec![annotated_delta]));
@@ -291,8 +324,9 @@ mod tests {
291324
// Create multiple deltas with the same choice index
292325
// One will have a MessageRole and no FinishReason,
293326
// the other will have a FinishReason and no MessageRole
294-
let annotated_delta1 = create_test_delta(0, "Hello,", None);
295-
let annotated_delta2 = create_test_delta(0, " world!", Some("stop".to_string()));
327+
let annotated_delta1 = create_test_delta(0, "Hello,", None, Some(-0.1));
328+
let annotated_delta2 =
329+
create_test_delta(0, " world!", Some("stop".to_string()), Some(-0.2));
296330

297331
// Create a stream
298332
let annotated_deltas = vec![annotated_delta1, annotated_delta2];
@@ -314,9 +348,10 @@ mod tests {
314348
choice.finish_reason,
315349
Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
316350
);
351+
assert_eq!(choice.logprobs.as_ref().unwrap().tokens.len(), 2);
317352
assert_eq!(
318-
choice.finish_reason,
319-
Some(dynamo_async_openai::types::CompletionFinishReason::Stop)
353+
choice.logprobs.as_ref().unwrap().token_logprobs,
354+
vec![Some(-0.1), Some(-0.2)]
320355
);
321356
}
322357

0 commit comments

Comments
 (0)