Skip to content

Commit

Permalink
extended the utility script
Browse files Browse the repository at this point in the history
  • Loading branch information
Cornul11 committed Feb 27, 2024
1 parent 8ed3f29 commit 0e0807c
Showing 1 changed file with 212 additions and 6 deletions.
218 changes: 212 additions & 6 deletions util/check_for_unknown_vulns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
import argparse
import json
import os
import subprocess

import requests
from jproperties import Properties
from mysql.connector import pooling
from pymongo import MongoClient
from tqdm import tqdm


def parse_database_url(db_url):
# db_url is in the format "jdbc:postgresql://localhost:5432/maven"
try:
url_parts = db_url.split("//")[1].split("/")
host_port = url_parts[0]
database = url_parts[1]

host = host_port.split(":")[0]

return host, database
except IndexError:
raise ValueError("Invalid database URL format")


properties = Properties()
with open("../config.properties", "rb") as properties_file:
properties.load(properties_file, "utf-8")

db_host, db_name = parse_database_url(properties.get("database.url").data)


def load_json_file(file_path):
Expand Down Expand Up @@ -35,6 +63,7 @@ def analyze_artifacts(maven_artifacts_file, artifacts_directory):
usable_shaded_jar_count = 0
detected_unknown_vulnerable_versions = 0
detected_known_vulnerable_versions = 0
total_known_vulnerable_versions = 0
maven_artifacts = load_json_file(maven_artifacts_file)
for artifact in maven_artifacts:
group_id = artifact["groupId"]
Expand All @@ -51,8 +80,9 @@ def analyze_artifacts(maven_artifacts_file, artifacts_directory):
detected_vulnerabilities = find_vulnerabilities_in_inferred_artifact(
inferred_artifact_file_path
)
is_vulnerable_version = version_key == "mostUsedVulnerableVersion"
total_known_vulnerable_versions += 1 if is_vulnerable_version else 0
if detected_vulnerabilities:
is_vulnerable_version = version_key == "mostUsedVulnerableVersion"
print_artifact_info(
group_id,
artifact_id,
Expand All @@ -70,28 +100,204 @@ def analyze_artifacts(maven_artifacts_file, artifacts_directory):
f"Total detected vulnerable versions: {detected_unknown_vulnerable_versions} ({(detected_unknown_vulnerable_versions / usable_shaded_jar_count) * 100:.2f}%)"
)
print(
f"Total detected known vulnerable versions: {detected_known_vulnerable_versions} ({(detected_known_vulnerable_versions / usable_shaded_jar_count) * 100:.2f}%)"
f"Total detected known vulnerable versions: {detected_known_vulnerable_versions} ({(detected_known_vulnerable_versions / total_known_vulnerable_versions) * 100:.2f}%)"
)


def connect_to_db():
try:
connection_pool = pooling.MySQLConnectionPool(
pool_name="pom_resolution_pool",
pool_size=5,
host=db_host,
database=db_name,
user=properties.get("database.username").data,
password=properties.get("database.password").data,
)
return connection_pool
except Exception as e:
print(f"Error connecting to the database: {e}")
return None


def get_all_libraries(cursor):
cursor.execute("SELECT id, group_id, artifact_id, version FROM libraries")
return cursor.fetchall()


def connect_to_mongodb():
client = MongoClient("mongodb://localhost:27072/")
db = client.osv_db
return db


def check_vulnerability_in_mongodb(db, group_id, artifact_id, version):
query = {
"affected.package.name": f"{group_id}:{artifact_id}",
"affected.package.ecosystem": "Maven",
"affected.versions": {"$in": [version]},
}
count = db.data.count_documents(query)
return count > 0


def get_vulnerable_libraries_from_mongodb(db):
query = {
"affected.package.ecosystem": "Maven",
}
return db.data.find(query)


def update_library_vulnerability_status(vulnerable_libraries, output_file_path):
not_found = 0
total_vulnerable = 0
pool = connect_to_db()
if pool:
cnx = pool.get_connection()
cursor = cnx.cursor(buffered=True)

with open(output_file_path, "w") as file:
for vuln in tqdm(vulnerable_libraries):
maven_affected = [
a for a in vuln["affected"] if a["package"]["ecosystem"] == "Maven"
]
if not maven_affected:
continue

affected_package = maven_affected[0]
if "versions" in affected_package:
for version in affected_package["versions"]:
total_vulnerable += 1
group_id, artifact_id = affected_package["package"][
"name"
].split(":")

file.write(f"{group_id}:{artifact_id}:{version}\n")
print(
f"Updating {group_id}:{artifact_id} version {version} to vulnerable"
)
query = "SELECT id FROM libraries WHERE group_id = %s AND artifact_id = %s AND version = %s"
cursor.execute(query, (group_id, artifact_id, version))
library_id = cursor.fetchone()
if not library_id:
print(
f"Library {group_id}:{artifact_id} version {version} not found in corpus"
)
not_found += 1
# query = "UPDATE libraries SET vulnerable = 1 WHERE group_id = %s AND artifact_id = %s AND version = %s"
# cursor.execute(query, (group_id, artifact_id, version))
cnx.commit()
cursor.close()
cnx.close()
print(f"Total not found: {not_found}")
print(f"Total vulnerable: {total_vulnerable}")


def fill_vulnerabilities(output_file_path="vulnerable_versions.txt"):
mongo_db = connect_to_mongodb()
vulnerable_libraries = get_vulnerable_libraries_from_mongodb(mongo_db)
update_library_vulnerability_status(vulnerable_libraries, output_file_path)


def check_if_exists_in_maven_central_index(group_id, artifact_id, version):
try:
response = requests.get(
"http://localhost:8032/lookup",
params={
"groupId": group_id,
"artifactId": artifact_id,
"version": version,
},
)
return response.status_code == 200
except subprocess.CalledProcessError as e:
print(f"An error occurred: {e}")
return None


def filter_maven_central_artifacts(input_file):
with open(input_file, "r") as file:
vulnerable_artifacts = file.readlines()

vulnerable_artifacts = [a.strip() for a in vulnerable_artifacts]
count_in_maven_index = 0
for artifact in tqdm(vulnerable_artifacts):
group_id, artifact_id, version = artifact.split(":")
exists = check_if_exists_in_maven_central_index(group_id, artifact_id, version)
if exists is not None:
if exists:
count_in_maven_index += 1
with open("filtered_vulnerable_versions.txt", "w") as file:
file.write(f"{group_id}:{artifact_id}:{version}\n")

print(f"Total in Maven Central index: {count_in_maven_index}")
print(
f"Percentage in Maven Central index: {(count_in_maven_index / len(vulnerable_artifacts)) * 100:.2f}%"
)


def download_vulnerable_artifacts(input_file, download_output_path):
with open(input_file, "r") as file:
vulnerable_artifacts = file.readlines()

vulnerable_artifacts = [a.strip() for a in vulnerable_artifacts]
for artifact in tqdm(vulnerable_artifacts):
group_id, artifact_id, version = artifact.split(":")
group_id_path = group_id.replace(".", "/")
download_path = os.path.join(download_output_path, group_id_path, artifact_id, version, f"{artifact_id}-{version}.jar")
# https://repo1.maven.org/maven2/com/daml/participant-state_2.13/2.3.13/participant-state_2.13-2.3.13.jar
if not os.path.exists(download_path):
url = f"https://repo1.maven.org/maven2/{group_id_path}/{artifact_id}/{version}/{artifact_id}-{version}.jar"
response = requests.get(url)
# check if the entire path exists
os.makedirs(os.path.dirname(download_path), exist_ok=True)
with open(download_path, "wb") as file:
file.write(response.content)
import sys
sys.exit(0)


def main():
parser = argparse.ArgumentParser(
description="Analyze Maven artifacts for vulnerabilities"
)
parser.add_argument("--mode", required=True, help="Mode of operation")
parser.add_argument(
"--output_file", help="Path to the vulnerable artifacts GAV output file"
)
parser.add_argument(
"--input_file", help="Path to the vulnerable artifacts GAV input file"
)
parser.add_argument(
"--maven_artifacts_file",
required=True,
help="Path to the JSON file containing Maven artifacts information",
)
parser.add_argument(
"--artifacts_directory",
required=True,
help="Path to the directory containing the inferred artifacts metadata",
)
parser.add_argument("--download_output_path", help="Path to the output directory")

args = parser.parse_args()
analyze_artifacts(args.maven_artifacts_file, args.artifacts_directory)

if args.mode == "analyze_artifacts":
if args.maven_artifacts_file and args.artifacts_directory:
analyze_artifacts(args.maven_artifacts_file, args.artifacts_directory)
else:
print(
"Error: Both --maven_artifacts_file and --artifacts_directory are required for this mode"
)
elif args.mode == "fill_vulnerabilities":
fill_vulnerabilities()
elif args.mode == "filter_maven_central_artifacts":
if not args.input_file:
print("Error: --input_file is required for this mode")
filter_maven_central_artifacts(args.input_file)
elif args.mode == "download_vulnerable_artifacts":
if not args.input_file or not args.download_output_path:
print(
"Error: Both --input_file and --download_output_path are required for this mode"
)
download_vulnerable_artifacts(args.input_file, args.download_output_path)

if __name__ == "__main__":
main()

0 comments on commit 0e0807c

Please sign in to comment.