Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 54 additions & 183 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,40 +135,24 @@ void divide_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
// dy = -(x/y^2) * dout
auto dy_res = -(x / (y * y)) * out_grad;
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
auto dy_res = -out_grad * (x / y / y);
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != y.dims()) {
auto dy_tmp = reduce_as<T>(dy_res, y);
set_output<T>(dy_tmp, dy);
} else {
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
sum<T>(dy_res, common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
} else {
set_output<T>(dy_res, dy);
}
set_output<T>(dy_res, dy);
}
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
Tensor one_tensor = full_scalar<T>(1.0, y.dtype());
auto dx_res = one_tensor / y * out_grad;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
auto dx_res = out_grad / y;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != x.dims()) {
auto dx_tmp = reduce_as<T>(dx_res, x);
set_output<T>(dx_tmp, dx);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
sum<T>(dx_res, common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
} else {
set_output<T>(dx_res, dx);
}
set_output<T>(dx_res, dx);
}
} // indicate we will compute dx
}
Expand Down Expand Up @@ -601,37 +585,22 @@ void add_grad(const Tensor& x,
Tensor* dx,
Tensor* dy) {
if (dy) {
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != y.dims()) {
auto dy_tmp = reduce_as<T>(out_grad, y);
set_output<T>(dy_tmp, dy);
} else {
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
out_grad.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
} else {
by_pass<T>(out_grad, dy);
}
by_pass<T>(out_grad, dy);
}
}

if (dx) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != x.dims()) {
auto dx_tmp = reduce_as<T>(out_grad, x);
set_output<T>(dx_tmp, dx);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
out_grad.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
} else {
by_pass<T>(out_grad, dx);
}
by_pass<T>(out_grad, dx);
}
}
}
Expand All @@ -645,36 +614,21 @@ void subtract_grad(const Tensor& x,
Tensor* dy) {
if (dy) {
auto scale_out_grad = scale<T>(out_grad, -1.0, 0.0, true);
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != y.dims()) {
auto dy_tmp = reduce_as<T>(scale_out_grad, y);
set_output<T>(dy_tmp, dy);
} else {
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
scale_out_grad.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
} else {
by_pass<T>(scale_out_grad, dy);
}
set_output<T>(scale_out_grad, dy);
}
}
if (dx) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != x.dims()) {
auto dx_tmp = reduce_as<T>(out_grad, x);
set_output<T>(dx_tmp, dx);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
out_grad.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
} else {
by_pass<T>(out_grad, dx);
}
by_pass<T>(out_grad, dx);
}
}
}
Expand All @@ -689,41 +643,23 @@ void multiply_grad(const Tensor& x,
if (x_grad) {
auto x_grad_unreduce = out_grad * y;
if (has_dynamic_shape(x.shape()) ||
has_dynamic_shape(x_grad_unreduce.shape())) {
has_dynamic_shape(x_grad_unreduce.shape()) ||
x_grad_unreduce.dims() != x.dims()) {
auto x_grad_reduced = reduce_as<T>(x_grad_unreduce, x);
set_output<T>(x_grad_reduced, x_grad);
} else {
if (x_grad_unreduce.dims() != x.dims()) {
auto axes = get_reduce_dims_from_out(x_grad_unreduce.dims(), x.dims());
auto x_grad_reduced = x_grad_unreduce.sum(
common::vectorize(axes), x_grad_unreduce.dtype(), false);
if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
}
set_output<T>(x_grad_reduced, x_grad);
} else {
set_output<T>(x_grad_unreduce, x_grad);
}
set_output<T>(x_grad_unreduce, x_grad);
}
}
if (y_grad) {
auto y_grad_unreduce = out_grad * x;
if (has_dynamic_shape(y.shape()) ||
has_dynamic_shape(y_grad_unreduce.shape())) {
has_dynamic_shape(y_grad_unreduce.shape()) ||
y_grad_unreduce.dims() != y.dims()) {
auto y_grad_reduced = reduce_as<T>(y_grad_unreduce, y);
set_output<T>(y_grad_reduced, y_grad);
} else {
if (y_grad_unreduce.dims() != y.dims()) {
auto axes = get_reduce_dims_from_out(y_grad_unreduce.dims(), y.dims());
auto y_grad_reduced = y_grad_unreduce.sum(
common::vectorize(axes), y_grad_unreduce.dtype(), false);
if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
}
set_output<T>(y_grad_reduced, y_grad);
} else {
set_output<T>(y_grad_unreduce, y_grad);
}
set_output<T>(y_grad_unreduce, y_grad);
}
}
}
Expand All @@ -739,20 +675,12 @@ void elementwise_pow_grad(const Tensor& x,
auto lnx = log<T>(x);
auto x_pow_y = elementwise_pow<T>(x, y);
auto dy_res = lnx * x_pow_y * out_grad;
if (has_dynamic_shape(out_grad.shape()) || has_dynamic_shape(y.shape())) {
if (has_dynamic_shape(out_grad.shape()) || has_dynamic_shape(y.shape()) ||
out_grad.dims() != y.dims()) {
auto dy_reduce_res = reduce_as<T>(dy_res, y);
set_output<T>(dy_reduce_res, dy);
} else {
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
} else {
set_output<T>(dy_res, dy);
}
set_output<T>(dy_res, dy);
}
} // indicate we will compute dy
if (dx) {
Expand All @@ -768,11 +696,8 @@ void elementwise_pow_grad(const Tensor& x,
auto x_pow_z = elementwise_pow<T>(x, tmp_z);
auto dx_res = y * x_pow_z * out_grad;
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
auto dx_reduce_res = reduce_as<T>(dx_res, x);
set_output<T>(dx_reduce_res, dx);
} else {
set_output<T>(dx_res, dx);
}
Expand Down Expand Up @@ -1029,20 +954,12 @@ void expand_grad(const Tensor& x,
const IntArray& shape,
Tensor* x_grad) {
if (x_grad) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != x.dims()) {
auto reduced = reduce_as<T>(out_grad, x);
set_output<T>(reduced, x_grad);
} else {
if (out_grad.dims() != x.dims()) {
auto axes = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto reduced = out_grad.sum(common::vectorize(axes), x.dtype(), false);
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
set_output<T>(reduced, x_grad);
} else {
by_pass<T>(out_grad, x_grad);
}
by_pass<T>(out_grad, x_grad);
}
}
}
Expand Down Expand Up @@ -1241,21 +1158,13 @@ void matmul_grad(const Tensor& x,
x_grad_trans = transpose<T>(x_grad_mm, reverse_perm);
}
if (has_dynamic_shape(x.shape()) ||
has_dynamic_shape(x_grad_trans.shape())) {
has_dynamic_shape(x_grad_trans.shape()) ||
x_grad_trans.dims() != x.dims()) {
auto x_grad_out = reduce_as<T>(x_grad_trans, temp_x_unsqueeze);
set_output<T>(x_grad_out, x_grad);
} else {
if (x_grad_trans.dims() != x.dims()) {
phi::DDim x_reduce_dim = get_reduce_dims_from_out(
x_grad_trans.dims(), temp_x_unsqueeze.dims());
auto dx_reduce_res = sum<T>(
x_grad_trans, common::vectorize(x_reduce_dim), x.dtype(), false);
auto x_grad_out = reshape<T>(dx_reduce_res, x.shape());
set_output<T>(x_grad_out, x_grad);
} else {
auto x_grad_out = x_grad_trans;
set_output<T>(x_grad_out, x_grad);
}
auto x_grad_out = x_grad_trans;
set_output<T>(x_grad_out, x_grad);
}
}

Expand All @@ -1274,21 +1183,13 @@ void matmul_grad(const Tensor& x,
y_grad_trans = transpose<T>(y_grad_mm, reverse_perm);
}
if (has_dynamic_shape(y.shape()) ||
has_dynamic_shape(y_grad_trans.shape())) {
has_dynamic_shape(y_grad_trans.shape()) ||
y_grad_trans.dims() != y.dims()) {
auto y_grad_out = reduce_as<T>(y_grad_trans, temp_y_unsqueeze);
set_output<T>(y_grad_out, y_grad);
} else {
if (y_grad_trans.dims() != y.dims()) {
phi::DDim y_reduce_dim = get_reduce_dims_from_out(
y_grad_trans.dims(), temp_y_unsqueeze.dims());
auto dy_reduce_res = sum<T>(
y_grad_trans, common::vectorize(y_reduce_dim), y.dtype(), false);
auto y_grad_out = reshape<T>(dy_reduce_res, y.shape());
set_output<T>(y_grad_out, y_grad);
} else {
auto y_grad_out = y_grad_trans;
set_output<T>(y_grad_out, y_grad);
}
auto y_grad_out = y_grad_trans;
set_output<T>(y_grad_out, y_grad);
}
}
}
Expand All @@ -1302,39 +1203,24 @@ void maximum_grad(const Tensor& x,
if (x_grad) {
auto x_tmp = cast<T>(greater_than<T>(x, y), out_grad.dtype());
auto dx_res = out_grad * x_tmp;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != x.dims()) {
auto dx_reduce_res = reduce_as<T>(dx_res, x);
set_output<T>(dx_reduce_res, x_grad);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
} else {
set_output<T>(dx_res, x_grad);
}
set_output<T>(dx_res, x_grad);
}
}

if (y_grad) {
auto y_tmp = cast<T>(less_equal<T>(x, y), out_grad.dtype());
auto dy_res = out_grad * y_tmp;
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != y.dims()) {
auto dy_reduce_res = reduce_as<T>(dy_res, y);
set_output<T>(dy_reduce_res, y_grad);
} else {
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
} else {
set_output<T>(dy_res, y_grad);
}
set_output<T>(dy_res, y_grad);
}
}
}
Expand Down Expand Up @@ -2306,39 +2192,24 @@ void minimum_grad(const Tensor& x,
if (x_grad) {
auto x_tmp = cast<T>(less_than<T>(x, y), out_grad.dtype());
auto dx_res = out_grad * x_tmp;
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(x.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != x.dims()) {
auto dx_reduce_res = reduce_as<T>(dx_res, x);
set_output<T>(dx_reduce_res, x_grad);
} else {
if (out_grad.dims() != x.dims()) {
auto reduce_dim = get_reduce_dims_from_out(out_grad.dims(), x.dims());
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
} else {
set_output<T>(dx_res, x_grad);
}
set_output<T>(dx_res, x_grad);
}
}

if (y_grad) {
auto y_tmp = cast<T>(greater_equal<T>(x, y), out_grad.dtype());
auto dy_res = out_grad * y_tmp;
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape())) {
if (has_dynamic_shape(y.shape()) || has_dynamic_shape(out_grad.shape()) ||
out_grad.dims() != y.dims()) {
auto dy_reduce_res = reduce_as<T>(dy_res, y);
set_output<T>(dy_reduce_res, y_grad);
} else {
if (out_grad.dims() != y.dims()) {
phi::DDim reduce_dim =
get_reduce_dims_from_out(out_grad.dims(), y.dims());
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
} else {
set_output<T>(dy_res, y_grad);
}
set_output<T>(dy_res, y_grad);
}
}
}
Expand Down
Loading