Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugs in conversations #260

Merged
merged 16 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ interface Chat : LLM {
}

private fun messages(memories: List<Memory>, promptWithContext: String): List<Message> =
memories.reversed().map { it.content } +
listOf(Message(Role.USER, promptWithContext, Role.USER.name))
memories.map { it.content } + listOf(Message(Role.USER, promptWithContext, Role.USER.name))

private suspend fun memories(
conversationId: ConversationId?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class CombinedVectorStore(private val top: VectorStore, private val bottom: Vect
VectorStore by top {

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> {
val bottomResults = bottom.memories(conversationId, limit)
val topResults = top.memories(conversationId, limit)
val bottomResults = bottom.memories(conversationId, limit - topResults.size)
return topResults + bottomResults
return (topResults + bottomResults).sortedBy { it.timestamp }.takeLast(limit)
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
override suspend fun addMemories(memories: List<Memory>) {
javipacheco marked this conversation as resolved.
Show resolved Hide resolved
state.update { prevState ->
prevState.copy(
orderedMemories = prevState.orderedMemories + memories.groupBy { it.conversationId }
orderedMemories =
memories
.groupBy { it.conversationId }
.let { memories ->
(prevState.orderedMemories.keys + memories.keys).associateWith { key ->
val l1 = prevState.orderedMemories[key] ?: emptyList()
val l2 = memories[key] ?: emptyList()
l1 + l2
}
}
)
}
}

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> {
val memories = state.get().orderedMemories[conversationId]
return memories?.take(limit).orEmpty()
return memories?.takeLast(limit).orEmpty().sortedBy { it.timestamp }
}

override suspend fun addTexts(texts: List<String>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@ package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.llm.models.chat.Message

/**
* Representation of the memory of a message in a conversation.
*
* @property content message sent.
* @property conversationId uniquely identifies the conversation in which the message took place.
* @property timestamp in milliseconds.
*/
data class Memory(val conversationId: ConversationId, val content: Message, val timestamp: Long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.xebia.functional.xef.vectorstores

import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class CombinedVectorStoreSpec :
StringSpec({
"memories function should return all of messages combined in the right order" {
val topMessages = generateRandomMessages(4, append = "top", startTimestamp = 1000)
val bottomMessages = generateRandomMessages(4, append = "bottom", startTimestamp = 2000)

val combinedVectorStore = topMessages.combine(bottomMessages)

val messages = combinedVectorStore.memories(defaultConversationId, Int.MAX_VALUE)

val messagesExpected = topMessages + bottomMessages

messages shouldBe messagesExpected
}

"memories function should return the last n combined messages in the right order" {
val topMessages = generateRandomMessages(4, append = "top", startTimestamp = 1000)
val bottomMessages = generateRandomMessages(4, append = "bottom", startTimestamp = 2000)

val combinedVectorStore = topMessages.combine(bottomMessages)

val messages = combinedVectorStore.memories(defaultConversationId, 6 * 2)

val messagesExpected = topMessages.takeLast(2 * 2) + bottomMessages

messages shouldBe messagesExpected
}

"memories function should return the messages with common conversation id combined in the right order" {
val topId = ConversationId("top-id")
val bottomId = ConversationId("bottom-id")
val commonId = ConversationId("common-id")

val topMessages =
generateRandomMessages(4, append = "top", conversationId = topId, startTimestamp = 1000)
val commonTopMessages =
generateRandomMessages(
4,
append = "common-top",
conversationId = commonId,
startTimestamp = 2000
)

val bottomMessages =
generateRandomMessages(
4,
append = "bottom",
conversationId = bottomId,
startTimestamp = 3000
)
val commonBottomMessages =
generateRandomMessages(
4,
append = "common-bottom",
conversationId = commonId,
startTimestamp = 4000
)

val combinedVectorStore =
(topMessages + commonTopMessages).combine(bottomMessages + commonBottomMessages)

val messages = combinedVectorStore.memories(commonId, Int.MAX_VALUE)

val messagesExpected = commonTopMessages + commonBottomMessages

messages shouldBe messagesExpected
}

"adding messages to a combined vector store" {
val topId = ConversationId("top-id")
val bottomId = ConversationId("bottom-id")
val commonId = ConversationId("common-id")

val topMessages =
generateRandomMessages(4, append = "top", conversationId = topId, startTimestamp = 1000)
val commonTopMessages =
generateRandomMessages(
4,
append = "common-top",
conversationId = commonId,
startTimestamp = 2000
)

val bottomMessages =
generateRandomMessages(
4,
append = "bottom",
conversationId = bottomId,
startTimestamp = 3000
)
val commonBottomMessages =
generateRandomMessages(
4,
append = "common-bottom",
conversationId = commonId,
startTimestamp = 4000
)

val combinedVectorStore =
(topMessages + commonTopMessages).combine(bottomMessages + commonBottomMessages)

val newCommonMessages =
generateRandomMessages(4, append = "new", conversationId = commonId, startTimestamp = 5000)
combinedVectorStore.addMemories(newCommonMessages)

combinedVectorStore.memories(commonId, 4 * 2) shouldBe newCommonMessages
}
})

suspend fun List<Memory>.combine(bottomMessages: List<Memory>): CombinedVectorStore {
val top = LocalVectorStore(FakeEmbeddings())
top.addMemories(this)

val bottom = LocalVectorStore(FakeEmbeddings())
bottom.addMemories(bottomMessages)

return CombinedVectorStore(top, bottom)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig

class FakeEmbeddings : Embeddings {
override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> = emptyList()

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
emptyList()
}
javipacheco marked this conversation as resolved.
Show resolved Hide resolved
raulraja marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.xebia.functional.xef.vectorstores

import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class LocalVectorStoreSpec :
StringSpec({
"memories function should return all of messages in the right order when the limit is greater than the number of stored messages" {
val localVectorStore = LocalVectorStore(FakeEmbeddings())

val messages1 = generateRandomMessages(4, startTimestamp = 1000)
val messages2 = generateRandomMessages(3, startTimestamp = 2000)

localVectorStore.addMemories(messages1)
localVectorStore.addMemories(messages2)

val messages = localVectorStore.memories(defaultConversationId, Int.MAX_VALUE)

val messagesExpected = messages1 + messages2

messages shouldBe messagesExpected
}

"memories function should return the last n messages in the right order" {
val localVectorStore = LocalVectorStore(FakeEmbeddings())

val limit = 3 * 2 // 3 couples of messages

val messages1 = generateRandomMessages(4, startTimestamp = 1000)
val messages2 = generateRandomMessages(3, startTimestamp = 2000)

localVectorStore.addMemories(messages1)
localVectorStore.addMemories(messages2)

val messages = localVectorStore.memories(defaultConversationId, limit)

val messagesExpected = (messages1 + messages2).takeLast(limit)

messages shouldBe messagesExpected
}

"memories function should return the last n messages in the right order for a specific conversation id" {
val localVectorStore = LocalVectorStore(FakeEmbeddings())

val limit = 3 * 2

val firstId = ConversationId("first-id")
val secondId = ConversationId("second-id")

val messages1 = generateRandomMessages(4, conversationId = firstId, startTimestamp = 1000)
val messages2 = generateRandomMessages(3, conversationId = secondId, startTimestamp = 2000)

localVectorStore.addMemories(messages1 + messages2)

val messagesFirstId = localVectorStore.memories(firstId, limit)
val messagesFirstIdExpected = messages1.takeLast(limit)

val messagesSecondId = localVectorStore.memories(secondId, limit)
val messagesSecondIdExpected = messages2.takeLast(limit)

messagesFirstId shouldBe messagesFirstIdExpected
messagesSecondId shouldBe messagesSecondIdExpected
}
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role

val defaultConversationId = ConversationId("default-id")

fun generateRandomMessages(
n: Int,
append: String? = null,
conversationId: ConversationId = defaultConversationId,
startTimestamp: Long = 0
): List<Memory> =
(0 until n).flatMap {
listOf(
Memory(
conversationId,
Message(Role.USER, "Question $it${append?.let { ": $it" } ?: ""}", "USER"),
startTimestamp + (it * 10)
),
Memory(
conversationId,
Message(Role.ASSISTANT, "Response $it${append?.let { ": $it" } ?: ""}", "ASSISTANT"),
startTimestamp + (it * 10) + 1
),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.xebia.functional.xef.auto

import com.xebia.functional.xef.auto.llm.openai.getOrElse
import com.xebia.functional.xef.auto.llm.openai.promptMessage

suspend fun main() {
ai {

val emailMessage = """
|You are a Marketing Responsible and have the information about different products. You have to prepare
|an email template with the personal information
""".trimMargin()

val email: String = promptMessage(emailMessage)

println("Prompt:\n $emailMessage")
println("Response:\n $email")

val summarizePrompt = """
|You are a Marketing Responsible and have the information about the best rated products.
|Summarize the next information:
|Love this product and so does my husband! He tried it because his face gets chapped and red from
|working outside. It actually helped by about 60%! I love it cuz it's lightweight and smells so yummy!
|After applying makeup, it doesn't leave streaks like other moisturizers cause. i would definitely use
|this!
|
|I've been using this for 10+yrs now. I don't have any noticeable
|wrinkles at all. I use Estée Lauder's micro essence then advance repair serum before I apply this
|lotion. A little goes a long way! It does feel greasier than most face lotions that I've tried
|previously but I don't apply much. I enjoy cucumber like scent of the face lotion. I have combination
|skin and never broke out using this. This is my daily skincare product with or without makeup. And it
|has SPF but I also apply Kravebeauty SPF on top as well for extra protection
""".trimMargin()

val summarize: String = promptMessage(summarizePrompt)

println("Prompt:\n $summarizePrompt}")
println("Response:\n $summarize")

val meaningPrompt = """
|What is the meaning of life?
""".trimMargin()

val meaning: String = promptMessage(meaningPrompt)

println("Prompt:\n $meaningPrompt}")
println("Response:\n $meaning")


}.getOrElse { println(it) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PGVectorStore(
content = Message(
role = Role.valueOf(role.uppercase()),
content = content,
name = "role",
name = role,
),
timestamp = timestamp,
)
Expand Down Expand Up @@ -83,9 +83,9 @@ class PGVectorStore(

fun createCollection(): Unit =
dataSource.connection {
val xa = UUID.generateUUID()
val uuid = UUID.generateUUID()
update(addNewCollection) {
bind(xa.toString())
bind(uuid.toString())
bind(collectionName)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ val createMemoryTable: String =
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT UNIQUE NOT NULL,
timestamp TIMESTAMP NOT NULL,
timestamp BIGINT NOT NULL
);"""
.trimIndent()

Expand Down Expand Up @@ -97,7 +97,7 @@ val getCollectionById: String =
val getMemoriesByConversationId: String =
"""SELECT * FROM xef_memory
WHERE conversation_id = ?
ORDER BY timestamp DESC LIMIT ?;"""
ORDER BY timestamp ASC LIMIT ?;"""
.trimIndent()

val addNewDocument: String =
Expand Down
Loading