Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port PromptTemplate #6

Merged
merged 3 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ repositories {
}

plugins {
base
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.spotless)
alias(libs.plugins.kotlinx.serialization)
base
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.spotless)
alias(libs.plugins.kotlinx.serialization)
}

java {
Expand Down Expand Up @@ -41,31 +41,33 @@ kotlin {
val hostOs = System.getProperty("os.name")
val isMingwX64 = hostOs.startsWith("Windows")
when {
hostOs == "Mac OS X" -> macosX64("native")
hostOs == "Linux" -> linuxX64("native")
isMingwX64 -> mingwX64("native")
else -> throw GradleException("Host OS is not supported in Kotlin/Native.")
}
hostOs == "Mac OS X" -> macosX64("native")
hostOs == "Linux" -> linuxX64("native")
isMingwX64 -> mingwX64("native")
else -> throw GradleException("Host OS is not supported in Kotlin/Native.")
}



sourceSets {
commonMain {
dependencies {
implementation(libs.arrow.fx)
implementation(libs.kotlinx.serialization.json)
implementation(libs.bundles.ktor.client)
}
}
commonTest {
dependencies {
implementation(kotlin("test"))
}
}

implementation("com.squareup.okio:okio:3.3.0")
}
}
commonTest {
dependencies {
implementation(kotlin("test"))
}
}
}
}

spotless {
kotlin {
ktfmt().googleStyle()
}
kotlin {
ktfmt().googleStyle()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package com.xebia.functional.prompt

import okio.FileSystem
import okio.Path
import okio.buffer
import okio.use

interface PromptTemplate {
val inputKeys: List<String>
suspend fun format(variables: Map<String, String>): String

companion object {
operator fun invoke(config: Config): PromptTemplate = object : PromptTemplate {
override val inputKeys: List<String> = config.inputVariables

override suspend fun format(variables: Map<String, String>): String {
val mergedArgs = mergePartialAndUserVariables(variables, config.inputVariables)
return when (config.templateFormat) {
TemplateFormat.Jinja2 -> TODO()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently not implemented in lanchain4s and is resulting in NoSuchElementException, so I am keeping it here as TODO for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed, this from the enum and raised a PR for Scala to do the same. I think it's better to not have this until we actually implement it.

TemplateFormat.FString -> {
val sortedArgs = mergedArgs.toList().sortedBy { it.first }
sortedArgs.fold(config.template) { acc, (k, v) -> acc.replace("{$k}", v) }
}
}
}

private fun mergePartialAndUserVariables(
variables: Map<String, String>,
inputVariables: List<String>
): Map<String, String> =
inputVariables.fold(variables) { acc, k ->
if (!acc.containsKey(k)) acc + (k to "{$k}") else acc
}
}

suspend fun fromExamples(
examples: List<String>,
suffix: String,
inputVariables: List<String>,
prefix: String
): PromptTemplate {
val template = """|$prefix
|
|${examples.joinToString(separator = "\n")}
|$suffix""".trimMargin()
return PromptTemplate(Config.orThrow(template, inputVariables))
}

suspend fun fromTemplate(template: String, inputVariables: List<String>): PromptTemplate =
PromptTemplate(Config.orThrow(template, inputVariables))

// TODO IO Dispatcher KMP ??
suspend fun fromFile(
templateFile: Path,
inputVariables: List<String>,
fileSystem: FileSystem
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only exists for JVM, Native & NodeJS.

If we drop the browser target then we can provide a default argument of FileSystem.SYSTEM in this function, the Okio library also provides the test instance as well. We can also keep it as is, since browser cannot provide an instance for this it cannot call the method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, we don't need to load templates from files. I think it's better to keep the browser support I suspect in the browser there is some kind of API to access local storage if you have been granted permission which we could use in place of speaking in terms of File

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raulraja it currently still supports browser, but from a browser sourcceSet you cannot provide a FileSystem instance. square/okio#1070.

I'll provide overloads of this function from JVM, Native and NodeJS that remove that argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could only do this for native, and JVM for now. Cannot define a separate NodeJS sourceSet, since there is no commonizer between NodeJS and Browser. https://youtrack.jetbrains.com/issue/KT-47038

Super tiny issue though, still quite easy to call it from node by passing FileSystem.SYSTEM.

): PromptTemplate =
fileSystem.source(templateFile).use { source ->
source.buffer().use { buffer ->
val template = buffer.readUtf8()
val config = Config.orThrow(template, inputVariables)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the smart-constructors for PromptTemplate are turning NonEmptyList<InvalidTemplate> into an exception, we can keep this as Raise<NonEmptyList<InvalidTemplate>>.

Perhaps better is Raise<InvalidTemplate>, where the Config smart-constructors is already formatting the final message from NonEmptyList<String>.

Then we can turn the signature into something like Raise<PromptTemplate.IncorrectConfig>, and create an error hierarchy if we need to with more errors.

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all the exceptions, and orThrow and pushed them to the smart-constructors. You can always wrap in either { }.getOrNull() or recover({ }) { null } or we can expose a Either<InvalidTemplate, A>.getOrThrow extension within the module so dealing with the error at least becomes explicit.

PromptTemplate(config)
}
}
}
}
54 changes: 54 additions & 0 deletions src/commonMain/kotlin/com/xebia/functional/prompt/models.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.xebia.functional.prompt

import arrow.core.EitherNel
import arrow.core.getOrElse
import arrow.core.raise.Raise
import arrow.core.raise.either
import arrow.core.raise.ensure
import arrow.core.raise.zipOrAccumulate

enum class TemplateFormat(name: String) {
Jinja2("jinja2"),
FString("f-string")
}

data class InvalidTemplate(val reason: String)

data class Config private constructor(
val inputVariables: List<String>,
val template: String,
val templateFormat: TemplateFormat = TemplateFormat.FString
) {
companion object {
operator fun invoke(template: String, inputVariables: List<String>): EitherNel<InvalidTemplate, Config> =
either {
val placeholders = placeholderValues(template)

zipOrAccumulate(
{ validate(template, inputVariables.toSet() - placeholders.toSet(), "unused") },
{ validate(template, placeholders.toSet() - inputVariables.toSet(), "missing") },
{ validateDuplicated(template, placeholders) }
) { _, _, _ -> Config(inputVariables, template) }
}

fun orThrow(template: String, inputVariables: List<String>): Config =
invoke(template, inputVariables).getOrElse { throw IllegalArgumentException(it.all.joinToString(transform = InvalidTemplate::reason)) }
}
}

private fun Raise<InvalidTemplate>.validate(template: String, diffSet: Set<String>, msg: String): Unit =
ensure(diffSet.isEmpty()) {
InvalidTemplate("Template '$template' has $msg arguments: ${diffSet.joinToString(", ")}")
}

private fun Raise<InvalidTemplate>.validateDuplicated(template: String, placeholders: List<String>) {
val args = placeholders.groupBy { it }.filter { it.value.size > 1 }.keys
ensure(args.isEmpty()) {
InvalidTemplate("Template '$template' has duplicate arguments: ${args.joinToString(", ")}")
}
}

private fun placeholderValues(template: String): List<String> {
val regex = Regex("""\{([^\{\}]+)\}""")
return regex.findAll(template).toList().mapNotNull { it.groupValues.firstOrNull() }
}