Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(integration test): Neo4J #133

Merged
merged 13 commits into from
Dec 3, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions rig-neo4j/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ anyhow = "1.0.86"
tokio = { version = "1.38.0", features = ["macros"] }
textwrap = { version = "0.16.1"}
term_size = { version = "0.3.2"}
testcontainers = "0.23.1"
tracing-subscriber = "0.3.18"

[[example]]
name = "vector_search_simple"
required-features = ["rig-core/derive"]

[[test]]
name = "integration_tests"
required-features = ["rig-core/derive"]
22 changes: 9 additions & 13 deletions rig-neo4j/examples/vector_search_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! 5. Returns the results
use std::env;

use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
Expand All @@ -18,7 +18,7 @@ use rig::{
use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType};

#[derive(Embed, Clone, Debug)]
pub struct WordDefinition {
pub struct Word {
pub id: String,
#[embed]
pub definition: String,
Expand All @@ -41,22 +41,22 @@ async fn main() -> Result<(), anyhow::Error> {
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let embeddings = EmbeddingsBuilder::new(model.clone())
.document(WordDefinition {
.document(Word {
id: "doc0".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
})?
.document(WordDefinition {
.document(Word {
id: "doc1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
})?
.document(WordDefinition {
.document(Word {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
})?
.build()
.await?;

let create_nodes = futures::stream::iter(embeddings)
futures::stream::iter(embeddings)
.map(|(doc, embeddings)| {
neo4j_client.graph.run(
neo4rs::query(
Expand All @@ -76,13 +76,9 @@ async fn main() -> Result<(), anyhow::Error> {
)
})
.buffer_unordered(3)
.collect::<Vec<_>>()
.await;

// Unwrap the results in the vector _create_nodes
for result in create_nodes {
result.unwrap(); // or handle the error appropriately
}
.try_collect::<Vec<_>>()
.await
.unwrap();

// Create a vector index on our vector store
println!("Creating vector index...");
Expand Down
12 changes: 6 additions & 6 deletions rig-neo4j/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@
//! .await
//! .unwrap();
//!
//! let index = client.index(
//! let index = client.get_index(
//! model,
//! IndexConfig::new("moviePlotsEmbedding"),
//! SearchParams::default(),
//! );
//! "moviePlotsEmbedding",
//! SearchParams::default()
//! ).await.unwrap();
//!
//! #[derive(Debug, Deserialize)]
//! struct Movie {
Expand Down Expand Up @@ -149,14 +149,14 @@ where
}

impl Neo4jClient {
const GET_INDEX_QUERY: &str = "
const GET_INDEX_QUERY: &'static str = "
SHOW VECTOR INDEXES
YIELD name, properties, options
WHERE name=$index_name
RETURN name, properties, options
";

const SHOW_INDEXES_QUERY: &str = "SHOW VECTOR INDEXES YIELD name RETURN name";
const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";

pub fn new(graph: Graph) -> Self {
Self { graph }
Expand Down
159 changes: 159 additions & 0 deletions rig-neo4j/tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use testcontainers::{
core::{IntoContainerPort, Mount, WaitFor},
runners::AsyncRunner,
GenericImage, ImageExt,
};

use futures::{StreamExt, TryStreamExt};
use rig::vector_store::VectorStoreIndex;
use rig::{
embeddings::{Embedding, EmbeddingsBuilder},
providers::openai,
Embed, OneOrMany,
};
use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType};

const BOLT_PORT: u16 = 7687;
const HTTP_PORT: u16 = 7474;

#[derive(Embed, Clone, serde::Deserialize, Debug)]
struct Word {
id: String,
#[embed]
definition: String,
}

#[tokio::test]
async fn vector_search_test() {
let mount = Mount::volume_mount("data", std::env::var("GITHUB_WORKSPACE").unwrap());
// Setup a local Neo 4J container for testing. NOTE: docker service must be running.
let container = GenericImage::new("neo4j", "latest")
.with_wait_for(WaitFor::Duration {
length: std::time::Duration::from_secs(5),
})
.with_exposed_port(BOLT_PORT.tcp())
.with_exposed_port(HTTP_PORT.tcp())
.with_mount(mount)
.with_env_var("NEO4J_AUTH", "none")
.start()
.await
.expect("Failed to start Neo 4J container");

let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap();
let host = container.get_host().await.unwrap().to_string();

let neo4j_client = Neo4jClient::connect(&format!("neo4j://{host}:{port}"), "", "")
.await
.unwrap();

// Initialize OpenAI client
let openai_client = openai::Client::from_env();

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002);

let embeddings = create_embeddings(model.clone()).await;

futures::stream::iter(embeddings)
.map(|(doc, embeddings)| {
neo4j_client.graph.run(
neo4rs::query(
"
CREATE
(document:DocumentEmbeddings {
id: $id,
document: $document,
embedding: $embedding})
RETURN document",
)
.param("id", doc.id)
// Here we use the first embedding but we could use any of them.
// Neo4j only takes primitive types or arrays as properties.
.param("embedding", embeddings.first().vec.clone())
.param("document", doc.definition.to_bolt_type()),
)
})
.buffer_unordered(3)
.try_collect::<Vec<_>>()
.await
.unwrap();

// Create a vector index on our vector store
println!("Creating vector index...");
neo4j_client
.graph
.run(neo4rs::query(
"CREATE VECTOR INDEX vector_index IF NOT EXISTS
FOR (m:DocumentEmbeddings)
ON m.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 1536,
`vector.similarity_function`: 'cosine'
}}",
))
.await
.unwrap();

// ℹ️ The index name must be unique among both indexes and constraints.
// A newly created index is not immediately available but is created in the background.

// Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready
let index_exists = neo4j_client
.graph
.run(neo4rs::query("CALL db.awaitIndex('vector_index')"))
.await;
if index_exists.is_err() {
println!("Index not ready, waiting for index...");
std::thread::sleep(std::time::Duration::from_secs(5));
}

println!("Index exists: {:?}", index_exists);

// Create a vector index on our vector store
// IMPORTANT: Reuse the same model that was used to generate the embeddings
let index = neo4j_client
.get_index(model, "vector_index", SearchParams::default())
.await
.unwrap();

// Query the index
let results = index
.top_n::<serde_json::Value>("What is a glarb?", 1)
.await
.unwrap();

let (_, _, value) = &results.first().unwrap();

assert_eq!(
value,
&serde_json::json!({
"id": "doc1",
"document": "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
"embedding": serde_json::Value::Null
})
)
}

async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(Word, OneOrMany<Embedding>)> {
let words = vec![
Word {
id: "doc0".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
},
Word {
id: "doc1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
Word {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}
];

EmbeddingsBuilder::new(model)
.documents(words)
.unwrap()
.build()
.await
.unwrap()
}
10 changes: 5 additions & 5 deletions rig-qdrant/examples/qdrant_vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use rig::{
use rig_qdrant::QdrantVectorStore;

#[derive(Embed, serde::Deserialize, serde::Serialize, Debug)]
struct Definition {
struct Word {
id: String,
#[embed]
definition: String,
Expand Down Expand Up @@ -56,15 +56,15 @@ async fn main() -> Result<(), anyhow::Error> {
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let documents = EmbeddingsBuilder::new(model.clone())
.document(Definition {
.document(Word {
id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
})?
.document(Definition {
.document(Word {
id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
})?
.document(Definition {
.document(Word {
id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
})?
Expand All @@ -91,7 +91,7 @@ async fn main() -> Result<(), anyhow::Error> {
let vector_store = QdrantVectorStore::new(client, model, query_params.build());

let results = vector_store
.top_n::<Definition>("What is a linglingdong?", 1)
.top_n::<Word>("What is a linglingdong?", 1)
.await?;

println!("Results: {:?}", results);
Expand Down
12 changes: 6 additions & 6 deletions rig-qdrant/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const QDRANT_PORT_SECONDARY: u16 = 6334;
const COLLECTION_NAME: &str = "rig-collection";

#[derive(Embed, Clone, serde::Deserialize, serde::Serialize, Debug)]
struct Definition {
struct Word {
id: String,
#[embed]
definition: String,
Expand Down Expand Up @@ -95,23 +95,23 @@ async fn vector_search_test() {
}

async fn create_points(model: openai::EmbeddingModel) -> Vec<PointStruct> {
let fake_definitions = vec![
Definition {
let words = vec![
Word {
id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
},
Definition {
Word {
id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
Definition {
Word {
id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}
];

let documents = EmbeddingsBuilder::new(model)
.documents(fake_definitions)
.documents(words)
.unwrap()
.build()
.await
Expand Down
Loading