Skip to content

Commit

Permalink
Adding test task for Postgres PGVectorStore (#509)
Browse files Browse the repository at this point in the history
* adding test task

* Bugs fixed on PGVectorStoreSpec

* VectorSize for OpenAI by default on PGPostgreStore

---------

Co-authored-by: Javi Pacheco <[email protected]>
  • Loading branch information
Montagon and javipacheco authored Oct 27, 2023
1 parent d164266 commit 122bf81
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 21 deletions.
4 changes: 4 additions & 0 deletions integrations/postgresql/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ tasks {

withType<AbstractPublishToMaven> { dependsOn(withType<Sign>()) }
}

tasks.test{
useJUnitPlatform()
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fun createEmbeddingTable(vectorSize: Int): String =

val addNewMemory: String =
"""INSERT INTO xef_memory(uuid, conversation_id, role, content, index)
VALUES (?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT DO NOTHING;"""
.trimIndent()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ CREATE TABLE IF NOT EXISTS xef_memory (
CREATE TABLE IF NOT EXISTS xef_embeddings (
uuid TEXT PRIMARY KEY,
collection_id TEXT REFERENCES xef_collections(uuid),
embedding vector(3),
embedding vector(1536),
content TEXT
);

27 changes: 27 additions & 0 deletions integrations/postgresql/src/test/kotlin/xef/MemoryData.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package xef

import arrow.atomic.AtomicInt
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory

class MemoryData {
val defaultConversationId = ConversationId("default-id")

val atomicInt = AtomicInt(0)

fun generateRandomMessages(
n: Int,
append: String? = null,
conversationId: ConversationId = defaultConversationId
): List<Memory> =
(0 until n).flatMap {
val m1 = Message(Role.USER, "Question $it${append?.let { ": $it" } ?: ""}", Role.USER.toString().lowercase())
val m2 = Message(Role.ASSISTANT, "Response $it${append?.let { ": $it" } ?: ""}", Role.ASSISTANT.toString().lowercase())
listOf(
Memory(conversationId, m1, atomicInt.addAndGet(1)),
Memory(conversationId, m2, atomicInt.addAndGet(1)),
)
}
}
26 changes: 8 additions & 18 deletions integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory
import com.xebia.functional.xef.llm.models.usage.Usage
import com.xebia.functional.xef.store.PGVectorStore
import com.xebia.functional.xef.store.postgresql.PGDistanceStrategy
import com.zaxxer.hikari.HikariConfig
Expand All @@ -20,8 +19,6 @@ import io.kotest.core.spec.style.StringSpec
import io.kotest.extensions.testcontainers.ContainerExtension
import io.kotest.matchers.shouldBe
import kotlinx.coroutines.flow.Flow
import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID
import org.junit.jupiter.api.assertThrows
import org.testcontainers.containers.PostgreSQLContainer
import org.testcontainers.utility.DockerImageName
Expand Down Expand Up @@ -93,19 +90,11 @@ class PGVectorStoreSpec :
}

"the added memories sorted by index should be obtained in the same order" {
val messages = 10
val memoryData = MemoryData()
val llm = TestLLM()
val conversationId = ConversationId(UUID.generateUUID().toString())
val memories = (0 until messages).flatMap {
val m1 = Message(Role.USER, "question $it", "user")
val m2 = Message(Role.ASSISTANT, "answer $it", "assistant")
listOf(
Memory(conversationId, m1, 1),
Memory(conversationId, m2, 2)
)
}
val memories = memoryData.generateRandomMessages(10)
pg.addMemories(memories)
memories shouldBe pg.memories(llm, conversationId, 1000)
memories shouldBe pg.memories(llm, memoryData.defaultConversationId, 1000)
}
})

Expand All @@ -130,11 +119,9 @@ class TestLLM(override val modelType: ModelType = ModelType.ADA) : Chat, AutoClo
}
}



private fun Embeddings.Companion.mock(
embedDocuments:
suspend (texts: List<String>, chunkSize: Int?, config: RequestConfig) -> List<Embedding> =
suspend (texts: List<String>, config: RequestConfig, chunkSize: Int?) -> List<Embedding> =
{ _, _, _ ->
listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f)))
},
Expand All @@ -145,6 +132,9 @@ private fun Embeddings.Companion.mock(
"baz" -> listOf()
else -> listOf()
}
},
createEmbeddings: suspend (request: EmbeddingRequest) -> EmbeddingResult = { _ ->
EmbeddingResult(listOf(Embedding(listOf(1.0f, 2.0f, 3.0f)), Embedding(listOf(4.0f, 5.0f, 6.0f))), Usage.ZERO)
}
): Embeddings =
object : Embeddings {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object PostgreSQLXef {

data class PGVectorStoreConfig(
val dbConfig: DBConfig,
val vectorSize: Int = 3,
val vectorSize: Int = 1536, // OpenAI default
val collectionName: String = "xef_collection",
val preDeleteCollection: Boolean = false,
val chunkSize: Int? = null,
Expand Down

0 comments on commit 122bf81

Please sign in to comment.