Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Initial implementation for Rust #762

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions data_prep/introspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
INTROSPECTOR_ORACLE_FAR_REACH = ''
INTROSPECTOR_ORACLE_KEYWORD = ''
INTROSPECTOR_ORACLE_EASY_PARAMS = ''
INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES = ''
INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = ''
INTROSPECTOR_ORACLE_OPTIMAL = ''
INTROSPECTOR_ORACLE_ALL_TESTS = ''
INTROSPECTOR_FUNCTION_SOURCE = ''
Expand Down Expand Up @@ -90,6 +90,7 @@ def get_oracle_dict() -> Dict[str, Any]:
'jvm-public-candidates': query_introspector_jvm_all_public_candidates,
'optimal-targets': query_introspector_for_optimal_targets,
'test-migration': query_introspector_for_tests,
'all-public-candidates': query_introspector_all_public_candidates,
}
return oracle_dict

Expand All @@ -102,7 +103,7 @@ def set_introspector_endpoints(endpoint):
INTROSPECTOR_ORACLE_KEYWORD, INTROSPECTOR_ADDR_TYPE, \
INTROSPECTOR_ALL_HEADER_FILES, INTROSPECTOR_ALL_FUNC_TYPES, \
INTROSPECTOR_SAMPLE_XREFS, INTROSPECTOR_ORACLE_EASY_PARAMS, \
INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES, \
INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES, \
INTROSPECTOR_ALL_JVM_SOURCE_PATH, INTROSPECTOR_ORACLE_OPTIMAL, \
INTROSPECTOR_HEADERS_FOR_FUNC, \
INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, \
Expand All @@ -119,7 +120,7 @@ def set_introspector_endpoints(endpoint):
f'{INTROSPECTOR_ENDPOINT}/far-reach-low-cov-fuzz-keyword')
INTROSPECTOR_ORACLE_EASY_PARAMS = (
f'{INTROSPECTOR_ENDPOINT}/easy-params-far-reach')
INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES = (
INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = (
f'{INTROSPECTOR_ENDPOINT}/all-public-candidates')
INTROSPECTOR_ORACLE_OPTIMAL = f'{INTROSPECTOR_ENDPOINT}/optimal-targets'
INTROSPECTOR_FUNCTION_SOURCE = f'{INTROSPECTOR_ENDPOINT}/function-source-code'
Expand Down Expand Up @@ -277,8 +278,17 @@ def query_introspector_jvm_all_public_candidates(project: str) -> list[dict]:
"""Queries Fuzz Introspector for all public accessible function or
constructor candidates.
"""
return query_introspector_oracle(
project, INTROSPECTOR_ORACLE_ALL_JVM_PUBLIC_CANDIDATES)
return query_introspector_oracle(project,
INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES)


def query_introspector_all_public_candidates(project: str) -> list[dict]:
"""Queries Fuzz Introspector for all public accessible function or
constructor candidates.
"""
#TODO May combine this with query_introspector_jvm_all_public_candidates
return query_introspector_oracle(project,
INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES)


def query_introspector_for_targets(project, target_oracle) -> list[Dict]:
Expand Down Expand Up @@ -859,7 +869,7 @@ def populate_benchmarks_using_introspector(project: str, language: str,
# arguments. Thus skipping it.
continue

if language == 'jvm':
elif language == 'jvm':
# Retrieve list of source file from introspector
src_path_list = query_introspector_jvm_source_path(project)
if src_path_list:
Expand All @@ -872,7 +882,8 @@ def populate_benchmarks_using_introspector(project: str, language: str,
if src_file not in src_path_list:
logger.error('error: %s %s', filename, interesting.keys())
continue
elif language != 'python' and interesting and filename not in [

elif language != 'rust' and interesting and filename not in [
os.path.basename(i) for i in interesting.keys()
]:
# TODO: Bazel messes up paths to include "/proc/self/cwd/..."
Expand Down
8 changes: 8 additions & 0 deletions data_prep/project_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def _get_harness(src_file: str, out: str, language: str) -> tuple[str, str]:
return '', ''
if language.lower() == 'python' and 'atheris.Fuzz()' not in content:
return '', ''
if language.lower() == 'rust' and 'fuzz_target!' not in content:
return '', ''

short_path = src_file[len(out):]
return short_path, content
Expand Down Expand Up @@ -307,6 +309,12 @@ def _identify_fuzz_targets(out: str, interesting_filenames: list[str],
interesting_filepaths.append(path)
if path.endswith('.py'):
potential_harnesses.append(path)
elif language == 'rust':
# For Rust
if path.endswith(tuple(interesting_filenames)):
interesting_filepaths.append(path)
if path.endswith('.rs'):
potential_harnesses.append(path)
else:
# For C/C++
short_path = path[len(out):]
Expand Down
7 changes: 7 additions & 0 deletions experiment/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ def __init__(self,
# zipp-zipp.difference.
self.id = self.id.replace('._', '.')

if self.language == 'rust':
# For rust projects, double colon (::) is sometime used to identify
# crate, impl or trait name of a function. This could affect the
# benchmark_id and cause OSS-Fuzz build failed.
# Special handling of benchmark_id is needed to avoid this situation.
self.id = self.id.replace('::', '-')

def __str__(self):
return (f'Benchmark<id={self.id}, project={self.project}, '
f'language={self.language}, '
Expand Down
32 changes: 27 additions & 5 deletions experiment/builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ def _contains_target_python_function(self, target_path: str) -> bool:

return min_func_name in generated_code

def _contains_target_rust_function(self, target_path: str) -> bool:
"""Validates if the LLM-generated code contains the target function for
rust projects."""
with open(target_path) as generated_code_file:
generated_code = generated_code_file.read()

min_func_name = self._get_minimum_func_name(
self.benchmark.function_signature)

# Retrieve function name only with crate, triat, impl or mod tag
min_func_name = min_func_name.rsplit('::', 1)[-1]
min_func_name = min_func_name.rsplit('.', 1)[-1]

return min_func_name in generated_code

def _pre_build_check(self, target_path: str,
build_result: BuildResult) -> bool:
"""Checks the generated target before building and running it."""
Expand All @@ -204,7 +219,10 @@ def _pre_build_check(self, target_path: str,
result = self._contains_target_jvm_method(target_path)
elif self.benchmark.language == 'python':
result = self._contains_target_python_function(target_path)
elif self.benchmark.language == 'rust':
result = self._contains_target_rust_function(target_path)
else:
# For C/C++
result = self._contains_target_function(target_path)

if not result:
Expand Down Expand Up @@ -482,8 +500,8 @@ def build_and_run_local(
build_result.succeeded = self.build_target_local(generated_project,
benchmark_log_path)

# Copy err.log into work dir (Ignored for JVM projects)
if language != 'jvm':
# Copy err.log into work dir (Ignored for JVM/Rust projects)
if language not in ['jvm', 'rust']:
try:
shutil.copyfile(
os.path.join(get_build_artifact_dir(generated_project, "workspace"),
Expand Down Expand Up @@ -514,7 +532,8 @@ def build_and_run_local(
# In many case JVM/python projects won't have much cov
# difference in short running. Adding the flag for JVM/python
# projects to temporary skip the checking of coverage change.
flag = not self.benchmark.language in ['jvm', 'python']
# Also skipping for rust projects in initial implementation.
flag = not self.benchmark.language in ['jvm', 'python', 'rust']
run_result.cov_pcs, run_result.total_pcs, \
run_result.crashes, run_result.crash_info, \
run_result.semantic_check = \
Expand Down Expand Up @@ -683,7 +702,8 @@ def _get_coverage_text_filename(self, project_name: str) -> str:
'jvm': 'jacoco.xml',
'python': 'all_cov.json',
'c++': f'{self.benchmark.target_name}.covreport',
'c': f'{self.benchmark.target_name}.covreport'
'c': f'{self.benchmark.target_name}.covreport',
'rust': f'{self.benchmark.target_name}.covreport',
}

return os.path.join(get_build_artifact_dir(project_name,
Expand All @@ -699,6 +719,7 @@ def _extract_local_textcoverage_data(self,
'python': 'r',
'c': 'rb',
'c++': 'rb',
'rust': 'rb',
}
with open(local_textcov_location,
language_modes.get(self.benchmark.language, 'rb')) as f:
Expand Down Expand Up @@ -1040,7 +1061,7 @@ def build_and_run_cloud(
self._copy_textcov_to_workdir(bucket, textcov_blob_path,
generated_target_name)
else:
# C/C++
# C/C++/Rust
blob = bucket.blob(textcov_blob_path)
if blob.exists():
with blob.open('rb') as f:
Expand Down Expand Up @@ -1082,6 +1103,7 @@ def _get_cloud_textcov_path(self, coverage_name: str) -> str:
if self.benchmark.language == 'python':
return f'{coverage_name}/textcov_reports/all_cov.json'

# For C/C++/Rust
return (f'{coverage_name}/textcov_reports/{self.benchmark.target_name}'
'.covreport')

Expand Down
16 changes: 16 additions & 0 deletions llm_toolkit/code_fixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,22 @@ def extract_error_message(log_path: str, project_target_basename: str,

return errors

# Error message extraction for Rust projects
if language == 'rust':
started = False
errors = []
for log_line in log_lines:
if started:
errors.append(log_line)
if log_line == 'error: could not compile':
break
else:
if log_line.startswith(('error[E', 'warning:')):
errors.append(log_line)
started = True

return errors

target_name, _ = os.path.splitext(project_target_basename)

error_lines_range: list[Optional[int]] = [None, None]
Expand Down
1 change: 1 addition & 0 deletions llm_toolkit/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def parse_code(response_path: str) -> str:
lines = _parse_code_block_by_marker(lines, '```c', '```')
lines = _parse_code_block_by_marker(lines, '```java', '```')
lines = _parse_code_block_by_marker(lines, '```python', '```')
lines = _parse_code_block_by_marker(lines, '```rust', '```')
lines = _parse_code_block_by_marker(lines, '```java_code', '```')
lines = _parse_code_block_by_marker(lines, '<code>', '</code>')
lines = _parse_code_block_by_marker(lines, '<java_code>', '</java_code>')
Expand Down
92 changes: 92 additions & 0 deletions llm_toolkit/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,98 @@ def post_process_generated_code(self, generated_code: str) -> str:
return generated_code


class DefaultRustTemplateBuilder(PromptBuilder):
"""Default builder for Rust projects."""

def __init__(self,
model: models.LLM,
benchmark: Benchmark,
template_dir: str = DEFAULT_TEMPLATE_DIR):
super().__init__(model)
self._template_dir = template_dir
self.benchmark = benchmark
self.project_url = oss_fuzz_checkout.get_project_repository(
self.benchmark.project)

# Load templates.
self.base_template_file = self._find_template(template_dir, 'rust_base.txt')
self.problem_template_file = self._find_template(template_dir,
'rust_problem.txt')

def _find_template(self, template_dir: str, template_name: str) -> str:
"""Finds template file based on |template_dir|."""
preferred_template = os.path.join(template_dir, template_name)
# Use the preferred template if it exists.
if os.path.isfile(preferred_template):
return preferred_template

# Fall back to the default template.
default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name)
return default_template

def _get_template(self, template_file: str) -> str:
"""Reads the template for prompts."""
with open(template_file) as file:
return file.read()

def _format_target(self, signature: str) -> str:
"""Format the target function for the prompts creation."""
target = self._get_template(self.problem_template_file)
arg_count = len(self.benchmark.params)
arg_type = [arg_dict['type'] for arg_dict in self.benchmark.params]

target = target.replace('{FUNCTION_SIGNATURE}', signature)
target = target.replace('{ARG_COUNT}', str(arg_count))
target = target.replace('{ARG_TYPE}', ','.join(arg_type))

return target

def _format_problem(self, signature: str) -> str:
"""Formats a problem based on the prompt template."""
base = self._get_template(self.base_template_file)
target_str = self._format_target(signature)

problem = base + target_str
problem = problem.replace("{PROJECT_NAME}", self.benchmark.project)
problem = problem.replace("{PROJECT_URL}", self.project_url)

return problem

def _prepare_prompt(self, prompt_str: str):
"""Constructs a prompt using the parameters and saves it."""
self._prompt.add_priming(prompt_str)

def build(self,
example_pair: list[list[str]],
project_example_content: Optional[list[list[str]]] = None,
project_context_content: Optional[dict] = None) -> prompts.Prompt:
"""Constructs a prompt using the templates in |self| and saves it.
Ignore target_file_type, project_example_content
and project_context_content parameters.
"""
final_problem = self._format_problem(self.benchmark.function_signature)
self._prepare_prompt(final_problem)
return self._prompt

def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str,
error_desc: Optional[str],
errors: list[str]) -> prompts.Prompt:
"""Builds a fixer prompt."""
# Do nothing for rust project now.
return self._prompt

def build_triager_prompt(self, benchmark: Benchmark, driver_code: str,
crash_info: str, crash_func: dict) -> prompts.Prompt:
"""Builds a triager prompt."""
# Do nothing for rust project now.
return self._prompt

def post_process_generated_code(self, generated_code: str) -> str:
"""Allows prompt builder to adjust the generated code."""
# Do nothing for rust project now.
return generated_code


class JvmErrorFixingBuilder(PromptBuilder):
"""Prompt builder for fixing JVM harness with complication error."""

Expand Down
3 changes: 3 additions & 0 deletions prompts/template_xml/rust_base.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
You are a security testing engineer who wants to write a Rust program to execute all lines in a given method by defining and initialising its parameters and necessary objects in a suitable way before fuzzing the method.
The <target> tag contains information of the target method to invoke.
The <requirements> tag contains additional requirements that you MUST follow for this code generation.
36 changes: 36 additions & 0 deletions prompts/template_xml/rust_problem.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<task>
Your goal is to write a fuzzing harness for the provided method signature to fuzz the method with random data. It is important that the provided solution compiles and actually calls the function specified by the method signature:
<target>
<function_signature>
{FUNCTION_SIGNATURE}
</function_signature>
The target function is belonging to the Rust project {PROJECT_NAME} ({PROJECT_URL}).
You MUST call to this target function in the original project, NOT creating a dummy function.
This function requires {ARG_COUNT} arguments. You must prepare them with random seeded data.
Here is a list of types for all arguments in order, separated by comma. You MUST preserve the modifiers.
{ARG_TYPE}
</target>
<requirements>
<item>Try as many variations of these inputs as possible.</item>
<item>Try creating the harness as complex as possible.</item>
<item>Try adding some nested loop to invoke the target method for multiple times.</item>
<item>The generated fuzzing harness should be wrapped with the <code> tag.</item>
<item>Please avoid using any multithreading or multi-processing approach.</item>
<item>You MUST create the fuzzing harness using Cargo-Fuzz approach.</item>
<item>You MUST use the #![no_main] tag.</item>
<item>You MUST use the libfuzzer_sys::fuzz_target crate.</item>
<item>You MUST include the fuzz_target macro to include all fuzzing statements.</item>
<item>You MUST include the use of the necessary functions and crate for calling the target function.</item>
<item>You MUST generate the harness with the assumption that the Cargo.toml for the Cargo-Fuzz directory cannot be changed.</item>
<item>The following is a sample of the fuzzing harness.
<code>
#![no_main]
use libfuzzer_sys::fuzz_target;

fuzz_target!(|data: &[u8]| {
// This is the macro acts as the entry point for the fuzzing harness.
// Add fuzzing logic here.
});
</code></item>
</requirements>
</task>
3 changes: 2 additions & 1 deletion report/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

MAX_RUN_LOGS_LEN = 16 * 1024

TARGET_EXTS = project_src.SEARCH_EXTS + ['.java', '.py'] + ['.fuzz_target']
TARGET_EXTS = project_src.SEARCH_EXTS + ['.java', '.py', '.rs'
] + ['.fuzz_target']

_CHAT_PROMPT_START_MARKER = re.compile(r'<CHAT PROMPT:ROUND\s+\d+>')
_CHAT_PROMPT_END_MARKER = re.compile(r'</CHAT PROMPT:ROUND\s+\d+>')
Expand Down