1616import six
1717import sys
1818import traceback
19+ import linecache
1920
2021from paddle .fluid .dygraph .dygraph_to_static .origin_info import Location , OriginInfo , global_origin_info_map
2122
2930DISABLE_ERROR_ENV_NAME = "TRANSLATOR_DISABLE_NEW_ERROR"
3031DEFAULT_DISABLE_NEW_ERROR = 0
3132
33+ SOURCE_CODE_RANGE = 5
34+ BLANK_COUNT_BEFORE_FILE_STR = 4
35+
3236
3337def attach_error_data (error , in_runtime = False ):
3438 """
@@ -40,6 +44,7 @@ def attach_error_data(error, in_runtime=False):
4044 Returns:
4145 An error attached data about original source code information and traceback.
4246 """
47+
4348 e_type , e_value , e_traceback = sys .exc_info ()
4449 tb = traceback .extract_tb (e_traceback )[1 :]
4550
@@ -82,12 +87,49 @@ def __init__(self, location, function_name, source_code):
8287 def formated_message (self ):
8388 # self.source_code may be empty in some functions.
8489 # For example, decorator generated function
85- return ' File "{}", line {}, in {}\n \t {}' .format (
90+ return ' ' * BLANK_COUNT_BEFORE_FILE_STR + ' File "{}", line {}, in {}\n \t {}' .format (
8691 self .location .filepath , self .location .lineno , self .function_name ,
8792 self .source_code .lstrip ()
8893 if isinstance (self .source_code , str ) else self .source_code )
8994
9095
96+ class TraceBackFrameRange (OriginInfo ):
97+ """
98+ Traceback frame information.
99+ """
100+
101+ def __init__ (self , location , function_name ):
102+ self .location = location
103+ self .function_name = function_name
104+ self .source_code = []
105+ blank_count = []
106+ begin_lineno = max (1 , self .location .lineno - int (SOURCE_CODE_RANGE / 2 ))
107+
108+ for i in range (begin_lineno , begin_lineno + SOURCE_CODE_RANGE ):
109+ line = linecache .getline (self .location .filepath , i )
110+ line_lstrip = line .strip ()
111+ self .source_code .append (line_lstrip )
112+ blank_count .append (len (line ) - len (line_lstrip ))
113+
114+ if i == self .location .lineno :
115+ hint_msg = '~' * len (self .source_code [- 1 ]) + ' <--- HERE'
116+ self .source_code .append (hint_msg )
117+ blank_count .append (blank_count [- 1 ])
118+ linecache .clearcache ()
119+
120+ min_black_count = min (blank_count )
121+ for i in range (len (self .source_code )):
122+ self .source_code [i ] = ' ' * (blank_count [i ] - min_black_count +
123+ BLANK_COUNT_BEFORE_FILE_STR * 2
124+ ) + self .source_code [i ]
125+
126+ def formated_message (self ):
127+ msg = ' ' * BLANK_COUNT_BEFORE_FILE_STR + 'File "{}", line {}, in {}\n ' .format (
128+ self .location .filepath , self .location .lineno , self .function_name )
129+ # add empty line after range code
130+ return msg + '\n ' .join (self .source_code ) + '\n '
131+
132+
91133class ErrorData (object ):
92134 """
93135 Error data attached to an exception which is raised in un-transformed code.
@@ -128,26 +170,34 @@ def create_message(self):
128170 return '\n ' .join (message_lines )
129171
130172 # Step2: Optimizes stack information with source code information of dygraph from user.
131- for filepath , lineno , funcname , code in self .origin_traceback :
173+ whether_source_range = True
174+ for filepath , lineno , funcname , code in self .origin_traceback [::- 1 ]:
132175 loc = Location (filepath , lineno )
133-
134176 dygraph_func_info = self .origin_info_map .get (loc .line_location ,
135177 None )
136178 if dygraph_func_info :
137- # TODO(liym27): more information to prompt users that this is the original information.
138- # Replaces trace stack information about transformed static code with original dygraph code.
139- traceback_frame = self .origin_info_map [loc .line_location ]
140- else :
141- traceback_frame = TraceBackFrame (loc , funcname , code )
142-
143- message_lines .append (traceback_frame .formated_message ())
179+ if whether_source_range :
180+ traceback_frame = TraceBackFrameRange (
181+ dygraph_func_info .location ,
182+ dygraph_func_info .function_name )
183+ whether_source_range = False
184+ else :
185+ traceback_frame = TraceBackFrame (
186+ dygraph_func_info .location ,
187+ dygraph_func_info .function_name ,
188+ dygraph_func_info .source_code )
189+ # Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2
190+ message_lines .insert (2 , traceback_frame .formated_message ())
144191
145192 # Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
146193 # NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length
147194 # is gather than 1, for example, the error_type is IndentationError.
148195 format_exception = traceback .format_exception_only (self .error_type ,
149196 self .error_value )
150- error_message = [" " * 4 + line for line in format_exception ]
197+ error_message = [
198+ " " * BLANK_COUNT_BEFORE_FILE_STR + line
199+ for line in format_exception
200+ ]
151201 message_lines .extend (error_message )
152202
153203 return '\n ' .join (message_lines )
@@ -175,7 +225,6 @@ def _simplify_error_value(self):
175225 self .error_value = self .error_type (error_value_str )
176226
177227 def raise_new_exception (self ):
178-
179228 # Raises the origin error if disable dygraph2static error module,
180229 if int (os .getenv (DISABLE_ERROR_ENV_NAME , DEFAULT_DISABLE_NEW_ERROR )):
181230 raise
0 commit comments