From bd286283aafa7a2a793170f0b3264b70f90ebd68 Mon Sep 17 00:00:00 2001 From: Joshua Mo Date: Mon, 11 Aug 2025 01:28:16 +0100 Subject: [PATCH] feat(rig-863): add retries to extractor tool --- rig-core/examples/extractor.rs | 3 +++ rig-core/src/extractor.rs | 40 +++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/rig-core/examples/extractor.rs b/rig-core/examples/extractor.rs index 54356e660..645afffc8 100644 --- a/rig-core/examples/extractor.rs +++ b/rig-core/examples/extractor.rs @@ -7,10 +7,13 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, JsonSchema, Serialize)] struct Person { /// The person's first name, if provided (null otherwise) + #[schemars(required)] pub first_name: Option, /// The person's last name, if provided (null otherwise) + #[schemars(required)] pub last_name: Option, /// The person's job, if provided (null otherwise) + #[schemars(required)] pub job: Option, } diff --git a/rig-core/src/extractor.rs b/rig-core/src/extractor.rs index ee8419768..6c49dc288 100644 --- a/rig-core/src/extractor.rs +++ b/rig-core/src/extractor.rs @@ -59,13 +59,43 @@ pub enum ExtractionError { pub struct Extractor Deserialize<'a> + Send + Sync> { agent: Agent, _t: PhantomData, + retries: u64, } impl Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor where M: Sync, { + /// Attempts to extract data from the given text with a number of retries. + /// + /// The function will retry the extraction if the initial attempt fails or + /// if the model does not call the `submit` tool. + /// + /// The number of retries is determined by the `retries` field on the Extractor struct. pub async fn extract(&self, text: impl Into + Send) -> Result { + let mut last_error = None; + let text_message = text.into(); + + for i in 0..=self.retries { + tracing::debug!( + "Attempting to extract JSON. Retries left: {retries}", + retries = self.retries - i + ); + let attempt_text = text_message.clone(); + match self.extract_json(attempt_text).await { + Ok(data) => return Ok(data), + Err(e) => { + tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying..."); + last_error = Some(e); + } + } + } + + // If the loop finishes without a successful extraction, return the last error encountered. + Err(last_error.unwrap_or(ExtractionError::NoData)) + } + + async fn extract_json(&self, text: impl Into + Send) -> Result { let response = self.agent.completion(text, vec![]).await?.send().await?; if !response.choice.iter().any(|x| { @@ -136,6 +166,7 @@ pub struct ExtractorBuilder< > { agent_builder: AgentBuilder, _t: PhantomData, + retries: Option, } impl Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel> @@ -151,7 +182,7 @@ impl Deserialize<'a> + Serialize + Send + Sync, M: Compl Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!. ") .tool(SubmitTool:: {_t: PhantomData}), - + retries: None, _t: PhantomData, } } @@ -181,11 +212,18 @@ impl Deserialize<'a> + Serialize + Send + Sync, M: Compl self } + /// Set the maximum number of retries for the extractor. + pub fn retries(mut self, retries: u64) -> Self { + self.retries = Some(retries); + self + } + /// Build the Extractor pub fn build(self) -> Extractor { Extractor { agent: self.agent_builder.build(), _t: PhantomData, + retries: self.retries.unwrap_or(0), } } }