Skip to content

Commit c3193f8

Browse files
authored
Allow custom name for functions module (#2241)
* Allow custom name for functions module * update tests * reorder
1 parent d970449 commit c3193f8

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

autogen/coding/local_commandline_code_executor.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030

3131
class LocalCommandLineCodeExecutor(CodeExecutor):
3232
SUPPORTED_LANGUAGES: ClassVar[List[str]] = ["bash", "shell", "sh", "pwsh", "powershell", "ps1", "python"]
33-
FUNCTIONS_MODULE: ClassVar[str] = "functions"
34-
FUNCTIONS_FILENAME: ClassVar[str] = "functions.py"
3533
FUNCTION_PROMPT_TEMPLATE: ClassVar[
3634
str
3735
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
@@ -45,6 +43,7 @@ def __init__(
4543
timeout: int = 60,
4644
work_dir: Union[Path, str] = Path("."),
4745
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
46+
functions_module: str = "functions",
4847
):
4948
"""(Experimental) A code executor class that executes code through a local command line
5049
environment.
@@ -76,6 +75,11 @@ def __init__(
7675
if isinstance(work_dir, str):
7776
work_dir = Path(work_dir)
7877

78+
if not functions_module.isidentifier():
79+
raise ValueError("Module name must be a valid Python identifier")
80+
81+
self._functions_module = functions_module
82+
7983
work_dir.mkdir(exist_ok=True)
8084

8185
self._timeout = timeout
@@ -104,10 +108,15 @@ def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEM
104108

105109
template = Template(prompt_template)
106110
return template.substitute(
107-
module_name=self.FUNCTIONS_MODULE,
111+
module_name=self._functions_module,
108112
functions="\n\n".join([to_stub(func) for func in self._functions]),
109113
)
110114

115+
@property
116+
def functions_module(self) -> str:
117+
"""(Experimental) The module name for the functions."""
118+
return self._functions_module
119+
111120
@property
112121
def functions(
113122
self,
@@ -154,7 +163,7 @@ def sanitize_command(lang: str, code: str) -> None:
154163

155164
def _setup_functions(self) -> None:
156165
func_file_content = _build_python_functions_file(self._functions)
157-
func_file = self._work_dir / self.FUNCTIONS_FILENAME
166+
func_file = self._work_dir / f"{self._functions_module}.py"
158167
func_file.write_text(func_file_content)
159168

160169
# Collect requirements

test/coding/test_user_defined_functions.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def function_missing_reqs() -> "pandas.DataFrame":
5454
def test_can_load_function_with_reqs(cls) -> None:
5555
with tempfile.TemporaryDirectory() as temp_dir:
5656
executor = cls(work_dir=temp_dir, functions=[load_data])
57-
code = f"""from {cls.FUNCTIONS_MODULE} import load_data
57+
code = f"""from {executor.functions_module} import load_data
5858
import pandas
5959
6060
# Get first row's name
@@ -74,7 +74,7 @@ def test_can_load_function_with_reqs(cls) -> None:
7474
def test_can_load_function(cls) -> None:
7575
with tempfile.TemporaryDirectory() as temp_dir:
7676
executor = cls(work_dir=temp_dir, functions=[add_two_numbers])
77-
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
77+
code = f"""from {executor.functions_module} import add_two_numbers
7878
print(add_two_numbers(1, 2))"""
7979

8080
result = executor.execute_code_blocks(
@@ -93,7 +93,7 @@ def test_can_load_function(cls) -> None:
9393
# def test_fails_for_missing_reqs(cls) -> None:
9494
# with tempfile.TemporaryDirectory() as temp_dir:
9595
# executor = cls(work_dir=temp_dir, functions=[function_missing_reqs])
96-
# code = f"""from {cls.FUNCTIONS_MODULE} import function_missing_reqs
96+
# code = f"""from {executor.functions_module} import function_missing_reqs
9797
# function_missing_reqs()"""
9898

9999
# with pytest.raises(ValueError):
@@ -109,7 +109,7 @@ def test_can_load_function(cls) -> None:
109109
def test_fails_for_function_incorrect_import(cls) -> None:
110110
with tempfile.TemporaryDirectory() as temp_dir:
111111
executor = cls(work_dir=temp_dir, functions=[function_incorrect_import])
112-
code = f"""from {cls.FUNCTIONS_MODULE} import function_incorrect_import
112+
code = f"""from {executor.functions_module} import function_incorrect_import
113113
function_incorrect_import()"""
114114

115115
with pytest.raises(ValueError):
@@ -125,7 +125,7 @@ def test_fails_for_function_incorrect_import(cls) -> None:
125125
def test_fails_for_function_incorrect_dep(cls) -> None:
126126
with tempfile.TemporaryDirectory() as temp_dir:
127127
executor = cls(work_dir=temp_dir, functions=[function_incorrect_dep])
128-
code = f"""from {cls.FUNCTIONS_MODULE} import function_incorrect_dep
128+
code = f"""from {executor.functions_module} import function_incorrect_dep
129129
function_incorrect_dep()"""
130130

131131
with pytest.raises(ValueError):
@@ -183,7 +183,7 @@ def add_two_numbers(a: int, b: int) -> int:
183183
)
184184

185185
executor = cls(work_dir=temp_dir, functions=[func])
186-
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
186+
code = f"""from {executor.functions_module} import add_two_numbers
187187
print(add_two_numbers(1, 2))"""
188188

189189
result = executor.execute_code_blocks(
@@ -219,10 +219,9 @@ def add_two_numbers(a: int, b: int) -> int:
219219
'''
220220
)
221221

222-
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
223-
print(add_two_numbers(object(), False))"""
224-
225222
executor = cls(work_dir=temp_dir, functions=[func])
223+
code = f"""from {executor.functions_module} import add_two_numbers
224+
print(add_two_numbers(object(), False))"""
226225

227226
result = executor.execute_code_blocks(
228227
code_blocks=[

0 commit comments

Comments
 (0)