Skip to content
19 changes: 17 additions & 2 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,27 @@ def handle(self) -> None:
auth_token = self.server.auth_token # type: ignore[attr-defined]

def poll(func: Callable[[], bool]) -> None:
poller = None
if os.name == "posix":
# On posix systems use poll to avoid problems with file descriptor
# numbers above 1024.
poller = select.poll()
poller.register(self.rfile, select.POLLIN)

while not self.server.server_shutdown: # type: ignore[attr-defined]
# Poll every 1 second for new data -- don't block in case of shutdown.
r, _, _ = select.select([self.rfile], [], [], 1)
if self.rfile in r and func():
if poller is not None:
# Unlike select, poll timeout is in millis. Rule out error events.
r = [fd for fd, event in poller.poll(1000) if event & select.POLLIN]
else:
# If poll is not available, use select.
r, _, _ = select.select([self.rfile.fileno()], [], [], 1)
if self.rfile.fileno() in r and func():
break

if poller is not None:
poller.unregister(self.rfile)

def accum_updates() -> bool:
num_updates = read_int(self.rfile)
for _ in range(num_updates):
Expand Down