Skip to content

Commit

Permalink
[Dy2stat]modify dy2stat error message in compile time (#35320)
Browse files Browse the repository at this point in the history
* modify dy2stat error message in compile time

* fix variable name
  • Loading branch information
0x45f authored Sep 1, 2021
1 parent b53887f commit b24f84c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 15 deletions.
73 changes: 61 additions & 12 deletions python/paddle/fluid/dygraph/dygraph_to_static/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import six
import sys
import traceback
import linecache

from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginInfo, global_origin_info_map

Expand All @@ -29,6 +30,9 @@
DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR = 0

SOURCE_CODE_RANGE = 5
BLANK_COUNT_BEFORE_FILE_STR = 4


def attach_error_data(error, in_runtime=False):
"""
Expand All @@ -40,6 +44,7 @@ def attach_error_data(error, in_runtime=False):
Returns:
An error attached data about original source code information and traceback.
"""

e_type, e_value, e_traceback = sys.exc_info()
tb = traceback.extract_tb(e_traceback)[1:]

Expand Down Expand Up @@ -82,12 +87,49 @@ def __init__(self, location, function_name, source_code):
def formated_message(self):
# self.source_code may be empty in some functions.
# For example, decorator generated function
return ' File "{}", line {}, in {}\n\t{}'.format(
return ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n\t{}'.format(
self.location.filepath, self.location.lineno, self.function_name,
self.source_code.lstrip()
if isinstance(self.source_code, str) else self.source_code)


class TraceBackFrameRange(OriginInfo):
"""
Traceback frame information.
"""

def __init__(self, location, function_name):
self.location = location
self.function_name = function_name
self.source_code = []
blank_count = []
begin_lineno = max(1, self.location.lineno - int(SOURCE_CODE_RANGE / 2))

for i in range(begin_lineno, begin_lineno + SOURCE_CODE_RANGE):
line = linecache.getline(self.location.filepath, i)
line_lstrip = line.strip()
self.source_code.append(line_lstrip)
blank_count.append(len(line) - len(line_lstrip))

if i == self.location.lineno:
hint_msg = '~' * len(self.source_code[-1]) + ' <--- HERE'
self.source_code.append(hint_msg)
blank_count.append(blank_count[-1])
linecache.clearcache()

min_black_count = min(blank_count)
for i in range(len(self.source_code)):
self.source_code[i] = ' ' * (blank_count[i] - min_black_count +
BLANK_COUNT_BEFORE_FILE_STR * 2
) + self.source_code[i]

def formated_message(self):
msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n'.format(
self.location.filepath, self.location.lineno, self.function_name)
# add empty line after range code
return msg + '\n'.join(self.source_code) + '\n'


class ErrorData(object):
"""
Error data attached to an exception which is raised in un-transformed code.
Expand Down Expand Up @@ -128,26 +170,34 @@ def create_message(self):
return '\n'.join(message_lines)

# Step2: Optimizes stack information with source code information of dygraph from user.
for filepath, lineno, funcname, code in self.origin_traceback:
whether_source_range = True
for filepath, lineno, funcname, code in self.origin_traceback[::-1]:
loc = Location(filepath, lineno)

dygraph_func_info = self.origin_info_map.get(loc.line_location,
None)
if dygraph_func_info:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame = self.origin_info_map[loc.line_location]
else:
traceback_frame = TraceBackFrame(loc, funcname, code)

message_lines.append(traceback_frame.formated_message())
if whether_source_range:
traceback_frame = TraceBackFrameRange(
dygraph_func_info.location,
dygraph_func_info.function_name)
whether_source_range = False
else:
traceback_frame = TraceBackFrame(
dygraph_func_info.location,
dygraph_func_info.function_name,
dygraph_func_info.source_code)
# Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2
message_lines.insert(2, traceback_frame.formated_message())

# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
# NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length
# is gather than 1, for example, the error_type is IndentationError.
format_exception = traceback.format_exception_only(self.error_type,
self.error_value)
error_message = [" " * 4 + line for line in format_exception]
error_message = [
" " * BLANK_COUNT_BEFORE_FILE_STR + line
for line in format_exception
]
message_lines.extend(error_message)

return '\n'.join(message_lines)
Expand Down Expand Up @@ -175,7 +225,6 @@ def _simplify_error_value(self):
self.error_value = self.error_type(error_value_str)

def raise_new_exception(self):

# Raises the origin error if disable dygraph2static error module,
if int(os.getenv(DISABLE_ERROR_ENV_NAME, DEFAULT_DISABLE_NEW_ERROR)):
raise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ def set_message(self):
['File "{}", line 35, in func_error_in_compile_time'.format(self.filepath),
'inner_func()',
'File "{}", line 28, in inner_func'.format(self.filepath),
'def inner_func():',
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'return',
]

def set_func_call(self):
Expand All @@ -242,7 +245,11 @@ def set_message(self):
self.expected_message = \
[
'File "{}", line 46, in func_error_in_compile_time_2'.format(self.filepath),
'x = fluid.layers.reshape(x, shape=[1, 2])'
'def func_error_in_compile_time_2(x):',
'x = fluid.dygraph.to_variable(x)',
'x = fluid.layers.reshape(x, shape=[1, 2])',
'<--- HERE',
'return x'
]


Expand All @@ -261,7 +268,10 @@ def set_exception_type(self):
def set_message(self):
self.expected_message = \
['File "{}", line 91, in forward'.format(self.filepath),
'@paddle.jit.to_static',
'def forward(self):',
'self.test_func()',
'<--- HERE'
]

def set_func_call(self):
Expand Down Expand Up @@ -318,7 +328,12 @@ def set_exception_type(self):
def set_message(self):
self.expected_message = \
['File "{}", line 80, in forward'.format(self.filepath),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'def forward(self, x):',
'y = self._linear(x)',
'z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")',
'<--- HERE',
'out = fluid.layers.mean(y[z])',
'return out'
]

def set_func_call(self):
Expand All @@ -329,7 +344,7 @@ def test_error(self):
self._test_raise_new_exception()


# Situation 4: NotImplementedError
# # Situation 4: NotImplementedError
class TestErrorInOther(unittest.TestCase):
def test(self):
paddle.disable_static()
Expand Down

0 comments on commit b24f84c

Please sign in to comment.