@@ -19,29 +19,57 @@ namespace lite {
19
19
namespace kernels {
20
20
namespace host {
21
21
22
+ #define INDEX_COMPUTE \
23
+ for (int i = 0 ; i < index_size; ++i) { \
24
+ auto index_ = p_index[i]; \
25
+ memcpy (p_output + i * slice_size, \
26
+ p_src + index_ * slice_size, \
27
+ slice_size * sizeof (DataType)); \
28
+ }
29
+
30
+ #define INDEX_V2_COMPUTE \
31
+ for (int i = 0 ; i < index_size; i++) { \
32
+ CHECK_LT (index_data[i], input_index_dim_size) \
33
+ << " The element of Index must be less than the size of" \
34
+ << " dim size of axis dim" ; \
35
+ } \
36
+ for (int i = 0 ; i < inner_dim_size; i++) { \
37
+ for (int j = 0 ; j < index_size; j++) { \
38
+ for (int k = 0 ; k < outer_dim_size; k++) { \
39
+ int index = k + index_data[j] * outer_dim_size + \
40
+ (i * input_size / inner_dim_size); \
41
+ out_data[out_index] = input_data[index]; \
42
+ out_index++; \
43
+ } \
44
+ } \
45
+ }
46
+
22
47
template <typename IndexType, typename DataType>
23
48
void GatherFunc (const operators::GatherParam& param) {
24
49
auto src_dims = param.X ->dims ();
25
50
auto index_size = param.Index ->dims ()[0 ];
26
51
auto * p_src = param.X ->data <DataType>();
27
- const IndexType* p_index = param.Index ->data <IndexType>();
28
52
auto * p_output = param.Out ->mutable_data <DataType>();
29
53
30
54
int slice_size = 1 ;
31
55
for (size_t i = 1 ; i < src_dims.size (); ++i) {
32
56
slice_size *= src_dims[i];
33
57
}
34
- for (int i = 0 ; i < index_size; ++i) {
35
- IndexType index_ = p_index[i];
36
- memcpy (p_output + i * slice_size,
37
- p_src + index_ * slice_size,
38
- slice_size * sizeof (DataType));
58
+
59
+ if (param.Index ->precision () == PrecisionType::kInt64 ) {
60
+ const int64_t * p_index = param.Index ->data <int64_t >();
61
+ INDEX_COMPUTE
62
+ } else if (param.Index ->precision () == PrecisionType::kInt32 ) {
63
+ const int32_t * p_index = param.Index ->data <int32_t >();
64
+ INDEX_COMPUTE
65
+ } else {
66
+ LOG (FATAL) << " Unsupported this index precision: "
67
+ << PrecisionToStr (param.Index ->precision ());
39
68
}
40
69
}
41
70
42
71
template <typename IndexType, typename AxisType, typename DataType>
43
72
void GatherV2Func (const operators::GatherParam& param) {
44
- auto * index_data = param.Index ->data <IndexType>();
45
73
auto * input_data = param.X ->data <DataType>();
46
74
auto * out_data = param.Out ->mutable_data <DataType>();
47
75
@@ -52,11 +80,7 @@ void GatherV2Func(const operators::GatherParam& param) {
52
80
int inner_dim_size = 1 ;
53
81
int outer_dim_size = 1 ;
54
82
int input_index_dim_size = input_dim[axis_index];
55
- for (int i = 0 ; i < index_size; i++) {
56
- CHECK_LT (index_data[i], input_index_dim_size)
57
- << " The element of Index must be less than the size of"
58
- << " dim size of axis dim" ;
59
- }
83
+
60
84
for (int i = 0 ; i < axis_index; i++) {
61
85
inner_dim_size *= input_dim[i];
62
86
}
@@ -65,15 +89,15 @@ void GatherV2Func(const operators::GatherParam& param) {
65
89
}
66
90
67
91
int out_index = 0 ;
68
- for ( int i = 0 ; i < inner_dim_size; i++ ) {
69
- for ( int j = 0 ; j < index_size; j++) {
70
- for ( int k = 0 ; k < outer_dim_size; k++) {
71
- int index = k + index_data[j] * outer_dim_size +
72
- (i * input_size / inner_dim_size );
73
- out_data[out_index] = input_data[index];
74
- out_index++;
75
- }
76
- }
92
+ if (param. Index -> precision () == PrecisionType:: kInt64 ) {
93
+ auto * index_data = param. Index -> data < int64_t >();
94
+ INDEX_V2_COMPUTE
95
+ } else if (param. Index -> precision () == PrecisionType:: kInt32 ) {
96
+ auto * index_data = param. Index -> data < int32_t >( );
97
+ INDEX_V2_COMPUTE
98
+ } else {
99
+ LOG (FATAL) << " Unsupported this index precision: "
100
+ << PrecisionToStr (param. Index -> precision ());
77
101
}
78
102
}
79
103
@@ -136,6 +160,8 @@ void GatherCompute<IndexType, AxisType>::Run() {
136
160
return ;
137
161
}
138
162
}
163
+ #undef INDEX_COMPUTE
164
+ #undef INDEX_V2_COMPUTE
139
165
140
166
} // namespace host
141
167
} // namespace kernels
0 commit comments