1
1
package com.xebia.functional.gpt4all
2
2
3
3
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
4
- import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding
5
- import com.xebia.functional.xef.embeddings.Embeddings
4
+ import com.xebia.functional.xef.llm.Embeddings
6
5
import com.xebia.functional.xef.llm.models.embeddings.Embedding
7
6
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
8
7
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
9
8
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
10
9
import com.xebia.functional.xef.llm.models.usage.Usage
11
10
12
- class HuggingFaceLocalEmbeddings (name : String , artifact : String ) : com.xebia.functional.xef.llm.Embeddings, Embeddings {
11
+ class HuggingFaceLocalEmbeddings (name : String , artifact : String ) : Embeddings {
13
12
14
13
private val tokenizer = HuggingFaceTokenizer .newInstance(" $name /$artifact " )
15
14
@@ -18,20 +17,17 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
18
17
override suspend fun createEmbeddings (request : EmbeddingRequest ): EmbeddingResult {
19
18
val embedings = tokenizer.batchEncode(request.input)
20
19
return EmbeddingResult (
21
- data = embedings.mapIndexed { n, em -> Embedding (" embedding " , em .ids.map { it.toFloat() }, n ) },
20
+ data = embedings.map { Embedding (it .ids.map { it.toFloat() }) },
22
21
usage = Usage .ZERO
23
22
)
24
23
}
25
24
26
25
override suspend fun embedDocuments (
27
26
texts : List <String >,
28
- chunkSize : Int? ,
29
- requestConfig : RequestConfig
30
- ): List <XefEmbedding > =
31
- tokenizer.batchEncode(texts).map { em -> XefEmbedding (em.ids.map { it.toFloat() }) }
32
-
33
- override suspend fun embedQuery (text : String , requestConfig : RequestConfig ): List <XefEmbedding > =
34
- embedDocuments(listOf (text), null , requestConfig)
27
+ requestConfig : RequestConfig ,
28
+ chunkSize : Int?
29
+ ): List <Embedding > =
30
+ tokenizer.batchEncode(texts).map { em -> Embedding (em.ids.map { it.toFloat() }) } // TODO we need to remove the index
35
31
36
32
companion object {
37
33
@JvmField
0 commit comments