|
| 1 | +package openai.data |
| 2 | + |
| 3 | +import com.xebia.functional.tokenizer.ModelType |
| 4 | +import com.xebia.functional.xef.llm.ChatWithFunctions |
| 5 | +import com.xebia.functional.xef.llm.Embeddings |
| 6 | +import com.xebia.functional.xef.llm.models.chat.* |
| 7 | +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest |
| 8 | +import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult |
| 9 | +import com.xebia.functional.xef.llm.models.functions.FunctionCall |
| 10 | +import com.xebia.functional.xef.llm.models.usage.Usage |
| 11 | +import kotlinx.coroutines.flow.Flow |
| 12 | + |
| 13 | +class TestFunctionsModel( |
| 14 | + override val modelType: ModelType, |
| 15 | + override val name: String, |
| 16 | + val responses: Map<String, String> = emptyMap(), |
| 17 | +) : ChatWithFunctions, Embeddings, AutoCloseable { |
| 18 | + |
| 19 | + var requests: MutableList<ChatCompletionRequest> = mutableListOf() |
| 20 | + |
| 21 | + override suspend fun createChatCompletion( |
| 22 | + request: ChatCompletionRequest |
| 23 | + ): ChatCompletionResponse { |
| 24 | + requests.add(request) |
| 25 | + return ChatCompletionResponse( |
| 26 | + id = "fake-id", |
| 27 | + `object` = "fake-object", |
| 28 | + created = 0, |
| 29 | + model = "fake-model", |
| 30 | + choices = |
| 31 | + listOf( |
| 32 | + Choice( |
| 33 | + message = |
| 34 | + Message( |
| 35 | + role = Role.USER, |
| 36 | + content = responses[request.messages.last().content] ?: "fake-content", |
| 37 | + name = Role.USER.name |
| 38 | + ), |
| 39 | + finishReason = "fake-finish-reason", |
| 40 | + index = 0 |
| 41 | + ) |
| 42 | + ), |
| 43 | + usage = Usage.ZERO |
| 44 | + ) |
| 45 | + } |
| 46 | + |
| 47 | + override suspend fun createChatCompletions( |
| 48 | + request: ChatCompletionRequest |
| 49 | + ): Flow<ChatCompletionChunk> { |
| 50 | + throw NotImplementedError() |
| 51 | + } |
| 52 | + |
| 53 | + override fun tokensFromMessages(messages: List<Message>): Int { |
| 54 | + return messages.sumOf { it.content.length } |
| 55 | + } |
| 56 | + |
| 57 | + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult { |
| 58 | + return EmbeddingResult(data = emptyList(), usage = Usage.ZERO) |
| 59 | + } |
| 60 | + |
| 61 | + override suspend fun createChatCompletionWithFunctions( |
| 62 | + request: ChatCompletionRequest |
| 63 | + ): ChatCompletionResponseWithFunctions { |
| 64 | + requests.add(request) |
| 65 | + val response = responses[request.messages.last().content] ?: "fake-content" |
| 66 | + return ChatCompletionResponseWithFunctions( |
| 67 | + id = "fake-id", |
| 68 | + `object` = "fake-object", |
| 69 | + created = 0, |
| 70 | + model = "fake-model", |
| 71 | + choices = |
| 72 | + listOf( |
| 73 | + ChoiceWithFunctions( |
| 74 | + message = |
| 75 | + MessageWithFunctionCall( |
| 76 | + role = Role.USER.name, |
| 77 | + content = response, |
| 78 | + functionCall = FunctionCall("fake-function-name", response), |
| 79 | + name = Role.USER.name |
| 80 | + ), |
| 81 | + finishReason = "fake-finish-reason", |
| 82 | + index = 0 |
| 83 | + ) |
| 84 | + ), |
| 85 | + usage = Usage.ZERO |
| 86 | + ) |
| 87 | + } |
| 88 | +} |
0 commit comments