Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safely spawn fine tuned model #494

Merged
merged 5 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
package com.xebia.functional.xef.conversation.finetuning

import arrow.core.getOrElse
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.prompt.Prompt

suspend fun main() {
val spawnModelId =
getenv("OPENAI_FINE_TUNED_MODEL_ID")
?: error("Please set the OPENAI_FINE_TUNED_MODEL_ID environment variable.")

val OAI = OpenAI()
val model = OAI.spawnModel(spawnModelId, OAI.GPT_3_5_TURBO)
val baseModel = OAI.GPT_3_5_TURBO

val fineTunedModelId = getenv("OPENAI_FINE_TUNED_MODEL_ID")
val fineTuneJobId = getenv("OPENAI_FINE_TUNE_JOB_ID")

val model =
when {
fineTunedModelId != null -> OAI.spawnModel(fineTunedModelId, baseModel)
fineTuneJobId != null -> OAI.spawnFineTunedModel(fineTuneJobId, baseModel)
else ->
error(
"Please set the OPENAI_FINE_TUNED_MODEL_ID or OPENAI_FINE_TUNE_JOB_ID environment variable."
)
}.getOrElse { error(it) }

OpenAI.conversation {
while (true) {
print("> ")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package com.xebia.functional.xef.conversation.llm.openai

import arrow.core.nonEmptyListOf
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.ensureNotNull
import com.aallam.openai.api.exception.InvalidRequestException
import com.aallam.openai.api.finetuning.FineTuningId
import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.LoggingConfig
Expand Down Expand Up @@ -137,6 +141,7 @@ class OpenAI(

@JvmField val DEFAULT_IMAGES = DALLE_2

/** Returns a list of all publicly available, supported models. */
fun supportedModels(): List<LLM> = // TODO: impl of abstract provider function
listOf(
GPT_4,
Expand All @@ -155,26 +160,53 @@ class OpenAI(
DALLE_2,
)

suspend fun <T : LLM> spawnModel(
modelId: String,
baseModel: T
): T { // TODO: impl of abstract provider function
if (findModel(modelId) == null) error("model not found")
return baseModel.copy(ModelType.FineTunedModel(modelId, baseModel = baseModel.modelType)) as? T
?: error("${baseModel::class} does not follow contract to return the most specific type")
/**
* Spawns a model by its [modelId]. It should have the same capabilities as [baseModel]. The model
* to spawn can i.e. be a fine-tuned model which is not known to the public.
*
* Warning: Throws an error at runtime during querying if the model does not provide the same
* capabilities as [baseModel].
*/
suspend fun <T : LLM> spawnModel(modelId: String, baseModel: T) =
either { // TODO: impl of abstract provider function
ensure(modelExists(modelId)) { "model $modelId not found" }
@Suppress("UNCHECKED_CAST")
baseModel.copy(ModelType.FineTunedModel(modelId, baseModel = baseModel.modelType)) as? T
?: error("${baseModel::class} does not follow contract to return the most specific type")
}

/**
* Spawns a model based off a [fineTuningJobId]. It should have the same capabilities as
* [baseModel].
*
* This function is safer than [spawnModel] because it checks if the base model the fine-tuned
* model was derived from matches [baseModel].
*/
suspend fun <T : LLM> spawnFineTunedModel(fineTuningJobId: String, baseModel: T) = either {
val job = defaultClient.fineTuningJob(FineTuningId(fineTuningJobId))
ensureNotNull(job) { "job $fineTuningJobId not found" }
val fineTunedModel = job.fineTunedModel
ensureNotNull(fineTunedModel) { "fine tuned model not available, status ${job.status}" }
ensure(baseModel.modelType.name == job.model.id) {
"base model instance does not match the job's base model"
}
spawnModel(fineTunedModel.id, baseModel).bind()
}

private suspend fun findModel(modelId: String): Any? { // TODO: impl of abstract provider function
/** Checks if the model exists. */
private suspend fun modelExists(
modelId: String
): Boolean { // TODO: impl of abstract provider function
val model =
try {
defaultClient.model(ModelId(modelId))
} catch (e: InvalidRequestException) {
when (e.error.detail?.code) {
"model_not_found" -> return null
"model_not_found" -> return false
else -> throw e
}
}
return ModelType.TODO(model.id.id)
return true
}

companion object {
Expand Down