@@ -926,9 +926,10 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
926926 inputs = train_runnable_program .x_values
927927 params = train_runnable_program .param_values
928928 combined_inputs = list (itertools .chain (inputs , params ))
929- forward_end_idx = len (program .global_block ().ops )
929+ forward_prog_len = len (program .global_block ().ops )
930+ forward_end_idx = forward_prog_len - 1
930931 forward_end_op = None
931- if forward_end_idx > 0 :
932+ if forward_prog_len > 0 :
932933 forward_end_op = program .global_block ().ops [- 1 ]
933934 grad_info_map = [None ] * len (combined_inputs )
934935 with backend_guard (self ._backend ):
@@ -958,7 +959,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
958959 "grad_input_" ,
959960 )
960961 op_between_forward_and_backward = (
961- len (program .global_block ().ops ) - forward_end_idx
962+ len (program .global_block ().ops ) - forward_prog_len
962963 )
963964
964965 # call grad to get backward ops.
@@ -985,7 +986,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
985986 if forward_end_op is not None :
986987 for idx , op in enumerate (program .global_block ().ops ):
987988 if op == forward_end_op :
988- forward_end_idx = idx + 1
989+ forward_end_idx = idx
989990 break
990991
991992 for hooker in self ._hookers :
@@ -1019,11 +1020,12 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
10191020 output_grads_to_append = list (
10201021 filter (lambda x : not is_fake_value (x ), x_grad_value + p_grad_value )
10211022 )
1022- backward_end_op_index = len (program .global_block ().ops )
1023+ backward_prog_len = len (program .global_block ().ops )
1024+ backward_end_op_index = backward_prog_len - 1
10231025 paddle .base .libpaddle .pir .append_shadow_outputs (
10241026 program ,
10251027 output_grads_to_append ,
1026- backward_end_op_index ,
1028+ backward_prog_len ,
10271029 "grad_output_" ,
10281030 )
10291031
@@ -1036,7 +1038,11 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
10361038 [inputs , params , targets , x_grad_value , p_grad_value , o_grad_value ]
10371039 )
10381040 forward_index_pass = IndicesPreservePass (
1039- [forward_end_idx , backward_start_op_index , backward_end_op_index ],
1041+ [
1042+ forward_end_idx + 1 ,
1043+ backward_start_op_index + 1 ,
1044+ backward_end_op_index + 1 ,
1045+ ],
10401046 fused_bn_add_act_pass ,
10411047 )
10421048 program = forward_index_pass (program )
@@ -1049,17 +1055,17 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
10491055 o_grad_value ,
10501056 ) = fused_bn_add_act_pass .values
10511057 (
1052- forward_end_idx ,
1053- backward_start_op_index ,
1054- backward_end_op_index ,
1058+ forward_end_range ,
1059+ backward_start_range ,
1060+ backward_end_op_range ,
10551061 ) = forward_index_pass .new_indices
10561062
10571063 return RunnableProgram (
10581064 program ,
10591065 (inputs , params , targets ),
10601066 (x_grad_value , p_grad_value , o_grad_value ),
1061- (0 , forward_end_idx ),
1062- (backward_start_op_index , backward_end_op_index ),
1067+ (0 , forward_end_range ),
1068+ (backward_start_range , backward_end_op_range ),
10631069 )
10641070
10651071 def _prepare_attributes (self , in_sot_mode = False ):
0 commit comments