|
1 | 1 | package com.xebia.functional.chains
|
2 | 2 |
|
3 |
| -import arrow.core.Either |
4 | 3 | import arrow.core.raise.either
|
5 | 4 | import com.xebia.functional.llm.openai.CompletionChoice
|
6 | 5 | import com.xebia.functional.llm.openai.CompletionRequest
|
7 | 6 | import com.xebia.functional.llm.openai.EmbeddingRequest
|
8 | 7 | import com.xebia.functional.llm.openai.EmbeddingResult
|
9 | 8 | import com.xebia.functional.llm.openai.OpenAIClient
|
10 | 9 | import com.xebia.functional.prompt.PromptTemplate
|
| 10 | +import io.kotest.assertions.arrow.core.shouldBeLeft |
| 11 | +import io.kotest.assertions.arrow.core.shouldBeRight |
11 | 12 | import io.kotest.core.spec.style.StringSpec
|
12 |
| -import io.kotest.matchers.shouldBe |
13 | 13 |
|
14 | 14 | class LLMChainSpec : StringSpec({
|
15 | 15 | "LLMChain should return a prediction with just the output" {
|
16 | 16 | val template = "Tell me {foo}."
|
17 | 17 | either {
|
18 | 18 | val prompt = PromptTemplate(template, listOf("foo"))
|
19 |
| - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, true) |
| 19 | + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) |
20 | 20 | chain.run("a joke").bind()
|
21 |
| - } shouldBe Either.Right( |
22 |
| - mapOf("answer" to "I'm not good at jokes") |
23 |
| - ) |
| 21 | + } shouldBeRight mapOf("answer" to "I'm not good at jokes") |
24 | 22 | }
|
25 | 23 |
|
26 | 24 | "LLMChain should return a prediction with both output and inputs" {
|
27 | 25 | val template = "Tell me {foo}."
|
28 | 26 | either {
|
29 | 27 | val prompt = PromptTemplate(template, listOf("foo"))
|
30 |
| - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) |
| 28 | + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput) |
31 | 29 | chain.run("a joke").bind()
|
32 |
| - } shouldBe Either.Right( |
33 |
| - mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") |
34 |
| - ) |
| 30 | + } shouldBeRight mapOf("foo" to "a joke", "answer" to "I'm not good at jokes") |
35 | 31 | }
|
36 | 32 |
|
37 | 33 | "LLMChain should return a prediction with a more complex template" {
|
38 | 34 | val template = "My name is {name} and I'm {age} years old"
|
39 | 35 | either {
|
40 | 36 | val prompt = PromptTemplate(template, listOf("name", "age"))
|
41 |
| - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) |
| 37 | + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, Chain.ChainOutput.InputAndOutput) |
42 | 38 | chain.run(mapOf("age" to "28", "name" to "foo")).bind()
|
43 |
| - } shouldBe Either.Right( |
44 |
| - mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") |
45 |
| - ) |
| 39 | + } shouldBeRight mapOf("age" to "28", "name" to "foo", "answer" to "Hello there! Nice to meet you foo") |
46 | 40 | }
|
47 | 41 |
|
48 | 42 | "LLMChain should fail when inputs are not the expected ones from the PromptTemplate" {
|
49 | 43 | val template = "My name is {name} and I'm {age} years old"
|
50 | 44 | either {
|
51 | 45 | val prompt = PromptTemplate(template, listOf("name", "age"))
|
52 |
| - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) |
| 46 | + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) |
53 | 47 | chain.run(mapOf("age" to "28", "brand" to "foo")).bind()
|
54 |
| - } shouldBe Either.Left( |
55 |
| - Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}") |
56 |
| - ) |
| 48 | + } shouldBeLeft Chain.InvalidInputs("The provided inputs: {age}, {brand} do not match with chain's inputs: {name}, {age}") |
57 | 49 | }
|
58 | 50 |
|
59 | 51 | "LLMChain should fail when using just one input but expecting more" {
|
60 | 52 | val template = "My name is {name} and I'm {age} years old"
|
61 | 53 | either {
|
62 | 54 | val prompt = PromptTemplate(template, listOf("name", "age"))
|
63 |
| - val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0, false) |
| 55 | + val chain = LLMChain(llm, prompt, "davinci", "testing", false, 1, 0.0) |
64 | 56 | chain.run("foo").bind()
|
65 |
| - } shouldBe Either.Left( |
66 |
| - Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") |
67 |
| - ) |
| 57 | + } shouldBeLeft Chain.InvalidInputs("The expected inputs are more than one: {name}, {age}") |
68 | 58 | }
|
69 | 59 | })
|
70 | 60 |
|
|
0 commit comments