Skip to content
This repository has been archived by the owner on Jan 22, 2024. It is now read-only.

Synthetically add runtime errors to the dataset. #3

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added __init__.py
Empty file.
Empty file added error_generation/__init__.py
Empty file.
93 changes: 93 additions & 0 deletions error_generation/add_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (C) 2021 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import redbaron as rb
from misc_utils import (
get_random_list_sample,
load_json,
load_data,
write_csv,
get_random_int,
get_random_float,
get_valid_code_trace,
)

VARIABLE_NAMES = ["tmp{}".format(i) for i in range(10)] + [
chr(ord("a") + i) for i in range(26)
]
NUM_RANGE = [1, 100]


def get_perturb_line_step(code_trace_org, err_suffx):
# Not all the variable types are valid for all the errors.
# For instance for index out of range an int var is not valid.
code_trace = get_valid_code_trace(code_trace_org, err_suffx)
if not code_trace:
return None, None, None
perturb_line = get_random_list_sample(code_trace.keys(), 1)[0]
perturb_step = get_random_list_sample(code_trace[perturb_line], 1)[0]
perturb_var = get_random_list_sample(perturb_step.keys(), 1)[0]
perturb_val = perturb_step[perturb_var]
return int(perturb_line), perturb_var, perturb_val


def perturb_program(red, code_trace, err_suffx, error_expr_factory_obj):
perturb_line, perturb_var, perturb_val = get_perturb_line_step(
code_trace, err_suffx
)
if perturb_line is None:
return 0
perturb_expression, is_err_present = error_expr_factory_obj.add_err(
err_suffx, perturb_var, perturb_val
)
# TODO(rishab): Need to be careful to ensure that that the insertion
# line is not an AssignmentNode in RedBaron.
if err_suffx == "math_domain_err":
# The sqrt function needs to be imported so that sqrt function
# can be called. I am not sure if we can just add the expression
# without proper imports.
import_statement, perturb_expression = perturb_expression.split(";")
red.at(perturb_line).insert_before(import_statement, offset=perturb_line - 1)
red.at(perturb_line + 1).insert_after(perturb_expression)
else:
red.at(perturb_line).insert_after(perturb_expression)
return is_err_present


def add_error(
org_code_fp, code_trace_fp, err_code_fp, err_suffx, error_expr_factory_obj
):
# We can optimize the code by passing the read file.
# But for now to ensure isolation, I am doing it
# explicitly.
code_trace = load_json(code_trace_fp)
# To keep this function generic the name of the output
# code file has the error type and indicator whether the
# the error is present or not as suffix.
err_code_fp = err_code_fp.replace(".txt", "-" + err_suffx + ".txt")
program = load_data(org_code_fp).strip()
red = rb.RedBaron(program)
try:
is_err_present = perturb_program(
red, code_trace, err_suffx, error_expr_factory_obj
)
err_code_fp = err_code_fp.replace(".txt", "-" + str(is_err_present) + ".txt")
except Exception as e:
# We can handle the exception as we want.
# But for the time being we can return False.
# import pdb;pdb.set_trace()
return False

write_csv(red.dumps(), err_code_fp)
return True
10 changes: 10 additions & 0 deletions error_generation/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
base_path: /Users/rishabgoel/Documents/compressive-ipagnn/data/codeforces
trace_code_path: error_generation/trace_code.py
process_suffix: processed
errors:
- zero_err
- assert_err
- not_subscriptable_err
- idx_out_range_err
- math_domain_err
- not_iterable_err
231 changes: 231 additions & 0 deletions error_generation/error_expression_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright (C) 2021 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from misc_utils import (
get_random_list_sample,
load_json,
load_data,
write_csv,
get_random_int,
get_random_float,
)


class ErrorFactory:
"""TODO (rishab):
1. implement methods for var name not defined and
operand mismatch.
2. make the expressions more complex.
"""

VARIABLE_NAMES = ["tmp{}".format(i) for i in range(10)] + [
chr(ord("a") + i) for i in range(26)
]
NUM_RANGE = [1, 100]

def __init__(self):
self._builders = {
"zero_err": self.get_zero_perturb_expression,
"assert_err": self.get_assert_perturb_expression,
"not_subscriptable_err": self.get_not_subscriptable_perturb_expression,
"idx_out_range_err": self.get_index_range_perturb_expression,
"undef_var_err": self.get_undef_name_perturb_expression, # Caution: not implemented properly.
"math_domain_err": self.get_math_domain_perturb_expression,
"not_iterable_err": self.get_int_not_iterable_perturb_expression,
}

def get_zero_perturb_expression(self, perturb_var, perturb_val):
assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0]
is_zerro_err = get_random_int(0, 1)
# is_zerro_err = 1
if is_zerro_err:
numerator = get_random_float(*self.NUM_RANGE, size=1)[0]
return (
assign_var
+ "="
+ str(int(numerator))
+ "/("
+ str(perturb_val)
+ "-"
+ perturb_var
+ ")",
is_zerro_err,
)
else:
perturb_val_offset, numerator = get_random_float(*self.NUM_RANGE, size=2)
perturb_val = perturb_val + int(perturb_val_offset)
return (
assign_var
+ "="
+ str(int(numerator))
+ "/("
+ str(perturb_val)
+ "-"
+ perturb_var
+ ")",
is_zerro_err,
)

def get_assert_perturb_expression(self, perturb_var, perturb_val):
is_assert_err = get_random_int(0, 1)
# is_assert_err = 1
if is_assert_err:
perturb_val_offset = get_random_float(*self.NUM_RANGE, size=1)[0]
perturb_val = perturb_val + int(perturb_val_offset)
return (
"assert " + perturb_var + "==" + str(perturb_val),
is_assert_err,
)
else:
return (
"assert " + perturb_var + "==" + str(perturb_val),
is_assert_err,
)

def get_not_subscriptable_perturb_expression(self, perturb_var, perturb_val):
is_not_subscriptable_err = get_random_int(0, 1)
# is_not_subscriptable_err = 1
if is_not_subscriptable_err:
random_val, numerator = get_random_float(*self.NUM_RANGE, size=2)
return (
perturb_var + "[" + str(int(numerator)) + "] = " + str(int(random_val)),
is_not_subscriptable_err,
)
else:
return (
"",
is_not_subscriptable_err,
)

def get_index_range_perturb_expression(self, perturb_var, perturb_val):
"""This will occur very less frequently and hence we perhaps
need to rethink how to handle generate the error.
"""
is_index_range_err = get_random_int(0, 1)
# is_index_range_err = 1
if is_index_range_err:
random_ass = get_random_float(*self.NUM_RANGE, size=1)[0]
return (
perturb_var
+ "["
+ str(len(perturb_val))
+ "] = "
+ str(int(random_ass)),
is_index_range_err,
)
else:
valid_idx = int(get_random_float(*[0, len(perturb_val) - 1], size=1)[0])
random_ass = get_random_float(*self.NUM_RANGE, size=1)[0]
return (
perturb_var + "[" + str(valid_idx) + "] = " + str(random_ass),
is_index_range_err,
)

def get_undef_name_perturb_expression(self, perturb_var, perturb_val):
"""Not implemented as per our requirements."""
is_undef_name_err = get_random_int(0, 1)
# is_undef_name_err = 1
if is_undef_name_err:
undef_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0]
return (
perturb_var + "=" + undef_var + "+" + str(perturb_val),
is_undef_name_err,
)
else:
return (
"",
is_undef_name_err,
)

def get_math_domain_perturb_expression(self, perturb_var, perturb_val):
"""The current implementation may cause unforeseen issues when the
is_math_domain_err is 0 as the assign_var can be a part of the program. Also, we may
perhaps need to refine how we import math module."""
is_math_domain_err = get_random_int(0, 1)
# is_math_domain_err = 1
if is_math_domain_err:
assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0]
if perturb_val >= 0:
random_ass = (
str(-1 * int(get_random_float(*self.NUM_RANGE, size=1)[0]))
+ "*"
+ perturb_var
)
else:
random_ass = (
str(int(get_random_float(*self.NUM_RANGE, size=1)[0]))
+ "*"
+ perturb_var
)
return (
"import math;"
+ assign_var
+ "="
+ "math.sqrt("
+ str(random_ass)
+ ")",
is_math_domain_err,
)
else:
assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0]
if perturb_val >= 0:
random_ass = (
str(int(get_random_float(*self.NUM_RANGE, size=1)[0]))
+ "*"
+ perturb_var
)
else:
random_ass = (
str(-1 * int(get_random_float(*self.NUM_RANGE, size=1)[0]))
+ "*"
+ perturb_var
)
return (
"import math;"
+ assign_var
+ "="
+ "math.sqrt("
+ str(random_ass)
+ ")",
is_math_domain_err,
)

def _relevant_operand_val_type(self, val, is_same):
pass

def get_operand_type_mismatch_perturb_expression(self, perturb_var, perturb_val):
pass

def get_int_not_iterable_perturb_expression(self, perturb_var, perturb_val):
"""TODO: 1. Add more variants of the for loop.
2. Add logic to include the while loop.
"""
is_int_not_iterable_err = get_random_int(0, 1)
# is_int_not_iterable_err = 1
if is_int_not_iterable_err:
assign_var = get_random_list_sample(self.VARIABLE_NAMES, 1)[0]
random_ass = int(get_random_float(*self.NUM_RANGE, size=1)[0])
return (
"{}=[{}+val for val in {}]".format(assign_var, random_ass, perturb_var),
is_int_not_iterable_err,
)
else:
return "", is_int_not_iterable_err

def add_err(self, err_type, perturb_var, perturb_val):
expr_builder = self._builders.get(err_type.lower(), None)
if not expr_builder:
raise ValueError(err_type + " is not a valid error generation function.")
return expr_builder(perturb_var, perturb_val)
Loading