|
| 1 | +package com.xebia.functional.xef.llm.openai |
| 2 | + |
| 3 | +class MockOpenAIClient( |
| 4 | + private val completion: (CompletionRequest) -> CompletionResult = { |
| 5 | + throw NotImplementedError("completion not implemented") |
| 6 | + }, |
| 7 | + private val chatCompletion: (ChatCompletionRequest) -> ChatCompletionResponse = { |
| 8 | + throw NotImplementedError("chat completion not implemented") |
| 9 | + }, |
| 10 | + private val embeddings: (EmbeddingRequest) -> EmbeddingResult = ::nullEmbeddings, |
| 11 | + private val images: (ImagesGenerationRequest) -> ImagesGenerationResponse = { |
| 12 | + throw NotImplementedError("images not implemented") |
| 13 | + }, |
| 14 | +) : OpenAIClient { |
| 15 | + override suspend fun createCompletion(request: CompletionRequest): CompletionResult = |
| 16 | + completion(request) |
| 17 | + |
| 18 | + override suspend fun createChatCompletion( |
| 19 | + request: ChatCompletionRequest |
| 20 | + ): ChatCompletionResponse = chatCompletion(request) |
| 21 | + |
| 22 | + override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult = |
| 23 | + embeddings(request) |
| 24 | + |
| 25 | + override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse = |
| 26 | + images(request) |
| 27 | +} |
| 28 | + |
| 29 | +fun nullEmbeddings(request: EmbeddingRequest): EmbeddingResult { |
| 30 | + val results = request.input.mapIndexed { index, s -> Embedding(s, listOf(0F), index) } |
| 31 | + return EmbeddingResult(request.model, "", results, Usage.ZERO) |
| 32 | +} |
| 33 | + |
| 34 | +fun simpleMockAIClient(execute: (String) -> String): MockOpenAIClient = |
| 35 | + MockOpenAIClient( |
| 36 | + completion = { req -> |
| 37 | + val request = "${req.prompt.orEmpty()} ${req.suffix.orEmpty()}" |
| 38 | + val response = execute(request) |
| 39 | + val result = CompletionChoice(response, 0, null, "end") |
| 40 | + val requestTokens = request.split(' ').size.toLong() |
| 41 | + val responseTokens = response.split(' ').size.toLong() |
| 42 | + val usage = Usage(requestTokens, responseTokens, requestTokens + responseTokens) |
| 43 | + CompletionResult("FakeID123", "", 0, req.model, listOf(result), usage) |
| 44 | + }, |
| 45 | + chatCompletion = { req -> |
| 46 | + val responses = |
| 47 | + req.messages.mapIndexed { ix, msg -> |
| 48 | + val response = execute(msg.content) |
| 49 | + Choice(Message(msg.role, response), "end", ix) |
| 50 | + } |
| 51 | + val requestTokens = req.messages.sumOf { it.content.split(' ').size.toLong() } |
| 52 | + val responseTokens = responses.sumOf { it.message.content.split(' ').size.toLong() } |
| 53 | + val usage = Usage(requestTokens, responseTokens, requestTokens + responseTokens) |
| 54 | + ChatCompletionResponse("FakeID123", "", 0, req.model, usage, responses) |
| 55 | + } |
| 56 | + ) |
0 commit comments