Skip to content

Commit

Permalink
Merge pull request #215 from stratosphereips/ondra-add-sankey-visuali…
Browse files Browse the repository at this point in the history
…zation

Ondra add sankey visualization
  • Loading branch information
ondrej-lukas authored Jul 10, 2024

Verified

This commit was signed with the committer’s verified signature.
slarse Simon Larsén
2 parents 4fff6ea + 4442750 commit 4e11806
Showing 2 changed files with 141 additions and 14 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -11,12 +11,14 @@ iniconfig==2.0.0
Jinja2==3.1.3
jsonlines==4.0.0
jsonpickle==3.0.2
kaleido==0.2.1
MarkupSafe==2.1.5
mypy-extensions==1.0.0
netaddr==1.2.1
networkx==3.2.1
numpy==1.26.4
packaging==23.2
plotly==5.22.0
pluggy==1.4.0
plum-dispatch==2.2.2
py-flags==1.1.4
153 changes: 139 additions & 14 deletions utils/trajectory_analysis.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
import umap
import plotly.graph_objects as go
from sklearn.preprocessing import StandardScaler


@@ -393,6 +394,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non
ActionType.FindData:3,
ActionType.ExploitService:4,
ActionType.ExfiltrateData:5,
"Invalid":6
}
counts = {
"Start":0,
@@ -401,6 +403,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non
ActionType.FindData:0,
ActionType.ExploitService:0,
ActionType.ExfiltrateData:0,
"Invalid":0
}
transitions = np.zeros([len(counts), len(counts)])
for play in game_plays:
@@ -409,18 +412,21 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non
previous_action = "Start"
for action_dict in play["trajectory"]["actions"]:
counts[previous_action] += 1
action = Action.from_dict(action_dict).type
try:
action = Action.from_dict(action_dict).type
except ValueError:
action = "Invalid"
transitions[idx_mapping[previous_action], idx_mapping[action]] += 1
previous_action = action
# make transitions probabilities
for action_type, count in counts.items():
transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/count
transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/max(count,1)
transitions = np.round(transitions, 2)

fig, ax = plt.subplots()
_ = ax.imshow(transitions)
ax.set_xticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys())
ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys())
ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys())
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
@@ -433,6 +439,121 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non
fig.tight_layout()
fig.savefig(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"), dpi=600)

def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason=None, probs=True, threshold=0)->dict:
idx_mapping = {
"Start":0,
ActionType.ScanNetwork:1,
ActionType.FindServices:2,
ActionType.FindData:3,
ActionType.ExploitService:4,
ActionType.ExfiltrateData:5,
"Invalid":6
}
counts = {
"Start":0,
ActionType.ScanNetwork:0,
ActionType.FindServices:0,
ActionType.FindData:0,
ActionType.ExploitService:0,
ActionType.ExfiltrateData:0,
"Invalid":0
}
transitions = np.zeros([len(counts), len(counts)])
for play in game_plays:
if end_reason and play["end_reason"] not in end_reason:
continue
previous_action = "Start"
for action_dict in play["trajectory"]["actions"]:
counts[previous_action] += 1
action = Action.from_dict(action_dict).type
transitions[idx_mapping[previous_action], idx_mapping[action]] += 1
previous_action = action
if probs:
# convert values to probabilities
for action_type, count in counts.items():
transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/max(count,1)
transitions = np.round(transitions, 2)

# Create a list of unique labels
labels = [str(x).lstrip("ActionType.") for x in idx_mapping.keys()]
labels += labels[1:]
# Define colors for each node
node_colors = [
'rgba(255, 0, 0, 0.8)',
'rgba(255, 153, 0, 0.8)',
'rgba(0, 204, 0, 0.8)',
'rgba(51, 204, 204, 0.8)',
'rgba(51, 102, 255, 0.8)',
'rgba(204, 204, 0, 0.8)',
'rgba(255, 0, 102, 0.8)',
'rgba(255, 153, 0, 0.8)',
'rgba(0, 204, 0, 0.8)',
'rgba(51, 204, 204, 0.8)',
'rgba(51, 102, 255, 0.8)',
'rgba(204, 204, 0, 0.8)',
'rgba(255, 0, 102, 0.8)',
]

# use hard-coded edges for now
source_indices = [0,0,0,0,0,0,1,1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3,3,4,4,4,4,4,4,5,5,5,5,5,5,6,6,6,6,6,6,]
target_indices = [1,2,3,4,5,6,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,]

# only show transitons with value higher than threshold
transitions = np.where(transitions < threshold, 0,transitions)
# no edge leads to start so we can skip the first column of the transition matrix
values = transitions[:,1:].flatten()
max_value = max(values)
opacities = [min(1,1.5*value / max_value) for value in values]

# Generate link colors based on source node colors and opacity
link_colors = []
for s, opacity in zip(source_indices, opacities):
color = node_colors[s]
# Adjust the color's alpha value based on opacity
rgba_color = color[:-4] + f'{opacity})'
link_colors.append(rgba_color)

# generate node positions
valid_nodes = [k for (k,v) in counts.items() if v > 0]
# add start
x_pos = [0.001]
# first column
x_pos += [0.25 for i in range(len(valid_nodes)-1)]
# second column
x_pos += [0.999 for i in range(len(valid_nodes)-1)]
# start
y_pos = [0.5]
for _ in range(2):
for node_idx in range(0, len(valid_nodes)-1):
y_pos.append(0.001 + node_idx*0.999/(len(valid_nodes)-1))
# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
arrangement='snap',
valueformat = ".2f",
valuesuffix = "",
node=dict(
pad=5,
thickness=10,
line=dict(color="black", width=0.5),
label=labels,
color=node_colors,
x = x_pos,
y = y_pos,
),
link=dict(
source=source_indices,
target=target_indices,
value=values,
color=link_colors,
arrowlen=5,
)
)])

fig.update_layout(title_text=f"ActionType Sankey Diagram - {play['model']}")
fig.write_image(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"))
fig.update_layout(font_size=18)
fig.show()

def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple:
edges = {}
for play in game_plays:
@@ -522,19 +643,23 @@ def get_change_in_nodes(edge_list1, edge_list2):
# filter trajectories based on their ending
END_REASON = None
#END_REASON = ["goal_reached"]
game_plays = read_json("./trajectories/2024-07-02_BaseAgent_Attacker.jsonl")
#game_plays = read_json("./trajectories/2024-07-03_QAgent_Attacker.jsonl")
game_plays = read_json("trajectories/2024-07-02_BaseAgent_Attacker.jsonl")
for play in game_plays:
play["model"] = "Optimal"
print(compute_mean_length(game_plays))
get_action_type_barplot_per_step(game_plays, end_reason=END_REASON)
get_action_type_histogram_per_step(game_plays, end_reason=END_REASON)
generate_mdp_from_trajecotries(game_plays,filename="MDP_visualization_optimal", end_reason=END_REASON)
states = {}
actions = {}
edges_optimal = gameplay_graph(game_plays, states, actions,end_reason=END_REASON)
state_to_id = {v:k for k,v in states.items()}
action_to_id = {v:k for k,v in states.items()}
get_graph_stats(edges_optimal, state_to_id, action_to_id)
generate_mdp_from_trajecotries(game_plays, filename="mdp_test", end_reason=END_REASON)
generate_sankey_from_trajecotries(game_plays, filename="sankey_test", end_reason=END_REASON, threshold=0.1)
# print(compute_mean_length(game_plays))
# get_action_type_barplot_per_step(game_plays, end_reason=END_REASON)
# get_action_type_histogram_per_step(game_plays, end_reason=END_REASON)
# generate_mdp_from_trajecotries(game_plays,filename="MDP_visualization_optimal", end_reason=END_REASON)
# states = {}
# actions = {}
# edges_optimal = gameplay_graph(game_plays, states, actions,end_reason=END_REASON)
# state_to_id = {v:k for k,v in states.items()}
# action_to_id = {v:k for k,v in states.items()}
# get_graph_stats(edges_optimal, state_to_id, action_to_id)

# # load trajectories from files
# game_plays_q_learning = read_json("./NSG_trajectories_q_agent_marl.experiment0004-episodes-20000.json")
# for play in game_plays_q_learning:

0 comments on commit 4e11806

Please sign in to comment.