Skip to content

Commit e7341fd

Browse files
authored
support the dynamic shape for prod_grad (#67775)
1 parent 8b9bf87 commit e7341fd

File tree

3 files changed

+241
-75
lines changed

3 files changed

+241
-75
lines changed

paddle/fluid/primitive/rule/vjp/details.h

Lines changed: 164 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,9 +2053,8 @@ void prod_grad(const Tensor& x,
20532053
bool reduce_all,
20542054
Tensor* x_grad) {
20552055
if (x_grad) {
2056-
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
20572056
int64_t axis_size = axis.size();
2058-
int64_t x_dim_size = x_dim.size();
2057+
int64_t x_dim_size = x.dims().size();
20592058
reduce_all = false;
20602059
if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
20612060
reduce_all = true;
@@ -2064,90 +2063,180 @@ void prod_grad(const Tensor& x,
20642063
}
20652064
auto out_grad_tmp = Tensor();
20662065
auto x_reshape = Tensor();
2067-
std::vector<int64_t> unchange_axis, change_axis, transpose_shape,
2068-
cumprod_shape;
2069-
std::vector<int> transpose_dim, origin_position;
2070-
if (x_dim_size == 1) {
2071-
out_grad_tmp = out_grad.expand(IntArray(x_dim));
2072-
} else {
2073-
if (!keep_dim) {
2074-
auto axis_ = std::vector<int64_t>();
2075-
if (reduce_all) {
2076-
for (int64_t i = 0; i < x_dim_size; i++) {
2077-
axis_.push_back(i);
2078-
}
2079-
} else {
2080-
axis_ = axis.GetData();
2081-
for (int64_t i = 0; i < axis_size; i++) {
2082-
if (axis[i] < 0) {
2083-
axis_[i] = axis[i] + x_dim_size;
2066+
if (has_dynamic_shape(x.shape())) {
2067+
Tensor x_dim = shape<T>(x);
2068+
std::vector<int64_t> unchange_axis, change_axis;
2069+
std::vector<int> transpose_dim, origin_position;
2070+
std::vector<Tensor> transpose_shape, cumprod_shape;
2071+
if (x_dim_size == 1) {
2072+
out_grad_tmp = backend::expand_with_tensor<T>(out_grad, x_dim);
2073+
} else {
2074+
if (!keep_dim) {
2075+
auto axis_ = std::vector<int64_t>();
2076+
if (reduce_all) {
2077+
for (int64_t i = 0; i < x_dim_size; i++) {
2078+
axis_.push_back(i);
2079+
}
2080+
} else {
2081+
axis_ = axis.GetData();
2082+
for (int64_t i = 0; i < axis_size; i++) {
2083+
if (axis[i] < 0) {
2084+
axis_[i] = axis[i] + x_dim_size;
2085+
}
20842086
}
20852087
}
2088+
Tensor out_grad_shape =
2089+
get_unsqueeze_dims<T>(shape<T>(out_grad), axis_);
2090+
Tensor out_grad_ = backend::reshape<T>(out_grad, out_grad_shape);
2091+
out_grad_tmp = backend::expand_with_tensor<T>(out_grad_, x_dim);
2092+
} else {
2093+
out_grad_tmp = backend::expand_with_tensor<T>(out_grad, x_dim);
20862094
}
2087-
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
2088-
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
2089-
out_grad_tmp = out_grad_.expand(IntArray(x_dim));
2090-
} else {
2091-
out_grad_tmp = out_grad.expand(IntArray(x_dim));
20922095
}
2093-
}
2094-
auto axis_ = std::vector<int64_t>();
2095-
if (reduce_all) {
2096-
int64_t numel = 1;
2097-
for (int64_t i = 0; i < x_dim_size; i++) {
2098-
axis_.push_back(i);
2099-
numel *= x_dim[i];
2096+
if (reduce_all) {
2097+
Tensor numel = full<T>({1}, 1.0, x_dim.dtype());
2098+
for (int64_t i = 0; i < x_dim_size; i++) {
2099+
numel = numel * get_slice<T>(x_dim, i);
2100+
}
2101+
cumprod_shape.push_back(numel);
2102+
x_reshape = backend::reshape<T>(x, concat<T>(cumprod_shape));
2103+
Tensor left_cumprod = cumprod<T>(x_reshape, -1, true, false);
2104+
Tensor right_cumprod = cumprod<T>(x_reshape, -1, true, true);
2105+
Tensor x_grad_tmp = left_cumprod * right_cumprod;
2106+
Tensor x_grad_tmp2 = backend::reshape<T>(x_grad_tmp, x_dim);
2107+
Tensor x_grad_res = x_grad_tmp2 * out_grad_tmp;
2108+
set_output<T>(x_grad_res, x_grad);
2109+
} else {
2110+
auto axis_ = std::vector<int64_t>();
2111+
int64_t unchange_size = x_dim_size - axis_size;
2112+
int64_t unchange_index = 0;
2113+
for (int64_t i = 0; i < axis_size; i++) {
2114+
if (axis[i] < 0) {
2115+
axis_.push_back(axis[i] + x_dim_size);
2116+
} else {
2117+
axis_.push_back(axis[i]);
2118+
}
2119+
}
2120+
for (int64_t i = 0; i < x_dim_size; i++) {
2121+
auto it = find(axis_.begin(), axis_.end(), i);
2122+
if (it != axis_.end()) {
2123+
int64_t index = it - axis_.begin();
2124+
origin_position.push_back(static_cast<int>(unchange_size + index));
2125+
} else {
2126+
unchange_axis.push_back(i);
2127+
origin_position.push_back(static_cast<int>(unchange_index));
2128+
unchange_index += 1;
2129+
}
2130+
}
2131+
Tensor numel = full<T>({1}, 1.0, x_dim.dtype());
2132+
for (int64_t i = 0; i < unchange_size; i++) {
2133+
transpose_shape.push_back(get_slice<T>(x_dim, unchange_axis[i]));
2134+
cumprod_shape.push_back(get_slice<T>(x_dim, unchange_axis[i]));
2135+
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
2136+
}
2137+
for (int64_t i = 0; i < axis_size; i++) {
2138+
transpose_shape.push_back(get_slice<T>(x_dim, axis_[i]));
2139+
transpose_dim.push_back(static_cast<int>(axis_[i]));
2140+
numel = numel * get_slice<T>(x_dim, axis_[i]);
2141+
}
2142+
cumprod_shape.push_back(numel);
2143+
Tensor x_transpose = transpose<T>(x, transpose_dim);
2144+
x_reshape = backend::reshape<T>(x_transpose, concat<T>(cumprod_shape));
2145+
Tensor left_cumprod = cumprod<T>(x_reshape, -1, true, false);
2146+
Tensor right_cumprod = cumprod<T>(x_reshape, -1, true, true);
2147+
Tensor x_grad_tmp = left_cumprod * right_cumprod;
2148+
Tensor x_grad_reshape =
2149+
backend::reshape<T>(x_grad_tmp, concat<T>(transpose_shape));
2150+
Tensor x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
2151+
Tensor x_grad_res = x_grad_tmp2 * out_grad_tmp;
2152+
set_output<T>(x_grad_res, x_grad);
21002153
}
2101-
cumprod_shape.push_back(numel);
2102-
x_reshape = reshape<T>(x, cumprod_shape);
2103-
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
2104-
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
2105-
auto x_grad_tmp = left_cumprod * right_cumprod;
2106-
auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape());
2107-
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2108-
set_output<T>(x_grad_res, x_grad);
21092154
} else {
2110-
int64_t unchange_size = x_dim_size - axis_size;
2111-
int64_t unchange_index = 0;
2112-
for (int64_t i = 0; i < axis_size; i++) {
2113-
if (axis[i] < 0) {
2114-
axis_.push_back(axis[i] + x_dim_size);
2155+
std::vector<int64_t> x_dim = common::vectorize<int64_t>(x.dims());
2156+
std::vector<int64_t> unchange_axis, change_axis, transpose_shape,
2157+
cumprod_shape;
2158+
std::vector<int> transpose_dim, origin_position;
2159+
if (x_dim_size == 1) {
2160+
out_grad_tmp = out_grad.expand(IntArray(x_dim));
2161+
} else {
2162+
if (!keep_dim) {
2163+
auto axis_ = std::vector<int64_t>();
2164+
if (reduce_all) {
2165+
for (int64_t i = 0; i < x_dim_size; i++) {
2166+
axis_.push_back(i);
2167+
}
2168+
} else {
2169+
axis_ = axis.GetData();
2170+
for (int64_t i = 0; i < axis_size; i++) {
2171+
if (axis[i] < 0) {
2172+
axis_[i] = axis[i] + x_dim_size;
2173+
}
2174+
}
2175+
}
2176+
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
2177+
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
2178+
out_grad_tmp = out_grad_.expand(IntArray(x_dim));
21152179
} else {
2116-
axis_.push_back(axis[i]);
2180+
out_grad_tmp = out_grad.expand(IntArray(x_dim));
21172181
}
21182182
}
2119-
for (int64_t i = 0; i < x_dim_size; i++) {
2120-
auto it = find(axis_.begin(), axis_.end(), i);
2121-
if (it != axis_.end()) {
2122-
int64_t index = it - axis_.begin();
2123-
origin_position.push_back(static_cast<int>(unchange_size + index));
2124-
} else {
2125-
unchange_axis.push_back(i);
2126-
origin_position.push_back(static_cast<int>(unchange_index));
2127-
unchange_index += 1;
2183+
if (reduce_all) {
2184+
int64_t numel = 1;
2185+
for (int64_t i = 0; i < x_dim_size; i++) {
2186+
numel *= x_dim[i];
21282187
}
2188+
cumprod_shape.push_back(numel);
2189+
x_reshape = reshape<T>(x, cumprod_shape);
2190+
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
2191+
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
2192+
auto x_grad_tmp = left_cumprod * right_cumprod;
2193+
auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape());
2194+
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2195+
set_output<T>(x_grad_res, x_grad);
2196+
} else {
2197+
auto axis_ = std::vector<int64_t>();
2198+
int64_t unchange_size = x_dim_size - axis_size;
2199+
int64_t unchange_index = 0;
2200+
for (int64_t i = 0; i < axis_size; i++) {
2201+
if (axis[i] < 0) {
2202+
axis_.push_back(axis[i] + x_dim_size);
2203+
} else {
2204+
axis_.push_back(axis[i]);
2205+
}
2206+
}
2207+
for (int64_t i = 0; i < x_dim_size; i++) {
2208+
auto it = find(axis_.begin(), axis_.end(), i);
2209+
if (it != axis_.end()) {
2210+
int64_t index = it - axis_.begin();
2211+
origin_position.push_back(static_cast<int>(unchange_size + index));
2212+
} else {
2213+
unchange_axis.push_back(i);
2214+
origin_position.push_back(static_cast<int>(unchange_index));
2215+
unchange_index += 1;
2216+
}
2217+
}
2218+
int64_t numel = 1;
2219+
for (int64_t i = 0; i < unchange_size; i++) {
2220+
transpose_shape.push_back(x_dim[unchange_axis[i]]);
2221+
cumprod_shape.push_back(x_dim[unchange_axis[i]]);
2222+
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
2223+
}
2224+
for (int64_t i = 0; i < axis_size; i++) {
2225+
transpose_shape.push_back(x_dim[axis_[i]]);
2226+
transpose_dim.push_back(static_cast<int>(axis_[i]));
2227+
numel *= x_dim[axis_[i]];
2228+
}
2229+
cumprod_shape.push_back(numel);
2230+
auto x_transpose = transpose<T>(x, transpose_dim);
2231+
x_reshape = reshape<T>(x_transpose, cumprod_shape);
2232+
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
2233+
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
2234+
auto x_grad_tmp = left_cumprod * right_cumprod;
2235+
auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
2236+
auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
2237+
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2238+
set_output<T>(x_grad_res, x_grad);
21292239
}
2130-
int64_t numel = 1;
2131-
for (int64_t i = 0; i < unchange_size; i++) {
2132-
transpose_shape.push_back(x_dim[unchange_axis[i]]);
2133-
cumprod_shape.push_back(x_dim[unchange_axis[i]]);
2134-
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
2135-
}
2136-
for (int64_t i = 0; i < axis_size; i++) {
2137-
transpose_shape.push_back(x_dim[axis_[i]]);
2138-
transpose_dim.push_back(static_cast<int>(axis_[i]));
2139-
numel *= x_dim[axis_[i]];
2140-
}
2141-
cumprod_shape.push_back(numel);
2142-
auto x_transpose = transpose<T>(x, transpose_dim);
2143-
x_reshape = reshape<T>(x_transpose, cumprod_shape);
2144-
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
2145-
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
2146-
auto x_grad_tmp = left_cumprod * right_cumprod;
2147-
auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
2148-
auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
2149-
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2150-
set_output<T>(x_grad_res, x_grad);
21512240
}
21522241
}
21532242
}

python/paddle/autograd/backward_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"pd_op.multiply",
5656
"pd_op.pad",
5757
"pd_op.pow",
58+
"pd_op.prod",
5859
"pd_op.reduce_as",
5960
"pd_op.relu",
6061
"pd_op.reshape",

test/prim/pir_prim/test_prim_sub_graph_pqrst_backward_dynamic_shape.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def pow_net(x):
3535
return paddle.pow(x, 3.2)
3636

3737

38+
def prod_net1(x):
39+
return paddle.prod(x)
40+
41+
42+
def prod_net2(x):
43+
return paddle.prod(x, 0)
44+
45+
46+
def prod_net3(x):
47+
return paddle.prod(x, keepdim=False)
48+
49+
50+
def prod_net4(x):
51+
return paddle.prod(x, 0, keepdim=False)
52+
53+
3854
def scale_net(x):
3955
return paddle.scale(x, scale=-2.3)
4056

@@ -161,6 +177,66 @@ def setUp(self):
161177
self.tol = 1e-6
162178

163179

180+
class TestPrimProdWithGrad1(TestPrimBaseWithGrad):
181+
def setUp(self):
182+
np.random.seed(2023)
183+
self.dtype = "float32"
184+
self.x_shape = [100]
185+
self.init_x_shape = [None]
186+
self.x = np.random.random(self.x_shape).astype(self.dtype)
187+
self.net = prod_net1
188+
self.enable_cinn = False
189+
self.tol = 1e-6
190+
191+
192+
class TestPrimProdWithGrad2(TestPrimBaseWithGrad):
193+
def setUp(self):
194+
np.random.seed(2023)
195+
self.dtype = "float32"
196+
self.x_shape = [100, 20, 30]
197+
self.init_x_shape = [None, None, 30]
198+
self.x = np.random.random(self.x_shape).astype(self.dtype)
199+
self.net = prod_net1
200+
self.enable_cinn = False
201+
self.tol = 1e-6
202+
203+
204+
class TestPrimProdWithGrad3(TestPrimBaseWithGrad):
205+
def setUp(self):
206+
np.random.seed(2023)
207+
self.dtype = "float32"
208+
self.x_shape = [100, 20, 30]
209+
self.init_x_shape = [None, None, 30]
210+
self.x = np.random.random(self.x_shape).astype(self.dtype)
211+
self.net = prod_net2
212+
self.enable_cinn = False
213+
self.tol = 1e-6
214+
215+
216+
class TestPrimProdWithGrad4(TestPrimBaseWithGrad):
217+
def setUp(self):
218+
np.random.seed(2023)
219+
self.dtype = "float32"
220+
self.x_shape = [100, 20, 30]
221+
self.init_x_shape = [None, None, 30]
222+
self.x = np.random.random(self.x_shape).astype(self.dtype)
223+
self.net = prod_net3
224+
self.enable_cinn = False
225+
self.tol = 1e-6
226+
227+
228+
class TestPrimProdWithGrad5(TestPrimBaseWithGrad):
229+
def setUp(self):
230+
np.random.seed(2023)
231+
self.dtype = "float32"
232+
self.x_shape = [100, 20, 30]
233+
self.init_x_shape = [None, None, 30]
234+
self.x = np.random.random(self.x_shape).astype(self.dtype)
235+
self.net = prod_net4
236+
self.enable_cinn = False
237+
self.tol = 1e-6
238+
239+
164240
class TestPrimScaleWithGrad(TestPrimBaseWithGrad):
165241
def setUp(self):
166242
np.random.seed(2023)

0 commit comments

Comments
 (0)