Skip to content

Commit

Permalink
feat: Implement notifications listener
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Oct 3, 2019
1 parent 8bcfa82 commit 33ee9ae
Show file tree
Hide file tree
Showing 3 changed files with 255 additions and 5 deletions.
44 changes: 44 additions & 0 deletions src/aria2p/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
interact easily with a remote aria2c process.
"""
import shutil
import threading
from base64 import b64encode
from pathlib import Path

Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(self, client=None):
if client is None:
client = Client()
self.client = client
self.listener = None

def add_magnet(self, magnet_uri, options=None, position=None):
"""
Expand Down Expand Up @@ -625,3 +627,45 @@ def copy_files(downloads, to_directory, force=False):
else:
results.append(False)
return results

def listen_to_notifications(self, threaded=False, **kwargs):
"""
Start listening to aria2 notifications via WebSocket.
This method differs from :method:`~aria2p.client.Client.listen_to_notifications` in that it expects callbacks
accepting two arguments, "api" and "gid", instead of only "gid". Accepting "api" allows to use the high-level
methods of the API class.
Stop listening to notifications with the :method:`~aria2p.api.API.stop_listening` method.
Args:
threaded (bool): Whether to start the listening loop in a thread or not (non-blocking or blocking).
"""

def closure(callback):
return (lambda gid: callback(self, gid)) if callable(callback) else None

def run():
self.client.listen_to_notifications(
**{key: closure(value) if key.startswith("on_") else value for key, value in kwargs.items()}
)

if threaded:
if "handle_signals" in kwargs:
kwargs["handle_signals"] = False
self.listener = threading.Thread(target=run)
self.listener.start()
else:
run()

def stop_listening(self):
"""
Stop listening to notifications.
If the listening loop was threaded, this method will wait for the thread to finish. The time it takes
for the thread to finish will depend on the timeout given while calling ``listen_to_notifications``.
"""
self.client.stop_listening()
if self.listener:
self.listener.join()
self.listener = None
186 changes: 182 additions & 4 deletions src/aria2p/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
process through the JSON-RPC protocol.
"""


import json

import requests
import websocket
from loguru import logger

from .utils import SignalHandler

DEFAULT_ID = -1
DEFAULT_HOST = "http://localhost"
Expand All @@ -26,6 +29,22 @@
JSONRPC_INTERNAL_ERROR: "Internal JSON-RPC error.",
}

NOTIFICATION_START = "aria2.onDownloadStart"
NOTIFICATION_PAUSE = "aria2.onDownloadPause"
NOTIFICATION_STOP = "aria2.onDownloadStop"
NOTIFICATION_COMPLETE = "aria2.onDownloadComplete"
NOTIFICATION_ERROR = "aria2.onDownloadError"
NOTIFICATION_BT_COMPLETE = "aria2.onBtDownloadComplete"

NOTIFICATION_TYPES = [
NOTIFICATION_START,
NOTIFICATION_PAUSE,
NOTIFICATION_STOP,
NOTIFICATION_COMPLETE,
NOTIFICATION_ERROR,
NOTIFICATION_BT_COMPLETE,
]


class ClientException(Exception):
"""An exception specific to JSON-RPC errors."""
Expand Down Expand Up @@ -161,6 +180,7 @@ def __init__(self, host=DEFAULT_HOST, port=DEFAULT_PORT, secret=""): # nosec
self.host = host
self.port = port
self.secret = secret
self.listening = False

def __str__(self):
return self.server
Expand All @@ -170,6 +190,11 @@ def server(self):
"""Property to return the full remote process / server address."""
return f"{self.host}:{self.port}/jsonrpc"

@property
def ws_server(self):
"""Property to return the full WebSocket remote server address."""
return f"ws{self.host[4:]}:{self.port}/jsonrpc"

# utils
def call(self, method, params=None, msg_id=None, insert_secret=True):
"""
Expand Down Expand Up @@ -292,6 +317,10 @@ def post(self, payload):
"""
return requests.post(self.server, data=payload).json()

@staticmethod
def response_as_exception(response):
return ClientException(response["error"]["code"], response["error"]["message"])

@staticmethod
def res_or_raise(response):
"""
Expand All @@ -307,9 +336,9 @@ def res_or_raise(response):
ClientException: when the response contains an error (client/server error).
See the :class:`~aria2p.client.ClientException` class.
"""
if "result" in response:
return response["result"]
raise ClientException(response["error"]["code"], response["error"]["message"])
if "error" in response:
raise Client.response_as_exception(response)
return response["result"]

@staticmethod
def get_payload(method, params=None, msg_id=None, as_json=True):
Expand Down Expand Up @@ -1598,3 +1627,152 @@ def list_notifications(self):
['aria2.onDownloadStart', 'aria2.onDownloadPause', ...
"""
return self.call(self.LIST_NOTIFICATIONS)

def listen_to_notifications(
self,
on_download_start=None,
on_download_pause=None,
on_download_stop=None,
on_download_complete=None,
on_download_error=None,
on_bt_download_complete=None,
timeout=5,
handle_signals=True,
):
"""
Start listening to aria2 notifications via WebSocket.
This method opens a WebSocket connection to the server and wait for notifications (or events) to be received.
It accepts callbacks as arguments, which are functions accepting one parameter called "gid", for each type
of notification.
Stop listening to notifications with the :method:`~aria2p.client.Client.stop_listening` method.
Args:
on_download_start (func): Callback for the "aria2.onDownloadStart" event.
on_download_pause (func): Callback for the "aria2.onDownloadPause" event.
on_download_stop (func): Callback for the "aria2.onDownloadStop" event.
on_download_complete (func): Callback for the "aria2.onDownloadComplete" event.
on_download_error (func): Callback for the "aria2.onDownloadError" event.
on_bt_download_complete (func): Callback for the "aria2.onBtDownloadComplete" event.
timeout (int): Timeout when waiting for data to be received. Use a small value for faster reactivity
when stopping to listen. Default is 5 seconds.
handle_signals (bool): Whether to add signal handlers to gracefully stop the loop on SIGTERM and SIGINT.
"""
self.listening = True
ws_server = self.ws_server

logger.debug(f"Notifications ({ws_server}): opening WebSocket with timeout={timeout}")
try:
ws = websocket.create_connection(ws_server, timeout=timeout)
except ConnectionRefusedError:
logger.error(f"Notifications ({ws_server}): connection refused. Is the server running?")
return

stopped = SignalHandler(["SIGTERM", "SIGINT"]) if handle_signals else False

while not stopped:
try:
logger.debug(f"Notifications ({ws_server}): waiting for data over WebSocket")
message = ws.recv()
except websocket.WebSocketConnectionClosedException:
logger.error(f"Notifications ({ws_server}): connection to server was closed. Is the server running?")
break
except websocket.WebSocketTimeoutException:
logger.debug(f"Notifications ({ws_server}): reached timeout ({timeout}s)")
else:
notification = Notification.get_or_raise(json.loads(message))
for notification_type, callback in (
(NOTIFICATION_START, on_download_start),
(NOTIFICATION_PAUSE, on_download_pause),
(NOTIFICATION_STOP, on_download_stop),
(NOTIFICATION_COMPLETE, on_download_complete),
(NOTIFICATION_ERROR, on_download_error),
(NOTIFICATION_BT_COMPLETE, on_bt_download_complete),
):
if notification.type == notification_type:
logger.info(
f"Notifications ({ws_server}): received {notification_type} with gid={notification.gid}"
)
if callable(callback):
logger.debug(f"Notifications ({ws_server}): calling {callback} with gid={notification.gid}")
callback(notification.gid)
else:
logger.debug(
f"Notifications ({ws_server}): no callback given for type " + notification.type
)
break

if not self.listening:
logger.debug(f"Notifications ({ws_server}): stopped listening")
break

if stopped:
logger.debug("Notifications: stopped listening after receiving a signal")
self.listening = False

logger.debug(f"Notifications ({ws_server}): closing WebSocket")
ws.close()

def stop_listening(self):
"""
Stop listening to notifications.
Although this method returns instantly, the actual listening loop can take some time to break out,
depending on the timeout that was given to it.
"""
self.listening = False


class Notification:
"""
A helper class for notifications.
You should not need to use this class. It simply provides methods to instantiate a notification with a
message received from the server through a WebSocket, or to raise a ClientException if the message is invalid.
"""

def __init__(self, type, gid):
f"""
Initialization method.
Args:
type (str): The notification type. Possible types are {",".join(NOTIFICATION_TYPES)}.
gid (str): The GID of the download related to the notification.
"""

self.type = type
self.gid = gid

@staticmethod
def get_or_raise(message):
"""
Static method to raise a ClientException when the message is invalid or return a Notification instance.
Args:
message (dict): The JSON-loaded message received over WebSocket.
Returns:
Notification: a Notification instance if the message is valid.
Raises:
ClientException: when the message contains an error.
"""
if "error" in message:
raise Client.response_as_exception(message)
return Notification.from_message(message)

@staticmethod
def from_message(message):
"""
Static method to return an instance of Notification.
This method expects a valid message (not containing errors).
Args:
message (dict): A valid message received over WebSocket.
Returns:
Notification: a Notification instance.
"""
return Notification(type=message["method"], gid=message["params"][0]["gid"])
30 changes: 29 additions & 1 deletion src/aria2p/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
"""
Utils module.
This module contains simple utility functions.
This module contains simple utility classes and functions.
"""
import signal

from loguru import logger


class SignalHandler:
"""A helper class to handle signals."""

def __init__(self, signals):
"""
Initialization method.
Args:
signals (list of str): List of signals names as found in the ``signal`` module (example: SIGTERM).
"""
logger.debug("Signal handler: handling signals " + ", ".join(signals))
self.triggered = False
for sig in signals:
signal.signal(signal.Signals[sig], self.trigger)

def __bool__(self):
"""Return True when one of the given signal was received, False otherwise."""
return self.triggered

def trigger(self, signum, frame):
"""Mark this instance as 'triggered' (a specified signal was received)."""
logger.debug(f"Signal handler: caught signal {signal.Signals(signum).name} ({signum})")
self.triggered = True


def human_readable_timedelta(value):
Expand Down

0 comments on commit 33ee9ae

Please sign in to comment.