Skip to content

Commit

Permalink
fix(rig-lancedb): rag embedding filtering (#104)
Browse files Browse the repository at this point in the history
* fix(lancedb): use embedding field name ot filter out embeddings from vector search result

* feat: add tracing info for token usage for all providers

* PR: make requested changes

* style: change formatting of info tracing

* fix: openai info tracing

* fix: change mutable borrow of value

* fix: remove mut from json value in tests

* fix: limit pub scope of filter function

* fix: change visibility of certain modules and traits
  • Loading branch information
marieaurore123 authored Nov 15, 2024
1 parent 2665bb0 commit f4452fc
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 11 deletions.
27 changes: 26 additions & 1 deletion rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ pub struct Usage {
pub output_tokens: u64,
}

impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
self.input_tokens,
match self.cache_read_input_tokens {
Some(token) => token.to_string(),
None => "n/a".to_string(),
},
match self.cache_creation_input_tokens {
Some(token) => token.to_string(),
None => "n/a".to_string(),
},
self.output_tokens
)
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
Expand Down Expand Up @@ -214,7 +233,13 @@ impl completion::CompletionModel for CompletionModel {

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Message(completion) => completion.try_into(),
ApiResponse::Message(completion) => {
tracing::info!(target: "rig",
"Anthropic completion token usage: {}",
completion.usage
);
completion.try_into()
}
ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Expand Down
22 changes: 21 additions & 1 deletion rig-core/src/providers/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ pub struct ApiVersion {
pub is_experimental: Option<bool>,
}

#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct BilledUnits {
#[serde(default)]
pub input_tokens: u32,
Expand All @@ -182,6 +182,16 @@ pub struct BilledUnits {
pub classifications: u32,
}

impl std::fmt::Display for BilledUnits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
self.input_tokens, self.output_tokens, self.search_units, self.classifications
)
}
}

#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
Expand Down Expand Up @@ -217,6 +227,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
match response.meta {
Some(meta) => tracing::info!(target: "rig",
"Cohere embeddings billed units: {}",
meta.billed_units,
),
None => tracing::info!(target: "rig",
"Cohere embeddings billed units: n/a",
),
};

if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(format!(
"Expected {} embeddings, got {}",
Expand Down
26 changes: 26 additions & 0 deletions rig-core/src/providers/gemini/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ impl completion::CompletionModel for CompletionModel {
.json::<GenerateContentResponse>()
.await?;

match response.usage_metadata {
Some(ref usage) => tracing::info!(target: "rig",
"Gemini completion token usage: {}",
usage
),
None => tracing::info!(target: "rig",
"Gemini completion token usage: n/a",
),
}

tracing::debug!("Received response");

completion::CompletionResponse::try_from(response)
Expand Down Expand Up @@ -359,6 +369,22 @@ pub mod gemini_api_types {
pub total_token_count: i32,
}

impl std::fmt::Display for UsageMetadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt token count: {}\nCached content token count: {}\nCandidates token count: {}\nTotal token count: {}",
self.prompt_token_count,
match self.cached_content_token_count {
Some(count) => count.to_string(),
None => "n/a".to_string(),
},
self.candidates_token_count,
self.total_token_count
)
}
}

/// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest).
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand Down
23 changes: 22 additions & 1 deletion rig-core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ pub struct Usage {
pub total_tokens: usize,
}

impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {}\nTotal tokens: {}",
self.prompt_tokens, self.total_tokens
)
}
}

#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
Expand Down Expand Up @@ -258,6 +268,11 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI embedding token usage: {}",
response.usage
);

if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
Expand Down Expand Up @@ -517,7 +532,13 @@ impl completion::CompletionModel for CompletionModel {

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI completion token usage: {:?}",
response.usage
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Expand Down
18 changes: 17 additions & 1 deletion rig-core/src/providers/perplexity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ pub struct Usage {
pub total_tokens: u32,
}

impl std::fmt::Display for Usage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}",
self.prompt_tokens, self.completion_tokens, self.total_tokens
)
}
}

impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
type Error = CompletionError;

Expand Down Expand Up @@ -235,7 +245,13 @@ impl completion::CompletionModel for CompletionModel {

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(completion) => Ok(completion.try_into()?),
ApiResponse::Ok(completion) => {
tracing::info!(target: "rig",
"Perplexity completion token usage: {}",
completion.usage
);
Ok(completion.try_into()?)
}
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Expand Down
11 changes: 7 additions & 4 deletions rig-lancedb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rig::{
};
use serde::Deserialize;
use serde_json::Value;
use utils::QueryToJson;
use utils::{FilterEmbeddings, QueryToJson};

mod utils;

Expand Down Expand Up @@ -250,16 +250,19 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for LanceDbV
.into_iter()
.enumerate()
.map(|(i, value)| {
let filtered_value = value
.filter(self.search_params.column.clone())
.map_err(serde_to_rig_error)?;
Ok((
match value.get("_distance") {
match filtered_value.get("_distance") {
Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(),
_ => 0.0,
},
match value.get(self.id_field.clone()) {
match filtered_value.get(self.id_field.clone()) {
Some(Value::String(id)) => id.to_string(),
_ => format!("unknown{i}"),
},
serde_json::from_value(value).map_err(serde_to_rig_error)?,
serde_json::from_value(filtered_value).map_err(serde_to_rig_error)?,
))
})
.collect()
Expand Down
2 changes: 1 addition & 1 deletion rig-lancedb/src/utils/deserializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError {

/// Trait used to deserialize data returned from LanceDB queries into a serde_json::Value vector.
/// Data returned by LanceDB is a vector of `RecordBatch` items.
pub trait RecordBatchDeserializer {
pub(crate) trait RecordBatchDeserializer {
fn deserialize(&self) -> Result<Vec<serde_json::Value>, VectorStoreError>;
}

Expand Down
63 changes: 61 additions & 2 deletions rig-lancedb/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
pub mod deserializer;
mod deserializer;

use deserializer::RecordBatchDeserializer;
use futures::TryStreamExt;
use lancedb::query::ExecutableQuery;
use rig::vector_store::VectorStoreError;
use serde::de::Error;

use crate::lancedb_to_rig_error;

/// Trait that facilitates the conversion of columnar data returned by a lanceDb query to serde_json::Value.
/// Used whenever a lanceDb table is queried.
pub trait QueryToJson {
pub(crate) trait QueryToJson {
async fn execute_query(&self) -> Result<Vec<serde_json::Value>, VectorStoreError>;
}

Expand All @@ -26,3 +27,61 @@ impl QueryToJson for lancedb::query::VectorQuery {
record_batches.deserialize()
}
}

pub(crate) trait FilterEmbeddings {
fn filter(self, embeddings_col: Option<String>) -> serde_json::Result<serde_json::Value>;
}

impl FilterEmbeddings for serde_json::Value {
fn filter(mut self, embeddings_col: Option<String>) -> serde_json::Result<serde_json::Value> {
match self.as_object_mut() {
Some(obj) => {
obj.remove(&embeddings_col.unwrap_or("embedding".to_string()));
serde_json::to_value(obj)
}
None => Err(serde_json::Error::custom(format!(
"{} is not an object",
self
))),
}
}
}

#[cfg(test)]
mod tests {
use crate::utils::FilterEmbeddings;

#[test]
fn test_filter_default() {
let json = serde_json::json!({
"id": "doc0",
"text": "Hello world",
"embedding": vec![0.3889, 0.6987, 0.7758, 0.7750, 0.7289, 0.3380, 0.1165, 0.1551, 0.3783, 0.1458,
0.3060, 0.2155, 0.8966, 0.5498, 0.7419, 0.8120, 0.2306, 0.5155, 0.9947, 0.0805]
});

let filtered_json = json.filter(None).unwrap();

assert_eq!(
filtered_json,
serde_json::json!({"id": "doc0", "text": "Hello world"})
);
}

#[test]
fn test_filter_non_default() {
let json = serde_json::json!({
"id": "doc0",
"text": "Hello world",
"vectors": vec![0.3889, 0.6987, 0.7758, 0.7750, 0.7289, 0.3380, 0.1165, 0.1551, 0.3783, 0.1458,
0.3060, 0.2155, 0.8966, 0.5498, 0.7419, 0.8120, 0.2306, 0.5155, 0.9947, 0.0805]
});

let filtered_json = json.filter(Some("vectors".to_string())).unwrap();

assert_eq!(
filtered_json,
serde_json::json!({"id": "doc0", "text": "Hello world"})
);
}
}

0 comments on commit f4452fc

Please sign in to comment.