@@ -24,103 +24,6 @@ limitations under the License. */
2424#include " paddle/phi/kernels/split_kernel.h"
2525namespace paddle {
2626namespace operators {
27- static inline std::vector<phi::DDim> UpdateOutsDims (
28- const bool is_runtime,
29- const bool each_section_is_known,
30- const phi::DDim in_dims,
31- const size_t num,
32- std::vector<int > sections,
33- const size_t axis,
34- const int outs_number) {
35- std::vector<phi::DDim> outs_dims (outs_number, in_dims);
36- int64_t input_axis_dim = in_dims[axis];
37- if (num > 0 ) {
38- if (is_runtime || input_axis_dim > 0 ) {
39- PADDLE_ENFORCE_EQ (
40- input_axis_dim % num,
41- 0 ,
42- phi::errors::InvalidArgument (
43- " The input's size along the split dimension "
44- " must be evenly divisible by Attr(num_or_sections). "
45- " But received Attr(num_or_sections) "
46- " = %d, input(X)'s shape = [%s], Attr(dim) = %d." ,
47- num,
48- in_dims,
49- axis));
50- size_t out_axis_dim = input_axis_dim / num;
51-
52- for (auto & out_dim : outs_dims) {
53- out_dim[axis] = out_axis_dim;
54- }
55- } else {
56- for (auto & out_dim : outs_dims) {
57- out_dim[axis] = -1 ;
58- }
59- }
60- } else if (sections.size () > 0 ) {
61- if (is_runtime || input_axis_dim > 0 ) {
62- const int unk_dim_val = -1 ;
63- int unk_dim_idx = -1 , num_of_unk = 0 ;
64- int sum_of_section = 0 ;
65- for (size_t i = 0 ; i < sections.size (); ++i) {
66- if (sections[i] == unk_dim_val) {
67- num_of_unk++;
68- unk_dim_idx = i;
69- } else {
70- sum_of_section += sections[i];
71- }
72- }
73-
74- if (each_section_is_known) {
75- PADDLE_ENFORCE_LE (
76- num_of_unk,
77- 1 ,
78- phi::errors::InvalidArgument (
79- " Only one dimension value of Attr(num_or_sections) "
80- " in SplitOp can be -1. "
81- " But received Attr(num_or_sections) = [%s]." ,
82- common::make_ddim (sections)));
83- }
84-
85- if (unk_dim_idx != -1 ) {
86- // for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1].
87- // input_axis_dim = 5, sum_of_sections = 5.
88- // the following check will fail.
89- PADDLE_ENFORCE_LT (
90- sum_of_section,
91- input_axis_dim,
92- phi::errors::InvalidArgument (
93- " Sum of Attr(num_or_sections) other than unknown section "
94- " must be less than the input's "
95- " size "
96- " along the split dimension. But received Attr(num_or_sections) "
97- " = [%s], input(X)'s shape = [%s], Attr(dim) = %d." ,
98- common::make_ddim (sections),
99- in_dims,
100- axis));
101- if (each_section_is_known) {
102- sections[unk_dim_idx] = input_axis_dim - sum_of_section;
103- }
104- } else {
105- PADDLE_ENFORCE_EQ (
106- sum_of_section,
107- input_axis_dim,
108- phi::errors::InvalidArgument (
109- " Sum of Attr(num_or_sections) must be equal to the input's "
110- " size "
111- " along the split dimension. But received Attr(num_or_sections)"
112- " = [%s], input(X)'s shape = [%s], Attr(dim) = %d." ,
113- common::make_ddim (sections),
114- in_dims,
115- axis));
116- }
117- }
118- for (int i = 0 ; i < outs_number; ++i) {
119- outs_dims[i][axis] = sections[i];
120- }
121- }
122- return outs_dims;
123- }
12427
12528template <typename T>
12629class SplitGradMaker : public framework ::SingleGradOpMaker<T> {
0 commit comments