Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions malsim/mal_simulator/agent_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
from types import MappingProxyType

from maltoolbox.attackgraph import AttackGraphNode
from malsim.mal_simulator.simulator_state import MalSimulatorState
Expand All @@ -17,6 +18,8 @@ class MalSimAgentState:
action_surface: frozenset[AttackGraphNode]
# Contains all nodes that this agent has performed successfully
performed_nodes: frozenset[AttackGraphNode]
# Contains the order of performed nodes
performed_nodes_order: MappingProxyType[int, frozenset[AttackGraphNode]]
# Contains the nodes performed successfully in the last step
step_performed_nodes: frozenset[AttackGraphNode]
# Contains possible nodes that became available in the last step
Expand Down
6 changes: 6 additions & 0 deletions malsim/mal_simulator/attacker_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __getstate__(self) -> dict[str, Any]:
state['num_attempts'] = dict(state['num_attempts'])
state['ttc_overrides'] = dict(state['ttc_overrides'])
state['ttc_value_overrides'] = dict(state['ttc_value_overrides'])
state['performed_nodes_order'] = dict(state['performed_nodes_order'])
return state

def __setstate__(self, state: dict[str, Any]) -> None:
Expand All @@ -58,6 +59,11 @@ def __setstate__(self, state: dict[str, Any]) -> None:
object.__setattr__(
self, 'ttc_value_overrides', MappingProxyType(state['ttc_value_overrides'])
)
object.__setattr__(
self,
'performed_nodes_order',
MappingProxyType(state['performed_nodes_order']),
)
# set other frozen attributes
for key, value in state.items():
if key not in ('num_attempts', 'ttc_overrides', 'ttc_value_overrides'):
Expand Down
12 changes: 11 additions & 1 deletion malsim/mal_simulator/attacker_state_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def create_attacker_state(
)
action_surface_removals: set[AttackGraphNode] = set()
action_surface_additions = new_action_surface
performed_nodes_order: dict[int, frozenset[AttackGraphNode]] = {}

if not sim_state.settings.compromise_entrypoints_at_start:
if sim_state.settings.compromise_entrypoints_at_start:
performed_nodes_order[0] = frozenset(entry_points)
else:
# If entrypoints not compromised at start,
# we need to put them in action surface
new_action_surface |= entry_points
Expand All @@ -82,6 +85,12 @@ def create_attacker_state(
ttc_value_overrides = previous_state.ttc_value_overrides
impossible_step_overrides = previous_state.impossible_step_overrides
compromised_nodes = previous_state.performed_nodes | step_compromised_nodes
performed_nodes_order = dict(previous_state.performed_nodes_order)

if step_compromised_nodes:
performed_nodes_order[previous_state.iteration] = frozenset(
step_compromised_nodes
)

# Build on previous attack surface (for performance)
action_surface_additions = (
Expand Down Expand Up @@ -128,6 +137,7 @@ def create_attacker_state(
iteration=(previous_state.iteration + 1) if previous_state else 1,
reward_rule=reward_rule,
actionability_rule=actionability_rule,
performed_nodes_order=MappingProxyType(performed_nodes_order),
)


Expand Down
19 changes: 18 additions & 1 deletion malsim/mal_simulator/defender_state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import Optional
from types import MappingProxyType
from typing import Any, Optional
from maltoolbox.attackgraph import AttackGraphNode
from malsim.config.node_property_rule import NodePropertyRule
from malsim.mal_simulator.agent_state import MalSimAgentState
Expand Down Expand Up @@ -27,6 +28,22 @@ class MalSimDefenderState(MalSimAgentState):
false_negative_rates_rule: Optional[NodePropertyRule] = None
observability_rule: Optional[NodePropertyRule] = None

# Pickling
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
state['performed_nodes_order'] = dict(state['performed_nodes_order'])
return state

def __setstate__(self, state: dict[str, Any]) -> None:
object.__setattr__(
self,
'performed_nodes_order',
MappingProxyType(state['performed_nodes_order']),
)
for key, value in state.items():
if key not in ('performed_nodes_order'):
object.__setattr__(self, key, value)


def get_defender_agents(
agent_states: AgentStates, alive_agents: set[str], only_alive: bool = False
Expand Down
12 changes: 12 additions & 0 deletions malsim/mal_simulator/defender_state_factories.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Creation/manipulation of defender state"""

from collections.abc import Set
from types import MappingProxyType
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -43,6 +44,11 @@ def create_defender_state(
previous_observed_nodes: Set[AttackGraphNode] = frozenset()
action_surface_additions: Set[AttackGraphNode] = action_surface
action_surface_removals: Set[AttackGraphNode] = frozenset()
performed_nodes_order: dict[int, frozenset[AttackGraphNode]] = {}

if step_enabled_defenses:
# Pre enabled defenses go into iteration 0
performed_nodes_order[0] = frozenset(step_enabled_defenses)
else:
# Previous rules used if previous state given
reward_rule = previous_state.reward_rule
Expand All @@ -64,6 +70,11 @@ def create_defender_state(
previous_observed_nodes = previous_state.observed_nodes
action_surface_additions = frozenset()
action_surface_removals = step_enabled_defenses
performed_nodes_order = dict(previous_state.performed_nodes_order)
if step_enabled_defenses:
performed_nodes_order[previous_state.iteration] = frozenset(
step_enabled_defenses
)

step_observed_nodes = defender_observed_nodes(
observability_rule,
Expand Down Expand Up @@ -94,6 +105,7 @@ def create_defender_state(
actionability_rule=actionability_rule,
false_negative_rates_rule=false_negative_rates_rule,
false_positive_rates_rule=false_positive_rates_rule,
performed_nodes_order=MappingProxyType(performed_nodes_order),
)


Expand Down
2 changes: 1 addition & 1 deletion malsim/mal_simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
defender_step,
)
from malsim.mal_simulator.node_getters import (
full_name_dict_to_node_dict,
full_names_or_nodes_to_nodes,
get_node,
)
Expand Down Expand Up @@ -694,6 +693,7 @@ def step(
agent_state = agent_states[agent_name]

if isinstance(agent_state, MalSimDefenderState):
current_iteration = agent_state.iteration
# Update defender state
updated_defender_state = create_defender_state(
sim_state=sim_state,
Expand Down
77 changes: 77 additions & 0 deletions tests/test_mal_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test MalSimulator class"""

from __future__ import annotations
import random
from typing import TYPE_CHECKING

from maltoolbox.attackgraph import AttackGraph, AttackGraphNode
Expand Down Expand Up @@ -1434,3 +1435,79 @@ def test_active_defenses() -> None:
sim.get_node('Creds:notDisclosed')
in sim.sim_state.graph_state.pre_enabled_defenses
)


def test_compromise_order() -> None:
"""Verify that the compromise order is correctly recorded"""

scenario = Scenario.load_from_file(
'tests/testdata/scenarios/socialEngineering_scenario.yml'
)
sim = MalSimulator.from_scenario(scenario)
sim.register_defender('Defender1')
states = sim.reset()

attacker_record = (
{0: set(states['Attacker1'].performed_nodes_order[0])}
if len(states['Attacker1'].performed_nodes_order) > 0
else {}
)
defender_record = (
{0: set(states['Defender1'].performed_nodes_order[0])}
if len(states['Defender1'].performed_nodes_order) > 0
else {}
)

for i in range(1, 101):
actions: dict[str, list[AttackGraphNode]] = {}
if len(states['Attacker1'].action_surface) == 0 or random.random() < 0.5:
actions['Attacker1'] = []
else:
actions['Attacker1'] = [
random.choice(list(states['Attacker1'].action_surface))
]

if len(states['Defender1'].action_surface) == 0 or random.random() < 0.3:
actions['Defender1'] = []
else:
actions['Defender1'] = [
random.choice(list(states['Defender1'].action_surface))
]
states = sim.step(actions)
if len(sim.recording[i]['Attacker1']) > 0:
attacker_record[i] = set(sim.recording[i]['Attacker1'])
if len(sim.recording[i]['Defender1']) > 0:
defender_record[i] = set(sim.recording[i]['Defender1'])
if sim.done():
break

for i in range(max(max(attacker_record.keys()), max(defender_record.keys())) + 1):
if i in attacker_record and i in states['Attacker1'].performed_nodes_order:
assert attacker_record[i] == states['Attacker1'].performed_nodes_order[i], (
f'Attacker record does not match simulator at time {i}'
)
elif i in attacker_record:
assert False, (
f'Attacker record has steps for time {i} but simulator does not'
)
elif i in states['Attacker1'].performed_nodes_order:
assert False, (
f'Simulator has steps for time {i} but attacker record does not'
)

for i in range(100):
if i in defender_record and i in states['Defender1'].performed_nodes_order:
assert defender_record[i] == states['Defender1'].performed_nodes_order[i], (
f'Defender record does not match simulator at time {i}'
)
elif i in defender_record:
assert False, (
f'Defender record has steps for time {i} but simulator does not'
)
elif i in states['Defender1'].performed_nodes_order:
assert False, (
f'Simulator has steps for time {i} but defender record does not'
)

assert states['Attacker1'].performed_nodes_order == attacker_record
assert states['Defender1'].performed_nodes_order == defender_record