-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- 0.0.5-alpha.119
- 0.0.5-alpha.118
- 0.0.5-alpha.117
- 0.0.5-alpha.116
- 0.0.5-alpha.115
- 0.0.5-alpha.114
- 0.0.5-alpha.113
- 0.0.5-alpha.112
- 0.0.5-alpha.111
- 0.0.5-alpha.110
- 0.0.5-alpha.109
- 0.0.5-alpha.108
- 0.0.5-alpha.107
- 0.0.5-alpha.106
- 0.0.5-alpha.105
- 0.0.5-alpha.104
- 0.0.5-alpha.103
- 0.0.5-alpha.102
- 0.0.5-alpha.101
- 0.0.5-alpha.100
- 0.0.5-alpha.99
- 0.0.5-alpha.98
- 0.0.5-alpha.97
- 0.0.5-alpha.96
- 0.0.5-alpha.95
- 0.0.5-alpha.94
- 0.0.5-alpha.93
- 0.0.5-alpha.92
- 0.0.5-alpha.91
- 0.0.5-alpha.90
- 0.0.5-alpha.89
- 0.0.5-alpha.88
- 0.0.5-alpha.87
- 0.0.5-alpha.86
- 0.0.5-alpha.85
- 0.0.5-alpha.84
- 0.0.5-alpha.83
- 0.0.5-alpha.82
- 0.0.5-alpha.81
- 0.0.5-alpha.80
- 0.0.5-alpha.79
- 0.0.5-alpha.78
- 0.0.5-alpha.77
- 0.0.5-alpha.76
- 0.0.5-alpha.75
- 0.0.5-alpha.74
- 0.0.5-alpha.73
- 0.0.5-alpha.72
- 0.0.5-alpha.71
- 0.0.5-alpha.70
- 0.0.5-alpha.69
- 0.0.5-alpha.68
- 0.0.5-alpha.67
- 0.0.5-alpha.66
- 0.0.5-alpha.65
- 0.0.5-alpha.64
- 0.0.5-alpha.63
- 0.0.5-alpha.62
- 0.0.5-alpha.61
- 0.0.5-alpha.60
- 0.0.5-alpha.59
- 0.0.5-alpha.58
- 0.0.5-alpha.57
- 0.0.5-alpha.56
- 0.0.5-alpha.55
- 0.0.5-alpha.54
- 0.0.5-alpha.53
- 0.0.5-alpha.52
- 0.0.5-alpha.51
- 0.0.5-alpha.50
- 0.0.5-alpha.49
- 0.0.5-alpha.48
- 0.0.5-alpha.47
- 0.0.5-alpha.46
- 0.0.5-alpha.45
- 0.0.5-alpha.44
- 0.0.5-alpha.43
- 0.0.5-alpha.42
- 0.0.5-alpha.41
- 0.0.5-alpha.40
- 0.0.5-alpha.39
- 0.0.5-alpha.38
- 0.0.5-alpha.37
- 0.0.5-alpha.36
- 0.0.5-alpha.35
- 0.0.5-alpha.34
- 0.0.5-alpha.33
- 0.0.5-alpha.32
- 0.0.5-alpha.31
- 0.0.5-alpha.30
- 0.0.5-alpha.29
- 0.0.5-alpha.28
- 0.0.5-alpha.27
- 0.0.5-alpha.26
- 0.0.5-alpha.25
- 0.0.5-alpha.24
- 0.0.5-alpha.23
- 0.0.5-alpha.22
- 0.0.5-alpha.21
- 0.0.5-alpha.20
- 0.0.5-alpha.19
- 0.0.5-alpha.18
- 0.0.5-alpha.17
- 0.0.5-alpha.16
- 0.0.5-alpha.15
- 0.0.5-alpha.14
- 0.0.5-alpha.13
- 0.0.5-alpha.12
- 0.0.5-alpha.11
- 0.0.5-alpha.10
- 0.0.5-alpha.9
- 0.0.5-alpha.8
- 0.0.5-alpha.7
- 0.0.5-alpha.6
- 0.0.5-alpha.5
- 0.0.5-alpha.4
- 0.0.5-alpha.3
- 0.0.5-alpha.2
- 0.0.5-alpha.1
- 0.0.4
- 0.0.4-alpha.104
- 0.0.4-alpha.103
- 0.0.4-alpha.102
- 0.0.4-alpha.101
- 0.0.4-alpha.100
- 0.0.4-alpha.99
- 0.0.4-alpha.98
- 0.0.4-alpha.97
- 0.0.4-alpha.96
- 0.0.4-alpha.95
- 0.0.4-alpha.94
- 0.0.4-alpha.93
- 0.0.4-alpha.92
- 0.0.4-alpha.91
- 0.0.4-alpha.90
- 0.0.4-alpha.89
- 0.0.4-alpha.88
- 0.0.4-alpha.87
- 0.0.4-alpha.86
- 0.0.4-alpha.85
- 0.0.4-alpha.84
- 0.0.4-alpha.83
- 0.0.4-alpha.82
- 0.0.4-alpha.81
- 0.0.4-alpha.80
- 0.0.4-alpha.79
- 0.0.4-alpha.78
- 0.0.4-alpha.77
- 0.0.4-alpha.76
- 0.0.4-alpha.75
- 0.0.4-alpha.74
- 0.0.4-alpha.73
- 0.0.4-alpha.72
- 0.0.4-alpha.71
- 0.0.4-alpha.70
- 0.0.4-alpha.69
- 0.0.4-alpha.68
- 0.0.4-alpha.67
- 0.0.4-alpha.66
- 0.0.4-alpha.65
- 0.0.4-alpha.64
- 0.0.4-alpha.63
- 0.0.4-alpha.62
- 0.0.4-alpha.61
- 0.0.4-alpha.60
- 0.0.4-alpha.59
- 0.0.4-alpha.58
- 0.0.4-alpha.57
- 0.0.4-alpha.56
- 0.0.4-alpha.55
- 0.0.4-alpha.54
- 0.0.4-alpha.53
- 0.0.4-alpha.52
- 0.0.4-alpha.51
- 0.0.4-alpha.50
- 0.0.4-alpha.49
- 0.0.4-alpha.48
- 0.0.4-alpha.47
- 0.0.4-alpha.46
- 0.0.4-alpha.45
- 0.0.4-alpha.44
- 0.0.4-alpha.43
- 0.0.4-alpha.42
- 0.0.4-alpha.41
- 0.0.4-alpha.40
- 0.0.4-alpha.39
- 0.0.4-alpha.38
- 0.0.4-alpha.37
- 0.0.4-alpha.36
- 0.0.4-alpha.35
- 0.0.4-alpha.34
- 0.0.4-alpha.33
- 0.0.4-alpha.32
- 0.0.4-alpha.31
- 0.0.4-alpha.30
- 0.0.4-alpha.29
- 0.0.4-alpha.28
- 0.0.4-alpha.27
- 0.0.4-alpha.26
- 0.0.4-alpha.25
- 0.0.4-alpha.24
- 0.0.4-alpha.23
- 0.0.4-alpha.22
- 0.0.4-alpha.21
- 0.0.4-alpha.20
- 0.0.4-alpha.19
- 0.0.4-alpha.18
- 0.0.4-alpha.17
- 0.0.4-alpha.16
- 0.0.4-alpha.15
- 0.0.4-alpha.14
- 0.0.4-alpha.13
- 0.0.4-alpha.12
- 0.0.4-alpha.11
- 0.0.4-alpha.10
- 0.0.4-alpha.9
- 0.0.4-alpha.8
- 0.0.4-alpha.7
- 0.0.4-alpha.6
- 0.0.4-alpha.5
- 0.0.4-alpha.4
- 0.0.4-alpha.3
- 0.0.4-alpha.2
- 0.0.4-alpha.1
- 0.0.3
- 0.0.3-alpha.51
- 0.0.3-alpha.50
- 0.0.3-alpha.49
- 0.0.3-alpha.48
- 0.0.3-alpha.47
- 0.0.3-alpha.46
- 0.0.3-alpha.45
- 0.0.3-alpha.44
- 0.0.3-alpha.43
- 0.0.3-alpha.42
- 0.0.3-alpha.41
- 0.0.3-alpha.40
- 0.0.3-alpha.39
- 0.0.3-alpha.38
- 0.0.3-alpha.37
- 0.0.3-alpha.36
- 0.0.3-alpha.35
- 0.0.3-alpha.34
- 0.0.3-alpha.33
- 0.0.3-alpha.32
- 0.0.3-alpha.31
- 0.0.3-alpha.30
- 0.0.3-alpha.29
- 0.0.3-alpha.28
- 0.0.3-alpha.27
- 0.0.3-alpha.26
- 0.0.3-alpha.25
- 0.0.3-alpha.24
- 0.0.3-alpha.23
- 0.0.3-alpha.22
- 0.0.3-alpha.21
- 0.0.3-alpha.20
- 0.0.3-alpha.19
- 0.0.3-alpha.18
- 0.0.3-alpha.17
- 0.0.3-alpha.16
- 0.0.3-alpha.15
- 0.0.3-alpha.14
- 0.0.3-alpha.13
- 0.0.3-alpha.12
- 0.0.3-alpha.11
- 0.0.3-alpha.10
- 0.0.3-alpha.9
- 0.0.3-alpha.8
- 0.0.3-alpha.7
- 0.0.3-alpha.6
- 0.0.3-alpha.5
- 0.0.3-alpha.4
- 0.0.3-alpha.3
- 0.0.3-alpha.2
- 0.0.3-alpha.1
- 0.0.2
- 0.0.2-alpha.68
- 0.0.2-alpha.67
- 0.0.2-alpha.66
- 0.0.2-alpha.65
- 0.0.2-alpha.64
- 0.0.2-alpha.63
- 0.0.2-alpha.62
- 0.0.2-alpha.61
- 0.0.2-alpha.60
- 0.0.2-alpha.59
- 0.0.2-alpha.58
- 0.0.2-alpha.57
- 0.0.2-alpha.56
- 0.0.2-alpha.55
- 0.0.2-alpha.54
- 0.0.2-alpha.53
- 0.0.2-alpha.52
- 0.0.2-alpha.51
- 0.0.2-alpha.50
- 0.0.2-alpha.49
- 0.0.2-alpha.48
- 0.0.2-alpha.47
- 0.0.2-alpha.46
- 0.0.2-alpha.45
- 0.0.2-alpha.44
- 0.0.2-alpha.43
- 0.0.2-alpha.42
- 0.0.2-alpha.41
- 0.0.2-alpha.40
- 0.0.2-alpha.39
- 0.0.2-alpha.38
- 0.0.2-alpha.37
- 0.0.2-alpha.36
- 0.0.2-alpha.35
- 0.0.2-alpha.34
- 0.0.2-alpha.33
- 0.0.2-alpha.32
- 0.0.2-alpha.31
- 0.0.2-alpha.30
- 0.0.2-alpha.29
- 0.0.2-alpha.28
- 0.0.2-alpha.27
- 0.0.2-alpha.26
- 0.0.2-alpha.25
- 0.0.2-alpha.24
- 0.0.2-alpha.23
- 0.0.2-alpha.22
- 0.0.2-alpha.21
- 0.0.2-alpha.20
- 0.0.2-alpha.19
- 0.0.2-alpha.18
- 0.0.2-alpha.17
- 0.0.2-alpha.16
- 0.0.2-alpha.15
- 0.0.2-alpha.14
- 0.0.2-alpha.13
- 0.0.2-alpha.12
- 0.0.2-alpha.11
- 0.0.2-alpha.10
- 0.0.2-alpha.9
- 0.0.2-alpha.8
- 0.0.2-alpha.7
- 0.0.2-alpha.6
- 0.0.2-alpha.5
- 0.0.2-alpha.4
- 0.0.2-alpha.3
- 0.0.2-alpha.2
- 0.0.2-alpha.1
- 0.0.1
- 0.0.1-alpha.57
- 0.0.1-alpha.56
- 0.0.1-alpha.55
- 0.0.1-alpha.54
- 0.0.1-alpha.53
- 0.0.1-alpha.52
- 0.0.1-alpha.51
- 0.0.1-alpha.50
- 0.0.1-alpha.49
- 0.0.1-alpha.48
- 0.0.1-alpha.47
- 0.0.1-alpha.46
- 0.0.1-alpha.45
- 0.0.1-alpha.44
- 0.0.1-alpha.43
- 0.0.1-alpha.42
- 0.0.1-alpha.41
- 0.0.1-alpha.40
- 0.0.1-alpha.39
- 0.0.1-alpha.38
- 0.0.1-alpha.37
- 0.0.1-alpha.36
- 0.0.1-alpha.35
- 0.0.1-alpha.34
- 0.0.1-alpha.33
- 0.0.1-alpha.32
- 0.0.1-alpha.31
- 0.0.1-alpha.30
- 0.0.1-alpha.29
- 0.0.1-alpha.28
- 0.0.1-alpha.27
- 0.0.1-alpha.26
- 0.0.1-alpha.25
- 0.0.1-alpha.24
- 0.0.1-alpha.23
- 0.0.1-alpha.22
- 0.0.1-alpha.21
- 0.0.1-alpha.20
- 0.0.1-alpha.19
- 0.0.1-alpha.18
- 0.0.1-alpha.17
- 0.0.1-alpha.16
- 0.0.1-alpha.15
- 0.0.1-alpha.14
- 0.0.1-alpha.13
- 0.0.1-alpha.12
- 0.0.1-alpha.11
- 0.0.1-alpha.10
- 0.0.1-alpha.9
- 0.0.1-alpha.8
- 0.0.1-alpha.7
- 0.0.1-alpha.6
- 0.0.1-alpha.5
- 0.0.1-alpha.4
- 0.0.1-alpha.3
- 0.0.1-alpha.2
- 0.0.1-alpha.1
Showing
13 changed files
with
547 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
src/commonMain/kotlin/com/xebia/functional/embeddings/Embeddings.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package com.xebia.functional.embeddings | ||
|
||
import com.xebia.functional.llm.openai.RequestConfig | ||
|
||
data class Embedding(val data: List<Float>) | ||
|
||
interface Embeddings { | ||
suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding> | ||
suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> | ||
|
||
companion object | ||
} |
51 changes: 51 additions & 0 deletions
51
src/commonMain/kotlin/com/xebia/functional/embeddings/OpenAIEmbeddings.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package com.xebia.functional.embeddings | ||
|
||
import arrow.fx.coroutines.parMap | ||
import arrow.resilience.retry | ||
import com.xebia.functional.env.OpenAIConfig | ||
import com.xebia.functional.llm.openai.EmbeddingRequest | ||
import com.xebia.functional.llm.openai.OpenAIClient | ||
import com.xebia.functional.llm.openai.RequestConfig | ||
import io.github.oshai.KLogger | ||
import kotlin.time.ExperimentalTime | ||
|
||
@ExperimentalTime | ||
class OpenAIEmbeddings( | ||
private val config: OpenAIConfig, | ||
private val oaiClient: OpenAIClient, | ||
private val logger: KLogger | ||
) : Embeddings { | ||
|
||
override suspend fun embedDocuments( | ||
texts: List<String>, | ||
chunkSize: Int?, | ||
requestConfig: RequestConfig | ||
): List<Embedding> = | ||
chunkedEmbedDocuments(texts, chunkSize ?: config.chunkSize, requestConfig) | ||
|
||
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> = | ||
if (text.isNotEmpty()) embedDocuments(listOf(text), null, requestConfig) else emptyList() | ||
|
||
private suspend fun chunkedEmbedDocuments( | ||
texts: List<String>, | ||
chunkSize: Int, | ||
requestConfig: RequestConfig | ||
): List<Embedding> = | ||
if (texts.isEmpty()) emptyList() | ||
else texts.chunked(chunkSize) | ||
.parMap { createEmbeddingWithRetry(it, requestConfig) } | ||
.flatten() | ||
|
||
private suspend fun createEmbeddingWithRetry(texts: List<String>, requestConfig: RequestConfig): List<Embedding> = | ||
kotlin.runCatching { | ||
config.retryConfig.schedule() | ||
.log { retriesSoFar, _ -> logger.warn { "Open AI call failed. So far we have retried $retriesSoFar times." } } | ||
.retry { | ||
oaiClient.createEmbeddings(EmbeddingRequest(requestConfig.model.name, texts, requestConfig.user.id)) | ||
.data.map { Embedding(it.embedding) } | ||
} | ||
}.getOrElse { | ||
logger.warn { "Open AI call failed. Giving up after ${config.retryConfig.maxRetries} retries" } | ||
throw it | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
package com.xebia.functional | ||
|
||
data class Document(val content: String) |
45 changes: 45 additions & 0 deletions
45
src/commonMain/kotlin/com/xebia/functional/vectorstores/VectorStore.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package com.xebia.functional.vectorstores | ||
|
||
import com.xebia.functional.Document | ||
import com.xebia.functional.embeddings.Embedding | ||
import kotlin.jvm.JvmInline | ||
import kotlinx.uuid.UUID | ||
|
||
@JvmInline | ||
value class DocumentVectorId(val id: UUID) | ||
|
||
interface VectorStore { | ||
/** | ||
* Add texts to the vector store after running them through the embeddings | ||
* | ||
* @param texts list of text to add to the vector store | ||
* @return a list of IDs from adding the texts to the vector store | ||
*/ | ||
suspend fun addTexts(texts: List<String>): List<DocumentVectorId> | ||
|
||
/** | ||
* Add documents to the vector store after running them through the embeddings | ||
* | ||
* @param documents list of Documents to add to the vector store | ||
* @return a list of IDs from adding the documents to the vector store | ||
*/ | ||
suspend fun addDocuments(documents: List<Document>): List<DocumentVectorId> | ||
|
||
/** | ||
* Return the docs most similar to the query | ||
* | ||
* @param query text to use to search for similar documents | ||
* @param limit number of documents to return | ||
* @return a list of Documents most similar to query | ||
*/ | ||
suspend fun similaritySearch(query: String, limit: Int): List<Document> | ||
|
||
/** | ||
* Return the docs most similar to the embedding | ||
* | ||
* @param embedding embedding vector to use to search for similar documents | ||
* @param limit number of documents to return | ||
* @return list of Documents most similar to the embedding | ||
*/ | ||
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document> | ||
} |
76 changes: 76 additions & 0 deletions
76
src/commonMain/kotlin/com/xebia/functional/vectorstores/postgres.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package com.xebia.functional.vectorstores | ||
|
||
import kotlinx.uuid.UUID | ||
|
||
data class PGCollection(val uuid: UUID, val collectionName: String) | ||
|
||
enum class PGDistanceStrategy(val strategy: String) { | ||
Euclidean("<->"), InnerProduct("<#>"), CosineDistance("<=>") | ||
} | ||
|
||
val createCollections: String = | ||
"""CREATE TABLE langchain4k_collections ( | ||
uuid TEXT PRIMARY KEY, | ||
name TEXT UNIQUE NOT NULL | ||
);""".trimIndent() | ||
|
||
val createEmbeddings: String = | ||
"""CREATE TABLE langchain4k_embeddings ( | ||
uuid TEXT PRIMARY KEY, | ||
collection_id TEXT REFERENCES langchain4k_collections(uuid), | ||
embedding BLOB, | ||
content TEXT | ||
);""".trimIndent() | ||
|
||
val addVectorExtension: String = | ||
"CREATE EXTENSION IF NOT EXISTS vector;" | ||
|
||
val createCollectionsTable: String = | ||
"""CREATE TABLE IF NOT EXISTS langchain4k_collections ( | ||
uuid TEXT PRIMARY KEY, | ||
name TEXT UNIQUE NOT NULL | ||
);""".trimIndent() | ||
|
||
fun createEmbeddingTable(vectorSize: Int): String = | ||
"""CREATE TABLE IF NOT EXISTS langchain4k_embeddings ( | ||
uuid TEXT PRIMARY KEY, | ||
collection_id TEXT REFERENCES langchain4k_collections(uuid), | ||
embedding vector($vectorSize), | ||
content TEXT | ||
);""".trimIndent() | ||
|
||
val addNewCollection: String = | ||
"""INSERT INTO langchain4k_collections(uuid, name) | ||
VALUES (?, ?) | ||
ON CONFLICT DO NOTHING;""".trimIndent() | ||
|
||
val deleteCollection: String = | ||
"""DELETE FROM langchain4k_collections | ||
WHERE uuid = ?;""".trimIndent() | ||
|
||
val getCollection: String = | ||
"""SELECT * FROM langchain4k_collections | ||
WHERE name = ?;""".trimIndent() | ||
|
||
val getCollectionById: String = | ||
"""SELECT * FROM langchain4k_collections | ||
WHERE uuid = ?;""".trimIndent() | ||
|
||
val addNewDocument: String = | ||
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content) | ||
VALUES (?, ?, ?, ?);""".trimIndent() | ||
|
||
val deleteCollectionDocs: String = | ||
"""DELETE FROM langchain4k_embeddings | ||
WHERE collection_id = ?;""".trimIndent() | ||
|
||
val addNewText: String = | ||
"""INSERT INTO langchain4k_embeddings(uuid, collection_id, embedding, content) | ||
VALUES (?, ?, ?::vector, ?);""".trimIndent() | ||
|
||
fun searchSimilarDocument(distance: PGDistanceStrategy): String = | ||
"""SELECT content FROM langchain4k_embeddings | ||
WHERE collection_id = ? | ||
ORDER BY embedding | ||
${distance.strategy} ?::vector | ||
LIMIT ?;""".trimIndent() |
23 changes: 23 additions & 0 deletions
23
src/commonTest/kotlin/com/xebia/functional/embeddings/Mock.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package com.xebia.functional.embeddings | ||
|
||
import com.xebia.functional.llm.openai.RequestConfig | ||
|
||
fun Embeddings.Companion.mock( | ||
embedDocuments: suspend (texts: List<String>, chunkSize: Int?, config: RequestConfig) -> List<Embedding> = { _, _, _ -> | ||
listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f))) | ||
}, | ||
embedQuery: suspend (text: String, config: RequestConfig) -> List<Embedding> = { text, _ -> | ||
when (text) { | ||
"foo" -> listOf(Embedding(listOf(1.0f, 2.0f, 3.0f))) | ||
"bar" -> listOf(Embedding(listOf(4.0f, 5.0f, 6.0f))) | ||
"baz" -> listOf() | ||
else -> listOf() | ||
} | ||
} | ||
): Embeddings = object : Embeddings { | ||
override suspend fun embedDocuments(texts: List<String>, chunkSize: Int?, requestConfig: RequestConfig): List<Embedding> = | ||
embedDocuments(texts, chunkSize, requestConfig) | ||
|
||
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> = | ||
embedQuery(text, requestConfig) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
package com.xebia.functional | ||
|
||
import arrow.core.raise.NullableRaise | ||
import arrow.core.raise.nullable | ||
import arrow.fx.coroutines.ResourceScope | ||
import arrow.fx.coroutines.autoCloseable | ||
import arrow.fx.coroutines.resourceScope | ||
import java.sql.Connection | ||
import java.sql.PreparedStatement | ||
import java.sql.ResultSet | ||
import java.sql.Types | ||
import javax.sql.DataSource | ||
|
||
suspend fun <A> DataSource.connection(block: suspend JDBCSyntax.() -> A): A = | ||
resourceScope { | ||
val conn = autoCloseable { connection } | ||
JDBCSyntax(conn, this).block() | ||
} | ||
|
||
class JDBCSyntax(conn: Connection, resourceScope: ResourceScope) : ResourceScope by resourceScope, Connection by conn { | ||
|
||
suspend fun prepareStatement( | ||
sql: String, | ||
binders: (SqlPreparedStatement.() -> Unit)? = null | ||
): PreparedStatement = autoCloseable { | ||
prepareStatement(sql) | ||
.apply { if (binders != null) SqlPreparedStatement(this).binders() } | ||
} | ||
|
||
suspend fun update( | ||
sql: String, | ||
binders: (SqlPreparedStatement.() -> Unit)? = null, | ||
): Unit { | ||
val statement = prepareStatement(sql, binders) | ||
statement.executeUpdate() | ||
} | ||
|
||
suspend fun <A> queryOneOrNull( | ||
sql: String, | ||
binders: (SqlPreparedStatement.() -> Unit)? = null, | ||
mapper: NullableSqlCursor.() -> A | ||
): A? { | ||
val statement = prepareStatement(sql, binders) | ||
val rs = autoCloseable { statement.executeQuery() } | ||
return if (rs.next()) nullable { mapper(NullableSqlCursor(rs, this)) } | ||
else null | ||
} | ||
|
||
suspend fun <A> queryAsList( | ||
sql: String, | ||
binders: (SqlPreparedStatement.() -> Unit)? = null, | ||
mapper: NullableSqlCursor.() -> A? | ||
): List<A> { | ||
val statement = prepareStatement(sql, binders) | ||
val rs = autoCloseable { statement.executeQuery() } | ||
return buildList { | ||
while (rs.next()) { | ||
nullable { mapper(NullableSqlCursor(rs, this)) }?.let(::add) | ||
} | ||
} | ||
} | ||
|
||
class SqlPreparedStatement(private val preparedStatement: PreparedStatement) { | ||
private var index: Int = 1 | ||
|
||
fun bind(short: Short?): Unit = bind(short?.toLong()) | ||
fun bind(byte: Byte?): Unit = bind(byte?.toLong()) | ||
fun bind(int: Int?): Unit = bind(int?.toLong()) | ||
fun bind(char: Char?): Unit = bind(char?.toString()) | ||
|
||
fun bind(bytes: ByteArray?): Unit = | ||
if (bytes == null) preparedStatement.setNull(index++, Types.BLOB) | ||
else preparedStatement.setBytes(index++, bytes) | ||
|
||
fun bind(long: Long?): Unit = | ||
if (long == null) preparedStatement.setNull(index++, Types.INTEGER) | ||
else preparedStatement.setLong(index++, long) | ||
|
||
fun bind(double: Double?): Unit = | ||
if (double == null) preparedStatement.setNull(index++, Types.REAL) | ||
else preparedStatement.setDouble(index++, double) | ||
|
||
fun bind(string: String?): Unit = | ||
if (string == null) preparedStatement.setNull(index++, Types.VARCHAR) | ||
else preparedStatement.setString(index++, string) | ||
} | ||
|
||
class SqlCursor(private val resultSet: ResultSet) { | ||
private var index: Int = 1 | ||
fun int(): Int? = long()?.toInt() | ||
fun string(): String? = resultSet.getString(index++) | ||
fun bytes(): ByteArray? = resultSet.getBytes(index++) | ||
fun long(): Long? = resultSet.getLong(index++).takeUnless { resultSet.wasNull() } | ||
fun double(): Double? = resultSet.getDouble(index++).takeUnless { resultSet.wasNull() } | ||
fun nextRow(): Boolean = resultSet.next() | ||
} | ||
|
||
class NullableSqlCursor(private val resultSet: ResultSet, private val raise: NullableRaise) { | ||
private var index: Int = 1 | ||
fun int(): Int = long().toInt() | ||
fun string(): String = raise.ensureNotNull(resultSet.getString(index++)) | ||
fun bytes(): ByteArray = raise.ensureNotNull(resultSet.getBytes(index++)) | ||
fun long(): Long = raise.ensureNotNull(resultSet.getLong(index++).takeUnless { resultSet.wasNull() }) | ||
fun double(): Double = raise.ensureNotNull(resultSet.getDouble(index++).takeUnless { resultSet.wasNull() }) | ||
fun nextRow(): Boolean = resultSet.next() | ||
} | ||
} |
101 changes: 101 additions & 0 deletions
101
src/jvmMain/kotlin/com/xebia/functional/PGVectorStore.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package com.xebia.functional | ||
|
||
import com.xebia.functional.embeddings.Embedding | ||
import com.xebia.functional.embeddings.Embeddings | ||
import com.xebia.functional.llm.openai.RequestConfig | ||
import com.xebia.functional.vectorstores.DocumentVectorId | ||
import com.xebia.functional.vectorstores.PGCollection | ||
import com.xebia.functional.vectorstores.PGDistanceStrategy | ||
import com.xebia.functional.vectorstores.VectorStore | ||
import com.xebia.functional.vectorstores.addNewCollection | ||
import com.xebia.functional.vectorstores.addNewText | ||
import com.xebia.functional.vectorstores.addVectorExtension | ||
import com.xebia.functional.vectorstores.createCollectionsTable | ||
import com.xebia.functional.vectorstores.createEmbeddingTable | ||
import com.xebia.functional.vectorstores.deleteCollection | ||
import com.xebia.functional.vectorstores.deleteCollectionDocs | ||
import com.xebia.functional.vectorstores.getCollection | ||
import com.xebia.functional.vectorstores.searchSimilarDocument | ||
import javax.sql.DataSource | ||
import kotlinx.uuid.UUID | ||
import kotlinx.uuid.generateUUID | ||
|
||
class PGVectorStore( | ||
private val vectorSize: Int, | ||
private val dataSource: DataSource, | ||
private val embeddings: Embeddings, | ||
private val collectionName: String, | ||
private val distanceStrategy: PGDistanceStrategy, | ||
private val preDeleteCollection: Boolean, | ||
private val requestConfig: RequestConfig, | ||
private val chunckSize: Int? | ||
) : VectorStore { | ||
|
||
suspend fun JDBCSyntax.getCollection(collectionName: String): PGCollection = | ||
queryOneOrNull(getCollection, | ||
{ bind(collectionName) } | ||
) { PGCollection(UUID(string()), string()) } | ||
?: throw IllegalStateException("Collection '$collectionName' not found") | ||
|
||
suspend fun JDBCSyntax.deleteCollection() { | ||
if (preDeleteCollection) { | ||
val collection = getCollection(collectionName) | ||
update(deleteCollectionDocs) { bind(collection.uuid.toString()) } | ||
update(deleteCollection) { bind(collection.uuid.toString()) } | ||
} | ||
} | ||
|
||
suspend fun initialDbSetup(): Unit = dataSource.connection { | ||
update(addVectorExtension) | ||
update(createCollectionsTable) | ||
update(createEmbeddingTable(vectorSize)) | ||
deleteCollection() | ||
} | ||
|
||
suspend fun createCollection(): Unit = dataSource.connection { | ||
val xa = UUID.generateUUID() | ||
update(addNewCollection) { | ||
bind(xa.toString()) | ||
bind(collectionName) | ||
} | ||
} | ||
|
||
override suspend fun addTexts(texts: List<String>): List<DocumentVectorId> = dataSource.connection { | ||
val embeddings = embeddings.embedDocuments(texts, chunckSize, requestConfig) | ||
val collection = getCollection(collectionName) | ||
texts.zip(embeddings) { text, embedding -> | ||
val uuid = UUID.generateUUID() | ||
update(addNewText) { | ||
bind(uuid.toString()) | ||
bind(collection.uuid.toString()) | ||
bind(embedding.data.toString()) | ||
bind(text) | ||
} | ||
DocumentVectorId(uuid) | ||
} | ||
} | ||
|
||
override suspend fun addDocuments(documents: List<Document>): List<DocumentVectorId> = | ||
addTexts(documents.map(Document::content)) | ||
|
||
override suspend fun similaritySearch(query: String, limit: Int): List<Document> = dataSource.connection { | ||
val embeddings = embeddings.embedQuery(query, requestConfig) | ||
.ifEmpty { throw IllegalStateException("Embedding for text: '$query', has not been properly generated") } | ||
val collection = getCollection(collectionName) | ||
queryAsList(searchSimilarDocument(distanceStrategy), { | ||
bind(collection.uuid.toString()) | ||
bind(embeddings[0].data.toString()) | ||
bind(limit) | ||
}) { Document(string()) } | ||
} | ||
|
||
override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document> = | ||
dataSource.connection { | ||
val collection = getCollection(collectionName) | ||
queryAsList(searchSimilarDocument(distanceStrategy), { | ||
bind(collection.uuid.toString()) | ||
bind(embedding.data.toString()) | ||
bind(limit) | ||
}) { Document(string()) } | ||
} | ||
} |
91 changes: 91 additions & 0 deletions
91
src/jvmTest/kotlin/com/xebia/functional/PGVectorStoreSpec.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package com.xebia.functional | ||
|
||
import com.xebia.functional.embeddings.Embedding | ||
import com.xebia.functional.embeddings.Embeddings | ||
import com.xebia.functional.embeddings.mock | ||
import com.xebia.functional.llm.openai.EmbeddingModel | ||
import com.xebia.functional.llm.openai.RequestConfig | ||
import com.xebia.functional.vectorstores.PGDistanceStrategy | ||
import com.zaxxer.hikari.HikariConfig | ||
import com.zaxxer.hikari.HikariDataSource | ||
import io.kotest.core.extensions.install | ||
import io.kotest.core.spec.style.StringSpec | ||
import io.kotest.extensions.testcontainers.SharedTestContainerExtension | ||
import io.kotest.matchers.shouldBe | ||
import org.junit.jupiter.api.assertThrows | ||
import org.testcontainers.containers.PostgreSQLContainer | ||
import org.testcontainers.utility.DockerImageName | ||
|
||
val postgres: PostgreSQLContainer<Nothing> = | ||
PostgreSQLContainer(DockerImageName.parse("ankane/pgvector").asCompatibleSubstituteFor("postgres")) | ||
|
||
class PGVectorStoreSpec : StringSpec({ | ||
|
||
val container = install(SharedTestContainerExtension(postgres)) | ||
val dataSource = autoClose(HikariDataSource(HikariConfig().apply { | ||
jdbcUrl = container.jdbcUrl | ||
username = container.username | ||
password = container.password | ||
driverClassName = "org.postgresql.Driver" | ||
})) | ||
|
||
val pg = PGVectorStore( | ||
vectorSize = 3, | ||
dataSource = dataSource, | ||
embeddings = Embeddings.mock(), | ||
collectionName = "test_collection", | ||
distanceStrategy = PGDistanceStrategy.Euclidean, | ||
preDeleteCollection = false, | ||
requestConfig = RequestConfig(EmbeddingModel.TextEmbeddingAda002, RequestConfig.Companion.User("user")), | ||
chunckSize = null | ||
) | ||
|
||
"initialDbSetup should configure the DB properly" { | ||
pg.initialDbSetup() | ||
} | ||
|
||
"addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB" { | ||
assertThrows<IllegalStateException> { | ||
pg.addTexts(listOf("foo", "bar")) | ||
}.message shouldBe "Collection 'test_collection' not found" | ||
} | ||
|
||
"similaritySearch should fail with a CollectionNotFoundError if collection isn't present in the DB" { | ||
assertThrows<IllegalStateException> { | ||
pg.similaritySearch("foo", 2) | ||
}.message shouldBe "Collection 'test_collection' not found" | ||
} | ||
|
||
"createCollection should create collection" { | ||
pg.createCollection() | ||
} | ||
|
||
"addTexts should return a list of 2 elements" { | ||
pg.addTexts(listOf("foo", "bar")).size shouldBe 2 | ||
} | ||
|
||
"similaritySearchByVector should return both documents" { | ||
pg.similaritySearchByVector(Embedding(listOf(4.0f, 5.0f, 6.0f)), 2) shouldBe listOf( | ||
Document("bar"), | ||
Document("foo") | ||
) | ||
} | ||
|
||
"addDocuments should return a list of 2 elements" { | ||
pg.addDocuments(listOf(Document("foo"), Document("bar"))).size shouldBe 2 | ||
} | ||
|
||
"similaritySearch should return 2 documents" { | ||
pg.similaritySearch("foo", 2).size shouldBe 2 | ||
} | ||
|
||
"similaritySearch should fail when embedding vector is empty" { | ||
assertThrows<IllegalStateException> { | ||
pg.similaritySearch("baz", 2) | ||
}.message shouldBe "Embedding for text: 'baz', has not been properly generated" | ||
} | ||
|
||
"similaritySearchByVector should return document" { | ||
pg.similaritySearchByVector(Embedding(listOf(1.0f, 2.0f, 3.0f)), 1) shouldBe listOf(Document("foo")) | ||
} | ||
}) |