Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Visual Threat Model Evaluator #16

Merged
merged 7 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ appPackage/build
.conda

/*.png
/*.svg
/*.svg
*.db
172 changes: 93 additions & 79 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ python = ">=3.12.4,<3.13"
teams-ai = "^1.2.1"
python-dotenv = "^1.0.1"
aiohttp = "3.9.5"
pyautogen = {version="0.2.28", extras=["llm", "retrievechat"]}
pyautogen = {version="0.2.34", extras=["llm", "retrievechat"]}
botbuilder-azure = "^4.15.1"
azure-identity = "^1.17.1"
pillow = "10.3.0"
Expand Down
17 changes: 17 additions & 0 deletions src/autogen_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable, Any, Optional
from autogen import ConversableAgent, register_function
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability

class ImmediateExecutorCapability(AgentCapability):
def __init__(self):
super().__init__()

def add_to_agent(self, caller_agent: ConversableAgent, f: Callable[..., Any], description: str, name: Optional[str] = None):
register_function(f, caller=caller_agent, executor=caller_agent, description=description, name=name)
caller_agent.register_hook('process_message_before_send', self._process_message_before_send)

def _process_message_before_send(self, message, sender: ConversableAgent, recipient, silent):
if isinstance(message, dict):
_, res = sender.generate_tool_calls_reply([message], sender)
return res.get("content") if isinstance(res, dict) else "Answered"
return message
14 changes: 14 additions & 0 deletions src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ async def on_login(context: TurnContext, state: AppTurnState):

return True

@app.message('/useVisual')
async def set_to_visual(context: TurnContext, state: AppTurnState):
state.conversation.use_xml_evaluator = False
await state.save(context)
await context.send_activity("Ready to use visual evaluator")
return True

@app.message('/useXML')
async def set_to_xml(context: TurnContext, state: AppTurnState):
state.conversation.use_xml_evaluator = True
await state.save(context)
await context.send_activity("Ready to use XML evaluator")
return True


@app.turn_state_factory
async def turn_state_factory(context: TurnContext):
Expand Down
36 changes: 4 additions & 32 deletions src/privacy_review_assistant_group.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from autogen import AssistantAgent, GroupChat, GroupChatManager, Agent
from autogen import AssistantAgent, GroupChat, Agent
from botbuilder.core import TurnContext

from state import AppTurnState
from rag_agents import setup_rag_assistant
from threat_model_reviewer_group import ThreatModelReviewerGroup
from visualizer_agent import setup_visualizer_agent
from threat_model_visualizer import ThreatModelImageVisualizerCapability
from xml_threat_model_reviewer import setup_xml_threat_model_reviewer

USE_XML_ASSISTANT=True

class PrivacyReviewAssistantGroup:
def __init__(self, llm_config):
self.llm_config = llm_config

def group_chat_builder(self, context: TurnContext, state: AppTurnState, user_agent: Agent) -> GroupChat:
use_xml_assistant = state.conversation.use_xml_evaluator
rag_assistant = setup_rag_assistant(self.llm_config)
threat_modeling_assistant = setup_xml_threat_model_reviewer(self.llm_config, context, state) if USE_XML_ASSISTANT else self.setup_threat_modeling_assistant(context, state, user_agent)
threat_modeling_assistant = setup_xml_threat_model_reviewer(self.llm_config, context, state) if use_xml_assistant else setup_visualizer_agent(self.llm_config, context, state)
visualizer_agent = self.setup_visualizer_assistant(context, state, user_agent)
group = GroupChat(
agents=[user_agent, rag_assistant, visualizer_agent, threat_modeling_assistant],
Expand All @@ -31,33 +30,6 @@ def group_chat_builder(self, context: TurnContext, state: AppTurnState, user_age
)

return group

def setup_threat_modeling_assistant(self, context: TurnContext, state: AppTurnState, user_agent: Agent) -> Agent:
def terminate_chat(message):
message_sender_name = message.get("name", "")
return message_sender_name != user_agent.name
assistant = AssistantAgent(
name="Threat_Model_Evaluator",
description="An agent that manages a group chat for threat modeling validation and evaluation.",
is_termination_msg=terminate_chat
)

threat_modeling_group = ThreatModelReviewerGroup(llm_config=self.llm_config).group_chat_builder(context, state, assistant)
threat_modeling_group_manager = GroupChatManager(
groupchat=threat_modeling_group,
llm_config=self.llm_config,
)
def trigger(sender):
return sender not in [assistant]
assistant.register_nested_chats([
{
"recipient": threat_modeling_group_manager,
"sender": assistant,
"summary_method": "last_msg",
"max_turns": 1,
},
], trigger=trigger)
return assistant

def setup_visualizer_assistant(self, _context: TurnContext, state: AppTurnState, _user_agent: Agent) -> Agent:
visualizer_assistant = AssistantAgent(
Expand Down
8 changes: 5 additions & 3 deletions src/rag_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,10 @@ def message_generator(_recipient, messages, sender, _config):
"recipient": rag_assistant_agent,
"sender": rag_proxy_agent,
"summary_method": "last_msg",
"message": message_generator
"message": message_generator,
"chat_id": 1,
},
], trigger=assistant)
], trigger=assistant, use_async=True)

def trigger(sender):
return sender not in [assistant] # To prevent the assistant from triggering itself
Expand All @@ -166,7 +167,8 @@ def custom_summary_method(
"sender": assistant,
"summary_method": custom_summary_method,
"max_turns": 1,
"chat_id": 1,
},
], trigger=trigger)
], trigger=trigger, use_async=True)

return assistant
11 changes: 5 additions & 6 deletions src/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
Licensed under the MIT License.
"""

from typing import Optional, List, Dict
from typing import Optional, List, Dict, Union

from botbuilder.core import Storage, TurnContext
from teams.state import ConversationState, TempState, TurnState, UserState
from autogen import ConversableAgent
from datetime import datetime

class AppConversationState(ConversationState):
message_history: List[Dict] | None = None
message_history: Optional[List[Dict]] = None
is_waiting_for_user_input: bool = False
started_waiting_for_user_input_at: datetime | str | None = None
spec_url: str | None = None
started_waiting_for_user_input_at: Optional[Union[datetime, str]] = None
spec_url: Optional[str] = None
use_xml_evaluator: bool = True

@classmethod
async def load(
Expand All @@ -29,7 +29,6 @@ async def clear(self, context: TurnContext) -> None:
self.spec_url = None
await self.save(context)


class AppTurnState(TurnState[AppConversationState, UserState, TempState]):
conversation: AppConversationState

Expand Down
99 changes: 68 additions & 31 deletions src/svg_to_png/lib/ThreatModel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional, Literal, Dict, List
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
import re
Expand All @@ -22,6 +23,18 @@ def print_all(iter):
ABSTRACTS_XMLNS = "{http://schemas.datacontract.org/2004/07/ThreatModeling.Model.Abstracts}"
ARRAY_XMLNS = "{http://schemas.microsoft.com/2003/10/Serialization/Arrays}"
KNOWLEDGE_BASE_XMLNS = "{http://schemas.datacontract.org/2004/07/ThreatModeling.KnowledgeBase}"

type User_Friendly_Block_Types = Literal["Boundary", "Annotation", "External Interactor", "Node", "Data Store", "Data Flow", "Trust Boundary"]
type Key_Label_Map = Dict[User_Friendly_Block_Types, List[Dict[Literal["index", "key", "name"], str]]]
element_to_user_friendly_key: Dict[str, User_Friendly_Block_Types] = {
"GE.TB.B": "Boundary",
"GE.A": "Annotation",
"GE.EI": "External Interactor",
"GE.P": "Node",
"GE.DS": "Data Store",
"GE.DF": "Data Flow",
"GE.TB.L": "Trust Boundary"
}

def get_shape_details(shape):
height = shape.find(build_tag(ABSTRACTS_XMLNS, "Height")).text
Expand Down Expand Up @@ -74,6 +87,7 @@ def set_groups(nodes, boundaries):
set_appropriate_groups(node, boundary)

class ThreatModel:
key_label_map: Key_Label_Map

def add_element(self, el: Element, icons: dict):
generic_type_id = el.find(build_tag(ABSTRACTS_XMLNS, "GenericTypeId")).text
Expand All @@ -82,7 +96,6 @@ def add_element(self, el: Element, icons: dict):
type = any_type_properties[0][0].text
name = el.get('custom_key') if el.get('custom_key') else get_element_name(el)
if generic_type_id == "GE.DS":

shape = GenericDataStore(generic_type_id, type, name, icons, *get_shape_details(el))
self.nodes.append(shape)
elif generic_type_id == "GE.EI":
Expand All @@ -109,8 +122,15 @@ def add_element(self, el: Element, icons: dict):
shape = None


def __init__(self, file: str = None, svg_content: str = None):
if not file and not svg_content:
def __init__(self, file: Optional[str] = None, svg_content: Optional[str] = None, build_for_ai_context: bool = False):
ET.register_namespace(
'xmlns', 'http://schemas.datacontract.org/2004/07/ThreatModeling.Model')
if file:
tree = ET.parse(file)
root = tree.getroot()
elif svg_content:
root = ET.fromstring(svg_content)
else:
raise Exception("Either file or svg_content should be provided")

self.boundaries = []
Expand All @@ -119,14 +139,6 @@ def __init__(self, file: str = None, svg_content: str = None):
self.curves = []
self.annotations = []
self.trust_line_boundaries = []

ET.register_namespace(
'xmlns', 'http://schemas.datacontract.org/2004/07/ThreatModeling.Model')
if file:
tree = ET.parse(file)
root = tree.getroot()
else:
root = ET.fromstring(svg_content)

knowledgeBase = root.find(build_tag(
THREAT_MODELING_XMLNS, 'KnowledgeBase'))
Expand Down Expand Up @@ -161,34 +173,59 @@ def __init__(self, file: str = None, svg_content: str = None):
tab_lines = tab.find(build_tag(THREAT_MODELING_XMLNS, "Lines"))

borders = tab_borders.findall(build_tag(ARRAY_XMLNS, "KeyValueOfguidanyType"))
key_label_tuples = []
key_index = 0
key_label_map: Key_Label_Map = {}
element_to_key_index = {
"GE.TB.B": 0,
"GE.A": 0,
"GE.EI": 0,
"GE.P": 0,
"GE.DS": 0,
"GE.DF": 0,
"GE.TB.L": 0
}

for border in borders:
value = border.find(build_tag(ARRAY_XMLNS, "Value"))
if not value:
continue
generic_type_id = value.find(build_tag(ABSTRACTS_XMLNS, "GenericTypeId")).text
if generic_type_id == "GE.TB.B":
user_friendly_key = "Boundary"
key = f'Boundary {key_index + 1}'
self.generate_custom_key(build_for_ai_context, key_label_map, element_to_key_index, value, generic_type_id)
self.add_element(value, icons)

if tab_lines is not None:
lines = tab_lines.findall(build_tag(ARRAY_XMLNS, "KeyValueOfguidanyType"))
for line in lines:
value = line.find(build_tag(ARRAY_XMLNS, "Value"))
if not value:
continue
generic_type_id = value.find(build_tag(ABSTRACTS_XMLNS, "GenericTypeId")).text
self.generate_custom_key(build_for_ai_context, key_label_map, element_to_key_index, value, generic_type_id)
self.add_element(value, icons)

# sort all the key_label_map
for key in key_label_map:
key_label_map[key] = sorted(key_label_map[key], key=lambda x: x['index'])

self.key_label_map = key_label_map
set_groups(self.nodes, self.boundaries)

def generate_custom_key(self, build_for_ai_context, key_label_map, element_to_key_index, value, generic_type_id):
if build_for_ai_context:
if isinstance(generic_type_id, str):
user_friendly_key = element_to_user_friendly_key.get(generic_type_id, "Element")
key_index = element_to_key_index[generic_type_id]
element_to_key_index[generic_type_id] += 1
else:
if generic_type_id == "GE.A":
user_friendly_key = "Annotation"
else:
user_friendly_key = "Node"
self.add_element(value, icon)
raise ValueError(f"Unknown generic_type_id: {generic_type_id}")
key = f'{user_friendly_key} {key_index + 1}'
name = get_element_name(value)
key_label_tuples.append((key, name))
key_index += 1

lines = tab_lines.findall(build_tag(ARRAY_XMLNS, "KeyValueOfguidanyType"))
for line in lines:
value = line.find(build_tag(ARRAY_XMLNS, "Value"))
self.add_element(value, icons)

self.key_label_tuples = key_label_tuples
set_groups(self.nodes, self.boundaries)
key_label_map[user_friendly_key] = [] if not key_label_map.get(user_friendly_key) else key_label_map[user_friendly_key]
key_label_map[user_friendly_key].append({
"index": key_index + 1,
"key": key,
"name": name,
})
value.set('custom_key', key)

def convert_to_svg(self, d):
for boundary in self.boundaries:
Expand Down
11 changes: 6 additions & 5 deletions src/svg_to_png/svg_to_png.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional
import drawsvg as draw
from .lib.ThreatModel import ThreatModel
from .lib.utils import get_bbox

def load_threat_model(file: str = None, svg_content: str = None):
return ThreatModel(file, svg_content)
def load_threat_model(file: Optional[str] = None, svg_content: Optional[str] = None,build_for_ai_context: bool = False):
return ThreatModel(file, svg_content, build_for_ai_context)

def convert_svg_to_png(file: str = None, svg_content: str = None, out_file="result"):
threat_model = ThreatModel(file, svg_content)
def convert_svg_to_png(file: Optional[str] = None, svg_content: Optional[str] = None, out_file="result", build_for_ai_context: bool = False):
threat_model = ThreatModel(file, svg_content, build_for_ai_context)

d = draw.Drawing(2000, 2000)
d.append(draw.elements.Raw('<style>@import url("https://fonts.googleapis.com/css?family=Open+Sans:400,400i,700,700i");</style>'))
Expand All @@ -22,4 +23,4 @@ def convert_svg_to_png(file: str = None, svg_content: str = None, out_file="resu
d.save_svg(f'{file_name}.svg')
d.save_png(f'{file_name}.png')

return threat_model.key_label_tuples
return threat_model.key_label_map
Loading