Skip to content

Commit 7530402

Browse files
authored
fix: make dynamic enums work as outputs in Ruby (#972)
Sorbet's type enforcement was breaking it, so as a workaround, all BAML enums become `enum | string` in the Sorbet type system
1 parent 02b495d commit 7530402

File tree

14 files changed

+325
-35
lines changed

14 files changed

+325
-35
lines changed

engine/language_client_codegen/src/ruby/field_type.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ impl ToRuby for FieldType {
66
fn to_ruby(&self) -> String {
77
match self {
88
FieldType::Class(name) => format!("Baml::Types::{}", name.clone()),
9-
FieldType::Enum(name) => format!("Baml::Types::{}", name.clone()),
9+
FieldType::Enum(name) => format!("T.any(Baml::Types::{}, String)", name.clone()),
1010
// https://sorbet.org/docs/stdlib-generics
1111
FieldType::List(inner) => format!("T::Array[{}]", inner.to_ruby()),
1212
FieldType::Map(key, value) => {

integ-tests/baml_src/test-files/dynamic/dynamic.baml

+10
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,13 @@ function MyFunc(input: string) -> DynamicOutput {
8181
"#
8282
}
8383

84+
function ClassifyDynEnumTwo(input: string) -> DynEnumTwo {
85+
client GPT35
86+
prompt #"
87+
Given a string, extract info using the schema:
88+
89+
{{ input}}
90+
91+
{{ ctx.output_format }}
92+
"#
93+
}

integ-tests/openapi/baml_client/openapi.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ paths:
4545
title: AudioInputResponse
4646
type: string
4747
operationId: AudioInput
48+
/call/ClassifyDynEnumTwo:
49+
post:
50+
requestBody:
51+
$ref: '#/components/requestBodies/ClassifyDynEnumTwo'
52+
responses:
53+
'200':
54+
description: Successful operation
55+
content:
56+
application/json:
57+
schema:
58+
title: ClassifyDynEnumTwoResponse
59+
$ref: '#/components/schemas/DynEnumTwo'
60+
operationId: ClassifyDynEnumTwo
4861
/call/ClassifyMessage:
4962
post:
5063
requestBody:
@@ -1073,6 +1086,19 @@ components:
10731086
required:
10741087
- aud
10751088
additionalProperties: false
1089+
ClassifyDynEnumTwo:
1090+
required: true
1091+
content:
1092+
application/json:
1093+
schema:
1094+
title: ClassifyDynEnumTwoRequest
1095+
type: object
1096+
properties:
1097+
input:
1098+
type: string
1099+
required:
1100+
- input
1101+
additionalProperties: false
10761102
ClassifyMessage:
10771103
required: true
10781104
content:

integ-tests/python/baml_client/async_client.py

+57
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,30 @@ async def AudioInput(
107107
mdl = create_model("AudioInputReturnType", inner=(str, ...))
108108
return coerce(mdl, raw.parsed())
109109

110+
async def ClassifyDynEnumTwo(
111+
self,
112+
input: str,
113+
baml_options: BamlCallOptions = {},
114+
) -> Union[types.DynEnumTwo, str]:
115+
__tb__ = baml_options.get("tb", None)
116+
if __tb__ is not None:
117+
tb = __tb__._tb
118+
else:
119+
tb = None
120+
__cr__ = baml_options.get("client_registry", None)
121+
122+
raw = await self.__runtime.call_function(
123+
"ClassifyDynEnumTwo",
124+
{
125+
"input": input,
126+
},
127+
self.__ctx_manager.get(),
128+
tb,
129+
__cr__,
130+
)
131+
mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
132+
return coerce(mdl, raw.parsed())
133+
110134
async def ClassifyMessage(
111135
self,
112136
input: str,
@@ -1984,6 +2008,39 @@ def AudioInput(
19842008
self.__ctx_manager.get(),
19852009
)
19862010

2011+
def ClassifyDynEnumTwo(
2012+
self,
2013+
input: str,
2014+
baml_options: BamlCallOptions = {},
2015+
) -> baml_py.BamlStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]]:
2016+
__tb__ = baml_options.get("tb", None)
2017+
if __tb__ is not None:
2018+
tb = __tb__._tb
2019+
else:
2020+
tb = None
2021+
__cr__ = baml_options.get("client_registry", None)
2022+
2023+
raw = self.__runtime.stream_function(
2024+
"ClassifyDynEnumTwo",
2025+
{
2026+
"input": input,
2027+
},
2028+
None,
2029+
self.__ctx_manager.get(),
2030+
tb,
2031+
__cr__,
2032+
)
2033+
2034+
mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
2035+
partial_mdl = create_model("ClassifyDynEnumTwoPartialReturnType", inner=(Optional[Union[types.DynEnumTwo, str]], ...))
2036+
2037+
return baml_py.BamlStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]](
2038+
raw,
2039+
lambda x: coerce(partial_mdl, x),
2040+
lambda x: coerce(mdl, x),
2041+
self.__ctx_manager.get(),
2042+
)
2043+
19872044
def ClassifyMessage(
19882045
self,
19892046
input: str,

integ-tests/python/baml_client/inlinedbaml.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.",
3232
"test-files/descriptions/descriptions.baml": "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}",
3333
"test-files/dynamic/client-registry.baml": "// Intentionally use a bad key\nclient<llm> BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n",
34-
"test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\n",
34+
"test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}",
3535
"test-files/functions/input/named-args/single/named-audio.baml": "function AudioInput(aud: audio) -> string{\n client Gemini\n prompt #\"\n {{ _.role(\"user\") }}\n\n Does this sound like a roar? Yes or no? One word no other characters.\n \n {{ aud }}\n \"#\n}\n\n\ntest TestURLAudioInput{\n functions [AudioInput]\n args {\n aud{ \n url https://actions.google.com/sounds/v1/emergency/beeper_emergency_call.ogg\n }\n } \n}\n\n\n",
3636
"test-files/functions/input/named-args/single/named-boolean.baml": "\n\nfunction TestFnNamedArgsSingleBool(myBool: bool) -> string{\n client GPT35\n prompt #\"\n Return this value back to me: {{myBool}}\n \"#\n}\n\ntest TestFnNamedArgsSingleBool {\n functions [TestFnNamedArgsSingleBool]\n args {\n myBool true\n }\n}",
3737
"test-files/functions/input/named-args/single/named-class-list.baml": "\n\n\nfunction TestFnNamedArgsSingleStringList(myArg: string[]) -> string{\n client GPT35\n prompt #\"\n Return this value back to me: {{myArg}}\n \"#\n}\n\ntest TestFnNamedArgsSingleStringList {\n functions [TestFnNamedArgsSingleStringList]\n args {\n myArg [\"hello\", \"world\"]\n }\n}",

integ-tests/python/baml_client/sync_client.py

+57
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,30 @@ def AudioInput(
105105
mdl = create_model("AudioInputReturnType", inner=(str, ...))
106106
return coerce(mdl, raw.parsed())
107107

108+
def ClassifyDynEnumTwo(
109+
self,
110+
input: str,
111+
baml_options: BamlCallOptions = {},
112+
) -> Union[types.DynEnumTwo, str]:
113+
__tb__ = baml_options.get("tb", None)
114+
if __tb__ is not None:
115+
tb = __tb__._tb
116+
else:
117+
tb = None
118+
__cr__ = baml_options.get("client_registry", None)
119+
120+
raw = self.__runtime.call_function_sync(
121+
"ClassifyDynEnumTwo",
122+
{
123+
"input": input,
124+
},
125+
self.__ctx_manager.get(),
126+
tb,
127+
__cr__,
128+
)
129+
mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
130+
return coerce(mdl, raw.parsed())
131+
108132
def ClassifyMessage(
109133
self,
110134
input: str,
@@ -1983,6 +2007,39 @@ def AudioInput(
19832007
self.__ctx_manager.get(),
19842008
)
19852009

2010+
def ClassifyDynEnumTwo(
2011+
self,
2012+
input: str,
2013+
baml_options: BamlCallOptions = {},
2014+
) -> baml_py.BamlSyncStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]]:
2015+
__tb__ = baml_options.get("tb", None)
2016+
if __tb__ is not None:
2017+
tb = __tb__._tb
2018+
else:
2019+
tb = None
2020+
__cr__ = baml_options.get("client_registry", None)
2021+
2022+
raw = self.__runtime.stream_function_sync(
2023+
"ClassifyDynEnumTwo",
2024+
{
2025+
"input": input,
2026+
},
2027+
None,
2028+
self.__ctx_manager.get(),
2029+
tb,
2030+
__cr__,
2031+
)
2032+
2033+
mdl = create_model("ClassifyDynEnumTwoReturnType", inner=(Union[types.DynEnumTwo, str], ...))
2034+
partial_mdl = create_model("ClassifyDynEnumTwoPartialReturnType", inner=(Optional[Union[types.DynEnumTwo, str]], ...))
2035+
2036+
return baml_py.BamlSyncStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]](
2037+
raw,
2038+
lambda x: coerce(partial_mdl, x),
2039+
lambda x: coerce(mdl, x),
2040+
self.__ctx_manager.get(),
2041+
)
2042+
19862043
def ClassifyMessage(
19872044
self,
19882045
input: str,

integ-tests/ruby/Rakefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
Rake::TestTask.new do |t|
1111
t.libs << "../../engine/language_client_ruby/lib"
1212
t.libs << "baml_client"
13-
# t.test_files = FileList["test_filtered.rb"]
14-
t.test_files = FileList["test_*.rb"]
13+
t.test_files = FileList["test_filtered.rb"]
14+
# t.test_files = FileList["test_*.rb"]
1515
t.options = '--verbose'
1616
end

0 commit comments

Comments
 (0)