Skip to content

Commit

Permalink
Fix: queued scan count (#850)
Browse files Browse the repository at this point in the history
Wait 0, 10, 20, 30, ..., or 90 miliseconds when add a new scan in the queue.
This try to avoid a race condition counting the amount of queued scans when many task
are started at the same time. This leads into wrong queue position
number shown in the logs and could jump the max_queued_scans setting
in some cases.
  • Loading branch information
jjnicola authored Feb 8, 2023
1 parent 52ef7a0 commit fd35308
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 230 deletions.
166 changes: 85 additions & 81 deletions ospd/command/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@


class BaseCommand(metaclass=InitSubclassMeta):

name = None
description = None
attributes = None
Expand Down Expand Up @@ -544,99 +543,104 @@ def handle_xml(self, xml: Element) -> bytes:
Return:
Response string for <start_scan> command.
"""
with self._daemon.scan_collection.scan_collection_lock:
current_queued_scans = self._daemon.get_count_queued_scans()
if (
self._daemon.max_queued_scans
and current_queued_scans >= self._daemon.max_queued_scans
):
logger.info(
'Maximum number of queued scans set to %d reached.',
self._daemon.max_queued_scans,
)
raise OspdCommandError(
'Maximum number of queued scans set to '
f'{str(self._daemon.max_queued_scans)} reached.',
'start_scan',
)

current_queued_scans = self._daemon.get_count_queued_scans()
if (
self._daemon.max_queued_scans
and current_queued_scans >= self._daemon.max_queued_scans
):
logger.info(
'Maximum number of queued scans set to %d reached.',
self._daemon.max_queued_scans,
)
raise OspdCommandError(
'Maximum number of queued scans set to '
f'{str(self._daemon.max_queued_scans)} reached.',
'start_scan',
)

target_str = xml.get('target')
ports_str = xml.get('ports')

# For backward compatibility, if target and ports attributes are set,
# <targets> element is ignored.
if target_str is None or ports_str is None:
target_element = xml.find('targets/target')
if target_element is None:
raise OspdCommandError('No targets or ports', 'start_scan')
target_str = xml.get('target')
ports_str = xml.get('ports')

# For backward compatibility, if target and ports attributes
# are set, <targets> element is ignored.
if target_str is None or ports_str is None:
target_element = xml.find('targets/target')
if target_element is None:
raise OspdCommandError('No targets or ports', 'start_scan')
else:
scan_target = OspRequest.process_target_element(
target_element
)
else:
scan_target = OspRequest.process_target_element(target_element)
else:
scan_target = {
'hosts': target_str,
'ports': ports_str,
'credentials': {},
'exclude_hosts': '',
'finished_hosts': '',
'options': {},
}
logger.warning(
"Legacy start scan command format is being used, which "
"is deprecated since 20.08. Please read the documentation "
"for start scan command."
)
scan_target = {
'hosts': target_str,
'ports': ports_str,
'credentials': {},
'exclude_hosts': '',
'finished_hosts': '',
'options': {},
}
logger.warning(
"Legacy start scan command format is being used, which "
"is deprecated since 20.08. Please read the documentation "
"for start scan command."
)

scan_id = xml.get('scan_id')
if scan_id is not None and scan_id != '' and not valid_uuid(scan_id):
raise OspdCommandError('Invalid scan_id UUID', 'start_scan')
scan_id = xml.get('scan_id')
if (
scan_id is not None
and scan_id != ''
and not valid_uuid(scan_id)
):
raise OspdCommandError('Invalid scan_id UUID', 'start_scan')

if xml.get('parallel'):
logger.warning(
"parallel attribute of start_scan will be ignored, sice "
"parallel scan is not supported by OSPd."
)

if xml.get('parallel'):
logger.warning(
"parallel attribute of start_scan will be ignored, sice "
"parallel scan is not supported by OSPd."
scanner_params = xml.find('scanner_params')
if scanner_params is None:
scanner_params = {}

# params are the parameters we got from the <scanner_params> XML.
params = self._daemon.preprocess_scan_params(scanner_params)

# VTS is an optional element. If present should not be empty.
vt_selection = {} # type: Dict
scanner_vts = xml.find('vt_selection')
if scanner_vts is not None:
if len(scanner_vts) == 0:
raise OspdCommandError('VTs list is empty', 'start_scan')
else:
vt_selection = OspRequest.process_vts_params(scanner_vts)

scan_params = self._daemon.process_scan_params(params)
scan_id_aux = scan_id
scan_id = self._daemon.create_scan(
scan_id, scan_target, scan_params, vt_selection
)

scanner_params = xml.find('scanner_params')
if scanner_params is None:
scanner_params = {}

# params are the parameters we got from the <scanner_params> XML.
params = self._daemon.preprocess_scan_params(scanner_params)
if not scan_id:
id_ = Element('id')
id_.text = scan_id_aux
return simple_response_str('start_scan', 100, 'Continue', id_)

# VTS is an optional element. If present should not be empty.
vt_selection = {} # type: Dict
scanner_vts = xml.find('vt_selection')
if scanner_vts is not None:
if len(scanner_vts) == 0:
raise OspdCommandError('VTs list is empty', 'start_scan')
else:
vt_selection = OspRequest.process_vts_params(scanner_vts)

scan_params = self._daemon.process_scan_params(params)
scan_id_aux = scan_id
scan_id = self._daemon.create_scan(
scan_id, scan_target, scan_params, vt_selection
)
logger.info(
'Scan %s added to the queue in position %d.',
scan_id,
self._daemon.get_count_queued_scans() + 1,
)

if not scan_id:
id_ = Element('id')
id_.text = scan_id_aux
return simple_response_str('start_scan', 100, 'Continue', id_)

logger.info(
'Scan %s added to the queue in position %d.',
scan_id,
current_queued_scans + 1,
)

id_ = Element('id')
id_.text = scan_id
id_.text = scan_id

return simple_response_str('start_scan', 200, 'OK', id_)
return simple_response_str('start_scan', 200, 'OK', id_)


class GetMemoryUsage(BaseCommand):

name = "get_memory_usage"
description = "print the memory consumption of all processes"
attributes = {
Expand Down
4 changes: 2 additions & 2 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,10 @@ def handle_client_stream(self, stream: Stream) -> None:
except (AttributeError, ValueError) as message:
logger.error(message)
return
except (ssl.SSLError) as exception:
except ssl.SSLError as exception:
logger.debug('Error: %s', exception)
break
except (socket.timeout) as exception:
except socket.timeout as exception:
logger.debug('Request timeout: %s', exception)
break

Expand Down
1 change: 0 additions & 1 deletion ospd/resultlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def add_result_to_list(
qod: str = '',
uri: str = '',
) -> None:

result = OrderedDict() # type: Dict
result['type'] = result_type
result['name'] = name
Expand Down
4 changes: 4 additions & 0 deletions ospd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ def __init__(self, file_storage_dir: str) -> None:
) # type: Optional[multiprocessing.managers.SyncManager]
self.scans_table = dict() # type: Dict
self.file_storage_dir = file_storage_dir
self.scan_collection_lock = (
None
) # type: Optional[multiprocessing.managers.Lock]

def init(self):
self.data_manager = multiprocessing.Manager()
self.scan_collection_lock = self.data_manager.RLock()

def add_result(
self,
Expand Down
4 changes: 0 additions & 4 deletions ospd_openvas/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,6 @@ def __init__(
self._mqtt_broker_port = mqtt_broker_port

def init(self, server: BaseServer) -> None:

notus_handler = NotusResultHandler(self.report_results)

if self._mqtt_broker_address:
Expand Down Expand Up @@ -610,7 +609,6 @@ def get_feed_info(self) -> Dict[str, Any]:
feed_info = {}
with feed_info_file.open(encoding='utf-8') as fcontent:
for line in fcontent:

try:
key, value = line.split('=', 1)
except ValueError:
Expand Down Expand Up @@ -1000,7 +998,6 @@ def report_results(self, results: list, scan_id: str) -> bool:

@staticmethod
def is_openvas_process_alive(openvas_process: psutil.Popen) -> bool:

try:
if openvas_process.status() == psutil.STATUS_ZOMBIE:
logger.debug("Process is a Zombie, waiting for it to clean up")
Expand Down Expand Up @@ -1191,7 +1188,6 @@ def exec_scan(self, scan_id: str):

got_results = False
while True:

openvas_process_is_alive = self.is_openvas_process_alive(
openvas_process
)
Expand Down
1 change: 0 additions & 1 deletion ospd_openvas/notus.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class Cache:
def __init__(
self, main_db: MainDB, prefix: str = "internal/notus/advisories"
):

self._main_db = main_db
# Check if it was previously uploaded
self.ctx, _ = OpenvasDB.find_database_by_pattern(
Expand Down
1 change: 0 additions & 1 deletion ospd_openvas/nvticache.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


class NVTICache(BaseDB):

QOD_TYPES = {
'exploit': '100',
'remote_vul': '99',
Expand Down
Loading

0 comments on commit fd35308

Please sign in to comment.