Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix iostream on new thread #2181

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion autogen/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .websockets import IOWebsockets

# Set the default input/output stream to the console
IOStream._default_io_stream.set(IOConsole())
IOStream.set_global_default(IOConsole())
IOStream.set_default(IOConsole())

__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")
35 changes: 30 additions & 5 deletions autogen/io/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from contextlib import contextmanager
from contextvars import ContextVar
import logging
from typing import Any, Iterator, Optional, Protocol, runtime_checkable

__all__ = ("OutputStream", "InputStream", "IOStream")

logger = logging.getLogger(__name__)


@runtime_checkable
class OutputStream(Protocol):
Expand Down Expand Up @@ -39,6 +42,31 @@ def input(self, prompt: str = "", *, password: bool = False) -> str:
class IOStream(InputStream, OutputStream, Protocol):
"""A protocol for input/output streams."""

# ContextVar must be used in multithreaded or async environments
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None)
_default_io_stream.set(None)
_global_default: Optional["IOStream"] = None

@staticmethod
def set_global_default(stream: "IOStream") -> None:
"""Set the default input/output stream.

Args:
stream (IOStream): The input/output stream to set as the default.
"""
IOStream._global_default = stream

@staticmethod
def get_global_default() -> "IOStream":
"""Get the default input/output stream.

Returns:
IOStream: The default input/output stream.
"""
if IOStream._global_default is None:
raise RuntimeError("No global default IOStream has been set")
return IOStream._global_default

@staticmethod
def get_default() -> "IOStream":
"""Get the default input/output stream.
Expand All @@ -48,13 +76,10 @@ def get_default() -> "IOStream":
"""
iostream = IOStream._default_io_stream.get()
if iostream is None:
raise RuntimeError("No default IOStream has been set")
logger.warning("No default IOStream has been set, defaulting to IOConsole.")
return IOStream.get_global_default()
return iostream

# ContextVar must be used in multithreaded or async environments
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
_default_io_stream.set(None)

@staticmethod
@contextmanager
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:
Expand Down
23 changes: 22 additions & 1 deletion test/io/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any
from threading import Thread
from typing import Any, List

from autogen.io import IOConsole, IOStream, IOWebsockets

Expand Down Expand Up @@ -26,3 +27,23 @@ def input(self, prompt: str = "", *, password: bool = False) -> str:
assert isinstance(IOStream.get_default(), MyIOStream)

assert isinstance(IOStream.get_default(), IOConsole)

def test_get_default_on_new_thread(self) -> None:
exceptions: List[Exception] = []

def on_new_thread(exceptions: List[Exception] = exceptions) -> None:
try:
assert isinstance(IOStream.get_default(), IOConsole)
except Exception as e:
exceptions.append(e)

# create a new thread and run the function
thread = Thread(target=on_new_thread)

thread.start()

# get exception from the thread
thread.join()

if exceptions:
raise exceptions[0]
Loading