Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ bool ParsePlace(const pir::Type& type, OpFuncType* type_) {
return true;
}
}
} else if (!type) {
return false;
} else {
PADDLE_THROW(common::errors::PreconditionNotMet(
"Only support AllocatedDenseTensorType and "
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,20 @@ std::vector<std::vector<pir::Value>> TuplePushOpVjpInterfaceModel::Vjp(
res[i].resize(1);
res[i][0] = pop_op.result(i - 1);
}

// set pop op stop_gradient attribute.
std::vector<pir::Attribute> outs_stop_gradient;
for (auto i = 1u; i < op->num_operands(); ++i) {
auto value = op->operand_source(i);
auto bool_attr = value.attribute<pir::BoolAttribute>(kStopGradientAttrName);
outs_stop_gradient.push_back(
bool_attr ? bool_attr
: pir::BoolAttribute::get(pir::IrContext::Instance(), true));
}

pop_op->set_attribute(
kStopGradientAttrName,
pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient));
return res;
}

Expand Down
43 changes: 43 additions & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,46 @@ def update_while_output_stopgradient(while_op, yield_op):
# Set to False if stop_gradient is False
if not stop_grad:
while_op.result(i - 1).stop_gradient = False


def find_index_of_yiled(value, yield_op):
for i, v in enumerate(yield_op.operands_source()):
if v.is_same(value):
return i
return -1


def update_tuple_pop_origin_inputs(tuple_pop_outputs):
if tuple_pop_outputs == []:
return tuple_pop_outputs
op = tuple_pop_outputs[0][0].get_defining_op()
assert op.name() == "cf.tuple_pop"
stack_op = op.operand_source(0).get_defining_op()
tuple_push_inputs = stack_op.result(1).first_use().owner().operands_source()
tuple_push_inputs_with_if = []
for input in tuple_push_inputs:
if input.first_use().owner().name() == "cf.yield":
yield_op = input.first_use().owner()
index = find_index_of_yiled(input, yield_op)
assert index != -1
tuple_push_inputs_with_if.append(
yield_op.get_parent_block().parent_op.result(index)
)
else:
tuple_push_inputs_with_if.append(input)

# pass inlets
return tuple_push_inputs_with_if[1:]


def value_in_block(value, block):
value_block = value.get_defining_op().get_parent_block()
while block.parent_op.name() != "builtin.module":
if block == value_block:
return True
block = block.parent_block
# now block is module op's block
if block == value_block:
return True

return False
92 changes: 90 additions & 2 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
some_in_set,
update_if_output_stopgradient,
update_no_grad_set_by_stopgradient,
update_tuple_pop_origin_inputs,
update_while_output_stopgradient,
value_in_block,
warning_once,
while_prune_check,
)
Expand Down Expand Up @@ -588,6 +590,16 @@ def update_input_grad_map(op, input_grads, all_inputs):
state.value_to_valuegrad[input].append([input_grad])
i += 1

def update_if_double_grad_input_grad_map(input_grads, all_inputs):
assert len(input_grads) == len(
all_inputs
), "input_grads should same to all_inputs"
for input, input_grad in zip(all_inputs, input_grads):
if isinstance(input_grad, list):
state.value_to_valuegrad[input].append(input_grad)
else:
state.value_to_valuegrad[input].append([input_grad])

def append_yield(
block,
base_op,
Expand Down Expand Up @@ -634,7 +646,15 @@ def append_yield(
new_value = return_map_value(
value, control_flow_value_to_copyvalue_map
)
append_full_like(0.0, new_value, value, state, backward_ops)
if not value_in_block(new_value, block):
# new_value.defining_op is another if block's tuple_pop
state.value_to_valuegrad[value] = [
[paddle.pir.fake_value()]
]
else:
append_full_like(
0.0, new_value, value, state, backward_ops
)

input_grad = return_map_value(
state.value_to_valuegrad[value][0][0],
Expand Down Expand Up @@ -748,6 +768,42 @@ def append_yield(
origin_inputs = get_real_op_inputs(op)
for sub_block in op.blocks():
build_pipe_for_block(sub_block)
# only for double grad if op
true_block = op.as_if_op().true_block()
false_block = op.as_if_op().false_block()

true_block_pop_inputs = []
true_block_pop_input_grad_stopgradients = []
if true_block.ops[0].name() == "cf.tuple_pop":
for result in true_block.ops[0].results():
true_block_pop_inputs.append([result])
true_block_pop_input_grad_stopgradients.append(
[result.stop_gradient]
)
false_block_pop_inputs = []
false_block_pop_input_grad_stopgradients = []
if false_block.ops[0].name() == 'cf.tuple_pop':
for result in false_block.ops[0].results():
false_block_pop_inputs.append([result])
false_block_pop_input_grad_stopgradients.append(
[result.stop_gradient]
)

if (
true_block_pop_inputs != []
or false_block_pop_inputs != []
):
inputs = (
inputs
+ true_block_pop_inputs
+ false_block_pop_inputs
)
input_grad_stopgradients = (
input_grad_stopgradients
+ true_block_pop_input_grad_stopgradients
+ false_block_pop_input_grad_stopgradients
)

with dynamic_shape_prim_vjp_guard(op, inputs):
input_grads = paddle.framework.core.call_vjp(
op,
Expand Down Expand Up @@ -777,6 +833,7 @@ def append_yield(
sub_control_flow_value_to_copyvalue_map = (
control_flow_value_to_copyvalue_map.copy()
)

append_backward_ops(
op,
[input[0] for input in inputs[1:]],
Expand All @@ -801,8 +858,36 @@ def append_yield(
)
for input_tuple in inputs_used_by_other_op:
state.value_to_valuegrad[input_tuple[0]] = []

# update input_grad map
update_input_grad_map(op, input_grads, origin_inputs)
if (
true_block_pop_inputs != []
or false_block_pop_inputs != []
):
true_block_pop_inputs = (
update_tuple_pop_origin_inputs(
true_block_pop_inputs
)
)
false_block_pop_inputs = (
update_tuple_pop_origin_inputs(
false_block_pop_inputs
)
)
# delete cond inputs
origin_inputs = (
origin_inputs[1:]
+ true_block_pop_inputs
+ false_block_pop_inputs
)
update_if_double_grad_input_grad_map(
input_grads, origin_inputs
)
else:
update_input_grad_map(
op, input_grads, origin_inputs
)

elif op.name() == "pd_op.while":
origin_inputs = get_real_op_inputs(op)
# prepare while[cond, loop_vars, other_input] other_input's grad
Expand Down Expand Up @@ -938,8 +1023,11 @@ def append_yield(
op.num_operands() == 0
and op.num_results() != 0
or op.name() == "pd_op.full_like"
or op.name() == "cf.tuple_pop"
):
for value in op.results():
if value not in state.value_to_valuegrad:
continue
if len(state.value_to_valuegrad[value]) > 1:
append_add_n(
op,
Expand Down
59 changes: 58 additions & 1 deletion test/dygraph_to_static/test_high_order_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import unittest

import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
Expand Down Expand Up @@ -41,7 +42,7 @@ def forward(self, x, y):
class TestBackwardHasNoGradError(Dy2StTestBase):
@test_ast_only
@test_pir_only
def test_backward_has_no_grad_error(self):
def _test_backward_has_no_grad_error(self):
net = HighOrderNet()
static_net = paddle.jit.to_static(net, full_graph=True)

Expand All @@ -58,5 +59,61 @@ def test_backward_has_no_grad_error(self):
x_grad_grad.backward()


class HighOrderControlFlowNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.eps = 1e-5

def forward(self, x):
if x.numel() > 0:
variance, mean = (
paddle.var(x, axis=-1, unbiased=False, keepdim=True),
paddle.mean(x, axis=-1, keepdim=True),
)
y = (x - mean) / paddle.sqrt(variance + self.eps)
else:
y = x

x_grad = paddle.grad(y, x, create_graph=True)[0]

return x_grad.mean()


class HighOrderCompareNet(HighOrderControlFlowNet):
def __init__(self):
super().__init__()
self.eps = 1e-5

def forward(self, x):
variance, mean = (
paddle.var(x, axis=-1, unbiased=False, keepdim=True),
paddle.mean(x, axis=-1, keepdim=True),
)
y = (x - mean) / paddle.sqrt(variance + self.eps)

x_grad = paddle.grad(y, x, create_graph=True)[0]

return x_grad.mean()


class TestBackwardControlFlow(Dy2StTestBase):
@test_ast_only
@test_pir_only
def test_control_flow_hign_order_backward(self):
conf_net = HighOrderControlFlowNet()
net = HighOrderCompareNet()
x = paddle.rand((5, 5)).astype('float32')
x.stop_gradient = False
static_net = paddle.jit.to_static(net, full_graph=True)
x_grad_grad = static_net(x)

conf_static_net = paddle.jit.to_static(conf_net, full_graph=True)
x_grad_grad_conf = conf_static_net(x)

np.testing.assert_allclose(
x_grad_grad.numpy(), x_grad_grad_conf.numpy()
)


if __name__ == "__main__":
unittest.main()