diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 00000000..5d8828f0 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,24 @@ +name: MyPy maltoolbox +on: [push] + +jobs: + + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[dev] .[ml] + - name: Type checking with MyPy + run: | + pip install mypy + mypy malsim tests --install-types --non-interactive diff --git a/malsim/__main__.py b/malsim/__main__.py index 30eb7978..aba2df1a 100644 --- a/malsim/__main__.py +++ b/malsim/__main__.py @@ -3,6 +3,7 @@ from __future__ import annotations import argparse import logging +from typing import Any, Optional from .mal_simulator import MalSimulator from .agents import DecisionAgent @@ -12,11 +13,11 @@ logger = logging.getLogger(__name__) logging.getLogger().setLevel(logging.INFO) -def run_simulation(sim: MalSimulator, agents: list[dict]): +def run_simulation(sim: MalSimulator, agents: list[dict[str, Any]]) -> None: """Run a simulation with agents""" sim.reset() - total_rewards = {agent_dict['name']: 0 for agent_dict in agents} + total_rewards = {agent_dict['name']: 0.0 for agent_dict in agents} all_agents_term_or_trunc = False logger.info("Starting CLI env simulator.") @@ -29,7 +30,7 @@ def run_simulation(sim: MalSimulator, agents: list[dict]): # Select actions for each agent for agent_dict in agents: - decision_agent: DecisionAgent = agent_dict.get('agent') + decision_agent: Optional[DecisionAgent] = agent_dict.get('agent') agent_name = agent_dict['name'] if decision_agent is None: logger.warning( @@ -66,7 +67,7 @@ def run_simulation(sim: MalSimulator, agents: list[dict]): agent_name = agent_dict['name'] print(f'Total reward "{agent_name}"', total_rewards[agent_name]) -def main(): +def main() -> None: """Entrypoint function of the MAL Toolbox CLI""" parser = argparse.ArgumentParser() parser.add_argument( diff --git a/malsim/agents/decision_agent.py b/malsim/agents/decision_agent.py index 34e2d92a..7687cd44 100644 --- a/malsim/agents/decision_agent.py +++ b/malsim/agents/decision_agent.py @@ -1,7 +1,7 @@ """A decision agent is a heuristic agent""" from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Any from abc import ABC, abstractmethod if TYPE_CHECKING: @@ -14,7 +14,7 @@ class DecisionAgent(ABC): def get_next_action( self, agent_state: MalSimAgentStateView, - **kwargs + **kwargs: Any ) -> Optional[AttackGraphNode]: """ Select next action the agent will work with. diff --git a/malsim/agents/heuristic_agent.py b/malsim/agents/heuristic_agent.py index e1377207..0fbc7285 100644 --- a/malsim/agents/heuristic_agent.py +++ b/malsim/agents/heuristic_agent.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import logging import math @@ -16,7 +16,7 @@ class DefendCompromisedDefender(DecisionAgent): """A defender that defends compromised assets using notPresent""" - def __init__(self, agent_config, **_): + def __init__(self, agent_config: dict[str, Any], **_: Any): # Seed and rng not currently used seed = ( agent_config["seed"] @@ -30,7 +30,7 @@ def __init__(self, agent_config, **_): ) def get_next_action( - self, agent_state: MalSimAgentStateView, **kwargs + self, agent_state: MalSimAgentStateView, **kwargs: Any ) -> Optional[AttackGraphNode]: """Return an action that disables a compromised node""" @@ -71,7 +71,7 @@ def get_next_action( class DefendFutureCompromisedDefender(DecisionAgent): """A defender that defends compromised assets using notPresent""" - def __init__(self, agent_config, **_): + def __init__(self, agent_config: dict[str, Any], **_: Any): # Seed and rng not currently used seed = ( agent_config["seed"] @@ -85,7 +85,7 @@ def __init__(self, agent_config, **_): ) def get_next_action( - self, agent_state: MalSimAgentStateView, **kwargs + self, agent_state: MalSimAgentStateView, **kwargs: Any ) -> Optional[AttackGraphNode]: """Return an action that disables a compromised node""" diff --git a/malsim/agents/keyboard_input.py b/malsim/agents/keyboard_input.py index 1d738e27..ed3209ea 100644 --- a/malsim/agents/keyboard_input.py +++ b/malsim/agents/keyboard_input.py @@ -1,6 +1,6 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Any from .decision_agent import DecisionAgent from ..mal_simulator import MalSimAgentStateView @@ -13,14 +13,14 @@ class KeyboardAgent(DecisionAgent): """An agent that makes decisions by asking user for keyboard input""" - def __init__(self, _, **kwargs): + def __init__(self, _: Any, **kwargs: Any): super().__init__(**kwargs) logger.info("Creating KeyboardAgent") def get_next_action( self, agent_state: MalSimAgentStateView, - **kwargs + **kwargs: Any ) -> Optional[AttackGraphNode]: """Compute action from action_surface""" @@ -35,13 +35,13 @@ def valid_action(user_input: str) -> bool: return 0 <= node <= len(agent_state.action_surface) - def get_action_object(user_input: str) -> tuple: + def get_action_object(user_input: str) -> Optional[int]: node = int(user_input) if user_input != "" else None return node if not agent_state.action_surface: print("No actions to pick for defender") - return [] + return None index_to_node = dict(enumerate(agent_state.action_surface)) user_input = "xxx" diff --git a/malsim/agents/passive_agent.py b/malsim/agents/passive_agent.py index f75ddde5..23de2724 100644 --- a/malsim/agents/passive_agent.py +++ b/malsim/agents/passive_agent.py @@ -1,7 +1,7 @@ """A passive agent that always choose to do nothing""" from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Any from .decision_agent import DecisionAgent from ..mal_simulator import MalSimAgentStateView @@ -11,13 +11,13 @@ from maltoolbox.attackgraph import AttackGraphNode class PassiveAgent(DecisionAgent): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): ... def get_next_action( self, agent_state: MalSimAgentStateView, - **kwargs + **kwargs: Any ) -> Optional[AttackGraphNode]: # A passive agent never does anything return None diff --git a/malsim/agents/searchers.py b/malsim/agents/searchers.py index c7f1bc4e..02712f9d 100644 --- a/malsim/agents/searchers.py +++ b/malsim/agents/searchers.py @@ -33,7 +33,7 @@ class BreadthFirstAttacker(DecisionAgent): 'seed': None, } - def __init__(self, agent_config: dict) -> None: + def __init__(self, agent_config: dict[str, Any]) -> None: """Initialize a BFS/DFS agent. Args: @@ -47,7 +47,7 @@ def __init__(self, agent_config: dict) -> None: self._started = False def get_next_action( - self, agent_state: MalSimAgentStateView, **kwargs + self, agent_state: MalSimAgentStateView, **kwargs: Any ) -> Optional[AttackGraphNode]: """Receive the next action according to agent policy (bfs/dfs)""" @@ -69,7 +69,7 @@ def _update_targets( self, new_nodes: set[AttackGraphNode], disabled_nodes: set[AttackGraphNode], - ): + ) -> None: new_targets: list[AttackGraphNode] = [] if self._settings['seed']: # If a seed is set, we assume the user wants determinism in the diff --git a/malsim/envs/__init__.py b/malsim/envs/__init__.py index 697e7dfe..a5096849 100644 --- a/malsim/envs/__init__.py +++ b/malsim/envs/__init__.py @@ -1,5 +1,5 @@ from .malsim_vectorized_obs_env import MalSimVectorizedObsEnv -from .gym_envs import AttackerEnv, DefenderEnv, register_envs +from .gym_envs import AttackerEnv, DefenderEnv, register_envs # type: ignore # not needed, used to silence ruff F401 __all__ = [ diff --git a/malsim/envs/base_classes.py b/malsim/envs/base_classes.py index 07970b65..8c655e6f 100644 --- a/malsim/envs/base_classes.py +++ b/malsim/envs/base_classes.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any, Optional from ..mal_simulator import MalSimulator, MalSimAgentStateView class MalSimEnv(ABC): @@ -7,24 +8,28 @@ def __init__(self, sim: MalSimulator): self.sim = sim @abstractmethod - def step(self, actions): + def step(self, actions: Any) -> Any: ... - def reset(self, seed=None, options=None): + def reset( + self, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None + ) -> None: self.sim.reset(seed=seed, options=options) def register_attacker( self, attacker_name: str, attacker_id: int - ): + ) -> None: self.sim.register_attacker(attacker_name, attacker_id) def register_defender( self, defender_name: str - ): + ) -> None: self.sim.register_defender(defender_name) def get_agent_state(self, agent_name: str) -> MalSimAgentStateView: return self.sim.agent_states[agent_name] - def render(self): + def render(self) -> None: pass diff --git a/malsim/envs/gym_envs.py b/malsim/envs/gym_envs.py index 81d58405..38679d54 100644 --- a/malsim/envs/gym_envs.py +++ b/malsim/envs/gym_envs.py @@ -1,4 +1,9 @@ -from typing import Any, Dict, SupportsFloat +# type: ignore +# Ignoring type checking in this file for now +# before someone with more Gymnasium knowledge +# can jump into the code base + +from typing import Any, Dict, SupportsFloat, Optional import gymnasium as gym import gymnasium.utils.env_checker as env_checker @@ -13,10 +18,10 @@ from ..agents import DecisionAgent -class AttackerEnv(gym.Env): +class AttackerEnv(gym.Env[Any, Any]): metadata = {'render_modes': []} - def __init__(self, scenario_file: str, **kwargs) -> None: + def __init__(self, scenario_file: str, **kwargs: Any) -> None: """ Params: - scenario_file: the scenario that should be loaded @@ -79,21 +84,21 @@ def step( infos[self.attacker_agent_name] ) - def render(self): - return self.sim.render() + def render(self) -> None: + self.sim.render() @property - def num_assets(self): + def num_assets(self) -> int: return len(self.sim._index_to_asset_type) @property - def num_step_names(self): + def num_step_names(self) -> int: return len(self.sim._index_to_step_name) -class DefenderEnv(gym.Env): +class DefenderEnv(gym.Env[Any, Any]): metadata = {'render_modes': []} - def __init__(self, scenario_file, **kwargs) -> None: + def __init__(self, scenario_file: str, **kwargs: Any) -> None: self.randomize = kwargs.pop('randomize_attacker_behavior', False) self.render_mode = kwargs.pop('render_mode', None) @@ -104,7 +109,7 @@ def __init__(self, scenario_file, **kwargs) -> None: # Register attacker agents from scenario self._register_attacker_agents(self.scenario_agents) - self.attacker_decision_agents = {} + self.attacker_decision_agents: dict[str, DecisionAgent] = {} # Register defender agent self.defender_agent_name = "DefenderEnvAgent" @@ -116,7 +121,7 @@ def __init__(self, scenario_file, **kwargs) -> None: self.action_space = \ self.sim.action_space(self.defender_agent_name) - def _register_attacker_agents(self, agents: list[dict]): + def _register_attacker_agents(self, agents: list[dict[str, Any]]) -> None: """Register attackers in simulator""" for agent_info in agents: if agent_info['type'] == AgentType.ATTACKER: @@ -125,7 +130,7 @@ def _register_attacker_agents(self, agents: list[dict]): agent_info['attacker_id']) def _create_attacker_decision_agents( - self, agents: list[dict], seed=None + self, agents: list[dict[str, Any]], seed: Optional[int] = None ) -> dict[str, DecisionAgent]: """Create decision agents for each attacker""" @@ -187,8 +192,8 @@ def step( infos[self.defender_agent_name], ) - def render(self): - return self.sim.render() + def render(self) -> None: + self.sim.render() @staticmethod def add_reverse_edges(edges: np.ndarray, defense_steps: set) -> np.ndarray: @@ -202,25 +207,25 @@ def add_reverse_edges(edges: np.ndarray, defense_steps: set) -> np.ndarray: return edges @property - def num_assets(self): + def num_assets(self) -> int: return len(self.sim._index_to_asset_type) @property - def num_step_names(self): + def num_step_names(self) -> int: return len(self.sim._index_to_step_name) -def _to_binary(val, max_val): +def _to_binary(val: int, max_val: int) -> np.typing.NDArray[np.int64]: return np.array( list(np.binary_repr(val, width=max_val.bit_length())), dtype=np.int64 ) -def vec_to_binary(vec, max_val): +def vec_to_binary(vec: list[int], max_val: int) -> np.typing.NDArray[np.int64]: return np.array([_to_binary(val, max_val) for val in vec]) -def vec_to_one_hot(vec, num_vals): +def vec_to_one_hot(vec: list[int], num_vals: int) -> np.typing.NDArray[np.int8]: return np.eye(num_vals, dtype=np.int8)[vec] @@ -302,7 +307,11 @@ def _to_graph(self, obs: dict[str, Any], info: dict[str, Any]) -> dict[str, Any] return to_graph(obs, info, self.num_steps) -def to_graph(obs: dict[str, Any], info: dict[str, Any], num_steps) -> dict[str, Any]: +def to_graph( + obs: dict[str, Any], + info: dict[str, Any], + num_steps: int + ) -> dict[str, Any]: nodes = np.concatenate( [ vec_to_one_hot(obs['observed_state'] + 1, 3), @@ -318,7 +327,7 @@ def to_graph(obs: dict[str, Any], info: dict[str, Any], num_steps) -> dict[str, } -def register_envs(): +def register_envs() -> None: gym.register('MALDefenderEnv-v0', entry_point=DefenderEnv) gym.register('MALAttackerEnv-v0', entry_point=AttackerEnv) diff --git a/malsim/envs/malsim_vectorized_obs_env.py b/malsim/envs/malsim_vectorized_obs_env.py index 58b08e90..ed72f4c4 100644 --- a/malsim/envs/malsim_vectorized_obs_env.py +++ b/malsim/envs/malsim_vectorized_obs_env.py @@ -7,6 +7,7 @@ from __future__ import annotations +from typing import Any, Optional import functools import logging import sys @@ -24,302 +25,12 @@ MalSimDefenderState ) -from .base_classes import MalSimEnv ITERATIONS_LIMIT = int(1e9) logger = logging.getLogger(__name__) -# First the logging methods: -def format_full_observation(sim, observation): - """ - Return a formatted string of the entire observation. This includes - sections that will not change over time, these define the structure of - the attack graph. - """ - obs_str = '\nAttack Graph Steps\n' - - str_format = "{:<5} {:<80} {:<6} {:<5} {:<5} {:<30} {:<8} {:<}\n" - header_entry = [ - "Entry", "Name", "Is_Obs", "State", - "RTTC", "Asset Type(Index)", "Asset Id", "Step" - ] - entries = [] - for entry in range(0, len(observation["observed_state"])): - asset_type_index = observation["asset_type"][entry] - asset_type_str = sim._index_to_asset_type[asset_type_index ] + \ - '(' + str(asset_type_index) + ')' - entries.append( - [ - entry, - sim._index_to_full_name[entry], - observation["is_observable"][entry], - observation["observed_state"][entry], - observation["remaining_ttc"][entry], - asset_type_str, - observation["asset_id"][entry], - observation["step_name"][entry], - ] - ) - obs_str += format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - obs_str += "\nAttack Graph Edges:\n" - for edge in observation["attack_graph_edges"]: - obs_str += str(edge) + "\n" - - obs_str += "\nInstance Model Assets:\n" - str_format = "{:<5} {:<5} {:<}\n" - header_entry = [ - "Entry", "Id", "Type(Index)"] - entries = [] - for entry in range(0, len(observation["model_asset_id"])): - asset_type_str = sim._index_to_asset_type[ - observation["model_asset_type"][entry]] + \ - '(' + str(observation["model_asset_type"][entry]) + ')' - entries.append( - [ - entry, - observation["model_asset_id"][entry], - asset_type_str - ] - ) - obs_str += format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - obs_str += "\nInstance Model Edges:\n" - str_format = "{:<5} {:<40} {:<40} {:<}\n" - header_entry = [ - "Entry", - "Left Asset(Id/Index)", - "Right Asset(Id/Index)", - "Type(Index)" - ] - entries = [] - for entry in range(0, len(observation["model_edges_ids"])): - assoc_type_str = sim._index_to_model_assoc_type[ - observation["model_edges_type"][entry]] + \ - '(' + str(observation["model_edges_type"][entry]) + ')' - left_asset_index = int(observation["model_edges_ids"][entry][0]) - right_asset_index = int(observation["model_edges_ids"][entry][1]) - left_asset_id = sim._index_to_model_asset_id[left_asset_index] - right_asset_id = sim._index_to_model_asset_id[right_asset_index] - left_asset_str = \ - sim.model.get_asset_by_id(left_asset_id).name + \ - '(' + str(left_asset_id) + '/' + str(left_asset_index) + ')' - right_asset_str = \ - sim.model.get_asset_by_id(right_asset_id).name + \ - '(' + str(right_asset_id) + '/' + str(right_asset_index) + ')' - entries.append( - [ - entry, - left_asset_str, - right_asset_str, - assoc_type_str - ] - ) - obs_str += format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - return obs_str - -def format_obs_var_sec( - sim, - observation, - included_values = [-1, 0, 1] - ): - """ - Return a formatted string of the sections of the observation that can - vary over time. - - Arguments: - observation - the observation to format - included_values - the values to list, any values not present in the - list will be filtered out - """ - - str_format = "{:>5} {:>80} {:<5} {:<5} {:<}\n" - header_entry = ["Id", "Name", "State", "RTTC", "Entry"] - entries = [] - for entry in range(0, len(observation["observed_state"])): - if observation["is_observable"][entry] and \ - observation["observed_state"][entry] in included_values: - entries.append( - [ - sim._index_to_id[entry], - sim._index_to_full_name[entry], - observation["observed_state"][entry], - observation["remaining_ttc"][entry], - entry - ] - ) - - obs_str = format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - return obs_str - -def format_info(sim, info): - can_act = "Yes" if info["action_mask"][0][1] > 0 else "No" - agent_info_str = f"Can act? {can_act}\n" - for entry in range(0, len(info["action_mask"][1])): - if info["action_mask"][1][entry] == 1: - agent_info_str += f"{sim._index_to_id[entry]} " \ - f"{sim._index_to_full_name[entry]}\n" - return agent_info_str - - -def log_mapping_tables(logger, sim): - """Log all mapping tables in MalSimulator""" - - str_format = "{:<5} {:<15} {:<}\n" - table = "\n" - header_entry = ["Index", "Attack Step Id", "Attack Step Full Name"] - entries = [] - for entry in sim._index_to_id: - entries.append( - [ - sim._id_to_index[entry], - entry, - sim._index_to_full_name[sim._id_to_index[entry]] - ] - ) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Asset Id"] - entries = [] - for entry in sim._model_asset_id_to_index: - entries.append( - [ - sim._model_asset_id_to_index[entry], - entry - ] - ) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Asset Type"] - entries = [] - for entry in sim._asset_type_to_index: - entries.append( - [ - sim._asset_type_to_index[entry], - entry - ] - ) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Attack Step Name"] - entries = [] - for entry in sim._index_to_step_name: - entries.append([sim._step_name_to_index[entry], entry]) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Association Type"] - entries = [] - for entry in sim._index_to_model_assoc_type: - entries.append([sim._model_assoc_type_to_index[entry], entry]) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - -def format_table( - entry_format: str, - header_entry: list[str], - entries: list[list[str]], - reprint_header: int = 0 - ) -> str: - """ - Format a table according to the parameters specified. - - Arguments: - entry_format - The string format for the table - reprint_header - How many rows apart to reprint the header. If 0 the - header will not be reprinted. - header_entry - The entry representing the header of the table - entries - The list of entries to format - - Return: - The formatted table. - """ - - formatted_str = '' - header = entry_format.format(*header_entry) - formatted_str += header - for entry_nr, entry in zip(range(0, len(entries)), entries): - formatted_str += entry_format.format(*entry) - if (reprint_header != 0) and ((entry_nr + 1) % reprint_header == 0): - formatted_str += header - return formatted_str - - -def log_agent_state( - logger, sim, agent, terminations, truncations, infos - ): - """Debug log all an agents current state""" - - agent_obs_str = format_obs_var_sec( - sim, agent.observation, included_values = [0, 1] - ) - - logger.debug( - 'Observation for agent "%s":\n%s', agent.name, agent_obs_str) - logger.debug( - 'Rewards for agent "%s": %d', agent.name, agent.reward) - logger.debug( - 'Termination for agent "%s": %s', - agent.name, terminations[agent.name]) - logger.debug( - 'Truncation for agent "%s": %s', - agent.name, str(truncations[agent.name])) - agent_info_str = format_info(sim, infos[agent.name]) - logger.debug( - 'Info for agent "%s":\n%s', agent.name, agent_info_str) - - -# Now the actual class: - -class MalSimVectorizedObsEnv(ParallelEnv, MalSimEnv): +class MalSimVectorizedObsEnv(ParallelEnv): # type: ignore """ Environment that runs simulation between agents. Builds serialized observations. @@ -331,11 +42,13 @@ def __init__( sim: MalSimulator ): - super().__init__(sim) + self.sim = sim # Useful instead of having to fetch .sim.attack_graph self.attack_graph = sim.attack_graph - + assert self.attack_graph.model, ( + "Attack graph in simulator needs to have a model attached to it" + ) # List mapping from node/asset index to id/name/type self._index_to_id = [n.id for n in self.attack_graph.nodes.values()] self._index_to_full_name = ( @@ -379,25 +92,26 @@ def __init__( enumerate(self._index_to_model_assoc_type) } - if logger.isEnabledFor(logging.DEBUG): - log_mapping_tables(logger, self) - self._blank_observation = self._create_blank_observation() - - self._agent_observations = {} - self._agent_infos = {} + self._agent_observations: dict[str, Any] = {} + self._agent_infos: dict[str, Any] = {} @property - def agents(self): + def agents(self) -> list[str]: """Required by ParallelEnv""" return list(self.sim._alive_agents) @property - def possible_agents(self): + def possible_agents(self) -> list[str]: """Required by ParallelEnv""" return list(self.sim._agent_states.keys()) - def _create_blank_observation(self, default_obs_state=-1): + def get_agent_state(self, agent_name: str) -> MalSimAgentStateView: + return self.sim.agent_states[agent_name] + + def _create_blank_observation( + self, default_obs_state: int = -1 + ) -> dict[str, Any]: """Create the initial observation""" # For now, an `object` is an attack step num_steps = len(self.sim.attack_graph.nodes) @@ -415,7 +129,8 @@ def _create_blank_observation(self, default_obs_state=-1): self._asset_type_to_index[step.lg_attack_step.asset.name] for step in self.attack_graph.nodes.values()], "asset_id": [step.model_asset.id - for step in self.attack_graph.nodes.values()], + for step in self.attack_graph.nodes.values() + if step.model_asset], "step_name": [ self._step_name_to_index.get( str(step.lg_attack_step.asset.name + ":" + step.name) @@ -460,6 +175,7 @@ def _create_blank_observation(self, default_obs_state=-1): observation["model_edges_ids"] = [] observation["model_edges_type"] = [] + assert self.attack_graph.model, "Graph needs model attached to it" for asset in self.attack_graph.model.assets.values(): observation["model_asset_id"].append(asset.id) observation["model_asset_type"].append( @@ -502,15 +218,9 @@ def _create_blank_observation(self, default_obs_state=-1): "model_edges_type": np.array(observation["model_edges_type"], dtype=np.int64) } - - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - format_full_observation(self, np_obs) - ) - return np_obs - def create_action_mask(self, agent: MalSimAgentStateView): + def create_action_mask(self, agent: MalSimAgentStateView) -> dict[str, Any]: """ Create an action mask for an agent based on its action_surface. @@ -548,69 +258,23 @@ def create_action_mask(self, agent: MalSimAgentStateView): ) } - def _update_agent_infos(self): + def _update_agent_infos(self) -> None: for agent in self.sim.agent_states.values(): self._agent_infos[agent.name] = self.create_action_mask(agent) - def _get_association_full_name(self, association) -> str: - """Get association full name - - TODO: Remove this method once the language graph integration is - complete in the mal-toolbox because the language graph associations - will use their full names for the name property - - Arguments: - association - the association whose full name will be returned - - Return: - A string containing the association name and the name of each of the - two asset types for the left and right fields separated by - underscores. - """ - - assoc_name = association.__class__.__name__ - if '_' in assoc_name: - # TODO: Not actually a to-do, but just an extra clarification that - # this is an ugly hack that will work for now until we get the - # unique association names. Right now some associations already - # use the asset types as part of their name if there are multiple - # associations with the same name. - return assoc_name - - left_field_name, right_field_name = \ - self.sim.attack_graph.model.get_association_field_names(association) - left_field = getattr(association, left_field_name) - right_field = getattr(association, right_field_name) - lang_assoc = self.sim.attack_graph.lang_graph.get_association_by_fields_and_assets( - left_field_name, - right_field_name, - left_field[0].type, - right_field[0].type - ) - if lang_assoc is None: - raise LookupError('Failed to find association for fields ' - '"%s" "%s" and asset types "%s" "%s"!' % ( - left_field_name, - right_field_name, - left_field[0].type, - right_field[0].type - ) - ) - assoc_full_name = lang_assoc.name + '_' + \ - lang_assoc.left_field.asset.name + '_' + \ - lang_assoc.right_field.asset.name - return assoc_full_name - @functools.lru_cache(maxsize=None) - def action_space(self, agent=None): + def action_space(self, agent: Optional[str] = None) -> MultiDiscrete: num_actions = 2 # two actions: wait or use # For now, an `object` is an attack step num_steps = len(self.sim.attack_graph.nodes) return MultiDiscrete([num_actions, num_steps], dtype=np.int64) @functools.lru_cache(maxsize=None) - def observation_space(self, agent_name: str = None): + def observation_space(self, agent_name: Optional[str] = None) -> Dict: # For now, an `object` is an attack step + assert self.attack_graph.model, ( + "Attack graph in simulator needs to have a model attached to it" + ) num_assets = len(self.attack_graph.model.assets) num_steps = len(self.attack_graph.nodes) num_lang_asset_types = len(self.sim.attack_graph.lang_graph.assets) @@ -752,17 +416,17 @@ def serialized_action_to_node( nodes = [self.index_to_node(step_idx)] return nodes - def register_attacker(self, attacker_name: str, attacker_id: int): - super().register_attacker(attacker_name, attacker_id) + def register_attacker(self, attacker_name: str, attacker_id: int) -> None: + self.sim.register_attacker(attacker_name, attacker_id) agent = self.sim.agent_states[attacker_name] self._init_agent(agent) - def register_defender(self, defender_name: str): - super().register_defender(defender_name) + def register_defender(self, defender_name: str) -> None: + self.sim.register_defender(defender_name) agent = self.sim.agent_states[defender_name] self._init_agent(agent) - def _init_agent(self, agent: MalSimAgentStateView): + def _init_agent(self, agent: MalSimAgentStateView) -> None: # Fill dicts with env specific agent obs/infos self._agent_observations[agent.name] = \ self._create_blank_observation() @@ -772,15 +436,15 @@ def _init_agent(self, agent: MalSimAgentStateView): def _update_attacker_obs( self, - compromised_nodes, - disabled_nodes, + compromised_nodes: set[AttackGraphNode], + disabled_nodes: set[AttackGraphNode], attacker_agent: MalSimAttackerState - ): + ) -> None: """Update the observation of the serialized obs attacker""" def _enable_node( - node: AttackGraphNode, agent_observation: dict - ): + node: AttackGraphNode, agent_observation: dict[str, Any] + ) -> None: """Set enabled node obs state to enabled and its children to disabled""" @@ -815,10 +479,10 @@ def _enable_node( def _update_defender_obs( self, - compromised_nodes: list[AttackGraphNode], - disabled_nodes: list[AttackGraphNode], + compromised_nodes: set[AttackGraphNode], + disabled_nodes: set[AttackGraphNode], defender_agent: MalSimDefenderState - ): + ) -> None: """Update the observation of the defender""" defender_observation = self._agent_observations[defender_agent.name] @@ -837,15 +501,17 @@ def _update_defender_obs( def reset( self, - seed: int | None = None, - options: dict | None = None - ) -> tuple[dict, dict]: + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None + ) -> tuple[dict[str, Any], dict[str, Any]]: """Reset simulator and return current observation and infos for each agent""" - MalSimEnv.reset(self, seed, options) - + self.sim.reset(seed=seed, options=options) self.attack_graph = self.sim.attack_graph # new ref + assert self.attack_graph.model, ( + "Attack graph in simulator needs to have a model attached to it" + ) for agent in self.sim.agent_states.values(): # Reset observation and action mask for agents @@ -855,30 +521,34 @@ def reset( self.create_action_mask(agent) # Enable pre-enabled nodes in observation - attacker_entry_points = [ + attacker_entry_points = set( n for n in self.sim.attack_graph.nodes.values() if n.is_compromised() - ] - pre_enabled_defenses = [ + ) + pre_enabled_defenses = set( n for n in self.sim.attack_graph.nodes.values() if n.defense_status == 1.0 - ] + ) for node in attacker_entry_points: node.extras['entrypoint'] = True self._update_observations( - attacker_entry_points + pre_enabled_defenses, [] + attacker_entry_points | pre_enabled_defenses, set() ) # TODO: should we return copies instead so they are not modified externally? return self._agent_observations, self._agent_infos - def _update_observations(self, compromised_nodes, disabled_nodes): + def _update_observations( + self, + compromised_nodes: set[AttackGraphNode], + disabled_nodes: set[AttackGraphNode] + ) -> None: """Update observations of all agents""" if not self.sim.sim_settings.uncompromise_untraversable_steps: - disabled_nodes = [] + disabled_nodes = set() # TODO: Is this correct? All attackers get the same compromised_nodes? logger.debug("Enable:\n\t%s", [n.full_name for n in compromised_nodes]) @@ -894,13 +564,22 @@ def _update_observations(self, compromised_nodes, disabled_nodes): compromised_nodes, disabled_nodes, agent ) - def step(self, actions: dict[str, tuple[int, int]]): + def step( + self, actions: dict[str, tuple[int, Optional[int]]] + ) -> tuple[ + dict[str, dict[str, Any]], + dict[str, float], + dict[str, bool], + dict[str, bool], + dict[str, dict[str, Any]] + ]: """Perform step with mal simulator and observe in parallel env""" - malsim_actions = {} + malsim_actions: dict[str, list[AttackGraphNode]] = {} for agent_name, agent_action in actions.items(): malsim_actions[agent_name] = [] - if agent_action[0]: + + if agent_action[0] and agent_action[1] is not None: # If agent wants to act, convert index to node malsim_actions[agent_name].append( self.index_to_node(agent_action[1]) @@ -908,11 +587,10 @@ def step(self, actions: dict[str, tuple[int, int]]): states = self.sim.step(malsim_actions) - all_actioned = [ - n - for state in states.values() + all_actioned = set( + n for state in states.values() for n in state.step_performed_nodes - ] + ) disabled_nodes = next(iter(states.values())).step_unviable_nodes self._update_agent_infos() # Update action masks diff --git a/malsim/mal_simulator.py b/malsim/mal_simulator.py index c331fc67..b359dec0 100644 --- a/malsim/mal_simulator.py +++ b/malsim/mal_simulator.py @@ -9,8 +9,12 @@ from maltoolbox import neo4j_configs from maltoolbox.ingestors import neo4j -from maltoolbox.attackgraph import (AttackGraph, AttackGraphNode, - Attacker, query) +from maltoolbox.attackgraph import ( + AttackGraph, + AttackGraphNode, + Attacker, + query +) from maltoolbox.attackgraph.analyzers import apriori ITERATIONS_LIMIT = int(1e9) @@ -33,7 +37,7 @@ class MalSimAgentState: # Contains current agent reward in the simulation # Attackers get positive rewards, defenders negative - reward: int = 0 + reward: float = 0 # Contains possible actions for the agent in the next step action_surface: set[AttackGraphNode] = field(default_factory=set) @@ -90,21 +94,21 @@ class MalSimAgentStateView(MalSimAttackerState, MalSimDefenderState): _frozen = False - def __init__(self, agent): + def __init__(self, agent: MalSimAgentState): self._agent = agent self._frozen = True - def __setattr__(self, key, value) -> None: + def __setattr__(self, key: str, value: Any) -> None: if self._frozen: raise AttributeError("Cannot modify agent state view") super().__setattr__(key, value) - def __delattr__(self, key) -> None: + def __delattr__(self, key: str) -> None: if self._frozen: raise AttributeError("Cannot modify agent state view") super().__delattr__(key) - def __getattribute__(self, attr) -> Any: + def __getattribute__(self, attr: str) -> Any: """Return read-only version of proxied agent's properties. If the attribute exists in the View only return it from there. Using @@ -135,7 +139,7 @@ def __getattribute__(self, attr) -> Any: return value - def __dir__(self): + def __dir__(self) -> list[str]: """Dynamically resolve attribute names for REPL autocompletion.""" dunder_attrs = [ attr @@ -173,7 +177,7 @@ def __init__( attack_graph: AttackGraph, prune_unviable_unnecessary: bool = True, sim_settings: MalSimulatorSettings = MalSimulatorSettings(), - max_iter=ITERATIONS_LIMIT, + max_iter: int = ITERATIONS_LIMIT, ): """ Args: @@ -200,7 +204,7 @@ def __init__( self.cur_iter = 0 # Keep track on current iteration # All internal agent states (dead or alive) - self._agent_states: dict[str, MalSimAgentState] = {} + self._agent_states: dict[str, MalSimAttackerState | MalSimDefenderState] = {} # Keep track on all 'living' agents sorted by order to step in self._alive_agents: set[str] = set() @@ -208,7 +212,7 @@ def __init__( def reset( self, seed: Optional[int] = None, - options: Optional[dict] = None + options: Optional[dict[str, Any]] = None ) -> dict[str, MalSimAgentStateView]: """Reset attack graph, iteration and reinitialize agents""" @@ -270,7 +274,7 @@ def _update_attacker_state( attacker_state: MalSimAttackerState, step_agent_compromised_nodes: set[AttackGraphNode], step_nodes_made_unviable: set[AttackGraphNode] - ): + ) -> None: """ Update a previous attacker state based on what the agent compromised and what nodes became unviable. @@ -314,7 +318,7 @@ def _update_defender_state( step_all_compromised_nodes: set[AttackGraphNode], step_enabled_defenses: set[AttackGraphNode], step_nodes_made_unviable: set[AttackGraphNode], - ): + ) -> None: """ Update a previous defender state based on what steps were enabled/compromised during last step @@ -332,7 +336,7 @@ def _update_defender_state( self._get_attacker_agents() ) - def _reset_agents(self): + def _reset_agents(self) -> None: """Reset agent states to a fresh start""" # Revive all agents @@ -340,21 +344,24 @@ def _reset_agents(self): # Create new attacker agent states compromised_steps = set() - for agent_state in self._get_attacker_agents(): + for attacker_state in self._get_attacker_agents(): # Create a new agent state for the attacker - attacker = self.attack_graph.attackers[agent_state.attacker.id] - self._agent_states[agent_state.name] = ( - self._create_attacker_state(agent_state.name, attacker) + assert attacker_state.attacker.id is not None, ( + f"Attacker {attacker_state.attacker} must have ID defined" + ) + attacker = self.attack_graph.attackers[attacker_state.attacker.id] + self._agent_states[attacker_state.name] = ( + self._create_attacker_state(attacker_state.name, attacker) ) compromised_steps |= attacker.reached_attack_steps # Create new defender agent states - for agent_state in self._get_defender_agents(): - self._agent_states[agent_state.name] = ( - self._create_defender_state(agent_state.name) + for defender_state in self._get_defender_agents(): + self._agent_states[defender_state.name] = ( + self._create_defender_state(defender_state.name) ) - def register_attacker(self, name: str, attacker_id: int): + def register_attacker(self, name: str, attacker_id: int) -> None: """Register a mal sim attacker agent""" assert name not in self._agent_states, \ f"Duplicate agent named {name} not allowed" @@ -370,7 +377,7 @@ def register_attacker(self, name: str, attacker_id: int): self._agent_states[name] = agent_state self._alive_agents.add(name) - def register_defender(self, name: str): + def register_defender(self, name: str) -> None: """Register a mal sim defender agent""" assert name not in self._agent_states, \ f"Duplicate agent named {name} not allowed" @@ -405,7 +412,7 @@ def _get_defender_agents(self) -> list[MalSimDefenderState]: def _uncompromise_attack_steps( self, attack_steps_to_uncompromise: set[AttackGraphNode] - ): + ) -> None: """Uncompromise nodes for each attacker agent Go through the nodes in `attack_steps_to_uncompromise` for each @@ -431,12 +438,14 @@ def _uncompromise_attack_steps( def _attacker_step( self, agent: MalSimAttackerState, nodes: list[AttackGraphNode] - ): + ) -> set[AttackGraphNode]: """Compromise attack step nodes with attacker Args: agent - the agent to compromise nodes with nodes - the nodes to compromise + + Returns: set of nodes that were compromised. """ compromised_nodes = set() @@ -472,13 +481,15 @@ def _attacker_step( def _defender_step( self, agent: MalSimDefenderState, nodes: list[AttackGraphNode] - ): + ) -> tuple[set[AttackGraphNode], set[AttackGraphNode]]: """Enable defense step nodes with defender. Args: agent - the agent to activate defense nodes with nodes - the defense step nodes to enable + Returns a tuple of two sets, `enabled_defenses` + and `attack_steps_made_unviable`. """ enabled_defenses = set() @@ -521,7 +532,7 @@ def _defender_step( return enabled_defenses, attack_steps_made_unviable @staticmethod - def _attacker_reward(attacker_state: MalSimAttackerState): + def _attacker_reward(attacker_state: MalSimAttackerState) -> float: """ Calculate current attacker reward by adding this steps compromised node rewards to the previous attacker reward. @@ -534,12 +545,12 @@ def _attacker_reward(attacker_state: MalSimAttackerState): """ # Attacker is rewarded for compromised nodes return attacker_state.reward + sum( - n.extras.get("reward", 0) + float(n.extras.get("reward", 0)) for n in attacker_state.step_performed_nodes ) @staticmethod - def _defender_reward(defender_state: MalSimDefenderState): + def _defender_reward(defender_state: MalSimDefenderState) -> float: """ Calculate current defender reward by subtracting this steps compromised/enabled node rewards from the previous defender reward. @@ -554,7 +565,7 @@ def _defender_reward(defender_state: MalSimDefenderState): step_enabled_defenses = defender_state.step_performed_nodes step_compromised_nodes = defender_state.step_all_compromised_nodes return defender_state.reward - sum( - n.extras.get("reward", 0) + float(n.extras.get("reward", 0)) for n in step_enabled_defenses | step_compromised_nodes ) @@ -644,7 +655,7 @@ def step( self.cur_iter += 1 return self.agent_states - def render(self): + def render(self) -> None: """Render attack graph from simulation in Neo4J""" logger.debug("Sending attack graph to Neo4J database.") neo4j.ingest_attack_graph( diff --git a/malsim/scenario.py b/malsim/scenario.py index a79731a5..7101bd86 100644 --- a/malsim/scenario.py +++ b/malsim/scenario.py @@ -12,7 +12,7 @@ """ import os -from typing import Any, Optional +from typing import Any, Optional, TextIO import yaml @@ -66,7 +66,7 @@ ] -def validate_scenario(scenario_dict): +def validate_scenario(scenario_dict: dict[str, Any]) -> None: """Verify scenario file keys""" # Verify that all keys in dict are supported @@ -83,7 +83,7 @@ def validate_scenario(scenario_dict): raise RuntimeError(f"Setting '{key}' missing from scenario file") -def path_relative_to_file_dir(rel_path, file): +def path_relative_to_file_dir(rel_path: str, file: TextIO) -> str: """Returns the absolute path of a relative path in a second file Arguments: @@ -98,7 +98,7 @@ def path_relative_to_file_dir(rel_path, file): def _validate_scenario_node_property_config( - graph: AttackGraph, prop_config: dict): + graph: AttackGraph, prop_config: dict[str, dict[str, Any]]) -> None: """Verify that node property configurations in a scenario contains only valid assets, asset types and step names""" @@ -133,6 +133,9 @@ def _validate_scenario_node_property_config( ) # TODO: revisit this variable once LookupDicts are merged + assert graph.model, ( + "Attack graph in scenario needs to have a model attached to it" + ) asset_names = set(a.name for a in graph.model.assets.values()) for asset_name in prop_config.get('by_asset_name', []): # Make sure each specified asset exist @@ -155,11 +158,11 @@ def _validate_scenario_node_property_config( def apply_scenario_node_property( attack_graph: AttackGraph, node_prop: str, - prop_config: dict, + prop_config: dict[str, dict[str, Any]], assumed_value: Optional[Any] = None, default_value: Optional[Any] = None, set_as_extras: bool = True -): +) -> None: """Apply node property values from scenario configuration. Note: Property values provided 'by_asset_name' will take precedence over @@ -182,7 +185,9 @@ def apply_scenario_node_property( themselves. """ - def _extract_value_from_entries(entries: dict|list, step_name: str) -> Any: + def _extract_value_from_entries( + entries: dict[str, Any] | list[str], step_name: str + ) -> Any: """ Return the property value matching the step name in the provided entries. @@ -207,8 +212,12 @@ def _extract_value_from_entries(entries: dict|list, step_name: str) -> Any: raise ValueError('Error! Scenario node property configuration ' 'is neither dictionary, nor list!') - def _set_value(step: AttackGraphNode, node_prop: str, value: Any, - set_as_extras:bool): + def _set_value( + step: AttackGraphNode, + node_prop: str, + value: Any, + set_as_extras: bool + ) -> None: """ Set the value of the node property to the value provided @@ -254,6 +263,9 @@ def _set_value(step: AttackGraphNode, node_prop: str, value: Any, # Check for matching specific asset(given by name) property # configuration entry + assert step.model_asset, ( + f"Attack step {step} missing connection to model" + ) prop_specific_asset_entries = ( prop_config.get('by_asset_name', {}) .get(step.model_asset.name, {}) @@ -315,7 +327,7 @@ def create_scenario_attacker( def load_simulator_agents( - attack_graph: AttackGraph, scenario: dict + attack_graph: AttackGraph, scenario: dict[str, Any] ) -> list[dict[str, Any]]: """Load agents to be registered in MALSimulator @@ -379,7 +391,7 @@ def load_simulator_agents( def apply_scenario_to_attack_graph( - attack_graph: AttackGraph, scenario: dict): + attack_graph: AttackGraph, scenario: dict[str, Any]) -> None: """Update attack graph according to scenario configuration Apply scenario configurations from a loaded scenario file @@ -451,8 +463,8 @@ def load_scenario(scenario_file: str) -> tuple[AttackGraph, list[dict[str, Any]] def create_simulator_from_scenario( scenario_file: str, - sim_class=MalSimulator, - **kwargs, + sim_class: Any = MalSimulator, + **kwargs: Any, ) -> tuple[MalSimulator, list[dict[str, Any]]]: """Creates and returns a MalSimulator created according to scenario file diff --git a/pyproject.toml b/pyproject.toml index a9482dce..65c57f4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,3 +55,12 @@ include = ["malsim*"] [tool.pytest.ini_options] pythonpath = ['.'] + +[tool.mypy] +strict = true +ignore_missing_imports = true +warn_unused_ignores = true +warn_no_return = true +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_any_generics = true diff --git a/tests/agents/test_heuristic_agents.py b/tests/agents/test_heuristic_agents.py index 841e933b..da5bd033 100644 --- a/tests/agents/test_heuristic_agents.py +++ b/tests/agents/test_heuristic_agents.py @@ -7,7 +7,9 @@ DefendFutureCompromisedDefender ) -def test_defend_compromised_defender(dummy_lang_graph: LanguageGraph): +def test_defend_compromised_defender( + dummy_lang_graph: LanguageGraph + ) -> None: r""" node1 node2 / \ / \ @@ -72,7 +74,9 @@ def test_defend_compromised_defender(dummy_lang_graph: LanguageGraph): assert action_node.id == node1.id -def test_defend_future_compromised_defender(dummy_lang_graph: LanguageGraph): +def test_defend_future_compromised_defender( + dummy_lang_graph: LanguageGraph + ) -> None: r""" node1 node2 / \ / \ diff --git a/tests/agents/test_searchers.py b/tests/agents/test_searchers.py index 7eb006a4..9187e032 100644 --- a/tests/agents/test_searchers.py +++ b/tests/agents/test_searchers.py @@ -6,7 +6,9 @@ from malsim.agents import BreadthFirstAttacker, DepthFirstAttacker -def test_breadth_first_traversal_simple(dummy_lang_graph: LanguageGraph): +def test_breadth_first_traversal_simple( + dummy_lang_graph: LanguageGraph + ) -> None: """ node1 | @@ -74,7 +76,9 @@ def test_breadth_first_traversal_simple(dummy_lang_graph: LanguageGraph): assert actual_order == expected_order, \ "Traversal order does not match expected breadth-first order" -def test_breadth_first_traversal_complicated(dummy_lang_graph: LanguageGraph): +def test_breadth_first_traversal_complicated( + dummy_lang_graph: LanguageGraph + ) -> None: r""" node1 ______________ / \ \ @@ -158,7 +162,9 @@ def test_breadth_first_traversal_complicated(dummy_lang_graph: LanguageGraph): "Traversal order does not match expected breadth-first order" -def test_depth_first_traversal_complicated(dummy_lang_graph: LanguageGraph): +def test_depth_first_traversal_complicated( + dummy_lang_graph: LanguageGraph + ) -> None: r""" node1 ______________ / \ \ diff --git a/tests/conftest.py b/tests/conftest.py index 8f80409f..6f36c5ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ ## Helpers -def path_testdata(filename): +def path_testdata(filename: str) -> str: """Returns the absolute path of a test data file (in ./testdata) Arguments: @@ -24,16 +24,6 @@ def path_testdata(filename): current_dir = path.dirname(path.realpath(__file__)) return path.join(current_dir, f"testdata/{filename}") - -def empty_model(name, lang_classes_factory): - """Fixture that generates a model for tests - - Uses coreLang specification (fixture) to create and return Model - """ - - # Create instance model from model json file - return Model(name, lang_classes_factory) - ## Fixtures @pytest.fixture(scope="session", name="env") @@ -50,21 +40,21 @@ def fixture_env()-> MalSimVectorizedObsEnv: @pytest.fixture -def corelang_lang_graph(): +def corelang_lang_graph() -> LanguageGraph: """Fixture that returns the coreLang language specification as dict""" mar_file_path = path_testdata("org.mal-lang.coreLang-1.0.0.mar") return LanguageGraph.from_mar_archive(mar_file_path) @pytest.fixture -def traininglang_lang_graph(): +def traininglang_lang_graph() -> LanguageGraph: """Fixture that returns the trainingLang language specification as dict""" mar_file_path = path_testdata("langs/org.mal-lang.trainingLang-1.0.0.mar") return LanguageGraph.from_mar_archive(mar_file_path) @pytest.fixture -def traininglang_model(traininglang_lang_graph): +def traininglang_model(traininglang_lang_graph: LanguageGraph) -> Model: """Fixture that generates a model for tests Uses trainingLang specification (fixture) to create and return a @@ -76,7 +66,7 @@ def traininglang_model(traininglang_lang_graph): @pytest.fixture -def model(corelang_lang_graph): +def model(corelang_lang_graph: LanguageGraph) -> Model: """Fixture that generates a model for tests Uses coreLang specification (fixture) to create and return a @@ -86,7 +76,7 @@ def model(corelang_lang_graph): return Model.load_from_file(model_file_name, corelang_lang_graph) @pytest.fixture -def dummy_lang_graph(corelang_lang_graph): +def dummy_lang_graph(corelang_lang_graph: LanguageGraph) -> LanguageGraph: """Fixture that generates a dummy LanguageGraph with a dummy LanguageGraphAsset and LanguageGraphAttackStep """ diff --git a/tests/envs/test_example_scenarios.py b/tests/envs/test_example_scenarios.py index ea7637e8..975bdbb1 100644 --- a/tests/envs/test_example_scenarios.py +++ b/tests/envs/test_example_scenarios.py @@ -5,7 +5,7 @@ from malsim.scenario import create_simulator_from_scenario -def test_bfs_vs_bfs_state_and_reward(): +def test_bfs_vs_bfs_state_and_reward() -> None: """ The point of this test is to see that a specific scenario runs deterministically. @@ -34,8 +34,8 @@ def test_bfs_vs_bfs_state_and_reward(): attacker_agent = attacker_agent_info["agent"] defender_agent = defender_agent_info["agent"] - total_reward_defender = 0 - total_reward_attacker = 0 + total_reward_defender = 0.0 + total_reward_attacker = 0.0 attacker_actions = [] defender_actions = [] diff --git a/tests/envs/test_gym_envs.py b/tests/envs/test_gym_envs.py index 7c7fa592..690bdfff 100644 --- a/tests/envs/test_gym_envs.py +++ b/tests/envs/test_gym_envs.py @@ -1,3 +1,8 @@ +# type: ignore +# Ignoring type checking in this file for now +# before someone with more Gymnasium knowledge +# can jump into the code base + import sys import os @@ -26,18 +31,18 @@ scenario_file_no_defender = 'tests/testdata/scenarios/no_defender_agent_scenario.yml' -def register_gym_agent(agent_id, entry_point): +def register_gym_agent(agent_id: str, entry_point: gym.Env) -> None: if agent_id not in gym.envs.registry.keys(): gym.register(agent_id, entry_point=entry_point) -def test_pz(env: MalSimVectorizedObsEnv): +def test_pz(env: MalSimVectorizedObsEnv) -> None: logger.debug('Run Parrallel API test.') parallel_api_test(env) # Check that an environment follows Gym API -def test_gym(): +def test_gym() -> None: logger.debug('Run Gym Test.') register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( @@ -53,7 +58,7 @@ def test_gym(): env_checker.check_env(env.unwrapped) -def test_random_defender_actions(): +def test_random_defender_actions() -> None: register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( 'MALDefenderEnv-v0', @@ -79,7 +84,7 @@ def available_actions(x): done = term or trunc -def test_episode(): +def test_episode() -> None: logger.debug('Run Episode Test.') register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( @@ -102,7 +107,7 @@ def test_episode(): # assert _return < 0.0 # If the defender does nothing then it will get a penalty for being attacked -def test_mask(): +def test_mask() -> None: register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( 'MALDefenderEnv-v0', @@ -117,7 +122,7 @@ def test_mask(): print(obs) -def test_defender_penalty(): +def test_defender_penalty() -> None: register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( 'MALDefenderEnv-v0', @@ -131,7 +136,7 @@ def test_defender_penalty(): # assert reward < 0 # All defense steps cost something -def test_action_mask(): +def test_action_mask() -> None: register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( 'MALDefenderEnv-v0', diff --git a/tests/envs/test_vectorized_obs_mal_simulator.py b/tests/envs/test_vectorized_obs_mal_simulator.py index 84bdf037..8c663fbd 100644 --- a/tests/envs/test_vectorized_obs_mal_simulator.py +++ b/tests/envs/test_vectorized_obs_mal_simulator.py @@ -1,11 +1,19 @@ """Test MalSimulator class""" +from __future__ import annotations +from typing import TYPE_CHECKING, Any from maltoolbox.attackgraph import AttackGraph, Attacker from malsim.mal_simulator import MalSimulator from malsim.envs import MalSimVectorizedObsEnv from malsim.scenario import load_scenario -def test_create_blank_observation(corelang_lang_graph, model): +if TYPE_CHECKING: + from maltoolbox.language import LanguageGraph + from maltoolbox.model import Model + +def test_create_blank_observation( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Make sure blank observation contains correct default values""" attack_graph = AttackGraph(corelang_lang_graph, model) @@ -58,8 +66,8 @@ def test_create_blank_observation(corelang_lang_graph, model): def test_create_blank_observation_deterministic( - corelang_lang_graph, model - ): + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Make sure blank observation is deterministic with seed given""" attack_graph = AttackGraph(corelang_lang_graph, model) @@ -94,8 +102,8 @@ def test_create_blank_observation_deterministic( def test_step_deterministic( - corelang_lang_graph, model - ): + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Make sure blank observation is deterministic with seed given""" attack_graph = AttackGraph(corelang_lang_graph, model) @@ -106,8 +114,8 @@ def test_step_deterministic( sim.register_attacker("test_attacker", attacker.id) sim.register_defender("test_defender") - obs1 = {} - obs2 = {} + obs1: dict[str, Any] = {} + obs2: dict[str, Any] = {} # Run 1 sim.reset(seed=123) @@ -138,8 +146,8 @@ def test_step_deterministic( def test_create_blank_observation_observability_given( - corelang_lang_graph, model - ): + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Make sure observability propagates correctly from extras field/scenario to observation in mal simulator""" @@ -171,8 +179,8 @@ def test_create_blank_observation_observability_given( assert not observable def test_create_blank_observation_actionability_given( - corelang_lang_graph, model - ): + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Make sure actionability propagates correctly from extras field/scenario to observation in mal simulator""" @@ -201,7 +209,7 @@ def test_create_blank_observation_actionability_given( assert not actionable -def test_malsimulator_observe_attacker(): +def test_malsimulator_observe_attacker() -> None: attack_graph, _ = load_scenario( 'tests/testdata/scenarios/simple_scenario.yml') @@ -267,15 +275,15 @@ def test_malsimulator_observe_attacker(): assert state == 0 -def test_malsimulator_observe_and_reward_attacker_defender(): +def test_malsimulator_observe_and_reward_attacker_defender() -> None: """Run attacker and defender actions and make sure rewards and observation states are updated correctly""" def verify_attacker_obs_state( - observed_state, - expected_reached, - expected_children_of_reached - ): + observed_state: list[int], + expected_reached: list[int], + expected_children_of_reached: list[int] + ) -> None: """Make sure obs state looks as expected""" for index, state in enumerate(observed_state): node_id = env._index_to_id[index] @@ -287,8 +295,8 @@ def verify_attacker_obs_state( assert state == -1 def verify_defender_obs_state( - observed_state - ): + observed_state: list[int] + ) -> None: """Make sure obs state looks as expected""" for index, state in enumerate(observed_state): node = env.index_to_node(index) @@ -427,7 +435,9 @@ def verify_defender_obs_state( assert rew[defender_agent_name] == - rew[attacker_agent_name] - reward_host_0_not_present -def test_malsimulator_initial_observation_defender(corelang_lang_graph, model): +def test_malsimulator_initial_observation_defender( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Make sure ._observe_defender observes nodes and set observed state""" attack_graph = AttackGraph(corelang_lang_graph, model) @@ -452,8 +462,8 @@ def test_malsimulator_initial_observation_defender(corelang_lang_graph, model): def test_malsimulator_observe_and_reward_attacker_no_entrypoints( - corelang_lang_graph, model - ): + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) attacker = Attacker("TestAttacker", [], []) @@ -475,8 +485,8 @@ def test_malsimulator_observe_and_reward_attacker_no_entrypoints( def test_malsimulator_observe_and_reward_attacker_entrypoints( - traininglang_lang_graph, traininglang_model - ): + traininglang_lang_graph: LanguageGraph, traininglang_model: Model + ) -> None: attack_graph = AttackGraph( traininglang_lang_graph, traininglang_model) diff --git a/tests/test_main.py b/tests/test_main.py index 062405b4..045f1d7b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2,13 +2,14 @@ import os from unittest.mock import patch +from typing import Any from malsim.__main__ import run_simulation from malsim.scenario import create_simulator_from_scenario from malsim.mal_simulator import MalSimulator -def path_relative_to_tests(filename): +def path_relative_to_tests(filename: str) -> str: """Returns the absolute path of a file in ./tests Arguments: @@ -19,7 +20,7 @@ def path_relative_to_tests(filename): @patch("builtins.input", return_value="\n") # to not freeze on input() -def test_run_simulation(mock_input): +def test_run_simulation(mock_input: Any) -> None: """Make sure we can run simulation with defender agent registered in scenario""" @@ -32,7 +33,7 @@ def test_run_simulation(mock_input): run_simulation(sim, agents) @patch("builtins.input", return_value="\n") # to not freeze on input() -def test_run_simulation_without_defender_agent(mock_input): +def test_run_simulation_without_defender_agent(mock_input: Any) -> None: """Make sure we can run simulation without defender agent registered in scenario""" diff --git a/tests/test_mal_simulator.py b/tests/test_mal_simulator.py index 1c6d5fca..37111f09 100644 --- a/tests/test_mal_simulator.py +++ b/tests/test_mal_simulator.py @@ -1,16 +1,22 @@ """Test MalSimulator class""" +from __future__ import annotations +from typing import TYPE_CHECKING from maltoolbox.attackgraph import AttackGraphNode, AttackGraph, Attacker from malsim.mal_simulator import MalSimulator +from malsim.scenario import load_scenario, create_simulator_from_scenario +from malsim.mal_simulator import MalSimDefenderState, MalSimAttackerState -from malsim.scenario import load_scenario +if TYPE_CHECKING: + from maltoolbox.language import LanguageGraph + from maltoolbox.model import Model -def test_init(corelang_lang_graph, model): +def test_init(corelang_lang_graph: LanguageGraph, model: Model) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) MalSimulator(attack_graph) -def test_reset(corelang_lang_graph, model): +def test_reset(corelang_lang_graph: LanguageGraph, model: Model) -> None: """Make sure attack graph is reset""" attack_graph = AttackGraph(corelang_lang_graph, model) @@ -52,7 +58,9 @@ def test_reset(corelang_lang_graph, model): assert attack_graph_before._to_dict() == attack_graph_after._to_dict() -def test_register_agent_attacker(corelang_lang_graph, model): +def test_register_agent_attacker( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) attack_graph.attach_attackers() sim = MalSimulator(attack_graph) @@ -64,7 +72,9 @@ def test_register_agent_attacker(corelang_lang_graph, model): assert agent_name in sim.agent_states -def test_register_agent_defender(corelang_lang_graph, model): +def test_register_agent_defender( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) sim = MalSimulator(attack_graph) @@ -75,7 +85,9 @@ def test_register_agent_defender(corelang_lang_graph, model): assert agent_name in sim.agent_states -def test_register_agent_action_surface(corelang_lang_graph, model): +def test_register_agent_action_surface( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) sim = MalSimulator(attack_graph) @@ -87,7 +99,9 @@ def test_register_agent_action_surface(corelang_lang_graph, model): assert node.is_available_defense() -def test_simulator_initialize_agents(corelang_lang_graph, model): +def test_simulator_initialize_agents( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: """Test _initialize_agents""" ag, _ = load_scenario('tests/testdata/scenarios/simple_scenario.yml') @@ -104,21 +118,26 @@ def test_simulator_initialize_agents(corelang_lang_graph, model): assert set(sim.agent_states.keys()) == {attacker_name, defender_name} -def test_get_agents(): +def test_get_agents() -> None: """Test _get_attacker_agents and _get_defender_agents""" - ag, _ = load_scenario('tests/testdata/scenarios/simple_scenario.yml') - sim = MalSimulator(ag) + sim, _ = create_simulator_from_scenario( + 'tests/testdata/scenarios/simple_scenario.yml' + ) sim.reset() - sim._get_attacker_agents() == ['attacker'] - sim._get_defender_agents() == ['defender'] + assert [a.name for a in sim._get_attacker_agents()] == ['Attacker1'] + assert [a.name for a in sim._get_defender_agents()] == ['Defender1'] -def test_attacker_step(corelang_lang_graph, model): +def test_attacker_step( + corelang_lang_graph: LanguageGraph, model: Model + ) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) entry_point = attack_graph.get_node_by_full_name('OS App:fullAccess') + assert entry_point, "OS App:fullAccess Should exist" + attacker = Attacker( 'attacker1', reached_attack_steps = {entry_point}, @@ -130,19 +149,23 @@ def test_attacker_step(corelang_lang_graph, model): sim.register_attacker(attacker.name, attacker.id) sim.reset() + attacker_agent = sim._agent_states[attacker.name] + assert isinstance(attacker_agent, MalSimAttackerState) # Can not attack the notPresent step defense_step = sim.attack_graph.get_node_by_full_name('OS App:notPresent') - actions = sim._attacker_step(attacker_agent, {defense_step}) + assert defense_step + actions = sim._attacker_step(attacker_agent, [defense_step]) assert not actions attack_step = sim.attack_graph.get_node_by_full_name('OS App:attemptRead') - actions = sim._attacker_step(attacker_agent, {attack_step}) + assert attack_step + actions = sim._attacker_step(attacker_agent, [attack_step]) assert actions == {attack_step} -def test_defender_step(corelang_lang_graph, model): +def test_defender_step(corelang_lang_graph: LanguageGraph, model: Model) -> None: attack_graph = AttackGraph(corelang_lang_graph, model) sim = MalSimulator(attack_graph) @@ -151,28 +174,33 @@ def test_defender_step(corelang_lang_graph, model): sim.reset() defender_agent = sim._agent_states[defender_name] + assert isinstance(defender_agent, MalSimDefenderState) + defense_step = sim.attack_graph.get_node_by_full_name( 'OS App:notPresent') - enabled, made_unviable = sim._defender_step(defender_agent, {defense_step}) + assert defense_step + enabled, made_unviable = sim._defender_step(defender_agent, [defense_step]) assert enabled == {defense_step} assert made_unviable # Can not defend attack_step attack_step = sim.attack_graph.get_node_by_full_name( 'OS App:attemptUseVulnerability') - enabled, made_unviable = sim._defender_step(defender_agent, {attack_step}) + assert attack_step + enabled, made_unviable = sim._defender_step(defender_agent, [attack_step]) assert enabled == set() assert not made_unviable -def test_agent_state_views_simple(corelang_lang_graph, model): +def test_agent_state_views_simple(corelang_lang_graph: LanguageGraph, model: Model) -> None: - def get_node(full_name) -> AttackGraphNode: + def get_node(full_name: str) -> AttackGraphNode: node = sim.attack_graph.get_node_by_full_name(full_name) assert node return node attack_graph = AttackGraph(corelang_lang_graph, model) entry_point = attack_graph.get_node_by_full_name('OS App:fullAccess') + assert entry_point, "Should exist" attacker = Attacker( 'attacker1', @@ -283,7 +311,7 @@ def get_node(full_name) -> AttackGraphNode: assert len(dsv.step_unviable_nodes) == 55 -def test_observe_attacker(): +def test_observe_attacker() -> None: attack_graph, _ = load_scenario( 'tests/testdata/scenarios/simple_scenario.yml' ) @@ -305,7 +333,7 @@ def test_observe_attacker(): assert len(attacker.reached_attack_steps) == 1 -def test_step_attacker_defender_action_surface_updates(): +def test_step_attacker_defender_action_surface_updates() -> None: ag, _ = load_scenario( 'tests/testdata/scenarios/traininglang_scenario.yml') @@ -348,7 +376,7 @@ def test_step_attacker_defender_action_surface_updates(): assert defender_step not in defender_agent.action_surface -def test_default_simulator_default_settings_eviction(): +def test_default_simulator_default_settings_eviction() -> None: """Test attacker node eviction using MalSimulatorSettings default""" ag, _ = load_scenario( 'tests/testdata/scenarios/traininglang_scenario.yml', diff --git a/tests/test_scenario.py b/tests/test_scenario.py index 7b13481e..b5aabb1c 100644 --- a/tests/test_scenario.py +++ b/tests/test_scenario.py @@ -11,7 +11,7 @@ ) from malsim.agents import PassiveAgent, BreadthFirstAttacker -def path_relative_to_tests(filename): +def path_relative_to_tests(filename: str) -> str: """Returns the absolute path of a file in ./tests Arguments: @@ -21,7 +21,7 @@ def path_relative_to_tests(filename): return os.path.join(current_dir, f"{filename}") -def test_load_scenario(): +def test_load_scenario() -> None: """Make sure we can load a scenario""" # Load the scenario @@ -68,7 +68,7 @@ def test_load_scenario(): assert isinstance(agents[1]['agent'], PassiveAgent) -def test_load_scenario_no_attacker_in_model(): +def test_load_scenario_no_attacker_in_model() -> None: """Make sure we can load a scenario""" # Load the scenario @@ -89,7 +89,7 @@ def test_load_scenario_no_attacker_in_model(): assert attack_step in attacker.entry_points -def test_load_scenario_attacker_in_model(): +def test_load_scenario_attacker_in_model() -> None: """ Make sure model attacker is removed if scenario has attacker Make sure model attacker is not removed if scenario has no attacker @@ -115,7 +115,7 @@ def test_load_scenario_attacker_in_model(): assert all_attackers[0].name == 'Attacker:15' # From scenario -def test_load_scenario_no_defender_agent(): +def test_load_scenario_no_defender_agent() -> None: """Make sure we can load a scenario""" # Load the scenario @@ -128,7 +128,7 @@ def test_load_scenario_no_defender_agent(): assert isinstance(agents[0]['agent'], BreadthFirstAttacker) -def test_load_scenario_agent_class_error(): +def test_load_scenario_agent_class_error() -> None: """Make sure we get error when loading with wrong class""" # Load the scenario @@ -140,7 +140,7 @@ def test_load_scenario_agent_class_error(): ) -def test_load_scenario_observability_given(): +def test_load_scenario_observability_given() -> None: """Load a scenario with observability settings given and make sure observability is applied correctly""" @@ -163,7 +163,7 @@ def test_load_scenario_observability_given(): assert not node.extras['observable'] -def test_load_scenario_observability_not_given(): +def test_load_scenario_observability_not_given() -> None: """Load a scenario where no observability settings are given""" # Load scenario with no observability specifed attack_graph, _ = load_scenario( @@ -177,7 +177,7 @@ def test_load_scenario_observability_not_given(): assert node.extras['observable'] -def test_apply_scenario_observability(): +def test_apply_scenario_observability() -> None: """Try different cases for observability settings""" # Load scenario with no observability specified @@ -219,7 +219,7 @@ def test_apply_scenario_observability(): else: assert not node.extras['observable'] -def test_apply_scenario_observability_faulty(): +def test_apply_scenario_observability_faulty() -> None: """Try different failing cases for observability settings""" # Load scenario with no observability specified @@ -297,7 +297,7 @@ def test_apply_scenario_observability_faulty(): ) -def test_load_scenario_false_positive_negative_rate(): +def test_load_scenario_false_positive_negative_rate() -> None: """Load a scenario with observability settings given and make sure observability is applied correctly""" @@ -337,7 +337,7 @@ def test_load_scenario_false_positive_negative_rate(): assert 'false_positive_rate' not in node.extras assert 'false_negative_rate' not in node.extras -def test_apply_scenario_fpr_fnr(): +def test_apply_scenario_fpr_fnr() -> None: """Try different cases for false positives/negatives rates""" # Load scenario with no specified @@ -389,7 +389,7 @@ def test_apply_scenario_fpr_fnr(): assert 'false_negative_rate' not in node.extras -def test_apply_scenario_rewards_old_format(): +def test_apply_scenario_rewards_old_format() -> None: """Try different cases for rewards""" # Load scenario with no specified