From a1d9e493e0ac06ad7f20a3d9f3b8742c449a0716 Mon Sep 17 00:00:00 2001 From: James McKinney <26463+jpmckinney@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:07:03 -0400 Subject: [PATCH] chore: Refactor webservice.py for easier modification --- scrapyd/exceptions.py | 4 - scrapyd/webservice.py | 186 +++++++++++++++++++++++------------------- 2 files changed, 100 insertions(+), 90 deletions(-) diff --git a/scrapyd/exceptions.py b/scrapyd/exceptions.py index 5db6ab49..d20ffe52 100644 --- a/scrapyd/exceptions.py +++ b/scrapyd/exceptions.py @@ -2,9 +2,5 @@ class ScrapydError(Exception): """Base class for exceptions from within this package""" -class MissingRequiredArgument(ScrapydError): - """Raised if a required argument is missing""" - - class RunnerError(ScrapydError): """Raised if the runner returns an error code""" diff --git a/scrapyd/webservice.py b/scrapyd/webservice.py index 81d14971..5aee49b7 100644 --- a/scrapyd/webservice.py +++ b/scrapyd/webservice.py @@ -1,34 +1,60 @@ +import functools import sys import traceback import uuid import zipfile from copy import copy from io import BytesIO +from typing import Optional from twisted.python import log -from twisted.web import http +from twisted.web import error, http -from scrapyd.exceptions import MissingRequiredArgument from scrapyd.jobstorage import job_items_url, job_log_url from scrapyd.utils import JsonResource, UtilsCache, get_spider_list, native_stringify_dict -def _get_required_param(args, param): - try: - return args[param] - except KeyError as e: - raise MissingRequiredArgument(str(e)) +def param( + decoded: str, + dest: Optional[str] = None, + required: bool = True, + default=None, + multiple: bool = False, + type=str, +): + encoded = decoded.encode() + if dest is None: + dest = decoded + if callable(default): + default = default() + + def decorator(func): + @functools.wraps(func) + def wrapper(self, txrequest, *args, **kwargs): + if encoded not in txrequest.args: + if required: + raise error.Error(code=http.OK, message=b"'%b' parameter is required" % encoded) + + value = default + else: + values = [] + for value in txrequest.args.pop(encoded): + values.append(value.decode() if type is str else type(value)) + if multiple: + value = values + else: + value = values[0] + kwargs[dest] = value -def _pop_required_param(args, param): - try: - return args.pop(param) - except KeyError as e: - raise MissingRequiredArgument(str(e)) + return func(self, txrequest, *args, **kwargs) + return wrapper -class WsResource(JsonResource): + return decorator + +class WsResource(JsonResource): def __init__(self, root): JsonResource.__init__(self) self.root = root @@ -40,11 +66,11 @@ def render(self, txrequest): if self.root.debug: return traceback.format_exc().encode('utf-8') log.err() - if isinstance(e, MissingRequiredArgument): - message = f"{e} parameter is required" + if isinstance(e, error.Error): + message = e.message.decode() else: message = f"{type(e).__name__}: {str(e)}" - r = self._error(message) + r = {"node_name": self.root.nodename, "status": "error", "message": message} return self.encode_object(r, txrequest).encode('utf-8') def render_OPTIONS(self, txrequest): @@ -56,12 +82,8 @@ def render_OPTIONS(self, txrequest): txrequest.setHeader('Allow', ', '.join(methods)) txrequest.setResponseCode(http.NO_CONTENT) - def _error(self, message): - return {"node_name": self.root.nodename, "status": "error", "message": message} - class DaemonStatus(WsResource): - def render_GET(self, txrequest): pending = sum(q.count() for q in self.root.poller.queues.values()) running = len(self.root.launcher.processes) @@ -77,63 +99,70 @@ def render_GET(self, txrequest): class Schedule(WsResource): - - def render_POST(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - settings = args.pop('setting', []) - settings = dict(x.split('=', 1) for x in settings) - args = {k: v[0] for k, v in args.items()} - project = _pop_required_param(args, 'project') - spider = _pop_required_param(args, 'spider') - version = args.get('_version', '') - priority = float(args.pop('priority', 0)) + @param('project') + @param('spider') + @param('_version', dest='version', required=False, default='') + # See https://github.com/scrapy/scrapyd/pull/215 + @param('jobid', required=False, default=lambda: uuid.uuid1().hex) + @param('priority', required=False, default=0, type=float) + @param('setting', required=False, default=list, multiple=True) + def render_POST(self, txrequest, project, spider, version, jobid, priority, setting): + spider_arguments = {k: v[0] for k, v in native_stringify_dict(copy(txrequest.args), keys_only=False).items()} spiders = get_spider_list(project, version=version) if spider not in spiders: - return self._error("spider '%s' not found" % spider) - args['settings'] = settings - jobid = args.pop('jobid', uuid.uuid1().hex) - args['_job'] = jobid - self.root.scheduler.schedule(project, spider, priority=priority, **args) + raise error.Error(code=http.OK, message=b"spider '%b' not found" % spider.encode()) + self.root.scheduler.schedule( + project, + spider, + priority=priority, + settings=dict(s.split('=', 1) for s in setting), + version=version, + _job=jobid, + **spider_arguments, + ) return {"node_name": self.root.nodename, "status": "ok", "jobid": jobid} class Cancel(WsResource): - - def render_POST(self, txrequest): - args = {k: v[0] for k, v in native_stringify_dict(copy(txrequest.args), keys_only=False).items()} - project = _get_required_param(args, 'project') - jobid = _get_required_param(args, 'job') - # Instead of os.name, use sys.platform, which disambiguates Cygwin, which implements SIGINT not SIGBREAK. - # https://cygwin.com/cygwin-ug-net/kill.html - # https://github.com/scrapy/scrapy/blob/06f9c289d1c92dbb8e41a837b886e5cadb81a061/tests/test_crawler.py#L886 - signal = args.get('signal', 'INT' if sys.platform != 'win32' else 'BREAK') + @param('project') + @param('job') + # Instead of os.name, use sys.platform, which disambiguates Cygwin, which implements SIGINT not SIGBREAK. + # https://cygwin.com/cygwin-ug-net/kill.html + # https://github.com/scrapy/scrapy/blob/06f9c28/tests/test_crawler.py#L886 + @param('signal', required=False, default='INT' if sys.platform != 'win32' else 'BREAK') + def render_POST(self, txrequest, project, job, signal): prevstate = None + try: queue = self.root.poller.queues[project] except KeyError as e: - return self._error(f"project {e} not found") - c = queue.remove(lambda x: x["_job"] == jobid) + raise error.Error(code=http.OK, message=b"project %b not found" % str(e).encode()) + + c = queue.remove(lambda x: x["_job"] == job) if c: prevstate = "pending" + spiders = self.root.launcher.processes.values() for s in spiders: - if s.project == project and s.job == jobid: + if s.project == project and s.job == job: s.transport.signalProcess(signal) prevstate = "running" + break + return {"node_name": self.root.nodename, "status": "ok", "prevstate": prevstate} class AddVersion(WsResource): - - def render_POST(self, txrequest): - egg = _pop_required_param(txrequest.args, b'egg')[0] + @param('project') + @param('version') + @param('egg', type=bytes) + def render_POST(self, txrequest, project, version, egg): if not zipfile.is_zipfile(BytesIO(egg)): - return self._error("egg is not a ZIP file (if using curl, use egg=@path not egg=path)") - eggf = BytesIO(egg) - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - project = _get_required_param(args, 'project')[0] - version = _get_required_param(args, 'version')[0] - self.root.eggstorage.put(eggf, project, version) + raise error.Error( + code=http.OK, message=b"egg is not a ZIP file (if using curl, use egg=@path not egg=path)" + ) + + self.root.eggstorage.put(BytesIO(egg), project, version) spiders = get_spider_list(project, version=version) self.root.update_projects() UtilsCache.invalid_cache(project) @@ -142,38 +171,30 @@ def render_POST(self, txrequest): class ListProjects(WsResource): - def render_GET(self, txrequest): projects = list(self.root.scheduler.list_projects()) return {"node_name": self.root.nodename, "status": "ok", "projects": projects} class ListVersions(WsResource): - - def render_GET(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - project = _get_required_param(args, 'project')[0] + @param('project') + def render_GET(self, txrequest, project): versions = self.root.eggstorage.list(project) return {"node_name": self.root.nodename, "status": "ok", "versions": versions} class ListSpiders(WsResource): - - def render_GET(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - project = _get_required_param(args, 'project')[0] - version = args.get('_version', [''])[0] + @param('project') + @param('_version', dest='version', required=False, default='') + def render_GET(self, txrequest, project, version): spiders = get_spider_list(project, runner=self.root.runner, version=version) return {"node_name": self.root.nodename, "status": "ok", "spiders": spiders} class Status(WsResource): - - def render_GET(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - job = _get_required_param(args, 'job')[0] - project = args.get('project', [None])[0] - + @param('job') + @param('project', required=False) + def render_GET(self, txrequest, job, project): spiders = self.root.launcher.processes.values() queues = self.root.poller.queues @@ -199,11 +220,8 @@ def render_GET(self, txrequest): class ListJobs(WsResource): - - def render_GET(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - project = args.get('project', [None])[0] - + @param('project', required=False) + def render_GET(self, txrequest, project): spiders = self.root.launcher.processes.values() queues = self.root.poller.queues @@ -242,10 +260,8 @@ def render_GET(self, txrequest): class DeleteProject(WsResource): - - def render_POST(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - project = _get_required_param(args, 'project')[0] + @param('project') + def render_POST(self, txrequest, project): self._delete_version(project) UtilsCache.invalid_cache(project) return {"node_name": self.root.nodename, "status": "ok"} @@ -256,11 +272,9 @@ def _delete_version(self, project, version=None): class DeleteVersion(DeleteProject): - - def render_POST(self, txrequest): - args = native_stringify_dict(copy(txrequest.args), keys_only=False) - project = _get_required_param(args, 'project')[0] - version = _get_required_param(args, 'version')[0] + @param('project') + @param('version') + def render_POST(self, txrequest, project, version): self._delete_version(project, version) UtilsCache.invalid_cache(project) return {"node_name": self.root.nodename, "status": "ok"}