diff --git a/src/it/scala/com/xebia/functional/vectorstores/db/PGVectorStoreSpec.scala b/src/it/scala/com/xebia/functional/vectorstores/db/PGVectorStoreSpec.scala index d1d4dc12e..954ee5604 100644 --- a/src/it/scala/com/xebia/functional/vectorstores/db/PGVectorStoreSpec.scala +++ b/src/it/scala/com/xebia/functional/vectorstores/db/PGVectorStoreSpec.scala @@ -27,80 +27,91 @@ class PGVectorStoreSpec extends DatabaseSuite: None ) - test("PGVectorStore - initialDbSetup") { + test("initialDbSetup should configure the DB properly") { val result: IO[Unit] = pg.initialDbSetup() assertIO(result, ()) } - test("PGVectorStore - addTexts should fail - collection not found".fail) { + test("addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB") { val result: IO[List[DocumentVectorId]] = pg.addTexts(TestData.texts) - assertIO(result, List.empty[DocumentVectorId]) + interceptMessageIO[PGErrors.CollectionNotFoundError]( + "Collection 'test_collection' not found" + )(result) } - test("PGVectorStore - createCollection") { + test("similaritySearch shoul fail with a CollectionNotFoundError if collection isn't present in the DB") { + val result: IO[List[Document]] = pg.similaritySearch("foo", 2) + interceptMessageIO[PGErrors.CollectionNotFoundError]( + "Collection 'test_collection' not found" + )(result) + } + + test("createCollection should create collection") { val result: IO[Int] = pg.createCollection assertIO(result, 1) } - test("PGVectorStore - addTexts should return a list of 2 elements") { + test("addTexts should return a list of 2 elements") { val result: IO[List[DocumentVectorId]] = pg.addTexts(TestData.texts) assertIO(result.map(_.length), 2) } - test("PGVectorStore - similaritySearchByVector should return both documents") { + test("similaritySearchByVector should return both documents") { val result: IO[List[Document]] = pg.similaritySearchByVector(TestData.barEmbedding, 2) assertIO(result.map(_.map(_.content)), List("bar", "foo")) } - test("PGVectorStore - addDocuments should return a list of 2 elements") { + test("addDocuments should return a list of 2 elements") { val result: IO[List[DocumentVectorId]] = pg.addDocuments(TestData.texts.map(Document.apply(_))) assertIO(result.map(_.length), 2) } - test("PGVectorStore - similaritySearch should return 2 documents") { + test("similaritySearch should return 2 documents") { val result: IO[List[Document]] = pg.similaritySearch("foo", 2) assertIO(result.map(_.length), 2) } - test("PGVectorStore - similaritySearch should fail when embedding vector is empty".fail) { + test("similaritySearch should fail when embedding vector is empty") { val result: IO[List[Document]] = pg.similaritySearch("baz", 2) - assertIO(result.map(_.length), 2) + interceptMessageIO[PGErrors.EmbeddingNotGeneratedError]( + "Embedding for text: 'baz', has not been properly generated" + )(result) } - test("PGVectorStore - similaritySearchByVector should return document for 'foo'") { + test("similaritySearchByVector should return document") { val result: IO[List[Document]] = pg.similaritySearchByVector(TestData.fooEmbedding, 1) assertIO(result.map(_.map(_.content)), List("foo")) } - test("PGVectorStore check query - addVectorExtension") { + test("check query - addVectorExtension") { check(PGSql.addVectorExtension) } - test("PGVectorStore check query - createCollectionsTable") { + test("check query - createCollectionsTable") { check(PGSql.createCollectionsTable) } - test("PGVectorStore check query - createEmbeddingTable") { + test("check query - createEmbeddingTable") { check(PGSql.createEmbeddingTable(3)) } - test("PGVectorStore check query - addNewCollection") { + test("check query - addNewCollection") { check(PGSql.addNewCollection(TestData.collectionUUID.id, TestData.collectionName)) } - test("PGVectorStore check query - getCollection") { + test("check query - getCollection") { check(PGSql.getCollection(TestData.collectionName)) } - test("PGVectorStore check query - getCollectionById") { + test("check query - getCollectionById") { check(PGSql.getCollectionById(TestData.collectionUUID.id)) } - test("PGVectorStore check query - deleteCollectionDocs") { + test("check query - deleteCollectionDocs") { check(PGSql.deleteCollectionDocs(TestData.collectionUUID.id)) } - test("PGVectorStore check query - deleteCollection") { + test("check query - deleteCollection") { check(PGSql.deleteCollection(TestData.collectionUUID.id)) } diff --git a/src/main/scala/com/xebia/functional/chains/models/Config.scala b/src/main/scala/com/xebia/functional/chains/models/Config.scala index af5c35b97..641e3a5e9 100644 --- a/src/main/scala/com/xebia/functional/chains/models/Config.scala +++ b/src/main/scala/com/xebia/functional/chains/models/Config.scala @@ -10,7 +10,7 @@ final case class Config( ) { def genInputs[F[_]: ApplicativeThrow](inputs: Map[String, String]): F[Map[String, String]] = ( - if ((inputKeys diff inputs.keySet).isEmpty) Some(inputs) else None + if ((inputKeys diff inputs.keySet).isEmpty && (inputs.keySet diff inputKeys).isEmpty) Some(inputs) else None ).liftTo[F](InvalidChainInputsError(inputKeys, inputs)) def genInputsFromString[F[_]: ApplicativeThrow](input: String): F[Map[String, String]] = diff --git a/src/main/scala/com/xebia/functional/vectorstores/postgres/PGErrors.scala b/src/main/scala/com/xebia/functional/vectorstores/postgres/PGErrors.scala index c1fa0751b..d713f099e 100644 --- a/src/main/scala/com/xebia/functional/vectorstores/postgres/PGErrors.scala +++ b/src/main/scala/com/xebia/functional/vectorstores/postgres/PGErrors.scala @@ -6,8 +6,8 @@ object PGErrors: class DatabaseSetupError(reason: String) extends Throwable with NoStackTrace: override def getMessage: String = s"Error while setting up the database: $reason" - class CollectionNotFound(collectionName: String) extends Throwable with NoStackTrace: - override def getMessage: String = s"Collection $collectionName not found" + class CollectionNotFoundError(collectionName: String) extends Throwable with NoStackTrace: + override def getMessage: String = s"Collection '$collectionName' not found" - class EmbeddingNotGenerated(text: String) extends Throwable with NoStackTrace: - override def getMessage(): String = s"Embedding for text $text has not been properly generated" + class EmbeddingNotGeneratedError(text: String) extends Throwable with NoStackTrace: + override def getMessage(): String = s"Embedding for text: '$text', has not been properly generated" diff --git a/src/main/scala/com/xebia/functional/vectorstores/postgres/PGSql.scala b/src/main/scala/com/xebia/functional/vectorstores/postgres/PGSql.scala index 5fe39e7ab..75d5b0911 100644 --- a/src/main/scala/com/xebia/functional/vectorstores/postgres/PGSql.scala +++ b/src/main/scala/com/xebia/functional/vectorstores/postgres/PGSql.scala @@ -29,9 +29,10 @@ object PGSql: """.update def createEmbeddingTable(vectorSize: Int): Update0 = - (fr"CREATE TABLE IF NOT EXISTS langchain4s_embeddings (uuid UUID PRIMARY KEY, collection_id UUID, embedding vector(" ++ Fragment.const( - vectorSize.show - ) ++ fr"), content VARCHAR)").update + (fr"CREATE TABLE IF NOT EXISTS langchain4s_embeddings (uuid UUID PRIMARY KEY, collection_id UUID references langchain4s_collections(uuid), embedding vector(" ++ Fragment + .const( + vectorSize.show + ) ++ fr"), content VARCHAR)").update def addNewCollection(uuid: UUID, collectionName: String): Update0 = sql""" @@ -76,7 +77,7 @@ object PGSql: VALUES ($uuid, $collectionId, ${embedding.data}::vector, $text) """.update - def searchSimilarDocument(e: Embedding, strategy: PGDistanceStrategy, k: Int): Query0[domain.Document] = - (fr"SELECT content FROM langchain4s_embeddings ORDER BY embedding " ++ Fragment + def searchSimilarDocument(e: Embedding, strategy: PGDistanceStrategy, collection: PGCollection, k: Int): Query0[domain.Document] = + (fr"SELECT content FROM langchain4s_embeddings WHERE collection_id = ${collection.uuid} ORDER BY embedding " ++ Fragment .const(strategy.strategy) ++ fr" ${e.data.map(_.toFloat)}::vector limit $k") .query[domain.Document] diff --git a/src/main/scala/com/xebia/functional/vectorstores/postgres/PGVectorStore.scala b/src/main/scala/com/xebia/functional/vectorstores/postgres/PGVectorStore.scala index 8719966a3..6ff6dd33c 100644 --- a/src/main/scala/com/xebia/functional/vectorstores/postgres/PGVectorStore.scala +++ b/src/main/scala/com/xebia/functional/vectorstores/postgres/PGVectorStore.scala @@ -35,7 +35,7 @@ class PGVectorStore[F[_]: Sync]( case true => for collection <- getCollection(collectionName).option - c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFound(collectionName)) + c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFoundError(collectionName)) _ <- deleteCollectionDocs(c.uuid).run _ <- deleteCollection(c.uuid).run yield () @@ -68,7 +68,7 @@ class PGVectorStore[F[_]: Sync]( us <- (for collection <- getCollection(collectionName).option - c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFound(collectionName)) + c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFoundError(collectionName)) uuids <- pairs.traverse((t, e) => memeid4s.UUID.V1.next @@ -88,14 +88,23 @@ class PGVectorStore[F[_]: Sync]( for qEmb <- embeddings.embedQuery(query, requestConfig) docs <- qEmb match - case e +: es => searchSimilarDocument(e, distanceStrategy, k).to[List].transact(xa) - case _ => Sync[F].raiseError(PGErrors.EmbeddingNotGenerated(query)) + case e +: es => + getCollection(collectionName).option + .flatMap { + case Some(col) => searchSimilarDocument(e, distanceStrategy, col, k).to[List] + case None => Sync[ConnectionIO].raiseError(PGErrors.CollectionNotFoundError(collectionName)) + }.transact(xa) + case _ => Sync[F].raiseError(PGErrors.EmbeddingNotGeneratedError(query)) yield docs } def similaritySearchByVector(embedding: Embedding, k: Int): F[List[Document]] = transactor.use { xa => - searchSimilarDocument(embedding, distanceStrategy, k).to[List].transact(xa) + getCollection(collectionName).option + .flatMap { + case Some(col) => searchSimilarDocument(embedding, distanceStrategy, col, k).to[List] + case None => Sync[ConnectionIO].raiseError(PGErrors.CollectionNotFoundError(collectionName)) + }.transact(xa) } object PGVectorStore: diff --git a/src/test/scala/com/xebia/functional/chains/VectorQAChainSpec.scala b/src/test/scala/com/xebia/functional/chains/VectorQAChainSpec.scala index 401b9472f..aa224ff35 100644 --- a/src/test/scala/com/xebia/functional/chains/VectorQAChainSpec.scala +++ b/src/test/scala/com/xebia/functional/chains/VectorQAChainSpec.scala @@ -36,7 +36,7 @@ class VectorQAChainSpec extends CatsEffectSuite: assertIO(result, TestData.outputIDK) } - test("run should fail with a InvalidChainInputsError if the inputs don't match the expected") { + test("run should fail with an InvalidChainInputsError if the inputs don't match the expected") { val vectorStore = VectorStoreMock.make val qa = VectorQAChain.makeWithDefaults[IO](OpenAIClientMock.make, vectorStore, "testing") val result = qa.run(Map("foo" -> "What do you think?")) @@ -45,3 +45,13 @@ class VectorQAChainSpec extends CatsEffectSuite: "The provided inputs (foo) do not match with chain's inputs (question)" )(result) } + + test("run should fail with an InvalidChainInputsError if the input is more than one") { + val vectorStore = VectorStoreMock.make + val qa = VectorQAChain.makeWithDefaults[IO](OpenAIClientMock.make, vectorStore, "testing") + val result = qa.run(Map("question" -> "bla bla bla", "foo" -> "What do you think?")) + + interceptMessageIO[InvalidChainInputsError]( + "The provided inputs (question, foo) do not match with chain's inputs (question)" + )(result) + }