diff --git a/malsim/agents/__init__.py b/malsim/agents/__init__.py index 9b4bb0bb..830af82a 100644 --- a/malsim/agents/__init__.py +++ b/malsim/agents/__init__.py @@ -2,11 +2,16 @@ from .passive_agent import PassiveAgent from .keyboard_input import KeyboardAgent from .searchers import BreadthFirstAttacker, DepthFirstAttacker +from .heuristic_agent import ( + DefendCompromisedDefender, DefendFutureCompromisedDefender +) __all__ = [ 'PassiveAgent', 'DecisionAgent', 'KeyboardAgent', 'BreadthFirstAttacker', - 'DepthFirstAttacker' + 'DepthFirstAttacker', + 'DefendCompromisedDefender', + 'DefendFutureCompromisedDefender' ] diff --git a/malsim/agents/heuristic_agent.py b/malsim/agents/heuristic_agent.py new file mode 100644 index 00000000..e1377207 --- /dev/null +++ b/malsim/agents/heuristic_agent.py @@ -0,0 +1,125 @@ +from __future__ import annotations +from typing import Optional, TYPE_CHECKING +import logging +import math + +import numpy as np + +from .decision_agent import DecisionAgent + +if TYPE_CHECKING: + from maltoolbox.attackgraph import AttackGraphNode + from ..mal_simulator import MalSimAgentStateView + +logger = logging.getLogger(__name__) + +class DefendCompromisedDefender(DecisionAgent): + """A defender that defends compromised assets using notPresent""" + + def __init__(self, agent_config, **_): + # Seed and rng not currently used + seed = ( + agent_config["seed"] + if agent_config.get("seed") + else np.random.SeedSequence().entropy + ) + self.rng = ( + np.random.default_rng(seed) + if agent_config.get("randomize") + else None + ) + + def get_next_action( + self, agent_state: MalSimAgentStateView, **kwargs + ) -> Optional[AttackGraphNode]: + + """Return an action that disables a compromised node""" + + selected_node_cost = math.inf + selected_node = None + + # To make it deterministic + possible_choices = list(agent_state.action_surface) + possible_choices.sort(key=lambda n: n.id) + + for node in possible_choices: + + if node.is_enabled_defense(): + continue + + node_cost = node.extras.get('reward', 0) + + # Strategy: + # - Enabled the cheapest defense node + # that has compromised child nodes + if node_cost < selected_node_cost: + + node_has_compromised_child = ( + any( + child_node.is_compromised() + for child_node in node.children + ) + ) + + if node_has_compromised_child: + selected_node = node + selected_node_cost = node_cost + + return selected_node + + +class DefendFutureCompromisedDefender(DecisionAgent): + """A defender that defends compromised assets using notPresent""" + + def __init__(self, agent_config, **_): + # Seed and rng not currently used + seed = ( + agent_config["seed"] + if agent_config.get("seed") + else np.random.SeedSequence().entropy + ) + self.rng = ( + np.random.default_rng(seed) + if agent_config.get("randomize") + else None + ) + + def get_next_action( + self, agent_state: MalSimAgentStateView, **kwargs + ) -> Optional[AttackGraphNode]: + + """Return an action that disables a compromised node""" + + selected_node_cost = math.inf + selected_node = None + + # To make it deterministic + possible_choices = list(agent_state.action_surface) + possible_choices.sort(key=lambda n: n.id) + + for node in possible_choices: + + if node.is_enabled_defense(): + continue + + node_cost = node.extras.get('reward', 0) + + # Strategy: + # - Enabled the cheapest defense node + # that has a non compromised child + # that has a compromised parent. + if node_cost < selected_node_cost: + + node_has_child_that_can_be_compromised = ( + any( + any(p.is_compromised() for p in child_node.parents) + and not child_node.is_compromised() + for child_node in node.children + ) + ) + + if node_has_child_that_can_be_compromised: + selected_node = node + selected_node_cost = node_cost + + return selected_node diff --git a/malsim/scenario.py b/malsim/scenario.py index f7612589..d7aa4f8e 100644 --- a/malsim/scenario.py +++ b/malsim/scenario.py @@ -22,7 +22,9 @@ BreadthFirstAttacker, DepthFirstAttacker, KeyboardAgent, - PassiveAgent + PassiveAgent, + DefendCompromisedDefender, + DefendFutureCompromisedDefender ) from .mal_simulator import AgentType, MalSimulator @@ -32,6 +34,8 @@ 'BreadthFirstAttacker': BreadthFirstAttacker, 'KeyboardAgent': KeyboardAgent, 'PassiveAgent': PassiveAgent, + 'DefendCompromisedDefender': DefendCompromisedDefender, + 'DefendFutureCompromisedDefender': DefendFutureCompromisedDefender } deprecated_fields = [ diff --git a/tests/agents/test_heuristic_agents.py b/tests/agents/test_heuristic_agents.py new file mode 100644 index 00000000..841e933b --- /dev/null +++ b/tests/agents/test_heuristic_agents.py @@ -0,0 +1,131 @@ +from unittest.mock import MagicMock +from maltoolbox.attackgraph import AttackGraphNode, Attacker +from maltoolbox.language import LanguageGraph +from malsim.mal_simulator import MalSimAgentStateView +from malsim.agents import ( + DefendCompromisedDefender, + DefendFutureCompromisedDefender +) + +def test_defend_compromised_defender(dummy_lang_graph: LanguageGraph): + r""" + node1 node2 + / \ / \ + node3 node4 node5 + + """ + dummy_or_attack_step = ( + dummy_lang_graph.assets['DummyAsset'] + .attack_steps['DummyOrAttackStep'] + ) + + # Create nodes + node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1) + node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2) + + node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3) + node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4) + node5 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=5) + + # Connect nodes (Node1 -> Node2 -> Node3 -> Node4) + node1.children.add(node3) + node3.parents.add(node1) + node1.children.add(node4) + node4.parents.add(node1) + + node2.children.add(node4) + node4.parents.add(node2) + node2.children.add(node5) + node5.parents.add(node5) + + # Set up an attacker + attacker = Attacker(name="TestAttacker") + attacker.compromise(node4) + + # Set up a mock MalSimAgentState + agent = MagicMock() + agent.action_surface = [node1, node2] + + # Set up MalSimAgentStateView + agent_view = MalSimAgentStateView(agent) + + # Configure BreadthFirstAttacker + agent_config = {"seed": 42, "randomize": False} + defender_ai = DefendCompromisedDefender(agent_config) + + # Should pick cheapest one + node1.extras['reward'] = 100 + node2.extras['reward'] = 10 + + # Get next action + action_node = defender_ai.get_next_action(agent_view) + assert action_node is not None, "Action node shouldn't be None" + assert action_node.id == node2.id + + # Should pick cheapest one + node1.extras['reward'] = 10 + node2.extras['reward'] = 100 + + # Get next action + action_node = defender_ai.get_next_action(agent_view) + assert action_node is not None, "Action node shouldn't be None" + assert action_node.id == node1.id + + +def test_defend_future_compromised_defender(dummy_lang_graph: LanguageGraph): + r""" + node1 node2 + / \ / \ + node3 node4 | node5 + | | + \ / + node 6 + """ + + dummy_or_attack_step = ( + dummy_lang_graph.assets['DummyAsset'] + .attack_steps['DummyOrAttackStep'] + ) + + # Create nodes + node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1) + node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2) + + node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3) + node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4) + node5 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=5) + node6 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=6) + + # Connect nodes (Node1 -> Node2 -> Node3 -> Node4) + node1.children.add(node3) + node3.parents.add(node1) + node1.children.add(node4) + node4.parents.add(node1) + + node2.children.add(node6) + node6.parents.add(node2) + node2.children.add(node5) + node5.parents.add(node5) + + node4.children.add(node6) + node6.parents.add(node4) + + # Set up an attacker + attacker = Attacker(name="TestAttacker") + attacker.compromise(node4) + + # Set up a mock MalSimAgentState + agent = MagicMock() + agent.action_surface = [node1, node2] + + # Set up MalSimAgentStateView + agent_view = MalSimAgentStateView(agent) + + # Configure BreadthFirstAttacker + agent_config = {"seed": 42, "randomize": False} + defender_ai = DefendFutureCompromisedDefender(agent_config) + + # Should pick node 2 either way + action_node = defender_ai.get_next_action(agent_view) + assert action_node is not None, "Action node shouldn't be None" + assert action_node.id == node2.id