diff --git a/.github/workflows/test-pytest.yml b/.github/workflows/test-pytest.yml index 86af6187..ffa2d717 100644 --- a/.github/workflows/test-pytest.yml +++ b/.github/workflows/test-pytest.yml @@ -17,8 +17,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install . + pip install ".[ml]" pip install pytest - name: Test with pytest run: | - pytest tests \ No newline at end of file + pytest tests diff --git a/.gitignore b/.gitignore index 47389d12..67b803eb 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ *.swp *.swo tmp/ +logs/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 602d5d65..c3dfdf43 100644 --- a/README.md +++ b/README.md @@ -35,12 +35,6 @@ they are a setup for running a simulation. This is how the format looks like: lang_file: model_file: -attacker_agent_class: 'BreadthFirstAttacker' | 'DepthFirstAttacker' | 'KeyboardAgent' - -# For defender_agent_class, null and False are treated the same - no defender will be used in the simulation -defender_agent_class: 'BreadthFirstAttacker' | 'DepthFirstAttacker' | 'KeyboardAgent' | null | False - - # Optionally add rewards for each attack step rewards: : @@ -50,17 +44,19 @@ rewards: # Data A:read: 100 ... +# Add entry points to AttackGraph with attacker names +# and attack step full_names +agents: + 'Attacker1': + type: 'attacker' + agent_class: BreadthFirstAttacker | DepthFirstAttacker | KeyboardAgent | null + entry_points: + - 'Credentials:6:attemptCredentialsReuse' -# Optionally add entry points to AttackGraph with attacker name and attack step full_names. -# NOTE: If attacker entry points defined in both model and scenario, -# the scenario overrides the ones in the model. -attacker_entry_points: - : - - + 'Defender1': + type: 'defender' + agent_class: BreadthFirstDefender | DepthFirstDefender | KeyboardAgent | null - # example: - # 'Attacker1': - # - 'Credentials:6:attemptCredentialsReuse' # Optionally add observability rules that are applied to AttackGrapNodes # to make only certain steps observable diff --git a/malsim/__init__.py b/malsim/__init__.py index 7861412a..b957450b 100644 --- a/malsim/__init__.py +++ b/malsim/__init__.py @@ -17,7 +17,7 @@ import logging -from malsim.wrappers.gym_wrapper import AttackerEnv, DefenderEnv, register_envs +from malsim.mal_simulator import MalSimulator """ MAL Simulator @@ -29,7 +29,7 @@ __license__ = "Apache 2.0" __docformat__ = "restructuredtext en" -__all__ = ("AttackerEnv", "DefenderEnv", "register_envs") +__all__ = ["MalSimulator"] # TODO: Make sure logging dir exists and make it configurable (or use same as maltoolbox) diff --git a/malsim/__main__.py b/malsim/__main__.py new file mode 100644 index 00000000..30eb7978 --- /dev/null +++ b/malsim/__main__.py @@ -0,0 +1,92 @@ +"""CLI to run simulations in MAL Simulator using scenario files""" + +from __future__ import annotations +import argparse +import logging + +from .mal_simulator import MalSimulator +from .agents import DecisionAgent +from .scenario import create_simulator_from_scenario + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logging.getLogger().setLevel(logging.INFO) + +def run_simulation(sim: MalSimulator, agents: list[dict]): + """Run a simulation with agents""" + + sim.reset() + total_rewards = {agent_dict['name']: 0 for agent_dict in agents} + all_agents_term_or_trunc = False + + logger.info("Starting CLI env simulator.") + + i = 1 + while not all_agents_term_or_trunc: + logger.info("Iteration %s", i) + all_agents_term_or_trunc = True + actions = {} + + # Select actions for each agent + for agent_dict in agents: + decision_agent: DecisionAgent = agent_dict.get('agent') + agent_name = agent_dict['name'] + if decision_agent is None: + logger.warning( + 'Agent "%s" has no decision agent class ' + 'specified in scenario. Waiting.', agent_name, + ) + continue + + sim_agent_state = sim.agent_states[agent_name] + agent_action = decision_agent.get_next_action(sim_agent_state) + if agent_action: + actions[agent_name] = [agent_action] + logger.info( + 'Agent "%s" chose action: %s', + agent_name, agent_action.full_name + ) + + # Perform next step of simulation + sim.step(actions) + + for agent_dict in agents: + agent_name = agent_dict['name'] + agent_state = sim.agent_states[agent_name] + total_rewards[agent_name] += agent_state.reward + if not agent_state.terminated and not agent_state.truncated: + all_agents_term_or_trunc = False + print("---\n") + i += 1 + + logger.info("Game Over.") + + # Print total rewards + for agent_dict in agents: + agent_name = agent_dict['name'] + print(f'Total reward "{agent_name}"', total_rewards[agent_name]) + +def main(): + """Entrypoint function of the MAL Toolbox CLI""" + parser = argparse.ArgumentParser() + parser.add_argument( + 'scenario_file', + type=str, + help="Can be found in https://github.com/mal-lang/malsim-scenarios/" + ) + parser.add_argument( + '-o', '--output-attack-graph', type=str, + help="If set to a path, attack graph will be dumped there", + ) + args = parser.parse_args() + + sim, agents = create_simulator_from_scenario(args.scenario_file) + + if args.output_attack_graph: + sim.attack_graph.save_to_file(args.output_attack_graph) + + run_simulation(sim, agents) + + +if __name__ == '__main__': + main() diff --git a/malsim/agents/__init__.py b/malsim/agents/__init__.py new file mode 100644 index 00000000..9b4bb0bb --- /dev/null +++ b/malsim/agents/__init__.py @@ -0,0 +1,12 @@ +from .decision_agent import DecisionAgent +from .passive_agent import PassiveAgent +from .keyboard_input import KeyboardAgent +from .searchers import BreadthFirstAttacker, DepthFirstAttacker + +__all__ = [ + 'PassiveAgent', + 'DecisionAgent', + 'KeyboardAgent', + 'BreadthFirstAttacker', + 'DepthFirstAttacker' +] diff --git a/malsim/agents/decision_agent.py b/malsim/agents/decision_agent.py new file mode 100644 index 00000000..34e2d92a --- /dev/null +++ b/malsim/agents/decision_agent.py @@ -0,0 +1,28 @@ +"""A decision agent is a heuristic agent""" + +from __future__ import annotations +from typing import TYPE_CHECKING, Optional +from abc import ABC, abstractmethod + +if TYPE_CHECKING: + from ..mal_simulator import MalSimAgentStateView + from maltoolbox.attackgraph import AttackGraphNode + +class DecisionAgent(ABC): + + @abstractmethod + def get_next_action( + self, + agent_state: MalSimAgentStateView, + **kwargs + ) -> Optional[AttackGraphNode]: + """ + Select next action the agent will work with. + + Attributes: + agent: Current state of and other info about the agent from the simulator + + Returns: + The selected action or None if there are no actions to select from. + """ + ... diff --git a/malsim/agents/keyboard_input.py b/malsim/agents/keyboard_input.py index 9764c682..1d738e27 100644 --- a/malsim/agents/keyboard_input.py +++ b/malsim/agents/keyboard_input.py @@ -1,20 +1,29 @@ -import numpy as np +from __future__ import annotations import logging +from typing import TYPE_CHECKING, Optional -AGENT_ATTACKER = "attacker" -AGENT_DEFENDER = "defender" +from .decision_agent import DecisionAgent +from ..mal_simulator import MalSimAgentStateView + +if TYPE_CHECKING: + from maltoolbox.attackgraph import AttackGraphNode logger = logging.getLogger(__name__) -null_action = (0, None) +class KeyboardAgent(DecisionAgent): + """An agent that makes decisions by asking user for keyboard input""" + def __init__(self, _, **kwargs): + super().__init__(**kwargs) + logger.info("Creating KeyboardAgent") -class KeyboardAgent: - def __init__(self, vocab): - logger.debug("Create Keyboard agent.") - self.vocab = vocab + def get_next_action( + self, + agent_state: MalSimAgentStateView, + **kwargs + ) -> Optional[AttackGraphNode]: + """Compute action from action_surface""" - def compute_action_from_dict(self, obs: dict, mask: tuple) -> tuple: def valid_action(user_input: str) -> bool: if user_input == "": return True @@ -24,40 +33,35 @@ def valid_action(user_input: str) -> bool: except ValueError: return False - try: - a = associated_action[action_strings[node]] - except IndexError: - return False - - if a == 0: - return True # wait is always valid - return node < len(available_actions) and node >= 0 + return 0 <= node <= len(agent_state.action_surface) def get_action_object(user_input: str) -> tuple: node = int(user_input) if user_input != "" else None - action = associated_action[action_strings[node]] if user_input != "" else 0 - return node, action - - available_actions = np.flatnonzero(mask[1]) + return node - action_strings = [self.vocab[i] for i in available_actions] - associated_action = {i: 1 for i in action_strings} - action_strings += ["wait"] - associated_action["wait"] = 0 + if not agent_state.action_surface: + print("No actions to pick for defender") + return [] + index_to_node = dict(enumerate(agent_state.action_surface)) user_input = "xxx" while not valid_action(user_input): print("Available actions:") - print("\n".join([f"{i}. {a}" for i, a in enumerate(action_strings)])) + print( + "\n".join( + [f"{i}. {n.full_name}" for i, n in index_to_node.items()] + ) + ) print("Enter action or leave empty to wait:") user_input = input("> ") if not valid_action(user_input): print("Invalid action.") - node, a = get_action_object(user_input) + index = get_action_object(user_input) print( - f"Selected action: {action_strings[node] if node is not None else 'wait'}" + f"Selected action: {index_to_node[index].full_name}" + if index is not None else 'wait' ) - return (a, available_actions[node] if a != 0 else -1) + return index_to_node[index] if index is not None else None diff --git a/malsim/agents/passive_agent.py b/malsim/agents/passive_agent.py new file mode 100644 index 00000000..f75ddde5 --- /dev/null +++ b/malsim/agents/passive_agent.py @@ -0,0 +1,23 @@ +"""A passive agent that always choose to do nothing""" + +from __future__ import annotations +from typing import TYPE_CHECKING, Optional + +from .decision_agent import DecisionAgent +from ..mal_simulator import MalSimAgentStateView + +if TYPE_CHECKING: + from ..mal_simulator import MalSimAgentStateView + from maltoolbox.attackgraph import AttackGraphNode + +class PassiveAgent(DecisionAgent): + def __init__(self, *args, **kwargs): + ... + + def get_next_action( + self, + agent_state: MalSimAgentStateView, + **kwargs + ) -> Optional[AttackGraphNode]: + # A passive agent never does anything + return None diff --git a/malsim/agents/searchers.py b/malsim/agents/searchers.py index 6b1f3a42..db54ded0 100644 --- a/malsim/agents/searchers.py +++ b/malsim/agents/searchers.py @@ -1,137 +1,99 @@ +from __future__ import annotations import logging +import random +import re from collections import deque -from typing import Any, Deque, Dict, List, Set, Union +from collections.abc import Iterable +from typing import Optional, TYPE_CHECKING -import numpy as np +from .decision_agent import DecisionAgent +from ..mal_simulator import MalSimAgentStateView + +if TYPE_CHECKING: + from maltoolbox.attackgraph import AttackGraphNode logger = logging.getLogger(__name__) -def get_new_targets( - observation: dict, discovered_targets: Set[int], mask: tuple -) -> List[int]: - attack_surface = mask[1] - surface_indexes = list(np.flatnonzero(attack_surface)) - new_targets = [idx for idx in surface_indexes if idx not in discovered_targets] - return new_targets, surface_indexes +class BreadthFirstAttacker(DecisionAgent): + """A Breadth-First agent, with possible randomization at each level.""" + _extend_method = 'extendleft' + # Controls where newly discovered steps will be appended to the list of + # available actions. Currently used to differentiate between BFS and DFS + # agents. -class PassiveAttacker: - def compute_action_from_dict(self, observation, mask): - return (0, None) + name = ' '.join(re.findall(r'[A-Z][^A-Z]*', __qualname__)) + # A human-friendly name for the agent. -class BreadthFirstAttacker: - def __init__(self, agent_config: dict) -> None: - self.targets: Deque[int] = deque([]) - self.current_target: int = None - seed = ( - agent_config["seed"] - if agent_config.get("seed", None) - else np.random.SeedSequence().entropy - ) - self.rng = ( - np.random.default_rng(seed) - if agent_config.get("randomize", False) - else None - ) - - def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple): - new_targets, surface_indexes = get_new_targets(observation, self.targets, mask) - - # Add new targets to the back of the queue - # if desired, shuffle the new targets to make the attacker more unpredictable - if self.rng: - self.rng.shuffle(new_targets) - for c in new_targets: - self.targets.appendleft(c) + default_settings = { + 'randomize': False, + # Whether to randomize next target selection, still respecting the + # policy of the agent (e.g. BFS or DFS). + 'seed': None, + # The random seed to initialize the randomness engine with. If set, the + # simulation will be deterministic. + } - self.current_target, done = self.select_next_target( - self.current_target, self.targets, surface_indexes - ) + def __init__(self, agent_config: dict) -> None: + """Initialize a BFS agent. - self.current_target = None if done else self.current_target - action = 0 if done else 1 - if action == 0: - logger.debug( - "Attacker Breadth First agent does not have " - "any valid targets it will terminate" - ) + Args: + agent_config: Dict with settings to override defaults + """ + self.targets: deque[AttackGraphNode] = deque() + self.current_target: Optional[AttackGraphNode] = None - return (action, self.current_target) + self.settings = self.default_settings | agent_config - @staticmethod - def select_next_target( - current_target: int, - targets: Union[List[int], Deque[int]], - attack_surface: Set[int], - ) -> int: - # If the current target was not compromised, put it - # back, but on the bottom of the stack. - if current_target in attack_surface: - targets.appendleft(current_target) - current_target = targets.pop() + self.rng = random.Random(self.settings.get('seed')) - while current_target not in attack_surface: - if len(targets) == 0: - return None, True + def get_next_action( + self, agent_state: MalSimAgentStateView, **kwargs + ) -> Optional[AttackGraphNode]: + self._update_targets(agent_state.action_surface) + self._select_next_target() - current_target = targets.pop() + return self.current_target - return current_target, False + def _update_targets(self, action_surface: Iterable[AttackGraphNode]): + if self.settings['seed']: + # If a seed is set, we assume the user wants determinism in the + # simulation. Thus, we sort to an ordered list to make sure the + # non-deterministic ordering of the action_surface set does not + # break simulation determinism. + action_surface = sorted(list(action_surface), key=lambda n: n.id) + new_targets = [ + step + for step in action_surface + if step not in self.targets and not step.is_compromised() + ] -class DepthFirstAttacker: - def __init__(self, agent_config: dict) -> None: - self.current_target = -1 - self.targets: List[int] = [] - seed = ( - agent_config["seed"] - if agent_config.get("seed", None) - else np.random.SeedSequence().entropy - ) - self.rng = ( - np.random.default_rng(seed) - if agent_config.get("randomize", False) - else None - ) - - def compute_action_from_dict(self, observation: Dict[str, Any], mask: tuple): - new_targets, surface_indexes = get_new_targets(observation, self.targets, mask) - - # Add new targets to the top of the stack - if self.rng: + if self.settings['randomize']: self.rng.shuffle(new_targets) - for c in new_targets: - self.targets.append(c) - - self.current_target, done = self.select_next_target( - self.current_target, self.targets, surface_indexes - ) - - self.current_target = None if done else self.current_target - action = 0 if done else 1 - return (action, self.current_target) - @staticmethod - def select_next_target( - current_target: int, - targets: Union[List[int], Deque[int]], - attack_surface: Set[int], - ) -> int: - if current_target in attack_surface: - return current_target, False + if self.current_target in new_targets: + # If self.current_target is not yet compromised, e.g. due to TTCs, + # keep using that as the target. + new_targets.remove(self.current_target) + new_targets.append(self.current_target) - while current_target not in attack_surface: - if len(targets) == 0: - return None, True + # Enabled defenses may remove previously possible attack steps. + self.targets = deque(filter(lambda n: n.is_viable, self.targets)) - current_target = targets.pop() + getattr(self.targets, self._extend_method)(new_targets) - return current_target, False + def _select_next_target(self) -> None: + """ + Implement the actual next target selection logic. + """ + try: + self.current_target = self.targets.pop() + except IndexError: + self.current_target = None -AGENTS = { - BreadthFirstAttacker.__name__: BreadthFirstAttacker, - DepthFirstAttacker.__name__: DepthFirstAttacker, -} +class DepthFirstAttacker(BreadthFirstAttacker): + _extend_method = 'extend' diff --git a/malsim/cli.py b/malsim/cli.py deleted file mode 100644 index 8265f8d3..00000000 --- a/malsim/cli.py +++ /dev/null @@ -1,134 +0,0 @@ -"""CLI to run simulations in MAL Simulator using scenario files""" - -from __future__ import annotations -import argparse -import logging - -from .sims.mal_simulator import MalSimulator -from .agents.keyboard_input import KeyboardAgent -from .scenario import create_simulator_from_scenario - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) -logging.getLogger().setLevel(logging.INFO) - - -def run_simulation(sim: MalSimulator, sim_config: dict): - """Run a simulation on an attack graph with given config""" - - # Constants - NULL_ACTION = (0, None) - - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents()), None) - - reverse_vocab = sim._index_to_full_name - - # Initialize defender and attacker according to classes - defender_class = sim_config['agents'][defender_agent_id]['agent_class']\ - if defender_agent_id else None - defender_agent = (defender_class(reverse_vocab) - if defender_class == KeyboardAgent - else defender_class({}) - if defender_class - else None) - - attacker_class = sim_config['agents'][attacker_agent_id]['agent_class'] - attacker_agent = (attacker_class(reverse_vocab) - if attacker_class == KeyboardAgent - else attacker_class({})) - - obs, infos = sim.reset() - done = False - - logger.info("Starting game.") - - total_reward_defender = 0 - total_reward_attacker = 0 - - while not done: - - defender_action = NULL_ACTION - if defender_agent: - defender_action = defender_agent.compute_action_from_dict( - obs[defender_agent_id], - infos[defender_agent_id]["action_mask"] - ) - - attacker_action = attacker_agent.compute_action_from_dict( - obs[attacker_agent_id], - infos[attacker_agent_id]["action_mask"] - ) - - if attacker_action[1] is not None: - logger.info( - "Attacker Action: %s", reverse_vocab[attacker_action[1]]) - else: - logger.info("Attacker Action: None") - # Stop the attacker if it has run out of things to do since - # the experiment cannot progress any further. - done = True - - action_dict = { - attacker_agent_id: attacker_action, - defender_agent_id: defender_action - } - - # Perform next step of simulation - obs, rewards, terminated, truncated, infos = sim.step(action_dict) - - logger.debug( - "Attacker has compromised the following attack steps so far:" - ) - attacker_obj = sim.attack_graph.attackers[ - sim.agents_dict[attacker_agent_id]["attacker"] - ] - for step in attacker_obj.reached_attack_steps: - logger.debug(step.id) - - logger.info("Attacker Reward: %s", rewards.get(attacker_agent_id)) - - if defender_agent: - logger.info("Defender Reward: %s", rewards.get(defender_agent_id)) - - total_reward_defender += rewards.get(defender_agent_id, 0) if defender_agent else 0 - total_reward_attacker += rewards.get(attacker_agent_id, 0) - - done |= terminated.get(attacker_agent_id, True) or truncated.get(attacker_agent_id, True) - - print("---\n") - - logger.info("Game Over.") - - if defender_agent: - logger.info("Total Defender Reward: %s", total_reward_defender) - logger.info("Total Attacker Reward: %s", total_reward_attacker) - - print("Press Enter to exit.") - input() - sim.close() - - -def main(): - """Entrypoint function of the MAL Toolbox CLI""" - parser = argparse.ArgumentParser() - parser.add_argument( - 'scenario_file', - type=str, - help="Can be found in https://github.com/mal-lang/malsim-scenarios/" - ) - parser.add_argument( - '-o', '--output-attack-graph', type=str, - help="If set to a path, attack graph will be dumped there", - ) - args = parser.parse_args() - - # Create simulator from scenario - simulator, sim_config = create_simulator_from_scenario(args.scenario_file) - if args.output_attack_graph: - simulator.attack_graph.save_to_file(args.output_attack_graph) - run_simulation(simulator, sim_config) - - -if __name__ == '__main__': - main() diff --git a/malsim/envs/__init__.py b/malsim/envs/__init__.py new file mode 100644 index 00000000..697e7dfe --- /dev/null +++ b/malsim/envs/__init__.py @@ -0,0 +1,10 @@ +from .malsim_vectorized_obs_env import MalSimVectorizedObsEnv +from .gym_envs import AttackerEnv, DefenderEnv, register_envs + +# not needed, used to silence ruff F401 +__all__ = [ + "MalSimVectorizedObsEnv", + "AttackerEnv", + "DefenderEnv", + "register_envs", +] diff --git a/malsim/envs/base_classes.py b/malsim/envs/base_classes.py new file mode 100644 index 00000000..07970b65 --- /dev/null +++ b/malsim/envs/base_classes.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from ..mal_simulator import MalSimulator, MalSimAgentStateView + +class MalSimEnv(ABC): + + def __init__(self, sim: MalSimulator): + self.sim = sim + + @abstractmethod + def step(self, actions): + ... + + def reset(self, seed=None, options=None): + self.sim.reset(seed=seed, options=options) + + def register_attacker( + self, attacker_name: str, attacker_id: int + ): + self.sim.register_attacker(attacker_name, attacker_id) + + def register_defender( + self, defender_name: str + ): + self.sim.register_defender(defender_name) + + def get_agent_state(self, agent_name: str) -> MalSimAgentStateView: + return self.sim.agent_states[agent_name] + + def render(self): + pass diff --git a/malsim/wrappers/gym_wrapper.py b/malsim/envs/gym_envs.py similarity index 57% rename from malsim/wrappers/gym_wrapper.py rename to malsim/envs/gym_envs.py index 159be7b0..22d64698 100644 --- a/malsim/wrappers/gym_wrapper.py +++ b/malsim/envs/gym_envs.py @@ -7,7 +7,10 @@ from gymnasium.core import RenderFrame import numpy as np -from ..scenario import create_simulator_from_scenario +from ..scenario import load_scenario +from ..mal_simulator import MalSimulator, AgentType +from ..envs import MalSimVectorizedObsEnv +from ..agents import DecisionAgent class AttackerEnv(gym.Env): @@ -23,13 +26,28 @@ def __init__(self, scenario_file: str, **kwargs) -> None: self.render_mode = kwargs.pop('render_mode', None) # Create a simulator from the scenario given - self.sim, _ = create_simulator_from_scenario(scenario_file, **kwargs) + attack_graph, agents = load_scenario(scenario_file, **kwargs) + self.sim = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) - # Use first attacker as attacker agent in simulation - # since only single agent is currently supported - self.attacker_agent_id = list(self.sim.get_attacker_agents().keys())[0] - self.observation_space = self.sim.observation_space(self.attacker_agent_id) - self.action_space = self.sim.action_space(self.attacker_agent_id) + attacker_agents = [ + agent for agent in agents if agent['type'] == AgentType.ATTACKER] + + assert len(attacker_agents) == 1, ( + "More than one attacker in scenario," + "can not decide which one to use in AttackerEnv") + + attacker_agent = attacker_agents[0] + self.attacker_agent_name = attacker_agent['name'] + self.sim.register_attacker( + self.attacker_agent_name, + attacker_agent['attacker_id'] + ) + self.sim.reset() + + self.observation_space = \ + self.sim.observation_space(self.attacker_agent_name) + self.action_space = \ + self.sim.action_space(self.attacker_agent_name) super().__init__() def reset( @@ -38,8 +56,8 @@ def reset( super().reset(seed=seed, options=options) # TODO: params not used by method, find out if we need to send them - obs, info = self.sim.reset(seed=seed, options=options) - return obs[self.attacker_agent_id], info[self.attacker_agent_id] + obs, infos = self.sim.reset(seed=seed, options=options) + return obs[self.attacker_agent_name], infos[self.attacker_agent_name] def step( self, action: Any @@ -48,15 +66,17 @@ def step( # TODO: Add potential defender and give defender action if it exists actions = { - self.attacker_agent_id: action, + self.attacker_agent_name: action, } - obs, rewards, terminated, truncated, infos = self.sim.step(actions) + + obs, rew, term, trunc, infos = self.sim.step(actions) + return ( - obs[self.attacker_agent_id], - rewards[self.attacker_agent_id], - terminated[self.attacker_agent_id], - truncated[self.attacker_agent_id], - infos[self.attacker_agent_id], + obs[self.attacker_agent_name], + rew[self.attacker_agent_name], + term[self.attacker_agent_name], + trunc[self.attacker_agent_name], + infos[self.attacker_agent_name] ) def render(self): @@ -64,12 +84,11 @@ def render(self): @property def num_assets(self): - return self.sim.num_assets + return len(self.sim._index_to_asset_type) @property def num_step_names(self): - return self.sim.num_step_names - + return len(self.sim._index_to_step_name) class DefenderEnv(gym.Env): metadata = {'render_modes': []} @@ -78,59 +97,94 @@ def __init__(self, scenario_file, **kwargs) -> None: self.randomize = kwargs.pop('randomize_attacker_behavior', False) self.render_mode = kwargs.pop('render_mode', None) - self.sim, conf = create_simulator_from_scenario(scenario_file, **kwargs) - - # Select first attacker and first defender for the simulation - # currently only one of each agent is supported - self.attacker_agent_id = list(self.sim.get_attacker_agents().keys())[0] - self.defender_agent_id = list(self.sim.get_defender_agents().keys())[0] - - self.attacker_class = conf['agents'][self.attacker_agent_id]['agent_class'] - self.attacker = self.attacker_class({}) - - self.observation_space = self.sim.observation_space(self.defender_agent_id) - self.action_space = self.sim.action_space(self.defender_agent_id) - - self.attacker_obs = None - self.attacker_mask = None + ag, agents = load_scenario(scenario_file) + + self.scenario_agents = agents + self.sim = MalSimVectorizedObsEnv(MalSimulator(ag), **kwargs) + + # Register attacker agents from scenario + self._register_attacker_agents(self.scenario_agents) + self.attacker_decision_agents = {} + + # Register defender agent + self.defender_agent_name = "DefenderEnvAgent" + self.sim.register_defender(self.defender_agent_name) + self.sim.reset() + + self.observation_space = \ + self.sim.observation_space(self.defender_agent_name) + self.action_space = \ + self.sim.action_space(self.defender_agent_name) + + def _register_attacker_agents(self, agents: list[dict]): + """Register attackers in simulator""" + for agent_info in agents: + if agent_info['type'] == AgentType.ATTACKER: + agent_name = agent_info['name'] + attacker_id = agent_info['attacker_id'] + self.sim.register_attacker(agent_name, attacker_id) + + def _create_attacker_decision_agents( + self, agents: list[dict], seed=None + ) -> dict[str, DecisionAgent]: + """Create decision agents for each attacker""" + + attacker_agents = {} + + for agent_info in agents: + if agent_info['type'] == AgentType.ATTACKER: + agent_name = agent_info['name'] + agent_class = agent_info.get('agent_class') + if agent_class: + attacker_agents[agent_name] = ( + agent_class( + {'seed': seed, 'randomize': self.randomize} + ) + ) + return attacker_agents def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None + self, *, + seed: int | None = None, + options: dict[str, Any] | None = None ) -> tuple[Any, dict[str, Any]]: - super().reset(seed=seed, options=options) - self.attacker = self.attacker_class({'seed': seed, 'randomize': self.randomize}) - - # TODO: params not used by method, find out if we need to send them - obs, info = self.sim.reset(seed=seed, options=options) - - self.attacker_obs = obs[self.attacker_agent_id] - self.attacker_mask = info[self.attacker_agent_id]['action_mask'] - - return obs[self.defender_agent_id], info[self.defender_agent_id] + super().reset(seed=seed, options=options) + self.attacker_decision_agents = self._create_attacker_decision_agents( + self.scenario_agents, seed=seed + ) + obs, infos = self.sim.reset(seed=seed, options=options) + return ( + obs[self.defender_agent_name], + infos[self.defender_agent_name] + ) def step( - self, action: Any + self, action: tuple[int, int] ) -> tuple[Any, SupportsFloat, bool, bool, dict[str, Any]]: - attacker_action = self.attacker.compute_action_from_dict( - self.attacker_obs, self.attacker_mask - ) - actions = { - self.defender_agent_id: action, - self.attacker_agent_id: attacker_action, - } - obs, rewards, terminated, truncated, infos = self.sim.step(actions) + actions = {} + actions[self.defender_agent_name] = action + + # Get actions from scenario attackers + for agent_name, decision_agent in self.attacker_decision_agents.items(): + # get next action from decision agent and put it in actions dict + attacker_state = self.sim.get_agent_state(agent_name) + attacker_action_node = decision_agent.get_next_action(attacker_state) + if attacker_action_node: + node_index = self.sim.node_to_index(attacker_action_node) + actions[agent_name] = (1, node_index) - self.attacker_obs = obs[self.attacker_agent_id] - self.attacker_mask = infos[self.attacker_agent_id]['action_mask'] + # Perform step + obs, rewards, terminated, truncated, infos = \ + self.sim.step(actions) return ( - obs[self.defender_agent_id], - rewards[self.defender_agent_id], - terminated[self.defender_agent_id], - truncated[self.defender_agent_id], - infos[self.defender_agent_id], + obs[self.defender_agent_name], + rewards[self.defender_agent_name], + terminated[self.defender_agent_name], + truncated[self.defender_agent_name], + infos[self.defender_agent_name], ) def render(self): @@ -149,11 +203,11 @@ def add_reverse_edges(edges: np.ndarray, defense_steps: set) -> np.ndarray: @property def num_assets(self): - return len(self.sim.unwrapped._index_to_asset_type) + return len(self.sim._index_to_asset_type) @property def num_step_names(self): - return len(self.sim.unwrapped._index_to_step_name) + return len(self.sim._index_to_step_name) def _to_binary(val, max_val): diff --git a/malsim/envs/malsim_vectorized_obs_env.py b/malsim/envs/malsim_vectorized_obs_env.py new file mode 100644 index 00000000..0f8952ca --- /dev/null +++ b/malsim/envs/malsim_vectorized_obs_env.py @@ -0,0 +1,937 @@ +""""MalSimVectorizedObsEnv: + - Abide to the ParallelEnv interface + - Build serialized observations from the MalSimulator state + - step() assumes that actions are given as AttackGraphNodes + - Used by AttackerEnv/DefenderEnv to be able to +""" + +from __future__ import annotations + +import functools +import logging +import sys + +import numpy as np +from gymnasium.spaces import MultiDiscrete, Box, Dict +from pettingzoo import ParallelEnv +from maltoolbox.attackgraph import AttackGraphNode + +from ..mal_simulator import ( + MalSimulator, + AgentType, + MalSimAgentStateView, + MalSimAttackerState, + MalSimDefenderState +) + +from .base_classes import MalSimEnv + +ITERATIONS_LIMIT = int(1e9) +logger = logging.getLogger(__name__) + +# First the logging methods: + +def format_full_observation(sim, observation): + """ + Return a formatted string of the entire observation. This includes + sections that will not change over time, these define the structure of + the attack graph. + """ + obs_str = '\nAttack Graph Steps\n' + + str_format = "{:<5} {:<80} {:<6} {:<5} {:<5} {:<30} {:<8} {:<}\n" + header_entry = [ + "Entry", "Name", "Is_Obs", "State", + "RTTC", "Asset Type(Index)", "Asset Id", "Step" + ] + entries = [] + for entry in range(0, len(observation["observed_state"])): + asset_type_index = observation["asset_type"][entry] + asset_type_str = sim._index_to_asset_type[asset_type_index ] + \ + '(' + str(asset_type_index) + ')' + entries.append( + [ + entry, + sim._index_to_full_name[entry], + observation["is_observable"][entry], + observation["observed_state"][entry], + observation["remaining_ttc"][entry], + asset_type_str, + observation["asset_id"][entry], + observation["step_name"][entry], + ] + ) + obs_str += format_table( + str_format, header_entry, entries, reprint_header = 30 + ) + + obs_str += "\nAttack Graph Edges:\n" + for edge in observation["attack_graph_edges"]: + obs_str += str(edge) + "\n" + + obs_str += "\nInstance Model Assets:\n" + str_format = "{:<5} {:<5} {:<}\n" + header_entry = [ + "Entry", "Id", "Type(Index)"] + entries = [] + for entry in range(0, len(observation["model_asset_id"])): + asset_type_str = sim._index_to_asset_type[ + observation["model_asset_type"][entry]] + \ + '(' + str(observation["model_asset_type"][entry]) + ')' + entries.append( + [ + entry, + observation["model_asset_id"][entry], + asset_type_str + ] + ) + obs_str += format_table( + str_format, header_entry, entries, reprint_header = 30 + ) + + obs_str += "\nInstance Model Edges:\n" + str_format = "{:<5} {:<40} {:<40} {:<}\n" + header_entry = [ + "Entry", + "Left Asset(Id/Index)", + "Right Asset(Id/Index)", + "Type(Index)" + ] + entries = [] + for entry in range(0, len(observation["model_edges_ids"])): + assoc_type_str = sim._index_to_model_assoc_type[ + observation["model_edges_type"][entry]] + \ + '(' + str(observation["model_edges_type"][entry]) + ')' + left_asset_index = int(observation["model_edges_ids"][entry][0]) + right_asset_index = int(observation["model_edges_ids"][entry][1]) + left_asset_id = sim._index_to_model_asset_id[left_asset_index] + right_asset_id = sim._index_to_model_asset_id[right_asset_index] + left_asset_str = \ + sim.model.get_asset_by_id(left_asset_id).name + \ + '(' + str(left_asset_id) + '/' + str(left_asset_index) + ')' + right_asset_str = \ + sim.model.get_asset_by_id(right_asset_id).name + \ + '(' + str(right_asset_id) + '/' + str(right_asset_index) + ')' + entries.append( + [ + entry, + left_asset_str, + right_asset_str, + assoc_type_str + ] + ) + obs_str += format_table( + str_format, header_entry, entries, reprint_header = 30 + ) + + return obs_str + +def format_obs_var_sec( + sim, + observation, + included_values = [-1, 0, 1] + ): + """ + Return a formatted string of the sections of the observation that can + vary over time. + + Arguments: + observation - the observation to format + included_values - the values to list, any values not present in the + list will be filtered out + """ + + str_format = "{:>5} {:>80} {:<5} {:<5} {:<}\n" + header_entry = ["Id", "Name", "State", "RTTC", "Entry"] + entries = [] + for entry in range(0, len(observation["observed_state"])): + if observation["is_observable"][entry] and \ + observation["observed_state"][entry] in included_values: + entries.append( + [ + sim._index_to_id[entry], + sim._index_to_full_name[entry], + observation["observed_state"][entry], + observation["remaining_ttc"][entry], + entry + ] + ) + + obs_str = format_table( + str_format, header_entry, entries, reprint_header = 30 + ) + + return obs_str + +def format_info(sim, info): + can_act = "Yes" if info["action_mask"][0][1] > 0 else "No" + agent_info_str = f"Can act? {can_act}\n" + for entry in range(0, len(info["action_mask"][1])): + if info["action_mask"][1][entry] == 1: + agent_info_str += f"{sim._index_to_id[entry]} " \ + f"{sim._index_to_full_name[entry]}\n" + return agent_info_str + + +def log_mapping_tables(logger, sim): + """Log all mapping tables in MalSimulator""" + + str_format = "{:<5} {:<15} {:<}\n" + table = "\n" + header_entry = ["Index", "Attack Step Id", "Attack Step Full Name"] + entries = [] + for entry in sim._index_to_id: + entries.append( + [ + sim._id_to_index[entry], + entry, + sim._index_to_full_name[sim._id_to_index[entry]] + ] + ) + table += format_table( + str_format, + header_entry, + entries, + reprint_header = 30 + ) + logger.debug(table) + + str_format = "{:<5} {:<}\n" + table = "\n" + header_entry = ["Index", "Asset Id"] + entries = [] + for entry in sim._model_asset_id_to_index: + entries.append( + [ + sim._model_asset_id_to_index[entry], + entry + ] + ) + table += format_table( + str_format, + header_entry, + entries, + reprint_header = 30 + ) + logger.debug(table) + + str_format = "{:<5} {:<}\n" + table = "\n" + header_entry = ["Index", "Asset Type"] + entries = [] + for entry in sim._asset_type_to_index: + entries.append( + [ + sim._asset_type_to_index[entry], + entry + ] + ) + table += format_table( + str_format, + header_entry, + entries, + reprint_header = 30 + ) + logger.debug(table) + + str_format = "{:<5} {:<}\n" + table = "\n" + header_entry = ["Index", "Attack Step Name"] + entries = [] + for entry in sim._index_to_step_name: + entries.append([sim._step_name_to_index[entry], entry]) + table += format_table( + str_format, + header_entry, + entries, + reprint_header = 30 + ) + logger.debug(table) + + str_format = "{:<5} {:<}\n" + table = "\n" + header_entry = ["Index", "Association Type"] + entries = [] + for entry in sim._index_to_model_assoc_type: + entries.append([sim._model_assoc_type_to_index[entry], entry]) + table += format_table( + str_format, + header_entry, + entries, + reprint_header = 30 + ) + logger.debug(table) + + +def format_table( + entry_format: str, + header_entry: list[str], + entries: list[list[str]], + reprint_header: int = 0 + ) -> str: + """ + Format a table according to the parameters specified. + + Arguments: + entry_format - The string format for the table + reprint_header - How many rows apart to reprint the header. If 0 the + header will not be reprinted. + header_entry - The entry representing the header of the table + entries - The list of entries to format + + Return: + The formatted table. + """ + + formatted_str = '' + header = entry_format.format(*header_entry) + formatted_str += header + for entry_nr, entry in zip(range(0, len(entries)), entries): + formatted_str += entry_format.format(*entry) + if (reprint_header != 0) and ((entry_nr + 1) % reprint_header == 0): + formatted_str += header + return formatted_str + + +def log_agent_state( + logger, sim, agent, terminations, truncations, infos + ): + """Debug log all an agents current state""" + + agent_obs_str = format_obs_var_sec( + sim, agent.observation, included_values = [0, 1] + ) + + logger.debug( + 'Observation for agent "%s":\n%s', agent.name, agent_obs_str) + logger.debug( + 'Rewards for agent "%s": %d', agent.name, agent.reward) + logger.debug( + 'Termination for agent "%s": %s', + agent.name, terminations[agent.name]) + logger.debug( + 'Truncation for agent "%s": %s', + agent.name, str(truncations[agent.name])) + agent_info_str = format_info(sim, infos[agent.name]) + logger.debug( + 'Info for agent "%s":\n%s', agent.name, agent_info_str) + + +# Now the actual class: + +class MalSimVectorizedObsEnv(ParallelEnv, MalSimEnv): + """ + Environment that runs simulation between agents. + Builds serialized observations. + Implements the ParallelEnv. + """ + + def __init__( + self, + sim: MalSimulator + ): + + super().__init__(sim) + + # Useful instead of having to fetch .sim.attack_graph + self.attack_graph = sim.attack_graph + + # List mapping from node/asset index to id/name/type + self._index_to_id = [n.id for n in self.attack_graph.nodes.values()] + self._index_to_full_name = ( + [n.full_name for n in self.attack_graph.nodes.values()] + ) + self._index_to_asset_type = ( + [n.name for n in self.attack_graph.lang_graph.assets.values()] + ) + + unique_step_type_names = { + n.full_name + for asset in self.attack_graph.lang_graph.assets.values() + for n in asset.attack_steps.values() + } + self._index_to_step_name = list(unique_step_type_names) + + self._index_to_model_asset_id = ( + [int(asset_id) for asset_id in self.attack_graph.model.assets] + ) + + unique_assoc_type_names = { + assoc.full_name + for asset in self.attack_graph.lang_graph.assets.values() + for assoc in asset.associations.values() + } + self._index_to_model_assoc_type = list(unique_assoc_type_names) + + # Lookup dicts attribute to index + self._id_to_index = { + n: i for i, n in enumerate(self._index_to_id)} + self._asset_type_to_index = { + n: i for i, n in enumerate(self._index_to_asset_type)} + self._step_name_to_index = { + n: i for i, n in enumerate(self._index_to_step_name) + } + self._model_asset_id_to_index = { + asset: i for i, asset in enumerate(self._index_to_model_asset_id) + } + self._model_assoc_type_to_index = { + assoc_type: i for i, assoc_type in + enumerate(self._index_to_model_assoc_type) + } + + if logger.isEnabledFor(logging.DEBUG): + log_mapping_tables(logger, self) + + self._blank_observation = self._create_blank_observation() + + self._agent_observations = {} + self._agent_infos = {} + + self.reset() + + @property + def agents(self): + """Required by ParallelEnv""" + return list(self.sim._alive_agents) + + @property + def possible_agents(self): + """Required by ParallelEnv""" + return list(self.sim._agents) + + def _create_blank_observation(self, default_obs_state=-1): + """Create the initial observation""" + # For now, an `object` is an attack step + num_steps = len(self.sim.attack_graph.nodes) + + observation = { + # If no observability set for node, assume observable. + "is_observable": [step.extras.get('observable', 1) + for step in self.attack_graph.nodes.values()], + # Same goes for actionable. + "is_actionable": [step.extras.get('actionable', 1) + for step in self.attack_graph.nodes.values()], + "observed_state": num_steps * [default_obs_state], + "remaining_ttc": num_steps * [0], + "asset_type": [ + self._asset_type_to_index[step.lg_attack_step.asset.name] + for step in self.attack_graph.nodes.values()], + "asset_id": [step.model_asset.id + for step in self.attack_graph.nodes.values()], + "step_name": [ + self._step_name_to_index.get( + str(step.lg_attack_step.asset.name + ":" + step.name) + ) for step in self.attack_graph.nodes.values()], + } + + logger.debug( + 'Create blank observation with %d attack steps.', num_steps) + + # Add attack graph edges to observation + observation["attack_graph_edges"] = [] + for attack_step in self.attack_graph.nodes.values(): + # For determinism we need to order the children + ordered_children = list(attack_step.children) + ordered_children.sort(key=lambda n: n.id) + for child in ordered_children: + observation["attack_graph_edges"].append( + [ + self._id_to_index[attack_step.id], + self._id_to_index[child.id] + ] + ) + + # Add reverse attack graph edges for defense steps (required by some + # defender agent logic) + for attack_step in self.attack_graph.nodes.values(): + if attack_step.type == "defense": + # For determinism we need to order the children + ordered_children = list(attack_step.children) + ordered_children.sort(key=lambda n: n.id) + for child in ordered_children: + observation["attack_graph_edges"].append( + [ + self._id_to_index[child.id], + self._id_to_index[attack_step.id] + ] + ) + + # Add instance model assets + observation["model_asset_id"] = [] + observation["model_asset_type"] = [] + observation["model_edges_ids"] = [] + observation["model_edges_type"] = [] + + for asset in self.attack_graph.model.assets.values(): + observation["model_asset_id"].append(asset.id) + observation["model_asset_type"].append( + self._asset_type_to_index[asset.type]) + + for fieldname, other_assets in asset.associated_assets.items(): + for other_asset in other_assets: + observation["model_edges_ids"].append( + [ + self._model_asset_id_to_index[asset.id], + self._model_asset_id_to_index[other_asset.id] + ] + ) + + lg_assoc = asset.lg_asset.associations[fieldname] + observation["model_edges_type"].append( + self._model_assoc_type_to_index[lg_assoc.full_name] + ) + + np_obs = { + "is_observable": np.array(observation["is_observable"], + dtype=np.int8), + "is_actionable": np.array(observation["is_actionable"], + dtype=np.int8), + "observed_state": np.array(observation["observed_state"], + dtype=np.int8), + "remaining_ttc": np.array(observation["remaining_ttc"], + dtype=np.int64), + "asset_type": np.array(observation["asset_type"], dtype=np.int64), + "asset_id": np.array(observation["asset_id"], dtype=np.int64), + "step_name": np.array(observation["step_name"], dtype=np.int64), + "attack_graph_edges": np.array(observation["attack_graph_edges"], + dtype=np.int64), + "model_asset_id": np.array(observation["model_asset_id"], + dtype=np.int64), + "model_asset_type": np.array(observation["model_asset_type"], + dtype=np.int64), + "model_edges_ids": np.array(observation["model_edges_ids"], + dtype=np.int64), + "model_edges_type": np.array(observation["model_edges_type"], + dtype=np.int64) + } + + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + format_full_observation(self, np_obs) + ) + + return np_obs + + def create_action_mask(self, agent: MalSimAgentStateView): + """ + Create an action mask for an agent based on its action_surface. + + Parameters: + agent: The agent for whom the mask is created. + + Returns: + A dictionary with the action mask for the agent. + """ + + available_actions = [0] * len(self.sim.attack_graph.nodes) + can_wait = 1 if agent.type == AgentType.DEFENDER else 0 + can_act = 0 + + for node in agent.action_surface: + + if agent.type == AgentType.DEFENDER: + # Defender can act on its whole action surface + index = self._id_to_index[node.id] + available_actions[index] = 1 + can_act = 1 + + if agent.type == AgentType.ATTACKER: + # Attacker can only act on nodes that are not compromised + attacker = \ + self.sim.attack_graph.attackers[agent.attacker_id] + if not node.is_compromised_by(attacker): + index = self._id_to_index[node.id] + available_actions[index] = 1 + can_act = 1 + + return { + 'action_mask': ( + np.array([can_wait, can_act], dtype=np.int8), + np.array(available_actions, dtype=np.int8) + ) + } + + def _update_agent_infos(self): + for agent in self.sim.agent_states.values(): + self._agent_infos[agent.name] = self.create_action_mask(agent) + + def _get_association_full_name(self, association) -> str: + """Get association full name + + TODO: Remove this method once the language graph integration is + complete in the mal-toolbox because the language graph associations + will use their full names for the name property + + Arguments: + association - the association whose full name will be returned + + Return: + A string containing the association name and the name of each of the + two asset types for the left and right fields separated by + underscores. + """ + + assoc_name = association.__class__.__name__ + if '_' in assoc_name: + # TODO: Not actually a to-do, but just an extra clarification that + # this is an ugly hack that will work for now until we get the + # unique association names. Right now some associations already + # use the asset types as part of their name if there are multiple + # associations with the same name. + return assoc_name + + left_field_name, right_field_name = \ + self.sim.attack_graph.model.get_association_field_names(association) + left_field = getattr(association, left_field_name) + right_field = getattr(association, right_field_name) + lang_assoc = self.sim.attack_graph.lang_graph.get_association_by_fields_and_assets( + left_field_name, + right_field_name, + left_field[0].type, + right_field[0].type + ) + if lang_assoc is None: + raise LookupError('Failed to find association for fields ' + '"%s" "%s" and asset types "%s" "%s"!' % ( + left_field_name, + right_field_name, + left_field[0].type, + right_field[0].type + ) + ) + assoc_full_name = lang_assoc.name + '_' + \ + lang_assoc.left_field.asset.name + '_' + \ + lang_assoc.right_field.asset.name + return assoc_full_name + + @functools.lru_cache(maxsize=None) + def action_space(self, agent=None): + num_actions = 2 # two actions: wait or use + # For now, an `object` is an attack step + num_steps = len(self.sim.attack_graph.nodes) + return MultiDiscrete([num_actions, num_steps], dtype=np.int64) + + @functools.lru_cache(maxsize=None) + def observation_space(self, agent_name: str = None): + # For now, an `object` is an attack step + num_assets = len(self.attack_graph.model.assets) + num_steps = len(self.attack_graph.nodes) + num_lang_asset_types = len(self.sim.attack_graph.lang_graph.assets) + + unique_step_types = set() + for asset_type in self.sim.attack_graph.lang_graph.assets.values(): + unique_step_types |= set(asset_type.attack_steps.values()) + num_lang_attack_steps = len(unique_step_types) + + unique_assoc_type_names = set() + for asset_type in self.sim.attack_graph.lang_graph.assets.values(): + for assoc_type in asset_type.associations.values(): + unique_assoc_type_names.add( + assoc_type.full_name + ) + num_lang_association_types = len(unique_assoc_type_names) + + num_attack_graph_edges = len( + self._blank_observation["attack_graph_edges"]) + num_model_edges = len( + self._blank_observation["model_edges_ids"]) + + return Dict( + { + "is_observable": Box( + 0, 1, shape=(num_steps,), dtype=np.int8 + ), # 0 for unobservable, 1 for observable + "is_actionable": Box( + 0, 1, shape=(num_steps,), dtype=np.int8 + ), # 0 for non-actionable, 1 for actionable + "observed_state": Box( + -1, 1, shape=(num_steps,), dtype=np.int8 + ), # -1 for unknown, + # 0 for disabled/not compromised, + # 1 for enabled/compromised + "remaining_ttc": Box( + 0, sys.maxsize, shape=(num_steps,), dtype=np.int64 + ), # remaining TTC + "asset_type": Box( + 0, + num_lang_asset_types, + shape=(num_steps,), + dtype=np.int64, + ), # asset type + "asset_id": Box( + 0, sys.maxsize, shape=(num_steps,), dtype=np.int64 + ), # asset id + "step_name": Box( + 0, + num_lang_attack_steps, + shape=(num_steps,), + dtype=np.int64, + ), # attack/defense step name + "attack_graph_edges": Box( + 0, + num_steps, + shape=(num_attack_graph_edges, 2), + dtype=np.int64, + ), # edges between attack graph steps + "model_asset_id": Box( + 0, + num_assets, + shape=(num_assets,), + dtype=np.int64, + ), # instance model asset ids + "model_asset_type": Box( + 0, + num_lang_asset_types, + shape=(num_assets,), + dtype=np.int64, + ), # instance model asset types + "model_edges_ids": Box( + 0, + num_assets, + shape=(num_model_edges, 2), + dtype=np.int64, + ), # instance model edge ids + "model_edges_type": Box( + 0, + num_lang_association_types, + shape=(num_model_edges, ), + dtype=np.int64, + ), # instance model edge types + } + ) + + def index_to_node(self, index: int) -> AttackGraphNode: + """Get a node from the attack graph by index + + Index is the position of the node in the lookup list. + First convert index to id and then fetch the node from the + AttackGraph. + + Raise LookupError if node with given index does not map to a node in + the attack graph and IndexError if the index is out of range for the + lookup list. + + Returns: + Attack graph node matching the id of the index in the lookup list + """ + + if index >= len(self._index_to_id): + raise IndexError( + f'Index {index}, is out of range of the ' + f'lookup list which is of length {len(self._index_to_id)}' + ) + + node_id = self._index_to_id[index] + node = self.sim.attack_graph.nodes[node_id] + if not node: + raise LookupError( + f'Index {index} (id: {node_id}), does not map to a node' + ) + return node + + def node_to_index(self, node: AttackGraphNode) -> int: + """Get the index of an attack graph node + + Returns: + Index of the attack graph node in the lookup list + """ + + assert node, "Node can not be None" + return self._id_to_index[node.id] + + def serialized_action_to_node( + self, serialized_action: tuple[int, int] + ) -> list[AttackGraphNode]: + """Convert serialized action to malsim action format + + (0, None) -> [] + (1, idx) -> [Node with idx] + + Currently supports single action only. + """ + nodes = [] + act, step_idx = serialized_action + if act: + nodes = [self.index_to_node(step_idx)] + return nodes + + def register_attacker(self, attacker_name: str, attacker_id: int): + super().register_attacker(attacker_name, attacker_id) + agent = self.sim.agent_states[attacker_name] + self._init_agent(agent) + + def register_defender(self, defender_name: str): + super().register_defender(defender_name) + agent = self.sim.agent_states[defender_name] + self._init_agent(agent) + + def _init_agent(self, agent: MalSimAgentStateView): + # Fill dicts with env specific agent obs/infos + self._agent_observations[agent.name] = \ + self._create_blank_observation() + + self._agent_infos[agent.name] = \ + self.create_action_mask(agent) + + def _update_attacker_obs( + self, + compromised_nodes, + disabled_nodes, + attacker_agent: MalSimAttackerState + ): + """Update the observation of the serialized obs attacker""" + + def _enable_node( + node: AttackGraphNode, agent_observation: dict + ): + """Set enabled node obs state to enabled and + its children to disabled""" + + # Mark enabled node obs state with 1 (enabled) + node_index = self._id_to_index[node.id] + agent_observation['observed_state'][node_index] = 1 + + # Mark unknown (-1) children node obs states with 0 (disabled) + for child_node in node.children: + child_index = self._id_to_index[child_node.id] + child_obs = agent_observation['observed_state'][child_index] + if child_obs == -1: + agent_observation['observed_state'][child_index] = 0 + + attacker = ( + self.sim.attack_graph.attackers[attacker_agent.attacker_id] + ) + attacker_observation = self._agent_observations[attacker_agent.name] + + for node in compromised_nodes: + if node.is_compromised_by(attacker): + # Enable node + logger.debug("Enable %s in attacker obs", node.full_name) + _enable_node(node, attacker_observation) + + for node in disabled_nodes: + is_entrypoint = node.extras.get('entrypoint', False) + if node.is_compromised_by(attacker) and not is_entrypoint: + logger.debug("Disable %s in attacker obs", node.full_name) + # Mark attacker compromised steps that were + # disabled by a defense as disabled in obs + node_idx = self.node_to_index(node) + attacker_observation['observed_state'][node_idx] = 0 + + def _update_defender_obs( + self, + compromised_nodes: list[AttackGraphNode], + disabled_nodes: list[AttackGraphNode], + defender_agent: MalSimDefenderState + ): + """Update the observation of the defender""" + + defender_observation = self._agent_observations[defender_agent.name] + + for node in compromised_nodes: + logger.debug("Enable %s in defender obs", node.full_name) + node_idx = self.node_to_index(node) + defender_observation['observed_state'][node_idx] = 1 + + for node in disabled_nodes: + is_entrypoint = node.extras.get('entrypoint', False) + if not is_entrypoint: + logger.debug("Disable %s in defender obs", node.full_name) + node_idx = self.node_to_index(node) + defender_observation['observed_state'][node_idx] = 0 + + def reset( + self, + seed: int | None = None, + options: dict | None = None + ) -> tuple[dict, dict]: + """Reset simulator and return current + observation and infos for each agent""" + + MalSimEnv.reset(self, seed, options) + + self.attack_graph = self.sim.attack_graph # new ref + + for agent in self.sim.agent_states.values(): + # Reset observation and action mask for agents + self._agent_observations[agent.name] = \ + self._create_blank_observation() + self._agent_infos[agent.name] = \ + self.create_action_mask(agent) + + # Enable pre-enabled nodes in observation + attacker_entry_points = [ + n for n in self.sim.attack_graph.nodes.values() + if n.is_compromised() + ] + pre_enabled_defenses = [ + n for n in self.sim.attack_graph.nodes.values() + if n.defense_status == 1.0 + ] + + for node in attacker_entry_points: + node.extras['entrypoint'] = True + + self._update_observations( + attacker_entry_points + pre_enabled_defenses, [] + ) + + # TODO: should we return copies instead so they are not modified externally? + return self._agent_observations, self._agent_infos + + def _update_observations(self, compromised_nodes, disabled_nodes): + """Update observations of all agents""" + + if not self.sim.sim_settings.uncompromise_untraversable_steps: + disabled_nodes = [] + + # TODO: Is this correct? All attackers get the same compromised_nodes? + logger.debug("Enable:\n\t%s", [n.full_name for n in compromised_nodes]) + logger.debug("Disable:\n\t%s", [n.full_name for n in disabled_nodes]) + + for agent in self.sim.agent_states.values(): + if agent.type == AgentType.ATTACKER: + self._update_attacker_obs( + compromised_nodes, disabled_nodes, agent + ) + elif agent.type == AgentType.DEFENDER: + self._update_defender_obs( + compromised_nodes, disabled_nodes, agent + ) + + def step(self, actions: dict[str, tuple[int, int]]): + """Perform step with mal simulator and observe in parallel env""" + + malsim_actions = {} + for agent_name, agent_action in actions.items(): + malsim_actions[agent_name] = [] + if agent_action[0]: + # If agent wants to act, convert index to node + malsim_actions[agent_name].append( + self.index_to_node(agent_action[1]) + ) + + states = self.sim.step(malsim_actions) + + all_actioned = [ + n + for state in states.values() + for n in state.step_performed_nodes + ] + disabled_nodes = next(iter(states.values())).step_unviable_nodes + + self._update_agent_infos() # Update action masks + self._update_observations(all_actioned, disabled_nodes) + + observations = self._agent_observations + rewards = {} + terminations = {} + truncations = {} + infos = self._agent_infos + + for agent in self.sim.agent_states.values(): + rewards[agent.name] = agent.reward + terminations[agent.name] = agent.terminated + truncations[agent.name] = agent.truncated + + return observations, rewards, terminations, truncations, infos diff --git a/malsim/mal_simulator.py b/malsim/mal_simulator.py new file mode 100644 index 00000000..0db01071 --- /dev/null +++ b/malsim/mal_simulator.py @@ -0,0 +1,520 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +import logging +from enum import Enum +from types import MappingProxyType +from typing import Any, Optional + +from maltoolbox import neo4j_configs +from maltoolbox.ingestors import neo4j +from maltoolbox.attackgraph import AttackGraph, AttackGraphNode, query +from maltoolbox.attackgraph.analyzers import apriori + +ITERATIONS_LIMIT = int(1e9) +logger = logging.getLogger(__name__) + + +class AgentType(Enum): + """Enum for agent types""" + ATTACKER = 'attacker' + DEFENDER = 'defender' + + +@dataclass +class MalSimAgentState: + """Stores the state of an agent in the simulator""" + + # Identifier of the agent, used in MalSimulator for lookup + name: str + type: AgentType + + # Contains current agent reward in the simulation + # Attackers get positive rewards, defenders negative + reward: int = 0 + + # Contains the steps performed successfully in the last step + step_performed_nodes: set[AttackGraphNode] = field(default_factory=set) + + # Contains possible actions for the agent in the next step + action_surface: set[AttackGraphNode] = field(default_factory=set) + + # Contains possible actions that became available in the last step + step_action_surface_additions: set[AttackGraphNode] = ( + field(default_factory = set)) + + # Contains previously possible actions that became unavailable in the last + # step + step_action_surface_removals: set[AttackGraphNode] = ( + field(default_factory = set)) + + # Contains nodes that defender actions made unviable in the last step + step_unviable_nodes: set[AttackGraphNode] = field(default_factory=set) + + # Fields that tell if the agent is done or stopped + truncated: bool = False + terminated: bool = False + + +class MalSimAttackerState(MalSimAgentState): + """Stores the state of an attacker in the simulator""" + + def __init__(self, name: str, attacker_id: int): + super().__init__(name, AgentType.ATTACKER) + self.attacker_id = attacker_id + + +class MalSimDefenderState(MalSimAgentState): + """Stores the state of a defender in the simulator""" + + # Contains the steps performed successfully by all of the attacker agents + # in the last step + step_all_compromised_nodes: set[AttackGraphNode] = set() + + def __init__(self, name: str): + super().__init__(name, AgentType.DEFENDER) + + +class MalSimAgentStateView: + """Read-only interface to MalSimAgentState.""" + + _frozen = False + + def __init__(self, agent): + self._agent = agent + self._frozen = True + + def __setattr__(self, key, value) -> None: + if self._frozen: + raise AttributeError("Cannot modify agent state view") + self.__dict__[key] = value + + def __delattr__(self, key) -> None: + if self._frozen: + raise AttributeError("Cannot modify agent state view") + super().__delattr__(key) + + def __getattr__(self, attr) -> Any: + """Return read-only version of proxied agent's properties.""" + value = getattr(self._agent, attr) + + if isinstance(value, dict): + return MappingProxyType(value) + if isinstance(value, list): + return tuple(value) + if isinstance(value, set): + return frozenset(value) + + return value + + def __dir__(self): + """Dynamically resolve attribute names for REPL autocompletion.""" + return list(vars(self._agent).keys()) + ["_agent", "_frozen"] + + +@dataclass +class MalSimulatorSettings(): + """Contains settings used in MalSimulator""" + + # uncompromise_untraversable_steps + # - Uncompromise (evict attacker) from nodes/steps that are no longer + # traversable (often because a defense kicked in) if set to True + # otherwise: + # - Leave the node/step compromised even after it becomes untraversable + uncompromise_untraversable_steps: bool = False + + +class MalSimulator(): + """A MAL Simulator that works on the AttackGraph + + Allows user to register agents (defender and attacker) + and lets the agents perform actions step by step and updates + the state of the attack graph based on the steps chosen. + """ + + def __init__( + self, + attack_graph: AttackGraph, + prune_unviable_unnecessary: bool = True, + sim_settings: MalSimulatorSettings = MalSimulatorSettings(), + max_iter=ITERATIONS_LIMIT, + ): + """ + Args: + attack_graph - The attack graph to use + max_iter - Max iterations in simulation + prune_unviable_unnecessary - Prunes graph if set to true + sim_settings - Settings for simulator + """ + logger.info("Creating Base MAL Simulator.") + + # Calculate viability and necessity and optionally prune graph + apriori.calculate_viability_and_necessity(attack_graph) + if prune_unviable_unnecessary: + apriori.prune_unviable_and_unnecessary_nodes(attack_graph) + + # Keep a backup attack graph to use when resetting + self.attack_graph_backup = copy.deepcopy(attack_graph) + + # Initialize all values + self.attack_graph = attack_graph + + self.sim_settings = sim_settings + self.max_iter = max_iter # Max iterations before stopping simulation + self.cur_iter = 0 # Keep track on current iteration + + # All internal agent states (dead or alive) + self._agent_states: dict[str, MalSimAgentState] = {} + + # Keep track on all 'living' agents sorted by order to step in + self._alive_agents: set[str] = set() + + def reset( + self, + seed: Optional[int] = None, + options: Optional[dict] = None + ) -> dict[str, MalSimAgentStateView]: + """Reset attack graph, iteration and reinitialize agents""" + + logger.info("Resetting MAL Simulator.") + # Reset attack graph + self.attack_graph = copy.deepcopy(self.attack_graph_backup) + # Reset current iteration + self.cur_iter = 0 + # Reset agents + self._reset_agents() + + return self.agent_states + + def _init_agent_rewards(self): + """Give rewards for pre-enabled attack/defense steps""" + + for agent in self._get_attacker_agents(): + attacker = self.attack_graph.attackers[agent.attacker_id] + agent.reward = sum( + n.extras.get("reward", 0) for n in attacker.reached_attack_steps + ) + + lost_reward = sum( + node.extras.get("reward", 0) + for node in self.attack_graph.nodes.values() + if node.is_compromised() or node.is_enabled_defense() + ) + + for agent in self._get_defender_agents(): + agent.reward = lost_reward + + def _init_agent_action_surfaces(self): + """Set agent action surfaces according to current state""" + for agent in self._agent_states.values(): + + if isinstance(agent, MalSimAttackerState): + attacker = self.attack_graph.attackers[agent.attacker_id] + agent.action_surface = query.calculate_attack_surface(attacker) + + elif isinstance(agent, MalSimDefenderState): + agent.action_surface = \ + query.get_defense_surface(self.attack_graph) + else: + raise LookupError(f"Agent type {agent.type} not supported") + + def _reset_agents(self): + """Reset agent rewards and action surfaces""" + + # Revive all agents + self._alive_agents = set(self._agent_states.keys()) + + for agent_state in self._get_attacker_agents(): + # Create a new agent state for the attacker + self._agent_states[agent_state.name] = ( + MalSimAttackerState(agent_state.name, agent_state.attacker_id) + ) + + for agent_state in self._get_defender_agents(): + # Create a new agent state for the attacker + self._agent_states[agent_state.name] = ( + MalSimDefenderState(agent_state.name) + ) + + self._init_agent_rewards() + self._init_agent_action_surfaces() + + def register_attacker(self, name: str, attacker_id: int): + """Register a mal sim attacker agent""" + assert name not in self._agent_states, \ + f"Duplicate agent named {name} not allowed" + + agent_state = MalSimAttackerState(name, attacker_id) + self._agent_states[name] = agent_state + self._alive_agents.add(name) + + def register_defender(self, name: str): + """Register a mal sim defender agent""" + assert name not in self._agent_states, \ + f"Duplicate agent named {name} not allowed" + + agent_state = MalSimDefenderState(name) + self._agent_states[name] = agent_state + self._alive_agents.add(name) + + @property + def agent_states(self) -> dict[str, MalSimAgentStateView]: + """Return read only agent state for all dead and alive agents""" + return { + name: MalSimAgentStateView(agent) + for name, agent in self._agent_states.items() + } + + def _get_attacker_agents(self) -> list[MalSimAttackerState]: + """Return list of mutable attacker agent states of alive attackers""" + return [ + a for a in self._agent_states.values() + if a.name in self._alive_agents + and isinstance(a, MalSimAttackerState) + ] + + def _get_defender_agents(self) -> list[MalSimDefenderState]: + """Return list of mutable defender agent states of alive defenders""" + return [ + a for a in self._agent_states.values() + if a.name in self._alive_agents + and isinstance(a, MalSimDefenderState) + ] + + def _uncompromise_attack_steps( + self, attack_steps_to_uncompromise: set[AttackGraphNode] + ): + """Uncompromise nodes for each attacker agent + + Go through the nodes in `attack_steps_to_uncompromise` for each + attacker agent. If a node is compromised by the attacker agent: + - Uncompromise the node and remove rewards for it. + """ + for attacker_agent in self._get_attacker_agents(): + attacker = self.attack_graph.attackers[attacker_agent.attacker_id] + + for unviable_node in attack_steps_to_uncompromise: + if unviable_node.is_compromised_by(attacker): + + # Reward is no longer present for attacker + node_reward = unviable_node.extras.get('reward', 0) + attacker_agent.reward -= node_reward + + # Reward is no longer present for defenders + for defender_agent in self._get_defender_agents(): + defender_agent.reward += node_reward + + # Uncompromise node if requested + attacker.undo_compromise(unviable_node) + + def _attacker_step( + self, agent: MalSimAttackerState, nodes: list[AttackGraphNode] + ) -> set[AttackGraphNode]: + """Compromise attack step nodes with attacker + + Args: + agent - the agent to compromise nodes with + nodes - the nodes to compromise + + Returns: A set of attack steps nodes that were compromised + """ + + compromised_nodes = set() + attacker = self.attack_graph.attackers[agent.attacker_id] + + for node in nodes: + assert node == self.attack_graph.nodes[node.id], ( + f"{agent.name} tried to enable a node that is not part " + "of this simulators attack_graph. Make sure the node " + "comes from the agents action surface." + ) + + logger.info( + 'Attacker agent "%s" stepping through "%s"(%d).', + agent.name, node.full_name, node.id + ) + + # Compromise node if possible + if query.is_node_traversable_by_attacker(node, attacker) \ + and node in agent.action_surface: + attacker.compromise(node) + agent.reward += node.extras.get('reward', 0) + compromised_nodes.add(node) + + logger.info( + 'Attacker agent "%s" compromised "%s"(%d).', + agent.name, node.full_name, node.id + ) + else: + logger.warning("Attacker could not compromise %s", + node.full_name) + + # Update attacker action surface + attack_surface_additions = query.calculate_attack_surface( + attacker, from_nodes=compromised_nodes, skip_compromised=True + ) + + agent.step_action_surface_additions = attack_surface_additions + agent.action_surface |= attack_surface_additions + agent.step_performed_nodes = compromised_nodes + + # Terminate attacker if it has nothing left to do + terminate = True + for node in agent.action_surface: + if not node.is_compromised_by(attacker): + terminate = False + break + agent.terminated = terminate + + def _defender_step( + self, agent: MalSimDefenderState, nodes: list[AttackGraphNode] + ): + """Enable defense step nodes with defender. + + Args: + agent - the agent to activate defense nodes with + nodes - the defense step nodes to enable + + """ + + enabled_defenses = set() + attack_steps_made_unviable = set() + + for node in nodes: + assert node == self.attack_graph.nodes[node.id], ( + f"{agent.name} tried to enable a node that is not part " + "of this simulators attack_graph. Make sure the node " + "comes from the agents action surface." + ) + logger.info( + 'Defender agent "%s" stepping through "%s"(%d).', + agent.name, node.full_name, node.id + ) + + if node not in agent.action_surface: + logger.warning( + 'Defender agent "%s" tried to step through "%s"(%d).' + 'which is not part of its defense surface. Defender ' + 'step will skip', agent.name, node.full_name, node.id + ) + continue + + # Enable defense if possible + if node.is_available_defense(): + node.defense_status = 1.0 + node.is_viable = False + attack_steps_made_unviable |= \ + apriori.propagate_viability_from_unviable_node(node) + agent.reward -= node.extras.get("reward", 0) + enabled_defenses.add(node) + logger.info( + 'Defender agent "%s" enabled "%s"(%d).', + agent.name, node.full_name, node.id + ) + + agent.step_performed_nodes = enabled_defenses + agent.step_unviable_nodes |= attack_steps_made_unviable + + for defender_agent in self._get_defender_agents(): + # Remove enabled defenses from all defenders action surface + defender_agent.step_action_surface_removals |= enabled_defenses + defender_agent.action_surface -= enabled_defenses + + for attacker_agent in self._get_attacker_agents(): + # Remove attack steps made unviable from all attackers + # action surfaces if they were part of it + attacker_agent.step_action_surface_removals |= ( + attacker_agent.action_surface & attack_steps_made_unviable + ) + attacker_agent.action_surface -= attack_steps_made_unviable + + if self.sim_settings.uncompromise_untraversable_steps: + self._uncompromise_attack_steps(attack_steps_made_unviable) + + def step( + self, actions: dict[str, list[AttackGraphNode]] + ) -> dict[str, MalSimAgentStateView]: + """Take a step in the simulation + + Args: + actions - a dict mapping agent name to agent action which + contains the actions for that user. + + Returns: + - A dictionary containing the agent state views keyed by agent names + """ + logger.info( + "Stepping through iteration %d/%d", self.cur_iter, self.max_iter + ) + logger.debug("Performing actions: %s", actions) + + # Populate these from the results for all agents' actions. + all_compromised = set() + unviable_nodes = set() + all_attackers_terminated = True + + # Prepare agent states for new step + for agent_name in self._alive_agents: + agent = self._agent_states[agent_name] + # Clear action surface removals from previous step to + # make sure the old values do not carry over. + agent.step_action_surface_removals = set() + # All agents share same set of unviable_nodes + agent.step_unviable_nodes = unviable_nodes + + # Perform defender actions first + for agent in self._get_defender_agents(): + agent_actions = actions.get(agent.name, []) + self._defender_step(agent, agent_actions) + # All defenders share the same set of compromised nodes + # Which is built from what the attackers do this step + agent.step_all_compromised_nodes = all_compromised + + # Perform attacker actions afterwards + for agent in self._get_attacker_agents(): + agent_actions = actions.get(agent.name, []) + self._attacker_step(agent, agent_actions) + all_compromised |= agent.step_performed_nodes + + if not agent.terminated: + all_attackers_terminated = False + + # Apply defenders negative rewards from compromises this step + lost_rewards = sum(n.extras.get("reward", 0) for n in all_compromised) + for defender in self._get_defender_agents(): + defender.reward -= lost_rewards + + if self._alive_agents and all_attackers_terminated: + # Terminate all defenders if all attackers are terminated + logger.info("All attackers are terminated") + for agent in self._agent_states.values(): + agent.terminated = True + + if self.cur_iter >= self.max_iter: + # Truncate all agents when max iter is reached + logger.info("Max iteration reached - all agents truncated") + for agent in self._agent_states.values(): + agent.truncated = True + + for agent_name in self._alive_agents.copy(): + agent = self._agent_states[agent_name] + if agent.terminated or agent.truncated: + logger.info("Removing agent %s", agent.name) + self._alive_agents.remove(agent_name) + + self.cur_iter += 1 + + return self.agent_states + + def render(self): + """Render attack graph from simulation in Neo4J""" + logger.debug("Sending attack graph to Neo4J database.") + neo4j.ingest_attack_graph( + self.attack_graph, + neo4j_configs["uri"], + neo4j_configs["username"], + neo4j_configs["password"], + neo4j_configs["dbname"], + delete=True, + ) diff --git a/malsim/scenario.py b/malsim/scenario.py index 0ac4f8f1..f7612589 100644 --- a/malsim/scenario.py +++ b/malsim/scenario.py @@ -12,34 +12,44 @@ """ import os -from typing import Optional +from typing import Optional, Any import yaml from maltoolbox.attackgraph import AttackGraph, Attacker, create_attack_graph -from .agents.searchers import BreadthFirstAttacker, DepthFirstAttacker -from .agents.keyboard_input import KeyboardAgent -from .sims.mal_simulator import MalSimulator +from .agents import ( + BreadthFirstAttacker, + DepthFirstAttacker, + KeyboardAgent, + PassiveAgent +) + +from .mal_simulator import AgentType, MalSimulator agent_class_name_to_class = { - 'BreadthFirstAttacker': BreadthFirstAttacker, 'DepthFirstAttacker': DepthFirstAttacker, - 'KeyboardAgent': KeyboardAgent + 'BreadthFirstAttacker': BreadthFirstAttacker, + 'KeyboardAgent': KeyboardAgent, + 'PassiveAgent': PassiveAgent, } +deprecated_fields = [ + 'attacker_agent_class', + 'defender_agent_class', + 'attacker_entry_points' +] + # All required fields in scenario yml file required_fields = [ + 'agents', 'lang_file', 'model_file', - 'attacker_agent_class', - 'defender_agent_class', ] # All allowed fields in scenario yml fild allowed_fields = required_fields + [ 'rewards', - 'attacker_entry_points', 'observable_steps', 'actionable_steps' ] @@ -50,8 +60,11 @@ def validate_scenario(scenario_dict): # Verify that all keys in dict are supported for key in scenario_dict.keys(): + if key in deprecated_fields: + raise SyntaxError(f"Scenario setting '{key}' is deprecated, see " + "README or ./tests/testdata/scenarios") if key not in allowed_fields: - raise SyntaxError(f"The setting '{key}' is not supported") + raise SyntaxError(f"Scenario setting '{key}' is not supported") # Verify that all required fields are in scenario file for key in required_fields: @@ -124,9 +137,11 @@ def _validate_scenario_property_rules( "observability/actionability rules" ) + # TODO: revisit this variable once LookupDicts are merged + asset_names = set(a.name for a in graph.model.assets.values()) for asset_name in rules.get('by_asset_name', []): # Make sure each specified asset exist - assert asset_name in graph.model.asset_names, ( + assert asset_name in asset_names, ( f"Failed to find asset name '{asset_name}' in model " f"'{graph.model.name}' when applying scenario" "observability/actionability rules") @@ -185,68 +200,94 @@ def apply_scenario_node_property_rules( step.extras[node_prop] = 0 -def apply_scenario_attacker_entrypoints( - attack_graph: AttackGraph, entry_points: dict -) -> None: +def add_attacker_entrypoints( + attack_graph: AttackGraph, attacker_name: str, entry_points: dict +) -> Attacker: """Apply attacker entrypoints to attackgraph from scenario - Go through attacker entry points from scenario file and add - them to the referenced attacker in the attack graph + Creater attacker, add entrypoints to it and compromise them. Args: - attack_graph: the attack graph to apply entry points to - - entry_points: the entry points to apply + - attacker_name: the name to give the attacker + - entry_points: the entry points to apply for the attacker + + Returns: + - the Attacker with the relevant entrypoints """ - for attacker_name, entry_point_names in entry_points.items(): - attacker = Attacker( - attacker_name, entry_points=set(), reached_attack_steps=set() - ) - attack_graph.add_attacker(attacker) + if entry_points: + # Override attackers in attack graph / model if + # entry points are defined in scenario + all_attackers = list(attack_graph.attackers.values()) + for attacker in all_attackers: + attack_graph.remove_attacker(attacker) + + attacker = Attacker(attacker_name) + attack_graph.add_attacker(attacker) - for entry_point_name in entry_point_names: - entry_point = attack_graph.get_node_by_full_name(entry_point_name) - if not entry_point: - raise LookupError(f"Node {entry_point_name} does not exist") - attacker.compromise(entry_point) + for entry_point_name in entry_points: + entry_point = attack_graph.get_node_by_full_name(entry_point_name) + if not entry_point: + raise LookupError(f"Node {entry_point_name} does not exist") + attacker.compromise(entry_point) - attacker.entry_points = list(attacker.reached_attack_steps) + attacker.entry_points = attacker.reached_attack_steps.copy() + return attacker -def load_scenario_simulation_config(scenario: dict) -> dict: - """Load configurations used in MALSimulator - Load parts of scenario are used for the MALSimulator + +def load_simulator_agents( + attack_graph: AttackGraph, scenario: dict + ) -> list[dict[str, Any]]: + """Load agents to be registered in MALSimulator + + Create the agents from the specified classes, + register entrypoints for attackers. Args: + - attack_graph: the attack graph - scenario: the scenario in question as a dict Return: - - config: a dict containing config + - agents: a dict containing agents and their settings """ - # Create config object which is later returned - config = {} - config['agents'] = {} + # Create list of agents dicts + agents = [] + + for agent_name, agent_info in scenario.get('agents', {}).items(): + class_name = agent_info.get('agent_class') + agent_type = AgentType(agent_info.get('type')) + agent_dict = {'name': agent_name, 'type': agent_type} + agent_config = agent_info.get('config', {}) + + if agent_type == AgentType.ATTACKER: + # Attacker has entrypoints + entry_points = agent_info.get('entry_points') + attacker = add_attacker_entrypoints( + attack_graph, agent_name, entry_points + ) + agent_dict['attacker_id'] = attacker.id + + if class_name is None: + # No class name - no agent object created + agents.append(agent_dict) + continue - # Currently only support one defender and attacker - attacker_id = "attacker" - defender_id = "defender" + if class_name not in agent_class_name_to_class: + # Illegal class agent + raise LookupError( + f"Agent class '{class_name}' not supported" + ) - if a_class := scenario.get('attacker_agent_class'): - if a_class not in agent_class_name_to_class: - raise LookupError(f"Agent class '{a_class}' not supported") - config['agents'][attacker_id] = {} - config['agents'][attacker_id]['type'] = 'attacker' - config['agents'][attacker_id]['agent_class'] = \ - agent_class_name_to_class.get(a_class) + # Initialize the agent object + agent_class = agent_class_name_to_class[class_name] + agent = agent_class(agent_config) + agent_dict['agent_class'] = agent_class + agent_dict['agent'] = agent + agents.append(agent_dict) - if d_class := scenario.get('defender_agent_class'): - if d_class not in agent_class_name_to_class: - raise LookupError(f"Agent class '{d_class}' not supported") - config['agents'][defender_id] = {} - config['agents'][defender_id]['type'] = 'defender' - config['agents'][defender_id]['agent_class'] = \ - agent_class_name_to_class.get(d_class) - return config + return agents def apply_scenario_to_attack_graph( @@ -268,18 +309,6 @@ def apply_scenario_to_attack_graph( rewards = scenario.get('rewards', {}) apply_scenario_rewards(attack_graph, rewards) - # Apply attacker entrypoints to attack graph - entry_points = scenario.get('attacker_entry_points', {}) - if entry_points: - # Override attackers in attack graph if - # entry points defined in scenario - all_attackers = set(attack_graph.attackers.values()) - for attacker in all_attackers: - attack_graph.remove_attacker(attacker) - - # Apply attacker entry points from scenario - apply_scenario_attacker_entrypoints(attack_graph, entry_points) - # Apply observability and actionability settings to attack graph for node_prop in ['observable', 'actionable']: node_prop_settings = scenario.get(node_prop + '_steps') @@ -287,7 +316,7 @@ def apply_scenario_to_attack_graph( attack_graph, node_prop, node_prop_settings) -def load_scenario(scenario_file: str) -> tuple[AttackGraph, dict]: +def load_scenario(scenario_file: str) -> tuple[AttackGraph, list[dict[str, Any]]]: """Load a scenario from a scenario file to an AttackGraph""" with open(scenario_file, 'r', encoding='utf-8') as s_file: @@ -301,14 +330,16 @@ def load_scenario(scenario_file: str) -> tuple[AttackGraph, dict]: apply_scenario_to_attack_graph(attack_graph, scenario) # Load the scenario configuration - scenario_config = load_scenario_simulation_config(scenario) + scenario_agents = load_simulator_agents(attack_graph, scenario) - return attack_graph, scenario_config + return attack_graph, scenario_agents def create_simulator_from_scenario( - scenario_file: str, **kwargs - ) -> tuple[MalSimulator, dict]: + scenario_file: str, + sim_class=MalSimulator, + **kwargs, + ) -> tuple[MalSimulator, list[dict[str, Any]]]: """Creates and returns a MalSimulator created according to scenario file A wrapper that loads the graph and config from the scenario file @@ -319,28 +350,24 @@ def create_simulator_from_scenario( - scenario_file: the file name of the scenario Returns: - - MalSimulator: the resulting simulator + - sim: the resulting simulator + - agents: the agent infos as a list of dicts """ - attack_graph, conf = load_scenario(scenario_file) + attack_graph, scenario_agents = load_scenario(scenario_file) - sim = MalSimulator( - attack_graph.lang_graph, - attack_graph.model, - attack_graph, - **kwargs - ) + sim = sim_class(attack_graph, **kwargs) - # This version only supports one defender and one attacker - for agent_id, agent_info in conf['agents'].items(): - if agent_info['type'] == "attacker": - assert len(attack_graph.attackers) == 1, ( - "You have defined more than one attacker, ", - "cannot decide which one belongs to agent in simulator" + # Register agents in simulator + for agent_dict in scenario_agents: + if agent_dict['type'] == AgentType.ATTACKER: + sim.register_attacker( + agent_dict['name'], + agent_dict['attacker_id'] + ) + elif agent_dict['type'] == AgentType.DEFENDER: + sim.register_defender( + agent_dict['name'] ) - attacker = next(iter(sim.attack_graph.attackers.values())) - sim.register_attacker(agent_id, attacker.id) - elif agent_info['type'] == "defender": - sim.register_defender(agent_id) - return sim, conf + return sim, scenario_agents diff --git a/malsim/sims/__init__.py b/malsim/sims/__init__.py deleted file mode 100644 index 4de75f97..00000000 --- a/malsim/sims/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .mal_simulator import MalSimulator as MalSimulator -from .mal_simulator_settings import MalSimulatorSettings as MalSimulatorSettings \ No newline at end of file diff --git a/malsim/sims/mal_simulator.py b/malsim/sims/mal_simulator.py deleted file mode 100644 index 323f8174..00000000 --- a/malsim/sims/mal_simulator.py +++ /dev/null @@ -1,1252 +0,0 @@ -from __future__ import annotations - -import sys -import copy -import logging -import functools -from typing import Optional, TYPE_CHECKING -import numpy as np - -from gymnasium.spaces import MultiDiscrete, Box, Dict -from pettingzoo import ParallelEnv - -from maltoolbox import neo4j_configs -from maltoolbox.attackgraph import AttackGraph, AttackGraphNode, Attacker -from maltoolbox.attackgraph.analyzers import apriori -from maltoolbox.attackgraph import query -from maltoolbox.ingestors import neo4j - -from .mal_simulator_settings import MalSimulatorSettings - -if TYPE_CHECKING: - from maltoolbox.language import LanguageGraph - from maltoolbox.model import Model - -ITERATIONS_LIMIT = int(1e9) - -logger = logging.getLogger(__name__) - -def format_table( - entry_format: str, - header_entry: list[str], - entries: list[list[str]], - reprint_header: int = 0 - ) -> str: - """ - Format a table according to the parameters specified. - - Arguments: - entry_format - The string format for the table - reprint_header - How many rows apart to reprint the header. If 0 the - header will not be reprinted. - header_entry - The entry representing the header of the table - entries - The list of entries to format - - Return: - The formatted table. - """ - - formatted_str = '' - header = entry_format.format(*header_entry) - formatted_str += header - for entry_nr, entry in zip(range(0, len(entries)), entries): - formatted_str += entry_format.format(*entry) - if (reprint_header != 0) and ((entry_nr + 1) % reprint_header == 0): - formatted_str += header - return formatted_str - -class MalSimulator(ParallelEnv): - def __init__( - self, - lang_graph: LanguageGraph, - model: Model, - attack_graph: AttackGraph, - max_iter=ITERATIONS_LIMIT, - prune_unviable_unnecessary: bool = True, - sim_settings: MalSimulatorSettings = MalSimulatorSettings(), - **kwargs, - ): - """ - Args: - lang_graph - The language graph to use - model - The model to use - attack_graph - The attack graph to use - max_iter - Max iterations in simulation - prune_unviable_unnecessary - Prunes graph if set to true - sim_settings - Settings for simulator - """ - super().__init__() - logger.info("Create Mal Simulator.") - self.lang_graph = lang_graph - self.model = model - - apriori.calculate_viability_and_necessity(attack_graph) - if prune_unviable_unnecessary: - apriori.prune_unviable_and_unnecessary_nodes(attack_graph) - - self.attack_graph = attack_graph - self.sim_settings = sim_settings - self.max_iter = max_iter - - self.attack_graph_backup = copy.deepcopy(self.attack_graph) - - self.possible_agents = [] - self.agents = [] - self.agents_dict = {} - - self.initialize(self.max_iter) - - def __call__(self): - return self - - def create_blank_observation(self, default_obs_state=-1): - # For now, an `object` is an attack step - num_steps = len(self.attack_graph.nodes) - - observation = { - # If no observability set for node, assume observable. - "is_observable": [step.extras.get('observable', 1) - for step in self.attack_graph.nodes.values()], - # Same goes for actionable. - "is_actionable": [step.extras.get('actionable', 1) - for step in self.attack_graph.nodes.values()], - "observed_state": num_steps * [default_obs_state], - "remaining_ttc": num_steps * [0], - "asset_type": [self._asset_type_to_index[step.lg_attack_step.asset.name] - for step in self.attack_graph.nodes.values()], - "asset_id": [step.model_asset.id - for step in self.attack_graph.nodes.values()], - "step_name": [ - self._step_name_to_index.get( - str(step.lg_attack_step.asset.name + ":" + step.name) - ) for step in self.attack_graph.nodes.values()], - } - - logger.debug( - 'Create blank observation with %d attack steps.', num_steps) - - # Add attack graph edges to observation - observation["attack_graph_edges"] = [] - for attack_step in self.attack_graph.nodes.values(): - # For determinism we need to order the children - ordered_children = list(attack_step.children) - ordered_children.sort(key=lambda n: n.id) - - for child in ordered_children: - observation["attack_graph_edges"].append( - [ - self._id_to_index[attack_step.id], - self._id_to_index[child.id] - ] - ) - - # Add reverse attack graph edges for defense steps - # (required by some defender agent logic) - if attack_step.type == "defense": - observation["attack_graph_edges"].append( - [ - self._id_to_index[child.id], - self._id_to_index[attack_step.id] - ] - ) - - # Add instance model assets - observation["model_asset_id"] = [] - observation["model_asset_type"] = [] - observation["model_edges_ids"] = [] - observation["model_edges_type"] = [] - - for asset in self.model.assets.values(): - observation["model_asset_id"].append(asset.id) - observation["model_asset_type"].append( - self._asset_type_to_index[asset.type]) - - for fieldname, other_assets in asset.associated_assets.items(): - for other_asset in other_assets: - observation["model_edges_ids"].append( - [ - self._model_asset_id_to_index[asset.id], - self._model_asset_id_to_index[other_asset.id] - ] - ) - - lg_assoc = asset.lg_asset.associations[fieldname] - observation["model_edges_type"].append( - self._model_assoc_type_to_index[lg_assoc.full_name] - ) - - - np_obs = { - "is_observable": np.array(observation["is_observable"], - dtype=np.int8), - "is_actionable": np.array(observation["is_actionable"], - dtype=np.int8), - "observed_state": np.array(observation["observed_state"], - dtype=np.int8), - "remaining_ttc": np.array(observation["remaining_ttc"], - dtype=np.int64), - "asset_type": np.array(observation["asset_type"], dtype=np.int64), - "asset_id": np.array(observation["asset_id"], dtype=np.int64), - "step_name": np.array(observation["step_name"], dtype=np.int64), - "attack_graph_edges": np.array(observation["attack_graph_edges"], - dtype=np.int64), - "model_asset_id": np.array(observation["model_asset_id"], - dtype=np.int64), - "model_asset_type": np.array(observation["model_asset_type"], - dtype=np.int64), - "model_edges_ids": np.array(observation["model_edges_ids"], - dtype=np.int64), - "model_edges_type": np.array(observation["model_edges_type"], - dtype=np.int64) - } - - return np_obs - - def format_full_observation(self, observation): - """ - Return a formatted string of the entire observation. This includes - sections that will not change over time, these define the structure of - the attack graph. - """ - obs_str = '\nAttack Graph Steps\n' - - str_format = "{:<5} {:<80} {:<6} {:<5} {:<5} {:<30} {:<8} {:<}\n" - header_entry = [ - "Entry", "Name", "Is_Obs", "State", "RTTC", "Asset Type(Index)", "Asset Id", "Step"] - entries = [] - for entry in range(0, len(observation["observed_state"])): - asset_type_index = observation["asset_type"][entry] - asset_type_str = self._index_to_asset_type[asset_type_index ] + \ - '(' + str(asset_type_index) + ')' - entries.append( - [ - entry, - self._index_to_full_name[entry], - observation["is_observable"][entry], - observation["observed_state"][entry], - observation["remaining_ttc"][entry], - asset_type_str, - observation["asset_id"][entry], - observation["step_name"][entry], - ] - ) - obs_str += format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - obs_str += "\nAttack Graph Edges:\n" - for edge in observation["attack_graph_edges"]: - obs_str += str(edge) + "\n" - - obs_str += "\nInstance Model Assets:\n" - str_format = "{:<5} {:<5} {:<}\n" - header_entry = [ - "Entry", "Id", "Type(Index)"] - entries = [] - for entry in range(0, len(observation["model_asset_id"])): - asset_type_str = self._index_to_asset_type[ - observation["model_asset_type"][entry]] + \ - '(' + str(observation["model_asset_type"][entry]) + ')' - entries.append( - [ - entry, - observation["model_asset_id"][entry], - asset_type_str - ] - ) - obs_str += format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - obs_str += "\nInstance Model Edges:\n" - str_format = "{:<5} {:<40} {:<40} {:<}\n" - header_entry = [ - "Entry", - "Left Asset(Id/Index)", - "Right Asset(Id/Index)", - "Type(Index)" - ] - entries = [] - for entry in range(0, len(observation["model_edges_ids"])): - assoc_type_str = self._index_to_model_assoc_type[ - observation["model_edges_type"][entry]] + \ - '(' + str(observation["model_edges_type"][entry]) + ')' - left_asset_index = int(observation["model_edges_ids"][entry][0]) - right_asset_index = int(observation["model_edges_ids"][entry][1]) - left_asset_id = self._index_to_model_asset_id[left_asset_index] - right_asset_id = self._index_to_model_asset_id[right_asset_index] - left_asset_str = \ - self.model.get_asset_by_id(left_asset_id).name + \ - '(' + str(left_asset_id) + '/' + str(left_asset_index) + ')' - right_asset_str = \ - self.model.get_asset_by_id(right_asset_id).name + \ - '(' + str(right_asset_id) + '/' + str(right_asset_index) + ')' - entries.append( - [ - entry, - left_asset_str, - right_asset_str, - assoc_type_str - ] - ) - obs_str += format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - return obs_str - - def format_obs_var_sec(self, - observation, - included_values = [-1, 0, 1]): - """ - Return a formatted string of the sections of the observation that can - vary over time. - - Arguments: - observation - the observation to format - included_values - the values to list, any values not present in the - list will be filtered out - """ - - str_format = "{:>5} {:>80} {:<5} {:<5} {:<}\n" - header_entry = ["Id", "Name", "State", "RTTC", "Entry"] - entries = [] - for entry in range(0, len(observation["observed_state"])): - if observation["is_observable"][entry] and \ - observation["observed_state"][entry] in included_values: - entries.append( - [ - self._index_to_id[entry], - self._index_to_full_name[entry], - observation["observed_state"][entry], - observation["remaining_ttc"][entry], - entry - ] - ) - - obs_str = format_table( - str_format, header_entry, entries, reprint_header = 30 - ) - - return obs_str - - def _format_info(self, info): - can_act = "Yes" if info["action_mask"][0][1] > 0 else "No" - agent_info_str = f"Can act? {can_act}\n" - for entry in range(0, len(info["action_mask"][1])): - if info["action_mask"][1][entry] == 1: - agent_info_str += f"{self._index_to_id[entry]} " \ - f"{self._index_to_full_name[entry]}\n" - return agent_info_str - - @functools.lru_cache(maxsize=None) - def observation_space(self, agent=None): - # For now, an `object` is an attack step - num_assets = len(self.attack_graph.model.assets) - num_steps = len(self.attack_graph.nodes) - num_lang_asset_types = len(self.lang_graph.assets) - - unique_step_types = set() - for asset_type in self.lang_graph.assets.values(): - unique_step_types |= set(asset_type.attack_steps.values()) - num_lang_attack_steps = len(unique_step_types) - - unique_assoc_type_names = set() - for asset_type in self.lang_graph.assets.values(): - for assoc_type in asset_type.associations.values(): - unique_assoc_type_names.add( - assoc_type.full_name - ) - num_lang_association_types = len(unique_assoc_type_names) - - num_attack_graph_edges = len( - self._blank_observation["attack_graph_edges"]) - num_model_edges = len( - self._blank_observation["model_edges_ids"]) - return Dict( - { - "is_observable": Box( - 0, 1, shape=(num_steps,), dtype=np.int8 - ), # 0 for unobservable, 1 for observable - "is_actionable": Box( - 0, 1, shape=(num_steps,), dtype=np.int8 - ), # 0 for non-actionable, 1 for actionable - "observed_state": Box( - -1, 1, shape=(num_steps,), dtype=np.int8 - ), # -1 for unknown, - # 0 for disabled/not compromised, - # 1 for enabled/compromised - "remaining_ttc": Box( - 0, sys.maxsize, shape=(num_steps,), dtype=np.int64 - ), # remaining TTC - "asset_type": Box( - 0, - num_lang_asset_types, - shape=(num_steps,), - dtype=np.int64, - ), # asset type - "asset_id": Box( - 0, sys.maxsize, shape=(num_steps,), dtype=np.int64 - ), # asset id - "step_name": Box( - 0, - num_lang_attack_steps, - shape=(num_steps,), - dtype=np.int64, - ), # attack/defense step name - "attack_graph_edges": Box( - 0, - num_steps, - shape=(num_attack_graph_edges, 2), - dtype=np.int64, - ), # edges between attack graph steps - "model_asset_id": Box( - 0, - num_assets, - shape=(num_assets,), - dtype=np.int64, - ), # instance model asset ids - "model_asset_type": Box( - 0, - num_lang_asset_types, - shape=(num_assets,), - dtype=np.int64, - ), # instance model asset types - "model_edges_ids": Box( - 0, - num_assets, - shape=(num_model_edges, 2), - dtype=np.int64, - ), # instance model edge ids - "model_edges_type": Box( - 0, - num_lang_association_types, - shape=(num_model_edges, ), - dtype=np.int64, - ), # instance model edge types - } - ) - - @functools.lru_cache(maxsize=None) - def action_space(self, agent=None): - num_actions = 2 # two actions: wait or use - # For now, an `object` is an attack step - num_steps = len(self.attack_graph.nodes) - return MultiDiscrete([num_actions, num_steps], dtype=np.int64) - - def reset( - self, - seed: Optional[int] = None, - options: Optional[dict] = None - ): - logger.info("Resetting simulator.") - self.attack_graph = copy.deepcopy(self.attack_graph_backup) - return self.initialize(self.max_iter) - - def log_mapping_tables(self): - """Log all mapping tables in MalSimulator""" - - str_format = "{:<5} {:<15} {:<}\n" - table = "\n" - header_entry = ["Index", "Attack Step Id", "Attack Step Full Name"] - entries = [] - for entry in self._index_to_id: - entries.append( - [ - self._id_to_index[entry], - entry, - self._index_to_full_name[self._id_to_index[entry]] - ] - ) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Asset Id"] - entries = [] - for entry in self._model_asset_id_to_index: - entries.append( - [ - self._model_asset_id_to_index[entry], - entry - ] - ) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Asset Type"] - entries = [] - for entry in self._asset_type_to_index: - entries.append( - [ - self._asset_type_to_index[entry], - entry - ] - ) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Attack Step Name"] - entries = [] - for entry in self._index_to_step_name: - entries.append([self._step_name_to_index[entry], entry]) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - str_format = "{:<5} {:<}\n" - table = "\n" - header_entry = ["Index", "Association Type"] - entries = [] - for entry in self._index_to_model_assoc_type: - entries.append([self._model_assoc_type_to_index[entry], entry]) - table += format_table( - str_format, - header_entry, - entries, - reprint_header = 30 - ) - logger.debug(table) - - - def _create_mapping_tables(self): - """Create mapping tables""" - logger.debug("Creating and listing mapping tables.") - - # Lookup lists index to attribute - self._index_to_id = [n.id for n in self.attack_graph.nodes.values()] - self._index_to_full_name = ( - [n.full_name for n in self.attack_graph.nodes.values()] - ) - self._index_to_asset_type = ( - [n.name for n in self.lang_graph.assets.values()] - ) - - unique_step_type_names = { - n.full_name - for asset in self.lang_graph.assets.values() - for n in asset.attack_steps.values() - } - self._index_to_step_name = list(unique_step_type_names) - - self._index_to_model_asset_id = ( - [int(asset_id) for asset_id in self.attack_graph.model.assets] - ) - - unique_assoc_type_names = { - assoc.full_name - for asset in self.lang_graph.assets.values() - for assoc in asset.associations.values() - } - self._index_to_model_assoc_type = list(unique_assoc_type_names) - - # Lookup dicts attribute to index - self._id_to_index = { - n: i for i, n in enumerate(self._index_to_id)} - self._asset_type_to_index = { - n: i for i, n in enumerate(self._index_to_asset_type)} - self._step_name_to_index = { - n: i for i, n in enumerate(self._index_to_step_name) - } - self._model_asset_id_to_index = { - asset: i for i, asset in enumerate(self._index_to_model_asset_id) - } - self._model_assoc_type_to_index = { - assoc_type: i for i, assoc_type in - enumerate(self._index_to_model_assoc_type) - } - - def index_to_node(self, index: int) -> AttackGraphNode: - """Get a node from the attack graph by index - - Index is the position of the node in the lookup list. - First convert index to id and then fetch the node from the - AttackGraph. - - Raise LookupError if node with given index does not map to a node in - the attack graph and IndexError if the index is out of range for the - lookup list. - - Returns: - Attack graph node matching the id of the index in the lookup list - """ - - if index >= len(self._index_to_id): - raise IndexError( - f'Index {index}, is out of range of the ' - f'lookup list which is of length {len(self._index_to_id)}' - ) - - node_id = self._index_to_id[index] - node = self.attack_graph.nodes[node_id] - if not node: - raise LookupError( - f'Index {index} (id: {node_id}), does not map to a node' - ) - return node - - def node_to_index(self, node: AttackGraphNode) -> int: - """Get the index of an attack graph node - - Returns: - Index of the attack graph node in the lookup list - """ - - assert node, "Node can not be None" - return self._id_to_index[node.id] - - def action_to_node( - self, action: tuple[int, int] - ) -> list[AttackGraphNode]: - """Convert serialized action to malsim action format - - (0, None) -> None - (1, idx) -> Node with index idx - - Currently supports single action only. - """ - node = None - act, step_idx = action - if act: - node = self.index_to_node(step_idx) - return node - - - def _initialize_agents(self) -> dict[str, list[int]]: - """Initialize agent rewards, observations, and action surfaces - - Return: - - An action dictionary mapping agent to initial actions - (attacker entry points and pre-activated defenses) - """ - # Initialize list of agent - self.agents = copy.deepcopy(self.possible_agents) - - # Will contain initally enabled steps - initial_actions = {} - - for agent in self.agents: - # Initialize rewards - self.agents_dict[agent]["rewards"] = 0 - agent_type = self.agents_dict[agent]["type"] - initial_actions[agent] = [] - - if agent_type == "attacker": - attacker_id = self.agents_dict[agent]["attacker"] - attacker = self.attack_graph.attackers[attacker_id] - assert attacker, f"No attacker with id {attacker_id}" - - # Initialize observations and action surfaces - self.agents_dict[agent]["observation"] = \ - self.create_blank_observation() - self.agents_dict[agent]["action_surface"] = \ - query.get_attack_surface(attacker) - - # Initial actions for attacker are its entrypoints - for entry_point in attacker.entry_points: - initial_actions[agent].append( - self._id_to_index[entry_point.id]) - entry_point.extras['entrypoint'] = True - - elif agent_type == "defender": - # Initialize observations and action surfaces - self.agents_dict[agent]["observation"] = \ - self.create_blank_observation(default_obs_state = 0) - self.agents_dict[agent]["action_surface"] = \ - query.get_defense_surface(self.attack_graph) - - # Initial actions for defender are all pre-enabled defenses - initial_actions[agent] = [self._id_to_index[node.id] - for node in self.attack_graph.nodes.values() - if node.is_enabled_defense()] - - else: - self.agents_dict[agent]["action_surface"] = [] - - # Sort to make action surface deterministic - self.agents_dict[agent]["action_surface"].sort(key=lambda n: n.id) - - return initial_actions - - def initialize(self, max_iter=ITERATIONS_LIMIT): - """Create mapping tables, register agents, and initialize their - observations, action surfaces, and rewards. - - Return initial observations and infos. - """ - - logger.info("Initializing MAL ParralelEnv Simulator.") - self._create_mapping_tables() - - if logger.isEnabledFor(logging.DEBUG): - self.log_mapping_tables() - - self.max_iter = max_iter - self.cur_iter = 0 - - logger.debug("Creating and listing blank observation space.") - self._blank_observation = self.create_blank_observation() - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - self.format_full_observation(self._blank_observation) - ) - - # Initialize agents and record the entry point actions - initial_actions = self._initialize_agents() - - observations, _, _, _, infos = ( - self._observe_and_reward(initial_actions, [])) - - return observations, infos - - def register_attacker(self, agent_name, attacker: int): - logger.info( - 'Register attacker "%s" agent with ' - "attacker index %d.", agent_name, attacker - ) - assert agent_name not in self.agents_dict, \ - f"Duplicate attacker agent named {agent_name} not allowed" - - self.possible_agents.append(agent_name) - self.agents_dict[agent_name] = { - "type": "attacker", - "attacker": attacker, - "observation": {}, - "action_surface": [], - "rewards": 0 - } - - def register_defender(self, agent_name): - """Add defender agent to the simulator - - Defenders are run first so that the defenses prevent attackers - appropriately in case any attackers select attack steps that the - defenders safeguards against during the same step. - """ - logger.info('Register defender "%s" agent.', agent_name) - assert agent_name not in self.agents_dict, \ - f"Duplicate defender agent named {agent_name} not allowed" - - # Add defenders at the front of the list to make sure they have - # priority. - self.possible_agents.insert(0, agent_name) - self.agents_dict[agent_name] = { - "type": "defender", - "observation": {}, - "action_surface": [], - "rewards": 0 - } - - def get_attacker_agents(self) -> dict: - """Return agents dictionaries of attacker agents""" - return {k: v for k, v in self.agents_dict.items() - if v['type'] == "attacker"} - - def get_defender_agents(self) -> dict: - """Return agents dictionaries of defender agents""" - return {k: v for k, v in self.agents_dict.items() - if v['type'] == "defender"} - - def state(self): - # Should return a state for all agents - return NotImplementedError - - def _attacker_step(self, agent, attack_step): - actions = [] - attacker_index = self.agents_dict[agent]["attacker"] - attacker = self.attack_graph.attackers[attacker_index] - attack_step_node = self.index_to_node(attack_step) - - logger.info( - 'Attacker agent "%s" stepping through "%s"(%d).', - agent, - attack_step_node.full_name, - attack_step_node.id - ) - if query.is_node_traversable_by_attacker(attack_step_node, attacker): - if not attack_step_node.is_compromised_by(attacker): - logger.debug( - 'Attacker agent "%s" has compromised "%s"(%d).', - agent, attack_step_node.full_name, attack_step_node.id - ) - attacker.compromise(attack_step_node) - self.agents_dict[agent]["action_surface"] = \ - query.update_attack_surface_add_nodes( - attacker, - self.agents_dict[agent]["action_surface"], - [attack_step_node] - ) - self.agents_dict[agent]["action_surface"].sort(key=lambda n: n.id) - else: - logger.warning( - 'Attacker agent "%s" has already compromised "%s"(%d).', - agent, attack_step_node.full_name, attack_step_node.id - ) - - actions.append(attack_step) - else: - logger.warning( - 'Attacker agent "%s" tried to compromise untraversable ' - 'attack step"%s"(%d).', - agent, - attack_step_node.full_name, - attack_step_node.id - ) - return actions - - def update_viability( - self, - node: AttackGraphNode, - unviable_attack_steps: list[AttackGraphNode] = None - ) -> list[AttackGraphNode]: - """ - Update the viability of the node in the graph and return any - attack steps that are no longer viable. - Propagate this recursively via children as long as changes occur. - - Arguments: - node - the node to propagate updates from - unviable_attack_steps - a list of the attack steps that have been - made unviable by a defense enabled in the - current step - """ - - unviable_attack_steps = [] if unviable_attack_steps is None \ - else unviable_attack_steps - logger.debug( - 'Update viability for node "%s"(%d)', - node.full_name, - node.id - ) - assert not node.is_viable, ("update_viability should not be called" - f" on viable node {node.full_name}") - - if node.extras.get('entrypoint'): - # Never make entrypoint unviable, and do not - # propagate its viability further - node.is_viable = True - return unviable_attack_steps - - if node.type in ('and', 'or'): - unviable_attack_steps.append(node) - - for child in node.children: - original_value = child.is_viable - if child.type == 'or': - child.is_viable = False - for parent in child.parents: - child.is_viable = child.is_viable or parent.is_viable - if child.type == 'and': - child.is_viable = False - - if child.is_viable != original_value: - self.update_viability(child, unviable_attack_steps) - - return unviable_attack_steps - - def _defender_step( - self, agent, defense_step_index - ) -> tuple[list[int], list[AttackGraphNode]]: - - actions = [] - defense_step_node = self.attack_graph.nodes[ - self._index_to_id[defense_step_index] - ] - logger.info( - 'Defender agent "%s" stepping through "%s"(%d).', - agent, - defense_step_node.full_name, - defense_step_node.id - ) - if defense_step_node not in self.agents_dict[agent]["action_surface"]: - logger.warning( - 'Defender agent "%s" tried to step through "%s"(%d).' - 'which is not part of its defense surface. Defender ' - 'step will skip', - agent, - defense_step_node.full_name, - defense_step_node.id - ) - return actions, [] - - defense_step_node.defense_status = 1.0 - defense_step_node.is_viable = False - prevented_attack_steps = self.update_viability( - defense_step_node - ) - actions.append(defense_step_index) - - # Remove defense from all defender agents' action surfaces since it is - # already enabled. And remove all of the prevented attack steps from - # the attackers' action surfaces. - for agent_el in self.agents: - if self.agents_dict[agent_el]["type"] == "defender": - try: - self.agents_dict[agent_el]["action_surface"].\ - remove(defense_step_node) - except ValueError: - # Optimization: the defender is told to remove - # the node from its defense surface even if it - # may have not been present to save one extra - # lookup. - pass - elif self.agents_dict[agent_el]["type"] == "attacker": - for attack_step in prevented_attack_steps: - try: - # Node is no longer part of attacker action surface - self.agents_dict[agent_el]\ - ["action_surface"].remove(attack_step) - except ValueError: - # Optimization: the attacker is told to remove - # the node from its attack surface even if it may - # have not been present to save one extra lookup. - pass - - - return actions, prevented_attack_steps - - def _observe_attacker( - self, - attacker_agent, - performed_actions: dict[str, list[int]] - ) -> None: - """ - Update the attacker observation based on the actions performed - in current step. - - Arguments: - attacker_agent - the attacker agent to fill in the observation for - observation - the blank observation to fill in - """ - - obs_state = self.agents_dict[attacker_agent]["observation"]\ - ["observed_state"] - - # Set obs state of reached attack steps to 1 (enabled) - for _, actions in performed_actions.items(): - for step_index in actions: - - if step_index is None: - # Waiting does not affect obs - continue - - node_id = self._index_to_id[step_index] - node = self.attack_graph.nodes[node_id] - if node.type in ('or', 'and'): - # Attack step activated, set to 1 (enabled) - obs_state[step_index] = 1 - - for child in node.children: - # Set its children to 0 (disabled) - child_index = self._id_to_index[child.id] - if obs_state[child_index] == -1: - obs_state[child_index] = 0 - - def _observe_defender( - self, - defender_agent, - performed_actions: dict[str, list[int]] - ): - - obs_state = self.agents_dict[defender_agent]["observation"]\ - ["observed_state"] - - if not self.sim_settings.cumulative_defender_obs: - # Clear the state if we do not it to accumulate observations over - # time. - obs_state.fill(0) - - # Only show the latest steps taken - for _, actions in performed_actions.items(): - for action in actions: - obs_state[action] = 1 - - def _observe_agents(self, performed_actions): - """Collect agents observations""" - - for agent in self.agents: - agent_type = self.agents_dict[agent]["type"] - if agent_type == "defender": - self._observe_defender(agent, performed_actions) - - elif agent_type == "attacker": - self._observe_attacker(agent, performed_actions) - - else: - logger.error( - "Agent %s has unknown type: %s", - agent, self.agents_dict[agent]["type"] - ) - - def _reward_agents(self, performed_actions): - """Update rewards from latest performed actions""" - for agent, actions in performed_actions.items(): - agent_type = self.agents_dict[agent]["type"] - - for action in actions: - if action is None: - continue - - node_id = self._index_to_id[action] - node = self.attack_graph.nodes[node_id] - node_reward = node.extras.get('reward', 0) - - if agent_type == "attacker": - # If attacker performed step, it will receive - # a reward and penalize all defenders - self.agents_dict[agent]["rewards"] += node_reward - - for d_agent in self.get_defender_agents(): - self.agents_dict[d_agent]["rewards"] -= node_reward - else: - # If a defender performed step, it will be penalized - self.agents_dict[agent]["rewards"] -= node_reward - - - def _collect_agents_infos(self): - """Collect agent info, this is used to determine the possible - actions in the next iteration step. Then fill in all of the""" - - attackers_done = True - infos = {} - can_wait = { - "attacker": 0, - "defender": 1, - } - - for agent in self.agents: - agent_type = self.agents_dict[agent]["type"] - available_actions = [0] * len(self.attack_graph.nodes) - can_act = 0 - - if agent_type == "defender": - for node in self.agents_dict[agent]["action_surface"]: - index = self._id_to_index[node.id] - available_actions[index] = 1 - can_act = 1 - - if agent_type == "attacker": - attacker = self.attack_graph.attackers[ - self.agents_dict[agent]["attacker"] - ] - for node in self.agents_dict[agent]["action_surface"]: - if not node.is_compromised_by(attacker): - index = self._id_to_index[node.id] - available_actions[index] = 1 - can_act = 1 - attackers_done = False - - infos[agent] = { - "action_mask": ( - np.array( - [can_wait[agent_type], can_act], dtype=np.int8), - np.array( - available_actions, dtype=np.int8) - )} - - return attackers_done, infos - - def _disable_attack_steps( - self, attack_steps_to_disable: list[AttackGraphNode] - ): - """Disable nodes for each attacker agent - - For each compromised attack step uncompromise the node, disable its - observed_state, and remove the rewards. - """ - - for attacker_agent in self.get_attacker_agents(): - attacker_index = self.agents_dict[attacker_agent]["attacker"] - attacker: Attacker = self.attack_graph.attackers[attacker_index] - - for unviable_node in attack_steps_to_disable: - if unviable_node.is_compromised_by(attacker): - - # Reward is no longer present for attacker - node_reward = unviable_node.extras.get('reward', 0) - self.agents_dict[attacker_agent]["rewards"] -= node_reward - - # Reward is no longer present for defenders - for defender_agent in self.get_defender_agents(): - self.agents_dict[defender_agent]["rewards"] += node_reward - - # Uncompromise node if requested - attacker.undo_compromise(unviable_node) - - # Uncompromised nodes observed state is 0 (disabled) - step_index = self._id_to_index[unviable_node.id] - agent_obs = self.agents_dict[attacker_agent]["observation"] - agent_obs['observed_state'][step_index] = 0 - - - def _observe_and_reward( - self, - performed_actions: dict[str, list[int]], - prevented_attack_steps: list[AttackGraphNode] - ): - """Update observations and reward agents based on latest actions - - Returns 5 dicts, each mapping from agent to: - observations, rewards, terminations, truncations, infos - """ - - terminations = {} - truncations = {} - infos = {} - finished_agents = [] - - if self.sim_settings.uncompromise_untraversable_steps: - # Disable attack steps for attackers to update the - # observations, rewards and action surface - self._disable_attack_steps(prevented_attack_steps) - - # Fill in the agent observations, rewards, - # infos, terminations, truncations. - # If no attackers have any actions left - # to take the simulation will terminate. - self._observe_agents(performed_actions) - self._reward_agents(performed_actions) - attackers_done, infos = self._collect_agents_infos() - - for agent in self.agents: - # Terminate simulation if no attackers have actions to take - terminations[agent] = attackers_done - if attackers_done: - logger.debug( - "No attacker has actions left to perform, " - "terminate agent \"%s\".", agent) - - truncations[agent] = False - if self.cur_iter >= self.max_iter: - logger.debug( - "Simulation has reached the maximum number of " - "iterations, %d, terminate agent \"%s\".", - self.max_iter, agent) - truncations[agent] = True - - if terminations[agent] or truncations[agent]: - finished_agents.append(agent) - - if logger.isEnabledFor(logging.DEBUG): - # Debug print agent states - agent_obs_str = self.format_obs_var_sec( - self.agents_dict[agent]["observation"], - included_values = [0, 1]) - - logger.debug( - 'Observation for agent "%s":\n%s', agent, agent_obs_str) - logger.debug( - 'Rewards for agent "%s": %d', agent, - self.agents_dict[agent]["rewards"]) - logger.debug( - 'Termination for agent "%s": %s', - agent, terminations[agent]) - logger.debug( - 'Truncation for agent "%s": %s', - agent, str(truncations[agent])) - - agent_info_str = self._format_info(infos[agent]) - logger.debug( - 'Info for agent "%s":\n%s', agent, agent_info_str) - - for agent in finished_agents: - self.agents.remove(agent) - - observations = {agent: self.agents_dict[agent]["observation"] \ - for agent in self.agents_dict} - rewards = {agent: self.agents_dict[agent]["rewards"] \ - for agent in self.agents_dict} - return ( - observations, - rewards, - terminations, - truncations, - infos - ) - - def step(self, actions): - """ - step(action) takes in an action for each agent and should return the - - observations - - rewards - - terminations - - truncations - - infos - dicts where each dict looks like {agent_1: item_1, agent_2: item_2} - """ - logger.debug( - "Stepping through iteration %d/%d", self.cur_iter, self.max_iter) - logger.debug("Performing actions: %s", actions) - - # Map agent to defense/attack steps performed in this step - performed_actions = {} - prevented_attack_steps = [] - - # Peform agent actions - for agent in self.agents: - action = actions[agent] - if action[0] == 0: - # Agent wants to wait - do nothing - continue - - action_step = action[1] - - if self.agents_dict[agent]["type"] == "attacker": - performed_actions[agent] = \ - self._attacker_step(agent, action_step) - - elif self.agents_dict[agent]["type"] == "defender": - defender_actions, prevented_attack_steps = \ - self._defender_step(agent, action_step) - performed_actions[agent] = defender_actions - - else: - logger.error( - 'Agent %s has unknown type: %s', - agent, self.agents_dict[agent]["type"]) - - observations, rewards, terminations, truncations, infos = ( - self._observe_and_reward( - performed_actions, - prevented_attack_steps - ) - ) - - self.cur_iter += 1 - - return observations, rewards, terminations, truncations, infos - - def render(self): - logger.debug("Ingest attack graph into Neo4J database.") - neo4j.ingest_attack_graph( - self.attack_graph, - neo4j_configs["uri"], - neo4j_configs["username"], - neo4j_configs["password"], - neo4j_configs["dbname"], - delete=True, - ) diff --git a/malsim/sims/mal_simulator_settings.py b/malsim/sims/mal_simulator_settings.py deleted file mode 100644 index 6b661150..00000000 --- a/malsim/sims/mal_simulator_settings.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class MalSimulatorSettings(): - """Contains settings used in MalSimulator""" - - # uncompromise_untraversable_steps - # - Uncompromise (evict attacker) from nodes/steps that are no longer - # traversable (often because a defense kicked in) if set to True - # otherwise: - # - Leave the node/step compromised even after it becomes untraversable - uncompromise_untraversable_steps: bool = False - - # cumulative_defender_obs - # - Defender sees the status of the whole attack graph if set to True - # otherwise: - # - Defender only sees the status of nodes changed in the current step - cumulative_defender_obs: bool = True diff --git a/malsim/wrappers/__init__.py b/malsim/wrappers/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pyproject.toml b/pyproject.toml index 818f0843..b7105001 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,7 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "py2neo>=2021.2.3", - "python-jsonschema-objects>=0.4.1", "mal-toolbox>=0.3,<0.4", - "numpy>=1.21.4", - "pettingzoo>=1.24.2", - "gymnasium==1.0", "PyYAML>=6.0.1" ] license = {text = "Apache Software License"} @@ -31,6 +27,18 @@ classifiers = [ "Topic :: Scientific/Engineering" ] +[project.optional-dependencies] +ml = [ + "numpy>=1.21.4", + "pettingzoo>=1.24.2", + "gymnasium==1.0", +] +dev = [ + "pytest", + "mypy", + "ruff", +] + [project.urls] "Homepage" = "https://github.com/mal-lang/mal-simulator" "Bug Tracker" = "https://github.com/mal-lang/mal-simulator/issues" @@ -38,7 +46,7 @@ classifiers = [ [project.scripts] -malsim = "malsim.cli:main" +malsim = "malsim.__main__:main" [build-system] requires = ["setuptools>=61.0"] diff --git a/tests/agents/test_searchers.py b/tests/agents/test_searchers.py new file mode 100644 index 00000000..628a0343 --- /dev/null +++ b/tests/agents/test_searchers.py @@ -0,0 +1,247 @@ +from unittest.mock import MagicMock +from maltoolbox.attackgraph import AttackGraphNode, Attacker +from maltoolbox.attackgraph.query import calculate_attack_surface +from maltoolbox.language import LanguageGraph +from malsim.mal_simulator import MalSimAgentStateView +from malsim.agents import BreadthFirstAttacker, DepthFirstAttacker + + +def test_breadth_first_traversal_simple(dummy_lang_graph: LanguageGraph): + """ + node1 + | + node2 + | + node3 + | + node4 + """ + dummy_or_attack_step = ( + dummy_lang_graph.assets['DummyAsset'] + .attack_steps['DummyOrAttackStep'] + ) + + # Create nodes + node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1) + node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2) + node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3) + node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4) + + # Connect nodes (Node1 -> Node2 -> Node3 -> Node4) + node1.children.add(node2) + node2.parents.add(node1) + node2.children.add(node3) + node3.parents.add(node2) + node3.children.add(node4) + node4.parents.add(node3) + + # Set up an attacker + attacker = Attacker( + name = "TestAttacker", + entry_points = {node1}, + reached_attack_steps = set(), + attacker_id = 100) + + # Set up a mock MalSimAgentState + agent = MagicMock() + agent.action_surface = [node1] + + # Set up MalSimAgentStateView + agent_view = MalSimAgentStateView(agent) + + # Configure BreadthFirstAttacker + agent_config = {"seed": 42, "randomize": False} + attacker_ai = BreadthFirstAttacker(agent_config) + + # Expected traversal order + expected_order = [1, 2, 3, 4] + + actual_order = [] + for _ in expected_order: + # Get next action + action_node = attacker_ai.get_next_action(agent_view) + assert action_node is not None, "Action node shouldn't be None" + + # Mark node as compromised + attacker.compromise(action_node) + agent.action_surface = calculate_attack_surface(attacker) + + # Store the ID for verification + actual_order.append(action_node.id) + + assert actual_order == expected_order, \ + "Traversal order does not match expected breadth-first order" + +def test_breadth_first_traversal_complicated(dummy_lang_graph: LanguageGraph): + r""" + node1 ______________ + / \ \ + node2 node3 node8 + / \ / \ + node4 node5 node6 node7 + + """ + + dummy_or_attack_step = ( + dummy_lang_graph.assets['DummyAsset'] + .attack_steps['DummyOrAttackStep'] + ) + + # Create nodes + node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1) + node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2) + node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3) + node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4) + node5 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=5) + node6 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=6) + node7 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=7) + node8 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=8) + + # Connect nodes (Node1 -> Node2 -> Node3 -> Node4) + node1.children.add(node2) + node2.parents.add(node1) + node1.children.add(node3) + node3.parents.add(node1) + node1.children.add(node8) + node8.parents.add(node1) + + node2.children.add(node4) + node4.parents.add(node2) + node2.children.add(node5) + node5.parents.add(node2) + + node3.children.add(node6) + node6.parents.add(node3) + node3.children.add(node7) + node7.parents.add(node3) + + # Set up an attacker + attacker = Attacker( + name = "TestAttacker", + entry_points = {node1}, + reached_attack_steps = set(), + attacker_id = 100) + + # Set up a mock MalSimAgentState + agent = MagicMock() + agent.action_surface = [node1] + + # Set up MalSimAgentStateView + agent_view = MalSimAgentStateView(agent) + + # Configure BreadthFirstAttacker + agent_config = {"seed": 42, "randomize": False} + attacker_ai = BreadthFirstAttacker(agent_config) + + # Expected traversal order + expected_order = [1, 2, 3, 8, 4, 5, 6, 7] + + actual_order = [] + for _ in expected_order: + # Get next action + action_node = attacker_ai.get_next_action(agent_view) + assert action_node is not None, "Action node shouldn't be None" + + # Mark node as compromised + attacker.compromise(action_node) + agent.action_surface = calculate_attack_surface(attacker) + + # Store the ID for verification + actual_order.append(action_node.id) + + for level in (0, 1), (1, 4), (4, 8): + assert set(expected_order[level[0]:level[1]]) == set(actual_order[level[0]:level[1]]), \ + "Traversal order does not match expected breadth-first order" + + +def test_depth_first_traversal_complicated(dummy_lang_graph: LanguageGraph): + r""" + node1 ______________ + / \ \ + node2 node3 node8 + / \ / \ + node4 node5 node6 node7 + + """ + dummy_or_attack_step = ( + dummy_lang_graph.assets['DummyAsset'] + .attack_steps['DummyOrAttackStep'] + ) + + # Create nodes + node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1) + node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2) + node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3) + node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4) + node5 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=5) + node6 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=6) + node7 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=7) + node8 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=8) + + # Connect nodes (Node1 -> Node2 -> Node3 -> Node4) + node1.children.add(node2) + node2.parents.add(node1) + node1.children.add(node3) + node3.parents.add(node1) + node1.children.add(node8) + node8.parents.add(node1) + + node2.children.add(node4) + node4.parents.add(node2) + node2.children.add(node5) + node5.parents.add(node2) + + node3.children.add(node6) + node6.parents.add(node3) + node3.children.add(node7) + node7.parents.add(node3) + + # Set up an attacker + attacker = Attacker( + name = "TestAttacker", + entry_points = {node1}, + reached_attack_steps = set(), + attacker_id = 100) + + # Set up a mock MalSimAgentState + agent = MagicMock() + agent.action_surface = [node1] + + # Set up MalSimAgentStateView + agent_view = MalSimAgentStateView(agent) + + # Configure BreadthFirstAttacker + agent_config = {"seed": 42, "randomize": False} + attacker_ai = DepthFirstAttacker(agent_config) + + # Expected traversal order + expected_order = [1, 8, 3, 7, 6, 2, 5, 4] + + actual_order = [] + for _ in expected_order: + # Get next action + action_node = attacker_ai.get_next_action(agent_view) + assert action_node is not None, "Action node shouldn't be None" + + # Mark node as compromised + attacker.compromise(action_node) + agent.action_surface = calculate_attack_surface(attacker) + + # Store the ID for verification + actual_order.append(action_node.id) + + assert actual_order == expected_order, \ + "Traversal order does not match expected breadth-first order" + + # All children of 1 must come after 1 + assert actual_order.index(8) > actual_order.index(1) + assert actual_order.index(2) > actual_order.index(1) + assert actual_order.index(3) > actual_order.index(1) + + # All children of 3 must come after 3 + assert actual_order.index(7) > actual_order.index(3) + assert actual_order.index(6) > actual_order.index(3) + + # All children of 2 must come after 2 + assert actual_order.index(4) > actual_order.index(2) + assert actual_order.index(5) > actual_order.index(2) diff --git a/tests/conftest.py b/tests/conftest.py index c5504292..1c227d5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,13 @@ from os import path import pytest -from maltoolbox.language import LanguageGraph from maltoolbox.model import Model from maltoolbox.attackgraph import create_attack_graph -from malsim.sims.mal_simulator import MalSimulator +from maltoolbox.language import ( + LanguageGraph, LanguageGraphAttackStep, LanguageGraphAsset +) +from malsim.mal_simulator import MalSimulator +from malsim.envs import MalSimVectorizedObsEnv model_file_name = 'tests/testdata/models/simple_test_model.yml' attack_graph_file_name = path.join('/tmp','attack_graph.json') @@ -34,18 +37,15 @@ def empty_model(name, lang_classes_factory): ## Fixtures @pytest.fixture(scope="session", name="env") -def fixture_env()-> MalSimulator: +def fixture_env()-> MalSimVectorizedObsEnv: attack_graph = create_attack_graph(lang_file_name, model_file_name) - lang_graph = attack_graph.lang_graph - model = attack_graph.model - attack_graph.save_to_file(attack_graph_file_name) + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph, max_iter=1000)) + env.register_defender('defender') - env = MalSimulator(lang_graph, model, attack_graph, max_iter=1000) - - env.register_attacker("attacker", 0) - env.register_defender("defender") + attacker_id = env.sim.attack_graph.attackers[0].id + env.register_attacker('attacker', attacker_id) return env @@ -85,3 +85,38 @@ def model(corelang_lang_graph): """ # Init LanguageClassesFactory return Model.load_from_file(model_file_name, corelang_lang_graph) + +@pytest.fixture +def dummy_lang_graph(corelang_lang_graph): + """Fixture that generates a dummy LanguageGraph with a dummy + LanguageGraphAsset and LanguageGraphAttackStep + """ + lang_graph = LanguageGraph() + dummy_asset = LanguageGraphAsset( + name = 'DummyAsset' + ) + lang_graph.assets['DummyAsset'] = dummy_asset + dummy_or_attack_step_node = LanguageGraphAttackStep( + name = 'DummyOrAttackStep', + type = 'or', + asset = dummy_asset + ) + dummy_asset.attack_steps['DummyOrAttackStep'] = dummy_or_attack_step_node + + dummy_and_attack_step_node = LanguageGraphAttackStep( + name = 'DummyAndAttackStep', + type = 'and', + asset = dummy_asset + ) + dummy_asset.attack_steps['DummyAndAttackStep'] =\ + dummy_and_attack_step_node + + dummy_defense_attack_step_node = LanguageGraphAttackStep( + name = 'DummyDefenseAttackStep', + type = 'defense', + asset = dummy_asset + ) + dummy_asset.attack_steps['DummyDefenseAttackStep'] =\ + dummy_defense_attack_step_node + + return lang_graph diff --git a/tests/envs/test_example_scenarios.py b/tests/envs/test_example_scenarios.py new file mode 100644 index 00000000..ea7637e8 --- /dev/null +++ b/tests/envs/test_example_scenarios.py @@ -0,0 +1,151 @@ +""" +Run a scenario and make sure expected actions are chosen and +expected reward is given to agents +""" + +from malsim.scenario import create_simulator_from_scenario + +def test_bfs_vs_bfs_state_and_reward(): + """ + The point of this test is to see that a specific + scenario runs deterministically. + + The test creates a simulator, two agents and runs them both with + BFS Agents against each other. + + It then verifies that rewards and actions performed are what we expected. + """ + + sim, agents = create_simulator_from_scenario( + "tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml" + ) + sim.reset() + + defender_agent_name = "defender1" + attacker_agent_name = "attacker1" + + attacker_agent_info = next( + agent for agent in agents if agent["name"] == attacker_agent_name + ) + defender_agent_info = next( + agent for agent in agents if agent["name"] == defender_agent_name + ) + + attacker_agent = attacker_agent_info["agent"] + defender_agent = defender_agent_info["agent"] + + total_reward_defender = 0 + total_reward_attacker = 0 + + attacker_actions = [] + defender_actions = [] + + while True: + # Run the simulation until agents are terminated/truncated + + # Select attacker node + attacker_agent_state = sim.agent_states[attacker_agent_info["name"]] + attacker_node = attacker_agent.get_next_action(attacker_agent_state) + + # Select defender node + defender_agent_state = sim.agent_states[defender_agent_info["name"]] + defender_node = defender_agent.get_next_action(defender_agent_state) + + # Step + actions = { + defender_agent_name: [defender_node] if defender_node else [], + attacker_agent_name: [attacker_node] if attacker_node else [] + } + states = sim.step(actions) + + # If actions were performed, add them to respective list + if attacker_node and attacker_node in \ + states['attacker1'].step_performed_nodes: + attacker_actions.append(attacker_node.full_name) + assert attacker_node in states['defender1'].step_all_compromised_nodes + + if defender_node and defender_node in \ + states['defender1'].step_performed_nodes: + defender_actions.append(defender_node.full_name) + + total_reward_defender += defender_agent_state.reward + total_reward_attacker += attacker_agent_state.reward + + # Break simulation if trunc or term + if defender_agent_state.terminated or attacker_agent_state.terminated: + break + if defender_agent_state.truncated or attacker_agent_state.truncated: + break + + # Make sure the actions performed were as expected + assert attacker_actions == [ + 'Internet:attemptReverseReach', + 'Internet:networkForwardingUninspected', + 'Internet:deny', + 'Internet:accessNetworkData', + 'ConnectionRule Internet->Linux ' + 'System:attemptConnectToApplicationsUninspected', + 'Internet:reverseReach', + 'Internet:networkForwardingInspected', + 'ConnectionRule Internet->Linux System:attemptAccessNetworksUninspected', + 'ConnectionRule Internet->Linux System:attemptDeny', + 'Internet:attemptEavesdrop', + 'Internet:attemptAdversaryInTheMiddle', + 'ConnectionRule Internet->Linux System:bypassRestricted', + 'ConnectionRule Internet->Linux System:bypassPayloadInspection', + 'ConnectionRule Internet->Linux System:connectToApplicationsUninspected', + 'ConnectionRule Internet->Linux System:attemptReverseReach', + 'ConnectionRule Internet->Linux System:attemptAccessNetworksInspected', + 'ConnectionRule Internet->Linux System:attemptConnectToApplicationsInspected', + 'ConnectionRule Internet->Linux System:successfulAccessNetworksUninspected', + 'ConnectionRule Internet->Linux System:deny', + 'Internet:bypassEavesdropDefense', + 'Internet:successfulEavesdrop', + 'Internet:bypassAdversaryInTheMiddleDefense', + 'Internet:successfulAdversaryInTheMiddle', + 'Linux system:networkConnectUninspected', + 'Linux system:networkConnectInspected', + 'ConnectionRule Internet->Linux System:reverseReach', + 'ConnectionRule Internet->Linux System:successfulAccessNetworksInspected', + 'ConnectionRule Internet->Linux System:connectToApplicationsInspected', + 'ConnectionRule Internet->Linux System:accessNetworksUninspected', + 'Linux system:denyFromNetworkingAsset', + 'Internet:eavesdrop', + 'Internet:adversaryInTheMiddle', + 'Linux system:attemptUseVulnerability', + 'Linux system:networkConnect', + 'Linux system:specificAccessNetworkConnect', + 'Linux system:softwareProductVulnerabilityNetworkAccessAchieved', + 'Linux system:attemptReverseReach', + 'ConnectionRule Internet->Linux System:accessNetworksInspected', + 'Linux system:attemptDeny', + 'Internet:accessInspected' + ] + + assert defender_actions == [ + 'Linux system:notPresent', + 'Linux system:supplyChainAuditing', + 'Internet:networkAccessControl', + 'Internet:eavesdropDefense', + 'Internet:adversaryInTheMiddleDefense', + 'ConnectionRule Internet->Linux System:restricted', + 'ConnectionRule Internet->Linux System:payloadInspection', + 'Secret data:notPresent', + 'SoftwareVuln:notPresent' + ] + for step_id in attacker_actions: + # Make sure that all attacker actions led to compromise + node = sim.attack_graph.get_node_by_full_name(step_id) + assert node.is_compromised() + + for step_id in defender_actions: + # Make sure that all defender actions let to defense enabled + node = sim.attack_graph.get_node_by_full_name(step_id) + assert node.is_enabled_defense() + + # Verify rewards in latest run and total rewards + assert attacker_agent_state.reward == 0 + assert defender_agent_state.reward == -50 + + assert total_reward_attacker == 0 + assert total_reward_defender == -2000 diff --git a/tests/api_test.py b/tests/envs/test_gym_envs.py similarity index 89% rename from tests/api_test.py rename to tests/envs/test_gym_envs.py index ec20ce17..7c7fa592 100644 --- a/tests/api_test.py +++ b/tests/envs/test_gym_envs.py @@ -11,8 +11,8 @@ from gymnasium.utils import env_checker from pettingzoo.test import parallel_api_test -from malsim.sims.mal_simulator import MalSimulator -from malsim.wrappers.gym_wrapper import AttackerEnv, DefenderEnv, MaskingWrapper +from malsim.envs import MalSimVectorizedObsEnv, AttackerEnv, DefenderEnv +from malsim.envs.gym_envs import MaskingWrapper from malsim.agents.searchers import BreadthFirstAttacker, DepthFirstAttacker logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def register_gym_agent(agent_id, entry_point): gym.register(agent_id, entry_point=entry_point) -def test_pz(env: MalSimulator): +def test_pz(env: MalSimVectorizedObsEnv): logger.debug('Run Parrallel API test.') parallel_api_test(env) @@ -43,7 +43,6 @@ def test_gym(): env = gym.make( 'MALDefenderEnv-v0', scenario_file=scenario_file, - unholy=False, ) env_checker.check_env(env.unwrapped) register_gym_agent('MALAttackerEnv-v0', entry_point=AttackerEnv) @@ -53,8 +52,6 @@ def test_gym(): ) env_checker.check_env(env.unwrapped) - pass - def test_random_defender_actions(): register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) @@ -88,7 +85,6 @@ def test_episode(): env = gym.make( 'MALDefenderEnv-v0', scenario_file=scenario_file, - unholy=False, ) done = False @@ -107,7 +103,7 @@ def test_episode(): def test_mask(): - gym.register('MALDefenderEnv-v0', entry_point=DefenderEnv) + register_gym_agent('MALDefenderEnv-v0', entry_point=DefenderEnv) env = gym.make( 'MALDefenderEnv-v0', scenario_file='tests/testdata/scenarios/simple_scenario.yml', @@ -126,7 +122,6 @@ def test_defender_penalty(): env = gym.make( 'MALDefenderEnv-v0', scenario_file=scenario_file, - unholy=False, ) _, info = env.reset() @@ -160,7 +155,7 @@ def test_action_mask(): # assert reward < 0 # All defense steps cost something -def test_env_step(env: MalSimulator) -> None: +def test_env_step(env: MalSimVectorizedObsEnv) -> None: obs, info = env.reset() attacker_action = env.action_space('attacker').sample() defender_action = env.action_space('defender').sample() @@ -171,7 +166,7 @@ def test_env_step(env: MalSimulator) -> None: assert 'defender' in obs -def test_check_space_env(env: MalSimulator) -> None: +def test_check_space_env(env: MalSimVectorizedObsEnv) -> None: attacker_space = env.observation_space('attacker') defender_space = env.observation_space('defender') @@ -206,7 +201,7 @@ def check_space(space, obs): DepthFirstAttacker, ], ) -def test_attacker(env: MalSimulator, attacker_class) -> None: +def test_attacker(env: MalSimVectorizedObsEnv, attacker_class) -> None: obs, info = env.reset() attacker = attacker_class( dict( @@ -218,13 +213,15 @@ def test_attacker(env: MalSimulator, attacker_class) -> None: step_limit = 1000000 done = False while not done and steps < step_limit: - action = attacker.compute_action_from_dict( - obs[AGENT_ATTACKER], info[AGENT_ATTACKER]['action_mask'] - ) + action_node = attacker.get_next_action( + env.get_agent_state(AGENT_ATTACKER)) + action = (0, None) + if action_node: + action = (1, env.node_to_index(action_node)) assert action != ACTION_TERMINATE assert action != ACTION_WAIT obs, rewards, terminated, truncated, info = env.step( - {AGENT_ATTACKER: action, AGENT_DEFENDER: [0]} + {AGENT_ATTACKER: action, AGENT_DEFENDER: (0, None)} ) sum_rewards += rewards[AGENT_ATTACKER] done = terminated[AGENT_ATTACKER] or truncated[AGENT_ATTACKER] @@ -233,7 +230,7 @@ def test_attacker(env: MalSimulator, attacker_class) -> None: assert done, 'Attacker failed to explore attack steps' -def test_env_multiple_steps(env: MalSimulator) -> None: +def test_env_multiple_steps(env: MalSimVectorizedObsEnv) -> None: obs, info = env.reset() for _ in range(100): attacker_action = env.action_space('attacker').sample() diff --git a/tests/envs/test_vectorized_obs_mal_simulator.py b/tests/envs/test_vectorized_obs_mal_simulator.py new file mode 100644 index 00000000..52cb2783 --- /dev/null +++ b/tests/envs/test_vectorized_obs_mal_simulator.py @@ -0,0 +1,570 @@ +"""Test MalSimulator class""" + +from maltoolbox.attackgraph import AttackGraph, Attacker +from malsim.mal_simulator import MalSimulator, MalSimAttackerState +from malsim.envs import MalSimVectorizedObsEnv +from malsim.scenario import load_scenario + +def test_create_blank_observation(corelang_lang_graph, model): + """Make sure blank observation contains correct default values""" + + attack_graph = AttackGraph(corelang_lang_graph, model) + sim = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + num_objects = len(attack_graph.nodes) + blank_observation = sim._create_blank_observation() + + assert len(blank_observation['is_observable']) == num_objects + for state in blank_observation['is_observable']: + # Default is that all nodes are observable, + # unless anything else is given through its extras field + assert state == 1 + + assert len(blank_observation['observed_state']) == num_objects + for state in blank_observation['observed_state']: + assert state == -1 # This is the default (which we get in blank observation) + + assert len(blank_observation['remaining_ttc']) == num_objects + for ttc in blank_observation['remaining_ttc']: + assert ttc == 0 # TTC is currently always 0 no matter what + + # asset_type_index points us to an asset type in sim._index_to_asset_type + # the index where asset_type_index lies on will point to an attack step id in sim._index_to_id + # The type we get from sim._index_to_asset_type[asset_type_index] + # should be the same as the asset type of attack step with id index in attack graph + assert len(blank_observation['asset_type']) == num_objects + for index, asset_type_index in enumerate(blank_observation['asset_type']): + # Note: offset is decremented from asset_type_index + expected_type = sim._index_to_asset_type[asset_type_index] + node = sim.index_to_node(index) + assert node.lg_attack_step.asset.name == expected_type + + # asset_id on index X in blank_observation['asset_id'] + # should be the same as the id of the asset of attack step X + assert len(blank_observation['asset_id']) == num_objects + for index, expected_asset_id in enumerate(blank_observation['asset_id']): + node = sim.index_to_node(index) + assert node.model_asset.id == expected_asset_id + + assert len(blank_observation['step_name']) == num_objects + + expected_num_edges = sum([1 for step in attack_graph.nodes.values() + for child in step.children] + + # We expect all defenses again (reversed) + [1 for step in attack_graph.nodes.values() + for child in step.children + if step.type == "defense"]) + assert len(blank_observation['attack_graph_edges']) == expected_num_edges + + +def test_create_blank_observation_deterministic( + corelang_lang_graph, model + ): + """Make sure blank observation is deterministic with seed given""" + + attack_graph = AttackGraph(corelang_lang_graph, model) + attack_graph.attach_attackers() + attacker = next(iter(attack_graph.attackers.values())) + + sim = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + sim.register_attacker("test_attacker", attacker.id) + sim.register_defender("test_defender") + + obs1, _ = sim.reset(seed=123) + obs2, _ = sim.reset(seed=123) + + assert list(obs1['test_attacker']['is_observable']) == list(obs2['test_attacker']['is_observable']) + assert list(obs1['test_attacker']['is_actionable']) == list(obs2['test_attacker']['is_actionable']) + assert list(obs1['test_attacker']['observed_state']) == list(obs2['test_attacker']['observed_state']) + assert list(obs1['test_attacker']['remaining_ttc']) == list(obs2['test_attacker']['remaining_ttc']) + assert list(obs1['test_attacker']['asset_type']) == list(obs2['test_attacker']['asset_type']) + assert list(obs1['test_attacker']['asset_id']) == list(obs2['test_attacker']['asset_id']) + assert list(obs1['test_attacker']['step_name']) == list(obs2['test_attacker']['step_name']) + + for i, elem in enumerate(obs1['test_attacker']['attack_graph_edges']): + assert list(obs2['test_attacker']['attack_graph_edges'][i]) == list(elem) + + assert list(obs1['test_attacker']['model_asset_id']) == list(obs2['test_attacker']['model_asset_id']) + assert list(obs1['test_attacker']['model_asset_type']) == list(obs2['test_attacker']['model_asset_type']) + + for i, elem in enumerate(obs1['test_attacker']['model_edges_ids']): + assert list(obs2['test_attacker']['model_edges_ids'][i]) == list(elem) + + assert list(obs1['test_attacker']['model_edges_type']) == list(obs2['test_attacker']['model_edges_type']) + + +def test_step_deterministic( + corelang_lang_graph, model + ): + """Make sure blank observation is deterministic with seed given""" + + attack_graph = AttackGraph(corelang_lang_graph, model) + attack_graph.attach_attackers() + attacker = next(iter(attack_graph.attackers.values())) + + sim = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + sim.register_attacker("test_attacker", attacker.id) + sim.register_defender("test_defender") + + obs1 = {} + obs2 = {} + + # Run 1 + sim.reset(seed=123) + for _ in range(10): + attacker_node = next( + n for n in sim.get_agent_state('test_attacker').action_surface + if not n.is_compromised() + ) + attacker_action = (1, sim.node_to_index(attacker_node)) + obs1, _, _, _, _ = sim.step( + {'test_defender': (0, None), 'test_attacker': attacker_action} + ) + + # Run 2 - identical + sim.reset(seed=123) + for _ in range(10): + attacker_node = next( + n for n in sim.get_agent_state('test_attacker').action_surface + if not n.is_compromised() + ) + attacker_action = (1, sim.node_to_index(attacker_node)) + obs2, _, _, _, _ = sim.step( + {'test_defender': (0, None), 'test_attacker': attacker_action} + ) + + assert list(obs1['test_attacker']['observed_state']) == list(obs2['test_attacker']['observed_state']) + assert list(obs1['test_defender']['observed_state']) == list(obs2['test_defender']['observed_state']) + + +def test_create_blank_observation_observability_given( + corelang_lang_graph, model + ): + """Make sure observability propagates correctly from extras field/scenario + to observation in mal simulator""" + + # Load Scenario with observability rules set + scenario_file = \ + 'tests/testdata/scenarios/traininglang_observability_scenario.yml' + ag, _ = load_scenario(scenario_file) + env = MalSimVectorizedObsEnv(MalSimulator(ag)) + + num_objects = len(env.sim.attack_graph.nodes) + blank_observation = env._create_blank_observation() + + assert len(blank_observation['is_observable']) == num_objects + + for index, observable in enumerate(blank_observation['is_observable']): + node = env.index_to_node(index) + + # Below are the rules from the traininglang observability scenario + # made into if statements + if node.lg_attack_step.asset.name == 'Host' and node.name in ('access'): + assert observable + elif node.lg_attack_step.asset.name == 'Host' and node.name in ('authenticate'): + assert observable + elif node.lg_attack_step.asset.name == 'Data' and node.name in ('read'): + assert observable + elif node.model_asset.name == 'User:3' and node.name in ('phishing'): + assert observable + else: + assert not observable + +def test_create_blank_observation_actionability_given( + corelang_lang_graph, model + ): + """Make sure actionability propagates correctly from extras field/scenario + to observation in mal simulator""" + + # Load Scenario with observability rules set + scenario_file = 'tests/testdata/scenarios/traininglang_actionability_scenario.yml' + ag, _ = load_scenario(scenario_file) + env = MalSimVectorizedObsEnv(MalSimulator(ag)) + + num_objects = len(env.sim.attack_graph.nodes) + blank_observation = env._create_blank_observation() + + assert len(blank_observation['is_actionable']) == num_objects + + for index, actionable in enumerate(blank_observation['is_actionable']): + node = env.index_to_node(index) + + # Below are the rules from the traininglang observability scenario + # made into if statements + if node.lg_attack_step.asset.name == 'Host' and node.name in ('notPresent'): + assert actionable + elif node.lg_attack_step.asset.name == 'Data' and node.name in ('notPresent'): + assert actionable + elif node.model_asset.name == 'User:3' and node.name in ('notPresent'): + assert actionable + else: + assert not actionable + +def test_step(corelang_lang_graph, model): + attack_graph = AttackGraph(corelang_lang_graph, model) + entry_point = attack_graph.get_node_by_full_name('OS App:fullAccess') + + attacker = Attacker( + 'attacker1', + reached_attack_steps = {entry_point}, + entry_points = {entry_point}, + attacker_id = 100) + attack_graph.add_attacker(attacker, attacker.id) + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + # Refresh attack graph reference to the one deepcopied during the reset + attack_graph = env.sim.attack_graph + + agent_info = MalSimAttackerState(attacker.name, attacker.id) + + # Can not attack the notPresent step + defense_step = attack_graph\ + .get_node_by_full_name('OS App:notPresent') + actions = env.sim._attacker_step(agent_info, {defense_step}) + assert not actions + assert not agent_info.step_action_surface_additions + + attack_step = attack_graph.get_node_by_full_name('OS App:attemptRead') + + # Action needs to be in action surface to be an allowed action + agent_info.action_surface = {attack_step} + + # Since action is in attack surface and since it is traversable, + # action will be performed. + env.sim._attacker_step(agent_info, {attack_step}) + assert agent_info.step_performed_nodes == {attack_step} + assert agent_info.step_action_surface_additions == attack_step.children + + +def test_malsimulator_defender_step(corelang_lang_graph, model): + attack_graph = AttackGraph(corelang_lang_graph, model) + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + agent_name = "defender1" + env.register_defender(agent_name) + env.reset() + + defender_agent = env.sim._agent_states[agent_name] + defense_step = env.sim.attack_graph.get_node_by_full_name( + 'OS App:notPresent') + env.sim._defender_step(defender_agent, {defense_step}) + assert defender_agent.step_performed_nodes == {defense_step} + + # Can not defend attack_step + attack_step = env.sim.attack_graph.get_node_by_full_name( + 'OS App:attemptUseVulnerability') + env.sim._defender_step(defender_agent, {attack_step}) + assert not defender_agent.step_performed_nodes + + +def test_malsimulator_observe_attacker(): + attack_graph, _ = load_scenario( + 'tests/testdata/scenarios/simple_scenario.yml') + + # Create the simulator + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + # Register the agents + defender_agent_name = 'defender' + attacker_agent_name = 'attacker' + + attacker = next(iter(attack_graph.attackers.values())) + + env.register_attacker(attacker_agent_name, attacker.id) + env.register_defender(defender_agent_name) + + # Must reset after registering agents + env.reset() + + # Make alteration to the attack graph attacker + assert len(env.sim.attack_graph.attackers) == 1 + attacker = next(iter(env.sim.attack_graph.attackers.values())) + + assert len(attacker.reached_attack_steps) == 1 + reached_step = next(iter(attacker.reached_attack_steps)) + + # Select actions for the attacker + actions_to_take = [] + for child_node in reached_step.children: + if child_node.type in ('and', 'or'): + # In the end the attacker will have three reached steps + # where two are children of the first one + actions_to_take.append(child_node) + + num_reached_steps_before = len(attacker.reached_attack_steps) + + for attacker_action in actions_to_take: + obs, _, _, _, _ = env.step({ + defender_agent_name: (0, None), + attacker_agent_name: (1, env.node_to_index(attacker_action)) + }) + + num_reached_steps_now = len(attacker.reached_attack_steps) + assert num_reached_steps_now == num_reached_steps_before + 1 + num_reached_steps_before = num_reached_steps_now + + attacker_observation = obs[attacker_agent_name]["observed_state"] + + for node in attacker.reached_attack_steps: + node_index = env._id_to_index[node.id] + node_obs_state = attacker_observation[node_index] + assert node_obs_state == 1 + + for index, state in enumerate(attacker_observation): + node = env.index_to_node(index) + + if node.is_compromised(): + assert state == 1 + else: + if state == -1: + for parent in node.parents: + assert parent not in attacker.reached_attack_steps + else: + assert state == 0 + + +def test_malsimulator_observe_and_reward_attacker_defender(): + """Run attacker and defender actions and make sure + rewards and observation states are updated correctly""" + + def verify_attacker_obs_state( + observed_state, + expected_reached, + expected_children_of_reached + ): + """Make sure obs state looks as expected""" + for index, state in enumerate(observed_state): + node_id = env._index_to_id[index] + if state == 1: + assert node_id in expected_reached + elif state == 0: + assert node_id in expected_children_of_reached + else: + assert state == -1 + + def verify_defender_obs_state( + observed_state + ): + """Make sure obs state looks as expected""" + for index, state in enumerate(observed_state): + node = env.index_to_node(index) + if state == 1: + assert node.is_compromised() or node.is_enabled_defense() + elif state == 0: + assert not node.is_compromised() and not node.is_enabled_defense(), f"{node.full_name} not correct state {state}" + else: + assert state == -1 + + attack_graph, _ = load_scenario( + 'tests/testdata/scenarios/traininglang_scenario.yml') + # Create the simulator + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + attacker = next(iter(attack_graph.attackers.values())) + attacker_agent_name = "Attacker1" + env.register_attacker(attacker_agent_name, attacker.id) + + defender_agent_name = "Defender1" + env.register_defender(defender_agent_name) + + env.reset() + + attacker_reached_steps = [n.id for n in attacker.entry_points] + attacker_reached_step_children = [] + for reached in attacker.entry_points: + attacker_reached_step_children.extend( + [n.id for n in reached.children] + ) + + # Prepare nodes that will be stepped through in order + user_3_compromise = env.sim.attack_graph\ + .get_node_by_full_name("User:3:compromise") + host_0_authenticate = env.sim.attack_graph\ + .get_node_by_full_name("Host:0:authenticate") + host_0_access = env.sim.attack_graph\ + .get_node_by_full_name("Host:0:access") + host_0_notPresent = env.sim.attack_graph\ + .get_node_by_full_name("Host:0:notPresent") + data_2_read = env.sim.attack_graph\ + .get_node_by_full_name("Data:2:read") + + # Step with attacker action + obs, rew, _, _, _ = env.step({ + defender_agent_name: (0, None), + attacker_agent_name: (1, env.node_to_index(user_3_compromise)) + } + ) + + # Verify obs state + attacker_reached_steps.append(user_3_compromise.id) + attacker_reached_step_children.extend( + [n.id for n in user_3_compromise.children]) + + verify_attacker_obs_state( + obs[attacker_agent_name]['observed_state'], + attacker_reached_steps, + attacker_reached_step_children) + verify_defender_obs_state( + obs[defender_agent_name]['observed_state'] + ) + + # Verify rewards + assert rew[defender_agent_name] == 0 + assert rew[attacker_agent_name] == 0 + + # Step with attacker again + obs, rew, _, _, _ = env.step({ + defender_agent_name: (0, None), + attacker_agent_name: (1, env.node_to_index(host_0_authenticate)) + } + ) + + # Verify obs state + attacker_reached_steps.append(host_0_authenticate.id) + attacker_reached_step_children.extend( + [n.id for n in host_0_authenticate.children]) + verify_attacker_obs_state( + obs[attacker_agent_name]['observed_state'], + attacker_reached_steps, + attacker_reached_step_children) + verify_defender_obs_state( + obs[defender_agent_name]['observed_state'] + ) + + # Verify rewards + assert rew[defender_agent_name] == 0 + assert rew[attacker_agent_name] == 0 + + # Step attacker again + obs, rew, _, _, _ = env.step({ + defender_agent_name: (0, None), + attacker_agent_name: (1, env.node_to_index(host_0_access)) + } + ) + + # Verify obs state + attacker_reached_steps.append(host_0_access.id) + attacker_reached_step_children.extend( + [n.id for n in host_0_access.children]) + verify_attacker_obs_state( + obs[attacker_agent_name]['observed_state'], + attacker_reached_steps, + attacker_reached_step_children) + verify_defender_obs_state( + obs[defender_agent_name]['observed_state'] + ) + + reward_host_0_access = 4 + # Verify rewards + assert rew[attacker_agent_name] == reward_host_0_access + assert rew[defender_agent_name] == -rew[attacker_agent_name] + + # Step defender and attacker + # Attacker wont be able to traverse Data:2:read since + # Host:0:notPresent is activated before + obs, rew, _, _, _ = env.step({ + defender_agent_name: (1, env.node_to_index(host_0_notPresent)), + attacker_agent_name: (1, env.node_to_index(data_2_read)) + } + ) + + # Attacker obs state should look the same as before + verify_attacker_obs_state( + obs[attacker_agent_name]['observed_state'], + attacker_reached_steps, + attacker_reached_step_children) + verify_defender_obs_state( + obs[defender_agent_name]['observed_state'] + ) + + # Verify rewards + reward_host_0_not_present = 2 + assert rew[attacker_agent_name] == reward_host_0_access # no additional reward + assert rew[defender_agent_name] == - rew[attacker_agent_name] - reward_host_0_not_present + + +def test_malsimulator_initial_observation_defender(corelang_lang_graph, model): + """Make sure ._observe_defender observes nodes and set observed state""" + + attack_graph = AttackGraph(corelang_lang_graph, model) + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + defender_agent_name = "defender" + env.register_defender(defender_agent_name) + obs, _ = env.reset() + + defender_obs_state = obs[defender_agent_name]["observed_state"] + + nodes_to_observe = [ + node for node in env.sim.attack_graph.nodes.values() + if node.is_enabled_defense() or node.is_compromised() + ] + + # Assert that observed state is 1 after observe_defender + for node in nodes_to_observe: + index = env._id_to_index[node.id] + # Make sure observed after + assert defender_obs_state[index] == 1 + + +def test_malsimulator_observe_and_reward_attacker_no_entrypoints( + corelang_lang_graph, model + ): + + attack_graph = AttackGraph(corelang_lang_graph, model) + attacker = Attacker("TestAttacker", [], []) + attack_graph.add_attacker(attacker) + sim = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + # Register an attacker + sim.register_attacker(attacker.name, attacker.id) + sim.reset() + + obs, rew, _, _, _ = sim.step({}) + + # Observe and reward with no new actions + # Since attacker has no entry points and no steps have been performed + # the observed state should be empty + for state in obs[attacker.name]['observed_state']: + assert state == -1 + assert rew[attacker.name] == 0 + + +def test_malsimulator_observe_and_reward_attacker_entrypoints( + traininglang_lang_graph, traininglang_model + ): + + attack_graph = AttackGraph( + traininglang_lang_graph, traininglang_model) + attack_graph.attach_attackers() + env = MalSimVectorizedObsEnv(MalSimulator(attack_graph)) + + # Register an attacker + attacker = env.sim.attack_graph.attackers[0] + env.register_attacker(attacker.name, attacker.id) + + # We need to reinitialize to initialize agent + obs, _ = env.reset() + + # Since reset deepcopies attack graph we + # need to fetch attacker again + attacker = env.sim.attack_graph.attackers[0] + + for index, state in enumerate( + obs[attacker.name]['observed_state']): + + node = env.index_to_node(index) + if state == -1: + assert node not in attacker.entry_points + assert node not in attacker.reached_attack_steps + assert not node.is_compromised() + assert not any([p.is_compromised() for p in node.parents]) + elif state == 0: + assert node not in attacker.entry_points + assert node not in attacker.reached_attack_steps + assert not node.is_compromised() + assert any([p.is_compromised() for p in node.parents]) + elif state == 1: + assert node in attacker.entry_points + assert node in attacker.reached_attack_steps + assert node.is_compromised() diff --git a/tests/run_demo.py b/tests/run_demo.py deleted file mode 100644 index b7aef280..00000000 --- a/tests/run_demo.py +++ /dev/null @@ -1,163 +0,0 @@ -import sys -import os - -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - -from json import JSONEncoder -import numpy as np -import logging - -from maltoolbox.language import LanguageClassesFactory, LanguageGraph -from maltoolbox.attackgraph import AttackGraph -from maltoolbox.model import Model - -from malsim.agents.keyboard_input import KeyboardAgent -from malsim.agents.searchers import BreadthFirstAttacker -from malsim.sims.mal_simulator import MalSimulator - - -logger = logging.getLogger(__name__) -logging.getLogger().setLevel(logging.DEBUG) -logging.getLogger("maltoolbox").setLevel(logging.DEBUG) - - -# Raise the logging level for the py2neo module to clean the logs a bit -# cleaner. -py2neo_logger = logging.getLogger("py2neo") -py2neo_logger.setLevel(logging.INFO) - -null_action = (0, None) - - -class NumpyArrayEncoder(JSONEncoder): - def default(self, o): - if isinstance(o, np.ndarray): - return o.tolist() - if isinstance(o, np.bool_): - return bool(o) - if isinstance(o, np.int64): - return int(o) - return JSONEncoder.default(self, o) - - -attacker_only = False - -AGENT_ATTACKER = "attacker" -AGENT_DEFENDER = "defender" - - -#MAL toolbox to load the graph attack -lang_file = "tests/testdata/langs/org.mal-lang.coreLang-1.0.0.mar" -lang_graph = LanguageGraph.from_mar_archive(lang_file) -lang_classes_factory = LanguageClassesFactory(lang_graph) - -model = Model.load_from_file("tests/testdata/models/run_demo_model.json", lang_classes_factory) - -attack_graph = AttackGraph(lang_graph, model) -attack_graph.attach_attackers() -attack_graph.save_to_file("logs/attack_graph.json") - -env = MalSimulator(lang_graph, model, attack_graph, max_iter=500) - -env.register_attacker(AGENT_ATTACKER, 0) -env.register_defender(AGENT_DEFENDER) - -control_attacker = False - -reverse_vocab = env._index_to_full_name - -defender = KeyboardAgent(reverse_vocab) -attacker = ( - KeyboardAgent(reverse_vocab) if control_attacker else BreadthFirstAttacker({}) -) - -obs, infos = env.reset() -done = False - -# Set rewards -# TODO Have a nice and configurable way of doing this when we have the -# scenario configuration format decided upon. -MAX_REWARD = int(1e9) - -env.attack_graph.get_node_by_full_name("OS App:notPresent").extras['reward'] = 50 -env.attack_graph.get_node_by_full_name("OS App:supplyChainAuditing").extras['reward'] = MAX_REWARD -env.attack_graph.get_node_by_full_name("Program 1:notPresent").extras['reward'] = 30 -env.attack_graph.get_node_by_full_name("Program 1:supplyChainAuditing").extras['reward'] = MAX_REWARD -env.attack_graph.get_node_by_full_name("SoftwareVulnerability:2:notPresent").extras['reward'] = 40 -env.attack_graph.get_node_by_full_name("Data:3:notPresent").extras['reward'] = 20 -env.attack_graph.get_node_by_full_name("Credentials:4:notPhishable").extras['reward'] = MAX_REWARD -env.attack_graph.get_node_by_full_name("Identity:5:notPresent").extras['reward'] = 35 -env.attack_graph.get_node_by_full_name("ConnectionRule:6:restricted").extras['reward'] = 40 -env.attack_graph.get_node_by_full_name("ConnectionRule:6:payloadInspection").extras['reward'] = 30 -env.attack_graph.get_node_by_full_name("Other OS App:notPresent").extras['reward'] = 50 -env.attack_graph.get_node_by_full_name("Other OS App:supplyChainAuditing").extras['reward'] = MAX_REWARD - -env.attack_graph.get_node_by_full_name("OS App:fullAccess").extras['reward'] = 100 -env.attack_graph.get_node_by_full_name("Program 1:fullAccess").extras['reward'] = 50 -env.attack_graph.get_node_by_full_name("Identity:5:assume").extras['reward'] = 50 -env.attack_graph.get_node_by_full_name("Other OS App:fullAccess").extras['reward'] = 200 - - -logger.info("Starting game.") - -total_reward_defender = 0 -total_reward_attacker = 0 - -while not done: - # env.render() - defender_action = ( - defender.compute_action_from_dict( - obs[AGENT_DEFENDER], infos[AGENT_DEFENDER]["action_mask"] - ) - if not attacker_only - else null_action - ) - attacker_action = attacker.compute_action_from_dict( - obs[AGENT_ATTACKER], infos[AGENT_ATTACKER]["action_mask"] - ) - - if attacker_action[1] is not None: - print("Attacker Action: ", reverse_vocab[attacker_action[1]]) - logger.debug(f"Attacker Action: {reverse_vocab[attacker_action[1]]}") - else: - print("Attacker Action: None") - logger.debug("Attacker Action: None") - # Stop the attacker if it has run out of things to do since the - # experiment cannot progress any further. - # TODO Perhaps we want to only do this if none of the agents have - # anything to do or we may simply wish to have them running to accrue - # penalties/rewards. This was added just to make it easier to debug. - done = True - action_dict = {AGENT_ATTACKER: attacker_action, AGENT_DEFENDER: defender_action} - obs, rewards, terminated, truncated, infos = env.step(action_dict) - - logger.debug("Attacker has compromised the following attack steps so " "far:") - attacker_obj = env.attack_graph.attackers[ - env.agents_dict[AGENT_ATTACKER]["attacker"] - ] - for step in attacker_obj.reached_attack_steps: - logger.debug(step.id) - - print("Attacker Reward: ", rewards[AGENT_ATTACKER]) - logger.debug(f"Attacker Reward: {rewards[AGENT_ATTACKER]}") - if not attacker_only: - print("Defender Reward: ", rewards[AGENT_DEFENDER]) - logger.debug(f"Defender Reward: {rewards[AGENT_DEFENDER]}") - total_reward_defender += rewards[AGENT_DEFENDER] if not attacker_only else 0 - total_reward_attacker += rewards[AGENT_ATTACKER] - - done |= terminated[AGENT_ATTACKER] or truncated[AGENT_ATTACKER] - - print("---\n") - -# env.render() -print("Game Over.") -logger.debug("Game Over.") -if not attacker_only: - print("Total Defender Reward: ", total_reward_defender) - logger.debug(f"Total Defender Reward: {total_reward_defender}") -print("Total Attacker Reward: ", total_reward_attacker) -logger.debug(f"Total Attacker Reward: {total_reward_attacker}") -print("Press Enter to exit.") -input() -env.close() diff --git a/tests/test_example_scenarios.py b/tests/test_example_scenarios.py deleted file mode 100644 index 817d1174..00000000 --- a/tests/test_example_scenarios.py +++ /dev/null @@ -1,172 +0,0 @@ -from malsim.scenario import create_simulator_from_scenario - - -def test_bfs_vs_bfs_state_and_reward(): - sim, sim_config = create_simulator_from_scenario( - "tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml" - ) - obs, infos = sim.reset() - - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents()), None) - - # Initialize defender and attacker according to classes - defender_class = sim_config["agents"][defender_agent_id]["agent_class"] - defender_agent = defender_class({}) - - attacker_class = sim_config["agents"][attacker_agent_id]["agent_class"] - attacker_agent = attacker_class({}) - - total_reward_defender = 0 - total_reward_attacker = 0 - - attacker_actions = [] - defender_actions = [] - - while True: - defender_action = defender_agent.compute_action_from_dict( - obs[defender_agent_id], infos[defender_agent_id]["action_mask"] - ) - - attacker_action = attacker_agent.compute_action_from_dict( - obs[attacker_agent_id], infos[attacker_agent_id]["action_mask"] - ) - - if attacker_action[0]: - attacker_node = sim.action_to_node(attacker_action) - attacker_actions.append(attacker_node.full_name) - if defender_action[0]: - defender_node = sim.action_to_node(defender_action) - defender_actions.append(defender_node.full_name) - - actions = {"defender": defender_action, "attacker": attacker_action} - obs, rewards, terminated, truncated, infos = sim.step(actions) - - total_reward_defender += rewards.get(defender_agent_id, 0) - total_reward_attacker += rewards.get(attacker_agent_id, 0) - - if terminated[defender_agent_id] or terminated[attacker_agent_id]: - break - - assert attacker_actions == [ - 'Internet:attemptReverseReach', - 'Internet:networkForwardingUninspected', - 'Internet:deny', - 'Internet:accessNetworkData', - 'ConnectionRule Internet->Linux ' - 'System:attemptConnectToApplicationsUninspected', - 'Internet:reverseReach', - 'Internet:networkForwardingInspected', - 'ConnectionRule Internet->Linux System:attemptAccessNetworksUninspected', - 'ConnectionRule Internet->Linux System:attemptDeny', - 'Internet:attemptEavesdrop', - 'Internet:attemptAdversaryInTheMiddle', - 'ConnectionRule Internet->Linux System:bypassRestricted', - 'ConnectionRule Internet->Linux System:bypassPayloadInspection', - 'ConnectionRule Internet->Linux System:connectToApplicationsUninspected', - 'ConnectionRule Internet->Linux System:attemptReverseReach', - 'ConnectionRule Internet->Linux System:attemptAccessNetworksInspected', - 'ConnectionRule Internet->Linux System:attemptConnectToApplicationsInspected', - 'ConnectionRule Internet->Linux System:successfulAccessNetworksUninspected', - 'ConnectionRule Internet->Linux System:deny', - 'Internet:bypassEavesdropDefense', - 'Internet:successfulEavesdrop', - 'Internet:bypassAdversaryInTheMiddleDefense', - 'Internet:successfulAdversaryInTheMiddle', - 'Linux system:networkConnectUninspected', - 'Linux system:networkConnectInspected', - 'ConnectionRule Internet->Linux System:reverseReach', - 'ConnectionRule Internet->Linux System:successfulAccessNetworksInspected', - 'ConnectionRule Internet->Linux System:connectToApplicationsInspected', - 'ConnectionRule Internet->Linux System:accessNetworksUninspected', - 'Linux system:denyFromNetworkingAsset', - 'Internet:eavesdrop', - 'Internet:adversaryInTheMiddle', - 'Linux system:attemptUseVulnerability', - 'Linux system:networkConnect', - 'Linux system:specificAccessNetworkConnect', - 'Linux system:softwareProductVulnerabilityNetworkAccessAchieved', - 'Linux system:attemptReverseReach', - 'ConnectionRule Internet->Linux System:accessNetworksInspected', - 'Linux system:attemptDeny', - 'Internet:accessInspected' - ] - - assert defender_actions == [ - 'Linux system:notPresent', - 'Linux system:supplyChainAuditing', - 'Internet:networkAccessControl', - 'Internet:eavesdropDefense', - 'Internet:adversaryInTheMiddleDefense', - 'ConnectionRule Internet->Linux System:restricted', - 'ConnectionRule Internet->Linux System:payloadInspection', - 'Secret data:notPresent', - 'SoftwareVuln:notPresent' - ] - - # Verify observations - for step_fullname in attacker_actions: - node = sim.attack_graph.get_node_by_full_name(step_fullname) - if node.is_compromised(): - node_index = sim.node_to_index(node) - assert obs[defender_agent_id]["observed_state"][node_index] - - for step_fullname in defender_actions: - node = sim.attack_graph.get_node_by_full_name(step_fullname) - node_index = sim.node_to_index(node) - assert obs[defender_agent_id]["observed_state"][node_index] - - assert rewards[attacker_agent_id] == 0 - assert rewards[defender_agent_id] == -50 - - assert total_reward_attacker == 0 - assert total_reward_defender == -2000 - - -def test_scenario_step_by_step(): - sim, sim_config = create_simulator_from_scenario( - "tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml" - ) - sim.reset() - - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents()), None) - - attacker_actions = [ - "Internet:attemptReverseReach", - "Internet:reverseReach", - "ConnectionRule Internet->Linux System:attemptReverseReach", - "ConnectionRule Internet->Linux System:reverseReach", - "Linux system:attemptReverseReach", - "Linux system:successfulReverseReach", - "Linux system:reverseReach", - "Secret data:attemptReverseReach", - "Secret data:reverseReach", - ] - - # Make sure attacker can take these steps - for attacker_action_fn in attacker_actions: - attacker_node = sim.attack_graph.get_node_by_full_name(attacker_action_fn) - attacker_action_index = sim.node_to_index(attacker_node) - actions = { - defender_agent_id: (0, None), - attacker_agent_id: (1, attacker_action_index) - } - sim.step(actions) - assert attacker_node.is_compromised() - - - # TODO: find out if this fails because is_traversable is not correct - # But not this one - # attacker_node = sim.attack_graph.get_node_by_full_name( - # "Secret data:read" - # ) - # attacker_action_index = sim.node_to_index(attacker_node) - - # actions = { - # defender_agent_id: (0, None), - # attacker_agent_id: (1, attacker_action_index) - # } - - # sim.step(actions) - # assert not attacker_node.is_compromised() \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_main.py similarity index 61% rename from tests/test_cli.py rename to tests/test_main.py index c6d584d0..062405b4 100644 --- a/tests/test_cli.py +++ b/tests/test_main.py @@ -3,8 +3,9 @@ import os from unittest.mock import patch +from malsim.__main__ import run_simulation from malsim.scenario import create_simulator_from_scenario -from malsim.cli import run_simulation +from malsim.mal_simulator import MalSimulator def path_relative_to_tests(filename): @@ -22,23 +23,22 @@ def test_run_simulation(mock_input): """Make sure we can run simulation with defender agent registered in scenario""" - simulator, config = create_simulator_from_scenario( - path_relative_to_tests( - './testdata/scenarios/bfs_vs_bfs_scenario.yml' - ) + scenario_file = path_relative_to_tests( + './testdata/scenarios/bfs_vs_bfs_scenario.yml' ) - run_simulation(simulator, config) + sim, agents = create_simulator_from_scenario( + scenario_file, sim_class=MalSimulator) + run_simulation(sim, agents) @patch("builtins.input", return_value="\n") # to not freeze on input() def test_run_simulation_without_defender_agent(mock_input): """Make sure we can run simulation without defender agent registered in scenario""" - simulator, config = create_simulator_from_scenario( - path_relative_to_tests( - './testdata/scenarios/no_defender_agent_scenario.yml' - ) + scenario_file = path_relative_to_tests( + './testdata/scenarios/no_defender_agent_scenario.yml' ) - run_simulation(simulator, config) - + sim, agents = create_simulator_from_scenario( + scenario_file, sim_class=MalSimulator) + run_simulation(sim, agents) diff --git a/tests/test_mal_simulator.py b/tests/test_mal_simulator.py index 89af3754..71759bd4 100644 --- a/tests/test_mal_simulator.py +++ b/tests/test_mal_simulator.py @@ -1,310 +1,46 @@ """Test MalSimulator class""" from maltoolbox.attackgraph import AttackGraph, Attacker -from malsim.sims.mal_simulator import MalSimulator -from malsim.scenario import load_scenario, create_simulator_from_scenario -from malsim.sims import MalSimulatorSettings +from malsim.mal_simulator import MalSimulator +from malsim.scenario import load_scenario -def test_malsimulator(corelang_lang_graph, model): +def test_init(corelang_lang_graph, model): attack_graph = AttackGraph(corelang_lang_graph, model) - MalSimulator(corelang_lang_graph, model, attack_graph) + MalSimulator(attack_graph) -def test_malsimulator_create_blank_observation(corelang_lang_graph, model): - """Make sure blank observation contains correct default values""" - attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - - num_objects = len(attack_graph.nodes) - blank_observation = sim.create_blank_observation() - - assert len(blank_observation['is_observable']) == num_objects - for state in blank_observation['is_observable']: - # Default is that all nodes are observable, - # unless anything else is given through its extras field - assert state == 1 - - assert len(blank_observation['observed_state']) == num_objects - for state in blank_observation['observed_state']: - assert state == -1 # This is the default (which we get in blank observation) - - assert len(blank_observation['remaining_ttc']) == num_objects - for ttc in blank_observation['remaining_ttc']: - assert ttc == 0 # TTC is currently always 0 no matter what - - # asset_type_index points us to an asset type in sim._index_to_asset_type - # the index where asset_type_index lies on will point to an attack step id in sim._index_to_id - # The type we get from sim._index_to_asset_type[asset_type_index] - # should be the same as the asset type of attack step with id index in attack graph - assert len(blank_observation['asset_type']) == num_objects - for index, asset_type_index in enumerate(blank_observation['asset_type']): - # Note: offset is decremented from asset_type_index - expected_type = sim._index_to_asset_type[asset_type_index] - node = sim.index_to_node(index) - assert node.lg_attack_step.asset.name == expected_type - - # asset_id on index X in blank_observation['asset_id'] - # should be the same as the id of the asset of attack step X - assert len(blank_observation['asset_id']) == num_objects - for index, expected_asset_id in enumerate(blank_observation['asset_id']): - node = sim.index_to_node(index) - assert node.model_asset.id == expected_asset_id - - assert len(blank_observation['step_name']) == num_objects - - expected_num_edges = sum([1 for step in attack_graph.nodes.values() - for child in step.children] + - # We expect all defenses again (reversed) - [1 for step in attack_graph.nodes.values() - for child in step.children - if step.type == "defense"]) - assert len(blank_observation['attack_graph_edges']) == expected_num_edges - - -def test_malsimulator_create_blank_observation_deterministic( - corelang_lang_graph, model - ): - """Make sure blank observation is deterministic with seed given""" - - attack_graph = AttackGraph(corelang_lang_graph, model) - attack_graph.attach_attackers() - all_attackers = list(attack_graph.attackers.values()) - - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - sim.register_attacker("test_attacker", all_attackers[0].id) - sim.register_defender("test_defender") - - obs1, _ = sim.reset(seed=123) - obs2, _ = sim.reset(seed=123) - - assert list(obs1['test_attacker']['is_observable']) == list(obs2['test_attacker']['is_observable']) - assert list(obs1['test_attacker']['is_actionable']) == list(obs2['test_attacker']['is_actionable']) - assert list(obs1['test_attacker']['observed_state']) == list(obs2['test_attacker']['observed_state']) - assert list(obs1['test_attacker']['remaining_ttc']) == list(obs2['test_attacker']['remaining_ttc']) - assert list(obs1['test_attacker']['asset_type']) == list(obs2['test_attacker']['asset_type']) - assert list(obs1['test_attacker']['asset_id']) == list(obs2['test_attacker']['asset_id']) - assert list(obs1['test_attacker']['step_name']) == list(obs2['test_attacker']['step_name']) - - for i, elem in enumerate(obs1['test_attacker']['attack_graph_edges']): - assert list(obs2['test_attacker']['attack_graph_edges'][i]) == list(elem) - - assert list(obs1['test_attacker']['model_asset_id']) == list(obs2['test_attacker']['model_asset_id']) - assert list(obs1['test_attacker']['model_asset_type']) == list(obs2['test_attacker']['model_asset_type']) - - for i, elem in enumerate(obs1['test_attacker']['model_edges_ids']): - assert list(obs2['test_attacker']['model_edges_ids'][i]) == list(elem) - - assert list(obs1['test_attacker']['model_edges_type']) == list(obs2['test_attacker']['model_edges_type']) - - -def test_malsimulator_step_deterministic( - corelang_lang_graph, model - ): - """Make sure blank observation is deterministic with seed given""" - - attack_graph = AttackGraph(corelang_lang_graph, model) - attack_graph.attach_attackers() - all_attackers = list(attack_graph.attackers.values()) - - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - sim.register_attacker("test_attacker", all_attackers[0].id) - sim.register_defender("test_defender") - - obs1 = {} - obs2 = {} - - # Run 1 - sim.reset(seed=123) - for _ in range(10): - attacker_node = next( - n for n in sim.agents_dict['test_attacker']['action_surface'] - if not n.is_compromised() - ) - attacker_action = (1, sim.node_to_index(attacker_node)) - obs1, _, _, _, _ = sim.step( - {'test_defender': [1, 0], 'test_attacker': attacker_action} - ) - - # Run 2 - identical - sim.reset(seed=123) - for _ in range(10): - attacker_node = next( - n for n in sim.agents_dict['test_attacker']['action_surface'] - if not n.is_compromised() - ) - attacker_action = (1, sim.node_to_index(attacker_node)) - obs2, _, _, _, _ = sim.step( - {'test_defender': [1, 0], 'test_attacker': attacker_action} - ) - - assert list(obs1['test_attacker']['observed_state']) == list(obs2['test_attacker']['observed_state']) - assert list(obs1['test_defender']['observed_state']) == list(obs2['test_defender']['observed_state']) - - -def test_malsimulator_create_blank_observation_observability_given( - corelang_lang_graph, model - ): - """Make sure observability propagates correctly from extras field/scenario - to observation in mal simulator""" - - # Load Scenario with observability rules set - scenario_file = 'tests/testdata/scenarios/traininglang_observability_scenario.yml' - sim, _ = create_simulator_from_scenario(scenario_file) - - num_objects = len(sim.attack_graph.nodes) - blank_observation = sim.create_blank_observation() - - assert len(blank_observation['is_observable']) == num_objects - - for index, observable in enumerate(blank_observation['is_observable']): - node = sim.index_to_node(index) - - # Below are the rules from the traininglang observability scenario - # made into if statements - if node.lg_attack_step.asset.name == 'Host' and node.name in ('access'): - assert observable - elif node.lg_attack_step.asset.name == 'Host' and node.name in ('authenticate'): - assert observable - elif node.lg_attack_step.asset.name == 'Data' and node.name in ('read'): - assert observable - elif node.model_asset.name == 'User:3' and node.name in ('phishing'): - assert observable - else: - assert not observable - - -def test_malsimulator_create_blank_observation_actionability_given( - corelang_lang_graph, model - ): - """Make sure actionability propagates correctly from extras field/scenario - to observation in mal simulator""" - - # Load Scenario with observability rules set - scenario_file = 'tests/testdata/scenarios/traininglang_actionability_scenario.yml' - sim, _ = create_simulator_from_scenario(scenario_file) - - num_objects = len(sim.attack_graph.nodes) - blank_observation = sim.create_blank_observation() - - assert len(blank_observation['is_actionable']) == num_objects - - for index, actionable in enumerate(blank_observation['is_actionable']): - node_id = sim._index_to_id[index] - node = sim.attack_graph.nodes[node_id] - - # Below are the rules from the traininglang observability scenario - # made into if statements - if node.lg_attack_step.asset.name == 'Host' and node.name in ('notPresent'): - assert actionable - elif node.lg_attack_step.asset.name == 'Data' and node.name in ('notPresent'): - assert actionable - elif node.model_asset.name == 'User:3' and node.name in ('notPresent'): - assert actionable - else: - assert not actionable - - -def test_malsimulator_format_info(corelang_lang_graph, model): - """Make sure format info works as expected""" - attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - - # Preparations of info to send to _format_info - can_wait = {"attacker": 0, "defender": 1} - infos = {} - agent = "attacker1" - agent_type = "attacker" - can_act = 1 - available_actions = [0] * len(attack_graph.nodes) - available_actions[0] = 1 # Only first action is available - - infos[agent] = { - "action_mask": ( - [can_wait[agent_type], can_act], - available_actions - ) - } - formatted = sim._format_info(infos[agent]) - assert formatted == "Can act? Yes\n0 OS App:notPresent\n" - - # Add an action and change 'can_act' to false - available_actions[1] = 1 # Also second action is available - can_act = False - infos[agent] = { - "action_mask": ( - [can_wait[agent_type], can_act], - available_actions - ) - } - formatted = sim._format_info(infos[agent]) - assert formatted == "Can act? No\n0 OS App:notPresent\n1 OS App:attemptUseVulnerability\n" - - -def test_malsimulator_observation_space(corelang_lang_graph, model): - attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - observation_space = sim.observation_space() - assert set(observation_space.keys()) == { - 'is_observable', - 'is_actionable', - 'observed_state', - 'remaining_ttc', - 'asset_type', - 'asset_id', - 'step_name', - 'attack_graph_edges', - 'model_asset_id', - 'model_asset_type', - 'model_edges_ids', - 'model_edges_type', - } - # All values in the observation space dict are of type Box - # which comes from gymnasium.spaces (Box is a Space) - # spaces have a shape (tuple) and a datatype (from numpy) - - -def test_malsimulator_action_space(corelang_lang_graph, model): - attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - action_space = sim.action_space() - # action_space is a 'MultiDiscrete' (from gymnasium.spaces) - assert action_space.shape == (2,) - - -def test_malsimulator_reset(corelang_lang_graph, model): +def test_reset(corelang_lang_graph, model): """Make sure attack graph is reset""" attack_graph = AttackGraph(corelang_lang_graph, model) - agent_name = "testagent" - agent_id = 0 + agent_entry_point = attack_graph.get_node_by_full_name( 'OS App:networkConnectUninspected') + attacker_name = "testagent" + attacker = Attacker( - agent_name, - entry_points={agent_entry_point}, - reached_attack_steps={agent_entry_point} + attacker_name, + entry_points = {agent_entry_point}, + reached_attack_steps = {agent_entry_point}, + attacker_id = 100 ) - attack_graph.add_attacker(attacker, agent_id) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) + attack_graph.add_attacker(attacker, attacker.id) - assert sim._index_to_id - assert sim._index_to_full_name - assert sim._id_to_index + sim = MalSimulator(attack_graph) attack_graph_before = sim.attack_graph - sim.register_attacker(agent_name, agent_id) - assert agent_name in sim.possible_agents - assert agent_name in sim.agents_dict - assert agent_name not in sim.agents + sim.register_attacker(attacker_name, attacker.id) + assert attacker.name in sim.agent_states + assert len(sim.agent_states) == 1 sim.reset() attack_graph_after = sim.attack_graph # Make sure agent was added (and not removed) - assert agent_name in sim.agents + assert attacker.name in sim.agent_states # Make sure the attack graph is not the same object but identical assert id(attack_graph_before) != id(attack_graph_after) @@ -316,672 +52,288 @@ def test_malsimulator_reset(corelang_lang_graph, model): assert attack_graph_before._to_dict() == attack_graph_after._to_dict() - -def test_malsimulator_register_attacker(corelang_lang_graph, model): +def test_register_agent_attacker(corelang_lang_graph, model): attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) + sim = MalSimulator(attack_graph) + + attacker = 1 agent_name = "attacker1" - attacker_id = 1 - sim.register_attacker(agent_name, attacker_id) - assert agent_name in sim.possible_agents - assert agent_name in sim.agents_dict + sim.register_attacker(agent_name, attacker) + assert agent_name in sim.agent_states + assert agent_name in sim.agent_states -def test_malsimulator_register_defender(corelang_lang_graph, model): + +def test_register_agent_defender(corelang_lang_graph, model): attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) + sim = MalSimulator(attack_graph) + agent_name = "defender1" sim.register_defender(agent_name) - assert agent_name in sim.possible_agents - assert agent_name in sim.agents_dict + assert agent_name in sim.agent_states + assert agent_name in sim.agent_states + + +def test_register_agent_action_surface(corelang_lang_graph, model): + attack_graph = AttackGraph(corelang_lang_graph, model) + sim = MalSimulator(attack_graph) + + agent_name = "defender1" + sim.register_defender(agent_name) -def test_simulator_initialize_agents(): + sim._init_agent_action_surfaces() + action_surface = sim.agent_states[agent_name].action_surface + for node in action_surface: + assert node.is_available_defense() + + +def test_simulator_initialize_agents(corelang_lang_graph, model): """Test _initialize_agents""" - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/simple_scenario.yml', - ) + ag, _ = load_scenario('tests/testdata/scenarios/simple_scenario.yml') + sim = MalSimulator(ag) + + # Register the agents + attacker_name = "attacker" + attacker_id = 1 + defender_name = "defender" + sim.register_attacker(attacker_name, attacker_id) + sim.register_defender(defender_name) + sim.reset() - agents = sim._initialize_agents() - assert set(agents.keys()) == {'defender', 'attacker'} - for node in sim.attack_graph.nodes.values(): - node_index = sim._id_to_index[node.id] - if node.is_enabled_defense(): - assert node_index in agents['defender'] - elif node.is_compromised(): - assert node_index in agents['attacker'] - else: - assert node_index not in agents['defender'] - assert node_index not in agents['attacker'] + assert set(sim.agent_states.keys()) == {attacker_name, defender_name} def test_get_agents(): - """Test get_attacker_agents and get_defender_agents""" + """Test _get_attacker_agents and _get_defender_agents""" - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/simple_scenario.yml', - ) + ag, _ = load_scenario('tests/testdata/scenarios/simple_scenario.yml') + sim = MalSimulator(ag) sim.reset() - sim.get_attacker_agents() == ['attacker'] - sim.get_defender_agents() == ['defender'] + sim._get_attacker_agents() == ['attacker'] + sim._get_defender_agents() == ['defender'] -def test_malsimulator_attacker_step(corelang_lang_graph, model): + +def test_attacker_step(corelang_lang_graph, model): attack_graph = AttackGraph(corelang_lang_graph, model) + entry_point = attack_graph.get_node_by_full_name('OS App:fullAccess') - attacker_id = 0 - attacker = Attacker('attacker1', set(), set()) - attack_graph.add_attacker(attacker, attacker_id) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) + attacker = Attacker( + 'attacker1', + reached_attack_steps = {entry_point}, + entry_points = {entry_point}, + attacker_id = 100 + ) + attack_graph.add_attacker(attacker, attacker.id) + sim = MalSimulator(attack_graph) - sim.register_attacker(attacker.name, attacker_id) + sim.register_attacker(attacker.name, attacker.id) sim.reset() + attacker_agent = sim._agent_states[attacker.name] # Can not attack the notPresent step - defense_step = attack_graph.get_node_by_full_name('OS App:notPresent') - actions = sim._attacker_step(attacker.name, defense_step.id) + defense_step = sim.attack_graph.get_node_by_full_name('OS App:notPresent') + actions = sim._attacker_step(attacker_agent, {defense_step}) assert not actions + assert not attacker_agent.step_action_surface_additions - # Can attack the attemptUseVulnerability step! - attack_step = attack_graph.get_node_by_full_name('OS App:attemptUseVulnerability') - actions = sim._attacker_step(attacker.name, attack_step.id) - assert actions == [attack_step.id] + attack_step = sim.attack_graph.get_node_by_full_name('OS App:attemptRead') + sim._attacker_step(attacker_agent, {attack_step}) + assert attacker_agent.step_performed_nodes == {attack_step} + assert attacker_agent.step_action_surface_additions == attack_step.children -def test_malsimulator_defender_step(corelang_lang_graph, model): +def test_defender_step(corelang_lang_graph, model): attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) + sim = MalSimulator(attack_graph) defender_name = "defender" sim.register_defender(defender_name) sim.reset() - defense_step = attack_graph.get_node_by_full_name( + defender_agent = sim._agent_states[defender_name] + defense_step = sim.attack_graph.get_node_by_full_name( 'OS App:notPresent') - actions, _ = sim._defender_step(defender_name, defense_step.id) - assert actions == [defense_step.id] + sim._defender_step(defender_agent, {defense_step}) + assert defender_agent.step_performed_nodes == {defense_step} # Can not defend attack_step - attack_step = attack_graph.get_node_by_full_name( + attack_step = sim.attack_graph.get_node_by_full_name( 'OS App:attemptUseVulnerability') - actions, _ = sim._defender_step(defender_name, attack_step.id) - assert not actions + sim._defender_step(defender_agent, {attack_step}) + assert not defender_agent.step_performed_nodes -def test_malsimulator_observe_attacker(): +def test_agent_state_views_simple(corelang_lang_graph, model): + attack_graph = AttackGraph(corelang_lang_graph, model) + entry_point = attack_graph.get_node_by_full_name('OS App:fullAccess') + + attacker = Attacker( + 'attacker1', + reached_attack_steps = {entry_point}, + entry_points = set(), + attacker_id = 100 + ) + attack_graph.add_attacker(attacker, attacker.id) + sim = MalSimulator(attack_graph) + attacker_name = 'attacker' + defender_name = 'defender' + sim.register_attacker(attacker_name, attacker.id) + + sim.register_defender(defender_name) + + # Evaluate the agent state views after reset + state_views = sim.reset() + asv = state_views['attacker'] + dsv = state_views['defender'] + assert asv.step_performed_nodes == set() + assert dsv.step_performed_nodes == set() + assert len(asv.action_surface) == 6 + assert len(dsv.action_surface) == 21 + assert dsv.step_action_surface_additions == set() + assert asv.step_action_surface_removals == set() + assert dsv.step_action_surface_removals == set() + + # Evaluate the agent state views after stepping through an attack step and + # a defense that will not impact it in any way + state_views = sim.step({ + 'defender': {sim.attack_graph.get_node_by_full_name( + 'Program 2:notPresent')}, + 'attacker': {sim.attack_graph.get_node_by_full_name( + 'OS App:attemptDeny')} + }) + asv = state_views['attacker'] + dsv = state_views['defender'] + assert asv.step_performed_nodes == { + sim.attack_graph.get_node_by_full_name('OS App:attemptDeny')} + assert dsv.step_performed_nodes == { + sim.attack_graph.get_node_by_full_name('Program 2:notPresent')} + assert asv.step_action_surface_additions == { + sim.attack_graph.get_node_by_full_name('OS App:successfulDeny')} + assert dsv.step_action_surface_additions == set() + assert asv.step_action_surface_removals == set() + assert dsv.step_action_surface_removals == { + sim.attack_graph.get_node_by_full_name('Program 2:notPresent')} + assert dsv.step_all_compromised_nodes == { + sim.attack_graph.get_node_by_full_name('OS App:attemptDeny')} + assert len(dsv.step_unviable_nodes) == 49 + + # Evaluate the agent state views after stepping through an attack step and + # a defense that would prevent it from occurring + state_views = sim.step({ + 'defender': {sim.attack_graph.get_node_by_full_name( + 'OS App:notPresent')}, + 'attacker': {sim.attack_graph.get_node_by_full_name( + 'OS App:successfulDeny')} + }) + asv = state_views['attacker'] + dsv = state_views['defender'] + assert asv.step_performed_nodes == set() + assert dsv.step_performed_nodes == { + sim.attack_graph.get_node_by_full_name('OS App:notPresent')} + assert asv.step_action_surface_additions == set() + assert dsv.step_action_surface_additions == set() + assert asv.step_action_surface_removals == { + sim.attack_graph.get_node_by_full_name('OS App:accessNetworkAndConnections'), + sim.attack_graph.get_node_by_full_name('OS App:specificAccess'), + sim.attack_graph.get_node_by_full_name('OS App:attemptApplicationRespondConnectThroughData'), + sim.attack_graph.get_node_by_full_name('OS App:attemptRead'), + sim.attack_graph.get_node_by_full_name('OS App:attemptModify'), + sim.attack_graph.get_node_by_full_name('OS App:successfulDeny'), + } + assert dsv.step_action_surface_removals == { + sim.attack_graph.get_node_by_full_name('OS App:notPresent')} + assert dsv.step_all_compromised_nodes == set() + assert len(dsv.step_unviable_nodes) == 55 + + +def test_observe_attacker(): attack_graph, _ = load_scenario( 'tests/testdata/scenarios/simple_scenario.yml' ) # Create the simulator - sim = MalSimulator( - attack_graph.lang_graph, attack_graph.model, attack_graph) + sim = MalSimulator(attack_graph) # Register the agents attacker_agent_id = "attacker" defender_agent_id = "defender" - attacker = next(iter(sim.attack_graph.attackers.values())) - - sim.register_attacker(attacker_agent_id, attacker.id) + sim.register_attacker(attacker_agent_id, 1) sim.register_defender(defender_agent_id) - - obs, _ = sim.reset() + sim.reset() # Make alteration to the attack graph attacker assert len(sim.attack_graph.attackers) == 1 - - # We reset to get the new attacker - attacker = sim.attack_graph.attackers[attacker.id] + attacker = next(iter(sim.attack_graph.attackers.values())) assert len(attacker.reached_attack_steps) == 1 - reached_step = list(attacker.reached_attack_steps)[0] - - # Select actions for the attacker - actions_to_take = [] - for child_node in reached_step.children: - if child_node.type in ('and', 'or'): - # In the end the attacker will have three reached steps - # where two are children of the first one - actions_to_take.append(child_node) - - attacker_agent_id = next(iter(sim.get_attacker_agents())) - num_reached_steps_before = len(attacker.reached_attack_steps) - - for attacker_action in actions_to_take: - action_index = sim._id_to_index[attacker_action.id] - - obs, _, _, _, _ = sim.step({ - defender_agent_id: (0, None), - attacker_agent_id: (1, action_index) - }) - - num_reached_steps_now = len(attacker.reached_attack_steps) - assert num_reached_steps_now == num_reached_steps_before + 1 - num_reached_steps_before = num_reached_steps_now - - attacker_observation = obs[attacker_agent_id]["observed_state"] - - for node in attacker.reached_attack_steps: - node_index = sim._id_to_index[node.id] - node_obs_state = attacker_observation[node_index] - assert node_obs_state == 1 - - for index, state in enumerate(attacker_observation): - node = sim.index_to_node(index) - - if node.is_compromised(): - assert state == 1 - else: - if state == -1: - for parent in node.parents: - assert parent not in attacker.reached_attack_steps - else: - assert state == 0 - -def test_malsimulator_initial_observation_defender(corelang_lang_graph, model): - """Make sure ._observe_defender observes nodes and set observed state""" - - attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - - defender_name = "defender" - sim.register_defender(defender_name) - observations, _ = sim.initialize() - defender_obs_state = observations[defender_name]["observed_state"] - - # Assert that observed state is not 1 before observe_defender - nodes_to_observe = [ - node for node in sim.attack_graph.nodes.values() - if node.is_enabled_defense() or node.is_compromised() - ] - - # Assert that observed state is 1 after observe_defender - for node in nodes_to_observe: - index = sim._id_to_index[node.id] - # Make sure observed after - assert defender_obs_state[index] == 1 - -def test_malsimulator_observe_and_reward_attacker_no_entrypoints( - corelang_lang_graph, model): - attack_graph = AttackGraph(corelang_lang_graph, model) - - attacker = Attacker("TestAttacker", set(), set()) - attack_graph.add_attacker(attacker) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - - # No agents - # No actions given - no one observes/rewards anything - obs, rew, term, trunc, infos = sim._observe_and_reward({}, []) - assert not obs and not rew and not term and not trunc and not infos - - # Register an attacker - attacker_name = "attacker" - sim.register_attacker(attacker_name, 0) - obs, rew, term, trunc, infos = sim._observe_and_reward({}, []) - - # We need to reinitialize (or reset if we want to reset the attackgraph) - # to also initialize the registered agents - obs, infos = sim.initialize() - assert list(obs.keys()) == [attacker_name] - assert list(infos.keys()) == [attacker_name] - - # Observe and reward with no new actions - obs, rew, term, trunc, infos = sim._observe_and_reward({}, []) - # Since attacker has no entry points and no steps have been performed - # the observed state should be empty - for state in obs['attacker']['observed_state']: - assert state == -1 - assert rew[attacker_name] == 0 - - -def test_malsimulator_observe_and_reward_attacker_entrypoints( - traininglang_lang_graph, traininglang_model - ): - - attack_graph = AttackGraph( - traininglang_lang_graph, traininglang_model) - attack_graph.attach_attackers() - sim = MalSimulator( - traininglang_lang_graph, traininglang_model, attack_graph) - - # Register an attacker - attacker_name = "attacker" - attacker = sim.attack_graph.attackers[0] - sim.register_attacker(attacker_name, attacker.id) - # We need to reinitialize to initialize agent - obs, infos = sim.initialize() - - # Observe and reward with no new actions - obs, rew, term, trunc, infos = sim._observe_and_reward({}, []) - - for index, state in enumerate(obs['attacker']['observed_state']): - node = sim.index_to_node(index) - if state == -1: - assert node not in attacker.entry_points - assert node not in attacker.reached_attack_steps - assert not node.is_compromised() - assert not any([p.is_compromised() for p in node.parents]) - elif state == 0: - assert node not in attacker.entry_points - assert node not in attacker.reached_attack_steps - assert not node.is_compromised() - assert any([p.is_compromised() for p in node.parents]) - elif state == 1: - assert node in attacker.entry_points - assert node in attacker.reached_attack_steps - assert node.is_compromised() - - assert rew[attacker_name] == 0 - - -def test_malsimulator_agents_registered( - traininglang_lang_graph, traininglang_model - ): - - sim, _ = create_simulator_from_scenario( +def test_step_attacker_defender_action_surface_updates(): + ag, _ = load_scenario( 'tests/testdata/scenarios/traininglang_scenario.yml') - attacker_name = "attacker" - defender_name = "defender" - - # We need to reinitialize to initialize agents - obs, infos = sim.initialize() - assert set(obs.keys()) == {attacker_name, defender_name} - assert set(infos.keys()) == {attacker_name, defender_name} - - -def test_malsimulator_update_viability(corelang_lang_graph, model): - attack_graph = AttackGraph(corelang_lang_graph, model) - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - attempt_vuln_node = attack_graph.get_node_by_full_name('OS App:attemptUseVulnerability') - assert attempt_vuln_node.is_viable - success_vuln_node = attack_graph.get_node_by_full_name('OS App:successfulUseVulnerability') - assert success_vuln_node.is_viable - - # Make attempt unviable - attempt_vuln_node.is_viable = False - sim.update_viability(attempt_vuln_node) - # Should make success unviable - assert not success_vuln_node.is_viable - - -def test_malsimulator_step_attacker(corelang_lang_graph, model): - attack_graph = AttackGraph(corelang_lang_graph, model) - attack_graph.attach_attackers() - sim = MalSimulator(corelang_lang_graph, model, attack_graph) - - agent_name = "attacker1" - attacker_id = attack_graph.attackers[0].id - sim.register_attacker(agent_name, attacker_id) - assert agent_name in sim.agents_dict - assert not sim.agents_dict[agent_name]['action_surface'] - - obs, infos = sim.reset() - - # Run step() with action crafted in test - action = 1 - step = sim.attack_graph.get_node_by_full_name('OS App:attemptRead') - assert step in sim.agents_dict[agent_name]['action_surface'] - - step_index = sim._id_to_index[step.id] - actions = {agent_name: (action, step_index)} - observations, rewards, terminations, truncations, infos = sim.step(actions) - assert len(observations[agent_name]['observed_state']) == len(attack_graph.nodes) - assert agent_name in sim.agents_dict - assert sim.agents_dict[agent_name]['action_surface'] - - # Make sure 'OS App:attemptUseVulnerability' is observed and set to 1 (active) - assert observations[agent_name]['observed_state'][step_index] == 1 - for child in step.children: - child_step_index = sim._id_to_index[child.id] - # Make sure 'OS App:attemptUseVulnerability' children are observed and set to 0 (not active) - assert observations[agent_name]['observed_state'][child_step_index] == 0 - - -def test_step_attacker_defender_action_surface_updates( - corelang_lang_graph, model): - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/traininglang_scenario.yml') + sim = MalSimulator(ag) + # Register the agents + attacker_agent_id = "attacker" + defender_agent_id = "defender" - attacker_agent = next(iter(sim.get_attacker_agents())) - defender_agent = next(iter(sim.get_defender_agents())) + sim.register_attacker(attacker_agent_id, 1) + sim.register_defender(defender_agent_id) sim.reset() + attacker_agent = sim.agent_states[attacker_agent_id] + defender_agent = sim.agent_states[defender_agent_id] + # Run step() with action crafted in test attacker_step = sim.attack_graph.get_node_by_full_name('User:3:compromise') - assert attacker_step in sim.agents_dict[attacker_agent]['action_surface'] + assert attacker_step in attacker_agent.action_surface defender_step = sim.attack_graph.get_node_by_full_name('User:3:notPresent') - assert defender_step in sim.agents_dict[defender_agent]['action_surface'] + assert defender_step in defender_agent.action_surface actions = { - attacker_agent: (1, sim._id_to_index[attacker_step.id]), - defender_agent: (1, sim._id_to_index[defender_step.id]) + attacker_agent.name: [attacker_step], + defender_agent.name: [defender_step] } sim.step(actions) - assert attacker_step not in sim.agents_dict[attacker_agent]['action_surface'] - assert defender_step not in sim.agents_dict[defender_agent]['action_surface'] + # Make sure no nodes added to action surface + assert not attacker_agent.step_action_surface_additions + assert not defender_agent.step_action_surface_additions -def test_default_simulator_default_settings_eviction(): - """Test attacker node eviction using MalSimulatorSettings default""" - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/traininglang_scenario.yml', - ) - - sim.reset() - - attacker = next(iter(sim.attack_graph.attackers.values())) - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents())) - - # Get a step to compromise and its defense parent - user_3_compromise = sim.attack_graph.get_node_by_full_name('User:3:compromise') - assert attacker not in user_3_compromise.compromised_by - user_3_compromise_defense = next(n for n in user_3_compromise.parents if n.type=='defense') - assert not user_3_compromise_defense.is_enabled_defense() - - # First let the attacker compromise User:3:compromise - actions = { - attacker_agent_id: (1, sim._id_to_index[user_3_compromise.id]), - defender_agent_id: (0, None) - } - sim.step(actions) - - # Check that the compromise happened and that the defense did not - assert attacker in user_3_compromise.compromised_by - assert not user_3_compromise_defense.is_enabled_defense() - - # Now let the defender defend, and the attacker waits - actions = { - attacker_agent_id: (0, None), - defender_agent_id: (1, sim._id_to_index[user_3_compromise_defense.id]) - } - sim.step(actions) - - # Verify defense was performed and attacker NOT kicked out - assert user_3_compromise_defense.is_enabled_defense() - assert attacker in user_3_compromise.compromised_by + # Make sure the steps are removed from the action surfaces + assert attacker_step in attacker_agent.step_action_surface_removals + assert defender_step in defender_agent.step_action_surface_removals + assert attacker_step not in attacker_agent.action_surface + assert defender_step not in defender_agent.action_surface -def test_malsimulator_observe_and_reward_attacker_defender(): - """Run attacker and defender actions and make sure - rewards and observation states are updated correctly""" - - def verify_attacker_obs_state( - obs_state, - expected_reached, - expected_children_of_reached - ): - """Make sure obs state looks as expected""" - for index, state in enumerate(obs_state): - node_id = sim._index_to_id[index] - if state == 1: - assert node_id in expected_reached - elif state == 0: - assert node_id in expected_children_of_reached - else: - assert state == -1 - - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/traininglang_scenario.yml') - sim.reset() - - attacker = next(iter(sim.attack_graph.attackers.values())) - attacker_name = "attacker" - defender_name = "defender" - attacker_reached_steps = [n.id for n in attacker.entry_points] - attacker_reached_step_children = [] - for reached in attacker.entry_points: - attacker_reached_step_children.extend( - [n.id for n in reached.children]) - - # Prepare nodes that will be stepped through in order - user_3_compromise = sim.attack_graph\ - .get_node_by_full_name("User:3:compromise") - host_0_authenticate = sim.attack_graph\ - .get_node_by_full_name("Host:0:authenticate") - host_0_access = sim.attack_graph\ - .get_node_by_full_name("Host:0:access") - host_0_notPresent = sim.attack_graph\ - .get_node_by_full_name("Host:0:notPresent") - data_2_read = sim.attack_graph\ - .get_node_by_full_name("Data:2:read") - - # Step with attacker action - obs, rew, _, _, _ = sim.step({ - defender_name: (0, None), - attacker_name: (1, sim._id_to_index[user_3_compromise.id]) - } - ) - # Verify obs state - attacker_reached_steps.append(user_3_compromise.id) - attacker_reached_step_children.extend( - [n.id for n in user_3_compromise.children]) - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - # Verify rewards - assert rew[defender_name] == 0 - assert rew[attacker_name] == 0 - - # Step with attacker again - obs, rew, _, _, _ = sim.step({ - defender_name: (0, None), - attacker_name: (1, sim._id_to_index[host_0_authenticate.id]) - }) - - # Verify obs state - attacker_reached_steps.append(host_0_authenticate.id) - attacker_reached_step_children.extend( - [n.id for n in host_0_authenticate.children]) - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - # Verify rewards - assert rew[defender_name] == 0 - assert rew[attacker_name] == 0 - - # Step attacker again - obs, rew, _, _, _ = sim.step({ - defender_name: (0, None), - attacker_name: (1, sim._id_to_index[host_0_access.id]) - }) - - # Verify obs state - attacker_reached_steps.append(host_0_access.id) - attacker_reached_step_children.extend( - [n.id for n in host_0_access.children]) - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - reward_host_0_access = 4 - # Verify rewards - assert rew[attacker_name] == reward_host_0_access - assert rew[defender_name] == -rew[attacker_name] - - # Step defender and attacker - # Attacker wont be able to traverse Data:2:read since - # Host:0:notPresent is activated before - obs, rew, _, _, _ = sim.step({ - defender_name: (1, sim._id_to_index[host_0_notPresent.id]), - attacker_name: (1, sim._id_to_index[data_2_read.id]) - }) - - # Attacker obs state should look the same as before - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - # Verify rewards - reward_host_0_not_present = 2 - assert rew[attacker_name] == reward_host_0_access # no additional reward - assert rew[defender_name] == -rew[attacker_name] - reward_host_0_not_present - - -def test_malsimulator_observe_and_reward_uncompromise_untraversable(): - """Run attacker and defender actions and make sure - rewards and observation states are updated correctly""" - - def verify_attacker_obs_state( - obs_state, - expected_reached, - expected_children_of_reached - ): - """Make sure obs state looks as expected""" - for index, state in enumerate(obs_state): - node_id = sim._index_to_id[index] - node = sim.attack_graph.nodes[node_id] - if state == 1: - assert node_id in expected_reached - elif state == 0: - assert node_id in expected_children_of_reached or \ - (node_id in expected_reached and not node.is_viable) - else: - assert state == -1 - - sim, _ = create_simulator_from_scenario( +def test_default_simulator_default_settings_eviction(): + """Test attacker node eviction using MalSimulatorSettings default""" + ag, _ = load_scenario( 'tests/testdata/scenarios/traininglang_scenario.yml', - sim_settings=MalSimulatorSettings( - uncompromise_untraversable_steps=True - ) ) - sim.reset() - - attacker = next(iter(sim.attack_graph.attackers.values())) - attacker_name = "attacker" - defender_name = "defender" - attacker_reached_steps = [n.id for n in attacker.entry_points] - attacker_reached_step_children = [] - for reached in attacker.entry_points: - attacker_reached_step_children.extend( - [n.id for n in reached.children]) - - # Prepare nodes that will be stepped through in order - user_3_compromise = sim.attack_graph\ - .get_node_by_full_name("User:3:compromise") - host_0_authenticate = sim.attack_graph\ - .get_node_by_full_name("Host:0:authenticate") - host_0_access = sim.attack_graph\ - .get_node_by_full_name("Host:0:access") - host_0_notPresent = sim.attack_graph\ - .get_node_by_full_name("Host:0:notPresent") - data_2_read = sim.attack_graph\ - .get_node_by_full_name("Data:2:read") - - # Step with attacker action - obs, rew, _, _, _ = sim.step({ - defender_name: (0, None), - attacker_name: (1, sim._id_to_index[user_3_compromise.id]) - } - ) - - # Verify obs state - attacker_reached_steps.append(user_3_compromise.id) - attacker_reached_step_children.extend( - [n.id for n in user_3_compromise.children]) - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - # Verify rewards - assert rew[defender_name] == 0 - assert rew[attacker_name] == 0 - - # Step with attacker again - obs, rew, _, _, _ = sim.step({ - defender_name: (0, None), - attacker_name: (1, sim._id_to_index[host_0_authenticate.id]) - }) - - # Verify obs state - attacker_reached_steps.append(host_0_authenticate.id) - attacker_reached_step_children.extend( - [n.id for n in host_0_authenticate.children]) - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - # Verify rewards - assert rew[defender_name] == 0 - assert rew[attacker_name] == 0 - - # Step attacker again - obs, rew, _, _, _ = sim.step({ - defender_name: (0, None), - attacker_name: (1, sim._id_to_index[host_0_access.id]) - }) - - # Verify obs state - attacker_reached_steps.append(host_0_access.id) - attacker_reached_step_children.extend( - [n.id for n in host_0_access.children]) - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - reward_host_0_access = 4 - # Verify rewards - assert rew[attacker_name] == reward_host_0_access - assert rew[defender_name] == -rew[attacker_name] - - # Step defender and attacker - # Attacker wont be able to traverse Data:2:read since - # Host:0:notPresent is activated before - obs, rew, _, _, _ = sim.step({ - defender_name: (1, sim._id_to_index[host_0_notPresent.id]), - attacker_name: (1, sim._id_to_index[data_2_read.id]) - }) - - # Attacker obs state should look the same as before - verify_attacker_obs_state( - obs[attacker_name]['observed_state'], - attacker_reached_steps, - attacker_reached_step_children) - - # Verify rewards - reward_host_0_not_present = 2 - assert rew[attacker_name] == 0 # no reward anymore - assert rew[defender_name] == -reward_host_0_not_present - -def test_simulator_settings_evict_attacker(): - """Test MalSimulatorSettings when it should evict attacker - from untraversable node""" + sim = MalSimulator(ag) - settings_evict_attacker = MalSimulatorSettings( - uncompromise_untraversable_steps=True - ) - - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/traininglang_scenario.yml', - sim_settings=settings_evict_attacker - ) + # Register the agents + # Register the agents + attacker_agent_id = "attacker" + defender_agent_id = "defender" + sim.register_attacker(attacker_agent_id, 1) + sim.register_defender(defender_agent_id) sim.reset() - attacker = next(iter(sim.attack_graph.attackers.values())) - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents())) - # Get a step to compromise and its defense parent + # Get a step to compromise and its defense parent user_3_compromise = sim.attack_graph.get_node_by_full_name('User:3:compromise') assert attacker not in user_3_compromise.compromised_by user_3_compromise_defense = next(n for n in user_3_compromise.parents if n.type=='defense') @@ -989,8 +341,8 @@ def test_simulator_settings_evict_attacker(): # First let the attacker compromise User:3:compromise actions = { - attacker_agent_id: (1, sim._id_to_index[user_3_compromise.id]), - defender_agent_id: (0, None) + attacker_agent_id: [user_3_compromise], + defender_agent_id: [] } sim.step(actions) @@ -1000,133 +352,11 @@ def test_simulator_settings_evict_attacker(): # Now let the defender defend, and the attacker waits actions = { - attacker_agent_id: (0, None), - defender_agent_id: (1, sim._id_to_index[user_3_compromise_defense.id]) + attacker_agent_id: [], + defender_agent_id: [user_3_compromise_defense] } sim.step(actions) - # Verify defense was performed and attacker WAS kicked out + # Verify defense was performed and attacker NOT kicked out assert user_3_compromise_defense.is_enabled_defense() - assert attacker not in user_3_compromise.compromised_by - - -def test_simulator_default_settings_defender_observation(): - """Test MalSimulatorSettings show previous steps in obs""" - - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/traininglang_scenario.yml' - ) - sim.reset() - - attacker = next(iter(sim.attack_graph.attackers.values())) - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents())) - - # Get an uncompromised step - user_3_compromise = sim.attack_graph.get_node_by_full_name( - 'User:3:compromise') - assert attacker not in user_3_compromise.compromised_by - - # Get a defense for the uncompromised step - user_3_compromise_defense = next( - n for n in user_3_compromise.parents if n.type=='defense') - assert not user_3_compromise_defense.is_enabled_defense() - - # First let the attacker compromise User:3:compromise - actions = { - attacker_agent_id: (1, sim._id_to_index[user_3_compromise.id]), - defender_agent_id: (0, None) - } - - obs, _, _, _, _ = sim.step(actions) - defender_observation = obs[defender_agent_id]['observed_state'] - - # Verify that all states in obs match the state of the attack graph - for index, state in enumerate(defender_observation): - step_id = sim._index_to_id[index] - node = sim.attack_graph.nodes[step_id] - if state == 1: - assert node.is_compromised() - else: - assert not node.is_compromised() - - # Now let the defender defend, and the attacker waits - actions = { - attacker_agent_id: (0, None), - defender_agent_id: (1, sim._id_to_index[user_3_compromise_defense.id]) - } - obs, _, _, _, _ = sim.step(actions) - defender_observation = obs[defender_agent_id]['observed_state'] - - # Verify that all states in obs match the state of the attack graph - for index, state in enumerate(defender_observation): - step_id = sim._index_to_id[index] - node = sim.attack_graph.nodes[step_id] - if state == 1: - assert node.is_compromised() or node.is_enabled_defense() - else: - assert not node.is_compromised() and not node.is_enabled_defense() - - -def test_simulator_settings_defender_observation(): - """Test MalSimulatorSettings only show last steps in obs""" - - settings_dont_show_previous = MalSimulatorSettings( - cumulative_defender_obs=False - ) - - sim, _ = create_simulator_from_scenario( - 'tests/testdata/scenarios/traininglang_scenario.yml', - sim_settings=settings_dont_show_previous - ) - sim.reset() - - attacker = next(iter(sim.attack_graph.attackers.values())) - attacker_agent_id = next(iter(sim.get_attacker_agents())) - defender_agent_id = next(iter(sim.get_defender_agents())) - - # Get an uncompromised step - user_3_compromise = sim.attack_graph.get_node_by_full_name( - 'User:3:compromise') - assert attacker not in user_3_compromise.compromised_by - - # Get a defense for the uncompromised step - user_3_compromise_defense = next( - n for n in user_3_compromise.parents if n.type=='defense') - assert not user_3_compromise_defense.is_enabled_defense() - - # First let the attacker compromise User:3:compromise - actions = { - attacker_agent_id: (1, sim._id_to_index[user_3_compromise.id]), - defender_agent_id: (0, None) - } - - obs, _, _, _, _ = sim.step(actions) - defender_observation = obs[defender_agent_id]['observed_state'] - - # Verify that the only active state node in obs - # is the latest performed step (User:3:compromise) - for index, state in enumerate(defender_observation): - step_id = sim._index_to_id[index] - node = sim.attack_graph.nodes[step_id] - if node == user_3_compromise: - assert state == 1 # Last performed step known active state - else: - assert state == 0 # All others inactive - - # Now let the defender defend, and the attacker waits - actions = { - attacker_agent_id: (0, None), - defender_agent_id: (1, sim._id_to_index[user_3_compromise_defense.id]) - } - obs, _, _, _, _ = sim.step(actions) - defender_observation = obs[defender_agent_id]['observed_state'] - - # Verify that the only active state node in obs - # is the latest performed step (the defense step) - for index, state in enumerate(defender_observation): - node = sim.index_to_node(index) - if node == user_3_compromise_defense: - assert state == 1 # Last performed step known active state - else: - assert state == 0 # All others inactive + assert attacker in user_3_compromise.compromised_by diff --git a/tests/test_scenario.py b/tests/test_scenario.py index 53d7cf07..314ce00a 100644 --- a/tests/test_scenario.py +++ b/tests/test_scenario.py @@ -7,9 +7,7 @@ apply_scenario_node_property_rules, load_scenario ) -from malsim.agents.keyboard_input import KeyboardAgent -from malsim.agents.searchers import BreadthFirstAttacker - +from malsim.agents import PassiveAgent, BreadthFirstAttacker def path_relative_to_tests(filename): """Returns the absolute path of a file in ./tests @@ -25,7 +23,7 @@ def test_load_scenario(): """Make sure we can load a scenario""" # Load the scenario - attack_graph, config = load_scenario( + attack_graph, agents = load_scenario( path_relative_to_tests('./testdata/scenarios/simple_scenario.yml') ) @@ -64,8 +62,8 @@ def test_load_scenario(): # Entry points list and reached attack steps list are different lists assert id(attacker.entry_points) != id(attacker.reached_attack_steps) - assert config['agents']['attacker']['agent_class'] == BreadthFirstAttacker - assert config['agents']['defender']['agent_class'] == KeyboardAgent + assert isinstance(agents[0]['agent'], BreadthFirstAttacker) + assert isinstance(agents[1]['agent'], PassiveAgent) def test_load_scenario_no_attacker_in_model(): @@ -119,13 +117,13 @@ def test_load_scenario_no_defender_agent(): """Make sure we can load a scenario""" # Load the scenario - _, config = load_scenario( + _, agents = load_scenario( path_relative_to_tests( './testdata/scenarios/no_defender_agent_scenario.yml' ) ) - assert 'defender' not in config['agents'] - assert config['agents']['attacker']['agent_class'] == BreadthFirstAttacker + assert 'defender' not in [a['name'] for a in agents] + assert isinstance(agents[0]['agent'], BreadthFirstAttacker) def test_load_scenario_agent_class_error(): diff --git a/tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml b/tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml index 82c6d1a1..938aedc7 100644 --- a/tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml +++ b/tests/testdata/scenarios/bfs_vs_bfs_network_app_data_scenario.yml @@ -1,26 +1,22 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/network_app_data_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'BreadthFirstAttacker' - -attacker_entry_points: - 'attacker1': - - 'Internet:accessUninspected' - # Rewards for each attack step rewards: Linux system:notPresent: 50 Secret data:read: 100 -# This is for new base mal simulator: -# agents: -# 'attacker1': -# agent_class: KeyboardAgent -# type: attacker -# entry_points: -# - 'Internet:accessUninspected' +agents: + 'attacker1': + agent_class: BreadthFirstAttacker + type: attacker + config: + seed: 1 + entry_points: + - 'Internet:accessUninspected' -# 'defender1': -# agent_class: BreadthFirstAttacker -# type: defender + 'defender1': + agent_class: BreadthFirstAttacker + type: defender + config: + seed: 1 diff --git a/tests/testdata/scenarios/bfs_vs_bfs_scenario.yml b/tests/testdata/scenarios/bfs_vs_bfs_scenario.yml index 46bbd969..713bc61e 100644 --- a/tests/testdata/scenarios/bfs_vs_bfs_scenario.yml +++ b/tests/testdata/scenarios/bfs_vs_bfs_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_test_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'BreadthFirstAttacker' - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -16,8 +13,13 @@ rewards: Identity:11:notPresent: 3.5 Identity:8:assume: 50 -# Add entry points to AttackGraph with attacker names -# and attack step full_names -attacker_entry_points: - 'Attacker1': - - 'Credentials:6:attemptCredentialsReuse' +agents: + 'attacker1': + agent_class: BreadthFirstAttacker + type: attacker + entry_points: + - 'Credentials:6:attemptCredentialsReuse' + + 'defender1': + agent_class: BreadthFirstAttacker + type: defender diff --git a/tests/testdata/scenarios/no_defender_agent_scenario.yml b/tests/testdata/scenarios/no_defender_agent_scenario.yml index 52b78e1b..326f5d8b 100644 --- a/tests/testdata/scenarios/no_defender_agent_scenario.yml +++ b/tests/testdata/scenarios/no_defender_agent_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_test_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: null # false should be fine as well - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -18,6 +15,11 @@ rewards: # Add entry points to AttackGraph with attacker names # and attack step full_names -attacker_entry_points: - 'Attacker1': - - 'Credentials:6:attemptCredentialsReuse' +agents: + 'attacker1': + type: attacker + agent_class: BreadthFirstAttacker + config: + seed: 1 + entry_points: + - 'Credentials:6:attemptCredentialsReuse' diff --git a/tests/testdata/scenarios/no_entry_points_simple_scenario.yml b/tests/testdata/scenarios/no_entry_points_simple_scenario.yml index 0a46d460..d9f7ba60 100644 --- a/tests/testdata/scenarios/no_entry_points_simple_scenario.yml +++ b/tests/testdata/scenarios/no_entry_points_simple_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_test_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -16,4 +13,8 @@ rewards: Identity:11:notPresent: 3.5 Identity:8:assume: 50 -# No attacker entry points \ No newline at end of file +# No attacker entry points +agents: + 'Lonely Defender': + type: 'defender' + agent_class: PassiveAgent diff --git a/tests/testdata/scenarios/no_existing_attacker_in_model_scenario.yml b/tests/testdata/scenarios/no_existing_attacker_in_model_scenario.yml index 910a3a2a..d417d528 100644 --- a/tests/testdata/scenarios/no_existing_attacker_in_model_scenario.yml +++ b/tests/testdata/scenarios/no_existing_attacker_in_model_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_no_attacker_test_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -17,6 +14,13 @@ rewards: # Add entry points to AttackGraph with attacker names # and attack step full_names -attacker_entry_points: +agents: 'Attacker1': - - 'Credentials:6:attemptCredentialsReuse' + agent_class: BreadthFirstAttacker + type: attacker + entry_points: + - 'Credentials:6:attemptCredentialsReuse' + + 'Defender1': + agent_class: KeyboardAgent + type: defender diff --git a/tests/testdata/scenarios/run_demo_scenario.yml b/tests/testdata/scenarios/run_demo_scenario.yml index 10277850..157916a7 100644 --- a/tests/testdata/scenarios/run_demo_scenario.yml +++ b/tests/testdata/scenarios/run_demo_scenario.yml @@ -4,9 +4,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/run_demo_model.json -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' - # Rewards for each attack step (same as in run_demo.py) rewards: OS App:notPresent: 50 @@ -25,3 +22,14 @@ rewards: Program 1:fullAccess: 50 Identity:5:assume: 50 Other OS App:fullAccess: 200 + +agents: + 'Attacker1': + agent_class: BreadthFirstAttacker + type: attacker + entry_points: + - 'OS App:networkConnectUninspected' + + 'Defender1': + agent_class: KeyboardAgent + type: defender diff --git a/tests/testdata/scenarios/simple_filtered_observability_scenario.yml b/tests/testdata/scenarios/simple_filtered_observability_scenario.yml index 5f61cbff..433f56ab 100644 --- a/tests/testdata/scenarios/simple_filtered_observability_scenario.yml +++ b/tests/testdata/scenarios/simple_filtered_observability_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_test_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -18,9 +15,16 @@ rewards: # Add entry points to AttackGraph with attacker names # and attack step full_names -attacker_entry_points: +agents: 'Attacker1': - - 'Credentials:6:attemptCredentialsReuse' + agent_class: BreadthFirstAttacker + type: attacker + entry_points: + - 'Credentials:6:attemptCredentialsReuse' + + 'Defender1': + agent_class: KeyboardAgent + type: defender # Optional way to make only certain attack steps observable # If observable_steps are set, all attack steps not @@ -33,4 +37,4 @@ observable_steps: by_asset_name: Identity:8: - - assume \ No newline at end of file + - assume diff --git a/tests/testdata/scenarios/simple_scenario.yml b/tests/testdata/scenarios/simple_scenario.yml index 200872b4..e39619e9 100644 --- a/tests/testdata/scenarios/simple_scenario.yml +++ b/tests/testdata/scenarios/simple_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_test_model.yml -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -18,6 +15,15 @@ rewards: # Add entry points to AttackGraph with attacker names # and attack step full_names -attacker_entry_points: +agents: 'Attacker1': + type: attacker + agent_class: BreadthFirstAttacker + config: + seed: 1 + entry_points: - 'Credentials:6:attemptCredentialsReuse' + + 'Defender1': + type: defender + agent_class: PassiveAgent diff --git a/tests/testdata/scenarios/traininglang_actionability_scenario.yml b/tests/testdata/scenarios/traininglang_actionability_scenario.yml index 8a42e2dd..04512047 100644 --- a/tests/testdata/scenarios/traininglang_actionability_scenario.yml +++ b/tests/testdata/scenarios/traininglang_actionability_scenario.yml @@ -11,14 +11,17 @@ rewards: Data:2:read: 5 Data:2:modify: 10 -# The possible entry points -attacker_entry_points: +agents: 'Attacker1': + type: 'attacker' + agent_class: BreadthFirstAttacker + entry_points: - 'User:3:phishing' - 'Host:0:connect' -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' + 'Defender1': + type: 'defender' + agent_class: KeyboardAgent # Optional way to make only certain steps actionable # If actionable_steps are set, all steps not diff --git a/tests/testdata/scenarios/traininglang_observability_scenario.yml b/tests/testdata/scenarios/traininglang_observability_scenario.yml index a90d26e1..c930e16f 100644 --- a/tests/testdata/scenarios/traininglang_observability_scenario.yml +++ b/tests/testdata/scenarios/traininglang_observability_scenario.yml @@ -11,14 +11,17 @@ rewards: Data:2:read: 5 Data:2:modify: 10 -# The possible entry points -attacker_entry_points: +agents: 'Attacker1': + type: 'attacker' + agent_class: BreadthFirstAttacker + entry_points: - 'User:3:phishing' - 'Host:0:connect' -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' + 'Defender1': + type: 'defender' + agent_class: KeyboardAgent # Optional way to make only certain steps observable # If observable_steps are set, all steps not diff --git a/tests/testdata/scenarios/traininglang_scenario.yml b/tests/testdata/scenarios/traininglang_scenario.yml index 0daa87b5..697cdad5 100644 --- a/tests/testdata/scenarios/traininglang_scenario.yml +++ b/tests/testdata/scenarios/traininglang_scenario.yml @@ -11,11 +11,12 @@ rewards: Data:2:read: 5 Data:2:modify: 10 -# The possible entry points -attacker_entry_points: +agents: 'Attacker1': + type: 'attacker' + entry_points: - 'User:3:phishing' - 'Host:0:connect' -attacker_agent_class: 'BreadthFirstAttacker' -defender_agent_class: 'KeyboardAgent' + 'Defender1': + type: 'defender' diff --git a/tests/testdata/scenarios/wrong_agent_classes_scenario.yml b/tests/testdata/scenarios/wrong_agent_classes_scenario.yml index edb33eb2..4cbdd45b 100644 --- a/tests/testdata/scenarios/wrong_agent_classes_scenario.yml +++ b/tests/testdata/scenarios/wrong_agent_classes_scenario.yml @@ -1,9 +1,6 @@ lang_file: ../langs/org.mal-lang.coreLang-1.0.0.mar model_file: ../models/simple_test_model.yml -attacker_agent_class: 'BananaAttacker' -defender_agent_class: 'FishAttacker' - # Rewards for each attack step rewards: OS App:notPresent: 2 @@ -17,6 +14,13 @@ rewards: # Add entry points to AttackGraph with attacker names # and attack step full_names -attacker_entry_points: +agents: 'Attacker1': + agent_class: BananaAttacker + type: 'attacker' + entry_points: - 'Credentials:6:attemptCredentialsReuse' + + 'Defender1': + agent_class: FishAttacker + type: 'defender'