-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port PromptTemplate #6
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package com.xebia.functional.prompt | ||
|
||
import okio.FileSystem | ||
import okio.Path | ||
import okio.buffer | ||
import okio.use | ||
|
||
interface PromptTemplate { | ||
val inputKeys: List<String> | ||
suspend fun format(variables: Map<String, String>): String | ||
|
||
companion object { | ||
operator fun invoke(config: Config): PromptTemplate = object : PromptTemplate { | ||
override val inputKeys: List<String> = config.inputVariables | ||
|
||
override suspend fun format(variables: Map<String, String>): String { | ||
val mergedArgs = mergePartialAndUserVariables(variables, config.inputVariables) | ||
return when (config.templateFormat) { | ||
TemplateFormat.Jinja2 -> TODO() | ||
TemplateFormat.FString -> { | ||
val sortedArgs = mergedArgs.toList().sortedBy { it.first } | ||
sortedArgs.fold(config.template) { acc, (k, v) -> acc.replace("{$k}", v) } | ||
} | ||
} | ||
} | ||
|
||
private fun mergePartialAndUserVariables( | ||
variables: Map<String, String>, | ||
inputVariables: List<String> | ||
): Map<String, String> = | ||
inputVariables.fold(variables) { acc, k -> | ||
if (!acc.containsKey(k)) acc + (k to "{$k}") else acc | ||
} | ||
} | ||
|
||
suspend fun fromExamples( | ||
examples: List<String>, | ||
suffix: String, | ||
inputVariables: List<String>, | ||
prefix: String | ||
): PromptTemplate { | ||
val template = """|$prefix | ||
| | ||
|${examples.joinToString(separator = "\n")} | ||
|$suffix""".trimMargin() | ||
return PromptTemplate(Config.orThrow(template, inputVariables)) | ||
} | ||
|
||
suspend fun fromTemplate(template: String, inputVariables: List<String>): PromptTemplate = | ||
PromptTemplate(Config.orThrow(template, inputVariables)) | ||
|
||
// TODO IO Dispatcher KMP ?? | ||
suspend fun fromFile( | ||
templateFile: Path, | ||
inputVariables: List<String>, | ||
fileSystem: FileSystem | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only exists for JVM, Native & NodeJS. If we drop the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At the moment, we don't need to load templates from files. I think it's better to keep the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @raulraja it currently still supports browser, but from a browser sourcceSet you cannot provide a I'll provide overloads of this function from JVM, Native and NodeJS that remove that argument. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could only do this for native, and JVM for now. Cannot define a separate NodeJS sourceSet, since there is no commonizer between NodeJS and Browser. https://youtrack.jetbrains.com/issue/KT-47038 Super tiny issue though, still quite easy to call it from node by passing |
||
): PromptTemplate = | ||
fileSystem.source(templateFile).use { source -> | ||
source.buffer().use { buffer -> | ||
val template = buffer.readUtf8() | ||
val config = Config.orThrow(template, inputVariables) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the smart-constructors for Perhaps better is Then we can turn the signature into something like WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed all the exceptions, and |
||
PromptTemplate(config) | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package com.xebia.functional.prompt | ||
|
||
import arrow.core.EitherNel | ||
import arrow.core.getOrElse | ||
import arrow.core.raise.Raise | ||
import arrow.core.raise.either | ||
import arrow.core.raise.ensure | ||
import arrow.core.raise.zipOrAccumulate | ||
|
||
enum class TemplateFormat(name: String) { | ||
Jinja2("jinja2"), | ||
FString("f-string") | ||
} | ||
|
||
data class InvalidTemplate(val reason: String) | ||
|
||
data class Config private constructor( | ||
val inputVariables: List<String>, | ||
val template: String, | ||
val templateFormat: TemplateFormat = TemplateFormat.FString | ||
) { | ||
companion object { | ||
operator fun invoke(template: String, inputVariables: List<String>): EitherNel<InvalidTemplate, Config> = | ||
either { | ||
val placeholders = placeholderValues(template) | ||
|
||
zipOrAccumulate( | ||
{ validate(template, inputVariables.toSet() - placeholders.toSet(), "unused") }, | ||
{ validate(template, placeholders.toSet() - inputVariables.toSet(), "missing") }, | ||
{ validateDuplicated(template, placeholders) } | ||
) { _, _, _ -> Config(inputVariables, template) } | ||
} | ||
|
||
fun orThrow(template: String, inputVariables: List<String>): Config = | ||
invoke(template, inputVariables).getOrElse { throw IllegalArgumentException(it.all.joinToString(transform = InvalidTemplate::reason)) } | ||
} | ||
} | ||
|
||
private fun Raise<InvalidTemplate>.validate(template: String, diffSet: Set<String>, msg: String): Unit = | ||
ensure(diffSet.isEmpty()) { | ||
InvalidTemplate("Template '$template' has $msg arguments: ${diffSet.joinToString(", ")}") | ||
} | ||
|
||
private fun Raise<InvalidTemplate>.validateDuplicated(template: String, placeholders: List<String>) { | ||
val args = placeholders.groupBy { it }.filter { it.value.size > 1 }.keys | ||
ensure(args.isEmpty()) { | ||
InvalidTemplate("Template '$template' has duplicate arguments: ${args.joinToString(", ")}") | ||
} | ||
} | ||
|
||
private fun placeholderValues(template: String): List<String> { | ||
val regex = Regex("""\{([^\{\}]+)\}""") | ||
return regex.findAll(template).toList().mapNotNull { it.groupValues.firstOrNull() } | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently not implemented in
lanchain4s
and is resulting inNoSuchElementException
, so I am keeping it here as TODO for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed, this from the enum and raised a PR for Scala to do the same. I think it's better to not have this until we actually implement it.