Skip to content

Commit

Permalink
inference output to file, db mode fixes, threshold as argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Cornul11 committed Nov 23, 2023
1 parent 145f3f1 commit 16e60f2
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 113 deletions.
3 changes: 2 additions & 1 deletion config.properties.example
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ numConsumerThreads=10
totalJars=42
mongoDbDatabase=osv_db
mongoDbCollection=data
mongoDbConnectionString=mongodb://localhost:27072
mongoDbConnectionString=mongodb://localhost:27072
databaseMode=file
2 changes: 1 addition & 1 deletion run_inference.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
/usr/lib/jvm/java-11-openjdk-amd64/bin/java -Xmx8g -cp target/dependency/*:target/thesis-1.0-SNAPSHOT.jar nl.tudelft.cornul11.thesis.corpus.MainApp -dbm memory -m DETECTION_MODE -f "$1"
/usr/lib/jvm/java-11-openjdk-amd64/bin/java -Xmx8g -cp target/dependency/*:target/thesis-1.0-SNAPSHOT.jar nl.tudelft.cornul11.thesis.corpus.MainApp -m IDENTIFICATION_MODE -f "$1"
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static void main(String[] args) {

DatabaseConfig databaseConfig = config.getDatabaseConfig();
DatabaseManager databaseManager = DatabaseManager.getInstance(databaseConfig);
SignatureDAO signatureDao = databaseManager.getSignatureDao();
SignatureDAO signatureDao = databaseManager.getSignatureDao(config.getDatabaseMode());
// Fetch hashes for a specific artifactId and version from the database
if (false) {
List<Long> dbHashesForArtifact = ((SignatureDAOImpl) signatureDao).getHashesForArtifactIdVersion("logback-core", "1.4.0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class FatJarServer extends AbstractHandler {
ConfigurationLoader config = new ConfigurationLoader();
DatabaseConfig databaseConfig = config.getDatabaseConfig();
DatabaseManager databaseManager = DatabaseManager.getInstance(databaseConfig);
SignatureDAO signatureDao = databaseManager.getSignatureDao();
SignatureDAO signatureDao = databaseManager.getSignatureDao(config.getDatabaseMode());
JarSignatureMapper jarSignatureMapper = new JarSignatureMapper(signatureDao);

MultipartConfigElement multipartConfig = new MultipartConfigElement(location, maxFileSize, maxRequestSize, fileSizeThreshold);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,51 +41,57 @@ public void run() {
if (mode != null) {
DatabaseConfig databaseConfig = config.getDatabaseConfig();
DatabaseManager databaseManager = DatabaseManager.getInstance(databaseConfig);
SignatureDAO signatureDao = databaseManager.getSignatureDao(options.getDatabaseMode());
SignatureDAO signatureDao = databaseManager.getSignatureDao(config.getDatabaseMode());

if ("CORPUS_GEN_MODE".equals(mode)) {
JarFileExplorer jarFileExplorer = new JarFileExplorer(signatureDao, config);
String directoryPath = options.getDirectory();
String filePaths = options.getFilePaths();
if (directoryPath != null) {
jarFileExplorer.processFiles(directoryPath, options.getLastPath());
} else if (filePaths != null) {
jarFileExplorer.processFilesFromPathListFile(filePaths, options.getLastPath());
} else {
System.out.println("Directory path or file path(s) is required for CORPUS_GEN_MODE");
printHelpMessage();
}
} else if ("DETECTION_MODE".equals(mode)) {
String fileName = options.getFilename();
if (fileName != null) {
JarFrequencyAnalyzer jarFrequencyAnalyzer = new JarFrequencyAnalyzer(signatureDao);
Map<String, Map<String, Object>> frequencyMap = jarFrequencyAnalyzer.processJar(fileName);
if (frequencyMap == null) {
logger.error("Error in processing jar file, ignoring it");
return;
switch (mode) {
case "CORPUS_GEN_MODE":
JarFileExplorer jarFileExplorer = new JarFileExplorer(signatureDao, config);
String directoryPath = options.getDirectory();
String filePaths = options.getFilePaths();
if (directoryPath != null) {
jarFileExplorer.processFiles(directoryPath, options.getLastPath());
} else if (filePaths != null) {
jarFileExplorer.processFilesFromPathListFile(filePaths, options.getLastPath());
} else {
System.out.println("Directory path or file path(s) is required for CORPUS_GEN_MODE");
printHelpMessage();
}
int totalClassCount = jarFrequencyAnalyzer.getTotalClassCount();
break;
case "IDENTIFICATION_MODE":
String fileName = options.getFilename();
Double threshold = options.getThreshold();
if (fileName != null) {
JarFrequencyAnalyzer jarFrequencyAnalyzer = new JarFrequencyAnalyzer(signatureDao);
Map<String, Map<String, Object>> frequencyMap = jarFrequencyAnalyzer.processJar(fileName);
if (frequencyMap == null) {
logger.error("Error in processing jar file, ignoring it");
return;
}
int totalClassCount = jarFrequencyAnalyzer.getTotalClassCount();

VulnerabilityAnalyzer vulnerabilityAnalyzer = new VulnerabilityAnalyzer(totalClassCount, config);
VulnerabilityAnalyzer vulnerabilityAnalyzer = new VulnerabilityAnalyzer(totalClassCount, config, threshold);

vulnerabilityAnalyzer.checkForVulnerability(frequencyMap);
} else {
System.out.println("File name is required for DETECTION_MODE");
printHelpMessage();
}
} else if ("EVALUATION_MODE".equals(mode)) {
String evaluationDirectory = options.getEvaluationDirectory();
if (evaluationDirectory != null) {
JarEvaluator jarEvaluator = new JarEvaluator(signatureDao, evaluationDirectory);
Map<String, List<JarEvaluator.InferredLibrary>> inferredLibrariesMap = jarEvaluator.inferLibrariesFromJars();
jarEvaluator.evaluate(inferredLibrariesMap);
} else {
System.out.println("Evaluation directory is required for EVALUATION_MODE");
vulnerabilityAnalyzer.checkForVulnerability(frequencyMap, options.getOutput());
} else {
System.out.println("File name is required for IDENTIFICATION_MODE");
printHelpMessage();
}
break;
case "EVALUATION_MODE":
String evaluationDirectory = options.getEvaluationDirectory();
if (evaluationDirectory != null) {
JarEvaluator jarEvaluator = new JarEvaluator(signatureDao, evaluationDirectory);
Map<String, List<JarEvaluator.InferredLibrary>> inferredLibrariesMap = jarEvaluator.inferLibrariesFromJars();
jarEvaluator.evaluate(inferredLibrariesMap);
} else {
System.out.println("Evaluation directory is required for EVALUATION_MODE");
printHelpMessage();
}
break;
default:
System.out.println("Invalid mode specified: " + mode);
printHelpMessage();
}
} else {
System.out.println("Invalid mode specified: " + mode);
printHelpMessage();
break;
}
} else {
System.out.println("No mode specified");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,23 @@ public String getFilePaths() {
return cmd.getOptionValue("fp");
}

public String getDatabaseMode() {
return cmd.getOptionValue("dbm");
public String getOutput() {
return cmd.getOptionValue("o");
}

public Double getThreshold() {
String thresholdValue = cmd.getOptionValue("t");
try {
double threshold = Double.parseDouble(thresholdValue);
if (threshold < 0 || threshold > 1) {
logger.error("Threshold value must be between 0 and 1");
System.exit(1);
}
return threshold;
} catch (Exception e) {
logger.error("Invalid threshold value: {}, must be a double value between 0 and 1", thresholdValue);
return null;
}
}

private Options buildOptions() {
Expand Down Expand Up @@ -82,7 +97,7 @@ private Options buildOptions() {
.longOpt("mode")
.hasArg()
.argName("mode")
.desc("Specify the operation mode: CORPUS_GEN_MODE or DETECTION_MODE")
.desc("Specify the operation mode: CORPUS_GEN_MODE, IDENTIFICATION_MODE or EVALUATION_MODE")
.build());

options.addOption(Option.builder("p")
Expand All @@ -106,13 +121,19 @@ private Options buildOptions() {
.desc("Specify the directory path for evaluation mode")
.build());

options.addOption(Option.builder("dbm")
.longOpt("databaseMode")
options.addOption(Option.builder("o")
.longOpt("output")
.hasArg()
.argName("mode")
.desc("Specif the database mode, memory or file")
.argName("file")
.desc("Specify the path to inference output file")
.build());

options.addOption(Option.builder("t")
.longOpt("threshold")
.hasArg()
.argName("threshold")
.desc("Specify the threshold for inference")
.build());
return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@ public static DatabaseManager getInstance(DatabaseConfig config) {
return InstanceHolder.instance;
}

public SignatureDAO getSignatureDao() {
return new SignatureDAOImpl(ds);
}

public SignatureDAO getSignatureDao(String dbmode) {
return new SignatureDAOImpl(ds, dbmode);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public SignatureDAOImpl(HikariDataSource ds) {

public SignatureDAOImpl(HikariDataSource ds, String dbMode) {
this.ds = ds;
this.dbMode = dbMode;
this.dbMode = Objects.requireNonNullElse(dbMode, "file");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,38 @@
package nl.tudelft.cornul11.thesis.corpus.service;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import nl.tudelft.cornul11.thesis.corpus.database.MongoDbClient;
import nl.tudelft.cornul11.thesis.corpus.util.ConfigurationLoader;
import org.bson.Document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

public class VulnerabilityAnalyzer {
private static final Logger logger = LoggerFactory.getLogger(VulnerabilityAnalyzer.class);
public static final double THRESHOLD = 0.8;
private static final double DEFAULT_THRESHOLD = 0.8;
private final double threshold;
private final MongoDbClient mongoDbClient;
private final int totalClassFileCount;
private final ConfigurationLoader config;

public VulnerabilityAnalyzer(int totalClassFileCount, ConfigurationLoader config) {
public VulnerabilityAnalyzer(int totalClassFileCount, ConfigurationLoader config, Double threshold) {
this.totalClassFileCount = totalClassFileCount;
this.config = config;
this.mongoDbClient = new MongoDbClient(config);
this.threshold = Objects.requireNonNullElse(threshold, DEFAULT_THRESHOLD);
}

public void checkForVulnerability(Map<String, Map<String, Object>> libraryVersionMap) {
public void checkForVulnerability(Map<String, Map<String, Object>> libraryVersionMap, String filePath) {
List<LibraryVersion> libraryVersions = new ArrayList<>();
for (Map.Entry<String, Map<String, Object>> entry : libraryVersionMap.entrySet()) {
double ratio = (double) entry.getValue().get("ratio");

if (ratio >= THRESHOLD) { // TODO: make this a parameter, or make it configurable;
if (ratio >= threshold) { // TODO: make this a parameter, or make it configurable;
LibraryVersion libraryVersion = getLibraryVersion(entry);
libraryVersions.add(libraryVersion);
}
Expand All @@ -41,10 +44,27 @@ public void checkForVulnerability(Map<String, Map<String, Object>> libraryVersio
// Log results
for (LibraryVersion libraryVersion : libraryVersions) {
ArrayList<Document> vulnerabilitiesList = mongoDbClient.getVulnerabilities(libraryVersion.getLibrary(), libraryVersion.getVersion());

List<Map<String, Object>> simplifiedVulnerabilities = vulnerabilitiesList.stream()
.map(Document::toJson)
.map(jsonString -> {
try {
return new ObjectMapper().readValue(jsonString, new TypeReference<Map<String, Object>>() {});
} catch (IOException e) {
logger.error("Error parsing vulnerability document to JSON", e);
return new HashMap<String, Object>();
}
})
.collect(Collectors.toList());
libraryVersion.setVulnerabilities(simplifiedVulnerabilities);
logResult(libraryVersion, vulnerabilitiesList);
}

if (libraryVersions.size() == 0) {
if (filePath != null) {
writeAnalysisToJsonFile(libraryVersions, filePath);
}

if (libraryVersions.isEmpty()) {
logger.info("No matches or vulnerabilities found");
}
}
Expand All @@ -60,23 +80,23 @@ private LibraryVersion getLibraryVersion(Map.Entry<String, Map<String, Object>>
return new LibraryVersion(library, version, count, total, ratio);
}

private void logResult(LibraryVersion libraryVersion, ArrayList<Document> vulnList) {
private void logResult(LibraryVersion libraryVersion, ArrayList<Document> vulnerabilities) {
DecimalFormat decimalFormat = new DecimalFormat("#.##");
double libraryPercentage = libraryVersion.getRatio() * 100;
double overallPercentage = (libraryVersion.getCount() * 100.0) / totalClassFileCount;
String libraryPercentageString = decimalFormat.format(libraryPercentage);
String overallPercentageString = decimalFormat.format(overallPercentage);

StringBuilder status;
if (vulnList.isEmpty()) {
if (vulnerabilities.isEmpty()) {
status = new StringBuilder("✅");
} else {
status = new StringBuilder("❌");
try {
for (Document vuln : vulnList) {
Object aliasesField = vuln.get("aliases");
for (Document doc : vulnerabilities) {
Object aliasesField = doc.get("aliases");
if (aliasesField == null) {
status.append(" -> ").append(vuln.get("id"));
status.append(" -> ").append(doc.get("id"));
} else if (aliasesField instanceof List && !((List<?>) aliasesField).isEmpty()) {
List<?> aliases = (List<?>) aliasesField;
Object firstAlias = aliases.get(0);
Expand All @@ -94,19 +114,38 @@ private void logResult(LibraryVersion libraryVersion, ArrayList<Document> vulnLi
logger.info(output);
}

public void writeAnalysisToJsonFile(List<LibraryVersion> libraryVersions, String filePath) {
ObjectMapper objectMapper = new ObjectMapper();
try {
objectMapper.writeValue(new File(filePath), libraryVersions);
} catch (IOException e) {
logger.error("Error while writing analysis to JSON file", e);
}
}

public static class LibraryVersion {
private final String library;
private final String version;
private final long count;
private final long total;
private final double ratio;
private List<Map<String, Object>> vulnerabilities;

public LibraryVersion(String library, String version, long count, long total, double ratio) {
this.library = library;
this.version = version;
this.count = count;
this.total = total;
this.ratio = ratio;
this.vulnerabilities = new ArrayList<>();
}

public void setVulnerabilities(List<Map<String, Object>> vulnerabilities) {
this.vulnerabilities = vulnerabilities;
}

public List<Map<String, Object>> getVulnerabilities() {
return vulnerabilities;
}

public String getLibrary() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ public String getMongoDbCollection() {
public String getMongoDbConnectionString() {
return config.getProperty("mongoDbConnectionString");
}

public String getDatabaseMode() {
return config.getProperty("databaseMode");
}
}
Loading

0 comments on commit 16e60f2

Please sign in to comment.