Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rig-core/examples/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
struct Person {
/// The person's first name, if provided (null otherwise)
#[schemars(required)]
pub first_name: Option<String>,
/// The person's last name, if provided (null otherwise)
#[schemars(required)]
pub last_name: Option<String>,
/// The person's job, if provided (null otherwise)
#[schemars(required)]
pub job: Option<String>,
}

Expand Down
40 changes: 39 additions & 1 deletion rig-core/src/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,43 @@ pub enum ExtractionError {
pub struct Extractor<M: CompletionModel, T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync> {
agent: Agent<M>,
_t: PhantomData<T>,
retries: u64,
}

impl<T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, M: CompletionModel> Extractor<M, T>
where
M: Sync,
{
/// Attempts to extract data from the given text with a number of retries.
///
/// The function will retry the extraction if the initial attempt fails or
/// if the model does not call the `submit` tool.
///
/// The number of retries is determined by the `retries` field on the Extractor struct.
pub async fn extract(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
let mut last_error = None;
let text_message = text.into();

for i in 0..=self.retries {
tracing::debug!(
"Attempting to extract JSON. Retries left: {retries}",
retries = self.retries - i
);
let attempt_text = text_message.clone();
match self.extract_json(attempt_text).await {
Ok(data) => return Ok(data),
Err(e) => {
tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying...");
last_error = Some(e);
}
}
}

// If the loop finishes without a successful extraction, return the last error encountered.
Err(last_error.unwrap_or(ExtractionError::NoData))
}

async fn extract_json(&self, text: impl Into<Message> + Send) -> Result<T, ExtractionError> {
let response = self.agent.completion(text, vec![]).await?.send().await?;

if !response.choice.iter().any(|x| {
Expand Down Expand Up @@ -136,6 +166,7 @@ pub struct ExtractorBuilder<
> {
agent_builder: AgentBuilder<M>,
_t: PhantomData<T>,
retries: Option<u64>,
}

impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: CompletionModel>
Expand All @@ -151,7 +182,7 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: Compl
Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!.
")
.tool(SubmitTool::<T> {_t: PhantomData}),

retries: None,
_t: PhantomData,
}
}
Expand Down Expand Up @@ -181,11 +212,18 @@ impl<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, M: Compl
self
}

/// Set the maximum number of retries for the extractor.
pub fn retries(mut self, retries: u64) -> Self {
self.retries = Some(retries);
self
}

/// Build the Extractor
pub fn build(self) -> Extractor<M, T> {
Extractor {
agent: self.agent_builder.build(),
_t: PhantomData,
retries: self.retries.unwrap_or(0),
}
}
}
Expand Down