diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt index 4e26923dd..e97b36518 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/LLM.kt @@ -12,6 +12,12 @@ sealed interface LLM : AutoCloseable { val name get() = modelType.name + /** + * Copies this instance and uses [modelType] for [LLM.modelType]. Has to return the most specific + * type of this instance! + */ + fun copy(modelType: ModelType): LLM + fun tokensFromMessages( messages: List ): Int { // TODO: naive implementation with magic numbers diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestEmbeddings.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestEmbeddings.kt index f4e81d7d0..2cc0d7989 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestEmbeddings.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestEmbeddings.kt @@ -12,6 +12,8 @@ class TestEmbeddings : Embeddings { override val modelType: ModelType = ModelType.TODO("test-embeddings") + override fun copy(modelType: ModelType) = TestEmbeddings() + override suspend fun embedDocuments( texts: List, requestConfig: RequestConfig, diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestFunctionsModel.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestFunctionsModel.kt index 89a3a666a..eb04ff174 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestFunctionsModel.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestFunctionsModel.kt @@ -18,6 +18,8 @@ class TestFunctionsModel( var requests: MutableList = mutableListOf() + override fun copy(modelType: ModelType) = TestFunctionsModel(modelType, responses) + override fun tokensFromMessages(messages: List): Int { return messages.sumOf { it.content.length } } diff --git a/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestModel.kt b/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestModel.kt index c3efc8d4b..39d25569e 100644 --- a/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestModel.kt +++ b/core/src/commonTest/kotlin/com/xebia/functional/xef/data/TestModel.kt @@ -16,6 +16,8 @@ class TestModel( var requests: MutableList = mutableListOf() + override fun copy(modelType: ModelType) = TestModel(modelType, responses) + override suspend fun createChatCompletion( request: ChatCompletionRequest ): ChatCompletionResponse { diff --git a/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/finetuning/FineTunedModelChat.kt b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/finetuning/FineTunedModelChat.kt new file mode 100644 index 000000000..010a43d39 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/xebia/functional/xef/conversation/finetuning/FineTunedModelChat.kt @@ -0,0 +1,23 @@ +package com.xebia.functional.xef.conversation.finetuning + +import com.xebia.functional.xef.conversation.llm.openai.OpenAI +import com.xebia.functional.xef.env.getenv +import com.xebia.functional.xef.prompt.Prompt + +suspend fun main() { + val spawnModelId = + getenv("OPENAI_FINE_TUNED_MODEL_ID") + ?: error("Please set the OPENAI_FINE_TUNED_MODEL_ID environment variable.") + + val OAI = OpenAI() + val model = OAI.spawnModel(spawnModelId, OAI.GPT_3_5_TURBO) + OpenAI.conversation { + while (true) { + print("> ") + val question = readlnOrNull() ?: break + val answer = model.promptStreaming(Prompt(question), this) + answer.collect(::print) + println() + } + } +} diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt index 8462c9126..e24fb2c37 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt @@ -66,6 +66,9 @@ interface GPT4All : AutoCloseable, Chat, Completion { val llModel = LLModel(path) + override fun copy(modelType: ModelType) = + GPT4All(url, path) + override suspend fun createCompletion(request: CompletionRequest): CompletionResult = with(request) { val config = LLModel.config() diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt index a7a7a7ab4..679294158 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/HuggingFaceLocalEmbeddings.kt @@ -11,13 +11,16 @@ import com.xebia.functional.xef.llm.models.usage.Usage class HuggingFaceLocalEmbeddings( override val modelType: ModelType, - artifact: String, + private val artifact: String, ) : Embeddings { private val tokenizer = HuggingFaceTokenizer.newInstance("${modelType.name}/$artifact") override val name: String = HuggingFaceLocalEmbeddings::class.java.canonicalName + override fun copy(modelType: ModelType) = + HuggingFaceLocalEmbeddings(modelType, artifact) + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { val embedings = tokenizer.batchEncode(request.input) return EmbeddingResult( diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GCP.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GCP.kt index 5b5934ef6..81ef15a23 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GCP.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GCP.kt @@ -47,10 +47,8 @@ class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: St val defaultClient = GcpClient(config) - val CODECHAT by lazy { GcpChat(ModelType.TODO("codechat-bison@001"), defaultClient) } - val TEXT_EMBEDDING_GECKO by lazy { - GcpEmbeddings(ModelType.TODO("textembedding-gecko"), defaultClient) - } + val CODECHAT by lazy { GcpChat(this, ModelType.TODO("codechat-bison@001")) } + val TEXT_EMBEDDING_GECKO by lazy { GcpEmbeddings(this, ModelType.TODO("textembedding-gecko")) } @JvmField val DEFAULT_CHAT = CODECHAT @JvmField val DEFAULT_EMBEDDING = TEXT_EMBEDDING_GECKO diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt index f4c611dbc..221ef095e 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt @@ -79,7 +79,7 @@ class GcpClient( ) val response = http.post( - "https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict" + "https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/publishers/google/models/$modelId:predict" ) { header("Authorization", "Bearer ${config.token}") contentType(ContentType.Application.Json) diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpChat.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpChat.kt index a77746db2..f805597ab 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpChat.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpChat.kt @@ -1,7 +1,7 @@ package com.xebia.functional.xef.gcp.models import com.xebia.functional.tokenizer.ModelType -import com.xebia.functional.xef.gcp.GcpClient +import com.xebia.functional.xef.gcp.GCP import com.xebia.functional.xef.llm.Chat import com.xebia.functional.xef.llm.models.chat.* import com.xebia.functional.xef.llm.models.usage.Usage @@ -13,10 +13,14 @@ import kotlinx.uuid.UUID import kotlinx.uuid.generateUUID class GcpChat( + private val provider: GCP, // TODO: use context receiver override val modelType: ModelType, - private val client: GcpClient, ) : Chat { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = GcpChat(provider, modelType) + override suspend fun createChatCompletion( request: ChatCompletionRequest ): ChatCompletionResponse { diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpCompletion.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpCompletion.kt index 3e1a36727..db9b3e7f8 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpCompletion.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpCompletion.kt @@ -1,7 +1,7 @@ package com.xebia.functional.xef.gcp.models import com.xebia.functional.tokenizer.ModelType -import com.xebia.functional.xef.gcp.GcpClient +import com.xebia.functional.xef.gcp.GCP import com.xebia.functional.xef.llm.Completion import com.xebia.functional.xef.llm.models.text.CompletionChoice import com.xebia.functional.xef.llm.models.text.CompletionRequest @@ -12,10 +12,14 @@ import kotlinx.uuid.UUID import kotlinx.uuid.generateUUID class GcpCompletion( + private val provider: GCP, // TODO: use context receiver override val modelType: ModelType, - private val client: GcpClient, ) : Completion { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = GcpCompletion(provider, modelType) + override suspend fun createCompletion(request: CompletionRequest): CompletionResult { val response: String = client.promptMessage( diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpEmbeddings.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpEmbeddings.kt index e9472f7a5..d46015694 100644 --- a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpEmbeddings.kt +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/models/GcpEmbeddings.kt @@ -1,6 +1,7 @@ package com.xebia.functional.xef.gcp.models import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.gcp.GCP import com.xebia.functional.xef.gcp.GcpClient import com.xebia.functional.xef.llm.Embeddings import com.xebia.functional.xef.llm.models.embeddings.Embedding @@ -9,10 +10,14 @@ import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult import com.xebia.functional.xef.llm.models.usage.Usage class GcpEmbeddings( + private val provider: GCP, // TODO: use context receiver override val modelType: ModelType, - private val client: GcpClient, ) : Embeddings { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = GcpEmbeddings(provider, modelType) + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { fun requestToEmbedding(it: GcpClient.EmbeddingPredictions): Embedding = Embedding(it.embeddings.values.map(Double::toFloat)) diff --git a/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt b/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt index d63bac7a5..8e95d55d1 100644 --- a/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt +++ b/integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt @@ -110,6 +110,9 @@ class PGVectorStoreSpec : }) class TestLLM(override val modelType: ModelType = ModelType.ADA) : Chat, AutoCloseable { + override fun copy(modelType: ModelType) = + TestLLM(modelType) + override fun tokensFromMessages(messages: List): Int = messages.map { calculateTokens(it) }.sum() private fun calculateTokens(message: Message): Int = message.content.split(" ").size + 2 // 2 is the role and name @@ -145,6 +148,9 @@ private fun Embeddings.Companion.mock( } ): Embeddings = object : Embeddings { + override fun copy(modelType: ModelType): LLM { + throw NotImplementedError() + } override suspend fun embedDocuments( texts: List, requestConfig: RequestConfig, diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAI.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAI.kt index 17fd7311f..ce34ca426 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAI.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/OpenAI.kt @@ -1,7 +1,9 @@ package com.xebia.functional.xef.conversation.llm.openai import arrow.core.nonEmptyListOf +import com.aallam.openai.api.exception.InvalidRequestException import com.aallam.openai.api.logging.LogLevel +import com.aallam.openai.api.model.ModelId import com.aallam.openai.client.LoggingConfig import com.aallam.openai.client.OpenAI as OpenAIClient import com.aallam.openai.client.OpenAIHost @@ -57,7 +59,7 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu } } - val defaultClient = + internal val defaultClient = OpenAIClient( host = getHost()?.let { OpenAIHost(it) } ?: OpenAIHost.OpenAI, token = getToken(), @@ -66,51 +68,43 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu ) .let { autoClose(it) } - val GPT_4 by lazy { autoClose(OpenAIChat(ModelType.GPT_4, defaultClient)) } + val GPT_4 by lazy { autoClose(OpenAIChat(this, ModelType.GPT_4)) } val GPT_4_0314 by lazy { - autoClose(OpenAIFunChat(ModelType.GPT_4_0314, defaultClient)) // legacy + autoClose(OpenAIFunChat(this, ModelType.GPT_4_0314)) // legacy } - val GPT_4_32K by lazy { autoClose(OpenAIChat(ModelType.GPT_4_32K, defaultClient)) } + val GPT_4_32K by lazy { autoClose(OpenAIChat(this, ModelType.GPT_4_32K)) } - val GPT_3_5_TURBO by lazy { autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, defaultClient)) } + val GPT_3_5_TURBO by lazy { autoClose(OpenAIChat(this, ModelType.GPT_3_5_TURBO)) } - val GPT_3_5_TURBO_16K by lazy { - autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO_16_K, defaultClient)) - } + val GPT_3_5_TURBO_16K by lazy { autoClose(OpenAIChat(this, ModelType.GPT_3_5_TURBO_16_K)) } val GPT_3_5_TURBO_FUNCTIONS by lazy { - autoClose(OpenAIFunChat(ModelType.GPT_3_5_TURBO_FUNCTIONS, defaultClient)) + autoClose(OpenAIFunChat(this, ModelType.GPT_3_5_TURBO_FUNCTIONS)) } val GPT_3_5_TURBO_0301 by lazy { - autoClose(OpenAIChat(ModelType.GPT_3_5_TURBO, defaultClient)) // legacy + autoClose(OpenAIChat(this, ModelType.GPT_3_5_TURBO)) // legacy } - val TEXT_DAVINCI_003 by lazy { - autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_003, defaultClient)) - } + val TEXT_DAVINCI_003 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_DAVINCI_003)) } - val TEXT_DAVINCI_002 by lazy { - autoClose(OpenAICompletion(ModelType.TEXT_DAVINCI_002, defaultClient)) - } + val TEXT_DAVINCI_002 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_DAVINCI_002)) } val TEXT_CURIE_001 by lazy { - autoClose(OpenAICompletion(ModelType.TEXT_SIMILARITY_CURIE_001, defaultClient)) + autoClose(OpenAICompletion(this, ModelType.TEXT_SIMILARITY_CURIE_001)) } - val TEXT_BABBAGE_001 by lazy { - autoClose(OpenAICompletion(ModelType.TEXT_BABBAGE_001, defaultClient)) - } + val TEXT_BABBAGE_001 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_BABBAGE_001)) } - val TEXT_ADA_001 by lazy { autoClose(OpenAICompletion(ModelType.TEXT_ADA_001, defaultClient)) } + val TEXT_ADA_001 by lazy { autoClose(OpenAICompletion(this, ModelType.TEXT_ADA_001)) } val TEXT_EMBEDDING_ADA_002 by lazy { - autoClose(OpenAIEmbeddings(ModelType.TEXT_EMBEDDING_ADA_002, defaultClient)) + autoClose(OpenAIEmbeddings(this, ModelType.TEXT_EMBEDDING_ADA_002)) } - val DALLE_2 by lazy { autoClose(OpenAIImages(ModelType.GPT_3_5_TURBO, defaultClient)) } + val DALLE_2 by lazy { autoClose(OpenAIImages(this, ModelType.GPT_3_5_TURBO)) } @JvmField val DEFAULT_CHAT = GPT_3_5_TURBO_16K @@ -120,8 +114,8 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu @JvmField val DEFAULT_IMAGES = DALLE_2 - fun supportedModels(): List = - listOf( + fun supportedModels(): List = // TODO: impl of abstract provider function + listOf( GPT_4, GPT_4_0314, GPT_4_32K, @@ -138,6 +132,28 @@ class OpenAI(internal var token: String? = null, internal var host: String? = nu DALLE_2, ) + suspend fun findModel(modelId: String): Any? { // TODO: impl of abstract provider function + val model = + try { + defaultClient.model(ModelId(modelId)) + } catch (e: InvalidRequestException) { + when (e.error.detail?.code) { + "model_not_found" -> return null + else -> throw e + } + } + return ModelType.TODO(model.id.id) + } + + suspend fun spawnModel( + modelId: String, + baseModel: T + ): T { // TODO: impl of abstract provider function + if (findModel(modelId) == null) error("model not found") + return baseModel.copy(ModelType.FineTunedModel(modelId, baseModel = baseModel.modelType)) as? T + ?: error("${baseModel::class} does not follow contract to return the most specific type") + } + companion object { @JvmField val FromEnvironment: OpenAI = OpenAI() diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIChat.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIChat.kt index 7b6cc182a..47d559d5f 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIChat.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIChat.kt @@ -4,8 +4,8 @@ import com.aallam.openai.api.chat.ChatChoice import com.aallam.openai.api.chat.ChatMessage import com.aallam.openai.api.chat.chatCompletionRequest import com.aallam.openai.api.model.ModelId -import com.aallam.openai.client.OpenAI import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.conversation.llm.openai.toInternal import com.xebia.functional.xef.conversation.llm.openai.toOpenAI import com.xebia.functional.xef.llm.Chat @@ -14,10 +14,14 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map class OpenAIChat( + private val provider: OpenAI, // TODO: use context receiver override val modelType: ModelType, - private val client: OpenAI, ) : Chat { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = OpenAIChat(provider, modelType) + override suspend fun createChatCompletion( request: ChatCompletionRequest ): ChatCompletionResponse { diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAICompletion.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAICompletion.kt index 930b327d7..e09a179b5 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAICompletion.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAICompletion.kt @@ -4,8 +4,8 @@ import com.aallam.openai.api.LegacyOpenAI import com.aallam.openai.api.completion.Choice import com.aallam.openai.api.completion.completionRequest import com.aallam.openai.api.model.ModelId -import com.aallam.openai.client.OpenAI import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.conversation.llm.openai.toInternal import com.xebia.functional.xef.llm.Completion import com.xebia.functional.xef.llm.models.text.CompletionChoice @@ -13,10 +13,14 @@ import com.xebia.functional.xef.llm.models.text.CompletionRequest import com.xebia.functional.xef.llm.models.text.CompletionResult class OpenAICompletion( + private val provider: OpenAI, // TODO: use context receiver override val modelType: ModelType, - private val client: OpenAI, ) : Completion { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = OpenAICompletion(provider, modelType) + @OptIn(LegacyOpenAI::class) override suspend fun createCompletion(request: CompletionRequest): CompletionResult { fun toInternal(it: Choice): CompletionChoice = diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIEmbeddings.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIEmbeddings.kt index 3151c66d3..7a144b3ef 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIEmbeddings.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIEmbeddings.kt @@ -3,8 +3,8 @@ package com.xebia.functional.xef.conversation.llm.openai.models import com.aallam.openai.api.embedding.Embedding as OpenAIEmbedding import com.aallam.openai.api.embedding.embeddingRequest import com.aallam.openai.api.model.ModelId -import com.aallam.openai.client.OpenAI import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.conversation.llm.openai.toInternal import com.xebia.functional.xef.llm.Embeddings import com.xebia.functional.xef.llm.models.embeddings.Embedding @@ -12,10 +12,14 @@ import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult class OpenAIEmbeddings( + private val provider: OpenAI, // TODO: use context receiver override val modelType: ModelType, - private val client: OpenAI, ) : Embeddings { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = OpenAIEmbeddings(provider, modelType) + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { val clientRequest = embeddingRequest { model = ModelId(request.model) diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIFunChat.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIFunChat.kt index 1ce376bb3..d7e0bd437 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIFunChat.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIFunChat.kt @@ -2,8 +2,8 @@ package com.xebia.functional.xef.conversation.llm.openai.models import com.aallam.openai.api.chat.* import com.aallam.openai.api.model.ModelId -import com.aallam.openai.client.OpenAI import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.conversation.llm.openai.toInternal import com.xebia.functional.xef.conversation.llm.openai.toOpenAI import com.xebia.functional.xef.llm.ChatWithFunctions @@ -17,10 +17,14 @@ import kotlinx.coroutines.flow.map import kotlinx.serialization.json.Json class OpenAIFunChat( + private val provider: OpenAI, // TODO: use context receiver override val modelType: ModelType, - private val client: OpenAI, ) : ChatWithFunctions { + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = OpenAIFunChat(provider, modelType) + override suspend fun createChatCompletionWithFunctions( request: FunChatCompletionRequest ): ChatCompletionResponseWithFunctions { diff --git a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIImages.kt b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIImages.kt index ab5567c51..b98a90fbc 100644 --- a/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIImages.kt +++ b/openai/src/commonMain/kotlin/com/xebia/functional/xef/conversation/llm/openai/models/OpenAIImages.kt @@ -4,8 +4,8 @@ import com.aallam.openai.api.BetaOpenAI import com.aallam.openai.api.image.ImageCreation import com.aallam.openai.api.image.ImageSize import com.aallam.openai.api.image.imageCreation -import com.aallam.openai.client.OpenAI import com.xebia.functional.tokenizer.ModelType +import com.xebia.functional.xef.conversation.llm.openai.OpenAI import com.xebia.functional.xef.llm.Images import com.xebia.functional.xef.llm.models.chat.Message import com.xebia.functional.xef.llm.models.images.ImageGenerationUrl @@ -13,9 +13,14 @@ import com.xebia.functional.xef.llm.models.images.ImagesGenerationRequest import com.xebia.functional.xef.llm.models.images.ImagesGenerationResponse class OpenAIImages( + private val provider: OpenAI, // TODO: use context receiver override val modelType: ModelType, - private val client: OpenAI, ) : Images { + + private val client = provider.defaultClient + + override fun copy(modelType: ModelType) = OpenAIImages(provider, modelType) + @OptIn(BetaOpenAI::class) override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse { val clientRequest: ImageCreation = imageCreation { diff --git a/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt b/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt index e69847507..b13a8ccf3 100644 --- a/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt +++ b/tokenizer/src/commonMain/kotlin/com/xebia/functional/tokenizer/ModelType.kt @@ -84,6 +84,18 @@ sealed class ModelType( object CODE_SEARCH_BABBAGE_CODE_001 : ModelType("code-search-babbage-code-001", R50K_BASE, 2046) object CODE_SEARCH_ADA_CODE_001 : ModelType("code-search-ada-code-001", R50K_BASE, 2046) + class FineTunedModel( + name: String, + val baseModel: ModelType, + ) : ModelType( + name = name, + encodingType = baseModel.encodingType, + maxContextLength = baseModel.maxContextLength, + tokensPerMessage = baseModel.tokensPerMessage, + tokensPerName = baseModel.tokensPerName, + tokenPadding = baseModel.tokenPadding, + ) + /** * Currently as of September 2023, * [ModelType] has only implementations for OpenAI.