Skip to content

Commit

Permalink
Merge pull request #1955 from blacklanternsecurity/fix-multiprocess-bug
Browse files Browse the repository at this point in the history
Fix Multiprocessing Shenanigans
  • Loading branch information
TheTechromancer authored Nov 18, 2024
2 parents 0ca76f3 + 53e9084 commit b0e5bc9
Show file tree
Hide file tree
Showing 12 changed files with 389 additions and 176 deletions.
9 changes: 6 additions & 3 deletions bbot/core/config/logger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
import atexit
import logging
Expand All @@ -9,6 +10,7 @@

from ..helpers.misc import mkdir, error_and_exit
from ...logger import colorize, loglevel_mapping
from ..multiprocess import SHARED_INTERPRETER_STATE


debug_format = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s %(filename)s:%(lineno)s %(message)s")
Expand Down Expand Up @@ -65,8 +67,9 @@ def __init__(self, core):

self.listener = None

self.process_name = multiprocessing.current_process().name
if self.process_name == "MainProcess":
# if we haven't set up logging yet, do it now
if not "_BBOT_LOGGING_SETUP" in os.environ:
os.environ["_BBOT_LOGGING_SETUP"] = "1"
self.queue = multiprocessing.Queue()
self.setup_queue_handler()
# Start the QueueListener
Expand Down Expand Up @@ -113,7 +116,7 @@ def setup_queue_handler(self, logging_queue=None, log_level=logging.DEBUG):

self.core_logger.setLevel(log_level)
# disable asyncio logging for child processes
if self.process_name != "MainProcess":
if not SHARED_INTERPRETER_STATE.is_main_process:
logging.getLogger("asyncio").setLevel(logging.ERROR)

def addLoggingLevel(self, levelName, levelNum, methodName=None):
Expand Down
23 changes: 20 additions & 3 deletions bbot/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from omegaconf import OmegaConf

from bbot.errors import BBOTError
from .multiprocess import SHARED_INTERPRETER_STATE


DEFAULT_CONFIG = None
Expand Down Expand Up @@ -41,9 +42,23 @@ def __init__(self):
self.logger
self.log = logging.getLogger("bbot.core")

self._prep_multiprocessing()

def _prep_multiprocessing(self):
import multiprocessing
from .helpers.process import BBOTProcess

if SHARED_INTERPRETER_STATE.is_main_process:
# if this is the main bbot process, set the logger and queue for the first time
from functools import partialmethod

self.process_name = multiprocessing.current_process().name
BBOTProcess.__init__ = partialmethod(
BBOTProcess.__init__, log_level=self.logger.log_level, log_queue=self.logger.queue
)

# this makes our process class the default for process pools, etc.
mp_context = multiprocessing.get_context("spawn")
mp_context.Process = BBOTProcess

@property
def home(self):
Expand Down Expand Up @@ -187,12 +202,14 @@ def create_process(self, *args, **kwargs):
if os.environ.get("BBOT_TESTING", "") == "True":
process = self.create_thread(*args, **kwargs)
else:
if self.process_name == "MainProcess":
if SHARED_INTERPRETER_STATE.is_scan_process:
from .helpers.process import BBOTProcess

process = BBOTProcess(*args, **kwargs)
else:
raise BBOTError(f"Tried to start server from process {self.process_name}")
import multiprocessing

raise BBOTError(f"Tried to start server from process {multiprocessing.current_process().name}")
process.daemon = True
return process

Expand Down
6 changes: 3 additions & 3 deletions bbot/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import contextlib
import contextvars
import zmq.asyncio
import multiprocessing
from pathlib import Path
from concurrent.futures import CancelledError
from contextlib import asynccontextmanager, suppress

from bbot.core import CORE
from bbot.errors import BBOTEngineError
from bbot.core.helpers.async_helpers import get_event_loop
from bbot.core.multiprocess import SHARED_INTERPRETER_STATE
from bbot.core.helpers.misc import rand_string, in_exception_chain


Expand Down Expand Up @@ -264,10 +266,8 @@ def available_commands(self):
return [s for s in self.CMDS if isinstance(s, str)]

def start_server(self):
import multiprocessing

process_name = multiprocessing.current_process().name
if process_name == "MainProcess":
if SHARED_INTERPRETER_STATE.is_scan_process:
kwargs = dict(self.server_kwargs)
# if we're in tests, we use a single event loop to avoid weird race conditions
# this allows us to more easily mock http, etc.
Expand Down
7 changes: 4 additions & 3 deletions bbot/core/helpers/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ async def _write_proc_line(proc, chunk):
return True
except Exception as e:
proc_args = [str(s) for s in getattr(proc, "args", [])]
command = " ".join(proc_args)
log.warning(f"Error writing line to stdin for command: {command}: {e}")
log.trace(traceback.format_exc())
command = " ".join(proc_args).strip()
if command:
log.warning(f"Error writing line to stdin for command: {command}: {e}")
log.trace(traceback.format_exc())
return False


Expand Down
18 changes: 0 additions & 18 deletions bbot/core/helpers/process.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import logging
import traceback
import threading
import multiprocessing
from multiprocessing.context import SpawnProcess

from .misc import in_exception_chain


current_process = multiprocessing.current_process()


class BBOTThread(threading.Thread):

default_name = "default bbot thread"
Expand Down Expand Up @@ -57,17 +53,3 @@ def run(self):
if not in_exception_chain(e, (KeyboardInterrupt,)):
log.warning(f"Error in {self.name}: {e}")
log.trace(traceback.format_exc())


if current_process.name == "MainProcess":
# if this is the main bbot process, set the logger and queue for the first time
from bbot.core import CORE
from functools import partialmethod

BBOTProcess.__init__ = partialmethod(
BBOTProcess.__init__, log_level=CORE.logger.log_level, log_queue=CORE.logger.queue
)

# this makes our process class the default for process pools, etc.
mp_context = multiprocessing.get_context("spawn")
mp_context.Process = BBOTProcess
58 changes: 58 additions & 0 deletions bbot/core/multiprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import atexit
from contextlib import suppress


class SharedInterpreterState:
"""
A class to track the primary BBOT process.
Used to prevent spawning multiple unwanted processes with multiprocessing.
"""

def __init__(self):
self.main_process_var_name = "_BBOT_MAIN_PID"
self.scan_process_var_name = "_BBOT_SCAN_PID"
atexit.register(self.cleanup)

@property
def is_main_process(self):
is_main_process = self.main_pid == os.getpid()
return is_main_process

@property
def is_scan_process(self):
is_scan_process = os.getpid() == self.scan_pid
return is_scan_process

@property
def main_pid(self):
main_pid = int(os.environ.get(self.main_process_var_name, 0))
if main_pid == 0:
main_pid = os.getpid()
# if main PID is not set, set it to the current PID
os.environ[self.main_process_var_name] = str(main_pid)
return main_pid

@property
def scan_pid(self):
scan_pid = int(os.environ.get(self.scan_process_var_name, 0))
if scan_pid == 0:
scan_pid = os.getpid()
# if scan PID is not set, set it to the current PID
os.environ[self.scan_process_var_name] = str(scan_pid)
return scan_pid

def update_scan_pid(self):
os.environ[self.scan_process_var_name] = str(os.getpid())

def cleanup(self):
with suppress(Exception):
if self.is_main_process:
with suppress(KeyError):
del os.environ[self.main_process_var_name]
with suppress(KeyError):
del os.environ[self.scan_process_var_name]


SHARED_INTERPRETER_STATE = SharedInterpreterState()
5 changes: 4 additions & 1 deletion bbot/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from collections import OrderedDict

from bbot import __version__

from bbot.core.event import make_event
from .manager import ScanIngress, ScanEgress
from bbot.core.helpers.misc import sha1, rand_string
from bbot.core.helpers.names_generator import random_name
from bbot.core.multiprocess import SHARED_INTERPRETER_STATE
from bbot.core.helpers.async_helpers import async_to_sync_gen
from bbot.errors import BBOTError, ScanError, ValidationError

Expand Down Expand Up @@ -259,6 +259,9 @@ async def _prep(self):
Creates the scan's output folder, loads its modules, and calls their .setup() methods.
"""

# update the master PID
SHARED_INTERPRETER_STATE.update_scan_pid()

self.helpers.mkdir(self.home)
if not self._prepped:
# save scan preset
Expand Down
1 change: 0 additions & 1 deletion bbot/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# silence stdout + trace
root_logger = logging.getLogger()
pytest_debug_file = Path(__file__).parent.parent.parent / "pytest_debug.log"
print(f"pytest_debug_file: {pytest_debug_file}")
debug_handler = logging.FileHandler(pytest_debug_file)
debug_handler.setLevel(logging.DEBUG)
debug_format = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s %(filename)s:%(lineno)s %(message)s")
Expand Down
17 changes: 17 additions & 0 deletions bbot/test/fastapi_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List
from bbot import Scanner
from fastapi import FastAPI, Query

app = FastAPI()


@app.get("/start")
async def start(targets: List[str] = Query(...)):
scanner = Scanner(*targets, modules=["httpx"])
events = [e async for e in scanner.async_start()]
return [e.json() for e in events]


@app.get("/ping")
async def ping():
return {"status": "ok"}
82 changes: 82 additions & 0 deletions bbot/test/test_step_1/test_bbot_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import time
import httpx
import multiprocessing
from pathlib import Path
from subprocess import Popen
from contextlib import suppress

cwd = Path(__file__).parent.parent.parent


def run_bbot_multiprocess(queue):
from bbot import Scanner

scan = Scanner("http://127.0.0.1:8888", "blacklanternsecurity.com", modules=["httpx"])
events = [e.json() for e in scan.start()]
queue.put(events)


def test_bbot_multiprocess(bbot_httpserver):

bbot_httpserver.expect_request("/").respond_with_data("[email protected]")

queue = multiprocessing.Queue()
events_process = multiprocessing.Process(target=run_bbot_multiprocess, args=(queue,))
events_process.start()
events_process.join()
events = queue.get()
assert len(events) >= 3
scan_events = [e for e in events if e["type"] == "SCAN"]
assert len(scan_events) == 2
assert any([e["data"] == "[email protected]" for e in events])


def test_bbot_fastapi(bbot_httpserver):

bbot_httpserver.expect_request("/").respond_with_data("[email protected]")
fastapi_process = start_fastapi_server()

try:

# wait for the server to start with a timeout of 60 seconds
start_time = time.time()
while True:
try:
response = httpx.get("http://127.0.0.1:8978/ping")
response.raise_for_status()
break
except httpx.HTTPError:
if time.time() - start_time > 60:
raise TimeoutError("Server did not start within 60 seconds.")
time.sleep(0.1)
continue

# run a scan
response = httpx.get(
"http://127.0.0.1:8978/start",
params={"targets": ["http://127.0.0.1:8888", "blacklanternsecurity.com"]},
timeout=100,
)
events = response.json()
assert len(events) >= 3
scan_events = [e for e in events if e["type"] == "SCAN"]
assert len(scan_events) == 2
assert any([e["data"] == "[email protected]" for e in events])

finally:
with suppress(Exception):
fastapi_process.terminate()


def start_fastapi_server():
import os
import sys

env = os.environ.copy()
with suppress(KeyError):
del env["BBOT_TESTING"]
python_executable = str(sys.executable)
process = Popen(
[python_executable, "-m", "uvicorn", "bbot.test.fastapi_test:app", "--port", "8978"], cwd=cwd, env=env
)
return process
Loading

0 comments on commit b0e5bc9

Please sign in to comment.