Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Commit

Permalink
Merge pull request #112 from bjoernricks/refactor-error
Browse files Browse the repository at this point in the history
Refactor error module
  • Loading branch information
bjoernricks authored Jun 24, 2019
2 parents b39bacb + 27f17db commit 5bb3daf
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 52 deletions.
25 changes: 23 additions & 2 deletions ospd/error.py → ospd/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,34 @@
from ospd.xml import simple_response_str


class OSPDError(Exception):
class OspdError(Exception):
""" Base error class for all Ospd related errors """


class RequiredArgument(OspdError):
"""Raised if a required argument/parameter is missing
Derives from :py:class:`OspdError`
"""

def __init__(self, function, argument):
# pylint: disable=super-init-not-called
self.function = function
self.argument = argument

def __str__(self):
return "{}: Argument {} is required".format(
self.function, self.argument
)


class OspdCommandError(OspdError):

""" This is an exception that will result in an error message to the
client """

def __init__(self, message, command='osp', status=400):
super().__init__()
super().__init__(message)
self.message = message
self.command = command
self.status = status
Expand Down
60 changes: 35 additions & 25 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import defusedxml.ElementTree as secET

from ospd import __version__
from ospd.error import OSPDError
from ospd.errors import OspdCommandError
from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid
from ospd.network import resolve_hostname, target_str_to_list
from ospd.vtfilter import VtsFilter
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(
cafile,
niceness=None, # pylint: disable=unused-argument
customvtfilter=None,
**kwargs # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
""" Initializes the daemon's internal data. """
# @todo: Actually it makes sense to move the certificate params to
Expand Down Expand Up @@ -345,7 +345,7 @@ def set_vts_version(self, vts_version):
vts_version (str): Identifies a unique vts version.
"""
if not vts_version:
raise OSPDError(
raise OspdCommandError(
'A vts_version parameter is required', 'set_vts_version'
)
self.vts_version = vts_version
Expand Down Expand Up @@ -400,16 +400,22 @@ def _preprocess_scan_params(self, xml_params):
try:
params[key] = int(params[key])
except ValueError:
raise OSPDError('Invalid %s value' % key, 'start_scan')
raise OspdCommandError(
'Invalid %s value' % key, 'start_scan'
)
if param_type == 'boolean':
if params[key] not in [0, 1]:
raise OSPDError('Invalid %s value' % key, 'start_scan')
raise OspdCommandError(
'Invalid %s value' % key, 'start_scan'
)
elif param_type == 'selection':
selection = self.get_scanner_param_default(key).split('|')
if params[key] not in selection:
raise OSPDError('Invalid %s value' % key, 'start_scan')
raise OspdCommandError(
'Invalid %s value' % key, 'start_scan'
)
if self.get_scanner_param_mandatory(key) and params[key] == '':
raise OSPDError(
raise OspdCommandError(
'Mandatory %s value is missing' % key, 'start_scan'
)
return params
Expand Down Expand Up @@ -451,7 +457,7 @@ def process_vts_params(self, scanner_vts):
vt_selection[vt_id] = {}
for vt_value in vt:
if not vt_value.attrib.get('id'):
raise OSPDError(
raise OspdCommandError(
'Invalid VT preference. No attribute id',
'start_scan',
)
Expand All @@ -461,7 +467,7 @@ def process_vts_params(self, scanner_vts):
if vt.tag == 'vt_group':
vts_filter = vt.attrib.get('filter', None)
if vts_filter is None:
raise OSPDError(
raise OspdCommandError(
'Invalid VT group. No filter given.', 'start_scan'
)
filters.append(vts_filter)
Expand Down Expand Up @@ -571,7 +577,7 @@ def process_targets_element(cls, scanner_target):
if hosts:
target_list.append([hosts, ports, credentials, exclude_hosts])
else:
raise OSPDError('No target to scan', 'start_scan')
raise OspdCommandError('No target to scan', 'start_scan')

return target_list

Expand All @@ -588,7 +594,7 @@ def handle_start_scan_command(self, scan_et):
if target_str is None or ports_str is None:
target_list = scan_et.find('targets')
if target_list is None or len(target_list) == 0:
raise OSPDError('No targets or ports', 'start_scan')
raise OspdCommandError('No targets or ports', 'start_scan')
else:
scan_targets = self.process_targets_element(target_list)
else:
Expand All @@ -598,21 +604,21 @@ def handle_start_scan_command(self, scan_et):

scan_id = scan_et.attrib.get('scan_id')
if scan_id is not None and scan_id != '' and not valid_uuid(scan_id):
raise OSPDError('Invalid scan_id UUID', 'start_scan')
raise OspdCommandError('Invalid scan_id UUID', 'start_scan')

try:
parallel = int(scan_et.attrib.get('parallel', '1'))
if parallel < 1 or parallel > 20:
parallel = 1
except ValueError:
raise OSPDError(
raise OspdCommandError(
'Invalid value for parallel scans. ' 'It must be a number',
'start_scan',
)

scanner_params = scan_et.find('scanner_params')
if scanner_params is None:
raise OSPDError('No scanner_params element', 'start_scan')
raise OspdCommandError('No scanner_params element', 'start_scan')

params = self._preprocess_scan_params(scanner_params)

Expand All @@ -621,7 +627,7 @@ def handle_start_scan_command(self, scan_et):
scanner_vts = scan_et.find('vt_selection')
if scanner_vts is not None:
if len(scanner_vts) == 0:
raise OSPDError('VTs list is empty', 'start_scan')
raise OspdCommandError('VTs list is empty', 'start_scan')
else:
vt_selection = self.process_vts_params(scanner_vts)

Expand Down Expand Up @@ -653,17 +659,21 @@ def handle_stop_scan_command(self, scan_et):

scan_id = scan_et.attrib.get('scan_id')
if scan_id is None or scan_id == '':
raise OSPDError('No scan_id attribute', 'stop_scan')
raise OspdCommandError('No scan_id attribute', 'stop_scan')
self.stop_scan(scan_id)

return simple_response_str('stop_scan', 200, 'OK')

def stop_scan(self, scan_id):
scan_process = self.scan_processes.get(scan_id)
if not scan_process:
raise OSPDError('Scan not found {0}.'.format(scan_id), 'stop_scan')
raise OspdCommandError(
'Scan not found {0}.'.format(scan_id), 'stop_scan'
)
if not scan_process.is_alive():
raise OSPDError('Scan already stopped or finished.', 'stop_scan')
raise OspdCommandError(
'Scan already stopped or finished.', 'stop_scan'
)

self.set_scan_status(scan_id, ScanStatus.STOPPED)
logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
Expand Down Expand Up @@ -819,12 +829,12 @@ def handle_client_stream(self, stream, is_unix=False):
return
try:
response = self.handle_command(data)
except OSPDError as exception:
except OspdCommandError as exception:
response = exception.as_xml()
logger.debug('Command error: %s', exception.message)
except Exception: # pylint: disable=broad-except
logger.exception('While handling client command:')
exception = OSPDError('Fatal error', 'error')
exception = OspdCommandError('Fatal error', 'error')
response = exception.as_xml()
if is_unix:
send_method = stream.send
Expand Down Expand Up @@ -917,7 +927,7 @@ def start_scan(self, scan_id, targets, parallel=1):
logger.info("%s: Scan started.", scan_id)
target_list = targets
if target_list is None or not target_list:
raise OSPDError('Erroneous targets list', 'start_scan')
raise OspdCommandError('Erroneous targets list', 'start_scan')

self.process_exclude_hosts(scan_id, target_list)

Expand Down Expand Up @@ -1084,7 +1094,7 @@ def handle_help_command(self, scan_et):
elif help_format == "xml":
text = self.get_xml_str(self.commands)
return simple_response_str('help', 200, 'OK', text)
raise OSPDError('Bogus help format', 'help')
raise OspdCommandError('Bogus help format', 'help')

def get_help_text(self):
""" Returns the help output in plain text format."""
Expand Down Expand Up @@ -1143,7 +1153,7 @@ def handle_delete_scan_command(self, scan_et):
self.check_scan_process(scan_id)
if self.delete_scan(scan_id):
return simple_response_str('delete_scan', 200, 'OK')
raise OSPDError('Scan in progress', 'delete_scan')
raise OspdCommandError('Scan in progress', 'delete_scan')

def delete_scan(self, scan_id):
""" Deletes scan_id scan from collection.
Expand Down Expand Up @@ -1610,10 +1620,10 @@ def handle_command(self, command):
tree = secET.fromstring(command)
except secET.ParseError:
logger.debug("Erroneous client input: %s", command)
raise OSPDError('Invalid data')
raise OspdCommandError('Invalid data')

if not self.command_exists(tree.tag) and tree.tag != "authenticate":
raise OSPDError('Bogus command name')
raise OspdCommandError('Bogus command name')

if tree.tag == "get_version":
return self.handle_get_version_command()
Expand Down
10 changes: 5 additions & 5 deletions ospd/vtfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import re
import operator

from ospd.error import OSPDError
from ospd.errors import OspdCommandError


class VtsFilter(object):
Expand Down Expand Up @@ -57,14 +57,14 @@ def parse_filters(self, vt_filter):
for single_filter in filter_list:
filter_aux = re.split(r'(\W)', single_filter, 1)
if len(filter_aux) < 3:
raise OSPDError(
raise OspdCommandError(
"Invalid number of argument in the filter", "get_vts"
)
_element, _oper, _val = filter_aux
if _element not in self.allowed_filter:
raise OSPDError("Invalid filter element", "get_vts")
raise OspdCommandError("Invalid filter element", "get_vts")
if _oper not in self.filter_operator:
raise OSPDError("Invalid filter operator", "get_vts")
raise OspdCommandError("Invalid filter operator", "get_vts")

filters.append(filter_aux)

Expand Down Expand Up @@ -109,7 +109,7 @@ def get_filtered_vts_list(self, vts, vt_filter):
Dictionary with filtered vulnerability tests.
"""
if not vt_filter:
raise OSPDError('vt_filter: A valid filter is required.')
raise OspdCommandError('vt_filter: A valid filter is required.')

filters = self.parse_filters(vt_filter)
if not filters:
Expand Down
39 changes: 33 additions & 6 deletions tests/test_error.py → tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,59 @@
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.

""" Test module for OSPDError class
""" Test module for OspdCommandError class
"""

import unittest

from ospd.ospd import OSPDError
from ospd.errors import OspdError, OspdCommandError, RequiredArgument


class OSPDErrorTestCase(unittest.TestCase):
class OspdCommandErrorTestCase(unittest.TestCase):
def test_is_ospd_error(self):
e = OspdCommandError('message')
self.assertIsInstance(e, OspdError)

def test_default_params(self):
e = OSPDError('message')
e = OspdCommandError('message')

self.assertEqual('message', e.message)
self.assertEqual(400, e.status)
self.assertEqual('osp', e.command)

def test_constructor(self):
e = OSPDError('message', 'command', '304')
e = OspdCommandError('message', 'command', '304')

self.assertEqual('message', e.message)
self.assertEqual('command', e.command)
self.assertEqual('304', e.status)

def test_string_conversion(self):
e = OspdCommandError('message foo bar', 'command', '304')

self.assertEqual('message foo bar', str(e))

def test_as_xml(self):
e = OSPDError('message')
e = OspdCommandError('message')

self.assertEqual(
b'<osp_response status="400" status_text="message" />', e.as_xml()
)


class RequiredArgumentTestCase(unittest.TestCase):
def test_raise_exception(self):
with self.assertRaises(RequiredArgument) as cm:
raise RequiredArgument('foo', 'bar')

ex = cm.exception
self.assertEqual(ex.function, 'foo')
self.assertEqual(ex.argument, 'bar')

def test_string_conversion(self):
ex = RequiredArgument('foo', 'bar')
self.assertEqual(str(ex), 'foo: Argument bar is required')

def test_is_ospd_error(self):
e = RequiredArgument('foo', 'bar')
self.assertIsInstance(e, OspdError)
Loading

0 comments on commit 5bb3daf

Please sign in to comment.