@@ -2053,9 +2053,8 @@ void prod_grad(const Tensor& x,
20532053 bool reduce_all,
20542054 Tensor* x_grad) {
20552055 if (x_grad) {
2056- std::vector<int64_t > x_dim = common::vectorize<int64_t >(x.dims ());
20572056 int64_t axis_size = axis.size ();
2058- int64_t x_dim_size = x_dim .size ();
2057+ int64_t x_dim_size = x. dims () .size ();
20592058 reduce_all = false ;
20602059 if (reduce_all || axis_size == 0 || axis_size == x_dim_size) {
20612060 reduce_all = true ;
@@ -2064,90 +2063,180 @@ void prod_grad(const Tensor& x,
20642063 }
20652064 auto out_grad_tmp = Tensor ();
20662065 auto x_reshape = Tensor ();
2067- std::vector<int64_t > unchange_axis, change_axis, transpose_shape,
2068- cumprod_shape;
2069- std::vector<int > transpose_dim, origin_position;
2070- if (x_dim_size == 1 ) {
2071- out_grad_tmp = out_grad.expand (IntArray (x_dim));
2072- } else {
2073- if (!keep_dim) {
2074- auto axis_ = std::vector<int64_t >();
2075- if (reduce_all) {
2076- for (int64_t i = 0 ; i < x_dim_size; i++) {
2077- axis_.push_back (i);
2078- }
2079- } else {
2080- axis_ = axis.GetData ();
2081- for (int64_t i = 0 ; i < axis_size; i++) {
2082- if (axis[i] < 0 ) {
2083- axis_[i] = axis[i] + x_dim_size;
2066+ if (has_dynamic_shape (x.shape ())) {
2067+ Tensor x_dim = shape<T>(x);
2068+ std::vector<int64_t > unchange_axis, change_axis;
2069+ std::vector<int > transpose_dim, origin_position;
2070+ std::vector<Tensor> transpose_shape, cumprod_shape;
2071+ if (x_dim_size == 1 ) {
2072+ out_grad_tmp = backend::expand_with_tensor<T>(out_grad, x_dim);
2073+ } else {
2074+ if (!keep_dim) {
2075+ auto axis_ = std::vector<int64_t >();
2076+ if (reduce_all) {
2077+ for (int64_t i = 0 ; i < x_dim_size; i++) {
2078+ axis_.push_back (i);
2079+ }
2080+ } else {
2081+ axis_ = axis.GetData ();
2082+ for (int64_t i = 0 ; i < axis_size; i++) {
2083+ if (axis[i] < 0 ) {
2084+ axis_[i] = axis[i] + x_dim_size;
2085+ }
20842086 }
20852087 }
2088+ Tensor out_grad_shape =
2089+ get_unsqueeze_dims<T>(shape<T>(out_grad), axis_);
2090+ Tensor out_grad_ = backend::reshape<T>(out_grad, out_grad_shape);
2091+ out_grad_tmp = backend::expand_with_tensor<T>(out_grad_, x_dim);
2092+ } else {
2093+ out_grad_tmp = backend::expand_with_tensor<T>(out_grad, x_dim);
20862094 }
2087- auto out_grad_shape = get_unsqueeze_dims (out_grad, axis_);
2088- auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
2089- out_grad_tmp = out_grad_.expand (IntArray (x_dim));
2090- } else {
2091- out_grad_tmp = out_grad.expand (IntArray (x_dim));
20922095 }
2093- }
2094- auto axis_ = std::vector<int64_t >();
2095- if (reduce_all) {
2096- int64_t numel = 1 ;
2097- for (int64_t i = 0 ; i < x_dim_size; i++) {
2098- axis_.push_back (i);
2099- numel *= x_dim[i];
2096+ if (reduce_all) {
2097+ Tensor numel = full<T>({1 }, 1.0 , x_dim.dtype ());
2098+ for (int64_t i = 0 ; i < x_dim_size; i++) {
2099+ numel = numel * get_slice<T>(x_dim, i);
2100+ }
2101+ cumprod_shape.push_back (numel);
2102+ x_reshape = backend::reshape<T>(x, concat<T>(cumprod_shape));
2103+ Tensor left_cumprod = cumprod<T>(x_reshape, -1 , true , false );
2104+ Tensor right_cumprod = cumprod<T>(x_reshape, -1 , true , true );
2105+ Tensor x_grad_tmp = left_cumprod * right_cumprod;
2106+ Tensor x_grad_tmp2 = backend::reshape<T>(x_grad_tmp, x_dim);
2107+ Tensor x_grad_res = x_grad_tmp2 * out_grad_tmp;
2108+ set_output<T>(x_grad_res, x_grad);
2109+ } else {
2110+ auto axis_ = std::vector<int64_t >();
2111+ int64_t unchange_size = x_dim_size - axis_size;
2112+ int64_t unchange_index = 0 ;
2113+ for (int64_t i = 0 ; i < axis_size; i++) {
2114+ if (axis[i] < 0 ) {
2115+ axis_.push_back (axis[i] + x_dim_size);
2116+ } else {
2117+ axis_.push_back (axis[i]);
2118+ }
2119+ }
2120+ for (int64_t i = 0 ; i < x_dim_size; i++) {
2121+ auto it = find (axis_.begin (), axis_.end (), i);
2122+ if (it != axis_.end ()) {
2123+ int64_t index = it - axis_.begin ();
2124+ origin_position.push_back (static_cast <int >(unchange_size + index));
2125+ } else {
2126+ unchange_axis.push_back (i);
2127+ origin_position.push_back (static_cast <int >(unchange_index));
2128+ unchange_index += 1 ;
2129+ }
2130+ }
2131+ Tensor numel = full<T>({1 }, 1.0 , x_dim.dtype ());
2132+ for (int64_t i = 0 ; i < unchange_size; i++) {
2133+ transpose_shape.push_back (get_slice<T>(x_dim, unchange_axis[i]));
2134+ cumprod_shape.push_back (get_slice<T>(x_dim, unchange_axis[i]));
2135+ transpose_dim.push_back (static_cast <int >(unchange_axis[i]));
2136+ }
2137+ for (int64_t i = 0 ; i < axis_size; i++) {
2138+ transpose_shape.push_back (get_slice<T>(x_dim, axis_[i]));
2139+ transpose_dim.push_back (static_cast <int >(axis_[i]));
2140+ numel = numel * get_slice<T>(x_dim, axis_[i]);
2141+ }
2142+ cumprod_shape.push_back (numel);
2143+ Tensor x_transpose = transpose<T>(x, transpose_dim);
2144+ x_reshape = backend::reshape<T>(x_transpose, concat<T>(cumprod_shape));
2145+ Tensor left_cumprod = cumprod<T>(x_reshape, -1 , true , false );
2146+ Tensor right_cumprod = cumprod<T>(x_reshape, -1 , true , true );
2147+ Tensor x_grad_tmp = left_cumprod * right_cumprod;
2148+ Tensor x_grad_reshape =
2149+ backend::reshape<T>(x_grad_tmp, concat<T>(transpose_shape));
2150+ Tensor x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
2151+ Tensor x_grad_res = x_grad_tmp2 * out_grad_tmp;
2152+ set_output<T>(x_grad_res, x_grad);
21002153 }
2101- cumprod_shape.push_back (numel);
2102- x_reshape = reshape<T>(x, cumprod_shape);
2103- auto left_cumprod = cumprod<T>(x_reshape, -1 , true , false );
2104- auto right_cumprod = cumprod<T>(x_reshape, -1 , true , true );
2105- auto x_grad_tmp = left_cumprod * right_cumprod;
2106- auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape ());
2107- auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2108- set_output<T>(x_grad_res, x_grad);
21092154 } else {
2110- int64_t unchange_size = x_dim_size - axis_size;
2111- int64_t unchange_index = 0 ;
2112- for (int64_t i = 0 ; i < axis_size; i++) {
2113- if (axis[i] < 0 ) {
2114- axis_.push_back (axis[i] + x_dim_size);
2155+ std::vector<int64_t > x_dim = common::vectorize<int64_t >(x.dims ());
2156+ std::vector<int64_t > unchange_axis, change_axis, transpose_shape,
2157+ cumprod_shape;
2158+ std::vector<int > transpose_dim, origin_position;
2159+ if (x_dim_size == 1 ) {
2160+ out_grad_tmp = out_grad.expand (IntArray (x_dim));
2161+ } else {
2162+ if (!keep_dim) {
2163+ auto axis_ = std::vector<int64_t >();
2164+ if (reduce_all) {
2165+ for (int64_t i = 0 ; i < x_dim_size; i++) {
2166+ axis_.push_back (i);
2167+ }
2168+ } else {
2169+ axis_ = axis.GetData ();
2170+ for (int64_t i = 0 ; i < axis_size; i++) {
2171+ if (axis[i] < 0 ) {
2172+ axis_[i] = axis[i] + x_dim_size;
2173+ }
2174+ }
2175+ }
2176+ auto out_grad_shape = get_unsqueeze_dims (out_grad, axis_);
2177+ auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
2178+ out_grad_tmp = out_grad_.expand (IntArray (x_dim));
21152179 } else {
2116- axis_. push_back (axis[i] );
2180+ out_grad_tmp = out_grad. expand ( IntArray (x_dim) );
21172181 }
21182182 }
2119- for (int64_t i = 0 ; i < x_dim_size; i++) {
2120- auto it = find (axis_.begin (), axis_.end (), i);
2121- if (it != axis_.end ()) {
2122- int64_t index = it - axis_.begin ();
2123- origin_position.push_back (static_cast <int >(unchange_size + index));
2124- } else {
2125- unchange_axis.push_back (i);
2126- origin_position.push_back (static_cast <int >(unchange_index));
2127- unchange_index += 1 ;
2183+ if (reduce_all) {
2184+ int64_t numel = 1 ;
2185+ for (int64_t i = 0 ; i < x_dim_size; i++) {
2186+ numel *= x_dim[i];
21282187 }
2188+ cumprod_shape.push_back (numel);
2189+ x_reshape = reshape<T>(x, cumprod_shape);
2190+ auto left_cumprod = cumprod<T>(x_reshape, -1 , true , false );
2191+ auto right_cumprod = cumprod<T>(x_reshape, -1 , true , true );
2192+ auto x_grad_tmp = left_cumprod * right_cumprod;
2193+ auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape ());
2194+ auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2195+ set_output<T>(x_grad_res, x_grad);
2196+ } else {
2197+ auto axis_ = std::vector<int64_t >();
2198+ int64_t unchange_size = x_dim_size - axis_size;
2199+ int64_t unchange_index = 0 ;
2200+ for (int64_t i = 0 ; i < axis_size; i++) {
2201+ if (axis[i] < 0 ) {
2202+ axis_.push_back (axis[i] + x_dim_size);
2203+ } else {
2204+ axis_.push_back (axis[i]);
2205+ }
2206+ }
2207+ for (int64_t i = 0 ; i < x_dim_size; i++) {
2208+ auto it = find (axis_.begin (), axis_.end (), i);
2209+ if (it != axis_.end ()) {
2210+ int64_t index = it - axis_.begin ();
2211+ origin_position.push_back (static_cast <int >(unchange_size + index));
2212+ } else {
2213+ unchange_axis.push_back (i);
2214+ origin_position.push_back (static_cast <int >(unchange_index));
2215+ unchange_index += 1 ;
2216+ }
2217+ }
2218+ int64_t numel = 1 ;
2219+ for (int64_t i = 0 ; i < unchange_size; i++) {
2220+ transpose_shape.push_back (x_dim[unchange_axis[i]]);
2221+ cumprod_shape.push_back (x_dim[unchange_axis[i]]);
2222+ transpose_dim.push_back (static_cast <int >(unchange_axis[i]));
2223+ }
2224+ for (int64_t i = 0 ; i < axis_size; i++) {
2225+ transpose_shape.push_back (x_dim[axis_[i]]);
2226+ transpose_dim.push_back (static_cast <int >(axis_[i]));
2227+ numel *= x_dim[axis_[i]];
2228+ }
2229+ cumprod_shape.push_back (numel);
2230+ auto x_transpose = transpose<T>(x, transpose_dim);
2231+ x_reshape = reshape<T>(x_transpose, cumprod_shape);
2232+ auto left_cumprod = cumprod<T>(x_reshape, -1 , true , false );
2233+ auto right_cumprod = cumprod<T>(x_reshape, -1 , true , true );
2234+ auto x_grad_tmp = left_cumprod * right_cumprod;
2235+ auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
2236+ auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
2237+ auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2238+ set_output<T>(x_grad_res, x_grad);
21292239 }
2130- int64_t numel = 1 ;
2131- for (int64_t i = 0 ; i < unchange_size; i++) {
2132- transpose_shape.push_back (x_dim[unchange_axis[i]]);
2133- cumprod_shape.push_back (x_dim[unchange_axis[i]]);
2134- transpose_dim.push_back (static_cast <int >(unchange_axis[i]));
2135- }
2136- for (int64_t i = 0 ; i < axis_size; i++) {
2137- transpose_shape.push_back (x_dim[axis_[i]]);
2138- transpose_dim.push_back (static_cast <int >(axis_[i]));
2139- numel *= x_dim[axis_[i]];
2140- }
2141- cumprod_shape.push_back (numel);
2142- auto x_transpose = transpose<T>(x, transpose_dim);
2143- x_reshape = reshape<T>(x_transpose, cumprod_shape);
2144- auto left_cumprod = cumprod<T>(x_reshape, -1 , true , false );
2145- auto right_cumprod = cumprod<T>(x_reshape, -1 , true , true );
2146- auto x_grad_tmp = left_cumprod * right_cumprod;
2147- auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
2148- auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
2149- auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
2150- set_output<T>(x_grad_res, x_grad);
21512240 }
21522241 }
21532242}
0 commit comments