Skip to content

Commit

Permalink
Bug fix and minor changes (#19)
Browse files Browse the repository at this point in the history
* Fixes bug in  method

* Improves  method and test coverage

* Attempt to run  and  in the CI pipeline

* Fixes format

* Apply suggestions from code review - removing redundant syntax

Co-authored-by: Juan Pedro Moreno <[email protected]>

---------

Co-authored-by: Juan Pedro Moreno <[email protected]>
  • Loading branch information
jrodrip and juanpedromoreno committed May 2, 2023
1 parent 3fe0a76 commit c8375c9
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,80 +27,91 @@ class PGVectorStoreSpec extends DatabaseSuite:
None
)

test("PGVectorStore - initialDbSetup") {
test("initialDbSetup should configure the DB properly") {
val result: IO[Unit] = pg.initialDbSetup()
assertIO(result, ())
}

test("PGVectorStore - addTexts should fail - collection not found".fail) {
test("addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB") {
val result: IO[List[DocumentVectorId]] = pg.addTexts(TestData.texts)
assertIO(result, List.empty[DocumentVectorId])
interceptMessageIO[PGErrors.CollectionNotFoundError](
"Collection 'test_collection' not found"
)(result)
}

test("PGVectorStore - createCollection") {
test("similaritySearch shoul fail with a CollectionNotFoundError if collection isn't present in the DB") {
val result: IO[List[Document]] = pg.similaritySearch("foo", 2)
interceptMessageIO[PGErrors.CollectionNotFoundError](
"Collection 'test_collection' not found"
)(result)
}

test("createCollection should create collection") {
val result: IO[Int] = pg.createCollection
assertIO(result, 1)
}

test("PGVectorStore - addTexts should return a list of 2 elements") {
test("addTexts should return a list of 2 elements") {
val result: IO[List[DocumentVectorId]] = pg.addTexts(TestData.texts)
assertIO(result.map(_.length), 2)
}

test("PGVectorStore - similaritySearchByVector should return both documents") {
test("similaritySearchByVector should return both documents") {
val result: IO[List[Document]] = pg.similaritySearchByVector(TestData.barEmbedding, 2)
assertIO(result.map(_.map(_.content)), List("bar", "foo"))
}

test("PGVectorStore - addDocuments should return a list of 2 elements") {
test("addDocuments should return a list of 2 elements") {
val result: IO[List[DocumentVectorId]] = pg.addDocuments(TestData.texts.map(Document.apply(_)))
assertIO(result.map(_.length), 2)
}

test("PGVectorStore - similaritySearch should return 2 documents") {
test("similaritySearch should return 2 documents") {
val result: IO[List[Document]] = pg.similaritySearch("foo", 2)
assertIO(result.map(_.length), 2)
}

test("PGVectorStore - similaritySearch should fail when embedding vector is empty".fail) {
test("similaritySearch should fail when embedding vector is empty") {
val result: IO[List[Document]] = pg.similaritySearch("baz", 2)
assertIO(result.map(_.length), 2)
interceptMessageIO[PGErrors.EmbeddingNotGeneratedError](
"Embedding for text: 'baz', has not been properly generated"
)(result)
}

test("PGVectorStore - similaritySearchByVector should return document for 'foo'") {
test("similaritySearchByVector should return document") {
val result: IO[List[Document]] = pg.similaritySearchByVector(TestData.fooEmbedding, 1)
assertIO(result.map(_.map(_.content)), List("foo"))
}

test("PGVectorStore check query - addVectorExtension") {
test("check query - addVectorExtension") {
check(PGSql.addVectorExtension)
}

test("PGVectorStore check query - createCollectionsTable") {
test("check query - createCollectionsTable") {
check(PGSql.createCollectionsTable)
}

test("PGVectorStore check query - createEmbeddingTable") {
test("check query - createEmbeddingTable") {
check(PGSql.createEmbeddingTable(3))
}

test("PGVectorStore check query - addNewCollection") {
test("check query - addNewCollection") {
check(PGSql.addNewCollection(TestData.collectionUUID.id, TestData.collectionName))
}

test("PGVectorStore check query - getCollection") {
test("check query - getCollection") {
check(PGSql.getCollection(TestData.collectionName))
}

test("PGVectorStore check query - getCollectionById") {
test("check query - getCollectionById") {
check(PGSql.getCollectionById(TestData.collectionUUID.id))
}

test("PGVectorStore check query - deleteCollectionDocs") {
test("check query - deleteCollectionDocs") {
check(PGSql.deleteCollectionDocs(TestData.collectionUUID.id))
}

test("PGVectorStore check query - deleteCollection") {
test("check query - deleteCollection") {
check(PGSql.deleteCollection(TestData.collectionUUID.id))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ final case class Config(
) {
def genInputs[F[_]: ApplicativeThrow](inputs: Map[String, String]): F[Map[String, String]] =
(
if ((inputKeys diff inputs.keySet).isEmpty) Some(inputs) else None
if ((inputKeys diff inputs.keySet).isEmpty && (inputs.keySet diff inputKeys).isEmpty) Some(inputs) else None
).liftTo[F](InvalidChainInputsError(inputKeys, inputs))

def genInputsFromString[F[_]: ApplicativeThrow](input: String): F[Map[String, String]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ object PGErrors:
class DatabaseSetupError(reason: String) extends Throwable with NoStackTrace:
override def getMessage: String = s"Error while setting up the database: $reason"

class CollectionNotFound(collectionName: String) extends Throwable with NoStackTrace:
override def getMessage: String = s"Collection $collectionName not found"
class CollectionNotFoundError(collectionName: String) extends Throwable with NoStackTrace:
override def getMessage: String = s"Collection '$collectionName' not found"

class EmbeddingNotGenerated(text: String) extends Throwable with NoStackTrace:
override def getMessage(): String = s"Embedding for text $text has not been properly generated"
class EmbeddingNotGeneratedError(text: String) extends Throwable with NoStackTrace:
override def getMessage(): String = s"Embedding for text: '$text', has not been properly generated"
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ object PGSql:
""".update

def createEmbeddingTable(vectorSize: Int): Update0 =
(fr"CREATE TABLE IF NOT EXISTS langchain4s_embeddings (uuid UUID PRIMARY KEY, collection_id UUID, embedding vector(" ++ Fragment.const(
vectorSize.show
) ++ fr"), content VARCHAR)").update
(fr"CREATE TABLE IF NOT EXISTS langchain4s_embeddings (uuid UUID PRIMARY KEY, collection_id UUID references langchain4s_collections(uuid), embedding vector(" ++ Fragment
.const(
vectorSize.show
) ++ fr"), content VARCHAR)").update

def addNewCollection(uuid: UUID, collectionName: String): Update0 =
sql"""
Expand Down Expand Up @@ -76,7 +77,7 @@ object PGSql:
VALUES ($uuid, $collectionId, ${embedding.data}::vector, $text)
""".update

def searchSimilarDocument(e: Embedding, strategy: PGDistanceStrategy, k: Int): Query0[domain.Document] =
(fr"SELECT content FROM langchain4s_embeddings ORDER BY embedding " ++ Fragment
def searchSimilarDocument(e: Embedding, strategy: PGDistanceStrategy, collection: PGCollection, k: Int): Query0[domain.Document] =
(fr"SELECT content FROM langchain4s_embeddings WHERE collection_id = ${collection.uuid} ORDER BY embedding " ++ Fragment
.const(strategy.strategy) ++ fr" ${e.data.map(_.toFloat)}::vector limit $k")
.query[domain.Document]
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class PGVectorStore[F[_]: Sync](
case true =>
for
collection <- getCollection(collectionName).option
c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFound(collectionName))
c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFoundError(collectionName))
_ <- deleteCollectionDocs(c.uuid).run
_ <- deleteCollection(c.uuid).run
yield ()
Expand Down Expand Up @@ -68,7 +68,7 @@ class PGVectorStore[F[_]: Sync](
us <-
(for
collection <- getCollection(collectionName).option
c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFound(collectionName))
c <- collection.liftTo[ConnectionIO](PGErrors.CollectionNotFoundError(collectionName))
uuids <-
pairs.traverse((t, e) =>
memeid4s.UUID.V1.next
Expand All @@ -88,14 +88,23 @@ class PGVectorStore[F[_]: Sync](
for
qEmb <- embeddings.embedQuery(query, requestConfig)
docs <- qEmb match
case e +: es => searchSimilarDocument(e, distanceStrategy, k).to[List].transact(xa)
case _ => Sync[F].raiseError(PGErrors.EmbeddingNotGenerated(query))
case e +: es =>
getCollection(collectionName).option
.flatMap {
case Some(col) => searchSimilarDocument(e, distanceStrategy, col, k).to[List]
case None => Sync[ConnectionIO].raiseError(PGErrors.CollectionNotFoundError(collectionName))
}.transact(xa)
case _ => Sync[F].raiseError(PGErrors.EmbeddingNotGeneratedError(query))
yield docs
}

def similaritySearchByVector(embedding: Embedding, k: Int): F[List[Document]] =
transactor.use { xa =>
searchSimilarDocument(embedding, distanceStrategy, k).to[List].transact(xa)
getCollection(collectionName).option
.flatMap {
case Some(col) => searchSimilarDocument(embedding, distanceStrategy, col, k).to[List]
case None => Sync[ConnectionIO].raiseError(PGErrors.CollectionNotFoundError(collectionName))
}.transact(xa)
}

object PGVectorStore:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class VectorQAChainSpec extends CatsEffectSuite:
assertIO(result, TestData.outputIDK)
}

test("run should fail with a InvalidChainInputsError if the inputs don't match the expected") {
test("run should fail with an InvalidChainInputsError if the inputs don't match the expected") {
val vectorStore = VectorStoreMock.make
val qa = VectorQAChain.makeWithDefaults[IO](OpenAIClientMock.make, vectorStore, "testing")
val result = qa.run(Map("foo" -> "What do you think?"))
Expand All @@ -45,3 +45,13 @@ class VectorQAChainSpec extends CatsEffectSuite:
"The provided inputs (foo) do not match with chain's inputs (question)"
)(result)
}

test("run should fail with an InvalidChainInputsError if the input is more than one") {
val vectorStore = VectorStoreMock.make
val qa = VectorQAChain.makeWithDefaults[IO](OpenAIClientMock.make, vectorStore, "testing")
val result = qa.run(Map("question" -> "bla bla bla", "foo" -> "What do you think?"))

interceptMessageIO[InvalidChainInputsError](
"The provided inputs (question, foo) do not match with chain's inputs (question)"
)(result)
}

0 comments on commit c8375c9

Please sign in to comment.