Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gcp runtime #371

Merged
merged 12 commits into from
Sep 1, 2023
Original file line number Diff line number Diff line change
@@ -1,34 +1,16 @@
package com.xebia.functional.xef.conversation.gpc

import arrow.core.nonEmptyListOf
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpChat
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.GcpEmbeddings
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.gcp.GCP
import com.xebia.functional.xef.gcp.promptMessage
import com.xebia.functional.xef.prompt.Prompt

suspend fun main() {
OpenAI.conversation {
val token =
getenv("GCP_TOKEN") ?: throw AIError.Env.GCP(nonEmptyListOf("missing GCP_TOKEN env var"))

val gcp =
GcpChat("codechat-bison@001", GcpConfig(token, "xefdemo", "us-central1")).let(::autoClose)
val gcpEmbeddingModel =
GcpChat("codechat-bison@001", GcpConfig(token, "xefdemo", "us-central1")).let(::autoClose)

val embeddingResult =
GcpEmbeddings(gcpEmbeddingModel)
.embedQuery("strawberry donuts", RequestConfig(RequestConfig.Companion.User("user")))
println(embeddingResult)

GCP.conversation {
while (true) {
print("\n🤖 Enter your question: ")
val userInput = readlnOrNull() ?: break
val answer = gcp.promptMessage(Prompt(userInput))
if (userInput == "exit") break
val answer = promptMessage(Prompt(userInput))
println("\n🤖 $answer")
}
println("\n🤖 Done")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package com.xebia.functional.xef.conversation.gpc
import com.xebia.functional.gpt4all.conversation
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.VertexAIRegion
import com.xebia.functional.xef.gcp.pipelines.GcpPipelinesClient

suspend fun main() {
conversation {
val token = getenv("GCP_TOKEN") ?: error("missing gcp token")
val pipelineClient = autoClose(GcpPipelinesClient(GcpConfig(token, "xefdemo", "us-central1")))
val pipelineClient =
autoClose(GcpPipelinesClient(GcpConfig(token, "xefdemo", VertexAIRegion.US_CENTRAL1)))
val answer = pipelineClient.list()
println("\n🤖 $answer")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.xebia.functional.xef.gcp

import arrow.core.nonEmptyListOf
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.PlatformConversation
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.store.LocalVectorStore
import com.xebia.functional.xef.store.VectorStore
import kotlin.jvm.JvmField
import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic
import kotlin.jvm.JvmSynthetic

private const val GCP_TOKEN_ENV_VAR = "GCP_TOKEN"
private const val GCP_PROJECT_ID_VAR = "GCP_PROJECT_ID"
private const val GCP_LOCATION_VAR = "GCP_LOCATION"

class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: String? = null) {
private val config =
GcpConfig(
token = token ?: tokenFromEnv(),
projectId = projectId ?: projectIdFromEnv(),
location = location ?: locationFromEnv(),
)

private fun tokenFromEnv(): String = fromEnv(GCP_TOKEN_ENV_VAR)

private fun projectIdFromEnv(): String = fromEnv(GCP_PROJECT_ID_VAR)

private fun locationFromEnv(): VertexAIRegion =
fromEnv(GCP_LOCATION_VAR).let { envVar ->
VertexAIRegion.entries.find { it.officialName == envVar }
}
?: throw AIError.Env.GCP(
nonEmptyListOf(
"invalid value for $GCP_LOCATION_VAR - valid values are ${VertexAIRegion.entries.map(VertexAIRegion::officialName)}"
)
)

private fun fromEnv(name: String): String =
getenv(name) ?: throw AIError.Env.GCP(nonEmptyListOf("missing $name env var"))

val CODECHAT by lazy { GcpModel("codechat-bison@001", config) }
val TEXT_EMBEDDING_GECKO by lazy { GcpModel("textembedding-gecko", config) }

@JvmField val DEFAULT_CHAT = CODECHAT
@JvmField val DEFAULT_EMBEDDING = TEXT_EMBEDDING_GECKO

fun supportedModels(): List<GcpModel> = listOf(CODECHAT, TEXT_EMBEDDING_GECKO)

companion object {

@JvmField val FromEnvironment: GCP = GCP()

@JvmSynthetic
suspend inline fun <A> conversation(
store: VectorStore,
noinline block: suspend Conversation.() -> A
): A = block(conversation(store))

@JvmSynthetic
suspend fun <A> conversation(block: suspend Conversation.() -> A): A =
block(conversation(LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))))

@JvmStatic
@JvmOverloads
fun conversation(
store: VectorStore = LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))
): PlatformConversation = Conversation(store)
}
}

suspend inline fun <A> GCP.conversation(noinline block: suspend Conversation.() -> A): A =
block(Conversation(LocalVectorStore(GcpEmbeddings(DEFAULT_EMBEDDING))))
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@ import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import io.ktor.client.*
import io.ktor.client.HttpClient
import io.ktor.client.call.*
import io.ktor.client.call.body
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.statement.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.*
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
Expand Down Expand Up @@ -85,7 +79,7 @@ class GcpClient(
)
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down Expand Up @@ -137,7 +131,7 @@ class GcpClient(
)
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/publishers/google/models/$modelId:predict"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ package com.xebia.functional.xef.gcp
data class GcpConfig(
val token: String,
val projectId: String,
/** https://cloud.google.com/vertex-ai/docs/general/locations */
val location: String, // Supported us-central1 or europe-west4
/** [GCP locations](https://cloud.google.com/vertex-ai/docs/general/locations) */
val location: VertexAIRegion, // Supported us-central1 or europe-west4
)

enum class VertexAIRegion(val officialName: String) {
Intex32 marked this conversation as resolved.
Show resolved Hide resolved
US_CENTRAL1("us-central1"),
EU_WEST4("europe-west4"),
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID

@OptIn(ExperimentalStdlibApi::class)
class GcpChat(modelId: String, config: GcpConfig) : Chat, Completion, AutoCloseable, Embeddings {
class GcpModel(modelId: String, config: GcpConfig) : Chat, Completion, AutoCloseable, Embeddings {
private val client: GcpClient = GcpClient(modelId, config)

override val name: String = client.modelId
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef.gcp

import com.xebia.functional.xef.conversation.AiDsl
import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import kotlinx.coroutines.flow.Flow

@AiDsl
suspend fun Conversation.promptMessage(prompt: Prompt, model: Chat = GCP().DEFAULT_CHAT): String =
model.promptMessage(prompt, this)

@AiDsl
fun Conversation.promptStreaming(prompt: Prompt, model: Chat = GCP().DEFAULT_CHAT): Flow<String> =
model.promptStreaming(prompt, this)
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class GcpPipelinesClient(
suspend fun list(): List<PipelineJob> {
val response =
http.get(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -92,7 +92,7 @@ class GcpPipelinesClient(
suspend fun get(pipelineJobName: String): PipelineJob? {
val response =
http.get(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs/$pipelineJobName"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -105,7 +105,7 @@ class GcpPipelinesClient(
suspend fun create(pipelineJobId: String?, pipelineJob: CreatePipelineJob): PipelineJob? {
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -120,7 +120,7 @@ class GcpPipelinesClient(
suspend fun cancel(pipelineJobName: String): Unit {
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName:cancel"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs/$pipelineJobName:cancel"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand All @@ -133,7 +133,7 @@ class GcpPipelinesClient(
suspend fun delete(pipelineJobName: String): Operation {
val response =
http.delete(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName"
"https://${config.location.officialName}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location.officialName}/pipelineJobs/$pipelineJobName"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@ import kotlin.jvm.JvmOverloads
import kotlin.jvm.JvmStatic
import kotlin.jvm.JvmSynthetic

private const val KEY_ENV_VAR = "OPENAI_TOKEN"
private const val HOST_ENV_VAR = "OPENAI_HOST"

class OpenAI(internal var token: String? = null, internal var host: String? = null) :
AutoCloseable, AutoClose by autoClose() {

private fun openAITokenFromEnv(): String {
return getenv("OPENAI_TOKEN")
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing OPENAI_TOKEN env var"))
return getenv(KEY_ENV_VAR)
?: throw AIError.Env.OpenAI(nonEmptyListOf("missing $KEY_ENV_VAR env var"))
}

private fun openAIHostFromEnv(): String? {
return getenv("OPENAI_HOST")
return getenv(HOST_ENV_VAR)
}

fun getToken(): String {
Expand Down