diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 59f7856688ee..5d3f8bda04c0 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -266,12 +266,37 @@ def handle(self) -> None: auth_token = self.server.auth_token # type: ignore[attr-defined] def poll(func: Callable[[], bool]) -> None: + poller = None + if os.name == "posix": + # On posix systems use poll to avoid problems with file descriptor + # numbers above 1024. + poller = select.poll() + poller.register(self.rfile, select.POLLIN) + while not self.server.server_shutdown: # type: ignore[attr-defined] # Poll every 1 second for new data -- don't block in case of shutdown. - r, _, _ = select.select([self.rfile], [], [], 1) - if self.rfile in r and func(): + r: set | list + if poller is not None: + r = set() + # Unlike select, poll timeout is in millis. + for fd, event in poller.poll(1000): + if event & (select.POLLIN | select.POLLHUP): + # Data can be read (for POLLHUP peer hang up, so reads will return + # 0 bytes, in which case we want to break out - this is consistent + # with how select behaves). + r.add(fd) + else: + # Could be POLLERR or POLLNVAL (select would raise in this case). + raise PySparkRuntimeError(f"Polling error - event {event} on fd {fd}") + else: + # If poll is not available, use select. + r, _, _ = select.select([self.rfile.fileno()], [], [], 1) + if self.rfile.fileno() in r and func(): break + if poller is not None: + poller.unregister(self.rfile) + def accum_updates() -> bool: num_updates = read_int(self.rfile) for _ in range(num_updates): diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7495569ecfda..5744fcefa436 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -30,6 +30,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer +from pyspark.errors import PySparkRuntimeError def compute_real_exit_code(exit_code): @@ -173,16 +174,22 @@ def handle_sigterm(*args): poller.register(listen_sock, select.POLLIN) while True: + ready_fds: set | list if poller is not None: - ready_fds = [fd_reverse_map[fd] for fd, _ in poller.poll(1000)] - else: - try: - ready_fds = select.select([0, listen_sock], [], [], 1)[0] - except select.error as ex: - if ex[0] == EINTR: - continue + ready_fds = set() + # Unlike select, poll timeout is in millis. + for fd, event in poller.poll(1000): + if event & (select.POLLIN | select.POLLHUP): + # Data can be read (for POLLHUP peer hang up, so reads will return + # 0 bytes, in which case we want to break out - this is consistent + # with how select behaves). + ready_fds.add(fd_reverse_map[fd]) else: - raise + # Could be POLLERR or POLLNVAL (select would raise in this case). + raise PySparkRuntimeError(f"Polling error - event {event} on fd {fd}") + else: + # If poll is not available, use select. + ready_fds = select.select([0, listen_sock], [], [], 1)[0] if 0 in ready_fds: try: