Skip to content
24 changes: 24 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: MyPy maltoolbox
on: [push]

jobs:

build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[dev] .[ml]
- name: Type checking with MyPy
run: |
pip install mypy
mypy malsim tests --install-types --non-interactive
9 changes: 5 additions & 4 deletions malsim/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import logging
from typing import Any, Optional

from .mal_simulator import MalSimulator
from .agents import DecisionAgent
Expand All @@ -12,11 +13,11 @@
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)

def run_simulation(sim: MalSimulator, agents: list[dict]):
def run_simulation(sim: MalSimulator, agents: list[dict[str, Any]]) -> None:
"""Run a simulation with agents"""

sim.reset()
total_rewards = {agent_dict['name']: 0 for agent_dict in agents}
total_rewards = {agent_dict['name']: 0.0 for agent_dict in agents}
all_agents_term_or_trunc = False

logger.info("Starting CLI env simulator.")
Expand All @@ -29,7 +30,7 @@ def run_simulation(sim: MalSimulator, agents: list[dict]):

# Select actions for each agent
for agent_dict in agents:
decision_agent: DecisionAgent = agent_dict.get('agent')
decision_agent: Optional[DecisionAgent] = agent_dict.get('agent')
agent_name = agent_dict['name']
if decision_agent is None:
logger.warning(
Expand Down Expand Up @@ -66,7 +67,7 @@ def run_simulation(sim: MalSimulator, agents: list[dict]):
agent_name = agent_dict['name']
print(f'Total reward "{agent_name}"', total_rewards[agent_name])

def main():
def main() -> None:
"""Entrypoint function of the MAL Toolbox CLI"""
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions malsim/agents/decision_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A decision agent is a heuristic agent"""

from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Any
from abc import ABC, abstractmethod

if TYPE_CHECKING:
Expand All @@ -14,7 +14,7 @@ class DecisionAgent(ABC):
def get_next_action(
self,
agent_state: MalSimAgentStateView,
**kwargs
**kwargs: Any
) -> Optional[AttackGraphNode]:
"""
Select next action the agent will work with.
Expand Down
10 changes: 5 additions & 5 deletions malsim/agents/heuristic_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import logging
import math

Expand All @@ -16,7 +16,7 @@
class DefendCompromisedDefender(DecisionAgent):
"""A defender that defends compromised assets using notPresent"""

def __init__(self, agent_config, **_):
def __init__(self, agent_config: dict[str, Any], **_: Any):
# Seed and rng not currently used
seed = (
agent_config["seed"]
Expand All @@ -30,7 +30,7 @@ def __init__(self, agent_config, **_):
)

def get_next_action(
self, agent_state: MalSimAgentStateView, **kwargs
self, agent_state: MalSimAgentStateView, **kwargs: Any
) -> Optional[AttackGraphNode]:

"""Return an action that disables a compromised node"""
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_next_action(
class DefendFutureCompromisedDefender(DecisionAgent):
"""A defender that defends compromised assets using notPresent"""

def __init__(self, agent_config, **_):
def __init__(self, agent_config: dict[str, Any], **_: Any):
# Seed and rng not currently used
seed = (
agent_config["seed"]
Expand All @@ -85,7 +85,7 @@ def __init__(self, agent_config, **_):
)

def get_next_action(
self, agent_state: MalSimAgentStateView, **kwargs
self, agent_state: MalSimAgentStateView, **kwargs: Any
) -> Optional[AttackGraphNode]:

"""Return an action that disables a compromised node"""
Expand Down
10 changes: 5 additions & 5 deletions malsim/agents/keyboard_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Any

from .decision_agent import DecisionAgent
from ..mal_simulator import MalSimAgentStateView
Expand All @@ -13,14 +13,14 @@
class KeyboardAgent(DecisionAgent):
"""An agent that makes decisions by asking user for keyboard input"""

def __init__(self, _, **kwargs):
def __init__(self, _: Any, **kwargs: Any):
super().__init__(**kwargs)
logger.info("Creating KeyboardAgent")

def get_next_action(
self,
agent_state: MalSimAgentStateView,
**kwargs
**kwargs: Any
) -> Optional[AttackGraphNode]:
"""Compute action from action_surface"""

Expand All @@ -35,13 +35,13 @@ def valid_action(user_input: str) -> bool:

return 0 <= node <= len(agent_state.action_surface)

def get_action_object(user_input: str) -> tuple:
def get_action_object(user_input: str) -> Optional[int]:
node = int(user_input) if user_input != "" else None
return node

if not agent_state.action_surface:
print("No actions to pick for defender")
return []
return None

index_to_node = dict(enumerate(agent_state.action_surface))
user_input = "xxx"
Expand Down
6 changes: 3 additions & 3 deletions malsim/agents/passive_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A passive agent that always choose to do nothing"""

from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Any

from .decision_agent import DecisionAgent
from ..mal_simulator import MalSimAgentStateView
Expand All @@ -11,13 +11,13 @@
from maltoolbox.attackgraph import AttackGraphNode

class PassiveAgent(DecisionAgent):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
...

def get_next_action(
self,
agent_state: MalSimAgentStateView,
**kwargs
**kwargs: Any
) -> Optional[AttackGraphNode]:
# A passive agent never does anything
return None
6 changes: 3 additions & 3 deletions malsim/agents/searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class BreadthFirstAttacker(DecisionAgent):
'seed': None,
}

def __init__(self, agent_config: dict) -> None:
def __init__(self, agent_config: dict[str, Any]) -> None:
"""Initialize a BFS/DFS agent.

Args:
Expand All @@ -47,7 +47,7 @@ def __init__(self, agent_config: dict) -> None:
self._started = False

def get_next_action(
self, agent_state: MalSimAgentStateView, **kwargs
self, agent_state: MalSimAgentStateView, **kwargs: Any
) -> Optional[AttackGraphNode]:
"""Receive the next action according to agent policy (bfs/dfs)"""

Expand All @@ -69,7 +69,7 @@ def _update_targets(
self,
new_nodes: set[AttackGraphNode],
disabled_nodes: set[AttackGraphNode],
):
) -> None:
new_targets: list[AttackGraphNode] = []
if self._settings['seed']:
# If a seed is set, we assume the user wants determinism in the
Expand Down
2 changes: 1 addition & 1 deletion malsim/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .malsim_vectorized_obs_env import MalSimVectorizedObsEnv
from .gym_envs import AttackerEnv, DefenderEnv, register_envs
from .gym_envs import AttackerEnv, DefenderEnv, register_envs # type: ignore

# not needed, used to silence ruff F401
__all__ = [
Expand Down
15 changes: 10 additions & 5 deletions malsim/envs/base_classes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from ..mal_simulator import MalSimulator, MalSimAgentStateView

class MalSimEnv(ABC):
Expand All @@ -7,24 +8,28 @@ def __init__(self, sim: MalSimulator):
self.sim = sim

@abstractmethod
def step(self, actions):
def step(self, actions: Any) -> Any:
...

def reset(self, seed=None, options=None):
def reset(
self,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None
) -> None:
self.sim.reset(seed=seed, options=options)

def register_attacker(
self, attacker_name: str, attacker_id: int
):
) -> None:
self.sim.register_attacker(attacker_name, attacker_id)

def register_defender(
self, defender_name: str
):
) -> None:
self.sim.register_defender(defender_name)

def get_agent_state(self, agent_name: str) -> MalSimAgentStateView:
return self.sim.agent_states[agent_name]

def render(self):
def render(self) -> None:
pass
Loading