diff --git a/colorama/initialise.py b/colorama/initialise.py index 430d066..eb114b9 100644 --- a/colorama/initialise.py +++ b/colorama/initialise.py @@ -5,9 +5,10 @@ from .ansitowin32 import AnsiToWin32 +UNSET = object() -orig_stdout = None -orig_stderr = None +orig_stdout = UNSET +orig_stderr = UNSET wrapped_stdout = None wrapped_stderr = None @@ -21,15 +22,17 @@ def reset_all(): def init(autoreset=False, convert=None, strip=None, wrap=True): - if not wrap and any([autoreset, convert, strip]): raise ValueError('wrap=False conflicts with any other arg=True') global wrapped_stdout, wrapped_stderr global orig_stdout, orig_stderr - orig_stdout = sys.stdout - orig_stderr = sys.stderr + # Prevent multiple calls from losing the original stdout/err + if orig_stdout is UNSET: + orig_stdout = sys.stdout + if orig_stderr is UNSET: + orig_stderr = sys.stderr if sys.stdout is None: wrapped_stdout = None @@ -49,10 +52,13 @@ def init(autoreset=False, convert=None, strip=None, wrap=True): def deinit(): - if orig_stdout is not None: + global orig_stdout, orig_stderr + if orig_stdout is not UNSET: sys.stdout = orig_stdout - if orig_stderr is not None: + orig_stdout = UNSET + if orig_stderr is not UNSET: sys.stderr = orig_stderr + orig_stderr = UNSET @contextlib.contextmanager @@ -78,3 +84,4 @@ def wrap_stream(stream, convert, strip, autoreset, wrap): if wrapper.should_wrap(): stream = wrapper.stream return stream + diff --git a/colorama/tests/initialise_test.py b/colorama/tests/initialise_test.py index 2f7384d..0aac36e 100644 --- a/colorama/tests/initialise_test.py +++ b/colorama/tests/initialise_test.py @@ -6,7 +6,8 @@ from mock import patch from ..ansitowin32 import StreamWrapper -from ..initialise import init +from .. import initialise +from ..initialise import deinit, init from .utils import osname, redirected_output, replace_by orig_stdout = sys.stdout @@ -22,6 +23,8 @@ def setUp(self): def tearDown(self): sys.stdout = orig_stdout sys.stderr = orig_stderr + initialise.orig_stdout = initialise.UNSET + initialise.orig_stderr = initialise.UNSET def assertWrapped(self): self.assertIsNot(sys.stdout, orig_stdout, 'stdout should be wrapped') @@ -75,11 +78,29 @@ def testInitWrapOffDoesntWrapOnWindows(self): def testInitWrapOffIncompatibleWithAutoresetOn(self): self.assertRaises(ValueError, lambda: init(autoreset=True, wrap=False)) - @patch('colorama.ansitowin32.winterm', None) @patch('colorama.ansitowin32.winapi_test', lambda *_: True) - def testInitOnlyWrapsOnce(self): - with osname("nt"): + def testInitTwiceCanBeUndoneWithDeinitOnce(self): + with osname('nt'): + self.assertNotWrapped() + init() + self.assertWrapped() init() + self.assertWrapped() + deinit() + self.assertNotWrapped() + + @patch('colorama.ansitowin32.winapi_test', lambda *_: True) + def testInitDeinitInitWorks(self): + with osname('nt'): + self.assertNotWrapped() + init() + self.assertWrapped() + deinit() + self.assertNotWrapped() + init() + self.assertWrapped() + deinit() + self.assertNotWrapped() init() self.assertWrapped() diff --git a/colorama/tests/utils.py b/colorama/tests/utils.py index de2abf5..a8af441 100644 --- a/colorama/tests/utils.py +++ b/colorama/tests/utils.py @@ -18,16 +18,20 @@ def isatty(self): def osname(name): orig = os.name os.name = name - yield - os.name = orig + try: + yield + finally: + os.name = orig @contextmanager def redirected_output(): orig = sys.stdout sys.stdout = Mock() sys.stdout.isatty = lambda: False - yield - sys.stdout = orig + try: + yield + finally: + sys.stdout = orig @contextmanager def replace_by(stream): @@ -35,9 +39,11 @@ def replace_by(stream): orig_stderr = sys.stderr sys.stdout = stream sys.stderr = stream - yield - sys.stdout = orig_stdout - sys.stderr = orig_stderr + try: + yield + finally: + sys.stdout = orig_stdout + sys.stderr = orig_stderr @contextmanager def replace_original_by(stream): @@ -45,14 +51,19 @@ def replace_original_by(stream): orig_stderr = sys.__stderr__ sys.__stdout__ = stream sys.__stderr__ = stream - yield - sys.__stdout__ = orig_stdout - sys.__stderr__ = orig_stderr + try: + yield + finally: + sys.__stdout__ = orig_stdout + sys.__stderr__ = orig_stderr @contextmanager def pycharm(): os.environ["PYCHARM_HOSTED"] = "1" non_tty = StreamNonTTY() - with replace_by(non_tty), replace_original_by(non_tty): - yield - del os.environ["PYCHARM_HOSTED"] + try: + with replace_by(non_tty), replace_original_by(non_tty): + yield + finally: + del os.environ["PYCHARM_HOSTED"] +