Skip to content

Commit

Permalink
JDBC Postgres VectorStore (#11)
Browse files Browse the repository at this point in the history
nomisRev authored Apr 27, 2023
1 parent 216682a commit 24f4d60
Showing 13 changed files with 547 additions and 5 deletions.
10 changes: 10 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
@@ -51,6 +51,8 @@ kotlin {
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
implementation(libs.okio)
implementation(libs.uuid)
implementation(libs.klogging)
}
}

@@ -63,9 +65,17 @@ kotlin {
implementation(libs.kotest.assertions.arrow)
}
}
val jvmMain by getting {
dependencies {
implementation(libs.hikari)
implementation(libs.postgresql)
}
}
val jvmTest by getting {
dependencies {
implementation(libs.kotest.junit5)
implementation(libs.kotest.testcontainers)
implementation(libs.testcontainers.postgresql)
}
}
}
12 changes: 12 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
@@ -7,7 +7,13 @@ ktor = "2.2.2"
spotless = "6.18.0"
okio = "3.3.0"
kotest = "5.5.4"
kotest-testcontainers = "1.3.4"
kotest-arrow = "1.3.0"
klogging = "4.0.0-beta-22"
uuid = "0.0.18"
postgresql = "42.5.1"
testcontainers = "1.17.6"
hikari = "5.0.1"

[libraries]
arrow-fx = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref = "arrow" }
@@ -23,7 +29,13 @@ kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref =
kotest-framework = { module = "io.kotest:kotest-framework-engine", version.ref = "kotest" }
kotest-property = { module = "io.kotest:kotest-property", version.ref = "kotest" }
kotest-junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" }
kotest-testcontainers = { module = "io.kotest.extensions:kotest-extensions-testcontainers", version.ref = "kotest-testcontainers" }
kotest-assertions-arrow = { module = "io.kotest.extensions:kotest-assertions-arrow", version.ref = "kotest-arrow" }
uuid = { module = "app.softwork:kotlinx-uuid-core", version.ref = "uuid" }
klogging = { module = "io.github.oshai:kotlin-logging", version.ref = "klogging" }
hikari = { module = "com.zaxxer:HikariCP", version.ref = "hikari" }
postgresql = { module = "org.postgresql:postgresql", version.ref = "postgresql" }
testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" }

[bundles]
ktor-client = [
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.xebia.functional.embeddings

import com.xebia.functional.llm.openai.RequestConfig

data class Embedding(val data: List<Float>)

interface Embeddings {
suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding>
suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding>

companion object
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.xebia.functional.embeddings

import arrow.fx.coroutines.parMap
import arrow.resilience.retry
import com.xebia.functional.env.OpenAIConfig
import com.xebia.functional.llm.openai.EmbeddingRequest
import com.xebia.functional.llm.openai.OpenAIClient
import com.xebia.functional.llm.openai.RequestConfig
import io.github.oshai.KLogger
import kotlin.time.ExperimentalTime

@ExperimentalTime
class OpenAIEmbeddings(
private val config: OpenAIConfig,
private val oaiClient: OpenAIClient,
private val logger: KLogger
) : Embeddings {

override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> =
chunkedEmbedDocuments(texts, chunkSize ?: config.chunkSize, requestConfig)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList()

private suspend fun chunkedEmbedDocuments(
texts: List<String>,
chunkSize: Int,
requestConfig: RequestConfig
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else texts.chunked(chunkSize)
.parMap { createEmbeddingWithRetry(it, requestConfig) }
.flatten()

private suspend fun createEmbeddingWithRetry(texts: List<String>, requestConfig: RequestConfig): List<Embedding> =
kotlin.runCatching {
config.retryConfig.schedule()
.log { retriesSoFar, _ -> logger.warn { "Open AI call failed. So far we have retried $retriesSoFar times." } }
.retry {
oaiClient.createEmbeddings(EmbeddingRequest(requestConfig.model.name, texts, requestConfig.user.id))
.data.map { Embedding(it.embedding) }
}
}.getOrElse {
logger.warn { "Open AI call failed. Giving up after ${config.retryConfig.maxRetries} retries" }
throw it
}
}
6 changes: 2 additions & 4 deletions src/commonMain/kotlin/com/xebia/functional/env/config.kt
Original file line number Diff line number Diff line change
@@ -17,11 +17,9 @@ data class Env(val openAI: OpenAIConfig, val huggingFace: HuggingFaceConfig)
data class OpenAIConfig(val token: String, val baseUrl: KUrl, val chunkSize: Int, val retryConfig: RetryConfig)

data class RetryConfig(val backoff: Duration, val maxRetries: Long) {
fun schedule(): Schedule<Throwable, Unit> =
fun schedule(): Schedule<Throwable, Long> =
Schedule.recurs<Throwable>(maxRetries)
.and(Schedule.exponential(backoff))
.jittered(0.75, 1.25)
.map { }
.zipLeft(Schedule.exponential<Throwable>(backoff).jittered(0.75, 1.25))
}

data class HuggingFaceConfig(val token: String, val baseUrl: KUrl)
15 changes: 14 additions & 1 deletion src/commonMain/kotlin/com/xebia/functional/llm/openai/models.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
package com.xebia.functional.llm.openai

import kotlin.jvm.JvmInline
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

enum class EmbeddingModel(name: String) {
TextEmbeddingAda002("text-embedding-ada-002")
}

data class RequestConfig(val model: EmbeddingModel, val user: User) {
companion object {
@JvmInline
value class User(val id: String)
}
}


@Serializable
data class CompletionChoice(val text: String, val index: Int, val finishReason: String)

@@ -38,7 +51,7 @@ data class EmbeddingResult(
)

@Serializable
class Embedding(val `object`: String, val embedding: List<Double>, val index: Int)
class Embedding(val `object`: String, val embedding: List<Float>, val index: Int)

@Serializable
data class Usage(
3 changes: 3 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/model.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.xebia.functional

data class Document(val content: String)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.xebia.functional.vectorstores

import com.xebia.functional.Document
import com.xebia.functional.embeddings.Embedding
import kotlin.jvm.JvmInline
import kotlinx.uuid.UUID

@JvmInline
value class DocumentVectorId(val id: UUID)

interface VectorStore {
/**
* Add texts to the vector store after running them through the embeddings
*
* @param texts list of text to add to the vector store
* @return a list of IDs from adding the texts to the vector store
*/
suspend fun addTexts(texts: List<String>): List<DocumentVectorId>

/**
* Add documents to the vector store after running them through the embeddings
*
* @param documents list of Documents to add to the vector store
* @return a list of IDs from adding the documents to the vector store
*/
suspend fun addDocuments(documents: List<Document>): List<DocumentVectorId>

/**
* Return the docs most similar to the query
*
* @param query text to use to search for similar documents
* @param limit number of documents to return
* @return a list of Documents most similar to query
*/
suspend fun similaritySearch(query: String, limit: Int): List<Document>

/**
* Return the docs most similar to the embedding
*
* @param embedding embedding vector to use to search for similar documents
* @param limit number of documents to return
* @return list of Documents most similar to the embedding
*/
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.xebia.functional.vectorstores

import kotlinx.uuid.UUID

data class PGCollection(val uuid: UUID, val collectionName: String)

enum class PGDistanceStrategy(val strategy: String) {
Euclidean("<->"), InnerProduct("<#>"), CosineDistance("<=>")
}

val createCollections: String =
"""CREATE TABLE langchain4k_collections (
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);""".trimIndent()

val createEmbeddings: String =
"""CREATE TABLE langchain4k_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES langchain4k_collections(uuid),
embedding BLOB,
content TEXT
);""".trimIndent()

val addVectorExtension: String =
"CREATE EXTENSION IF NOT EXISTS vector;"

val createCollectionsTable: String =
"""CREATE TABLE IF NOT EXISTS langchain4k_collections (
uuid TEXT PRIMARY KEY,
name TEXT UNIQUE NOT NULL
);""".trimIndent()

fun createEmbeddingTable(vectorSize: Int): String =
"""CREATE TABLE IF NOT EXISTS langchain4k_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES langchain4k_collections(uuid),
embedding vector($vectorSize),
content TEXT
);""".trimIndent()

val addNewCollection: String =
"""INSERT INTO langchain4k_collections(uuid, name)
VALUES (?, ?)
ON CONFLICT DO NOTHING;""".trimIndent()

val deleteCollection: String =
"""DELETE FROM langchain4k_collections
WHERE uuid = ?;""".trimIndent()

val getCollection: String =
"""SELECT * FROM langchain4k_collections
WHERE name = ?;""".trimIndent()

val getCollectionById: String =
"""SELECT * FROM langchain4k_collections
WHERE uuid = ?;""".trimIndent()

val addNewDocument: String =
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content)
VALUES (?, ?, ?, ?);""".trimIndent()

val deleteCollectionDocs: String =
"""DELETE FROM langchain4k_embeddings
WHERE collection_id = ?;""".trimIndent()

val addNewText: String =
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content)
VALUES (?, ?, ?::vector, ?);""".trimIndent()

fun searchSimilarDocument(distance: PGDistanceStrategy): String =
"""SELECT content FROM langchain4k_embeddings
WHERE collection_id = ?
ORDER BY embedding
${distance.strategy} ?::vector
LIMIT ?;""".trimIndent()
23 changes: 23 additions & 0 deletions src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.xebia.functional.embeddings

import com.xebia.functional.llm.openai.RequestConfig

fun Embeddings.Companion.mock(
embedDocuments: suspend (texts: List<String>, chunkSize: Int?, config: RequestConfig) -> List<Embedding> = { _, _, _ ->
listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f)))
},
embedQuery: suspend (text: String, config: RequestConfig) -> List<Embedding> = { text, _ ->
when (text) {
"foo" -> listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)))
"bar" -> listOf(Embedding(listOf(4.0f, 5.0f, 6.0f)))
"baz" -> listOf()
else -> listOf()
}
}
): Embeddings = object : Embeddings {
override suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding> =
embedDocuments(texts, chunkSize, requestConfig)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
embedQuery(text, requestConfig)
}
Loading

0 comments on commit 24f4d60

Please sign in to comment.