Skip to content

Commit 84b5ade

Browse files
committed
fix recompute index bug
1 parent b2a43a7 commit 84b5ade

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

python/paddle/base/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,15 @@ def _enable_dist_prim_all():
567567

568568
def _enable_auto_recompute():
569569
flag = os.getenv("FLAGS_enable_auto_recompute")
570+
571+
# NOTE(chenxi67): open recompute when cinn is enabled
572+
from paddle.base.framework import in_cinn_mode
573+
574+
if in_cinn_mode():
575+
if flag and flag.lower() in ("0", "false"):
576+
return False
577+
else:
578+
return True
570579
if flag and flag.lower() in ("1", "true"):
571580
return True
572581
else:

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -926,9 +926,10 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
926926
inputs = train_runnable_program.x_values
927927
params = train_runnable_program.param_values
928928
combined_inputs = list(itertools.chain(inputs, params))
929-
forward_end_idx = len(program.global_block().ops)
929+
forward_prog_len = len(program.global_block().ops)
930+
forward_end_idx = forward_prog_len - 1
930931
forward_end_op = None
931-
if forward_end_idx > 0:
932+
if forward_prog_len > 0:
932933
forward_end_op = program.global_block().ops[-1]
933934
grad_info_map = [None] * len(combined_inputs)
934935
with backend_guard(self._backend):
@@ -958,7 +959,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
958959
"grad_input_",
959960
)
960961
op_between_forward_and_backward = (
961-
len(program.global_block().ops) - forward_end_idx
962+
len(program.global_block().ops) - forward_prog_len
962963
)
963964

964965
# call grad to get backward ops.
@@ -985,7 +986,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
985986
if forward_end_op is not None:
986987
for idx, op in enumerate(program.global_block().ops):
987988
if op == forward_end_op:
988-
forward_end_idx = idx + 1
989+
forward_end_idx = idx
989990
break
990991

991992
for hooker in self._hookers:
@@ -1019,11 +1020,12 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
10191020
output_grads_to_append = list(
10201021
filter(lambda x: not is_fake_value(x), x_grad_value + p_grad_value)
10211022
)
1022-
backward_end_op_index = len(program.global_block().ops)
1023+
backward_prog_len = len(program.global_block().ops)
1024+
backward_end_op_index = backward_prog_len - 1
10231025
paddle.base.libpaddle.pir.append_shadow_outputs(
10241026
program,
10251027
output_grads_to_append,
1026-
backward_end_op_index,
1028+
backward_prog_len,
10271029
"grad_output_",
10281030
)
10291031

@@ -1036,7 +1038,11 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
10361038
[inputs, params, targets, x_grad_value, p_grad_value, o_grad_value]
10371039
)
10381040
forward_index_pass = IndicesPreservePass(
1039-
[forward_end_idx, backward_start_op_index, backward_end_op_index],
1041+
[
1042+
forward_end_idx + 1,
1043+
backward_start_op_index + 1,
1044+
backward_end_op_index + 1,
1045+
],
10401046
fused_bn_add_act_pass,
10411047
)
10421048
program = forward_index_pass(program)
@@ -1049,17 +1055,17 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
10491055
o_grad_value,
10501056
) = fused_bn_add_act_pass.values
10511057
(
1052-
forward_end_idx,
1053-
backward_start_op_index,
1054-
backward_end_op_index,
1058+
forward_end_range,
1059+
backward_start_range,
1060+
backward_end_op_range,
10551061
) = forward_index_pass.new_indices
10561062

10571063
return RunnableProgram(
10581064
program,
10591065
(inputs, params, targets),
10601066
(x_grad_value, p_grad_value, o_grad_value),
1061-
(0, forward_end_idx),
1062-
(backward_start_op_index, backward_end_op_index),
1067+
(0, forward_end_range),
1068+
(backward_start_range, backward_end_op_range),
10631069
)
10641070

10651071
def _prepare_attributes(self, in_sot_mode=False):

0 commit comments

Comments
 (0)