diff --git a/selftests/test_leaks.py b/selftests/test_leaks.py new file mode 100644 index 0000000..29f0a80 --- /dev/null +++ b/selftests/test_leaks.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: GPL-2.0-or-later +# Copyright (C) 2022 David Lamparter for NetDEF, Inc. +""" +basic tests for topotato.leaks FD checks +""" + +import sys +import os +import socket +import pytest + +from topotato.leaks import fdinfo, FDState, FDDelta + + +def test_fdinfo_pipe(): + a, b = os.pipe() + try: + assert fdinfo(a).startswith("pipe") + finally: + os.close(a) + os.close(b) + + +def test_fdinfo_socket(): + a, b = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + try: + i = fdinfo(a.fileno()) + finally: + a.close() + b.close() + + assert i.startswith("socket") + assert "AF_UNIX" in i + assert "SOCK_STREAM" in i + + +def test_fdinfo_sockaddr(): + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP) as fd: + fd.bind(("::1", 0)) + i = fdinfo(fd.fileno()) + + assert i.startswith("socket") + assert "AF_INET6" in i + assert "SOCK_DGRAM" in i + assert "IPPROTO_UDP" in i + assert "'::1'" in i + + +def test_fdinfo_dev(): + with open("/dev/null", "r") as fd: + assert fdinfo(fd.fileno()).startswith("chardev") + + +def test_fdinfo_ns(): + if sys.platform != "linux": + pytest.skip("Linux only test") + + with open("/proc/self/ns/mnt", "r") as fd: + i = fdinfo(fd.fileno()) + assert i.startswith("nsfd") + assert "mnt" in i + + +def test_fdstate(): + state = FDState() + assert 1 in state + + +def test_fddelta(): + state0 = FDState() + state1 = FDState() + + a, b = os.pipe() + + state2 = FDState() + + c, d = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + os.dup2(d.fileno(), b) + os.close(a) + + state3 = FDState() + + delta01 = FDDelta(state0, state1) + assert len(delta01) == 0 + + delta12 = FDDelta(state1, state2) + assert a in delta12.opened + assert b in delta12.opened + + delta23 = FDDelta(state2, state3) + assert a in delta23.closed + assert b in delta23.changed + + c.close() + d.close() diff --git a/topotato/base.py b/topotato/base.py index a3f872f..a43cce3 100644 --- a/topotato/base.py +++ b/topotato/base.py @@ -46,6 +46,7 @@ from .livescapy import LiveScapy from .generatorwrap import GeneratorWrapper, GeneratorChecks from .network import TopotatoNetwork +from .leaks import FDState, FDDelta, fdinfo if typing.TYPE_CHECKING: from types import TracebackType @@ -536,6 +537,8 @@ def reportinfo(self): return fspath, float("-inf"), "startup" def setup(self): + self.cls_node.fdstate_start = FDState() + # this needs to happen before TopotatoItem.setup, since that accesses # cls_node.netinst with _SkipMgr(self): @@ -584,6 +587,18 @@ def setup(self): def __call__(self): self.cls_node.do_stop(self) + fdstate_end = FDState() + delta = FDDelta(self.cls_node.fdstate_start, fdstate_end) + + if delta: + _logger.error("FD leaks detected:") + for fd in sorted(delta.closed): + _logger.error("FD %4d closed", fd) + for fd in sorted(delta.changed): + _logger.error("FD %4d differs, now: %s", fd, fdinfo(fd)) + for fd in sorted(delta.opened): + _logger.error("FD %4d opened: %s", fd, fdinfo(fd)) + class TestBase: """ @@ -847,6 +862,8 @@ class TopotatoClass(_pytest.python.Class): started_ts: float netinst: "TopotatoNetwork" + fdstate_start: FDState + # pylint: disable=protected-access @classmethod def from_hook(cls, obj, collector, name): diff --git a/topotato/leaks.py b/topotato/leaks.py new file mode 100644 index 0000000..9c38565 --- /dev/null +++ b/topotato/leaks.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: GPL-2.0-or-later +# Copyright (C) 2024 David Lamparter for NetDEF, Inc. +""" +FD leak checks +""" + +import sys +import os +import stat +import socket +import fcntl +import errno +import itertools + +from typing import ( + Dict, + Optional, + Set, + Tuple, +) + +_types_s = ["SOCK_STREAM", "SOCK_DGRAM", "SOCK_SEQPACKET", "SOCK_RAW"] + +_afs = {int(getattr(socket, n)): n for n in dir(socket) if n.startswith("AF_")} +_types = {int(getattr(socket, n)): n for n in _types_s if hasattr(socket, n)} +_ipprotos = { + int(getattr(socket, n)): n for n in dir(socket) if n.startswith("IPPROTO_") +} + +if sys.platform == "linux": + from .nswrap import getnstype + +else: + + def getnstype(fd: int) -> Optional[str]: # pylint: disable=unused-argument + return None + + +def _hexbytes(i): + if not isinstance(i, bytes): + return i + return ":".join("%02x" % b for b in i) + + +def _socknamewrap(fn): + try: + name = fn() + except OSError as e: + if e.errno == errno.ENOTCONN: + return "not_connected" + if e.errno == errno.EOPNOTSUPP: + return "not_supported" + return f"E({e!r})" + + if isinstance(name, tuple): + name = tuple(_hexbytes(i) for i in name) + return repr(name) + + +# pylint: disable=too-many-locals,too-many-return-statements,too-many-branches +def fdinfo(fd: int) -> str: + """ + Give a human-usable string description of an open file descriptor. + + Note this shouldn't raise an exception if something goes wrong since it is + a debugging aid. + """ + + extra = [] + + try: + st = os.fstat(fd) + except OSError as e: + return f"stat_failed({e!r})" + + try: + fdlink = os.readlink(f"/proc/self/fd/{fd}") + except OSError: + fdlink = None + if fdlink: + extra.append(f", link={fdlink!r}") + + try: + fdflags = fcntl.fcntl(fd, fcntl.F_GETFD) + except OSError: + fdflags = 0 + if fdflags & fcntl.FD_CLOEXEC: + extra.append(", cloexec") + + extrastr = "".join(extra) + + nstype = getnstype(fd) + + try: + if stat.S_ISSOCK(st.st_mode): + with socket.fromfd(fd, family=-1, type=-1) as s: + # socket.fromfd does a dup() on the fd. otherwise the fd + # would be b0rked afterwards when s is closed + assert s.fileno() != fd + + af = s.getsockopt(socket.SOL_SOCKET, socket.SO_DOMAIN) + typ = s.getsockopt(socket.SOL_SOCKET, socket.SO_TYPE) + protocol = s.getsockopt(socket.SOL_SOCKET, socket.SO_PROTOCOL) + + sockname = _socknamewrap(s.getsockname) + peername = _socknamewrap(s.getpeername) + + if af in {socket.AF_INET, socket.AF_INET6}: + protostr = _ipprotos.get(protocol, str(protocol)) + else: + protostr = str(protocol) + + return f"socket({_afs.get(af, str(af))}, {_types.get(typ, str(typ))}, {protostr}, sockname={sockname}, peername={peername}{extrastr})" + + if nstype is not None: + major, minor = st.st_dev >> 8, st.st_dev & 0xFF + return f"nsfd({nstype}, dev={major}:{minor}, inode={st.st_ino}, mode={stat.S_IMODE(st.st_mode):#o}{extrastr})" + + basic = { + "file": stat.S_ISREG, + "dir": stat.S_ISDIR, + "chardev": stat.S_ISCHR, + "blkdev": stat.S_ISBLK, + } + for kind, test in basic.items(): + if test(st.st_mode): + major, minor = st.st_dev >> 8, st.st_dev & 0xFF + return f"{kind}(dev={major}:{minor}, inode={st.st_ino}, mode={stat.S_IMODE(st.st_mode):#o}{extrastr})" + + if stat.S_ISFIFO(st.st_mode): + return ( + f"pipe(inode={st.st_ino}, mode={stat.S_IMODE(st.st_mode):#o}{extrastr})" + ) + + return f"?({st!r}{extrastr})" + + except OSError as e: + return f"{st!r} [Exc: {e!r}]" + + +class FDState(Dict[int, Tuple[int, int, int]]): + """ + Capture a snapshot of open file descriptor state. + + Does not hold FDs open, that would defeat the purpose. Just record types + and dev/ino numbers to compare. + """ + + stop_after = 256 + + @staticmethod + def _key(st: os.stat_result) -> Tuple[int, int, int]: + return (stat.S_IFMT(st.st_mode), st.st_dev, st.st_ino) + + def __init__(self): + super().__init__() + + stop = 0 + for fd in itertools.count(): + st = None + try: + st = os.fstat(fd) + except OSError as e: + if e.errno != errno.EBADF: + raise + + stop += 1 + if stop >= self.stop_after: + break + continue + + self[fd] = self._key(st) + stop = 0 + + +class FDDelta: + """ + Changes between two :py:class:`FDState`. + """ + + opened: Set[int] + changed: Set[int] + closed: Set[int] + + def __init__(self, before: FDState, after: FDState): + self.before = before + self.after = after + + k1 = set(before.keys()) + k2 = set(after.keys()) + self.opened = k2 - k1 + self.closed = k1 - k2 + self.changed = set() + for fd in k1 & k2: + if before[fd] != after[fd]: + self.changed.add(fd) + + def __len__(self): + """ + Cumulative size of constituent sets, mostly for quick boolean checks. + """ + return len(self.opened) + len(self.closed) + len(self.changed) + + def asdict(self): + items = ["opened", "changed", "closed"] + return {k: {fd: fdinfo(fd) for fd in getattr(self, k)} for k in items} diff --git a/topotato/nswrap.py b/topotato/nswrap.py index c66a646..28606eb 100644 --- a/topotato/nswrap.py +++ b/topotato/nswrap.py @@ -7,11 +7,16 @@ import sys import os import time +import fcntl import ctypes import ctypes.util import errno -from typing import List, ClassVar +from typing import ( + ClassVar, + List, + Optional, +) from .defer import subprocess from .utils import LockedFile, PathDict @@ -29,6 +34,13 @@ CLONE_NEWNS = 0x00020000 CLONE_NEWNET = 0x40000000 +_nstypes = { + CLONE_NEWNS: "mnt", + CLONE_NEWNET: "net", +} + +NS_GET_NSTYPE = (0xB7 << 8) | 0x03 + def setns(nsfd: int, nstype: int = 0): ret = _setns(nsfd, nstype) @@ -44,6 +56,15 @@ def unshare(nstype: int = 0): raise OSError(_errno, os.strerror(_errno)) +def getnstype(fd: int) -> Optional[str]: + try: + nstype = fcntl.ioctl(fd, NS_GET_NSTYPE) + except OSError: + return None + + return _nstypes.get(nstype, hex(nstype)) + + class LinuxNamespaceJoinFailed(SystemError): pass