Skip to content

Commit ea8db91

Browse files
authored
Support string based UDFs (microsoft#2195)
1 parent 4bedecf commit ea8db91

File tree

3 files changed

+165
-9
lines changed

3 files changed

+165
-9
lines changed

autogen/coding/func_with_reqs.py

+71-4
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
from typing_extensions import ParamSpec
66
from textwrap import indent, dedent
77
from dataclasses import dataclass, field
8+
from importlib.abc import SourceLoader
9+
import importlib
810

911
T = TypeVar("T")
1012
P = ParamSpec("P")
1113

1214

13-
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T]]) -> str:
15+
def _to_code(func: Union[FunctionWithRequirements[T, P], Callable[P, T], FunctionWithRequirementsStr]) -> str:
16+
if isinstance(func, FunctionWithRequirementsStr):
17+
return func.func
18+
1419
code = inspect.getsource(func)
1520
# Strip the decorator
1621
if code.startswith("@"):
@@ -50,6 +55,57 @@ def to_str(i: Union[str, Alias]) -> str:
5055
return f"from {im.module} import {imports}"
5156

5257

58+
class _StringLoader(SourceLoader):
59+
def __init__(self, data: str):
60+
self.data = data
61+
62+
def get_source(self, fullname: str) -> str:
63+
return self.data
64+
65+
def get_data(self, path: str) -> bytes:
66+
return self.data.encode("utf-8")
67+
68+
def get_filename(self, fullname: str) -> str:
69+
return "<not a real path>/" + fullname + ".py"
70+
71+
72+
@dataclass
73+
class FunctionWithRequirementsStr:
74+
func: str
75+
_compiled_func: Callable[..., Any]
76+
_func_name: str
77+
python_packages: List[str] = field(default_factory=list)
78+
global_imports: List[Import] = field(default_factory=list)
79+
80+
def __init__(self, func: str, python_packages: List[str] = [], global_imports: List[Import] = []):
81+
self.func = func
82+
self.python_packages = python_packages
83+
self.global_imports = global_imports
84+
85+
module_name = "func_module"
86+
loader = _StringLoader(func)
87+
spec = importlib.util.spec_from_loader(module_name, loader)
88+
if spec is None:
89+
raise ValueError("Could not create spec")
90+
module = importlib.util.module_from_spec(spec)
91+
if spec.loader is None:
92+
raise ValueError("Could not create loader")
93+
94+
try:
95+
spec.loader.exec_module(module)
96+
except Exception as e:
97+
raise ValueError(f"Could not compile function: {e}") from e
98+
99+
functions = inspect.getmembers(module, inspect.isfunction)
100+
if len(functions) != 1:
101+
raise ValueError("The string must contain exactly one function")
102+
103+
self._func_name, self._compiled_func = functions[0]
104+
105+
def __call__(self, *args: Any, **kwargs: Any) -> None:
106+
raise NotImplementedError("String based function with requirement objects are not directly callable")
107+
108+
53109
@dataclass
54110
class FunctionWithRequirements(Generic[T, P]):
55111
func: Callable[P, T]
@@ -62,6 +118,12 @@ def from_callable(
62118
) -> FunctionWithRequirements[T, P]:
63119
return cls(python_packages=python_packages, global_imports=global_imports, func=func)
64120

121+
@staticmethod
122+
def from_str(
123+
func: str, python_packages: List[str] = [], global_imports: List[Import] = []
124+
) -> FunctionWithRequirementsStr:
125+
return FunctionWithRequirementsStr(func=func, python_packages=python_packages, global_imports=global_imports)
126+
65127
# Type this based on F
66128
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
67129
return self.func(*args, **kwargs)
@@ -91,11 +153,13 @@ def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]:
91153
return wrapper
92154

93155

94-
def _build_python_functions_file(funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any]]]) -> str:
156+
def _build_python_functions_file(
157+
funcs: List[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]]
158+
) -> str:
95159
# First collect all global imports
96160
global_imports = set()
97161
for func in funcs:
98-
if isinstance(func, FunctionWithRequirements):
162+
if isinstance(func, (FunctionWithRequirements, FunctionWithRequirementsStr)):
99163
global_imports.update(func.global_imports)
100164

101165
content = "\n".join(map(_import_to_str, global_imports)) + "\n\n"
@@ -106,7 +170,7 @@ def _build_python_functions_file(funcs: List[Union[FunctionWithRequirements[Any,
106170
return content
107171

108172

109-
def to_stub(func: Callable[..., Any]) -> str:
173+
def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str:
110174
"""Generate a stub for a function as a string
111175
112176
Args:
@@ -115,6 +179,9 @@ def to_stub(func: Callable[..., Any]) -> str:
115179
Returns:
116180
str: The stub for the function
117181
"""
182+
if isinstance(func, FunctionWithRequirementsStr):
183+
return to_stub(func._compiled_func)
184+
118185
content = f"def {func.__name__}{inspect.signature(func)}:\n"
119186
docstring = func.__doc__
120187

autogen/coding/local_commandline_code_executor.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import warnings
77
from typing import Any, Callable, ClassVar, List, TypeVar, Union, cast
88
from typing_extensions import ParamSpec
9-
from autogen.coding.func_with_reqs import FunctionWithRequirements, _build_python_functions_file, to_stub
9+
from autogen.coding.func_with_reqs import (
10+
FunctionWithRequirements,
11+
FunctionWithRequirementsStr,
12+
_build_python_functions_file,
13+
to_stub,
14+
)
1015

1116
from ..code_utils import TIMEOUT_MSG, WIN32, _cmd
1217
from .base import CodeBlock, CodeExecutor, CodeExtractor, CommandLineCodeResult
@@ -39,7 +44,7 @@ def __init__(
3944
self,
4045
timeout: int = 60,
4146
work_dir: Union[Path, str] = Path("."),
42-
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]] = [],
47+
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
4348
):
4449
"""(Experimental) A code executor class that executes code through a local command line
4550
environment.
@@ -104,7 +109,9 @@ def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEM
104109
)
105110

106111
@property
107-
def functions(self) -> List[Union[FunctionWithRequirements[Any, A], Callable[..., Any]]]:
112+
def functions(
113+
self,
114+
) -> List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]]:
108115
"""(Experimental) The functions that are available to the code executor."""
109116
return self._functions
110117

test/coding/test_user_defined_functions.py

+84-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
else:
1111
skip = False
1212

13-
from autogen.coding.func_with_reqs import with_requirements
13+
from autogen.coding.func_with_reqs import FunctionWithRequirements, with_requirements
1414

1515
classes_to_test = [LocalCommandLineCodeExecutor]
1616

@@ -137,7 +137,6 @@ def test_fails_for_function_incorrect_dep(cls) -> None:
137137

138138

139139
@pytest.mark.parametrize("cls", classes_to_test)
140-
@pytest.mark.skipif(skip, reason="pandas not installed")
141140
def test_formatted_prompt(cls) -> None:
142141
with tempfile.TemporaryDirectory() as temp_dir:
143142
executor = cls(work_dir=temp_dir, functions=[add_two_numbers])
@@ -149,3 +148,86 @@ def test_formatted_prompt(cls) -> None:
149148
'''
150149
in result
151150
)
151+
152+
153+
@pytest.mark.parametrize("cls", classes_to_test)
154+
def test_formatted_prompt_str_func(cls) -> None:
155+
with tempfile.TemporaryDirectory() as temp_dir:
156+
func = FunctionWithRequirements.from_str(
157+
'''
158+
def add_two_numbers(a: int, b: int) -> int:
159+
"""Add two numbers together."""
160+
return a + b
161+
'''
162+
)
163+
executor = cls(work_dir=temp_dir, functions=[func])
164+
165+
result = executor.format_functions_for_prompt()
166+
assert (
167+
'''def add_two_numbers(a: int, b: int) -> int:
168+
"""Add two numbers together."""
169+
'''
170+
in result
171+
)
172+
173+
174+
@pytest.mark.parametrize("cls", classes_to_test)
175+
def test_can_load_str_function_with_reqs(cls) -> None:
176+
with tempfile.TemporaryDirectory() as temp_dir:
177+
func = FunctionWithRequirements.from_str(
178+
'''
179+
def add_two_numbers(a: int, b: int) -> int:
180+
"""Add two numbers together."""
181+
return a + b
182+
'''
183+
)
184+
185+
executor = cls(work_dir=temp_dir, functions=[func])
186+
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
187+
print(add_two_numbers(1, 2))"""
188+
189+
result = executor.execute_code_blocks(
190+
code_blocks=[
191+
CodeBlock(language="python", code=code),
192+
]
193+
)
194+
assert result.output == "3\n"
195+
assert result.exit_code == 0
196+
197+
198+
@pytest.mark.parametrize("cls", classes_to_test)
199+
def test_cant_load_broken_str_function_with_reqs(cls) -> None:
200+
201+
with pytest.raises(ValueError):
202+
_ = FunctionWithRequirements.from_str(
203+
'''
204+
invaliddef add_two_numbers(a: int, b: int) -> int:
205+
"""Add two numbers together."""
206+
return a + b
207+
'''
208+
)
209+
210+
211+
@pytest.mark.parametrize("cls", classes_to_test)
212+
def test_cant_run_broken_str_function_with_reqs(cls) -> None:
213+
with tempfile.TemporaryDirectory() as temp_dir:
214+
func = FunctionWithRequirements.from_str(
215+
'''
216+
def add_two_numbers(a: int, b: int) -> int:
217+
"""Add two numbers together."""
218+
return a + b
219+
'''
220+
)
221+
222+
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
223+
print(add_two_numbers(object(), False))"""
224+
225+
executor = cls(work_dir=temp_dir, functions=[func])
226+
227+
result = executor.execute_code_blocks(
228+
code_blocks=[
229+
CodeBlock(language="python", code=code),
230+
]
231+
)
232+
assert "TypeError: unsupported operand type(s) for +:" in result.output
233+
assert result.exit_code == 1

0 commit comments

Comments
 (0)