Skip to content

Commit 1b2b891

Browse files
committed
chore: added utility to detect possible tool call start for a chunk
Signed-off-by: ayushag <[email protected]>
1 parent 1477f6e commit 1b2b891

File tree

8 files changed

+391
-18
lines changed

8 files changed

+391
-18
lines changed

lib/parsers/src/tool_calling/harmony/harmony_parser.rs

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ pub fn parse_tool_calls_harmony(
2929
// Check if tool call start tokens are present, if not return everything as normal text
3030
// Start Token: "<|start|>assistant<|channel|>commentary" should be present in the text if tool calls are present
3131
// End Token: "<|call|>"
32-
if !config
33-
.tool_call_start_tokens
34-
.iter()
35-
.any(|token| trimmed.contains(token))
36-
{
32+
if !detect_tool_call_start_harmony(text, config).unwrap_or(false) {
3733
return Ok((vec![], Some(trimmed)));
3834
}
3935

@@ -158,6 +154,24 @@ pub fn parse_tool_calls_harmony(
158154
Ok((res, Some(normal_text.to_string())))
159155
}
160156

157+
pub fn detect_tool_call_start_harmony(
158+
chunk: &str,
159+
config: &JsonParserConfig,
160+
) -> anyhow::Result<bool> {
161+
let trimmed = chunk.trim();
162+
if trimmed.is_empty() {
163+
return Ok(false);
164+
}
165+
if config
166+
.tool_call_start_tokens
167+
.iter()
168+
.any(|token| trimmed.contains(token))
169+
{
170+
return Ok(true);
171+
}
172+
Ok(false)
173+
}
174+
161175
#[cfg(test)]
162176
mod tests {
163177
use super::*;
@@ -270,3 +284,32 @@ mod tests {
270284
assert_eq!(args["unit"], "celsius");
271285
}
272286
}
287+
288+
#[cfg(test)]
289+
mod detect_parser_tests {
290+
use super::*;
291+
292+
#[test]
293+
fn test_detect_tool_call_start_harmony_chunk_with_tool_call_start_token() {
294+
let text = r#"<|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json"#;
295+
let config = JsonParserConfig {
296+
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
297+
tool_call_end_tokens: vec!["<|call|>".to_string()],
298+
..Default::default()
299+
};
300+
let result = detect_tool_call_start_harmony(text, &config).unwrap();
301+
assert!(result);
302+
}
303+
304+
#[test]
305+
fn test_detect_tool_call_start_harmony_chunk_without_tool_call_start_token() {
306+
let text = r#"<|channel|>commentary to=functions.get_current_weather"#;
307+
let config = JsonParserConfig {
308+
tool_call_start_tokens: vec!["<|start|>assistant<|channel|>commentary".to_string()],
309+
tool_call_end_tokens: vec!["<|call|>".to_string()],
310+
..Default::default()
311+
};
312+
let result = detect_tool_call_start_harmony(text, &config).unwrap();
313+
assert!(!result);
314+
}
315+
}

lib/parsers/src/tool_calling/harmony/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
pub mod harmony_parser;
55

66
pub use super::{config, response};
7-
pub use harmony_parser::parse_tool_calls_harmony;
7+
pub use harmony_parser::{detect_tool_call_start_harmony, parse_tool_calls_harmony};

lib/parsers/src/tool_calling/json/base_json_parser.rs

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,141 @@ pub fn try_tool_call_parse_basic_json(
306306

307307
Ok((vec![], Some(trimmed.to_string())))
308308
}
309+
310+
pub fn detect_tool_call_start_basic_json(
311+
chunk: &str,
312+
config: &JsonParserConfig,
313+
) -> anyhow::Result<bool> {
314+
// Case 1: If there is any of the start tokens in the chunk, return true
315+
if config
316+
.tool_call_start_tokens
317+
.iter()
318+
.any(|token| chunk.contains(token))
319+
{
320+
return Ok(true);
321+
}
322+
323+
// Case 2: If there is any "{" or "[" in the chunk, return true
324+
// This case will lead to false positives for those models which does not emit tool call start tokens
325+
if chunk.contains("{") || chunk.contains("[") {
326+
return Ok(true);
327+
}
328+
Ok(false)
329+
}
330+
331+
#[cfg(test)]
332+
mod detect_parser_tests {
333+
use super::*;
334+
335+
#[test]
336+
fn detect_tool_call_start_basic_json_chunk_with_tool_call_start_token_hermes() {
337+
let text =
338+
r#"<tool_call>{"name": "search", "parameters": { "query": "rust" } }</tool_call>"#;
339+
let config = JsonParserConfig {
340+
tool_call_start_tokens: vec!["<tool_call>".to_string()],
341+
tool_call_end_tokens: vec!["</tool_call>".to_string()],
342+
..Default::default()
343+
};
344+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
345+
assert!(result);
346+
}
347+
348+
#[test]
349+
fn detect_tool_call_start_basic_json_chunk_without_tool_call_start_token() {
350+
let text = r#"{"name": "search", "parameters": { "query": "rust" } }"#;
351+
let config = JsonParserConfig {
352+
tool_call_start_tokens: vec!["<tool_call>".to_string()],
353+
tool_call_end_tokens: vec!["</tool_call>".to_string()],
354+
..Default::default()
355+
};
356+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
357+
assert!(result);
358+
}
359+
360+
#[test]
361+
fn detect_tool_call_start_basic_json_chunk_without_tool_call_start_token_with_normal_text() {
362+
let text = r#"Here it is {"name": "#;
363+
let config = JsonParserConfig {
364+
tool_call_start_tokens: vec!["<tool_call>".to_string()],
365+
tool_call_end_tokens: vec!["</tool_call>".to_string()],
366+
..Default::default()
367+
};
368+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
369+
assert!(result);
370+
}
371+
372+
#[test]
373+
fn detect_tool_call_start_basic_json_chunk_with_square_brackets() {
374+
// These kind of false positives are expected when calling this function for stream=True
375+
let text = r#"Here it is [{"name": "search","#;
376+
let config = JsonParserConfig {
377+
tool_call_start_tokens: vec!["<tool_call>".to_string()],
378+
tool_call_end_tokens: vec!["</tool_call>".to_string()],
379+
..Default::default()
380+
};
381+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
382+
assert!(result);
383+
}
384+
385+
#[test]
386+
fn detect_tool_call_start_basic_json_chunk_false_positive() {
387+
// These kind of false positives are expected when calling this function for stream=True
388+
let text = r#"Here it is { Whats up"#;
389+
let config = JsonParserConfig {
390+
tool_call_start_tokens: vec!["<tool_call>".to_string()],
391+
tool_call_end_tokens: vec!["</tool_call>".to_string()],
392+
..Default::default()
393+
};
394+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
395+
assert!(result);
396+
}
397+
398+
#[test]
399+
fn detect_tool_call_start_basic_json_chunk_with_tool_call_start_token_nemotron_deci() {
400+
let text =
401+
r#"<TOOLCALL>[{"name": "search", "parameters": { "query": "rust" } }]</TOOLCALL>"#;
402+
let config = JsonParserConfig {
403+
tool_call_start_tokens: vec!["<TOOLCALL>".to_string()],
404+
tool_call_end_tokens: vec!["</TOOLCALL>".to_string()],
405+
..Default::default()
406+
};
407+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
408+
assert!(result);
409+
}
410+
411+
#[test]
412+
fn detect_tool_call_start_basic_json_chunk_with_lllama3_json_token() {
413+
let text = r#"<|python_tag|>{ "name": }"#;
414+
let config = JsonParserConfig {
415+
tool_call_start_tokens: vec!["<|python_tag|>".to_string()],
416+
tool_call_end_tokens: vec!["".to_string()],
417+
..Default::default()
418+
};
419+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
420+
assert!(result);
421+
}
422+
423+
#[test]
424+
fn detect_tool_call_start_basic_json_chunk_mistral_token() {
425+
let text = r#"Hello Yo ! [TOOL_CALLS]{"name": "search", "#;
426+
let config = JsonParserConfig {
427+
tool_call_start_tokens: vec!["[TOOL_CALLS]".to_string()],
428+
tool_call_end_tokens: vec!["".to_string()],
429+
..Default::default()
430+
};
431+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
432+
assert!(result);
433+
}
434+
435+
#[test]
436+
fn detect_tool_call_start_basic_json_chunk_phi4_token() {
437+
let text = r#"functools{"name": "search", "#;
438+
let config = JsonParserConfig {
439+
tool_call_start_tokens: vec!["functools".to_string()],
440+
tool_call_end_tokens: vec!["".to_string()],
441+
..Default::default()
442+
};
443+
let result = detect_tool_call_start_basic_json(text, &config).unwrap();
444+
assert!(result);
445+
}
446+
}

lib/parsers/src/tool_calling/json/deepseek_parser.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,7 @@ pub fn parse_tool_calls_deepseek_v3_1(
4343
}
4444

4545
// If tool call start token is not present then, no tool calls are there, return empty tool calls and the original trimmed string
46-
if let Some(start_token) = tool_call_start_tokens.first() {
47-
if !trimmed.contains(start_token) {
48-
return Ok((vec![], Some(trimmed.to_string())));
49-
}
50-
} else {
51-
// Invalid start token
46+
if !detect_tool_call_start_deepseek_v3_1(trimmed, config).unwrap_or(false) {
5247
return Ok((vec![], Some(trimmed.to_string())));
5348
}
5449

@@ -106,6 +101,21 @@ pub fn parse_tool_calls_deepseek_v3_1(
106101
Ok((tool_calls, Some(normal_text)))
107102
}
108103

104+
pub fn detect_tool_call_start_deepseek_v3_1(
105+
chunk: &str,
106+
config: &JsonParserConfig,
107+
) -> anyhow::Result<bool> {
108+
// if chunk contains tool_call_start_tokens then return true
109+
if config
110+
.tool_call_start_tokens
111+
.iter()
112+
.any(|token| chunk.contains(token))
113+
{
114+
return Ok(true);
115+
}
116+
Ok(false)
117+
}
118+
109119
#[cfg(test)]
110120
mod tests {
111121
use super::*;
@@ -220,3 +230,43 @@ mod tests {
220230
assert_eq!(result.len(), 0);
221231
}
222232
}
233+
234+
#[cfg(test)]
235+
mod detect_parser_tests {
236+
use super::*;
237+
#[test]
238+
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token() {
239+
let text = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
240+
let config = JsonParserConfig {
241+
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
242+
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
243+
..Default::default()
244+
};
245+
let result = detect_tool_call_start_deepseek_v3_1(text, &config).unwrap();
246+
assert!(result);
247+
}
248+
249+
#[test]
250+
fn test_detect_tool_call_start_deepseek_v3_1_chunk_without_tool_call_start_token() {
251+
let text = r#"<|tool▁call▁begin|>get_current_weather宽带}"#;
252+
let config = JsonParserConfig {
253+
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
254+
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
255+
..Default::default()
256+
};
257+
let result = detect_tool_call_start_deepseek_v3_1(text, &config).unwrap();
258+
assert!(!result);
259+
}
260+
261+
#[test]
262+
fn test_detect_tool_call_start_deepseek_v3_1_chunk_with_tool_call_start_token_in_middle() {
263+
let text = r#"The following tool calls retrieve weather information: <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather宽带}"#;
264+
let config = JsonParserConfig {
265+
tool_call_start_tokens: vec!["<|tool▁calls▁begin|>".to_string()],
266+
tool_call_end_tokens: vec!["<|tool▁calls▁end|>".to_string()],
267+
..Default::default()
268+
};
269+
let result = detect_tool_call_start_deepseek_v3_1(text, &config).unwrap();
270+
assert!(result);
271+
}
272+
}

lib/parsers/src/tool_calling/json/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ pub mod base_json_parser;
55
pub mod deepseek_parser;
66

77
pub use super::{config, response};
8-
pub use base_json_parser::try_tool_call_parse_basic_json;
9-
pub use deepseek_parser::parse_tool_calls_deepseek_v3_1;
8+
pub use base_json_parser::{detect_tool_call_start_basic_json, try_tool_call_parse_basic_json};
9+
pub use deepseek_parser::{detect_tool_call_start_deepseek_v3_1, parse_tool_calls_deepseek_v3_1};
1010

1111
pub use super::config::JsonParserConfig;
1212
pub use super::response::ToolCallResponse;
@@ -34,3 +34,10 @@ pub fn try_tool_call_parse_json(
3434
JsonParserType::DeepseekV31 => parse_tool_calls_deepseek_v3_1(message, config),
3535
}
3636
}
37+
38+
pub fn detect_tool_call_start_json(chunk: &str, config: &JsonParserConfig) -> anyhow::Result<bool> {
39+
match config.parser_type {
40+
JsonParserType::Basic => detect_tool_call_start_basic_json(chunk, config),
41+
JsonParserType::DeepseekV31 => detect_tool_call_start_deepseek_v3_1(chunk, config),
42+
}
43+
}

0 commit comments

Comments
 (0)