Skip to content

Commit ff49b65

Browse files
serrasfranciscodr
andauthored
Mock AI (#180)
* Mock AI response * Two ways to mock * Fix NotImplementErrors * Make mock embeddings return 0 * Format code --------- Co-authored-by: Francisco Diaz <[email protected]>
1 parent 66c00eb commit ff49b65

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt

+31
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import com.xebia.functional.xef.embeddings.Embeddings
88
import com.xebia.functional.xef.embeddings.OpenAIEmbeddings
99
import com.xebia.functional.xef.env.OpenAIConfig
1010
import com.xebia.functional.xef.llm.openai.KtorOpenAIClient
11+
import com.xebia.functional.xef.llm.openai.MockOpenAIClient
1112
import com.xebia.functional.xef.llm.openai.OpenAIClient
13+
import com.xebia.functional.xef.llm.openai.simpleMockAIClient
1214
import com.xebia.functional.xef.vectorstores.CombinedVectorStore
1315
import com.xebia.functional.xef.vectorstores.LocalVectorStore
1416
import com.xebia.functional.xef.vectorstores.VectorStore
@@ -53,6 +55,21 @@ suspend fun <A> AIScope(block: suspend AIScope.() -> A, orElse: suspend (AIError
5355
orElse(e)
5456
}
5557

58+
@OptIn(ExperimentalTime::class)
59+
suspend fun <A> MockAIScope(
60+
mockClient: MockOpenAIClient,
61+
block: suspend AIScope.() -> A,
62+
orElse: suspend (AIError) -> A
63+
): A =
64+
try {
65+
val embeddings = OpenAIEmbeddings(OpenAIConfig(), mockClient)
66+
val vectorStore = LocalVectorStore(embeddings)
67+
val scope = AIScope(mockClient, vectorStore, embeddings)
68+
block(scope)
69+
} catch (e: AIError) {
70+
orElse(e)
71+
}
72+
5673
/**
5774
* Run the [AI] value to produce _either_ an [AIError], or [A]. this method initialises all the
5875
* dependencies required to run the [AI] value and once it finishes it closes all the resources.
@@ -64,6 +81,20 @@ suspend fun <A> AIScope(block: suspend AIScope.() -> A, orElse: suspend (AIError
6481
suspend inline fun <reified A> AI<A>.toEither(): Either<AIError, A> =
6582
ai { invoke().right() }.getOrElse { it.left() }
6683

84+
/**
85+
* Run the [AI] value to produce _either_ an [AIError], or [A]. This method uses the [mockAI] to
86+
* compute the different responses.
87+
*/
88+
suspend fun <A> AI<A>.mock(mockAI: MockOpenAIClient): Either<AIError, A> =
89+
MockAIScope(mockAI, { invoke().right() }, { it.left() })
90+
91+
/**
92+
* Run the [AI] value to produce _either_ an [AIError], or [A]. This method uses the [mockAI] to
93+
* compute the different responses.
94+
*/
95+
suspend fun <A> AI<A>.mock(mockAI: (String) -> String): Either<AIError, A> =
96+
MockAIScope(simpleMockAIClient(mockAI), { invoke().right() }, { it.left() })
97+
6798
/**
6899
* Run the [AI] value to produce [A]. this method initialises all the dependencies required to run
69100
* the [AI] value and once it finishes it closes all the resources.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package com.xebia.functional.xef.llm.openai
2+
3+
class MockOpenAIClient(
4+
private val completion: (CompletionRequest) -> CompletionResult = {
5+
throw NotImplementedError("completion not implemented")
6+
},
7+
private val chatCompletion: (ChatCompletionRequest) -> ChatCompletionResponse = {
8+
throw NotImplementedError("chat completion not implemented")
9+
},
10+
private val embeddings: (EmbeddingRequest) -> EmbeddingResult = ::nullEmbeddings,
11+
private val images: (ImagesGenerationRequest) -> ImagesGenerationResponse = {
12+
throw NotImplementedError("images not implemented")
13+
},
14+
) : OpenAIClient {
15+
override suspend fun createCompletion(request: CompletionRequest): CompletionResult =
16+
completion(request)
17+
18+
override suspend fun createChatCompletion(
19+
request: ChatCompletionRequest
20+
): ChatCompletionResponse = chatCompletion(request)
21+
22+
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
23+
embeddings(request)
24+
25+
override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse =
26+
images(request)
27+
}
28+
29+
fun nullEmbeddings(request: EmbeddingRequest): EmbeddingResult {
30+
val results = request.input.mapIndexed { index, s -> Embedding(s, listOf(0F), index) }
31+
return EmbeddingResult(request.model, "", results, Usage.ZERO)
32+
}
33+
34+
fun simpleMockAIClient(execute: (String) -> String): MockOpenAIClient =
35+
MockOpenAIClient(
36+
completion = { req ->
37+
val request = "${req.prompt.orEmpty()} ${req.suffix.orEmpty()}"
38+
val response = execute(request)
39+
val result = CompletionChoice(response, 0, null, "end")
40+
val requestTokens = request.split(' ').size.toLong()
41+
val responseTokens = response.split(' ').size.toLong()
42+
val usage = Usage(requestTokens, responseTokens, requestTokens + responseTokens)
43+
CompletionResult("FakeID123", "", 0, req.model, listOf(result), usage)
44+
},
45+
chatCompletion = { req ->
46+
val responses =
47+
req.messages.mapIndexed { ix, msg ->
48+
val response = execute(msg.content)
49+
Choice(Message(msg.role, response), "end", ix)
50+
}
51+
val requestTokens = req.messages.sumOf { it.content.split(' ').size.toLong() }
52+
val responseTokens = responses.sumOf { it.message.content.split(' ').size.toLong() }
53+
val usage = Usage(requestTokens, responseTokens, requestTokens + responseTokens)
54+
ChatCompletionResponse("FakeID123", "", 0, req.model, usage, responses)
55+
}
56+
)

core/src/commonMain/kotlin/com/xebia/functional/xef/llm/openai/models.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ data class Usage(
114114
@SerialName("prompt_tokens") val promptTokens: Long,
115115
@SerialName("completion_tokens") val completionTokens: Long? = null,
116116
@SerialName("total_tokens") val totalTokens: Long
117-
)
117+
) {
118+
companion object {
119+
val ZERO: Usage = Usage(0, 0, 0)
120+
}
121+
}
118122

119123
data class LLMModel(val name: String, val kind: Kind, val modelType: ModelType) {
120124
enum class Kind {

0 commit comments

Comments
 (0)