Skip to content

Commit

Permalink
Remove Arrow Resource from Core (#178)
Browse files Browse the repository at this point in the history
Co-authored-by: yago <[email protected]>
  • Loading branch information
nomisRev and Yawolf authored Jun 12, 2023
1 parent dd386b8 commit c4f28cb
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 128 deletions.
4 changes: 4 additions & 0 deletions core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ kotlin {
mingwX64()

sourceSets {
all {
languageSettings.optIn("kotlin.ExperimentalStdlibApi")
}

val commonMain by getting {
dependencies {
api(libs.bundles.arrow)
Expand Down
54 changes: 18 additions & 36 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/auto/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package com.xebia.functional.xef.auto

import arrow.core.Either
import arrow.core.left
import arrow.core.raise.Raise
import arrow.core.right
import arrow.fx.coroutines.ResourceScope
import arrow.fx.coroutines.resourceScope
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.embeddings.OpenAIEmbeddings
Expand All @@ -14,18 +11,15 @@ import com.xebia.functional.xef.llm.openai.KtorOpenAIClient
import com.xebia.functional.xef.llm.openai.OpenAIClient
import com.xebia.functional.xef.vectorstores.CombinedVectorStore
import com.xebia.functional.xef.vectorstores.LocalVectorStore
import com.xebia.functional.xef.vectorstores.LocalVectorStoreBuilder
import com.xebia.functional.xef.vectorstores.VectorStore
import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlin.jvm.JvmName
import kotlin.time.ExperimentalTime

@DslMarker annotation class AiDsl

/**
* An [AI] value represents an action relying on artificial intelligence that can be run to produce
* an [A]. This value is _lazy_ and can be combined with other `AI` values using [AIScope.invoke],
* an `A`. This value is _lazy_ and can be combined with other `AI` values using [AIScope.invoke],
* and thus forms a monadic DSL.
*
* All [AI] actions that are composed together using [AIScope.invoke] share the same [VectorStore],
Expand All @@ -45,16 +39,14 @@ inline fun <A> ai(noinline block: suspend AIScope.() -> A): AI<A> = block
suspend inline fun <A> AI<A>.getOrElse(crossinline orElse: suspend (AIError) -> A): A =
AIScope(this) { orElse(it) }

@OptIn(ExperimentalTime::class)
@OptIn(ExperimentalTime::class, ExperimentalStdlibApi::class)
suspend fun <A> AIScope(block: suspend AIScope.() -> A, orElse: suspend (AIError) -> A): A =
try {
resourceScope {
val openAIConfig = OpenAIConfig()
val openAiClient: OpenAIClient = KtorOpenAIClient(openAIConfig)
val logger = KotlinLogging.logger("AutoAI")
val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient, logger)
val openAIConfig = OpenAIConfig()
KtorOpenAIClient(openAIConfig).use { openAiClient ->
val embeddings = OpenAIEmbeddings(openAIConfig, openAiClient)
val vectorStore = LocalVectorStore(embeddings)
val scope = AIScope(openAiClient, vectorStore, embeddings, logger, this)
val scope = AIScope(openAiClient, vectorStore, embeddings)
block(scope)
}
} catch (e: AIError) {
Expand All @@ -78,7 +70,7 @@ suspend inline fun <reified A> AI<A>.toEither(): Either<AIError, A> =
*
* This operator is **terminal** meaning it runs and completes the _chain_ of `AI` actions.
*
* @throws AIException in case something went wrong.
* @throws AIError in case something went wrong.
* @see getOrElse for an operator that allow directly handling the [AIError] case instead of
* throwing.
*/
Expand All @@ -87,17 +79,12 @@ suspend inline fun <reified A> AI<A>.getOrThrow(): A = getOrElse { throw it }
/**
* The [AIScope] is the context in which [AI] values are run. It encapsulates all the dependencies
* required to run [AI] values, and provides convenient syntax for writing [AI] based programs.
*
* It exposes the [ResourceScope] so you can easily add your own resources with the scope of the
* [AI] program, and [Raise] of [AIError] in case you want to compose any [Raise] based actions.
*/
class AIScope(
val openAIClient: OpenAIClient,
val context: VectorStore,
internal val embeddings: Embeddings,
private val logger: KLogger,
resourceScope: ResourceScope
) : ResourceScope by resourceScope {
val embeddings: Embeddings
) {

/**
* Allows invoking [AI] values in the context of this [AIScope].
Expand Down Expand Up @@ -140,28 +127,23 @@ class AIScope(
}

/**
* Creates a new scoped [VectorStore] using [store], which is scoped to the [block] lambda. The
* [block] also runs on a _nested_ [resourceScope], meaning that all additional resources created
* within [block] will be finalized after [block] finishes.
* Creates a nested scope that combines the provided [store] with the outer _store_. This is done
* using [CombinedVectorStore].
*
* **Note:** if the implementation of [VectorStore] is relying on resources you're manually
* responsible for closing any potential resources.
*/
@AiDsl
suspend fun <A> contextScope(
store: suspend ResourceScope.(Embeddings) -> VectorStore,
block: AI<A>
): A = resourceScope {
val newStore = store(this@AIScope.embeddings)
suspend fun <A> contextScope(store: VectorStore, block: AI<A>): A =
AIScope(
this@AIScope.openAIClient,
CombinedVectorStore(newStore, this@AIScope.context),
this@AIScope.embeddings,
this@AIScope.logger,
this
CombinedVectorStore(store, this@AIScope.context),
this@AIScope.embeddings
)
.block()
}

@AiDsl
suspend fun <A> contextScope(block: AI<A>): A = contextScope(LocalVectorStoreBuilder, block)
suspend fun <A> contextScope(block: AI<A>): A = contextScope(LocalVectorStore(embeddings), block)

/** Add new [docs] to the [context], and then executes the [block]. */
@AiDsl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@ import com.xebia.functional.xef.env.OpenAIConfig
import com.xebia.functional.xef.llm.openai.EmbeddingRequest
import com.xebia.functional.xef.llm.openai.OpenAIClient
import com.xebia.functional.xef.llm.openai.RequestConfig
import io.github.oshai.kotlinlogging.KLogger
import kotlin.time.ExperimentalTime

@ExperimentalTime
class OpenAIEmbeddings(
private val config: OpenAIConfig,
private val oaiClient: OpenAIClient,
private val logger: KLogger
) : Embeddings {
class OpenAIEmbeddings(private val config: OpenAIConfig, private val oaiClient: OpenAIClient) :
Embeddings {

override suspend fun embedDocuments(
texts: List<String>,
Expand Down
19 changes: 0 additions & 19 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/ktor.kt
Original file line number Diff line number Diff line change
@@ -1,32 +1,13 @@
package com.xebia.functional.xef

import arrow.fx.coroutines.ResourceScope
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpRequestRetry
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.defaultRequest
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.setBody
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json

inline fun <reified A> HttpRequestBuilder.configure(token: String, request: A): Unit {
header("Authorization", "Bearer $token")
contentType(ContentType.Application.Json)
setBody(request)
}

suspend fun ResourceScope.httpClient(baseUrl: String): HttpClient =
install({
HttpClient {
install(HttpTimeout)
install(ContentNegotiation) { json() }
install(HttpRequestRetry)
defaultRequest { url(baseUrl) }
}
}) { client, _ ->
client.close()
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
package com.xebia.functional.xef.llm.huggingface

import arrow.fx.coroutines.ResourceScope
import com.xebia.functional.xef.configure
import com.xebia.functional.xef.env.HuggingFaceConfig
import com.xebia.functional.xef.httpClient
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.defaultRequest
import io.ktor.client.request.post
import io.ktor.http.path
import io.ktor.serialization.kotlinx.json.json

interface HuggingFaceClient {
suspend fun generate(request: InferenceRequest, model: Model): List<Generation>
}

suspend fun ResourceScope.KtorHuggingFaceClient(config: HuggingFaceConfig): HuggingFaceClient =
KtorHuggingFaceClient(httpClient(config.baseUrl), config)
class KtorHuggingFaceClient(private val config: HuggingFaceConfig) :
HuggingFaceClient, AutoCloseable {

private class KtorHuggingFaceClient(
private val httpClient: HttpClient,
private val config: HuggingFaceConfig
) : HuggingFaceClient {
private val httpClient: HttpClient = HttpClient {
install(HttpTimeout)
install(ContentNegotiation) { json() }
defaultRequest { url(config.baseUrl) }
}

override suspend fun generate(request: InferenceRequest, model: Model): List<Generation> {
val response =
Expand All @@ -29,4 +32,8 @@ private class KtorHuggingFaceClient(
}
return response.body()
}

override fun close() {
httpClient.close()
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
package com.xebia.functional.xef.llm.openai

import arrow.fx.coroutines.ResourceScope
import com.xebia.functional.xef.configure
import com.xebia.functional.xef.env.OpenAIConfig
import com.xebia.functional.xef.httpClient
import io.github.oshai.kotlinlogging.KLogger
import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.defaultRequest
import io.ktor.client.plugins.timeout
import io.ktor.client.request.post
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.bodyAsText
import io.ktor.http.HttpStatusCode
import io.ktor.http.path
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
Expand All @@ -38,13 +40,14 @@ data class ImagesGenerationResponse(val created: Long, val data: List<ImageGener

@Serializable data class ImageGenerationUrl(val url: String)

suspend fun ResourceScope.KtorOpenAIClient(config: OpenAIConfig): OpenAIClient =
KtorOpenAIClient(httpClient(config.baseUrl), config)
@OptIn(ExperimentalStdlibApi::class)
class KtorOpenAIClient(private val config: OpenAIConfig) : OpenAIClient, AutoCloseable {

private class KtorOpenAIClient(
private val httpClient: HttpClient,
private val config: OpenAIConfig
) : OpenAIClient {
private val httpClient: HttpClient = HttpClient {
install(HttpTimeout)
install(ContentNegotiation) { json() }
defaultRequest { url(config.baseUrl) }
}

private val logger: KLogger = KotlinLogging.logger {}

Expand Down Expand Up @@ -101,6 +104,8 @@ private class KtorOpenAIClient(
timeout { requestTimeoutMillis = config.requestTimeoutMillis }
}
.bodyOrError()

override fun close() = httpClient.close()
}

private suspend inline fun <reified T> HttpResponse.bodyOrError(): T =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.embeddings.Embedding

/**
* A way of composing two [VectorStore] instances together, this class will **first search** [top],
* and then [bottom].
*
* If all results can be found in [top] it will skip searching [bottom].
*/
class CombinedVectorStore(private val top: VectorStore, private val bottom: VectorStore) :
VectorStore by top {

fun pop(): VectorStore = bottom

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
val topResults = top.similaritySearch(query, limit)
return when {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.xebia.functional.xef.vectorstores

import arrow.fx.coroutines.ResourceScope
import arrow.fx.stm.TMap
import arrow.fx.stm.TVar
import arrow.fx.stm.atomically
Expand All @@ -10,10 +9,6 @@ import com.xebia.functional.xef.llm.openai.EmbeddingModel
import com.xebia.functional.xef.llm.openai.RequestConfig
import kotlin.math.sqrt

val LocalVectorStoreBuilder: suspend ResourceScope.(Embeddings) -> LocalVectorStore = { e ->
LocalVectorStore(e)
}

class LocalVectorStore
private constructor(
private val embeddings: Embeddings,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.xebia.functional.xef.vectorstores

import com.xebia.functional.xef.embeddings.Embedding
import kotlin.jvm.JvmStatic

interface VectorStore {
/**
Expand Down Expand Up @@ -32,6 +33,7 @@ interface VectorStore {
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String>

companion object {
@JvmStatic
val EMPTY: VectorStore =
object : VectorStore {
override suspend fun addTexts(texts: List<String>) {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.xebia.functional.xef.vectorstores

import arrow.fx.coroutines.ResourceScope
import arrow.fx.coroutines.autoCloseable
import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.openai.EmbeddingModel
Expand Down Expand Up @@ -91,8 +89,8 @@ fun InMemoryLuceneBuilder(
useAIEmbeddings: Boolean = true,
writerConfig: IndexWriterConfig = IndexWriterConfig(),
similarity: VectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN
): suspend ResourceScope.(Embeddings) -> DirectoryLucene = { embeddings ->
autoCloseable { InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity) }
): suspend (Embeddings) -> DirectoryLucene = { embeddings ->
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity)
}

fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.data }.toFloatArray()
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,29 @@ class PGVectorStore(
private val chunckSize: Int?
) : VectorStore {

suspend fun JDBCSyntax.getCollection(collectionName: String): PGCollection =
private fun JDBCSyntax.getCollection(collectionName: String): PGCollection =
queryOneOrNull(getCollection, { bind(collectionName) }) {
PGCollection(UUID(string()), string())
}
?: throw IllegalStateException("Collection '$collectionName' not found")

suspend fun JDBCSyntax.deleteCollection() {
private fun JDBCSyntax.deleteCollection() {
if (preDeleteCollection) {
val collection = getCollection(collectionName)
update(deleteCollectionDocs) { bind(collection.uuid.toString()) }
update(deleteCollection) { bind(collection.uuid.toString()) }
}
}

suspend fun initialDbSetup(): Unit =
fun initialDbSetup(): Unit =
dataSource.connection {
update(addVectorExtension)
update(createCollectionsTable)
update(createEmbeddingTable(vectorSize))
deleteCollection()
}

suspend fun createCollection(): Unit =
fun createCollection(): Unit =
dataSource.connection {
val xa = UUID.generateUUID()
update(addNewCollection) {
Expand Down
Loading

0 comments on commit c4f28cb

Please sign in to comment.