diff --git a/autogen/io/__init__.py b/autogen/io/__init__.py index 20d6d5a578f5..6bb8a35680f4 100644 --- a/autogen/io/__init__.py +++ b/autogen/io/__init__.py @@ -3,6 +3,7 @@ from .websockets import IOWebsockets # Set the default input/output stream to the console -IOStream._default_io_stream.set(IOConsole()) +IOStream.set_global_default(IOConsole()) +IOStream.set_default(IOConsole()) __all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets") diff --git a/autogen/io/base.py b/autogen/io/base.py index 857d532e4f56..db25560e0e82 100644 --- a/autogen/io/base.py +++ b/autogen/io/base.py @@ -1,9 +1,12 @@ from contextlib import contextmanager from contextvars import ContextVar +import logging from typing import Any, Iterator, Optional, Protocol, runtime_checkable __all__ = ("OutputStream", "InputStream", "IOStream") +logger = logging.getLogger(__name__) + @runtime_checkable class OutputStream(Protocol): @@ -39,6 +42,31 @@ def input(self, prompt: str = "", *, password: bool = False) -> str: class IOStream(InputStream, OutputStream, Protocol): """A protocol for input/output streams.""" + # ContextVar must be used in multithreaded or async environments + _default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None) + _default_io_stream.set(None) + _global_default: Optional["IOStream"] = None + + @staticmethod + def set_global_default(stream: "IOStream") -> None: + """Set the default input/output stream. + + Args: + stream (IOStream): The input/output stream to set as the default. + """ + IOStream._global_default = stream + + @staticmethod + def get_global_default() -> "IOStream": + """Get the default input/output stream. + + Returns: + IOStream: The default input/output stream. + """ + if IOStream._global_default is None: + raise RuntimeError("No global default IOStream has been set") + return IOStream._global_default + @staticmethod def get_default() -> "IOStream": """Get the default input/output stream. @@ -48,13 +76,10 @@ def get_default() -> "IOStream": """ iostream = IOStream._default_io_stream.get() if iostream is None: - raise RuntimeError("No default IOStream has been set") + logger.warning("No default IOStream has been set, defaulting to IOConsole.") + return IOStream.get_global_default() return iostream - # ContextVar must be used in multithreaded or async environments - _default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream") - _default_io_stream.set(None) - @staticmethod @contextmanager def set_default(stream: Optional["IOStream"]) -> Iterator[None]: diff --git a/test/io/test_base.py b/test/io/test_base.py index a96f8cb3c794..ba05955164c3 100644 --- a/test/io/test_base.py +++ b/test/io/test_base.py @@ -1,4 +1,5 @@ -from typing import Any +from threading import Thread +from typing import Any, List from autogen.io import IOConsole, IOStream, IOWebsockets @@ -26,3 +27,23 @@ def input(self, prompt: str = "", *, password: bool = False) -> str: assert isinstance(IOStream.get_default(), MyIOStream) assert isinstance(IOStream.get_default(), IOConsole) + + def test_get_default_on_new_thread(self) -> None: + exceptions: List[Exception] = [] + + def on_new_thread(exceptions: List[Exception] = exceptions) -> None: + try: + assert isinstance(IOStream.get_default(), IOConsole) + except Exception as e: + exceptions.append(e) + + # create a new thread and run the function + thread = Thread(target=on_new_thread) + + thread.start() + + # get exception from the thread + thread.join() + + if exceptions: + raise exceptions[0]