Skip to content

Commit 378db71

Browse files
committed
Merge branch 'main' into server-embeddings-endpoint
2 parents 383b218 + de40a7b commit 378db71

File tree

23 files changed

+93
-136
lines changed

23 files changed

+93
-136
lines changed

core/src/commonMain/kotlin/com/xebia/functional/xef/embeddings/Embeddings.kt

-17
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
package com.xebia.functional.xef.llm
22

3+
import arrow.fx.coroutines.parMap
4+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
35
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
46
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
7+
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
58

69
interface Embeddings : LLM {
710
suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult
11+
12+
suspend fun embedDocuments(
13+
texts: List<String>,
14+
requestConfig: RequestConfig,
15+
chunkSize: Int?
16+
): List<Embedding> =
17+
if (texts.isEmpty()) emptyList()
18+
else
19+
texts
20+
.chunked(chunkSize ?: 400)
21+
.parMap { createEmbeddings(EmbeddingRequest(name, texts, requestConfig.user.id)).data }
22+
.flatten()
23+
24+
suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
25+
if (text.isNotEmpty()) embedDocuments(listOf(text), requestConfig, null) else emptyList()
26+
27+
companion object
828
}
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
package com.xebia.functional.xef.llm.models.embeddings
22

3-
class Embedding(val `object`: String, val embedding: List<Float>, val index: Int)
3+
class Embedding(val embedding: List<Float>)

core/src/commonMain/kotlin/com/xebia/functional/xef/store/CombinedVectorStore.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.xebia.functional.xef.store
22

3-
import com.xebia.functional.xef.embeddings.Embedding
3+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
44

55
/**
66
* A way of composing two [VectorStore] instances together, this class will **first search** [top],

core/src/commonMain/kotlin/com/xebia/functional/xef/store/LocalVectorStore.kt

+6-7
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ package com.xebia.functional.xef.store
33
import arrow.atomic.Atomic
44
import arrow.atomic.getAndUpdate
55
import arrow.atomic.update
6-
import com.xebia.functional.xef.embeddings.Embedding
7-
import com.xebia.functional.xef.embeddings.Embeddings
6+
import com.xebia.functional.xef.llm.Embeddings
7+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
88
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
99
import kotlin.math.sqrt
1010

@@ -54,8 +54,7 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
5454
}
5555

5656
override suspend fun addTexts(texts: List<String>) {
57-
val embeddingsList =
58-
embeddings.embedDocuments(texts, chunkSize = null, requestConfig = requestConfig)
57+
val embeddingsList = embeddings.embedDocuments(texts, requestConfig = requestConfig, null)
5958
state.getAndUpdate { prevState ->
6059
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
6160
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
@@ -80,9 +79,9 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
8079
}
8180

8281
private fun Embedding.cosineSimilarity(other: Embedding): Double {
83-
val dotProduct = this.data.zip(other.data).sumOf { (a, b) -> (a * b).toDouble() }
84-
val magnitudeA = sqrt(this.data.sumOf { (it * it).toDouble() })
85-
val magnitudeB = sqrt(other.data.sumOf { (it * it).toDouble() })
82+
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b).toDouble() }
83+
val magnitudeA = sqrt(this.embedding.sumOf { (it * it).toDouble() })
84+
val magnitudeB = sqrt(other.embedding.sumOf { (it * it).toDouble() })
8685
return dotProduct / (magnitudeA * magnitudeB)
8786
}
8887
}

core/src/commonMain/kotlin/com/xebia/functional/xef/store/VectorStore.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.xebia.functional.xef.store
22

3-
import com.xebia.functional.xef.embeddings.Embedding
3+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
44
import kotlin.jvm.JvmStatic
55

66
interface VectorStore {
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
package com.xebia.functional.xef.data
22

3-
import com.xebia.functional.xef.embeddings.Embedding
4-
import com.xebia.functional.xef.embeddings.Embeddings
3+
import com.xebia.functional.xef.llm.Embeddings
4+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
5+
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
6+
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
57
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
8+
import com.xebia.functional.xef.llm.models.usage.Usage
69

710
class TestEmbeddings : Embeddings {
11+
12+
override val name: String
13+
get() = "test-embeddings"
14+
815
override suspend fun embedDocuments(
916
texts: List<String>,
10-
chunkSize: Int?,
11-
requestConfig: RequestConfig
17+
requestConfig: RequestConfig,
18+
chunkSize: Int?
1219
): List<Embedding> = emptyList()
1320

1421
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
1522
emptyList()
23+
24+
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
25+
EmbeddingResult(emptyList(), Usage.ZERO)
1626
}

examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/streaming/OpenAIStreamingExample.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ package com.xebia.functional.xef.conversation.streaming
22

33
import com.xebia.functional.xef.conversation.Conversation
44
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
5-
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
65
import com.xebia.functional.xef.llm.Chat
76
import com.xebia.functional.xef.prompt.Prompt
87
import com.xebia.functional.xef.store.LocalVectorStore
98

109
suspend fun main() {
1110
val chat: Chat = OpenAI().DEFAULT_CHAT
12-
val embeddings = OpenAIEmbeddings(OpenAI().DEFAULT_EMBEDDING)
11+
val embeddings = OpenAI().DEFAULT_EMBEDDING
1312
val scope = Conversation(LocalVectorStore(embeddings))
1413
chat.promptStreaming(prompt = Prompt("What is the meaning of life?"), scope = scope).collect {
1514
print(it)

examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/streaming/SpaceCraftLocal.kt

+1-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package com.xebia.functional.xef.conversation.streaming
22

33
import com.xebia.functional.xef.conversation.Conversation
44
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
5-
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
65
import com.xebia.functional.xef.llm.StreamedFunction
76
import com.xebia.functional.xef.prompt.Prompt
87
import com.xebia.functional.xef.store.LocalVectorStore
@@ -17,8 +16,7 @@ suspend fun main() {
1716

1817
val model = OpenAI(host = "http://localhost:8081/").DEFAULT_SERIALIZATION
1918

20-
val scope =
21-
Conversation(LocalVectorStore(OpenAIEmbeddings(OpenAI.FromEnvironment.DEFAULT_EMBEDDING)))
19+
val scope = Conversation(LocalVectorStore(OpenAI.FromEnvironment.DEFAULT_EMBEDDING))
2220

2321
model
2422
.promptStreaming(

gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/Conversation.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ import com.xebia.functional.xef.store.LocalVectorStore
55
import com.xebia.functional.xef.store.VectorStore
66

77
suspend inline fun <A> conversation(
8-
store: VectorStore = LocalVectorStore(HuggingFaceLocalEmbeddings.DEFAULT),
9-
noinline block: suspend Conversation.() -> A
8+
store: VectorStore = LocalVectorStore(HuggingFaceLocalEmbeddings.DEFAULT),
9+
noinline block: suspend Conversation.() -> A
1010
): A = block(Conversation(store))

gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt

+7-11
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
package com.xebia.functional.gpt4all
22

33
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
4-
import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding
5-
import com.xebia.functional.xef.embeddings.Embeddings
4+
import com.xebia.functional.xef.llm.Embeddings
65
import com.xebia.functional.xef.llm.models.embeddings.Embedding
76
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
87
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
98
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
109
import com.xebia.functional.xef.llm.models.usage.Usage
1110

12-
class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.functional.xef.llm.Embeddings, Embeddings {
11+
class HuggingFaceLocalEmbeddings(name: String, artifact: String) : Embeddings {
1312

1413
private val tokenizer = HuggingFaceTokenizer.newInstance("$name/$artifact")
1514

@@ -18,20 +17,17 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
1817
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
1918
val embedings = tokenizer.batchEncode(request.input)
2019
return EmbeddingResult(
21-
data = embedings.mapIndexed { n, em -> Embedding("embedding", em.ids.map { it.toFloat() }, n) },
20+
data = embedings.map { Embedding(it.ids.map { it.toFloat() }) },
2221
usage = Usage.ZERO
2322
)
2423
}
2524

2625
override suspend fun embedDocuments(
2726
texts: List<String>,
28-
chunkSize: Int?,
29-
requestConfig: RequestConfig
30-
): List<XefEmbedding> =
31-
tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) }
32-
33-
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<XefEmbedding> =
34-
embedDocuments(listOf(text), null, requestConfig)
27+
requestConfig: RequestConfig,
28+
chunkSize: Int?
29+
): List<Embedding> =
30+
tokenizer.batchEncode(texts).map { em -> Embedding(em.ids.map { it.toFloat() }) } // TODO we need to remove the index
3531

3632
companion object {
3733
@JvmField

integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GCP.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: St
6161

6262
@JvmSynthetic
6363
suspend fun <A> conversation(block: suspend Conversation.() -> A): A =
64-
block(conversation(LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))))
64+
block(conversation(LocalVectorStore(FromEnvironment.DEFAULT_EMBEDDING)))
6565

6666
@JvmStatic
6767
@JvmOverloads
6868
fun conversation(
69-
store: VectorStore = LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))
69+
store: VectorStore = LocalVectorStore(FromEnvironment.DEFAULT_EMBEDDING)
7070
): PlatformConversation = Conversation(store)
7171
}
7272
}
7373

7474
suspend inline fun <A> GCP.conversation(noinline block: suspend Conversation.() -> A): A =
75-
block(Conversation(LocalVectorStore(GcpEmbeddings(DEFAULT_EMBEDDING))))
75+
block(Conversation(LocalVectorStore(DEFAULT_EMBEDDING)))

integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpEmbeddings.kt

-27
This file was deleted.

integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpModel.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ class GcpModel(modelId: String, config: GcpConfig) : Chat, Completion, AutoClose
9595
}
9696

9797
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
98-
fun requestToEmbedding(index: Int, it: GcpClient.EmbeddingPredictions): Embedding =
99-
Embedding("embedding", it.embeddings.values.map(Double::toFloat), index = index)
98+
fun requestToEmbedding(it: GcpClient.EmbeddingPredictions): Embedding =
99+
Embedding(it.embeddings.values.map(Double::toFloat))
100100

101101
val response = client.embeddings(request)
102102
return EmbeddingResult(
103-
data = response.predictions.mapIndexed(::requestToEmbedding),
103+
data = response.predictions.map(::requestToEmbedding),
104104
usage = usage(response),
105105
)
106106
}

integrations/lucene/src/main/kotlin/com/xebia/functional/xef/store/Lucene.kt

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.xebia.functional.xef.store
22

3-
import com.xebia.functional.xef.embeddings.Embedding
4-
import com.xebia.functional.xef.embeddings.Embeddings
3+
import com.xebia.functional.xef.llm.Embeddings
54
import com.xebia.functional.xef.llm.models.chat.Message
65
import com.xebia.functional.xef.llm.models.chat.Role
6+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
77
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
88
import org.apache.lucene.analysis.standard.StandardAnalyzer
99
import org.apache.lucene.document.Document
@@ -89,7 +89,7 @@ open class Lucene(
8989

9090
override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
9191
requireNotNull(embeddings) { "no embeddings were computed for this model" }
92-
val luceneQuery = KnnFloatVectorQuery("embedding", embedding.data.toFloatArray(), limit)
92+
val luceneQuery = KnnFloatVectorQuery("embedding", embedding.embedding.toFloatArray(), limit)
9393
val searcher = IndexSearcher(DirectoryReader.open(writer))
9494
return searcher.search(luceneQuery, limit).extract(searcher)
9595
}
@@ -150,4 +150,4 @@ fun InMemoryLuceneBuilder(
150150
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity)
151151
}
152152

153-
fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.data }.toFloatArray()
153+
fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.embedding }.toFloatArray()

integrations/postgresql/src/main/kotlin/com/xebia/functional/xef/store/PostgreSQLVectorStore.kt

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.xebia.functional.xef.store
22

3-
import com.xebia.functional.xef.embeddings.Embedding
4-
import com.xebia.functional.xef.embeddings.Embeddings
3+
import com.xebia.functional.xef.llm.Embeddings
54
import com.xebia.functional.xef.llm.models.chat.Message
65
import com.xebia.functional.xef.llm.models.chat.Role
6+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
77
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
88
import com.xebia.functional.xef.store.postgresql.*
99
import kotlinx.uuid.UUID
@@ -95,14 +95,14 @@ class PGVectorStore(
9595

9696
override suspend fun addTexts(texts: List<String>): Unit =
9797
dataSource.connection {
98-
val embeddings = embeddings.embedDocuments(texts, chunkSize, requestConfig)
98+
val embeddings = embeddings.embedDocuments(texts, requestConfig, chunkSize)
9999
val collection = getCollection(collectionName)
100100
texts.zip(embeddings) { text, embedding ->
101101
val uuid = UUID.generateUUID()
102102
update(addNewText) {
103103
bind(uuid.toString())
104104
bind(collection.uuid.toString())
105-
bind(embedding.data.toString())
105+
bind(embedding.embedding.toString())
106106
bind(text)
107107
}
108108
}
@@ -121,7 +121,7 @@ class PGVectorStore(
121121
searchSimilarDocument(distanceStrategy),
122122
{
123123
bind(collection.uuid.toString())
124-
bind(embeddings[0].data.toString())
124+
bind(embeddings[0].embedding.toString())
125125
bind(limit)
126126
}
127127
) {
@@ -136,7 +136,7 @@ class PGVectorStore(
136136
searchSimilarDocument(distanceStrategy),
137137
{
138138
bind(collection.uuid.toString())
139-
bind(embedding.data.toString())
139+
bind(embedding.embedding.toString())
140140
bind(limit)
141141
}
142142
) {

integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package xef
22

3-
import com.xebia.functional.xef.embeddings.Embedding
4-
import com.xebia.functional.xef.embeddings.Embeddings
3+
import com.xebia.functional.xef.llm.Embeddings
4+
import com.xebia.functional.xef.llm.models.embeddings.Embedding
55
import com.xebia.functional.xef.llm.models.chat.Message
66
import com.xebia.functional.xef.llm.models.chat.Role
7+
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
8+
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
79
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
810
import com.xebia.functional.xef.store.ConversationId
911
import com.xebia.functional.xef.store.Memory
@@ -123,10 +125,17 @@ private fun Embeddings.Companion.mock(
123125
object : Embeddings {
124126
override suspend fun embedDocuments(
125127
texts: List<String>,
126-
chunkSize: Int?,
127-
requestConfig: RequestConfig
128-
): List<Embedding> = embedDocuments(texts, chunkSize, requestConfig)
128+
requestConfig: RequestConfig,
129+
chunkSize: Int?
130+
): List<Embedding> = embedDocuments(texts, requestConfig, chunkSize)
129131

130132
override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
131133
embedQuery(text, requestConfig)
134+
135+
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
136+
createEmbeddings(request)
137+
138+
139+
override val name: String
140+
get() = "embeddings"
132141
}

kotlin/src/commonMain/kotlin/com/xebia/functional/xef/conversation/DSLExtensions.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.xebia.functional.xef.conversation
22

3-
import com.xebia.functional.xef.embeddings.Embeddings
3+
import com.xebia.functional.xef.llm.Embeddings
44
import com.xebia.functional.xef.store.LocalVectorStore
55
import com.xebia.functional.xef.store.VectorStore
66

0 commit comments

Comments
 (0)