diff --git a/python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/_jupyter_client.py b/python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/_jupyter_client.py index b8ae4058c992..874010c9a2c5 100644 --- a/python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/_jupyter_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/code_executors/jupyter/_jupyter_client.py @@ -52,11 +52,15 @@ def _get_ws_base_url(self) -> str: return f"ws://{self._connection_info.host}{port}" def list_kernel_specs(self) -> dict[str, dict[str, str]]: - response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) + response = self._session.get( + f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers() + ) return cast(dict[str, dict[str, str]], response.json()) def list_kernels(self) -> list[dict[str, str]]: - response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) + response = self._session.get( + f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers() + ) return cast(list[dict[str, str]], response.json()) def start_kernel(self, kernel_spec_name: str) -> str: @@ -78,13 +82,15 @@ def start_kernel(self, kernel_spec_name: str) -> str: def delete_kernel(self, kernel_id: str) -> None: response = self._session.delete( - f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers() + f"{self._get_api_base_url()}/api/kernels/{kernel_id}", + headers=self._get_headers(), ) response.raise_for_status() def restart_kernel(self, kernel_id: str) -> None: response = self._session.post( - f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers() + f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", + headers=self._get_headers(), ) response.raise_for_status() @@ -92,7 +98,14 @@ async def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient: ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" headers = self._get_headers() headers["Cookie"] = self._get_cookies() - websocket = await connect(ws_url, additional_headers=headers) + websocket = await connect( + ws_url, + additional_headers=headers, + max_size=2**24, + open_timeout=120, + ping_timeout=30, + close_timeout=30, + ) return JupyterKernelClient(websocket) @@ -114,7 +127,9 @@ def __init__(self, websocket: ClientConnection): self._session_id: str = uuid.uuid4().hex self._websocket = websocket - async def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str: + async def _send_message( + self, *, content: dict[str, Any], channel: str, message_type: str + ) -> str: timestamp = datetime.datetime.now().isoformat() message_id = uuid.uuid4().hex message = { @@ -136,7 +151,9 @@ async def _send_message(self, *, content: dict[str, Any], channel: str, message_ return message_id async def wait_for_ready(self) -> None: - message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request") + message_id = await self._send_message( + content={}, channel="shell", message_type="kernel_info_request" + ) async for message in self._receive_message(message_id): if message["msg_type"] == "kernel_info_reply": @@ -167,8 +184,14 @@ async def execute(self, code: str) -> ExecutionResult: match data_type: case "text/plain": text_output.append(data) - case type if type.startswith("image/") or type == "text/html": - data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data)) + case type if type.startswith( + "image/" + ) or type == "text/html": + data_output.append( + self.ExecutionResult.DataItem( + mime_type=data_type, data=data + ) + ) case _: text_output.append(json.dumps(data)) case "stream": @@ -186,14 +209,19 @@ async def execute(self, code: str) -> ExecutionResult: break return JupyterKernelClient.ExecutionResult( - is_ok=True, output="\n".join([output for output in text_output]), data_items=data_output + is_ok=True, + output="\n".join([output for output in text_output]), + data_items=data_output, ) async def __aenter__(self) -> Self: return self async def __aexit__( - self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: await self._websocket.close() diff --git a/python/packages/autogen-ext/tests/code_executors/test_jupyter_code_executor.py b/python/packages/autogen-ext/tests/code_executors/test_jupyter_code_executor.py index 43430cb878e2..578e8a3bebb3 100644 --- a/python/packages/autogen-ext/tests/code_executors/test_jupyter_code_executor.py +++ b/python/packages/autogen-ext/tests/code_executors/test_jupyter_code_executor.py @@ -3,7 +3,6 @@ from pathlib import Path import pytest -import websockets from autogen_core import CancellationToken from autogen_core.code_executor import CodeBlock from autogen_ext.code_executors.jupyter import JupyterCodeExecutor, JupyterCodeResult, LocalJupyterServer