Skip to content

Commit

Permalink
Added '.run()' method to TelnetServer.
Browse files Browse the repository at this point in the history
This is a better than having a separate .start() and .stop() when it comes to
cancellation.
  • Loading branch information
jonathanslenders committed Feb 14, 2023
1 parent 233a818 commit 5eb6efd
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 35 deletions.
5 changes: 1 addition & 4 deletions examples/telnet/hello-world.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ async def interact(connection):

async def main():
server = TelnetServer(interact=interact, port=2323)
server.start()

# Run forever.
await Future()
await server.run()


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions examples/telnet/toolbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ def get_toolbar():

async def main():
server = TelnetServer(interact=interact, port=2323)
server.start()

# Run forever.
await Future()
await server.run()


if __name__ == "__main__":
Expand Down
85 changes: 58 additions & 27 deletions src/prompt_toolkit/contrib/telnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,11 @@ def __init__(
self.encoding = encoding
self.style = style
self.enable_cpr = enable_cpr

self._run_task: asyncio.Task[None] | None = None
self._application_tasks: list[asyncio.Task[None]] = []

self.connections: set[TelnetConnection] = set()
self._listen_socket: socket.socket | None = None

@classmethod
def _create_socket(cls, host: str, port: int) -> socket.socket:
Expand All @@ -298,44 +299,74 @@ def _create_socket(cls, host: str, port: int) -> socket.socket:
s.listen(4)
return s

def start(self) -> None:
async def run(self, ready_cb: Callable[[], None] | None = None) -> None:
"""
Start the telnet server.
Don't forget to call `loop.run_forever()` after doing this.
Run the telnet server, until this gets cancelled.
:param ready_cb: Callback that will be called at the point that we're
actually listening.
"""
self._listen_socket = self._create_socket(self.host, self.port)
socket = self._create_socket(self.host, self.port)
logger.info(
"Listening for telnet connections on %s port %r", self.host, self.port
)

get_running_loop().add_reader(self._listen_socket, self._accept)
get_running_loop().add_reader(socket, lambda: self._accept(socket))

if ready_cb:
ready_cb()

try:
# Run forever, until cancelled.
await asyncio.Future()
finally:
get_running_loop().remove_reader(socket)
socket.close()

# Wait for all applications to finish.
for t in self._application_tasks:
t.cancel()

# (This is similar to
# `Application.cancel_and_wait_for_background_tasks`. We wait for the
# background tasks to complete, but don't propagate exceptions, because
# we can't use `ExceptionGroup` yet.)
if len(self._application_tasks) > 0:
await asyncio.wait(
self._application_tasks,
timeout=None,
return_when=asyncio.ALL_COMPLETED,
)

def start(self) -> None:
"""
Start the telnet server (stop by calling and awaiting `stop()`).
Note: When possible, it's better to call `.run()` instead.
"""
if self._run_task is not None:
# Already running.
return

self._run_task = get_running_loop().create_task(self.run())

async def stop(self) -> None:
if self._listen_socket:
get_running_loop().remove_reader(self._listen_socket)
self._listen_socket.close()

# Wait for all applications to finish.
for t in self._application_tasks:
t.cancel()

# (This is similar to
# `Application.cancel_and_wait_for_background_tasks`. We wait for the
# background tasks to complete, but don't propagate exceptions, because
# we can't use `ExceptionGroup` yet.)
if len(self._application_tasks) > 0:
await asyncio.wait(
self._application_tasks, timeout=None, return_when=asyncio.ALL_COMPLETED
)
"""
Stop a telnet server that was started using `.start()` and wait for the
cancellation to complete.
"""
if self._run_task is not None:
self._run_task.cancel()
try:
await self._run_task
except asyncio.CancelledError:
pass

def _accept(self) -> None:
def _accept(self, listen_socket: socket.socket) -> None:
"""
Accept new incoming connection.
"""
if self._listen_socket is None:
return # Should not happen. `_accept` is called after `start`.

conn, addr = self._listen_socket.accept()
conn, addr = listen_socket.accept()
logger.info("New connection %r %r", *addr)

# Run application for this connection.
Expand Down

0 comments on commit 5eb6efd

Please sign in to comment.