Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing VectorStoreService dependency in Xef Routes #459

Merged
merged 1 commit into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions server/src/main/kotlin/com/xebia/functional/xef/server/Server.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,8 @@ import com.xebia.functional.xef.server.db.psql.XefDatabaseConfig
import com.xebia.functional.xef.server.db.psql.XefVectorStoreConfig
import com.xebia.functional.xef.server.db.psql.XefVectorStoreConfig.Companion.getVectorStoreService
import com.xebia.functional.xef.server.exceptions.exceptionsHandler
import com.xebia.functional.xef.server.http.routes.genAIRoutes
import com.xebia.functional.xef.server.http.routes.organizationRoutes
import com.xebia.functional.xef.server.http.routes.projectsRoutes
import com.xebia.functional.xef.server.http.routes.userRoutes
import com.xebia.functional.xef.server.services.OrganizationRepositoryService
import com.xebia.functional.xef.server.services.ProjectRepositoryService
import com.xebia.functional.xef.server.http.routes.*
import com.xebia.functional.xef.server.services.RepositoryService
import com.xebia.functional.xef.server.services.UserRepositoryService
import io.ktor.client.*
import io.ktor.client.engine.cio.*
import io.ktor.client.plugins.auth.*
Expand Down Expand Up @@ -85,10 +79,8 @@ object Server {
}
exceptionsHandler()
routing {
genAIRoutes(ktorClient, vectorStoreService)
userRoutes(UserRepositoryService(logger))
organizationRoutes(OrganizationRepositoryService(logger))
projectsRoutes(ProjectRepositoryService(logger))
xefRoutes(logger)
aiRoutes(ktorClient)
}
}
awaitCancellation()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package com.xebia.functional.xef.server.http.routes

import com.aallam.openai.api.BetaOpenAI
import com.xebia.functional.xef.server.models.Token
import com.xebia.functional.xef.server.models.exceptions.XefExceptions
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.utils.io.jvm.javaio.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonPrimitive

enum class Provider {
OPENAI, GPT4ALL, GCP
}

fun String.toProvider(): Provider? = when (this) {
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> Provider.OPENAI
}

@OptIn(BetaOpenAI::class)
fun Routing.aiRoutes(
client: HttpClient
) {
val openAiUrl = "https://api.openai.com/v1"

authenticate("auth-bearer") {
post("/chat/completions") {
val token = call.getToken()
val body = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(body)

val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false

if (!isStream) {
client.makeRequest(call, "$openAiUrl/chat/completions", body, token)
} else {
client.makeStreaming(call, "$openAiUrl/chat/completions", body, token)
}
}

post("/embeddings") {
val token = call.getToken()
val context = call.receive<String>()
client.makeRequest(call, "$openAiUrl/embeddings", context, token)
}
}
}

private suspend fun HttpClient.makeRequest(
call: ApplicationCall,
url: String,
body: String,
token: Token
) {
val response = this.request(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}
call.response.headers.copyFrom(response.headers)
call.respond(response.status, response.body<String>())
}

private suspend fun HttpClient.makeStreaming(
call: ApplicationCall,
url: String,
body: String,
token: Token
) {
this.preparePost(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}.execute { httpResponse ->
call.response.headers.copyFrom(httpResponse.headers)
call.respondOutputStream {
httpResponse
.bodyAsChannel()
.copyTo(this@respondOutputStream)
}
}
}

private fun ResponseHeaders.copyFrom(headers: Headers) = headers
.entries()
.filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception
.forEach { (key, values) ->
values.forEach { value -> this.appendIfAbsent(key, value) }
}

private fun ApplicationCall.getProvider(): Provider =
request.headers["xef-provider"]?.toProvider()
?: Provider.OPENAI

fun ApplicationCall.getToken(): Token =
principal<UserIdPrincipal>()?.name?.let { Token(it) } ?: throw XefExceptions.AuthorizationException("No token found")

fun ApplicationCall.getId(): Int = getInt("id")

fun ApplicationCall.getInt(field: String): Int =
this.parameters[field]?.toInt() ?: throw XefExceptions.ValidationException("Invalid $field")

Original file line number Diff line number Diff line change
@@ -1,123 +1,16 @@
package com.xebia.functional.xef.server.http.routes

import com.aallam.openai.api.BetaOpenAI
import com.xebia.functional.xef.server.models.Token
import com.xebia.functional.xef.server.models.exceptions.XefExceptions
import com.xebia.functional.xef.server.services.VectorStoreService
import com.xebia.functional.xef.server.services.OrganizationRepositoryService
import com.xebia.functional.xef.server.services.ProjectRepositoryService
import com.xebia.functional.xef.server.services.UserRepositoryService
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.server.application.*
import io.ktor.server.auth.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.utils.io.jvm.javaio.*
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.boolean
import kotlinx.serialization.json.jsonPrimitive
import org.slf4j.Logger

enum class Provider {
OPENAI, GPT4ALL, GCP
}

fun String.toProvider(): Provider? = when (this) {
"openai" -> Provider.OPENAI
"gpt4all" -> Provider.GPT4ALL
"gcp" -> Provider.GCP
else -> Provider.OPENAI
}

@OptIn(BetaOpenAI::class)
fun Routing.genAIRoutes(
client: HttpClient,
vectorStoreService: VectorStoreService
) {
val openAiUrl = "https://api.openai.com/v1"

authenticate("auth-bearer") {
post("/chat/completions") {
val token = call.getToken()
val body = call.receive<String>()
val data = Json.decodeFromString<JsonObject>(body)

val isStream = data["stream"]?.jsonPrimitive?.boolean ?: false

if (!isStream) {
client.makeRequest(call, "$openAiUrl/chat/completions", body, token)
} else {
client.makeStreaming(call, "$openAiUrl/chat/completions", body, token)
}
}

post("/embeddings") {
val token = call.getToken()
val context = call.receive<String>()
client.makeRequest(call, "$openAiUrl/embeddings", context, token)
}
}
}

private suspend fun HttpClient.makeRequest(
call: ApplicationCall,
url: String,
body: String,
token: Token
fun Routing.xefRoutes(
logger: Logger
) {
val response = this.request(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}
call.response.headers.copyFrom(response.headers)
call.respond(response.status, response.body<String>())
userRoutes(UserRepositoryService(logger))
organizationRoutes(OrganizationRepositoryService(logger))
projectsRoutes(ProjectRepositoryService(logger))
}

private suspend fun HttpClient.makeStreaming(
call: ApplicationCall,
url: String,
body: String,
token: Token
) {
this.preparePost(url) {
headers {
bearerAuth(token.value)
}
contentType(ContentType.Application.Json)
method = HttpMethod.Post
setBody(body)
}.execute { httpResponse ->
call.response.headers.copyFrom(httpResponse.headers)
call.respondOutputStream {
httpResponse
.bodyAsChannel()
.copyTo(this@respondOutputStream)
}
}
}

private fun ResponseHeaders.copyFrom(headers: Headers) = headers
.entries()
.filter { (key, _) -> !HttpHeaders.isUnsafe(key) } // setting unsafe headers results in exception
.forEach { (key, values) ->
values.forEach { value -> this.appendIfAbsent(key, value) }
}

private fun ApplicationCall.getProvider(): Provider =
request.headers["xef-provider"]?.toProvider()
?: Provider.OPENAI

fun ApplicationCall.getToken(): Token =
principal<UserIdPrincipal>()?.name?.let { Token(it) } ?: throw XefExceptions.AuthorizationException("No token found")

fun ApplicationCall.getId(): Int = getInt("id")

fun ApplicationCall.getInt(field: String): Int =
this.parameters[field]?.toInt() ?: throw XefExceptions.ValidationException("Invalid $field")