Skip to content

Commit

Permalink
JDBC Postgres VectorStore (#11)
Browse files Browse the repository at this point in the history
nomisRev authored Apr 27, 2023
1 parent 216682a commit 24f4d60
Showing 13 changed files with 547 additions and 5 deletions.
10 changes: 10 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
@@ -51,6 +51,8 @@ kotlin {
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
implementation(libs.okio)
implementation(libs.uuid)
implementation(libs.klogging)
}
}

@@ -63,9 +65,17 @@ kotlin {
implementation(libs.kotest.assertions.arrow)
}
}
val jvmMain by getting {
dependencies {
implementation(libs.hikari)
implementation(libs.postgresql)
}
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
implementation(libs.kotest.testcontainers)
implementation(libs.testcontainers.postgresql)
}
}
}
12 changes: 12 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
@@ -7,7 +7,13 @@ ktor = "2.2.2"
spotless = "6.18.0"
okio = "3.3.0"
kotest = "5.5.4"
kotest-testcontainers = "1.3.4"
kotest-arrow = "1.3.0"
klogging = "4.0.0-beta-22"
uuid = "0.0.18"
postgresql = "42.5.1"
testcontainers = "1.17.6"
hikari = "5.0.1"

[libraries]
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
@@ -23,7 +29,13 @@ kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref =
kotest-framework = { module = "io.kotest:kotest-framework-engine", version.ref = "kotest" }
kotest-property = { module = "io.kotest:kotest-property", version.ref = "kotest" }
kotest-junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" }
kotest-testcontainers = { module = "io.kotest.extensions:kotest-extensions-testcontainers", version.ref = "kotest-testcontainers" }
kotest-assertions-arrow = { module = "io.kotest.extensions:kotest-assertions-arrow", version.ref = "kotest-arrow" }
uuid = { module = "app.softwork:kotlinx-uuid-core", version.ref = "uuid" }
klogging = { module = "io.github.oshai:kotlin-logging", version.ref = "klogging" }
hikari = { module = "com.zaxxer:HikariCP", version.ref = "hikari" }
postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" }
testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" }

[bundles]
ktor-client = [
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.xebia.functional.embeddings

import com.xebia.functional.llm.openai.RequestConfig

data class Embedding(val data: List<Float>)

interface Embeddings {
suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding>
suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding>

companion object
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.xebia.functional.embeddings

import arrow.fx.coroutines.parMap
import arrow.resilience.retry
import com.xebia.functional.env.OpenAIConfig
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.RequestConfig
import io.github.oshai.KLogger
import kotlin.time.ExperimentalTime

@ExperimentalTime
class OpenAIEmbeddings(
private val config: OpenAIConfig,
private val oaiClient: OpenAIClient,
private val logger: KLogger
) : Embeddings {

override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> =
chunkedEmbedDocuments(texts, chunkSize ?: config.chunkSize, requestConfig)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList()

private suspend fun chunkedEmbedDocuments(
texts: List<String>,
chunkSize: Int,
requestConfig: RequestConfig
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else texts.chunked(chunkSize)
.parMap { createEmbeddingWithRetry(it, requestConfig) }
.flatten()

private suspend fun createEmbeddingWithRetry(texts: List<String>, requestConfig: RequestConfig): List<Embedding> =
kotlin.runCatching {
config.retryConfig.schedule()
.log { retriesSoFar, _ -> logger.warn { "Open AI call failed. So far we have retried $retriesSoFar times." } }
.retry {
oaiClient.createEmbeddings(EmbeddingRequest(requestConfig.model.name, texts, requestConfig.user.id))
.data.map { Embedding(it.embedding) }
}
}.getOrElse {
logger.warn { "Open AI call failed. Giving up after ${config.retryConfig.maxRetries} retries" }
throw it
}
}
6 changes: 2 additions & 4 deletions src/commonMain/kotlin/com/xebia/functional/env/config.kt
Original file line number Diff line number Diff line change
@@ -17,11 +17,9 @@ data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig)
data class OpenAIConfig(val token: String, val baseUrl: KUrl, val chunkSize: Int, val retryConfig: RetryConfig)

data class RetryConfig(val backoff: Duration, val maxRetries: Long) {
fun schedule(): Schedule<Throwable, Unit> =
fun schedule(): Schedule<Throwable, Long> =
Schedule.recurs<Throwable>(maxRetries)
.and(Schedule.exponential(backoff))
.jittered(0.75, 1.25)
.map { }
.zipLeft(Schedule.exponential<Throwable>(backoff).jittered(0.75, 1.25))
}

data class HuggingFaceConfig(val token: String, val baseUrl: KUrl)
15 changes: 14 additions & 1 deletion src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
package com.xebia.functional.llm.openai

import kotlin.jvm.JvmInline
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

enum class EmbeddingModel(name: String) {
TextEmbeddingAda002("text-embedding-ada-002")
}

data class RequestConfig(val model: EmbeddingModel, val user: User) {
companion object {
@JvmInline
value class User(val id: String)
}
}


@Serializable
data class CompletionChoice(val text: String, val index: Int, val finishReason: String)

@@ -38,7 +51,7 @@ data class EmbeddingResult(
)

@Serializable
class Embedding(val `object`: String, val embedding: List<Double>, val index: Int)
class Embedding(val `object`: String, val embedding: List<Float>, val index: Int)

@Serializable
data class Usage(
3 changes: 3 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/model.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.xebia.functional

data class Document(val content: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.xebia.functional.vectorstores

import com.xebia.functional.Document
import com.xebia.functional.embeddings.Embedding
import kotlin.jvm.JvmInline
import kotlinx.uuid.UUID

@JvmInline
value class DocumentVectorId(val id: UUID)

interface VectorStore {
/**
* Add texts to the vector store after running them through the embeddings
*
* @param texts list of text to add to the vector store
* @return a list of IDs from adding the texts to the vector store
*/
suspend fun addTexts(texts: List<String>): List<DocumentVectorId>

/**
* Add documents to the vector store after running them through the embeddings
*
* @param documents list of Documents to add to the vector store
* @return a list of IDs from adding the documents to the vector store
*/
suspend fun addDocuments(documents: List<Document>): List<DocumentVectorId>

/**
* Return the docs most similar to the query
*
* @param query text to use to search for similar documents
* @param limit number of documents to return
* @return a list of Documents most similar to query
*/
suspend fun similaritySearch(query: String, limit: Int): List<Document>

/**
* Return the docs most similar to the embedding
*
* @param embedding embedding vector to use to search for similar documents
* @param limit number of documents to return
* @return list of Documents most similar to the embedding
*/
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.xebia.functional.vectorstores

import kotlinx.uuid.UUID

data class PGCollection(val uuid: UUID, val collectionName: String)

enum class PGDistanceStrategy(val strategy: String) {
Euclidean("<->"), InnerProduct("<#>"), CosineDistance("<=>")
}

val createCollections: String =
"""CREATE TABLE langchain4k_collections (
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);""".trimIndent()

val createEmbeddings: String =
"""CREATE TABLE langchain4k_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES langchain4k_collections(uuid),
embedding BLOB,
content TEXT
);""".trimIndent()

val addVectorExtension: String =
"CREATE EXTENSION IF NOT EXISTS vector;"

val createCollectionsTable: String =
"""CREATE TABLE IF NOT EXISTS langchain4k_collections (
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);""".trimIndent()

fun createEmbeddingTable(vectorSize: Int): String =
"""CREATE TABLE IF NOT EXISTS langchain4k_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES langchain4k_collections(uuid),
embedding vector($vectorSize),
content TEXT
);""".trimIndent()

val addNewCollection: String =
"""INSERT INTO langchain4k_collections(uuid, name)
VALUES (?, ?)
ON CONFLICT DO NOTHING;""".trimIndent()

val deleteCollection: String =
"""DELETE FROM langchain4k_collections
WHERE uuid = ?;""".trimIndent()

val getCollection: String =
"""SELECT * FROM langchain4k_collections
WHERE name = ?;""".trimIndent()

val getCollectionById: String =
"""SELECT * FROM langchain4k_collections
WHERE uuid = ?;""".trimIndent()

val addNewDocument: String =
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content)
VALUES (?, ?, ?, ?);""".trimIndent()

val deleteCollectionDocs: String =
"""DELETE FROM langchain4k_embeddings
WHERE collection_id = ?;""".trimIndent()

val addNewText: String =
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content)
VALUES (?, ?, ?::vector, ?);""".trimIndent()

fun searchSimilarDocument(distance: PGDistanceStrategy): String =
"""SELECT content FROM langchain4k_embeddings
WHERE collection_id = ?
ORDER BY embedding
${distance.strategy} ?::vector
LIMIT ?;""".trimIndent()
23 changes: 23 additions & 0 deletions src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.xebia.functional.embeddings

import com.xebia.functional.llm.openai.RequestConfig

fun Embeddings.Companion.mock(
embedDocuments: suspend (texts: List<String>, chunkSize: Int?, config: RequestConfig) -> List<Embedding> = { _, _, _ ->
listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f)))
},
embedQuery: suspend (text: String, config: RequestConfig) -> List<Embedding> = { text, _ ->
when (text) {
"foo" -> listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)))
"bar" -> listOf(Embedding(listOf(4.0f, 5.0f, 6.0f)))
"baz" -> listOf()
else -> listOf()
}
}
): Embeddings = object : Embeddings {
override suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding> =
embedDocuments(texts, chunkSize, requestConfig)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
embedQuery(text, requestConfig)
}
107 changes: 107 additions & 0 deletions src/jvmMain/kotlin/com/xebia/functional/JDBCSyntax.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package com.xebia.functional

import arrow.core.raise.NullableRaise
import arrow.core.raise.nullable
import arrow.fx.coroutines.ResourceScope
import arrow.fx.coroutines.autoCloseable
import arrow.fx.coroutines.resourceScope
import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.sql.Types
import javax.sql.DataSource

suspend fun <A> DataSource.connection(block: suspend JDBCSyntax.() -> A): A =
resourceScope {
val conn = autoCloseable { connection }
JDBCSyntax(conn, this).block()
}

class JDBCSyntax(conn: Connection, resourceScope: ResourceScope) : ResourceScope by resourceScope, Connection by conn {

suspend fun prepareStatement(
sql: String,
binders: (SqlPreparedStatement.() -> Unit)? = null
): PreparedStatement = autoCloseable {
prepareStatement(sql)
.apply { if (binders != null) SqlPreparedStatement(this).binders() }
}

suspend fun update(
sql: String,
binders: (SqlPreparedStatement.() -> Unit)? = null,
): Unit {
val statement = prepareStatement(sql, binders)
statement.executeUpdate()
}

suspend fun <A> queryOneOrNull(
sql: String,
binders: (SqlPreparedStatement.() -> Unit)? = null,
mapper: NullableSqlCursor.() -> A
): A? {
val statement = prepareStatement(sql, binders)
val rs = autoCloseable { statement.executeQuery() }
return if (rs.next()) nullable { mapper(NullableSqlCursor(rs, this)) }
else null
}

suspend fun <A> queryAsList(
sql: String,
binders: (SqlPreparedStatement.() -> Unit)? = null,
mapper: NullableSqlCursor.() -> A?
): List<A> {
val statement = prepareStatement(sql, binders)
val rs = autoCloseable { statement.executeQuery() }
return buildList {
while (rs.next()) {
nullable { mapper(NullableSqlCursor(rs, this)) }?.let(::add)
}
}
}

class SqlPreparedStatement(private val preparedStatement: PreparedStatement) {
private var index: Int = 1

fun bind(short: Short?): Unit = bind(short?.toLong())
fun bind(byte: Byte?): Unit = bind(byte?.toLong())
fun bind(int: Int?): Unit = bind(int?.toLong())
fun bind(char: Char?): Unit = bind(char?.toString())

fun bind(bytes: ByteArray?): Unit =
if (bytes == null) preparedStatement.setNull(index++, Types.BLOB)
else preparedStatement.setBytes(index++, bytes)

fun bind(long: Long?): Unit =
if (long == null) preparedStatement.setNull(index++, Types.INTEGER)
else preparedStatement.setLong(index++, long)

fun bind(double: Double?): Unit =
if (double == null) preparedStatement.setNull(index++, Types.REAL)
else preparedStatement.setDouble(index++, double)

fun bind(string: String?): Unit =
if (string == null) preparedStatement.setNull(index++, Types.VARCHAR)
else preparedStatement.setString(index++, string)
}

class SqlCursor(private val resultSet: ResultSet) {
private var index: Int = 1
fun int(): Int? = long()?.toInt()
fun string(): String? = resultSet.getString(index++)
fun bytes(): ByteArray? = resultSet.getBytes(index++)
fun long(): Long? = resultSet.getLong(index++).takeUnless { resultSet.wasNull() }
fun double(): Double? = resultSet.getDouble(index++).takeUnless { resultSet.wasNull() }
fun nextRow(): Boolean = resultSet.next()
}

class NullableSqlCursor(private val resultSet: ResultSet, private val raise: NullableRaise) {
private var index: Int = 1
fun int(): Int = long().toInt()
fun string(): String = raise.ensureNotNull(resultSet.getString(index++))
fun bytes(): ByteArray = raise.ensureNotNull(resultSet.getBytes(index++))
fun long(): Long = raise.ensureNotNull(resultSet.getLong(index++).takeUnless { resultSet.wasNull() })
fun double(): Double = raise.ensureNotNull(resultSet.getDouble(index++).takeUnless { resultSet.wasNull() })
fun nextRow(): Boolean = resultSet.next()
}
}
101 changes: 101 additions & 0 deletions src/jvmMain/kotlin/com/xebia/functional/PGVectorStore.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package com.xebia.functional

import com.xebia.functional.embeddings.Embedding
import com.xebia.functional.embeddings.Embeddings
import com.xebia.functional.llm.openai.RequestConfig
import com.xebia.functional.vectorstores.DocumentVectorId
import com.xebia.functional.vectorstores.PGCollection
import com.xebia.functional.vectorstores.PGDistanceStrategy
import com.xebia.functional.vectorstores.VectorStore
import com.xebia.functional.vectorstores.addNewCollection
import com.xebia.functional.vectorstores.addNewText
import com.xebia.functional.vectorstores.addVectorExtension
import com.xebia.functional.vectorstores.createCollectionsTable
import com.xebia.functional.vectorstores.createEmbeddingTable
import com.xebia.functional.vectorstores.deleteCollection
import com.xebia.functional.vectorstores.deleteCollectionDocs
import com.xebia.functional.vectorstores.getCollection
import com.xebia.functional.vectorstores.searchSimilarDocument
import javax.sql.DataSource
import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID

class PGVectorStore(
private val vectorSize: Int,
private val dataSource: DataSource,
private val embeddings: Embeddings,
private val collectionName: String,
private val distanceStrategy: PGDistanceStrategy,
private val preDeleteCollection: Boolean,
private val requestConfig: RequestConfig,
private val chunckSize: Int?
) : VectorStore {

suspend fun JDBCSyntax.getCollection(collectionName: String): PGCollection =
queryOneOrNull(getCollection,
{ bind(collectionName) }
) { PGCollection(UUID(string()), string()) }
?: throw IllegalStateException("Collection '$collectionName' not found")

suspend fun JDBCSyntax.deleteCollection() {
if (preDeleteCollection) {
val collection = getCollection(collectionName)
update(deleteCollectionDocs) { bind(collection.uuid.toString()) }
update(deleteCollection) { bind(collection.uuid.toString()) }
}
}

suspend fun initialDbSetup(): Unit = dataSource.connection {
update(addVectorExtension)
update(createCollectionsTable)
update(createEmbeddingTable(vectorSize))
deleteCollection()
}

suspend fun createCollection(): Unit = dataSource.connection {
val xa = UUID.generateUUID()
update(addNewCollection) {
bind(xa.toString())
bind(collectionName)
}
}

override suspend fun addTexts(texts: List<String>): List<DocumentVectorId> = dataSource.connection {
val embeddings = embeddings.embedDocuments(texts, chunckSize, requestConfig)
val collection = getCollection(collectionName)
texts.zip(embeddings) { text, embedding ->
val uuid = UUID.generateUUID()
update(addNewText) {
bind(uuid.toString())
bind(collection.uuid.toString())
bind(embedding.data.toString())
bind(text)
}
DocumentVectorId(uuid)
}
}

override suspend fun addDocuments(documents: List<Document>): List<DocumentVectorId> =
addTexts(documents.map(Document::content))

override suspend fun similaritySearch(query: String, limit: Int): List<Document> = dataSource.connection {
val embeddings = embeddings.embedQuery(query, requestConfig)
.ifEmpty { throw IllegalStateException("Embedding for text: '$query', has not been properly generated") }
val collection = getCollection(collectionName)
queryAsList(searchSimilarDocument(distanceStrategy), {
bind(collection.uuid.toString())
bind(embeddings[0].data.toString())
bind(limit)
}) { Document(string()) }
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document> =
dataSource.connection {
val collection = getCollection(collectionName)
queryAsList(searchSimilarDocument(distanceStrategy), {
bind(collection.uuid.toString())
bind(embedding.data.toString())
bind(limit)
}) { Document(string()) }
}
}
91 changes: 91 additions & 0 deletions src/jvmTest/kotlin/com/xebia/functional/PGVectorStoreSpec.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package com.xebia.functional

import com.xebia.functional.embeddings.Embedding
import com.xebia.functional.embeddings.Embeddings
import com.xebia.functional.embeddings.mock
import com.xebia.functional.llm.openai.EmbeddingModel
import com.xebia.functional.llm.openai.RequestConfig
import com.xebia.functional.vectorstores.PGDistanceStrategy
import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource
import io.kotest.core.extensions.install
import io.kotest.core.spec.style.StringSpec
import io.kotest.extensions.testcontainers.SharedTestContainerExtension
import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.assertThrows
import org.testcontainers.containers.PostgreSQLContainer
import org.testcontainers.utility.DockerImageName

val postgres: PostgreSQLContainer<Nothing> =
PostgreSQLContainer(DockerImageName.parse("ankane/pgvector").asCompatibleSubstituteFor("postgres"))

class PGVectorStoreSpec : StringSpec({

val container = install(SharedTestContainerExtension(postgres))
val dataSource = autoClose(HikariDataSource(HikariConfig().apply {
jdbcUrl = container.jdbcUrl
username = container.username
password = container.password
driverClassName = "org.postgresql.Driver"
}))

val pg = PGVectorStore(
vectorSize = 3,
dataSource = dataSource,
embeddings = Embeddings.mock(),
collectionName = "test_collection",
distanceStrategy = PGDistanceStrategy.Euclidean,
preDeleteCollection = false,
requestConfig = RequestConfig(EmbeddingModel.TextEmbeddingAda002, RequestConfig.Companion.User("user")),
chunckSize = null
)

"initialDbSetup should configure the DB properly" {
pg.initialDbSetup()
}

"addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB" {
assertThrows<IllegalStateException> {
pg.addTexts(listOf("foo", "bar"))
}.message shouldBe "Collection 'test_collection' not found"
}

"similaritySearch should fail with a CollectionNotFoundError if collection isn't present in the DB" {
assertThrows<IllegalStateException> {
pg.similaritySearch("foo", 2)
}.message shouldBe "Collection 'test_collection' not found"
}

"createCollection should create collection" {
pg.createCollection()
}

"addTexts should return a list of 2 elements" {
pg.addTexts(listOf("foo", "bar")).size shouldBe 2
}

"similaritySearchByVector should return both documents" {
pg.similaritySearchByVector(Embedding(listOf(4.0f, 5.0f, 6.0f)), 2) shouldBe listOf(
Document("bar"),
Document("foo")
)
}

"addDocuments should return a list of 2 elements" {
pg.addDocuments(listOf(Document("foo"), Document("bar"))).size shouldBe 2
}

"similaritySearch should return 2 documents" {
pg.similaritySearch("foo", 2).size shouldBe 2
}

"similaritySearch should fail when embedding vector is empty" {
assertThrows<IllegalStateException> {
pg.similaritySearch("baz", 2)
}.message shouldBe "Embedding for text: 'baz', has not been properly generated"
}

"similaritySearchByVector should return document" {
pg.similaritySearchByVector(Embedding(listOf(1.0f, 2.0f, 3.0f)), 1) shouldBe listOf(Document("foo"))
}
})

0 comments on commit 24f4d60

Please sign in to comment.