Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion malsim/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
from .passive_agent import PassiveAgent
from .keyboard_input import KeyboardAgent
from .searchers import BreadthFirstAttacker, DepthFirstAttacker
from .heuristic_agent import (
DefendCompromisedDefender, DefendFutureCompromisedDefender
)

__all__ = [
'PassiveAgent',
'DecisionAgent',
'KeyboardAgent',
'BreadthFirstAttacker',
'DepthFirstAttacker'
'DepthFirstAttacker',
'DefendCompromisedDefender',
'DefendFutureCompromisedDefender'
]
125 changes: 125 additions & 0 deletions malsim/agents/heuristic_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
import logging
import math

import numpy as np

from .decision_agent import DecisionAgent

if TYPE_CHECKING:
from maltoolbox.attackgraph import AttackGraphNode
from ..mal_simulator import MalSimAgentStateView

logger = logging.getLogger(__name__)

class DefendCompromisedDefender(DecisionAgent):
"""A defender that defends compromised assets using notPresent"""

def __init__(self, agent_config, **_):
# Seed and rng not currently used
seed = (
agent_config["seed"]
if agent_config.get("seed")
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize")
else None
)

def get_next_action(
self, agent_state: MalSimAgentStateView, **kwargs
) -> Optional[AttackGraphNode]:

"""Return an action that disables a compromised node"""

selected_node_cost = math.inf
selected_node = None

# To make it deterministic
possible_choices = list(agent_state.action_surface)
possible_choices.sort(key=lambda n: n.id)

for node in possible_choices:

if node.is_enabled_defense():
continue

node_cost = node.extras.get('reward', 0)

# Strategy:
# - Enabled the cheapest defense node
# that has compromised child nodes
if node_cost < selected_node_cost:

node_has_compromised_child = (
any(
child_node.is_compromised()
for child_node in node.children
)
)

if node_has_compromised_child:
selected_node = node
selected_node_cost = node_cost

return selected_node


class DefendFutureCompromisedDefender(DecisionAgent):
"""A defender that defends compromised assets using notPresent"""

def __init__(self, agent_config, **_):
# Seed and rng not currently used
seed = (
agent_config["seed"]
if agent_config.get("seed")
else np.random.SeedSequence().entropy
)
self.rng = (
np.random.default_rng(seed)
if agent_config.get("randomize")
else None
)

def get_next_action(
self, agent_state: MalSimAgentStateView, **kwargs
) -> Optional[AttackGraphNode]:

"""Return an action that disables a compromised node"""

selected_node_cost = math.inf
selected_node = None

# To make it deterministic
possible_choices = list(agent_state.action_surface)
possible_choices.sort(key=lambda n: n.id)

for node in possible_choices:

if node.is_enabled_defense():
continue

node_cost = node.extras.get('reward', 0)

# Strategy:
# - Enabled the cheapest defense node
# that has a non compromised child
# that has a compromised parent.
if node_cost < selected_node_cost:

node_has_child_that_can_be_compromised = (
any(
any(p.is_compromised() for p in child_node.parents)
and not child_node.is_compromised()
for child_node in node.children
)
)

if node_has_child_that_can_be_compromised:
selected_node = node
selected_node_cost = node_cost

return selected_node
6 changes: 5 additions & 1 deletion malsim/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
BreadthFirstAttacker,
DepthFirstAttacker,
KeyboardAgent,
PassiveAgent
PassiveAgent,
DefendCompromisedDefender,
DefendFutureCompromisedDefender
)

from .mal_simulator import AgentType, MalSimulator
Expand All @@ -32,6 +34,8 @@
'BreadthFirstAttacker': BreadthFirstAttacker,
'KeyboardAgent': KeyboardAgent,
'PassiveAgent': PassiveAgent,
'DefendCompromisedDefender': DefendCompromisedDefender,
'DefendFutureCompromisedDefender': DefendFutureCompromisedDefender
}

deprecated_fields = [
Expand Down
131 changes: 131 additions & 0 deletions tests/agents/test_heuristic_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from unittest.mock import MagicMock
from maltoolbox.attackgraph import AttackGraphNode, Attacker
from maltoolbox.language import LanguageGraph
from malsim.mal_simulator import MalSimAgentStateView
from malsim.agents import (
DefendCompromisedDefender,
DefendFutureCompromisedDefender
)

def test_defend_compromised_defender(dummy_lang_graph: LanguageGraph):
r"""
node1 node2
/ \ / \
node3 node4 node5

"""
dummy_or_attack_step = (
dummy_lang_graph.assets['DummyAsset']
.attack_steps['DummyOrAttackStep']
)

# Create nodes
node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1)
node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2)

node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3)
node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4)
node5 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=5)

# Connect nodes (Node1 -> Node2 -> Node3 -> Node4)
node1.children.add(node3)
node3.parents.add(node1)
node1.children.add(node4)
node4.parents.add(node1)

node2.children.add(node4)
node4.parents.add(node2)
node2.children.add(node5)
node5.parents.add(node5)

# Set up an attacker
attacker = Attacker(name="TestAttacker")
attacker.compromise(node4)

# Set up a mock MalSimAgentState
agent = MagicMock()
agent.action_surface = [node1, node2]

# Set up MalSimAgentStateView
agent_view = MalSimAgentStateView(agent)

# Configure BreadthFirstAttacker
agent_config = {"seed": 42, "randomize": False}
defender_ai = DefendCompromisedDefender(agent_config)

# Should pick cheapest one
node1.extras['reward'] = 100
node2.extras['reward'] = 10

# Get next action
action_node = defender_ai.get_next_action(agent_view)
assert action_node is not None, "Action node shouldn't be None"
assert action_node.id == node2.id

# Should pick cheapest one
node1.extras['reward'] = 10
node2.extras['reward'] = 100

# Get next action
action_node = defender_ai.get_next_action(agent_view)
assert action_node is not None, "Action node shouldn't be None"
assert action_node.id == node1.id


def test_defend_future_compromised_defender(dummy_lang_graph: LanguageGraph):
r"""
node1 node2
/ \ / \
node3 node4 | node5
| |
\ /
node 6
"""

dummy_or_attack_step = (
dummy_lang_graph.assets['DummyAsset']
.attack_steps['DummyOrAttackStep']
)

# Create nodes
node1 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=1)
node2 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=2)

node3 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=3)
node4 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=4)
node5 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=5)
node6 = AttackGraphNode(lg_attack_step=dummy_or_attack_step, node_id=6)

# Connect nodes (Node1 -> Node2 -> Node3 -> Node4)
node1.children.add(node3)
node3.parents.add(node1)
node1.children.add(node4)
node4.parents.add(node1)

node2.children.add(node6)
node6.parents.add(node2)
node2.children.add(node5)
node5.parents.add(node5)

node4.children.add(node6)
node6.parents.add(node4)

# Set up an attacker
attacker = Attacker(name="TestAttacker")
attacker.compromise(node4)

# Set up a mock MalSimAgentState
agent = MagicMock()
agent.action_surface = [node1, node2]

# Set up MalSimAgentStateView
agent_view = MalSimAgentStateView(agent)

# Configure BreadthFirstAttacker
agent_config = {"seed": 42, "randomize": False}
defender_ai = DefendFutureCompromisedDefender(agent_config)

# Should pick node 2 either way
action_node = defender_ai.get_next_action(agent_view)
assert action_node is not None, "Action node shouldn't be None"
assert action_node.id == node2.id