Skip to content

Commit

Permalink
Bugs in conversations (#260)
Browse files Browse the repository at this point in the history
* Bugs in conversations

* Using trim marging on Conversation example

* Bug fixed when storing messages on LocalVectorStore and tests

* victorcrrd comments addressed

* CombinedVectorStore fixed and tests

* PGVectorStore tests (#272)

* format

* syntax error in createMemoryTable query

* being able to connect to db

* added missing test

* added documentation to indicate that timestamp is measured in millis

* modeling timestamp in postgres as a BIGINT instead of a TIMESTAMP to avoid problems

* messages retrieved from postgres were assigned always the same name "role", now it matches the name stored in the database

* getting postgres messages in the right order: from oldest to newest (in the form of a chat)

* added simple test

* specifying more the test

* substituting localhost by IP 0.0.0.0 is fine for this test

* Sorting messages by timestamp

* Removing categoriesSelection

---------

Co-authored-by: Raúl Raja Martínez <[email protected]>
Co-authored-by: Victor Carrillo-Redondo <[email protected]>
  • Loading branch information
3 people authored Jul 27, 2023
1 parent bb07d43 commit 7b74452
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ interface Chat : LLM {
}

private fun messages(memories: List<Memory>, promptWithContext: String): List<Message> =
memories.reversed().map { it.content } +
listOf(Message(Role.USER, promptWithContext, Role.USER.name))
memories.map { it.content } + listOf(Message(Role.USER, promptWithContext, Role.USER.name))

private suspend fun memories(
conversationId: ConversationId?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class CombinedVectorStore(private val top: VectorStore, private val bottom: Vect
VectorStore by top {

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> {
val bottomResults = bottom.memories(conversationId, limit)
val topResults = top.memories(conversationId, limit)
val bottomResults = bottom.memories(conversationId, limit - topResults.size)
return topResults + bottomResults
return (topResults + bottomResults).sortedBy { it.timestamp }.takeLast(limit)
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,23 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
override suspend fun addMemories(memories: List<Memory>) {
state.update { prevState ->
prevState.copy(
orderedMemories = prevState.orderedMemories + memories.groupBy { it.conversationId }
orderedMemories =
memories
.groupBy { it.conversationId }
.let { memories ->
(prevState.orderedMemories.keys + memories.keys).associateWith { key ->
val l1 = prevState.orderedMemories[key] ?: emptyList()
val l2 = memories[key] ?: emptyList()
l1 + l2
}
}
)
}
}

override suspend fun memories(conversationId: ConversationId, limit: Int): List<Memory> {
val memories = state.get().orderedMemories[conversationId]
return memories?.take(limit).orEmpty()
return memories?.takeLast(limit).orEmpty().sortedBy { it.timestamp }
}

override suspend fun addTexts(texts: List<String>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,11 @@ package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.llm.models.chat.Message

/**
* Representation of the memory of a message in a conversation.
*
* @property content message sent.
* @property conversationId uniquely identifies the conversation in which the message took place.
* @property timestamp in milliseconds.
*/
data class Memory(val conversationId: ConversationId, val content: Message, val timestamp: Long)
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.xebia.functional.xef.vectorstores

import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class CombinedVectorStoreSpec :
StringSpec({
"memories function should return all of messages combined in the right order" {
val topMessages = generateRandomMessages(4, append = "top", startTimestamp = 1000)
val bottomMessages = generateRandomMessages(4, append = "bottom", startTimestamp = 2000)

val combinedVectorStore = topMessages.combine(bottomMessages)

val messages = combinedVectorStore.memories(defaultConversationId, Int.MAX_VALUE)

val messagesExpected = topMessages + bottomMessages

messages shouldBe messagesExpected
}

"memories function should return the last n combined messages in the right order" {
val topMessages = generateRandomMessages(4, append = "top", startTimestamp = 1000)
val bottomMessages = generateRandomMessages(4, append = "bottom", startTimestamp = 2000)

val combinedVectorStore = topMessages.combine(bottomMessages)

val messages = combinedVectorStore.memories(defaultConversationId, 6 * 2)

val messagesExpected = topMessages.takeLast(2 * 2) + bottomMessages

messages shouldBe messagesExpected
}

"memories function should return the messages with common conversation id combined in the right order" {
val topId = ConversationId("top-id")
val bottomId = ConversationId("bottom-id")
val commonId = ConversationId("common-id")

val topMessages =
generateRandomMessages(4, append = "top", conversationId = topId, startTimestamp = 1000)
val commonTopMessages =
generateRandomMessages(
4,
append = "common-top",
conversationId = commonId,
startTimestamp = 2000
)

val bottomMessages =
generateRandomMessages(
4,
append = "bottom",
conversationId = bottomId,
startTimestamp = 3000
)
val commonBottomMessages =
generateRandomMessages(
4,
append = "common-bottom",
conversationId = commonId,
startTimestamp = 4000
)

val combinedVectorStore =
(topMessages + commonTopMessages).combine(bottomMessages + commonBottomMessages)

val messages = combinedVectorStore.memories(commonId, Int.MAX_VALUE)

val messagesExpected = commonTopMessages + commonBottomMessages

messages shouldBe messagesExpected
}

"adding messages to a combined vector store" {
val topId = ConversationId("top-id")
val bottomId = ConversationId("bottom-id")
val commonId = ConversationId("common-id")

val topMessages =
generateRandomMessages(4, append = "top", conversationId = topId, startTimestamp = 1000)
val commonTopMessages =
generateRandomMessages(
4,
append = "common-top",
conversationId = commonId,
startTimestamp = 2000
)

val bottomMessages =
generateRandomMessages(
4,
append = "bottom",
conversationId = bottomId,
startTimestamp = 3000
)
val commonBottomMessages =
generateRandomMessages(
4,
append = "common-bottom",
conversationId = commonId,
startTimestamp = 4000
)

val combinedVectorStore =
(topMessages + commonTopMessages).combine(bottomMessages + commonBottomMessages)

val newCommonMessages =
generateRandomMessages(4, append = "new", conversationId = commonId, startTimestamp = 5000)
combinedVectorStore.addMemories(newCommonMessages)

combinedVectorStore.memories(commonId, 4 * 2) shouldBe newCommonMessages
}
})

suspend fun List<Memory>.combine(bottomMessages: List<Memory>): CombinedVectorStore {
val top = LocalVectorStore(FakeEmbeddings())
top.addMemories(this)

val bottom = LocalVectorStore(FakeEmbeddings())
bottom.addMemories(bottomMessages)

return CombinedVectorStore(top, bottom)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig

class FakeEmbeddings : Embeddings {
override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> = emptyList()

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
emptyList()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.xebia.functional.xef.vectorstores

import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe

class LocalVectorStoreSpec :
StringSpec({
"memories function should return all of messages in the right order when the limit is greater than the number of stored messages" {
val localVectorStore = LocalVectorStore(FakeEmbeddings())

val messages1 = generateRandomMessages(4, startTimestamp = 1000)
val messages2 = generateRandomMessages(3, startTimestamp = 2000)

localVectorStore.addMemories(messages1)
localVectorStore.addMemories(messages2)

val messages = localVectorStore.memories(defaultConversationId, Int.MAX_VALUE)

val messagesExpected = messages1 + messages2

messages shouldBe messagesExpected
}

"memories function should return the last n messages in the right order" {
val localVectorStore = LocalVectorStore(FakeEmbeddings())

val limit = 3 * 2 // 3 couples of messages

val messages1 = generateRandomMessages(4, startTimestamp = 1000)
val messages2 = generateRandomMessages(3, startTimestamp = 2000)

localVectorStore.addMemories(messages1)
localVectorStore.addMemories(messages2)

val messages = localVectorStore.memories(defaultConversationId, limit)

val messagesExpected = (messages1 + messages2).takeLast(limit)

messages shouldBe messagesExpected
}

"memories function should return the last n messages in the right order for a specific conversation id" {
val localVectorStore = LocalVectorStore(FakeEmbeddings())

val limit = 3 * 2

val firstId = ConversationId("first-id")
val secondId = ConversationId("second-id")

val messages1 = generateRandomMessages(4, conversationId = firstId, startTimestamp = 1000)
val messages2 = generateRandomMessages(3, conversationId = secondId, startTimestamp = 2000)

localVectorStore.addMemories(messages1 + messages2)

val messagesFirstId = localVectorStore.memories(firstId, limit)
val messagesFirstIdExpected = messages1.takeLast(limit)

val messagesSecondId = localVectorStore.memories(secondId, limit)
val messagesSecondIdExpected = messages2.takeLast(limit)

messagesFirstId shouldBe messagesFirstIdExpected
messagesSecondId shouldBe messagesSecondIdExpected
}
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role

val defaultConversationId = ConversationId("default-id")

fun generateRandomMessages(
n: Int,
append: String? = null,
conversationId: ConversationId = defaultConversationId,
startTimestamp: Long = 0
): List<Memory> =
(0 until n).flatMap {
listOf(
Memory(
conversationId,
Message(Role.USER, "Question $it${append?.let { ": $it" } ?: ""}", "USER"),
startTimestamp + (it * 10)
),
Memory(
conversationId,
Message(Role.ASSISTANT, "Response $it${append?.let { ": $it" } ?: ""}", "ASSISTANT"),
startTimestamp + (it * 10) + 1
),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.xebia.functional.xef.auto

import com.xebia.functional.xef.auto.llm.openai.getOrElse
import com.xebia.functional.xef.auto.llm.openai.promptMessage

suspend fun main() {
ai {

val emailMessage = """
|You are a Marketing Responsible and have the information about different products. You have to prepare
|an email template with the personal information
""".trimMargin()

val email: String = promptMessage(emailMessage)

println("Prompt:\n $emailMessage")
println("Response:\n $email")

val summarizePrompt = """
|You are a Marketing Responsible and have the information about the best rated products.
|Summarize the next information:
|Love this product and so does my husband! He tried it because his face gets chapped and red from
|working outside. It actually helped by about 60%! I love it cuz it's lightweight and smells so yummy!
|After applying makeup, it doesn't leave streaks like other moisturizers cause. i would definitely use
|this!
|
|I've been using this for 10+yrs now. I don't have any noticeable
|wrinkles at all. I use Estée Lauder's micro essence then advance repair serum before I apply this
|lotion. A little goes a long way! It does feel greasier than most face lotions that I've tried
|previously but I don't apply much. I enjoy cucumber like scent of the face lotion. I have combination
|skin and never broke out using this. This is my daily skincare product with or without makeup. And it
|has SPF but I also apply Kravebeauty SPF on top as well for extra protection
""".trimMargin()

val summarize: String = promptMessage(summarizePrompt)

println("Prompt:\n $summarizePrompt}")
println("Response:\n $summarize")

val meaningPrompt = """
|What is the meaning of life?
""".trimMargin()

val meaning: String = promptMessage(meaningPrompt)

println("Prompt:\n $meaningPrompt}")
println("Response:\n $meaning")


}.getOrElse { println(it) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PGVectorStore(
content = Message(
role = Role.valueOf(role.uppercase()),
content = content,
name = "role",
name = role,
),
timestamp = timestamp,
)
Expand Down Expand Up @@ -83,9 +83,9 @@ class PGVectorStore(

fun createCollection(): Unit =
dataSource.connection {
val xa = UUID.generateUUID()
val uuid = UUID.generateUUID()
update(addNewCollection) {
bind(xa.toString())
bind(uuid.toString())
bind(collectionName)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ val createMemoryTable: String =
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT UNIQUE NOT NULL,
timestamp TIMESTAMP NOT NULL,
timestamp BIGINT NOT NULL
);"""
.trimIndent()

Expand Down Expand Up @@ -97,7 +97,7 @@ val getCollectionById: String =
val getMemoriesByConversationId: String =
"""SELECT * FROM xef_memory
WHERE conversation_id = ?
ORDER BY timestamp DESC LIMIT ?;"""
ORDER BY timestamp ASC LIMIT ?;"""
.trimIndent()

val addNewDocument: String =
Expand Down
Loading

0 comments on commit 7b74452

Please sign in to comment.