Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 164 additions & 75 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -2039,9 +2039,8 @@ void prod_grad(const Tensor& x,
bool reduce_all,
Tensor* x_grad) {
if (x_grad) {
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
int64_t x_dim_size = x.dims().size();
reduce_all = false;
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
reduce_all = true;
Expand All @@ -2050,90 +2049,180 @@ void prod_grad(const Tensor& x,
}
auto out_grad_tmp = Tensor();
auto x_reshape = Tensor();
std::vector<int64_t> unchange_axis, change_axis, transpose_shape,
cumprod_shape;
std::vector<int> transpose_dim, origin_position;
if (x_dim_size == 1) {
out_grad_tmp = out_grad.expand(IntArray(x_dim));
} else {
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
if (has_dynamic_shape(x.shape())) {
Tensor x_dim = shape<T>(x);
std::vector<int64_t> unchange_axis, change_axis;
std::vector<int> transpose_dim, origin_position;
std::vector<Tensor> transpose_shape, cumprod_shape;
if (x_dim_size == 1) {
out_grad_tmp = backend::expand_with_tensor<T>(out_grad, x_dim);
} else {
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
Tensor out_grad_shape =
get_unsqueeze_dims<T>(shape<T>(out_grad), axis_);
Tensor out_grad_ = backend::reshape<T>(out_grad, out_grad_shape);
out_grad_tmp = backend::expand_with_tensor<T>(out_grad_, x_dim);
} else {
out_grad_tmp = backend::expand_with_tensor<T>(out_grad, x_dim);
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
out_grad_tmp = out_grad_.expand(IntArray(x_dim));
} else {
out_grad_tmp = out_grad.expand(IntArray(x_dim));
}
}
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
int64_t numel = 1;
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
numel *= x_dim[i];
if (reduce_all) {
Tensor numel = full<T>({1}, 1.0, x_dim.dtype());
for (int64_t i = 0; i < x_dim_size; i++) {
numel = numel * get_slice<T>(x_dim, i);
}
cumprod_shape.push_back(numel);
x_reshape = backend::reshape<T>(x, concat<T>(cumprod_shape));
Tensor left_cumprod = cumprod<T>(x_reshape, -1, true, false);
Tensor right_cumprod = cumprod<T>(x_reshape, -1, true, true);
Tensor x_grad_tmp = left_cumprod * right_cumprod;
Tensor x_grad_tmp2 = backend::reshape<T>(x_grad_tmp, x_dim);
Tensor x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
} else {
auto axis_ = std::vector<int64_t>();
int64_t unchange_size = x_dim_size - axis_size;
int64_t unchange_index = 0;
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_.push_back(axis[i] + x_dim_size);
} else {
axis_.push_back(axis[i]);
}
}
for (int64_t i = 0; i < x_dim_size; i++) {
auto it = find(axis_.begin(), axis_.end(), i);
if (it != axis_.end()) {
int64_t index = it - axis_.begin();
origin_position.push_back(static_cast<int>(unchange_size + index));
} else {
unchange_axis.push_back(i);
origin_position.push_back(static_cast<int>(unchange_index));
unchange_index += 1;
}
}
Tensor numel = full<T>({1}, 1.0, x_dim.dtype());
for (int64_t i = 0; i < unchange_size; i++) {
transpose_shape.push_back(get_slice<T>(x_dim, unchange_axis[i]));
cumprod_shape.push_back(get_slice<T>(x_dim, unchange_axis[i]));
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
}
for (int64_t i = 0; i < axis_size; i++) {
transpose_shape.push_back(get_slice<T>(x_dim, axis_[i]));
transpose_dim.push_back(static_cast<int>(axis_[i]));
numel = numel * get_slice<T>(x_dim, axis_[i]);
}
cumprod_shape.push_back(numel);
Tensor x_transpose = transpose<T>(x, transpose_dim);
x_reshape = backend::reshape<T>(x_transpose, concat<T>(cumprod_shape));
Tensor left_cumprod = cumprod<T>(x_reshape, -1, true, false);
Tensor right_cumprod = cumprod<T>(x_reshape, -1, true, true);
Tensor x_grad_tmp = left_cumprod * right_cumprod;
Tensor x_grad_reshape =
backend::reshape<T>(x_grad_tmp, concat<T>(transpose_shape));
Tensor x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
Tensor x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
}
cumprod_shape.push_back(numel);
x_reshape = reshape<T>(x, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape());
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
} else {
int64_t unchange_size = x_dim_size - axis_size;
int64_t unchange_index = 0;
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_.push_back(axis[i] + x_dim_size);
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
std::vector<int64_t> unchange_axis, change_axis, transpose_shape,
cumprod_shape;
std::vector<int> transpose_dim, origin_position;
if (x_dim_size == 1) {
out_grad_tmp = out_grad.expand(IntArray(x_dim));
} else {
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
out_grad_tmp = out_grad_.expand(IntArray(x_dim));
} else {
axis_.push_back(axis[i]);
out_grad_tmp = out_grad.expand(IntArray(x_dim));
}
}
for (int64_t i = 0; i < x_dim_size; i++) {
auto it = find(axis_.begin(), axis_.end(), i);
if (it != axis_.end()) {
int64_t index = it - axis_.begin();
origin_position.push_back(static_cast<int>(unchange_size + index));
} else {
unchange_axis.push_back(i);
origin_position.push_back(static_cast<int>(unchange_index));
unchange_index += 1;
if (reduce_all) {
int64_t numel = 1;
for (int64_t i = 0; i < x_dim_size; i++) {
numel *= x_dim[i];
}
cumprod_shape.push_back(numel);
x_reshape = reshape<T>(x, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape());
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
} else {
auto axis_ = std::vector<int64_t>();
int64_t unchange_size = x_dim_size - axis_size;
int64_t unchange_index = 0;
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_.push_back(axis[i] + x_dim_size);
} else {
axis_.push_back(axis[i]);
}
}
for (int64_t i = 0; i < x_dim_size; i++) {
auto it = find(axis_.begin(), axis_.end(), i);
if (it != axis_.end()) {
int64_t index = it - axis_.begin();
origin_position.push_back(static_cast<int>(unchange_size + index));
} else {
unchange_axis.push_back(i);
origin_position.push_back(static_cast<int>(unchange_index));
unchange_index += 1;
}
}
int64_t numel = 1;
for (int64_t i = 0; i < unchange_size; i++) {
transpose_shape.push_back(x_dim[unchange_axis[i]]);
cumprod_shape.push_back(x_dim[unchange_axis[i]]);
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
}
for (int64_t i = 0; i < axis_size; i++) {
transpose_shape.push_back(x_dim[axis_[i]]);
transpose_dim.push_back(static_cast<int>(axis_[i]));
numel *= x_dim[axis_[i]];
}
cumprod_shape.push_back(numel);
auto x_transpose = transpose<T>(x, transpose_dim);
x_reshape = reshape<T>(x_transpose, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
}
int64_t numel = 1;
for (int64_t i = 0; i < unchange_size; i++) {
transpose_shape.push_back(x_dim[unchange_axis[i]]);
cumprod_shape.push_back(x_dim[unchange_axis[i]]);
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
}
for (int64_t i = 0; i < axis_size; i++) {
transpose_shape.push_back(x_dim[axis_[i]]);
transpose_dim.push_back(static_cast<int>(axis_[i]));
numel *= x_dim[axis_[i]];
}
cumprod_shape.push_back(numel);
auto x_transpose = transpose<T>(x, transpose_dim);
x_reshape = reshape<T>(x_transpose, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions python/paddle/autograd/backward_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"pd_op.multiply",
"pd_op.pad",
"pd_op.pow",
"pd_op.prod",
"pd_op.reduce_as",
"pd_op.relu",
"pd_op.reshape",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def pow_net(x):
return paddle.pow(x, 3.2)


def prod_net1(x):
return paddle.prod(x)


def prod_net2(x):
return paddle.prod(x, 0)


def prod_net3(x):
return paddle.prod(x, keepdim=False)


def prod_net4(x):
return paddle.prod(x, 0, keepdim=False)


def scale_net(x):
return paddle.scale(x, scale=-2.3)

Expand Down Expand Up @@ -161,6 +177,66 @@ def setUp(self):
self.tol = 1e-6


class TestPrimProdWithGrad1(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [100]
self.init_x_shape = [None]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = prod_net1
self.enable_cinn = False
self.tol = 1e-6


class TestPrimProdWithGrad2(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [100, 20, 30]
self.init_x_shape = [None, None, 30]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = prod_net1
self.enable_cinn = False
self.tol = 1e-6


class TestPrimProdWithGrad3(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [100, 20, 30]
self.init_x_shape = [None, None, 30]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = prod_net2
self.enable_cinn = False
self.tol = 1e-6


class TestPrimProdWithGrad4(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [100, 20, 30]
self.init_x_shape = [None, None, 30]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = prod_net3
self.enable_cinn = False
self.tol = 1e-6


class TestPrimProdWithGrad5(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [100, 20, 30]
self.init_x_shape = [None, None, 30]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = prod_net4
self.enable_cinn = False
self.tol = 1e-6


class TestPrimScaleWithGrad(TestPrimBaseWithGrad):
def setUp(self):
np.random.seed(2023)
Expand Down