Skip to content

Commit 21a7eb3

Browse files
davorrunjesonichi
andauthored
Fix iostream on new thread (#2181)
* fixed get_stream in new thread by introducing a global default * fixed get_stream in new thread by introducing a global default --------- Co-authored-by: Chi Wang <[email protected]>
1 parent f467f21 commit 21a7eb3

File tree

3 files changed

+54
-7
lines changed

3 files changed

+54
-7
lines changed

autogen/io/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .websockets import IOWebsockets
44

55
# Set the default input/output stream to the console
6-
IOStream._default_io_stream.set(IOConsole())
6+
IOStream.set_global_default(IOConsole())
7+
IOStream.set_default(IOConsole())
78

89
__all__ = ("IOConsole", "IOStream", "InputStream", "OutputStream", "IOWebsockets")

autogen/io/base.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from contextlib import contextmanager
22
from contextvars import ContextVar
3+
import logging
34
from typing import Any, Iterator, Optional, Protocol, runtime_checkable
45

56
__all__ = ("OutputStream", "InputStream", "IOStream")
67

8+
logger = logging.getLogger(__name__)
9+
710

811
@runtime_checkable
912
class OutputStream(Protocol):
@@ -39,6 +42,31 @@ def input(self, prompt: str = "", *, password: bool = False) -> str:
3942
class IOStream(InputStream, OutputStream, Protocol):
4043
"""A protocol for input/output streams."""
4144

45+
# ContextVar must be used in multithreaded or async environments
46+
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream", default=None)
47+
_default_io_stream.set(None)
48+
_global_default: Optional["IOStream"] = None
49+
50+
@staticmethod
51+
def set_global_default(stream: "IOStream") -> None:
52+
"""Set the default input/output stream.
53+
54+
Args:
55+
stream (IOStream): The input/output stream to set as the default.
56+
"""
57+
IOStream._global_default = stream
58+
59+
@staticmethod
60+
def get_global_default() -> "IOStream":
61+
"""Get the default input/output stream.
62+
63+
Returns:
64+
IOStream: The default input/output stream.
65+
"""
66+
if IOStream._global_default is None:
67+
raise RuntimeError("No global default IOStream has been set")
68+
return IOStream._global_default
69+
4270
@staticmethod
4371
def get_default() -> "IOStream":
4472
"""Get the default input/output stream.
@@ -48,13 +76,10 @@ def get_default() -> "IOStream":
4876
"""
4977
iostream = IOStream._default_io_stream.get()
5078
if iostream is None:
51-
raise RuntimeError("No default IOStream has been set")
79+
logger.warning("No default IOStream has been set, defaulting to IOConsole.")
80+
return IOStream.get_global_default()
5281
return iostream
5382

54-
# ContextVar must be used in multithreaded or async environments
55-
_default_io_stream: ContextVar[Optional["IOStream"]] = ContextVar("default_iostream")
56-
_default_io_stream.set(None)
57-
5883
@staticmethod
5984
@contextmanager
6085
def set_default(stream: Optional["IOStream"]) -> Iterator[None]:

test/io/test_base.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any
1+
from threading import Thread
2+
from typing import Any, List
23

34
from autogen.io import IOConsole, IOStream, IOWebsockets
45

@@ -26,3 +27,23 @@ def input(self, prompt: str = "", *, password: bool = False) -> str:
2627
assert isinstance(IOStream.get_default(), MyIOStream)
2728

2829
assert isinstance(IOStream.get_default(), IOConsole)
30+
31+
def test_get_default_on_new_thread(self) -> None:
32+
exceptions: List[Exception] = []
33+
34+
def on_new_thread(exceptions: List[Exception] = exceptions) -> None:
35+
try:
36+
assert isinstance(IOStream.get_default(), IOConsole)
37+
except Exception as e:
38+
exceptions.append(e)
39+
40+
# create a new thread and run the function
41+
thread = Thread(target=on_new_thread)
42+
43+
thread.start()
44+
45+
# get exception from the thread
46+
thread.join()
47+
48+
if exceptions:
49+
raise exceptions[0]

0 commit comments

Comments
 (0)