Skip to content

Commit

Permalink
Feat(integration test): Neo4J (#133)
Browse files Browse the repository at this point in the history
* setup: integration test

* test: add to integration test

* fix(test): fix bugs while testing

* fix: fix bugs after merging with main

* cargo fmt

* fix: fix error in comment

* refactor: rename fakedefinition to definition

* fix: make PR requested change

* fixL cargo clippy

* fix: rename struct to Word

* fix: rename struct to Word
  • Loading branch information
marieaurore123 authored Dec 3, 2024
1 parent cda8102 commit 96e78be
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 30 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions rig-neo4j/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ anyhow = "1.0.86"
tokio = { version = "1.38.0", features = ["macros"] }
textwrap = { version = "0.16.1"}
term_size = { version = "0.3.2"}
testcontainers = "0.23.1"
tracing-subscriber = "0.3.18"

[[example]]
name = "vector_search_simple"
required-features = ["rig-core/derive"]

[[test]]
name = "integration_tests"
required-features = ["rig-core/derive"]
22 changes: 9 additions & 13 deletions rig-neo4j/examples/vector_search_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! 5. Returns the results
use std::env;

use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
Expand All @@ -18,7 +18,7 @@ use rig::{
use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType};

#[derive(Embed, Clone, Debug)]
pub struct WordDefinition {
pub struct Word {
pub id: String,
#[embed]
pub definition: String,
Expand All @@ -41,22 +41,22 @@ async fn main() -> Result<(), anyhow::Error> {
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let embeddings = EmbeddingsBuilder::new(model.clone())
.document(WordDefinition {
.document(Word {
id: "doc0".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
})?
.document(WordDefinition {
.document(Word {
id: "doc1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
})?
.document(WordDefinition {
.document(Word {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
})?
.build()
.await?;

let create_nodes = futures::stream::iter(embeddings)
futures::stream::iter(embeddings)
.map(|(doc, embeddings)| {
neo4j_client.graph.run(
neo4rs::query(
Expand All @@ -76,13 +76,9 @@ async fn main() -> Result<(), anyhow::Error> {
)
})
.buffer_unordered(3)
.collect::<Vec<_>>()
.await;

// Unwrap the results in the vector _create_nodes
for result in create_nodes {
result.unwrap(); // or handle the error appropriately
}
.try_collect::<Vec<_>>()
.await
.unwrap();

// Create a vector index on our vector store
println!("Creating vector index...");
Expand Down
12 changes: 6 additions & 6 deletions rig-neo4j/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@
//! .await
//! .unwrap();
//!
//! let index = client.index(
//! let index = client.get_index(
//! model,
//! IndexConfig::new("moviePlotsEmbedding"),
//! SearchParams::default(),
//! );
//! "moviePlotsEmbedding",
//! SearchParams::default()
//! ).await.unwrap();
//!
//! #[derive(Debug, Deserialize)]
//! struct Movie {
Expand Down Expand Up @@ -149,14 +149,14 @@ where
}

impl Neo4jClient {
const GET_INDEX_QUERY: &str = "
const GET_INDEX_QUERY: &'static str = "
SHOW VECTOR INDEXES
YIELD name, properties, options
WHERE name=$index_name
RETURN name, properties, options
";

const SHOW_INDEXES_QUERY: &str = "SHOW VECTOR INDEXES YIELD name RETURN name";
const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name";

pub fn new(graph: Graph) -> Self {
Self { graph }
Expand Down
159 changes: 159 additions & 0 deletions rig-neo4j/tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use testcontainers::{
core::{IntoContainerPort, Mount, WaitFor},
runners::AsyncRunner,
GenericImage, ImageExt,
};

use futures::{StreamExt, TryStreamExt};
use rig::vector_store::VectorStoreIndex;
use rig::{
embeddings::{Embedding, EmbeddingsBuilder},
providers::openai,
Embed, OneOrMany,
};
use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType};

const BOLT_PORT: u16 = 7687;
const HTTP_PORT: u16 = 7474;

#[derive(Embed, Clone, serde::Deserialize, Debug)]
struct Word {
id: String,
#[embed]
definition: String,
}

#[tokio::test]
async fn vector_search_test() {
let mount = Mount::volume_mount("data", std::env::var("GITHUB_WORKSPACE").unwrap());
// Setup a local Neo 4J container for testing. NOTE: docker service must be running.
let container = GenericImage::new("neo4j", "latest")
.with_wait_for(WaitFor::Duration {
length: std::time::Duration::from_secs(5),
})
.with_exposed_port(BOLT_PORT.tcp())
.with_exposed_port(HTTP_PORT.tcp())
.with_mount(mount)
.with_env_var("NEO4J_AUTH", "none")
.start()
.await
.expect("Failed to start Neo 4J container");

let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap();
let host = container.get_host().await.unwrap().to_string();

let neo4j_client = Neo4jClient::connect(&format!("neo4j://{host}:{port}"), "", "")
.await
.unwrap();

// Initialize OpenAI client
let openai_client = openai::Client::from_env();

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002);

let embeddings = create_embeddings(model.clone()).await;

futures::stream::iter(embeddings)
.map(|(doc, embeddings)| {
neo4j_client.graph.run(
neo4rs::query(
"
CREATE
(document:DocumentEmbeddings {
id: $id,
document: $document,
embedding: $embedding})
RETURN document",
)
.param("id", doc.id)
// Here we use the first embedding but we could use any of them.
// Neo4j only takes primitive types or arrays as properties.
.param("embedding", embeddings.first().vec.clone())
.param("document", doc.definition.to_bolt_type()),
)
})
.buffer_unordered(3)
.try_collect::<Vec<_>>()
.await
.unwrap();

// Create a vector index on our vector store
println!("Creating vector index...");
neo4j_client
.graph
.run(neo4rs::query(
"CREATE VECTOR INDEX vector_index IF NOT EXISTS
FOR (m:DocumentEmbeddings)
ON m.embedding
OPTIONS { indexConfig: {
`vector.dimensions`: 1536,
`vector.similarity_function`: 'cosine'
}}",
))
.await
.unwrap();

// ℹ️ The index name must be unique among both indexes and constraints.
// A newly created index is not immediately available but is created in the background.

// Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready
let index_exists = neo4j_client
.graph
.run(neo4rs::query("CALL db.awaitIndex('vector_index')"))
.await;
if index_exists.is_err() {
println!("Index not ready, waiting for index...");
std::thread::sleep(std::time::Duration::from_secs(5));
}

println!("Index exists: {:?}", index_exists);

// Create a vector index on our vector store
// IMPORTANT: Reuse the same model that was used to generate the embeddings
let index = neo4j_client
.get_index(model, "vector_index", SearchParams::default())
.await
.unwrap();

// Query the index
let results = index
.top_n::<serde_json::Value>("What is a glarb?", 1)
.await
.unwrap();

let (_, _, value) = &results.first().unwrap();

assert_eq!(
value,
&serde_json::json!({
"id": "doc1",
"document": "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.",
"embedding": serde_json::Value::Null
})
)
}

async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(Word, OneOrMany<Embedding>)> {
let words = vec![
Word {
id: "doc0".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
},
Word {
id: "doc1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
Word {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}
];

EmbeddingsBuilder::new(model)
.documents(words)
.unwrap()
.build()
.await
.unwrap()
}
10 changes: 5 additions & 5 deletions rig-qdrant/examples/qdrant_vector_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use rig::{
use rig_qdrant::QdrantVectorStore;

#[derive(Embed, serde::Deserialize, serde::Serialize, Debug)]
struct Definition {
struct Word {
id: String,
#[embed]
definition: String,
Expand Down Expand Up @@ -56,15 +56,15 @@ async fn main() -> Result<(), anyhow::Error> {
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let documents = EmbeddingsBuilder::new(model.clone())
.document(Definition {
.document(Word {
id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
})?
.document(Definition {
.document(Word {
id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
})?
.document(Definition {
.document(Word {
id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
})?
Expand All @@ -91,7 +91,7 @@ async fn main() -> Result<(), anyhow::Error> {
let vector_store = QdrantVectorStore::new(client, model, query_params.build());

let results = vector_store
.top_n::<Definition>("What is a linglingdong?", 1)
.top_n::<Word>("What is a linglingdong?", 1)
.await?;

println!("Results: {:?}", results);
Expand Down
12 changes: 6 additions & 6 deletions rig-qdrant/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const QDRANT_PORT_SECONDARY: u16 = 6334;
const COLLECTION_NAME: &str = "rig-collection";

#[derive(Embed, Clone, serde::Deserialize, serde::Serialize, Debug)]
struct Definition {
struct Word {
id: String,
#[embed]
definition: String,
Expand Down Expand Up @@ -95,23 +95,23 @@ async fn vector_search_test() {
}

async fn create_points(model: openai::EmbeddingModel) -> Vec<PointStruct> {
let fake_definitions = vec![
Definition {
let words = vec![
Word {
id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(),
definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
},
Definition {
Word {
id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
Definition {
Word {
id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}
];

let documents = EmbeddingsBuilder::new(model)
.documents(fake_definitions)
.documents(words)
.unwrap()
.build()
.await
Expand Down

0 comments on commit 96e78be

Please sign in to comment.