Skip to content

Commit

Permalink
Use more sensible default for websocket connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Leon0402 committed Dec 28, 2024
1 parent 72c8eda commit 661d670
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ def _get_ws_base_url(self) -> str:
return f"ws://{self._connection_info.host}{port}"

def list_kernel_specs(self) -> dict[str, dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers())
response = self._session.get(
f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()
)
return cast(dict[str, dict[str, str]], response.json())

def list_kernels(self) -> list[dict[str, str]]:
response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers())
response = self._session.get(
f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()
)
return cast(list[dict[str, str]], response.json())

def start_kernel(self, kernel_spec_name: str) -> str:
Expand All @@ -78,21 +82,30 @@ def start_kernel(self, kernel_spec_name: str) -> str:

def delete_kernel(self, kernel_id: str) -> None:
response = self._session.delete(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers()
f"{self._get_api_base_url()}/api/kernels/{kernel_id}",
headers=self._get_headers(),
)
response.raise_for_status()

def restart_kernel(self, kernel_id: str) -> None:
response = self._session.post(
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers()
f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart",
headers=self._get_headers(),
)
response.raise_for_status()

async def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient:
ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels"
headers = self._get_headers()
headers["Cookie"] = self._get_cookies()
websocket = await connect(ws_url, additional_headers=headers)
websocket = await connect(
ws_url,
additional_headers=headers,
max_size=2**24,
open_timeout=120,
ping_timeout=30,
close_timeout=30,
)
return JupyterKernelClient(websocket)


Expand All @@ -114,7 +127,9 @@ def __init__(self, websocket: ClientConnection):
self._session_id: str = uuid.uuid4().hex
self._websocket = websocket

async def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str:
async def _send_message(
self, *, content: dict[str, Any], channel: str, message_type: str
) -> str:
timestamp = datetime.datetime.now().isoformat()
message_id = uuid.uuid4().hex
message = {
Expand All @@ -136,7 +151,9 @@ async def _send_message(self, *, content: dict[str, Any], channel: str, message_
return message_id

async def wait_for_ready(self) -> None:
message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request")
message_id = await self._send_message(
content={}, channel="shell", message_type="kernel_info_request"
)

async for message in self._receive_message(message_id):
if message["msg_type"] == "kernel_info_reply":
Expand Down Expand Up @@ -167,8 +184,14 @@ async def execute(self, code: str) -> ExecutionResult:
match data_type:
case "text/plain":
text_output.append(data)
case type if type.startswith("image/") or type == "text/html":
data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data))
case type if type.startswith(
"image/"
) or type == "text/html":
data_output.append(
self.ExecutionResult.DataItem(
mime_type=data_type, data=data
)
)
case _:
text_output.append(json.dumps(data))
case "stream":
Expand All @@ -186,14 +209,19 @@ async def execute(self, code: str) -> ExecutionResult:
break

return JupyterKernelClient.ExecutionResult(
is_ok=True, output="\n".join([output for output in text_output]), data_items=data_output
is_ok=True,
output="\n".join([output for output in text_output]),
data_items=data_output,
)

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self._websocket.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path

import pytest
import websockets
from autogen_core import CancellationToken
from autogen_core.code_executor import CodeBlock
from autogen_ext.code_executors.jupyter import JupyterCodeExecutor, JupyterCodeResult, LocalJupyterServer
Expand Down

0 comments on commit 661d670

Please sign in to comment.