Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

properly close OutStream and various fixes #1305

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions ipykernel/_version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
store the current version info of the server.
"""
from __future__ import annotations

import re

# Version string must appear intact for hatch versioning
Expand Down
1 change: 1 addition & 0 deletions ipykernel/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ def embed_kernel(module=None, local_ns=None, **kwargs):
app.kernel.user_ns = local_ns
app.shell.set_completer_frame() # type:ignore[union-attr]
app.start()
app.close()
1 change: 1 addition & 0 deletions ipykernel/inprocess/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

from jupyter_client.channelsabc import HBChannelABC

Expand Down
14 changes: 6 additions & 8 deletions ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ def fileno(self):
"""
Things like subprocess will peak and write to the fileno() of stderr/stdout.
"""
if getattr(self, "_original_stdstream_copy", None) is not None:
return self._original_stdstream_copy
if getattr(self, "_original_stdstream_fd", None) is not None:
return self._original_stdstream_fd
msg = "fileno"
raise io.UnsupportedOperation(msg)

Expand Down Expand Up @@ -527,10 +527,7 @@ def __init__(
# echo on the _copy_ we made during
# this is the actual terminal FD now
echo = io.TextIOWrapper(
io.FileIO(
self._original_stdstream_copy,
"w",
)
io.FileIO(self._original_stdstream_copy, "w", closefd=False)
)
self.echo = echo
else:
Expand Down Expand Up @@ -595,9 +592,10 @@ def close(self):
self._should_watch = False
# thread won't wake unless there's something to read
# writing something after _should_watch will not be echoed
os.write(self._original_stdstream_fd, b"\0")
if self.watch_fd_thread is not None:
if self.watch_fd_thread is not None and self.watch_fd_thread.is_alive():
os.write(self._original_stdstream_fd, b"\0")
self.watch_fd_thread.join()
self.echo = None
# restore original FDs
os.dup2(self._original_stdstream_copy, self._original_stdstream_fd)
os.close(self._original_stdstream_copy)
Expand Down
76 changes: 59 additions & 17 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp, ConnectionFileMix

_ports = Dict()

_original_io = Any()
_log_map = Any()
_io_modified = Bool(False)
_blackhole = Any()

subcommands = {
"install": (
"ipykernel.kernelspec.InstallIPythonKernelSpecApp",
Expand Down Expand Up @@ -470,56 +475,93 @@ def log_connection_info(self):

def init_blackhole(self):
"""redirects stdout/stderr to devnull if necessary"""
self._save_io()
if self.no_stdout or self.no_stderr:
blackhole = open(os.devnull, "w") # noqa: SIM115
# keep reference around so that it would not accidentally close the pipe fds
self._blackhole = open(os.devnull, "w") # noqa: SIM115
if self.no_stdout:
sys.stdout = sys.__stdout__ = blackhole # type:ignore[misc]
if sys.stdout is not None:
sys.stdout.flush()
sys.stdout = self._blackhole
if self.no_stderr:
sys.stderr = sys.__stderr__ = blackhole # type:ignore[misc]
if sys.stderr is not None:
sys.stderr.flush()
sys.stderr = self._blackhole

def init_io(self):
"""Redirect input streams and set a display hook."""
self._save_io()
if self.outstream_class:
outstream_factory = import_item(str(self.outstream_class))
if sys.stdout is not None:
sys.stdout.flush()

e_stdout = None if self.quiet else sys.__stdout__
e_stderr = None if self.quiet else sys.__stderr__
e_stdout = None if self.quiet else sys.stdout
e_stderr = None if self.quiet else sys.stderr

if not self.capture_fd_output:
outstream_factory = partial(outstream_factory, watchfd=False)

if sys.stdout is not None:
sys.stdout.flush()
sys.stdout = outstream_factory(self.session, self.iopub_thread, "stdout", echo=e_stdout)

if sys.stderr is not None:
sys.stderr.flush()
sys.stderr = outstream_factory(self.session, self.iopub_thread, "stderr", echo=e_stderr)

if hasattr(sys.stderr, "_original_stdstream_copy"):
for handler in self.log.handlers:
if isinstance(handler, StreamHandler) and (handler.stream.buffer.fileno() == 2):
if (
isinstance(handler, StreamHandler)
and (buffer := getattr(handler.stream, "buffer", None))
and (fileno := getattr(buffer, "fileno", None))
and fileno() == sys.stderr._original_stdstream_fd # type:ignore[attr-defined]
):
self.log.debug("Seeing logger to stderr, rerouting to raw filedescriptor.")

handler.stream = TextIOWrapper(
FileIO(
sys.stderr._original_stdstream_copy,
"w",
)
io_wrapper = TextIOWrapper(
FileIO(sys.stderr._original_stdstream_copy, "w", closefd=False)
)
self._log_map[id(io_wrapper)] = handler.stream
handler.stream = io_wrapper
if self.displayhook_class:
displayhook_factory = import_item(str(self.displayhook_class))
self.displayhook = displayhook_factory(self.session, self.iopub_socket)
sys.displayhook = self.displayhook

self.patch_io()

def _save_io(self):
if not self._io_modified:
self._original_io = sys.stdout, sys.stderr, sys.displayhook
self._log_map = {}
self._io_modified = True

def reset_io(self):
"""restore original io

restores state after init_io
"""
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
sys.displayhook = sys.__displayhook__
if not self._io_modified:
return
stdout, stderr, displayhook = sys.stdout, sys.stderr, sys.displayhook
sys.stdout, sys.stderr, sys.displayhook = self._original_io
self._original_io = None
self._io_modified = False
if finish_displayhook := getattr(displayhook, "finish_displayhook", None):
finish_displayhook()
if hasattr(stderr, "_original_stdstream_copy"):
for handler in self.log.handlers:
if orig_stream := self._log_map.get(id(handler.stream)):
self.log.debug("Seeing modified logger, rerouting back to stderr")
handler.stream = orig_stream
self._log_map = None
if self.outstream_class:
outstream_factory = import_item(str(self.outstream_class))
if isinstance(stderr, outstream_factory):
stderr.close()
if isinstance(stdout, outstream_factory):
stdout.close()
if self._blackhole:
self._blackhole.close()

def patch_io(self):
"""Patch important libraries that can't handle sys.stdout forwarding"""
Expand Down
2 changes: 2 additions & 0 deletions ipykernel/pickleutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations

import copy
import pickle
import sys
Expand Down
2 changes: 2 additions & 0 deletions ipykernel/thread.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Base class for threads."""
from __future__ import annotations

import typing as t
from threading import Event, Thread

Expand Down
1 change: 1 addition & 0 deletions tests/test_kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_blackhole():
app.no_stderr = True
app.no_stdout = True
app.init_blackhole()
app.close()


def test_start_app():
Expand Down
Loading