From e1d5855cf505cf0bc9aeec1f30e8554dc8f2b552 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Wed, 10 Jul 2024 11:52:58 +0200 Subject: [PATCH 1/3] Firstt attempt --- utils/trajectory_analysis.py | 96 +++++++++++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 1 deletion(-) diff --git a/utils/trajectory_analysis.py b/utils/trajectory_analysis.py index 355960d..4c54a6f 100644 --- a/utils/trajectory_analysis.py +++ b/utils/trajectory_analysis.py @@ -420,7 +420,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non fig, ax = plt.subplots() _ = ax.imshow(transitions) ax.set_xticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) - ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) + ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. @@ -433,6 +433,100 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non fig.tight_layout() fig.savefig(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"), dpi=600) +def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason=None)->dict: + idx_mapping = { + "Start":0, + ActionType.ScanNetwork:1, + ActionType.FindServices:2, + ActionType.FindData:3, + ActionType.ExploitService:4, + ActionType.ExfiltrateData:5, + } + counts = { + "Start":0, + ActionType.ScanNetwork:0, + ActionType.FindServices:0, + ActionType.FindData:0, + ActionType.ExploitService:0, + ActionType.ExfiltrateData:0, + } + transitions = np.zeros([len(counts), len(counts)]) + for play in game_plays: + if end_reason and play["end_reason"] not in end_reason: + continue + previous_action = "Start" + for action_dict in play["trajectory"]["actions"]: + counts[previous_action] += 1 + action = Action.from_dict(action_dict).type + transitions[idx_mapping[previous_action], idx_mapping[action]] += 1 + previous_action = action + # Create a list of unique labels + labels = list(idx_mapping.keys()) + + # Define colors for each node + node_colors = ['rgba(31, 119, 180, 0.8)', 'rgba(255, 127, 14, 0.8)', 'rgba(44, 160, 44, 0.8)', + 'rgba(214, 39, 40, 0.8)', 'rgba(148, 103, 189, 0.8)', 'rgba(170, 103, 189, 0.8)', 'rgba(255,221,51,0.8)'] + + # Convert the source and target lists to indices + source_indices = [labels.index(s) for s in idx_mapping.keys()] + target_indices = [labels.index(t) for t in idx_mapping.keys()] + + # Normalize the values to use for opacity + max_value = max(values) + opacities = [value / max_value for value in values] + + # Generate link colors based on source node colors and opacity + link_colors = [] + for s, opacity in zip(source_indices, opacities): + color = node_colors[s] + # Adjust the color's alpha value based on opacity + rgba_color = color[:-4] + f'{opacity})' + link_colors.append(rgba_color) + + # Create the Sankey diagram + fig = go.Figure(data=[go.Sankey( + node=dict( + pad=15, + thickness=20, + line=dict(color="black", width=0.5), + label=labels, + color=node_colors + ), + link=dict( + source=source_indices, + target=target_indices, + value=values, + color=link_colors + ) + )]) + + fig.update_layout(title_text="Sankey Diagram with Node and Link Colors", font_size=10) + fig.show() + + + + # make transitions probabilities + # for action_type, count in counts.items(): + # transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/count + # transitions = np.round(transitions, 2) + + # fig, ax = plt.subplots() + # _ = ax.imshow(transitions) + # ax.set_xticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) + # ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) + # plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + # rotation_mode="anchor") + # # Loop over data dimensions and create text annotations. + # for i in range(len(idx_mapping)): + # for j in range(len(idx_mapping)): + # _ = ax.text(j, i, transitions[i, j], + # ha="center", va="center", color="w") + + # ax.set_title(f"Visualization of MDP for {play['model']}") + # fig.tight_layout() + # fig.savefig(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"), dpi=600) + + def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: edges = {} for play in game_plays: From e9db5056a0c4dac90f86d9343f7598df7513558d Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Wed, 10 Jul 2024 17:08:32 +0200 Subject: [PATCH 2/3] Code for sankey visualization of the MDP --- utils/trajectory_analysis.py | 167 +++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 68 deletions(-) diff --git a/utils/trajectory_analysis.py b/utils/trajectory_analysis.py index 4c54a6f..df19d01 100644 --- a/utils/trajectory_analysis.py +++ b/utils/trajectory_analysis.py @@ -7,6 +7,7 @@ import matplotlib from mpl_toolkits.axes_grid1 import make_axes_locatable import umap +import plotly.graph_objects as go from sklearn.preprocessing import StandardScaler @@ -393,6 +394,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non ActionType.FindData:3, ActionType.ExploitService:4, ActionType.ExfiltrateData:5, + "Invalid":6 } counts = { "Start":0, @@ -401,6 +403,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non ActionType.FindData:0, ActionType.ExploitService:0, ActionType.ExfiltrateData:0, + "Invalid":0 } transitions = np.zeros([len(counts), len(counts)]) for play in game_plays: @@ -409,12 +412,15 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non previous_action = "Start" for action_dict in play["trajectory"]["actions"]: counts[previous_action] += 1 - action = Action.from_dict(action_dict).type + try: + action = Action.from_dict(action_dict).type + except ValueError: + action = "Invalid" transitions[idx_mapping[previous_action], idx_mapping[action]] += 1 previous_action = action # make transitions probabilities for action_type, count in counts.items(): - transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/count + transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/max(count,1) transitions = np.round(transitions, 2) fig, ax = plt.subplots() @@ -433,7 +439,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non fig.tight_layout() fig.savefig(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"), dpi=600) -def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason=None)->dict: +def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason=None, probs=True, threshold=0)->dict: idx_mapping = { "Start":0, ActionType.ScanNetwork:1, @@ -441,6 +447,7 @@ def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason= ActionType.FindData:3, ActionType.ExploitService:4, ActionType.ExfiltrateData:5, + "Invalid":6 } counts = { "Start":0, @@ -449,6 +456,7 @@ def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason= ActionType.FindData:0, ActionType.ExploitService:0, ActionType.ExfiltrateData:0, + "Invalid":0 } transitions = np.zeros([len(counts), len(counts)]) for play in game_plays: @@ -460,73 +468,92 @@ def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason= action = Action.from_dict(action_dict).type transitions[idx_mapping[previous_action], idx_mapping[action]] += 1 previous_action = action + if probs: + # convert values to probabilities + for action_type, count in counts.items(): + transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/max(count,1) + transitions = np.round(transitions, 2) + # Create a list of unique labels - labels = list(idx_mapping.keys()) - + labels = [str(x).lstrip("ActionType.") for x in idx_mapping.keys()] + labels += labels[1:] # Define colors for each node - node_colors = ['rgba(31, 119, 180, 0.8)', 'rgba(255, 127, 14, 0.8)', 'rgba(44, 160, 44, 0.8)', - 'rgba(214, 39, 40, 0.8)', 'rgba(148, 103, 189, 0.8)', 'rgba(170, 103, 189, 0.8)', 'rgba(255,221,51,0.8)'] - - # Convert the source and target lists to indices - source_indices = [labels.index(s) for s in idx_mapping.keys()] - target_indices = [labels.index(t) for t in idx_mapping.keys()] - - # Normalize the values to use for opacity + node_colors = [ + 'rgba(255, 0, 0, 0.8)', + 'rgba(255, 153, 0, 0.8)', + 'rgba(0, 204, 0, 0.8)', + 'rgba(51, 204, 204, 0.8)', + 'rgba(51, 102, 255, 0.8)', + 'rgba(204, 204, 0, 0.8)', + 'rgba(255, 0, 102, 0.8)', + 'rgba(255, 153, 0, 0.8)', + 'rgba(0, 204, 0, 0.8)', + 'rgba(51, 204, 204, 0.8)', + 'rgba(51, 102, 255, 0.8)', + 'rgba(204, 204, 0, 0.8)', + 'rgba(255, 0, 102, 0.8)', + ] + + # use hard-coded edges for now + source_indices = [0,0,0,0,0,0,1,1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3,3,4,4,4,4,4,4,5,5,5,5,5,5,6,6,6,6,6,6,] + target_indices = [1,2,3,4,5,6,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,] + + # only show transitons with value higher than threshold + transitions = np.where(transitions < threshold, 0,transitions) + # no edge leads to start so we can skip the first column of the transition matrix + values = transitions[:,1:].flatten() max_value = max(values) - opacities = [value / max_value for value in values] + opacities = [min(1,1.5*value / max_value) for value in values] # Generate link colors based on source node colors and opacity link_colors = [] for s, opacity in zip(source_indices, opacities): - color = node_colors[s] - # Adjust the color's alpha value based on opacity - rgba_color = color[:-4] + f'{opacity})' - link_colors.append(rgba_color) - + color = node_colors[s] + # Adjust the color's alpha value based on opacity + rgba_color = color[:-4] + f'{opacity})' + link_colors.append(rgba_color) + + # generate node positions + valid_nodes = [k for (k,v) in counts.items() if v > 0] + # add start + x_pos = [0.001] + # first column + x_pos += [0.25 for i in range(len(valid_nodes)-1)] + # second column + x_pos += [0.999 for i in range(len(valid_nodes)-1)] + # start + y_pos = [0.5] + for _ in range(2): + for node_idx in range(0, len(valid_nodes)-1): + y_pos.append(0.001 + node_idx*0.999/(len(valid_nodes)-1)) # Create the Sankey diagram fig = go.Figure(data=[go.Sankey( - node=dict( - pad=15, - thickness=20, - line=dict(color="black", width=0.5), - label=labels, - color=node_colors - ), - link=dict( - source=source_indices, - target=target_indices, - value=values, - color=link_colors - ) + arrangement='snap', + valueformat = ".2f", + valuesuffix = "", + node=dict( + pad=5, + thickness=10, + line=dict(color="black", width=0.5), + label=labels, + color=node_colors, + x = x_pos, + y = y_pos, + ), + link=dict( + source=source_indices, + target=target_indices, + value=values, + color=link_colors, + arrowlen=5, + ) )]) - fig.update_layout(title_text="Sankey Diagram with Node and Link Colors", font_size=10) + fig.update_layout(title_text=f"ActionType Sankey Diagram - {play['model']}") + fig.write_image(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png")) + fig.update_layout(font_size=18) fig.show() - - - - # make transitions probabilities - # for action_type, count in counts.items(): - # transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/count - # transitions = np.round(transitions, 2) - - # fig, ax = plt.subplots() - # _ = ax.imshow(transitions) - # ax.set_xticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) - # ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) - # plt.setp(ax.get_xticklabels(), rotation=45, ha="right", - # rotation_mode="anchor") - # # Loop over data dimensions and create text annotations. - # for i in range(len(idx_mapping)): - # for j in range(len(idx_mapping)): - # _ = ax.text(j, i, transitions[i, j], - # ha="center", va="center", color="w") - - # ax.set_title(f"Visualization of MDP for {play['model']}") - # fig.tight_layout() - # fig.savefig(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"), dpi=600) - def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: edges = {} for play in game_plays: @@ -616,19 +643,23 @@ def get_change_in_nodes(edge_list1, edge_list2): # filter trajectories based on their ending END_REASON = None #END_REASON = ["goal_reached"] - game_plays = read_json("./trajectories/2024-07-02_BaseAgent_Attacker.jsonl") + #game_plays = read_json("./trajectories/2024-07-03_QAgent_Attacker.jsonl") + game_plays = read_json("trajectories/2024-07-02_BaseAgent_Attacker.jsonl") for play in game_plays: play["model"] = "Optimal" - print(compute_mean_length(game_plays)) - get_action_type_barplot_per_step(game_plays, end_reason=END_REASON) - get_action_type_histogram_per_step(game_plays, end_reason=END_REASON) - generate_mdp_from_trajecotries(game_plays,filename="MDP_visualization_optimal", end_reason=END_REASON) - states = {} - actions = {} - edges_optimal = gameplay_graph(game_plays, states, actions,end_reason=END_REASON) - state_to_id = {v:k for k,v in states.items()} - action_to_id = {v:k for k,v in states.items()} - get_graph_stats(edges_optimal, state_to_id, action_to_id) + generate_mdp_from_trajecotries(game_plays, filename="mdp_test", end_reason=END_REASON) + generate_sankey_from_trajecotries(game_plays, filename="sankey_test", end_reason=END_REASON, threshold=0.1) + # print(compute_mean_length(game_plays)) + # get_action_type_barplot_per_step(game_plays, end_reason=END_REASON) + # get_action_type_histogram_per_step(game_plays, end_reason=END_REASON) + # generate_mdp_from_trajecotries(game_plays,filename="MDP_visualization_optimal", end_reason=END_REASON) + # states = {} + # actions = {} + # edges_optimal = gameplay_graph(game_plays, states, actions,end_reason=END_REASON) + # state_to_id = {v:k for k,v in states.items()} + # action_to_id = {v:k for k,v in states.items()} + # get_graph_stats(edges_optimal, state_to_id, action_to_id) + # # load trajectories from files # game_plays_q_learning = read_json("./NSG_trajectories_q_agent_marl.experiment0004-episodes-20000.json") # for play in game_plays_q_learning: From 44427501d678852c87ea74f8440ed85eecf14ca8 Mon Sep 17 00:00:00 2001 From: Ondrej Lukas Date: Wed, 10 Jul 2024 17:10:42 +0200 Subject: [PATCH 3/3] Visualization requirements --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 5069ca8..629ca25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,12 +11,14 @@ iniconfig==2.0.0 Jinja2==3.1.3 jsonlines==4.0.0 jsonpickle==3.0.2 +kaleido==0.2.1 MarkupSafe==2.1.5 mypy-extensions==1.0.0 netaddr==1.2.1 networkx==3.2.1 numpy==1.26.4 packaging==23.2 +plotly==5.22.0 pluggy==1.4.0 plum-dispatch==2.2.2 py-flags==1.1.4