diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 7475b1eb96..05f65873bc 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -125,6 +125,7 @@ prepend_agent_framework_to_user_agent, ) from ._tools import ( + SKIP_PARSING, FunctionInvocationConfiguration, FunctionInvocationLayer, FunctionTool, @@ -258,6 +259,7 @@ "GROUP_INDEX_KEY", "GROUP_KIND_KEY", "GROUP_TOKEN_COUNT_KEY", + "SKIP_PARSING", "SUMMARIZED_BY_SUMMARY_ID_KEY", "SUMMARY_OF_GROUP_IDS_KEY", "SUMMARY_OF_MESSAGE_IDS_KEY", diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 5f5e91b656..3f15472a5a 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -94,6 +94,33 @@ ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]") ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) + +class _SkipParsingSentinel: + """Sentinel signaling that :meth:`FunctionTool.invoke` should return the raw value. + + When passed as ``result_parser`` to :class:`FunctionTool` (or the ``@tool`` decorator), + the default :meth:`FunctionTool.parse_result` is bypassed and the wrapped function's + return value is returned unchanged from :meth:`FunctionTool.invoke`. Callers may also + request the raw value on a per-call basis by passing ``skip_parsing=True`` to + :meth:`FunctionTool.invoke`. + + Use the module-level ``SKIP_PARSING`` singleton — do not instantiate this class. + """ + + _instance: ClassVar[_SkipParsingSentinel | None] = None + + def __new__(cls) -> _SkipParsingSentinel: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __repr__(self) -> str: + return "SKIP_PARSING" + + +SKIP_PARSING: Final[_SkipParsingSentinel] = _SkipParsingSentinel() +"""Sentinel for ``FunctionTool(result_parser=...)`` meaning "do not parse the result".""" + # region Helpers @@ -279,7 +306,7 @@ def __init__( additional_properties: dict[str, Any] | None = None, func: Callable[..., Any] | None = None, input_model: type[BaseModel] | Mapping[str, Any] | None = None, - result_parser: Callable[[Any], str | list[Content]] | None = None, + result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None, **kwargs: Any, ) -> None: """Initialize the FunctionTool. @@ -327,9 +354,11 @@ def __init__( result_parser: An optional callable with signature ``Callable[[Any], str]`` that overrides the default result parsing behavior. When provided, this callable is used to convert the raw function return value to a string instead of the - built-in :meth:`parse_result` logic. Depending on your function, it may be - easiest to just do the serialization directly in the function body rather - than providing a custom ``result_parser``. + built-in :meth:`parse_result` logic. Pass the :data:`SKIP_PARSING` sentinel + instead of a callable to opt out of parsing entirely; in that case + :meth:`invoke` returns the wrapped function's raw return value. Depending + on your function, it may be easiest to just do the serialization directly + in the function body rather than providing a custom ``result_parser``. **kwargs: Additional keyword arguments. """ # Core attributes (formerly from BaseTool) @@ -508,31 +537,65 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: self.invocation_exception_count += 1 raise + @overload async def invoke( self, *, arguments: BaseModel | Mapping[str, Any] | None = None, context: FunctionInvocationContext | None = None, tool_call_id: str | None = None, + skip_parsing: Literal[True], **kwargs: Any, - ) -> list[Content]: + ) -> Any: ... + + @overload + async def invoke( + self, + *, + arguments: BaseModel | Mapping[str, Any] | None = None, + context: FunctionInvocationContext | None = None, + tool_call_id: str | None = None, + skip_parsing: Literal[False] = False, + **kwargs: Any, + ) -> list[Content]: ... + + async def invoke( + self, + *, + arguments: BaseModel | Mapping[str, Any] | None = None, + context: FunctionInvocationContext | None = None, + tool_call_id: str | None = None, + skip_parsing: bool = False, + **kwargs: Any, + ) -> list[Content] | Any: """Run the AI function with the provided arguments as a Pydantic model. The raw return value of the wrapped function is automatically parsed into a ``list[Content]`` using :meth:`parse_result` or the custom ``result_parser`` - if one was provided. Every result — text, rich media, or serialized objects — - is represented uniformly as Content items. + configured on the tool. Every result — text, rich media, or serialized + objects — is represented uniformly as Content items. + + Parsing can be skipped in two ways: configure the tool with + ``result_parser=SKIP_PARSING`` to always skip parsing, or pass + ``skip_parsing=True`` per call. Either way the wrapped function's raw value + is returned. This is intended for callers (e.g. sandboxed runtimes) that + consume the value from Python directly and would otherwise undo the + ``Content`` wrapping. Keyword Args: arguments: A mapping or model instance containing the arguments for the function. context: Explicit function invocation context carrying runtime kwargs. tool_call_id: Optional tool call identifier used for telemetry and tracing. + skip_parsing: When ``True``, bypass parsing and return the wrapped function's + raw value instead of a ``list[Content]``. Defaults to ``False``. kwargs: Direct function argument values. When provided, every keyword must match a declared tool parameter. Runtime data must be passed via ``context``. Returns: - A list of Content items representing the tool output. + ``list[Content]`` by default. The raw function return value (``Any``) when + ``skip_parsing=True`` (or the tool was constructed with + ``result_parser=SKIP_PARSING``). Raises: TypeError: If arguments is not mapping-like or fails schema checks. @@ -544,7 +607,9 @@ async def invoke( from ._types import Content from .observability import OBSERVABILITY_SETTINGS - parser = self.result_parser or FunctionTool.parse_result + configured_parser = self.result_parser + skip_parsing = skip_parsing or configured_parser is SKIP_PARSING + parser = configured_parser if callable(configured_parser) else FunctionTool.parse_result parameter_names = set(self.parameters().get("properties", {}).keys()) direct_argument_kwargs = ( @@ -616,6 +681,10 @@ async def invoke( logger.debug(f"Function arguments: {observable_kwargs}") res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res + if skip_parsing: + logger.info(f"Function {self.name} succeeded.") + logger.debug(f"Function result: {type(result).__name__}") + return result try: parsed = parser(result) except Exception: @@ -671,6 +740,13 @@ async def invoke( logger.error(f"Function failed. Error: {exception}") raise else: + if skip_parsing: + logger.info(f"Function {self.name} succeeded.") + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] + result_str = str(result) + span.set_attribute(OtelAttr.TOOL_RESULT, result_str) + logger.debug(f"Function result: {result_str}") + return result try: parsed = parser(result) except Exception: @@ -1067,7 +1143,7 @@ def tool( max_invocations: int | None = None, max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, - result_parser: Callable[[Any], str | list[Content]] | None = None, + result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None, ) -> FunctionTool: ... @@ -1083,7 +1159,7 @@ def tool( max_invocations: int | None = None, max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, - result_parser: Callable[[Any], str | list[Content]] | None = None, + result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None, ) -> Callable[[Callable[..., Any]], FunctionTool]: ... @@ -1098,7 +1174,7 @@ def tool( max_invocations: int | None = None, max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, - result_parser: Callable[[Any], str | list[Content]] | None = None, + result_parser: Callable[[Any], str | list[Content]] | _SkipParsingSentinel | None = None, ) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]: """Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically. diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 6fa7172295..b3762bf4ef 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from agent_framework import ( + SKIP_PARSING, Content, FunctionTool, tool, @@ -1300,4 +1301,165 @@ def __len__(self) -> int: assert normalized[1] is standalone +# region SKIP_PARSING sentinel & skip_parsing + + +async def test_invoke_skip_parsing_returns_native_value() -> None: + """invoke(skip_parsing=True) returns the wrapped function's raw value.""" + + @tool + def get_weather(city: str) -> dict[str, Any]: + """Get the weather.""" + return {"city": city, "temperature_c": 21.5, "conditions": "partly cloudy"} + + raw = await get_weather.invoke(arguments={"city": "Seattle"}, skip_parsing=True) + + assert isinstance(raw, dict) + assert raw == {"city": "Seattle", "temperature_c": 21.5, "conditions": "partly cloudy"} + + +async def test_invoke_skip_parsing_passes_through_custom_objects() -> None: + """skip_parsing must not call str()/repr() on the result.""" + + class Custom: # noqa: B903 + def __init__(self, value: int) -> None: + self.value = value + + @tool + def make() -> Custom: + """Make a custom object.""" + return Custom(42) + + raw = await make.invoke(skip_parsing=True) + + assert isinstance(raw, Custom) + assert raw.value == 42 + + +async def test_invoke_skip_parsing_awaits_async_functions() -> None: + @tool + async def slow(x: int) -> int: + """Async tool.""" + return x * 2 + + raw = await slow.invoke(arguments={"x": 21}, skip_parsing=True) + assert raw == 42 + + +async def test_invoke_skip_parsing_bypasses_configured_result_parser() -> None: + """The tool's own result_parser is bypassed when skip_parsing=True is requested.""" + parser_calls: list[Any] = [] + + def parser(value: Any) -> str: + parser_calls.append(value) + return "PARSED" + + @tool(result_parser=parser) + def make_dict() -> dict[str, int]: + """Returns a dict.""" + return {"a": 1} + + raw = await make_dict.invoke(skip_parsing=True) + assert raw == {"a": 1} + assert parser_calls == [] + + # Sanity: omitting skip_parsing still applies the configured parser. + parsed = await make_dict.invoke() + assert parsed[0].type == "text" + assert parsed[0].text == "PARSED" + + +async def test_constructor_skip_parsing_sentinel_returns_raw_by_default() -> None: + """Constructing a tool with result_parser=SKIP_PARSING makes invoke return the raw value.""" + + @tool(result_parser=SKIP_PARSING) + def make_dict() -> dict[str, int]: + """Returns a dict.""" + return {"a": 1} + + raw = await make_dict.invoke() + assert raw == {"a": 1} + + +async def test_invoke_skip_parsing_validates_arguments() -> None: + """Argument validation is shared with the default path.""" + + @tool + def adder(x: int, y: int) -> int: + """Add.""" + return x + y + + with pytest.raises(TypeError): + await adder.invoke(arguments={"x": "not-an-int", "y": 1}, skip_parsing=True) + + +async def test_invoke_skip_parsing_rejects_unexpected_runtime_kwargs() -> None: + @tool + async def echo(message: str) -> str: + """Echo.""" + return message + + with pytest.raises(TypeError, match="Unexpected keyword argument"): + await echo.invoke(arguments={"message": "hi"}, skip_parsing=True, api_token="secret") + + +async def test_invoke_skip_parsing_raises_for_declaration_only_tool() -> None: + declared = FunctionTool(name="dummy", description="declaration only") + + from agent_framework.exceptions import ToolException + + with pytest.raises(ToolException): + await declared.invoke(arguments={}, skip_parsing=True) + + +async def test_invoke_skip_parsing_records_telemetry(span_exporter: InMemorySpanExporter) -> None: + """skip_parsing participates in OTEL spans and records str(raw) as TOOL_RESULT.""" + + @tool(name="raw_tool", description="raw tool") + def returns_dict(x: int) -> dict[str, int]: + """Returns a dict.""" + return {"value": x} + + span_exporter.clear() + raw = await returns_dict.invoke(arguments={"x": 5}, tool_call_id="raw_call", skip_parsing=True) + + assert raw == {"value": 5} + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.attributes[OtelAttr.TOOL_NAME] == "raw_tool" + assert span.attributes[OtelAttr.TOOL_CALL_ID] == "raw_call" + assert span.attributes[OtelAttr.TOOL_RESULT] == "{'value': 5}" + + +async def test_invoke_default_path_records_parsed_telemetry( + span_exporter: InMemorySpanExporter, +) -> None: + """Regression: omitting skip_parsing still records the parsed result in telemetry.""" + + def parser(value: Any) -> str: + return f"parsed:{value}" + + @tool(name="parsed_tool", description="parsed", result_parser=parser) + def returns_int() -> int: + """Returns an int.""" + return 7 + + span_exporter.clear() + parsed = await returns_int.invoke(tool_call_id="parsed_call") + + assert parsed[0].text == "parsed:7" + spans = span_exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].attributes[OtelAttr.TOOL_RESULT] == "parsed:7" + + +def test_skip_parsing_is_singleton() -> None: + """SKIP_PARSING is a singleton; instantiation returns the same object.""" + from agent_framework._tools import _SkipParsingSentinel + + assert _SkipParsingSentinel() is SKIP_PARSING + assert repr(SKIP_PARSING) == "SKIP_PARSING" + + # endregion diff --git a/python/packages/hyperlight/README.md b/python/packages/hyperlight/README.md index 1b1bc1e0ce..afc36f9365 100644 --- a/python/packages/hyperlight/README.md +++ b/python/packages/hyperlight/README.md @@ -130,3 +130,9 @@ codeact = HyperlightCodeActProvider( - `allowed_domains` accepts a single string target such as `"github.com"` to allow all backend-supported methods, an explicit `(target, method_or_methods)` tuple such as `("github.com", "GET")`, or an `AllowedDomain` named tuple. +- Tools registered with the sandbox return their native Python value + (`dict`, `list`, primitives, or custom objects) directly to the guest via the + Hyperlight FFI. Any `result_parser` configured on a `FunctionTool` is + intended for LLM-facing consumers and does not run on the sandbox path — + apply formatting inside the tool function itself if you need it for + in-sandbox consumers. diff --git a/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py b/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py index a46707ac0d..f91eeb5215 100644 --- a/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py +++ b/python/packages/hyperlight/agent_framework_hyperlight/_execute_code_tool.py @@ -2,42 +2,45 @@ from __future__ import annotations -import ast import asyncio -import copy import mimetypes import shutil import threading import time from collections.abc import Callable, Sequence -from dataclasses import dataclass +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import suppress +from copy import copy +from dataclasses import dataclass, field from pathlib import Path, PurePosixPath from tempfile import TemporaryDirectory -from typing import Annotated, Any, Protocol, TypeGuard, cast +from typing import Any, Protocol, TypeGuard, TypeVar, cast from urllib.parse import urlparse from agent_framework import Content, FunctionTool from agent_framework._tools import ApprovalMode, normalize_tools -from pydantic import BaseModel, Field from ._instructions import build_codeact_instructions, build_execute_code_description from ._types import AllowedDomain, AllowedDomainInput, FileMount, FileMountHostPath, FileMountInput DEFAULT_HYPERLIGHT_BACKEND = "wasm" DEFAULT_HYPERLIGHT_MODULE = "python_guest.path" -EXECUTE_CODE_INPUT_DESCRIPTION = "Python code to execute in an isolated Hyperlight sandbox." +EXECUTE_CODE_TOOL_DESCRIPTION = "Execute Python in an isolated Hyperlight sandbox." OUTPUT_FILE_RETRY_ATTEMPTS = 10 OUTPUT_FILE_RETRY_DELAY_SECONDS = 0.1 - -class _ExecuteCodeInput(BaseModel): - code: Annotated[str, Field(description=EXECUTE_CODE_INPUT_DESCRIPTION)] - - -@dataclass(frozen=True, slots=True) -class _StoredFileMount: - host_path: Path - mount_path: str +EXECUTE_CODE_INPUT_SCHEMA: dict[str, Any] = { + "type": "object", + "title": "_ExecuteCodeInput", + "properties": { + "code": { + "type": "string", + "title": "Code", + "description": "Python code to execute in an isolated Hyperlight sandbox.", + }, + }, + "required": ["code"], +} @dataclass(frozen=True, slots=True) @@ -85,13 +88,43 @@ class SandboxRuntime(Protocol): def execute(self, *, config: _RunConfig, code: str) -> list[Content]: ... +_T = TypeVar("_T") + + +class _SandboxWorker: + """Single-threaded executor that confines all sandbox operations to one OS thread. + + The Hyperlight ``WasmSandbox`` is declared ``unsendable`` in PyO3, meaning it can only be + accessed from the OS thread that created it; touching it from any other thread triggers a + Rust panic that cannot be caught from Python. Every cached :class:`_SandboxEntry` therefore + owns its own ``_SandboxWorker``, and *all* lifecycle and execution calls against the + underlying sandbox object must be routed through :meth:`submit`/:meth:`run`. + """ + + __slots__ = ("_executor",) + + def __init__(self, *, name: str = "hl-sandbox") -> None: + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=name) + + def submit(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> Future[_T]: + return self._executor.submit(fn, *args, **kwargs) + + def run(self, fn: Callable[..., _T], /, *args: Any, **kwargs: Any) -> _T: + return self._executor.submit(fn, *args, **kwargs).result() + + def shutdown(self) -> None: + # Do not block on shutdown; stop accepting new tasks, but allow the currently running + # task and any already-queued tasks to finish before the worker thread exits. + self._executor.shutdown(wait=False, cancel_futures=False) + + @dataclass class _SandboxEntry: sandbox: Any snapshot: Any input_dir: TemporaryDirectory[str] | None output_dir: TemporaryDirectory[str] | None - lock: threading.RLock + worker: _SandboxWorker = field(default_factory=_SandboxWorker) def _load_sandbox_class() -> type[Any]: @@ -106,10 +139,6 @@ def _load_sandbox_class() -> type[Any]: return Sandbox -def _passthrough_result_parser(result: Any) -> str: - return repr(result) - - def _collect_tools(*tool_groups: Any) -> list[FunctionTool]: tools_by_name: dict[str, FunctionTool] = {} @@ -166,7 +195,7 @@ def _is_file_mount_pair(value: Any) -> TypeGuard[FileMount | tuple[FileMountHost return isinstance(host_path, (str, Path)) and isinstance(mount_path, str) -def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount: +def _normalize_file_mount_input(file_mount: FileMountInput) -> FileMount: host_path: FileMountHostPath mount_path: str if isinstance(file_mount, str): @@ -176,7 +205,7 @@ def _normalize_file_mount_input(file_mount: FileMountInput) -> _StoredFileMount: host_path = file_mount[0] mount_path = file_mount[1] - return _StoredFileMount( + return FileMount( host_path=_resolve_existing_path(host_path), mount_path=_normalize_mount_path(mount_path), ) @@ -445,18 +474,13 @@ def _build_execution_contents( def _make_sandbox_callback(tool_obj: FunctionTool) -> Callable[..., Any]: - sandbox_tool = copy.copy(tool_obj) - # Auto-assign a passthrough parser so the raw return value round-trips through - # `ast.literal_eval` in the sandbox callback below. User-supplied parsers are - # left in place so callers can customize how results are exposed to the guest. - if sandbox_tool.result_parser is None: - sandbox_tool.result_parser = _passthrough_result_parser + sandbox_tool = copy(tool_obj) def _callback(**kwargs: Any) -> Any: - async def _invoke() -> list[Content]: - return await sandbox_tool.invoke(arguments=kwargs) + async def _invoke() -> Any: + return await sandbox_tool.invoke(arguments=kwargs, skip_parsing=True) - # FunctionTool.invoke() is always async. The real Hyperlight backend invokes + # FunctionTool.invoke() is async. The real Hyperlight backend invokes # registered callbacks synchronously via FFI, so this must be a sync function. # We run the async call on a dedicated thread to avoid conflicts with any # event loop that may be running on the current thread. @@ -474,22 +498,11 @@ def _run() -> None: worker.join() if error_box: raise error_box[0] - contents: list[Content] = result_box[0] - - values: list[Any] = [] - for content in contents: - if content.type == "text" and content.text is not None: - try: - values.append(ast.literal_eval(content.text)) - except (SyntaxError, ValueError): - values.append(content.text) - continue - - values.append(content.to_dict()) - - if len(values) == 1: - return values[0] - return values + # Return the raw value. The Hyperlight FFI marshals primitives (dict, list, + # str, int, float, bool, None) natively into the guest, and falls back to + # repr()/str() for unsupported types — so the guest receives real Python + # objects without a lossy host-side serialization round-trip. + return result_box[0] return _callback @@ -509,7 +522,7 @@ def _clear_directory(output_dir: TemporaryDirectory[str] | None) -> None: pass -class _SandboxRegistry: +class _SandboxRegistry(SandboxRuntime): def __init__(self) -> None: self._entries: dict[tuple[Any, ...], _SandboxEntry] = {} self._entries_lock = threading.RLock() @@ -517,28 +530,54 @@ def __init__(self) -> None: def execute(self, *, config: _RunConfig, code: str) -> list[Content]: """Execute code in a cached sandbox matching the given config. - Entries are keyed by ``config.cache_key()``. Concurrent calls with the same - key are serialized by the entry lock so they never race, but they share the - same sandbox instance. For true parallel execution, use distinct provider - instances or configs that produce different cache keys. + Entries are keyed by ``config.cache_key()``. All operations against the underlying + sandbox object are routed through the entry's dedicated single-threaded worker, which + both serializes concurrent callers and satisfies the PyO3 ``unsendable`` invariant + that the sandbox can only be touched from the thread that created it. """ + entry = self._get_or_create_entry(config) + return entry.worker.run(self._run_on_worker, entry, code) + + @staticmethod + def _run_on_worker(entry: _SandboxEntry, code: str) -> list[Content]: + entry.sandbox.restore(entry.snapshot) + _clear_directory(entry.output_dir) + result = entry.sandbox.run(code=code) + return _build_execution_contents( + result=result, + sandbox=entry.sandbox, + output_dir=entry.output_dir, + code=code, + ) + + def _get_or_create_entry(self, config: _RunConfig) -> _SandboxEntry: cache_key = config.cache_key() with self._entries_lock: entry = self._entries.get(cache_key) if entry is None: entry = self._create_entry(config) self._entries[cache_key] = entry + return entry - with entry.lock: - entry.sandbox.restore(entry.snapshot) - _clear_directory(entry.output_dir) - result = entry.sandbox.run(code=code) - return _build_execution_contents( - result=result, - sandbox=entry.sandbox, - output_dir=entry.output_dir, - code=code, - ) + def close(self) -> None: + """Shut down all per-entry worker threads and release per-entry resources. + + Safe to call multiple times. Runs any sandbox close hook on the entry's + own worker thread to honor the PyO3 ``unsendable`` invariant. + """ + with self._entries_lock: + entries = list(self._entries.values()) + self._entries.clear() + for entry in entries: + close_hook = getattr(entry.sandbox, "close", None) or getattr(entry.sandbox, "shutdown", None) + if callable(close_hook): + with suppress(Exception): + entry.worker.run(close_hook) + entry.worker.shutdown() + for tmp_dir in (entry.input_dir, entry.output_dir): + if tmp_dir is not None: + with suppress(Exception): + tmp_dir.cleanup() def _create_entry(self, config: _RunConfig) -> _SandboxEntry: input_dir_handle = TemporaryDirectory() if config.filesystem_enabled else None @@ -578,26 +617,37 @@ def _configure_sandbox(*, sandbox: Any, expand_missing_scheme: bool) -> None: methods=list(allowed_domain.methods) if allowed_domain.methods is not None else None, ) - sandbox = _create_sandbox() - _configure_sandbox(sandbox=sandbox, expand_missing_scheme=False) - - try: - sandbox.run("None") - except RuntimeError as exc: - if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains): - raise + worker = _SandboxWorker() + def _build_sandbox() -> tuple[Any, Any]: sandbox = _create_sandbox() - _configure_sandbox(sandbox=sandbox, expand_missing_scheme=True) - sandbox.run("None") + _configure_sandbox(sandbox=sandbox, expand_missing_scheme=False) + + try: + sandbox.run("None") + except RuntimeError as exc: + if not _should_retry_allowed_domain_registration(error=exc, allowed_domains=config.allowed_domains): + raise + + sandbox = _create_sandbox() + _configure_sandbox(sandbox=sandbox, expand_missing_scheme=True) + sandbox.run("None") + + snapshot = sandbox.snapshot() + return sandbox, snapshot + + try: + sandbox, snapshot = worker.run(_build_sandbox) + except BaseException: + worker.shutdown() + raise - snapshot = sandbox.snapshot() return _SandboxEntry( sandbox=sandbox, snapshot=snapshot, input_dir=input_dir_handle, output_dir=output_dir_handle, - lock=threading.RLock(), + worker=worker, ) @@ -619,10 +669,10 @@ def __init__( ) -> None: super().__init__( name="execute_code", - description=EXECUTE_CODE_INPUT_DESCRIPTION, + description=EXECUTE_CODE_TOOL_DESCRIPTION, approval_mode="never_require", func=self._run_code, - input_model=_ExecuteCodeInput, + input_model=EXECUTE_CODE_INPUT_SCHEMA, ) self._state_lock = threading.RLock() self._registry = _registry or _SandboxRegistry() @@ -632,7 +682,7 @@ def __init__( self._module: str | None = module self._module_path: str | None = module_path self._managed_tools: list[FunctionTool] = [] - self._file_mounts: dict[str, _StoredFileMount] = {} + self._file_mounts: dict[str, FileMount] = {} self._allowed_domains: dict[str, AllowedDomain] = {} if tools is not None: @@ -648,7 +698,7 @@ def __init__( def description(self) -> str: state_lock = getattr(self, "_state_lock", None) if state_lock is None: - return str(self.__dict__.get("description", EXECUTE_CODE_INPUT_DESCRIPTION)) + return str(self.__dict__.get("description", EXECUTE_CODE_TOOL_DESCRIPTION)) with state_lock: allowed_domains = sorted(self._allowed_domains.values(), key=lambda value: value.target) @@ -841,9 +891,9 @@ def _build_run_config(self) -> _RunConfig: workspace_signature = _path_tree_signature(workspace_root) if workspace_root is not None else () normalized_mounts = tuple( _NormalizedFileMount( - host_path=mount.host_path, + host_path=Path(mount.host_path), mount_path=mount.mount_path, - path_signature=_path_tree_signature(mount.host_path), + path_signature=_path_tree_signature(Path(mount.host_path)), ) for mount in stored_mounts ) diff --git a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py index ab6a3f7c78..e41a5a6ee1 100644 --- a/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py +++ b/python/packages/hyperlight/tests/hyperlight/test_hyperlight_codeact.py @@ -937,3 +937,191 @@ async def _concurrent_task(): assert concurrent_ran, "Event loop was blocked during sandbox execution" assert result[0].type == "text" + + +class _ThreadAffinityFakeSandbox(_FakeSandbox): + """Fake sandbox that records the OS thread of every method invocation. + + Mirrors the PyO3 ``unsendable`` invariant of ``hyperlight_sandbox.WasmSandbox``: + if ``__init__``, ``register_tool``, ``allow_domain``, ``run``, ``snapshot`` or ``restore`` + are ever called from more than one thread for a given instance, the test fails. + """ + + affinity_failures: list[str] = [] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._owner_thread = threading.get_ident() + self.thread_ids: set[int] = {self._owner_thread} + + def _record(self, method: str) -> None: + ident = threading.get_ident() + self.thread_ids.add(ident) + if ident != self._owner_thread: + _ThreadAffinityFakeSandbox.affinity_failures.append( + f"{method} called from thread {ident}, expected {self._owner_thread}" + ) + + def register_tool(self, name_or_tool: Any, callback: Any | None = None) -> None: + self._record("register_tool") + super().register_tool(name_or_tool, callback) + + def allow_domain(self, target: str, methods: list[str] | None = None) -> None: + self._record("allow_domain") + super().allow_domain(target, methods) + + def run(self, code: str) -> _FakeResult: + self._record("run") + return super().run(code) + + def snapshot(self) -> str: + self._record("snapshot") + return super().snapshot() + + def restore(self, snapshot: Any) -> None: + self._record("restore") + super().restore(snapshot) + + +async def test_sandbox_calls_are_pinned_to_owning_worker_thread( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Regression: WasmSandbox is unsendable; every sandbox call must run on its owner thread.""" + _ThreadAffinityFakeSandbox.instances.clear() + _ThreadAffinityFakeSandbox.affinity_failures.clear() + monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox) + + execute_code = HyperlightExecuteCodeTool() + + # Invoke many times concurrently; asyncio.to_thread will spread these across the default + # executor's worker threads, which previously caused PyO3 to panic when a different thread + # touched the cached sandbox. + results = await asyncio.gather(*[execute_code.invoke(arguments={"code": "None"}) for _ in range(8)]) + for result in results: + assert result[0].type == "text" + + assert _ThreadAffinityFakeSandbox.affinity_failures == [] + assert len(_ThreadAffinityFakeSandbox.instances) == 1 + sandbox = _ThreadAffinityFakeSandbox.instances[0] + # All sandbox-touching calls must have stayed on a single owning thread, distinct from the + # caller thread that asyncio.to_thread used for dispatch. + assert sandbox.thread_ids == {sandbox._owner_thread} + assert sandbox._owner_thread != threading.get_ident() + + +async def test_sandbox_owner_thread_persists_across_dispatch_threads( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Sequential calls landing on different dispatch threads still share one sandbox thread.""" + _ThreadAffinityFakeSandbox.instances.clear() + _ThreadAffinityFakeSandbox.affinity_failures.clear() + monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ThreadAffinityFakeSandbox) + + execute_code = HyperlightExecuteCodeTool() + + for _ in range(5): + result = await execute_code.invoke(arguments={"code": "None"}) + assert result[0].type == "text" + + assert _ThreadAffinityFakeSandbox.affinity_failures == [] + assert len(_ThreadAffinityFakeSandbox.instances) == 1 + + +def test_sandbox_registry_close_shuts_down_workers(monkeypatch: pytest.MonkeyPatch) -> None: + _FakeSandbox.instances.clear() + monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _FakeSandbox) + + registry = execute_code_module._SandboxRegistry() + execute_code = HyperlightExecuteCodeTool(_registry=registry) + asyncio.run(execute_code.invoke(arguments={"code": "None"})) + + entries = list(registry._entries.values()) + assert len(entries) == 1 + worker = entries[0].worker + + registry.close() + + assert registry._entries == {} + # Submitting after shutdown must fail; this proves the executor was actually torn down. + with pytest.raises(RuntimeError): + worker.submit(lambda: None) + + +def test_sandbox_registry_close_releases_per_entry_resources(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + """close() must invoke any sandbox close hook and release temp directories.""" + + close_calls: list[int] = [] + + class _ClosableFakeSandbox(_FakeSandbox): + def close(self) -> None: + close_calls.append(1) + + _FakeSandbox.instances.clear() + monkeypatch.setattr(execute_code_module, "_load_sandbox_class", lambda: _ClosableFakeSandbox) + + workspace = tmp_path / "workspace" + workspace.mkdir() + registry = execute_code_module._SandboxRegistry() + execute_code = HyperlightExecuteCodeTool(workspace_root=workspace, _registry=registry) + asyncio.run(execute_code.invoke(arguments={"code": "None"})) + + entries = list(registry._entries.values()) + assert len(entries) == 1 + entry = entries[0] + assert entry.input_dir is not None and entry.output_dir is not None + input_path = Path(entry.input_dir.name) + output_path = Path(entry.output_dir.name) + assert input_path.exists() and output_path.exists() + + registry.close() + + assert close_calls == [1] + assert not input_path.exists() + assert not output_path.exists() + + +async def test_make_sandbox_callback_returns_native_dict() -> None: + """Host tool returning a dict must be forwarded as a native dict (no repr round-trip).""" + + @tool + def get_weather(city: str) -> dict[str, Any]: + """Get weather.""" + return {"city": city, "temp_c": 21.5} + + callback = execute_code_module._make_sandbox_callback(get_weather) + result = callback(city="Seattle") + + assert isinstance(result, dict) + assert result == {"city": "Seattle", "temp_c": 21.5} + + +async def test_make_sandbox_callback_bypasses_user_result_parser() -> None: + """Documented behavior change: result_parser is bypassed in the sandbox path.""" + + parser_calls: list[Any] = [] + + def parser(value: Any) -> str: + parser_calls.append(value) + return "PARSED" + + @tool(result_parser=parser) + def make_payload() -> dict[str, int]: + """Returns a dict.""" + return {"a": 1, "b": 2} + + callback = execute_code_module._make_sandbox_callback(make_payload) + result = callback() + + assert result == {"a": 1, "b": 2} + assert parser_calls == [], "result_parser must not run on the sandbox path" + + +async def test_make_sandbox_callback_propagates_exceptions() -> None: + @tool + def boom(x: int) -> int: + """Always fails.""" + raise RuntimeError("nope") + + callback = execute_code_module._make_sandbox_callback(boom) + with pytest.raises(RuntimeError, match="nope"): + callback(x=1)