Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lib/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def convert_sarif(app_name, repo_context, sarif_files, findings_fname):
tags.append(
{
"key": "owasp_category",
"value": owasp_category.split("-")[0].capitalize(),
"value": owasp_category,
"shiftleft_managed": True,
}
)
Expand All @@ -402,6 +402,9 @@ def convert_sarif(app_name, repo_context, sarif_files, findings_fname):
lineno = location.get("physicalLocation", {})["region"][
"startLine"
]
end_lineno = location.get("physicalLocation", {})[
"contextRegion"
]["endLine"]
finding = {
"app": app_name,
"type": "extscan",
Expand All @@ -412,6 +415,7 @@ def convert_sarif(app_name, repo_context, sarif_files, findings_fname):
utils.calculate_line_hash(
filename,
lineno,
end_lineno,
location.get("physicalLocation", {})["region"][
"snippet"
]["text"],
Expand Down
14 changes: 12 additions & 2 deletions lib/pyt/cfg/expr_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def init_function_cfg(self, node, module_definitions):

first_node = module_statements.first_statement

if CALL_IDENTIFIER not in first_node.label:
if first_node is not None and CALL_IDENTIFIER not in first_node.label:
entry_node.connect(first_node)

last_nodes = module_statements.last_statements
Expand Down Expand Up @@ -138,7 +138,7 @@ def visit_Attribute(self, node):
def visit_Name(self, node):
return self.visit_miscelleaneous_node(node)

def visit_NameConstant(self, node):
def visit_Constant(self, node):
return self.visit_miscelleaneous_node(node)

def visit_Str(self, node):
Expand Down Expand Up @@ -250,6 +250,8 @@ def save_def_args_in_temp(

# Create e.g. temp_N_def_arg1 = call_arg1_label_visitor.result for each argument
for i, call_arg in enumerate(call_args):
if i > len(def_args) - 1:
break
# If this results in an IndexError it is invalid Python
def_arg_temp_name = (
"temp_" + str(saved_function_call_index) + "_" + def_args[i]
Expand Down Expand Up @@ -333,7 +335,11 @@ def create_local_scope_from_def_args(
preceding call to save_def_args_in_temp.
"""
# Create e.g. def_arg1 = temp_N_def_arg1 for each argument
if def_args is None:
return
for i in range(len(call_args)):
if i > len(def_args) - 1:
return
def_arg_local_name = def_args[i]
def_arg_temp_name = (
"temp_" + str(saved_function_call_index) + "_" + def_args[i]
Expand Down Expand Up @@ -370,6 +376,8 @@ def visit_and_get_function_nodes(self, definition, first_node):
self.connect_if_allowed(previous_node, entry_node)

function_body_connect_statements = self.stmt_star_handler(definition.node.body)
if isinstance(function_body_connect_statements, IgnoredNode):
return (IgnoredNode, first_node)
entry_node.connect(function_body_connect_statements.first_statement)

exit_node = self.append_node(EntryOrExitNode("Exit " + definition.name))
Expand Down Expand Up @@ -439,6 +447,8 @@ def return_handler(
saved_function_call_index(int): Unique number for each call.
first_node(EntryOrExitNode or RestoreNode): Used to connect previous statements to this function.
"""
if not hasattr(function_nodes, "__iter__"):
return
if any(isinstance(node, YieldNode) for node in function_nodes):
# Presence of a `YieldNode` means that the function is a generator
rhs_prefix = "yld_"
Expand Down
38 changes: 32 additions & 6 deletions lib/pyt/cfg/stmt_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
module.name for module in iter_modules()
} # Don't warn about failing to import these

# Some builtin packages to break recursion
BUILTIN_PKGS = ["os", "logging", "json", "markdown"]

# Cache to keep track of visited modules to break recursion
visited_module_paths = {}


class StmtVisitor(ast.NodeVisitor):
def __init__(self, allow_local_directory_imports=True):
Expand Down Expand Up @@ -190,6 +196,8 @@ def handle_or_else(self, orelse, test):
else_connect_statements = self.stmt_star_handler(
orelse, prev_node_to_avoid=self.nodes[-1]
)
if isinstance(else_connect_statements, IgnoredNode):
return IgnoredNode()
test.connect(else_connect_statements.first_statement)
return else_connect_statements.last_statements

Expand All @@ -205,6 +213,8 @@ def visit_If(self, node):

if node.orelse:
orelse_last_nodes = self.handle_or_else(node.orelse, test)
if isinstance(orelse_last_nodes, IgnoredNode):
return IgnoredNode()
body_connect_stmts.last_statements.extend(orelse_last_nodes)
else:
body_connect_stmts.last_statements.append(
Expand All @@ -229,6 +239,8 @@ def visit_Return(self, node):

if isinstance(node.value, ast.Call):
return_value_of_call = self.visit(node.value)
if not hasattr(return_value_of_call, "left_hand_side"):
return None
return_node = ReturnNode(
LHS + " = " + return_value_of_call.left_hand_side,
LHS,
Expand Down Expand Up @@ -293,7 +305,8 @@ def visit_Try(self, node):
orelse_last_nodes = self.handle_or_else(
node.orelse, body.last_statements[-1]
)
body.last_statements.extend(orelse_last_nodes)
if not isinstance(orelse_last_nodes, IgnoredNode):
body.last_statements.extend(orelse_last_nodes)

if node.finalbody:
finalbody = self.stmt_star_handler(node.finalbody)
Expand Down Expand Up @@ -482,6 +495,8 @@ def assignment_call_node(self, left_hand_label, ast_node):
self.undecided = True # Used for handling functions in assignments

call = self.visit(ast_node.value)
if not hasattr(call, "left_hand_side"):
return None
call_label = call.left_hand_side

call_assignment = AssignmentCallNode(
Expand Down Expand Up @@ -807,12 +822,16 @@ def add_module( # noqa: C901
"""
module_path = module[1]

if module_or_package_name in BUILTIN_PKGS:
uninspectable_modules.add(module_or_package_name)
return IgnoredNode()
if visited_module_paths.get(module[0]):
return IgnoredNode()
visited_module_paths[module[0]] = True
parent_definitions = self.module_definitions_stack[-1]
# Here, in `visit_Import` and in `visit_ImportFrom` are the only places the `import_alias_mapping` is updated
parent_definitions.import_alias_mapping.update(import_alias_mapping)
parent_definitions.import_names = local_names
if not module_or_package_name:
return
new_module_definitions = ModuleDefinitions(local_names, module_or_package_name)
new_module_definitions.is_init = is_init
self.module_definitions_stack.append(new_module_definitions)
Expand All @@ -824,7 +843,7 @@ def add_module( # noqa: C901
)
tree = generate_ast(module_path)
if not tree:
return None
return IgnoredNode()
# module[0] is None during e.g. "from . import foo", so we must str()
self.nodes.append(EntryOrExitNode("Module Entry " + str(module[0])))
self.visit(tree)
Expand Down Expand Up @@ -899,7 +918,6 @@ def add_module( # noqa: C901
)
parent_definition.node = def_.node
parent_definitions.definitions.append(parent_definition)

return exit_node

def from_directory_import(
Expand Down Expand Up @@ -1001,6 +1019,9 @@ def handle_relative_import(self, node):

# Is it a file?
if name_with_dir.endswith(".py"):
if visited_module_paths.get(name_with_dir):
return IgnoredNode()
visited_module_paths[name_with_dir] = True
return self.add_module(
module=(node.module, name_with_dir),
module_or_package_name=None,
Expand Down Expand Up @@ -1083,8 +1104,14 @@ def visit_ImportFrom(self, node):
)
for module in self.project_modules:
name = module[0]
if node.level == 0:
break
if node.module == name:
if os.path.isdir(module[1]):
if visited_module_paths.get(module[1]):
return IgnoredNode()
# Break recursion
visited_module_paths[module[1]] = True
return self.from_directory_import(
module,
not_as_alias_handler(node.names),
Expand All @@ -1107,7 +1134,6 @@ def visit_ImportFrom(self, node):
local_definitions.import_alias_mapping[
name.asname or name.name
] = "{}.{}".format(node.module, name.name)

if node.module not in uninspectable_modules:
uninspectable_modules.add(node.module)
return IgnoredNode()
2 changes: 1 addition & 1 deletion lib/pyt/helper_visitors/label_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def visit_keyword(self, node):
def insert_space(self):
self.result += " "

def visit_NameConstant(self, node):
def visit_Constant(self, node):
self.result += str(node.value)

def visit_Subscript(self, node):
Expand Down
3 changes: 3 additions & 0 deletions lib/pyt/vulnerabilities/trigger_definitions_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def kwarg_propagates(self, keyword):
def get_kwarg_from_position(self, index):
return self.arg_position_to_kwarg.get(index)

def __str__(self):
return f"Sink: Type: {self.sink_type}, Trigger: {self._trigger}"

@property
def all_arguments_propagate_taint(self):
if self.kwarg_list:
Expand Down
62 changes: 57 additions & 5 deletions lib/pyt/vulnerabilities/vulnerabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def find_triggers(nodes, trigger_words):
"""
trigger_nodes = list()
for node in nodes:
trigger_nodes.extend(iter(label_contains(node, trigger_words)))
trigger_nodes.extend(iter(label_starts_with(node, trigger_words)))
return trigger_nodes


Expand All @@ -124,6 +124,28 @@ def label_contains(node, triggers):
yield TriggerNode(trigger, node)


def label_starts_with(node, triggers):
"""Determine if node starts with the trigger_words provided.

Args:
node(Node): CFG node to check.
trigger_words(list[Union[Sink, Source]]): list of trigger words to look for.

Returns:
Iterable of TriggerNodes found. Can be multiple because multiple
trigger_words can be in one node.
"""
for trigger in triggers:
if trigger.trigger_word in node.label:
if (
f"ret_{trigger.trigger_word}" in node.label
or f" {trigger.trigger_word}" in node.label
or f".{trigger.trigger_word}" in node.label
or node.label.startswith(trigger.trigger_word)
):
yield TriggerNode(trigger, node)


def build_sanitiser_node_dict(cfg, sinks_in_file):
"""Build a dict of string -> TriggerNode pairs, where the string
is the sanitiser and the TriggerNode is a TriggerNode of the sanitiser.
Expand Down Expand Up @@ -229,8 +251,11 @@ def get_vulnerability_chains(current_node, sink, def_use, chain=[]):
yield chain
else:
vuln_chain = list(chain)
vuln_chain.append(use)
yield from get_vulnerability_chains(use, sink, def_use, vuln_chain)
if use not in vuln_chain:
vuln_chain.append(use)
yield from get_vulnerability_chains(use, sink, def_use, vuln_chain)
else:
yield chain


def how_vulnerable(
Expand Down Expand Up @@ -319,7 +344,6 @@ def get_vulnerability(source, sink, triggers, lattice, cfg, blackbox_mapping):
sink_args = get_sink_args(sink.cfg_node)
else:
sink_args = get_sink_args_which_propagate(sink, sink.cfg_node.ast_node)

tainted_node_in_sink_arg = get_tainted_node_in_sink_args(
sink_args, nodes_in_constraint,
)
Expand Down Expand Up @@ -363,10 +387,33 @@ def get_vulnerability(source, sink, triggers, lattice, cfg, blackbox_mapping):

vuln_deets["reassignment_nodes"] = chain
return vuln_factory(vulnerability_type)(**vuln_deets)

return None


def filter_over_taint(vulnerability, source, sink, blackbox_mapping):
"""Filter over tainted objects such as Sensitive Data Leaks
"""
source_cfg = source.cfg_node
sink_cfg = sink.cfg_node
sensitive_data_list = blackbox_mapping.get("sensitive_data_list")
sensitive_allowed_log_levels = blackbox_mapping.get("sensitive_allowed_log_levels")
source_type = source.source_type
sink_type = sink.sink_type
if sink_type == "Logging":
# Ignore logging for non-sensitive data
if source_cfg.label.lower() not in sensitive_data_list:
return None
# Ignore vulnerabilities with acceptable log levels
for log_level in sensitive_allowed_log_levels:
if log_level in sink.trigger_word.lower():
return None
# render method based on Framework_Parameter is a known FP
if sink_type == "ReturnedToUser":
if sink.trigger_word == "render(" and source_type == "Framework_Parameter":
return None
return vulnerability


def find_vulnerabilities_in_cfg(
cfg, definitions, lattice, blackbox_mapping, vulnerabilities_list
):
Expand All @@ -385,6 +432,11 @@ def find_vulnerabilities_in_cfg(
vulnerability = get_vulnerability(
source, sink, triggers, lattice, cfg, blackbox_mapping
)
# Filter over-tained vulnerability
if vulnerability:
vulnerability = filter_over_taint(
vulnerability, source, sink, blackbox_mapping
)
if vulnerability:
vulnerabilities_list.append(vulnerability)

Expand Down
13 changes: 10 additions & 3 deletions lib/pyt/vulnerabilities/vulnerability_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def as_dict(self):
source_path = self.source.path.split("/")[-1]
sink_path = self.sink.path.split("/")[-1]
rule_id = self.source_type

rule_name = f"Data flow from {self.source_type} to {self.sink_type}"
description = f"User controlled data flow from the source `{source_path}:{self.source.line_number}` to the sink `{sink_path}:{self.sink.line_number}`"
rule = source_sink_rules.get(f"{self.source_type}:{self.sink_type}")
Expand All @@ -116,13 +117,19 @@ def as_dict(self):
rule_name = rule.name
severity = rule.severity
message_format = rule.message_format
message_format = message_format.replace(
"{$sources}", f"{source_path}:{self.source.line_number}"
)
owasp_category = rule.owasp_category
sources = f"{source_path}:{self.source.line_number}"
if self.source_type == "Framework_Parameter":
source_parameter = self.source.label
sources = (
f"{source_parameter} in {source_path}:{self.source.line_number}"
)
message_format = message_format.replace("{$sources}", sources)
message_format = message_format.replace(
"{$sinks}", f"{sink_path}:{self.sink.line_number}"
)
description = message_format

return {
"rule_id": rule_id,
"rule_name": rule_name,
Expand Down
Loading