diff --git a/malsim/mal_simulator/agent_state.py b/malsim/mal_simulator/agent_state.py index 7d4f94c9..2bc3ca1c 100644 --- a/malsim/mal_simulator/agent_state.py +++ b/malsim/mal_simulator/agent_state.py @@ -1,5 +1,6 @@ from __future__ import annotations from dataclasses import dataclass +from types import MappingProxyType from maltoolbox.attackgraph import AttackGraphNode from malsim.mal_simulator.simulator_state import MalSimulatorState @@ -17,6 +18,8 @@ class MalSimAgentState: action_surface: frozenset[AttackGraphNode] # Contains all nodes that this agent has performed successfully performed_nodes: frozenset[AttackGraphNode] + # Contains the order of performed nodes + performed_nodes_order: MappingProxyType[int, frozenset[AttackGraphNode]] # Contains the nodes performed successfully in the last step step_performed_nodes: frozenset[AttackGraphNode] # Contains possible nodes that became available in the last step diff --git a/malsim/mal_simulator/attacker_state.py b/malsim/mal_simulator/attacker_state.py index 4d2bd674..06edacf7 100644 --- a/malsim/mal_simulator/attacker_state.py +++ b/malsim/mal_simulator/attacker_state.py @@ -45,6 +45,7 @@ def __getstate__(self) -> dict[str, Any]: state['num_attempts'] = dict(state['num_attempts']) state['ttc_overrides'] = dict(state['ttc_overrides']) state['ttc_value_overrides'] = dict(state['ttc_value_overrides']) + state['performed_nodes_order'] = dict(state['performed_nodes_order']) return state def __setstate__(self, state: dict[str, Any]) -> None: @@ -58,6 +59,11 @@ def __setstate__(self, state: dict[str, Any]) -> None: object.__setattr__( self, 'ttc_value_overrides', MappingProxyType(state['ttc_value_overrides']) ) + object.__setattr__( + self, + 'performed_nodes_order', + MappingProxyType(state['performed_nodes_order']), + ) # set other frozen attributes for key, value in state.items(): if key not in ('num_attempts', 'ttc_overrides', 'ttc_value_overrides'): diff --git a/malsim/mal_simulator/attacker_state_factories.py b/malsim/mal_simulator/attacker_state_factories.py index 5487e9aa..634cd422 100644 --- a/malsim/mal_simulator/attacker_state_factories.py +++ b/malsim/mal_simulator/attacker_state_factories.py @@ -63,8 +63,11 @@ def create_attacker_state( ) action_surface_removals: set[AttackGraphNode] = set() action_surface_additions = new_action_surface + performed_nodes_order: dict[int, frozenset[AttackGraphNode]] = {} - if not sim_state.settings.compromise_entrypoints_at_start: + if sim_state.settings.compromise_entrypoints_at_start: + performed_nodes_order[0] = frozenset(entry_points) + else: # If entrypoints not compromised at start, # we need to put them in action surface new_action_surface |= entry_points @@ -82,6 +85,12 @@ def create_attacker_state( ttc_value_overrides = previous_state.ttc_value_overrides impossible_step_overrides = previous_state.impossible_step_overrides compromised_nodes = previous_state.performed_nodes | step_compromised_nodes + performed_nodes_order = dict(previous_state.performed_nodes_order) + + if step_compromised_nodes: + performed_nodes_order[previous_state.iteration] = frozenset( + step_compromised_nodes + ) # Build on previous attack surface (for performance) action_surface_additions = ( @@ -128,6 +137,7 @@ def create_attacker_state( iteration=(previous_state.iteration + 1) if previous_state else 1, reward_rule=reward_rule, actionability_rule=actionability_rule, + performed_nodes_order=MappingProxyType(performed_nodes_order), ) diff --git a/malsim/mal_simulator/defender_state.py b/malsim/mal_simulator/defender_state.py index 9068e7ba..494b8ad4 100644 --- a/malsim/mal_simulator/defender_state.py +++ b/malsim/mal_simulator/defender_state.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Optional +from types import MappingProxyType +from typing import Any, Optional from maltoolbox.attackgraph import AttackGraphNode from malsim.config.node_property_rule import NodePropertyRule from malsim.mal_simulator.agent_state import MalSimAgentState @@ -27,6 +28,22 @@ class MalSimDefenderState(MalSimAgentState): false_negative_rates_rule: Optional[NodePropertyRule] = None observability_rule: Optional[NodePropertyRule] = None + # Pickling + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + state['performed_nodes_order'] = dict(state['performed_nodes_order']) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + object.__setattr__( + self, + 'performed_nodes_order', + MappingProxyType(state['performed_nodes_order']), + ) + for key, value in state.items(): + if key not in ('performed_nodes_order'): + object.__setattr__(self, key, value) + def get_defender_agents( agent_states: AgentStates, alive_agents: set[str], only_alive: bool = False diff --git a/malsim/mal_simulator/defender_state_factories.py b/malsim/mal_simulator/defender_state_factories.py index 3a1007b7..68d4bc49 100644 --- a/malsim/mal_simulator/defender_state_factories.py +++ b/malsim/mal_simulator/defender_state_factories.py @@ -1,6 +1,7 @@ """Creation/manipulation of defender state""" from collections.abc import Set +from types import MappingProxyType from typing import Optional import numpy as np @@ -43,6 +44,11 @@ def create_defender_state( previous_observed_nodes: Set[AttackGraphNode] = frozenset() action_surface_additions: Set[AttackGraphNode] = action_surface action_surface_removals: Set[AttackGraphNode] = frozenset() + performed_nodes_order: dict[int, frozenset[AttackGraphNode]] = {} + + if step_enabled_defenses: + # Pre enabled defenses go into iteration 0 + performed_nodes_order[0] = frozenset(step_enabled_defenses) else: # Previous rules used if previous state given reward_rule = previous_state.reward_rule @@ -64,6 +70,11 @@ def create_defender_state( previous_observed_nodes = previous_state.observed_nodes action_surface_additions = frozenset() action_surface_removals = step_enabled_defenses + performed_nodes_order = dict(previous_state.performed_nodes_order) + if step_enabled_defenses: + performed_nodes_order[previous_state.iteration] = frozenset( + step_enabled_defenses + ) step_observed_nodes = defender_observed_nodes( observability_rule, @@ -94,6 +105,7 @@ def create_defender_state( actionability_rule=actionability_rule, false_negative_rates_rule=false_negative_rates_rule, false_positive_rates_rule=false_positive_rates_rule, + performed_nodes_order=MappingProxyType(performed_nodes_order), ) diff --git a/malsim/mal_simulator/simulator.py b/malsim/mal_simulator/simulator.py index bb816660..3e0c7508 100644 --- a/malsim/mal_simulator/simulator.py +++ b/malsim/mal_simulator/simulator.py @@ -20,7 +20,6 @@ defender_step, ) from malsim.mal_simulator.node_getters import ( - full_name_dict_to_node_dict, full_names_or_nodes_to_nodes, get_node, ) @@ -694,6 +693,7 @@ def step( agent_state = agent_states[agent_name] if isinstance(agent_state, MalSimDefenderState): + current_iteration = agent_state.iteration # Update defender state updated_defender_state = create_defender_state( sim_state=sim_state, diff --git a/tests/test_mal_simulator.py b/tests/test_mal_simulator.py index 4e21f4b1..7eb710e0 100644 --- a/tests/test_mal_simulator.py +++ b/tests/test_mal_simulator.py @@ -1,6 +1,7 @@ """Test MalSimulator class""" from __future__ import annotations +import random from typing import TYPE_CHECKING from maltoolbox.attackgraph import AttackGraph, AttackGraphNode @@ -1434,3 +1435,79 @@ def test_active_defenses() -> None: sim.get_node('Creds:notDisclosed') in sim.sim_state.graph_state.pre_enabled_defenses ) + + +def test_compromise_order() -> None: + """Verify that the compromise order is correctly recorded""" + + scenario = Scenario.load_from_file( + 'tests/testdata/scenarios/socialEngineering_scenario.yml' + ) + sim = MalSimulator.from_scenario(scenario) + sim.register_defender('Defender1') + states = sim.reset() + + attacker_record = ( + {0: set(states['Attacker1'].performed_nodes_order[0])} + if len(states['Attacker1'].performed_nodes_order) > 0 + else {} + ) + defender_record = ( + {0: set(states['Defender1'].performed_nodes_order[0])} + if len(states['Defender1'].performed_nodes_order) > 0 + else {} + ) + + for i in range(1, 101): + actions: dict[str, list[AttackGraphNode]] = {} + if len(states['Attacker1'].action_surface) == 0 or random.random() < 0.5: + actions['Attacker1'] = [] + else: + actions['Attacker1'] = [ + random.choice(list(states['Attacker1'].action_surface)) + ] + + if len(states['Defender1'].action_surface) == 0 or random.random() < 0.3: + actions['Defender1'] = [] + else: + actions['Defender1'] = [ + random.choice(list(states['Defender1'].action_surface)) + ] + states = sim.step(actions) + if len(sim.recording[i]['Attacker1']) > 0: + attacker_record[i] = set(sim.recording[i]['Attacker1']) + if len(sim.recording[i]['Defender1']) > 0: + defender_record[i] = set(sim.recording[i]['Defender1']) + if sim.done(): + break + + for i in range(max(max(attacker_record.keys()), max(defender_record.keys())) + 1): + if i in attacker_record and i in states['Attacker1'].performed_nodes_order: + assert attacker_record[i] == states['Attacker1'].performed_nodes_order[i], ( + f'Attacker record does not match simulator at time {i}' + ) + elif i in attacker_record: + assert False, ( + f'Attacker record has steps for time {i} but simulator does not' + ) + elif i in states['Attacker1'].performed_nodes_order: + assert False, ( + f'Simulator has steps for time {i} but attacker record does not' + ) + + for i in range(100): + if i in defender_record and i in states['Defender1'].performed_nodes_order: + assert defender_record[i] == states['Defender1'].performed_nodes_order[i], ( + f'Defender record does not match simulator at time {i}' + ) + elif i in defender_record: + assert False, ( + f'Defender record has steps for time {i} but simulator does not' + ) + elif i in states['Defender1'].performed_nodes_order: + assert False, ( + f'Simulator has steps for time {i} but defender record does not' + ) + + assert states['Attacker1'].performed_nodes_order == attacker_record + assert states['Defender1'].performed_nodes_order == defender_record