diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 34ed11c53..5c8f84272 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -51,6 +51,8 @@ kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutin ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" } ktor-http = { module = "io.ktor:ktor-http", version.ref = "ktor" } ktor-client ={ module = "io.ktor:ktor-client-core", version.ref = "ktor" } +ktor-client-content-negotiation ={ module = "io.ktor:ktor-client-content-negotiation", version.ref = "ktor" } +ktor-client-serialization = { module = "io.ktor:ktor-serialization-kotlinx-json", version.ref = "ktor" } ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" } ktor-client-js = { module = "io.ktor:ktor-client-js", version.ref = "ktor" } ktor-client-winhttp = { module = "io.ktor:ktor-client-winhttp", version.ref = "ktor" } @@ -94,6 +96,11 @@ jackson-schema-jakarta = { module = "com.github.victools:jsonschema-module-jakar jakarta-validation = { module = "jakarta.validation:jakarta.validation-api", version.ref = "jakarta" } [bundles] +ktor-client = [ + "ktor-client", + "ktor-client-content-negotiation", + "ktor-client-serialization" +] arrow = [ "arrow-core", "arrow-fx-coroutines" diff --git a/integrations/gcp/build.gradle.kts b/integrations/gcp/build.gradle.kts index e6eb03aa4..aaf91acfe 100644 --- a/integrations/gcp/build.gradle.kts +++ b/integrations/gcp/build.gradle.kts @@ -1,51 +1,89 @@ plugins { - id(libs.plugins.kotlin.multiplatform.get().pluginId) - id(libs.plugins.kotlinx.serialization.get().pluginId) - alias(libs.plugins.spotless) - alias(libs.plugins.arrow.gradle.publish) - alias(libs.plugins.semver.gradle) + id(libs.plugins.kotlin.multiplatform.get().pluginId) + id(libs.plugins.kotlinx.serialization.get().pluginId) + alias(libs.plugins.spotless) + alias(libs.plugins.arrow.gradle.publish) + alias(libs.plugins.semver.gradle) } repositories { - mavenCentral() + mavenCentral() } java { - sourceCompatibility = JavaVersion.VERSION_11 - targetCompatibility = JavaVersion.VERSION_11 - toolchain { - languageVersion = JavaLanguageVersion.of(11) - } + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 + toolchain { + languageVersion = JavaLanguageVersion.of(11) + } } kotlin { - jvm() - js(IR) { - browser() - nodejs() + jvm() + js(IR) { + browser() + nodejs() + } + + linuxX64() + macosX64() + macosArm64() + mingwX64() + + sourceSets { + val commonMain by getting { + dependencies { + api(projects.xefCore) + implementation(libs.bundles.ktor.client) + } + } + + val jvmMain by getting { + dependencies { + implementation(libs.logback) + api(libs.ktor.client.cio) + } + } + + val jsMain by getting { + dependencies { + api(libs.ktor.client.js) + } } - linuxX64() - macosX64() - macosArm64() - mingwX64() - - sourceSets { - val commonMain by getting { - dependencies { - api(projects.xefCore) - } - } + val linuxX64Main by getting { + dependencies { + api(libs.ktor.client.cio) + } } + + val macosX64Main by getting { + dependencies { + api(libs.ktor.client.cio) + } + } + + val macosArm64Main by getting { + dependencies { + api(libs.ktor.client.cio) + } + } + + val mingwX64Main by getting { + dependencies { + api(libs.ktor.client.winhttp) + } + } + } } spotless { - kotlin { - target("**/*.kt") - ktfmt().googleStyle() - } + kotlin { + target("**/*.kt") + ktfmt().googleStyle() + } } tasks.withType { - dependsOn(tasks.withType()) + dependsOn(tasks.withType()) } diff --git a/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt new file mode 100644 index 000000000..a2f6d8878 --- /dev/null +++ b/integrations/gcp/src/commonMain/kotlin/com/xebia/functional/xef/gcp/GcpClient.kt @@ -0,0 +1,116 @@ +package com.xebia.functional.xef.gcp + +import com.xebia.functional.xef.AIError +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.HttpRequestRetry +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.http.isSuccess +import io.ktor.serialization.kotlinx.json.json +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +@OptIn(ExperimentalStdlibApi::class) +class GcpClient( + private val apiEndpoint: String, + private val projectId: String, + private val modelId: String, + private val token: String +) : AutoCloseable { + private val http: HttpClient = HttpClient { + install(HttpTimeout) + install(HttpRequestRetry) + install(ContentNegotiation) { + json( + Json { + encodeDefaults = false + isLenient = true + } + ) + } + } + + @Serializable + private data class Prompt(val instances: List, val parameters: Parameters? = null) + + @Serializable + private data class Instance( + val context: String? = null, + val examples: List? = null, + val messages: List, + ) + + @Serializable data class Example(val input: String, val output: String) + + @Serializable private data class Message(val author: String, val content: String) + + @Serializable + private class Parameters( + val temperature: Double? = null, + val maxOutputTokens: Int? = null, + val topK: Int? = null, + val topP: Double? = null + ) + + @Serializable data class Response(val predictions: List) + + @Serializable + data class SafetyAttributes( + val blocked: Boolean, + val scores: List, + val categories: List + ) + + @Serializable data class CitationMetadata(val citations: List) + + @Serializable data class Candidates(val author: String?, val content: String?) + + @Serializable + data class Predictions( + val safetyAttributes: List, + val citationMetadata: List, + val candidates: List + ) + + suspend fun promptMessage( + prompt: String, + temperature: Double? = null, + maxOutputTokens: Int? = null, + topK: Int? = null, + topP: Double? = null + ): String { + val body = + Prompt( + listOf(Instance(messages = listOf(Message(author = "user", content = prompt)))), + Parameters(temperature, maxOutputTokens, topK, topP) + ) + val response = + http.post( + "https://$apiEndpoint/v1/projects/$projectId/locations/us-central1/publishers/google/models/$modelId:predict" + ) { + header("Authorization", "Bearer $token") + contentType(ContentType.Application.Json) + setBody(body) + } + + return if (response.status.isSuccess()) + response.body().predictions.firstOrNull()?.candidates?.firstOrNull()?.content + ?: throw AIError.NoResponse() + else throw GcpClientException(response.status, response.bodyAsText()) + } + + class GcpClientException(val httpStatusCode: HttpStatusCode, val error: String) : + IllegalStateException("$httpStatusCode: $error") + + override fun close() { + http.close() + } +}