diff --git a/requirements.txt b/requirements.txt index 5069ca8..629ca25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,12 +11,14 @@ iniconfig==2.0.0 Jinja2==3.1.3 jsonlines==4.0.0 jsonpickle==3.0.2 +kaleido==0.2.1 MarkupSafe==2.1.5 mypy-extensions==1.0.0 netaddr==1.2.1 networkx==3.2.1 numpy==1.26.4 packaging==23.2 +plotly==5.22.0 pluggy==1.4.0 plum-dispatch==2.2.2 py-flags==1.1.4 diff --git a/utils/trajectory_analysis.py b/utils/trajectory_analysis.py index 355960d..df19d01 100644 --- a/utils/trajectory_analysis.py +++ b/utils/trajectory_analysis.py @@ -7,6 +7,7 @@ import matplotlib from mpl_toolkits.axes_grid1 import make_axes_locatable import umap +import plotly.graph_objects as go from sklearn.preprocessing import StandardScaler @@ -393,6 +394,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non ActionType.FindData:3, ActionType.ExploitService:4, ActionType.ExfiltrateData:5, + "Invalid":6 } counts = { "Start":0, @@ -401,6 +403,7 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non ActionType.FindData:0, ActionType.ExploitService:0, ActionType.ExfiltrateData:0, + "Invalid":0 } transitions = np.zeros([len(counts), len(counts)]) for play in game_plays: @@ -409,18 +412,21 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non previous_action = "Start" for action_dict in play["trajectory"]["actions"]: counts[previous_action] += 1 - action = Action.from_dict(action_dict).type + try: + action = Action.from_dict(action_dict).type + except ValueError: + action = "Invalid" transitions[idx_mapping[previous_action], idx_mapping[action]] += 1 previous_action = action # make transitions probabilities for action_type, count in counts.items(): - transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/count + transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/max(count,1) transitions = np.round(transitions, 2) fig, ax = plt.subplots() _ = ax.imshow(transitions) ax.set_xticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) - ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) + ax.set_yticks(np.arange(len(idx_mapping)), labels=idx_mapping.keys()) plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. @@ -433,6 +439,121 @@ def generate_mdp_from_trajecotries(game_plays:list, filename:str, end_reason=Non fig.tight_layout() fig.savefig(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png"), dpi=600) +def generate_sankey_from_trajecotries(game_plays:list, filename:str, end_reason=None, probs=True, threshold=0)->dict: + idx_mapping = { + "Start":0, + ActionType.ScanNetwork:1, + ActionType.FindServices:2, + ActionType.FindData:3, + ActionType.ExploitService:4, + ActionType.ExfiltrateData:5, + "Invalid":6 + } + counts = { + "Start":0, + ActionType.ScanNetwork:0, + ActionType.FindServices:0, + ActionType.FindData:0, + ActionType.ExploitService:0, + ActionType.ExfiltrateData:0, + "Invalid":0 + } + transitions = np.zeros([len(counts), len(counts)]) + for play in game_plays: + if end_reason and play["end_reason"] not in end_reason: + continue + previous_action = "Start" + for action_dict in play["trajectory"]["actions"]: + counts[previous_action] += 1 + action = Action.from_dict(action_dict).type + transitions[idx_mapping[previous_action], idx_mapping[action]] += 1 + previous_action = action + if probs: + # convert values to probabilities + for action_type, count in counts.items(): + transitions[idx_mapping[action_type]] = transitions[idx_mapping[action_type]]/max(count,1) + transitions = np.round(transitions, 2) + + # Create a list of unique labels + labels = [str(x).lstrip("ActionType.") for x in idx_mapping.keys()] + labels += labels[1:] + # Define colors for each node + node_colors = [ + 'rgba(255, 0, 0, 0.8)', + 'rgba(255, 153, 0, 0.8)', + 'rgba(0, 204, 0, 0.8)', + 'rgba(51, 204, 204, 0.8)', + 'rgba(51, 102, 255, 0.8)', + 'rgba(204, 204, 0, 0.8)', + 'rgba(255, 0, 102, 0.8)', + 'rgba(255, 153, 0, 0.8)', + 'rgba(0, 204, 0, 0.8)', + 'rgba(51, 204, 204, 0.8)', + 'rgba(51, 102, 255, 0.8)', + 'rgba(204, 204, 0, 0.8)', + 'rgba(255, 0, 102, 0.8)', + ] + + # use hard-coded edges for now + source_indices = [0,0,0,0,0,0,1,1,1,1,1,1,2,2,2,2,2,2,3,3,3,3,3,3,4,4,4,4,4,4,5,5,5,5,5,5,6,6,6,6,6,6,] + target_indices = [1,2,3,4,5,6,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,7,8,9,10,11,12,] + + # only show transitons with value higher than threshold + transitions = np.where(transitions < threshold, 0,transitions) + # no edge leads to start so we can skip the first column of the transition matrix + values = transitions[:,1:].flatten() + max_value = max(values) + opacities = [min(1,1.5*value / max_value) for value in values] + + # Generate link colors based on source node colors and opacity + link_colors = [] + for s, opacity in zip(source_indices, opacities): + color = node_colors[s] + # Adjust the color's alpha value based on opacity + rgba_color = color[:-4] + f'{opacity})' + link_colors.append(rgba_color) + + # generate node positions + valid_nodes = [k for (k,v) in counts.items() if v > 0] + # add start + x_pos = [0.001] + # first column + x_pos += [0.25 for i in range(len(valid_nodes)-1)] + # second column + x_pos += [0.999 for i in range(len(valid_nodes)-1)] + # start + y_pos = [0.5] + for _ in range(2): + for node_idx in range(0, len(valid_nodes)-1): + y_pos.append(0.001 + node_idx*0.999/(len(valid_nodes)-1)) + # Create the Sankey diagram + fig = go.Figure(data=[go.Sankey( + arrangement='snap', + valueformat = ".2f", + valuesuffix = "", + node=dict( + pad=5, + thickness=10, + line=dict(color="black", width=0.5), + label=labels, + color=node_colors, + x = x_pos, + y = y_pos, + ), + link=dict( + source=source_indices, + target=target_indices, + value=values, + color=link_colors, + arrowlen=5, + ) + )]) + + fig.update_layout(title_text=f"ActionType Sankey Diagram - {play['model']}") + fig.write_image(os.path.join("figures", f"{filename}_{END_REASON if end_reason else ''}.png")) + fig.update_layout(font_size=18) + fig.show() + def gameplay_graph(game_plays:list, states, actions, end_reason=None)->tuple: edges = {} for play in game_plays: @@ -522,19 +643,23 @@ def get_change_in_nodes(edge_list1, edge_list2): # filter trajectories based on their ending END_REASON = None #END_REASON = ["goal_reached"] - game_plays = read_json("./trajectories/2024-07-02_BaseAgent_Attacker.jsonl") + #game_plays = read_json("./trajectories/2024-07-03_QAgent_Attacker.jsonl") + game_plays = read_json("trajectories/2024-07-02_BaseAgent_Attacker.jsonl") for play in game_plays: play["model"] = "Optimal" - print(compute_mean_length(game_plays)) - get_action_type_barplot_per_step(game_plays, end_reason=END_REASON) - get_action_type_histogram_per_step(game_plays, end_reason=END_REASON) - generate_mdp_from_trajecotries(game_plays,filename="MDP_visualization_optimal", end_reason=END_REASON) - states = {} - actions = {} - edges_optimal = gameplay_graph(game_plays, states, actions,end_reason=END_REASON) - state_to_id = {v:k for k,v in states.items()} - action_to_id = {v:k for k,v in states.items()} - get_graph_stats(edges_optimal, state_to_id, action_to_id) + generate_mdp_from_trajecotries(game_plays, filename="mdp_test", end_reason=END_REASON) + generate_sankey_from_trajecotries(game_plays, filename="sankey_test", end_reason=END_REASON, threshold=0.1) + # print(compute_mean_length(game_plays)) + # get_action_type_barplot_per_step(game_plays, end_reason=END_REASON) + # get_action_type_histogram_per_step(game_plays, end_reason=END_REASON) + # generate_mdp_from_trajecotries(game_plays,filename="MDP_visualization_optimal", end_reason=END_REASON) + # states = {} + # actions = {} + # edges_optimal = gameplay_graph(game_plays, states, actions,end_reason=END_REASON) + # state_to_id = {v:k for k,v in states.items()} + # action_to_id = {v:k for k,v in states.items()} + # get_graph_stats(edges_optimal, state_to_id, action_to_id) + # # load trajectories from files # game_plays_q_learning = read_json("./NSG_trajectories_q_agent_marl.experiment0004-episodes-20000.json") # for play in game_plays_q_learning: