Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLflow Gateway models #507

Merged
merged 8 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies {
implementation(projects.xefOpenai)
implementation(projects.xefReasoning)
implementation(projects.xefOpentelemetry)
implementation(projects.xefMlflow)
implementation(libs.kotlinx.serialization.json)
implementation(libs.logback)
implementation(libs.klogging)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.xebia.functional.xef.conversation.mlflow

import com.xebia.functional.xef.mlflow.*
import com.xebia.functional.xef.mlflow.MlflowClient
import io.ktor.client.*

suspend fun main() {

val gatewayUri = "http://localhost:5000"

val httpClient = HttpClient()

val client = MlflowClient(gatewayUri, httpClient)

println("MLflow Gateway client created. Press any key to continue...")
readlnOrNull()

println("Searching available models...")
println()
val routes = client.searchRoutes()

println(
"""
|######### Routes found #########
|${routes.joinToString(separator = "\n") { printRoute(it) }}
|
"""
.trimMargin()
)
println()

while (true) {

println("Select the route you want to interact with")
val route = readlnOrNull() ?: "chat"

val gptRoute = client.getRoute(route)
println("Route found: ${gptRoute?.name}. What do you want to ask?")

val question = readlnOrNull() ?: "What's the best day of the week and why?"

val response =
gptRoute?.name?.let { it ->
client.chat(
it,
listOf(
ChatMessage(ChatRole.SYSTEM, "You are a helpful assistant. Be concise"),
ChatMessage(ChatRole.USER, question),
),
temperature = 0.7,
maxTokens = 200
)
}

val chatResponse = response?.candidates?.get(0)?.message?.content

println("Chat GPT response was: \n\n$chatResponse")
println()
println("Do you want to continue? (y/N)")
val userInput = readlnOrNull() ?: ""
if (!userInput.equals("y", true)) break
}

httpClient.close()
}

private fun printModel(model: RouteModel): String =
"(name = '${model.name}', provider = '${model.provider}')"

private fun printRoute(r: RouteDefinition): String =
"""
|Name: ${r.name}
| * Route type: ${r.routeType}
| * Route url: ${r.routeUrl}
| * Model: ${printModel(r.model)}"""
.trimMargin()
123 changes: 123 additions & 0 deletions integrations/mlflow/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
plugins {
id(libs.plugins.kotlin.multiplatform.get().pluginId)
id(libs.plugins.kotlinx.serialization.get().pluginId)
alias(libs.plugins.spotless)
alias(libs.plugins.arrow.gradle.publish)
alias(libs.plugins.semver.gradle)
alias(libs.plugins.detekt)
}


dependencies {
detektPlugins(project(":detekt-rules"))
}

detekt {
toolVersion = "1.23.1"
source = files("src/commonMain/kotlin", "src/jvmMain/kotlin")
config.setFrom("../../config/detekt/detekt.yml")
autoCorrect = true
}


repositories {
mavenCentral()
}

java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
toolchain {
languageVersion = JavaLanguageVersion.of(11)
}
}

kotlin {
jvm()
js(IR) {
browser()
nodejs()
}

linuxX64()
macosX64()
macosArm64()
mingwX64()

sourceSets {
val commonMain by getting {
dependencies {
api(projects.xefCore)
implementation(libs.bundles.ktor.client)
implementation(libs.uuid)
implementation(libs.kotlinx.datetime)
}
}

val jvmMain by getting {
dependencies {
implementation(libs.logback)
api(libs.ktor.client.cio)
}
}

val jsMain by getting {
dependencies {
api(libs.ktor.client.js)
}
}

val linuxX64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val macosX64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val macosArm64Main by getting {
dependencies {
api(libs.ktor.client.cio)
}
}

val mingwX64Main by getting {
dependencies {
api(libs.ktor.client.winhttp)
}
}
}
}

spotless {
kotlin {
target("**/*.kt")
ktfmt().googleStyle().configure {
it.setRemoveUnusedImport(true)
}
}
}

tasks{
withType<io.gitlab.arturbosch.detekt.Detekt>().configureEach {
dependsOn(":detekt-rules:assemble")
autoCorrect = true
}
named("detektJvmMain") {
dependsOn(":detekt-rules:assemble")
getByName("build").dependsOn(this)
}
named("detekt") {
dependsOn(":detekt-rules:assemble")
getByName("build").dependsOn(this)
}
withType<AbstractPublishToMaven> {
dependsOn(withType<Sign>())
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.xebia.functional.xef.mlflow

import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.json.Json

class MlflowClient(private val gatewayUrl: String, client: HttpClient) : AutoClose by autoClose() {

private val internal =
client.config {
install(ContentNegotiation) {
json(
Json {
encodeDefaults = false
isLenient = true
ignoreUnknownKeys = true
}
)
}
}

private val json = Json { ignoreUnknownKeys = true }

private suspend fun routes(): List<RouteDefinition> {

val response = internal.get("$gatewayUrl/api/2.0/gateway/routes/")
if (response.status.isSuccess()) {
val textResponse = response.bodyAsText()
val data = json.decodeFromString<RoutesResponse>(textResponse)
return data.routes
} else {
throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}
}

suspend fun searchRoutes(): List<RouteDefinition> = routes()

suspend fun getRoute(name: String): RouteDefinition? = routes().find { it.name == name }

suspend fun prompt(
route: String,
prompt: String,
candidateCount: Int? = null,
temperature: Double? = null,
maxTokens: Int? = null,
stop: List<String>? = null
): PromptResponse {
val body = Prompt(prompt, temperature, candidateCount, stop, maxTokens)
val response =
internal.post("$gatewayUrl/gateway/$route/invocations") {
accept(ContentType.Application.Json)
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess()) response.body<PromptResponse>()
else if (response.status.value == 422)
throw MLflowValidationError(
response.status,
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error"
)
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}

suspend fun chat(
route: String,
messages: List<ChatMessage>,
candidateCount: Int? = null,
temperature: Double? = null,
maxTokens: Int? = null,
stop: List<String>? = null
): ChatResponse {
val body = Chat(messages, temperature, candidateCount, stop, maxTokens)
val response =
internal.post("$gatewayUrl/gateway/$route/invocations") {
accept(ContentType.Application.Json)
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess()) response.body<ChatResponse>()
else if (response.status.value == 422)
throw MLflowValidationError(
response.status,
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error"
)
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}

suspend fun embeddings(route: String, text: List<String>): EmbeddingsResponse {
val body = Embeddings(text)
val response =
internal.post("$gatewayUrl/gateway/$route/invocations") {
accept(ContentType.Application.Json)
contentType(ContentType.Application.Json)
setBody(body)
}

return if (response.status.isSuccess()) response.body<EmbeddingsResponse>()
else if (response.status.value == 422)
throw MLflowValidationError(
response.status,
response.body<ValidationError>().detail?.firstOrNull()?.msg ?: "Unknown error"
)
else throw MLflowClientUnexpectedError(response.status, response.bodyAsText())
}

class MLflowValidationError(httpStatusCode: HttpStatusCode, error: String) :
IllegalStateException("$httpStatusCode: $error")

class MLflowClientUnexpectedError(httpStatusCode: HttpStatusCode, error: String) :
IllegalStateException("$httpStatusCode: $error")
}
Loading