diff --git a/rig-core/src/extractor.rs b/rig-core/src/extractor.rs index d4fbb1c47..d9ad02e3a 100644 --- a/rig-core/src/extractor.rs +++ b/rig-core/src/extractor.rs @@ -27,12 +27,28 @@ //! .await //! .expect("Failed to extract data from text"); //! ``` - -use std::marker::PhantomData; +//! # Hooks and Validation +//! +//! For advanced observability and validation, use the hook system: +//! +//! ```rust +//! use rig::extractor::{ExtractorWithHooks, ExtractorHook, ExtractorValidatorHook}; +//! +//! // Create hooks and validators +//! let hooks: Vec> = vec![Box::new(LoggingHook)]; +//! let validators: Vec>> = vec![Box::new(AgeValidator)]; +//! +//! // Use with any extractor - must import ExtractorWithHooks trait +//! let result = extractor.extract_with_hooks(text, hooks, validators).await?; +//! use schemars::{JsonSchema, schema_for}; use serde::{Deserialize, Serialize}; -use serde_json::json; +use serde_json::{Value, json}; +use std::boxed::Box; +use std::future::Future; +use std::marker::PhantomData; +use std::pin::Pin; use crate::{ agent::{Agent, AgentBuilder}, @@ -53,6 +69,271 @@ pub enum ExtractionError { #[error("CompletionError: {0}")] CompletionError(#[from] CompletionError), + + #[error("ValidationError: {0}")] + ValidationError(String), +} +/// A trait for observing and reacting to extraction lifecycle events. +/// +/// `ExtractorHook` provides a way to monitor the extraction process at key points, +/// enabling observability, logging, metrics collection, and debugging capabilities. +/// Hooks are called during extraction attempts and can be used to track progress, +/// measure performance, or implement custom monitoring logic. +/// To use hooks with extractors, import [`ExtractorWithHooks`] and call +/// extract_with_hooks +/// +/// # Lifecycle Events +/// +/// The hook methods are called in the following order during each extraction attempt: +/// 1. [`before_extract`] - Called before starting an extraction attempt +/// 2. [`after_parse`] - Called after successfully parsing JSON from the model response +/// 3. [`on_success`] OR [`on_error`] - Called when the attempt succeeds or fails +/// +/// # Thread Safety +/// +/// Implementations must be `Send + Sync + 'static` to be used across async boundaries +/// and stored in collections. +/// +/// # Example +/// +/// ```rust +/// use std::sync::{Arc, Mutex}; +/// use std::pin::Pin; +/// use std::boxed::Box; +/// use std::future::Future; +/// use rig::extractor::ExtractorWithHooks; +/// +/// #[derive(Clone)] +/// struct LoggingHook { +/// events: Arc>>, +/// } +/// +/// impl ExtractorHook for LoggingHook { +/// fn before_extract(&self, attempt: u64, text: &Message) -> Pin + Send + '_>> { +/// let events = Arc::clone(&self.events); +/// Box::pin(async move { +/// events.lock().unwrap().push(format!("Starting attempt {}", attempt)); +/// }) +/// } +/// +/// fn after_parse(&self, attempt: u64, data: &Value) -> Pin + Send + '_>> { +/// let events = Arc::clone(&self.events); +/// Box::pin(async move { +/// events.lock().unwrap().push(format!("Parsed data on attempt {}", attempt)); +/// }) +/// } +/// +/// fn on_error(&self, attempt: u64, error: &ExtractionError) -> Pin + Send + '_>> { +/// let events = Arc::clone(&self.events); +/// let error_msg = error.to_string(); +/// Box::pin(async move { +/// events.lock().unwrap().push(format!("Error on attempt {}: {}", attempt, error_msg)); +/// }) +/// } +/// +/// fn on_success(&self, attempt: u64, data: &Value) -> Pin + Send + '_>> { +/// let events = Arc::clone(&self.events); +/// Box::pin(async move { +/// events.lock().unwrap().push(format!("Success on attempt {}", attempt)); +/// }) +/// } +/// } +/// +/// let hooks: Vec> = vec![ +/// Box::new(LoggingHook { events: Arc::new(Mutex::new(Vec::new())) }), +/// ]; +/// +/// let validators: Vec>> = vec![ +/// Box::new(AgeValidator { min_age: 18, max_age: 120 }), +/// ]; +/// +/// let result = extractor.extract_with_hooks(text, hooks, validators).await?; +/// ``` +/// +/// [`before_extract`]: ExtractorHook::before_extract +/// [`after_parse`]: ExtractorHook::after_parse +/// [`on_success`]: ExtractorHook::on_success +/// [`on_error`]: ExtractorHook::on_error +pub trait ExtractorHook: Send + Sync + 'static { + /// Called before each extraction attempt begins. + /// + /// This method is invoked at the start of each extraction attempt, before any + /// communication with the language model. It provides an opportunity to perform + /// setup operations, start timers, increment counters, or log the beginning + /// of an extraction attempt. + fn before_extract( + &self, + attempt: u64, + text: &Message, + ) -> Pin + Send + '_>>; + + /// Called after successfully parsing JSON from the model response. + /// + /// This method is invoked when the language model has returned a response + /// and the JSON has been successfully parsed, but before any validation occurs. + /// It's useful for inspecting the raw extracted data or logging successful + /// parsing events. + fn after_parse( + &self, + attempt: u64, + data: &Value, + ) -> Pin + Send + '_>>; + + /// Called when an extraction attempt fails. + /// + /// This method is invoked whenever an extraction attempt fails, whether due to + /// model errors, parsing failures, validation errors, or other issues. It provides + /// an opportunity to log errors, update failure metrics, or perform cleanup. + fn on_error( + &self, + attempt: u64, + error: &ExtractionError, + ) -> Pin + Send + '_>>; + + /// Called when extraction succeeds completely. + /// + /// This method is invoked when an extraction attempt succeeds, meaning the model + /// response was parsed successfully and all validation passed. It's called after + /// all processing is complete and represents the final success of the extraction. + fn on_success( + &self, + attempt: u64, + data: &Value, + ) -> Pin + Send + '_>>; +} + +/// A trait for implementing custom validation logic on extracted data. +/// +/// `ExtractorValidatorHook` allows you to define custom validation rules that are +/// applied to extracted data after JSON parsing but before the extraction is considered +/// successful. When validation fails, the extraction attempt is retried (if retries are +/// configured), giving the language model an opportunity to self-correct based on +/// the validation error feedback. +/// +/// Validators are type-specific and work with the concrete extracted data structure, +/// enabling precise business rule validation, data quality checks, and domain-specific +/// constraints that go beyond basic JSON schema validation. +/// +/// # Validation Flow +/// +/// 1. Model extracts data and JSON is parsed successfully +/// 2. Each validator's [`validate`] method is called in sequence +/// 3. If any validator returns an error, the extraction attempt fails and may retry +/// 4. If all validators pass, the extraction succeeds +/// +/// # Error Handling +/// +/// Validation errors are converted to [`ExtractionError::ValidationError`] and fed back +/// into the retry loop, allowing the model to attempt self-correction on subsequent tries. +/// +/// # Example +/// To use validators with extractors, import [`ExtractorWithHooks`] and call +/// extract_with_hooks +/// ```rust +/// use std::pin::Pin; +/// use std::boxed::Box; +/// use std::future::Future; +/// use rig::extractor::ExtractorWithHooks; +/// +/// #[derive(serde::Deserialize, serde::Serialize, schemars::JsonSchema)] +/// struct Person { +/// name: String, +/// age: u8, +/// email: Option, +/// } +/// +/// #[derive(Clone)] +/// struct AgeValidator { +/// min_age: u8, +/// max_age: u8, +/// } +/// +/// impl ExtractorValidatorHook for AgeValidator { +/// fn validate(&self, person: &Person) -> Pin> + Send + '_>> { +/// let min_age = self.min_age; +/// let max_age = self.max_age; +/// let age = person.age; +/// Box::pin(async move { +/// if age < min_age { +/// return Err(ExtractionError::ValidationError( +/// format!("Age {} is below minimum of {}", age, min_age) +/// )); +/// } +/// if age > max_age { +/// return Err(ExtractionError::ValidationError( +/// format!("Age {} exceeds maximum of {}", age, max_age) +/// )); +/// } +/// Ok(()) +/// }) +/// } +/// } +/// ``` +/// You can chain multiple validators together. They are executed in order, and the first +/// validation failure will cause the extraction attempt to fail: +/// +/// ```rust +/// let validators: Vec>> = vec![ +/// Box::new(AgeValidator { min_age: 18, max_age: 120 }), +/// Box::new(EmailValidator), +/// Box::new(BusinessRuleValidator), +/// ]; +/// +/// let result = extractor.extract_with_hooks(text, vec![], validators).await?; +/// ``` +/// +/// [`validate`]: ExtractorValidatorHook::validate +pub trait ExtractorValidatorHook: Send + Sync +where + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, +{ + /// Validates the extracted data according to custom business rules. + fn validate( + &self, + data: &T, + ) -> Pin> + Send + '_>>; +} + +/// Extension trait that adds hook and validation capabilities to extractors. +/// +/// `ExtractorWithHooks` extends any extractor with advanced observability and validation +/// features through the [`extract_with_hooks`] method. +/// +/// # Usage +/// +/// To use this functionality, you must import the trait: +/// +/// ```rust +/// use rig::extractor::ExtractorWithHooks; +/// +/// let hooks: Vec> = vec![ +/// Box::new(LoggingHook::new()), +/// ]; +/// +/// let validators: Vec>> = vec![ +/// Box::new(AgeValidator { min_age: 18, max_age: 120 }), +/// ]; +/// +/// let result = extractor.extract_with_hooks(text, hooks, validators).await?; +/// ``` +/// +/// [`extract_with_hooks`]: ExtractorWithHooks::extract_with_hooks +pub trait ExtractorWithHooks +where + T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, +{ + /// Extracts structured data with observability hooks and custom validation. + /// + /// This method extends the basic extraction functionality with lifecycle hooks + /// for monitoring and custom validators for data quality assurance. It provides + /// the same extraction capabilities as [`Extractor::extract`] but with additional + /// observability and validation features. + fn extract_with_hooks( + &self, + text: impl Into + Send, + hooks: Vec>, + validators: Vec>>, + ) -> Pin> + Send + '_>>; } /// Extractor for structured data from text @@ -66,6 +347,63 @@ where retries: u64, } +impl ExtractorWithHooks for Extractor +where + M: CompletionModel, + T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, +{ + /// This implementation creates the complete extraction lifecycle with + /// observability and validation. It manages the retry loop, coordinates + /// hook calls, and ensures validators are executed in the correct sequence. + fn extract_with_hooks( + &self, + text: impl Into + Send, + hooks: Vec>, + validators: Vec>>, + ) -> Pin> + Send + '_>> { + let text_msg = text.into(); + + Box::pin(async move { + let mut last_error = None; + + for i in 0..=self.retries { + tracing::debug!( + "Attempting to extract Json. Retries left:{retries}", + retries = self.retries - i + ); + + for hook in &hooks { + hook.before_extract(i, &text_msg).await; + } + let attempt_t = text_msg.clone(); + match self + .extract_validated_json(attempt_t, i, &hooks, &validators) + .await + { + Ok(data) => { + let data_value = + serde_json::to_value(&data).unwrap_or(serde_json::Value::Null); + for hook in &hooks { + hook.on_success(i, &data_value).await; + } + return Ok(data); + } + Err(e) => { + tracing::warn!( + "Attempt number {i} to extract Json failed: {e:?}. Retrying..." + ); + for hook in &hooks { + hook.on_error(i, &e).await; + } + last_error = Some(e); + } + } + } + Err(last_error.unwrap_or(ExtractionError::NoData)) + }) + } +} + impl Extractor where M: CompletionModel, @@ -155,6 +493,49 @@ where Ok(serde_json::from_value(raw_data)?) } + async fn extract_validated_json( + &self, + text: impl Into + Send, + attempt: u64, + hooks: &[Box], + validators: &[Box>], + ) -> Result { + let response = self.agent.completion(text, vec![]).await?.send().await?; + + let args = response + .choice + .into_iter() + .filter_map(|content| { + if let AssistantContent::ToolCall(ToolCall { + function: ToolFunction { arguments, name }, + .. + }) = content + { + if name == SUBMIT_TOOL_NAME { + Some(arguments) + } else { + None + } + } else { + None + } + }) + .collect::>(); + let raw_data = args.into_iter().next().ok_or(ExtractionError::NoData)?; + + for hook in hooks { + hook.after_parse(attempt, &raw_data).await; + } + + let parsed_data: T = serde_json::from_value(raw_data.clone())?; + + for validator in validators { + if let Err(val_error) = validator.validate(&parsed_data).await { + return Err(ExtractionError::ValidationError(val_error.to_string())); + } + } + Ok(parsed_data) + } pub async fn get_inner(&self) -> &Agent { &self.agent } diff --git a/rig-core/tests/extractor_hooks.rs b/rig-core/tests/extractor_hooks.rs new file mode 100644 index 000000000..fca527039 --- /dev/null +++ b/rig-core/tests/extractor_hooks.rs @@ -0,0 +1,457 @@ +use rig::client::CompletionClient; +use rig::extractor::{ExtractionError, ExtractorHook, ExtractorValidatorHook, ExtractorWithHooks}; +use rig::providers::openai; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::boxed::Box; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::time::Instant; + +//************** For the first test ************** + +#[derive(Debug, Deserialize, Serialize, JsonSchema, PartialEq)] +struct TestData { + name: String, + count: u32, +} + +#[derive(Clone)] +struct CallCounter { + count: Arc>, +} + +impl CallCounter { + fn new() -> Self { + Self { + count: Arc::new(Mutex::new(0)), + } + } +} + +impl ExtractorHook for CallCounter { + fn before_extract( + &self, + _attempt: u64, + _text: &rig::message::Message, + ) -> Pin + Send + '_>> { + let count = Arc::clone(&self.count); + Box::pin(async move { + *count.lock().unwrap() += 1; + }) + } + + fn after_parse( + &self, + _attempt: u64, + _data: &Value, + ) -> Pin + Send + '_>> { + Box::pin(async move {}) + } + fn on_error( + &self, + _: u64, + _: &ExtractionError, + ) -> Pin + Send + '_>> { + Box::pin(async move {}) + } + + fn on_success(&self, _: u64, _: &Value) -> Pin + Send + '_>> { + Box::pin(async move {}) + } +} + +#[derive(Clone)] +struct PassValidator; + +impl ExtractorValidatorHook for PassValidator { + fn validate( + &self, + _data: &TestData, + ) -> Pin> + Send + '_>> { + Box::pin(async move { Ok(()) }) + } +} + +#[tokio::test] +#[ignore] +async fn test_hooks_called() { + let Ok(api_key) = std::env::var("OPENAI_API_KEY") else { + println!("skipping: api key not set"); + return; + }; + + let client = openai::Client::new(&api_key); + let extractor = client.extractor::(openai::GPT_4O_MINI).build(); + + let counter = CallCounter::new(); + let hooks: Vec> = vec![Box::new(counter.clone())]; + let validators: Vec>> = vec![]; + + let result = extractor + .extract_with_hooks("name: test, count: 42", hooks, validators) + .await; + + assert!( + *counter.count.lock().unwrap() > 0, + "Hook should have been called" + ); + + if let Ok(data) = result { + println!("Extraction result: {data:?}"); + } +} + +#[tokio::test] +#[ignore] +async fn test_validator_is_called() { + let Ok(api_key) = std::env::var("OPENAI_API_KEY") else { + println!("Skipping: OPENAI_API_KEY not set"); + return; + }; + + let client = openai::Client::new(&api_key); + let extractor = client.extractor::(openai::GPT_4O_MINI).build(); + + let hooks: Vec> = vec![]; + let validators: Vec>> = vec![Box::new(PassValidator)]; + + let result = extractor + .extract_with_hooks("name: hello, count: 123", hooks, validators) + .await; + + // test passes if validator was called successfully + println!("Validator test completed: {:?}", result); +} + +//************** For the Second test ************** +#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +struct SecondTestData { + name: String, + count: u32, +} + +#[derive(Clone)] +struct CycleTracker { + events: Arc>>, + start_time: Arc>>, +} + +impl CycleTracker { + fn new() -> Self { + Self { + events: Arc::new(Mutex::new(Vec::new())), + start_time: Arc::new(Mutex::new(None)), + } + } + + fn get_events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + fn log_event(&self, event: String) { + self.events.lock().unwrap().push(event); + } +} + +impl ExtractorHook for CycleTracker { + fn before_extract( + &self, + attempt: u64, + _text: &rig::message::Message, + ) -> Pin + Send + '_>> { + let tracker = self.clone(); + Box::pin(async move { + if attempt == 0 { + *tracker.start_time.lock().unwrap() = Some(Instant::now()); + } + tracker.log_event(format!("Before extract attempt {attempt}")); + }) + } + + fn after_parse( + &self, + attempt: u64, + data: &Value, + ) -> Pin + Send + '_>> { + let tracker = self.clone(); + let data = data.clone(); + Box::pin(async move { + tracker.log_event(format!("After parse attempt number {attempt}: {data}")); + }) + } + + fn on_error( + &self, + attempt: u64, + error: &ExtractionError, + ) -> Pin + Send + '_>> { + let tracker = self.clone(); + // need to format it otherwise wont be able to send it through async move + let error_msg = format!("{error:?}"); + Box::pin(async move { + tracker.log_event(format!("On error attempt number {attempt}: {error_msg}")); + }) + } + + fn on_success( + &self, + attempt: u64, + data: &Value, + ) -> Pin + Send + '_>> { + let tracker = self.clone(); + let data = data.clone(); + Box::pin(async move { + let elapsed = tracker + .start_time + .lock() + .unwrap() + .map(|s| s.elapsed()) + .unwrap_or_default(); + tracker.log_event(format!( + "On success attempt number {attempt}-> elapsed time:{elapsed:?}, data: {data:?}" + )); + }) + } +} + +#[derive(Clone)] +struct CountValidator { + max_count: u32, + call_count: Arc>, +} + +impl CountValidator { + fn new(max_count: u32) -> Self { + Self { + max_count, + call_count: Arc::new(Mutex::new(0)), + } + } + + fn get_call_count(&self) -> u32 { + *self.call_count.lock().unwrap() + } +} + +impl ExtractorValidatorHook for CountValidator { + fn validate( + &self, + data: &SecondTestData, + ) -> Pin> + Send + '_>> { + let max_count = self.max_count; + let call_count = Arc::clone(&self.call_count); + let count = data.count; + + Box::pin(async move { + *call_count.lock().unwrap() += 1; + if count > max_count { + return Err(ExtractionError::ValidationError(format!( + "Count {count} exceeds maximum {max_count}" + ))); + } + Ok(()) + }) + } +} + +#[derive(Clone)] +struct NameValidator { + forbidden_names: Vec, +} + +impl NameValidator { + fn new(forbidden_names: Vec<&str>) -> Self { + Self { + forbidden_names: forbidden_names.iter().map(|s| s.to_string()).collect(), + } + } +} + +impl ExtractorValidatorHook for NameValidator { + fn validate( + &self, + data: &SecondTestData, + ) -> Pin> + Send + '_>> { + let forbidden_names = self.forbidden_names.clone(); + let name = data.name.clone(); + + Box::pin(async move { + if forbidden_names.contains(&name) { + return Err(ExtractionError::ValidationError(format!( + "Name {name} is forbidden" + ))); + } + + Ok(()) + }) + } +} + +#[tokio::test] +#[ignore] +async fn test_validation_failure() { + let Ok(api_key) = std::env::var("OPENAI_API_KEY") else { + println!("Skipping: OPENAI_API_KEY not set"); + return; + }; + + let client = openai::Client::new(&api_key); + let extractor = client + .extractor::(openai::GPT_4O_MINI) + .retries(3) + .preamble("Extract name and count. If validation fails, try different values.") + .build(); + + let tracker = CycleTracker::new(); + let count_validator = CountValidator::new(50); + + let hooks: Vec> = vec![Box::new(tracker.clone())]; + let validators: Vec>> = + vec![Box::new(count_validator.clone())]; + + // using a count that is over limit + let result = extractor + .extract_with_hooks("name: TestUser, count: 999", hooks, validators) + .await; + + let events = tracker.get_events(); + println!("Cycle events: {:#?}", events); + println!( + "Validator called {} times", + count_validator.get_call_count() + ); + + assert!( + events + .iter() + .any(|e| e.starts_with("Before extract attempt 0")), + "Should have attempted extraction" + ); + + assert!( + count_validator.get_call_count() > 0, + "Validator should have been called" + ); + + match result { + Ok(data) => { + println!("SUCCESS: Model self-corrected to: {:?}", data); + assert!(data.count <= 50, "Final result should pass validation"); + assert!(events.iter().any(|e| e.starts_with("ON_SUCCESS"))); + } + Err(e) => { + println!("EXPECTED FAILURE: Validation failed after retries: {}", e); + assert!(events.iter().any(|e| e.starts_with("On error"))); + } + } +} + +#[tokio::test] +#[ignore] +async fn test_multiple_validators() { + let Ok(api_key) = std::env::var("OPENAI_API_KEY") else { + println!("Skipping: OPENAI_API_KEY not set"); + return; + }; + + let client = openai::Client::new(&api_key); + let extractor = client + .extractor::(openai::GPT_4O_MINI) + .retries(2) + .preamble("Extract name and count. Avoid forbidden names and keep count reasonable.") + .build(); + + let tracker = CycleTracker::new(); + let count_validator = CountValidator::new(100); + let name_validator = NameValidator::new(vec!["admin", "root", "test"]); + + let hooks: Vec> = vec![Box::new(tracker.clone())]; + let validators: Vec>> = + vec![Box::new(count_validator.clone()), Box::new(name_validator)]; + + //both name and max count should cause errors + let result = extractor + .extract_with_hooks("The admin user has a count of 500 items", hooks, validators) + .await; + + let events = tracker.get_events(); + println!("Complex scenario events: {:#?}", events); + + match result { + Ok(data) => { + println!("Model successfully completed validation: {:?}", data); + assert_ne!(data.name, "admin"); + assert!(data.count <= 100); + assert!(events.iter().any(|e| e.starts_with("On success"))); + } + Err(e) => { + println!("Validation correctly prevented extraction: {}", e); + assert!(events.iter().any(|e| e.starts_with("On error"))); + } + } + + assert!(count_validator.get_call_count() > 0); + assert!(events.len() >= 6); +} +#[tokio::test] +#[ignore] +async fn test_hook_timing_and_detailed_logging() { + let Ok(api_key) = std::env::var("OPENAI_API_KEY") else { + println!("Skipping: OPENAI_API_KEY not set"); + return; + }; + + let client = openai::Client::new(&api_key); + let extractor = client + .extractor::(openai::GPT_4O_MINI) + .retries(1) + .build(); + + let tracker = CycleTracker::new(); + let validator = CountValidator::new(1000); + + let hooks: Vec> = vec![Box::new(tracker.clone())]; + let validators: Vec>> = + vec![Box::new(validator)]; + + let result = extractor + .extract_with_hooks("name: Alice, count: 42", hooks, validators) + .await; + + let events = tracker.get_events(); + println!("Detailed timing events: {:#?}", events); + + let before_events: Vec<_> = events + .iter() + .filter(|e| e.starts_with("Before extract")) + .collect(); + let after_events: Vec<_> = events + .iter() + .filter(|e| e.starts_with("After parse")) + .collect(); + let success_events: Vec<_> = events + .iter() + .filter(|e| e.starts_with("On success")) + .collect(); + + assert!( + !before_events.is_empty(), + "Should have before_extract events" + ); + assert!(!after_events.is_empty(), "Should have after_parse events"); + + if result.is_ok() { + assert!(!success_events.is_empty(), "Should have success events"); + + let success_event = &success_events[0]; + assert!( + success_event.contains("elapsed time"), + "Should include timing information" + ); + } + + // sequence should make sense before -> after -> success/error + assert!(events.len() >= 2, "Should have multiple lifecycle events"); +}