Skip to content

Commit 0b8be48

Browse files
[host-gather-fix] fix gather_compute in index_type not compatible error (#9000)
1 parent e367b05 commit 0b8be48

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

lite/kernels/host/gather_compute.cc

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,57 @@ namespace lite {
1919
namespace kernels {
2020
namespace host {
2121

22+
#define INDEX_COMPUTE \
23+
for (int i = 0; i < index_size; ++i) { \
24+
auto index_ = p_index[i]; \
25+
memcpy(p_output + i * slice_size, \
26+
p_src + index_ * slice_size, \
27+
slice_size * sizeof(DataType)); \
28+
}
29+
30+
#define INDEX_V2_COMPUTE \
31+
for (int i = 0; i < index_size; i++) { \
32+
CHECK_LT(index_data[i], input_index_dim_size) \
33+
<< "The element of Index must be less than the size of" \
34+
<< "dim size of axis dim"; \
35+
} \
36+
for (int i = 0; i < inner_dim_size; i++) { \
37+
for (int j = 0; j < index_size; j++) { \
38+
for (int k = 0; k < outer_dim_size; k++) { \
39+
int index = k + index_data[j] * outer_dim_size + \
40+
(i * input_size / inner_dim_size); \
41+
out_data[out_index] = input_data[index]; \
42+
out_index++; \
43+
} \
44+
} \
45+
}
46+
2247
template <typename IndexType, typename DataType>
2348
void GatherFunc(const operators::GatherParam& param) {
2449
auto src_dims = param.X->dims();
2550
auto index_size = param.Index->dims()[0];
2651
auto* p_src = param.X->data<DataType>();
27-
const IndexType* p_index = param.Index->data<IndexType>();
2852
auto* p_output = param.Out->mutable_data<DataType>();
2953

3054
int slice_size = 1;
3155
for (size_t i = 1; i < src_dims.size(); ++i) {
3256
slice_size *= src_dims[i];
3357
}
34-
for (int i = 0; i < index_size; ++i) {
35-
IndexType index_ = p_index[i];
36-
memcpy(p_output + i * slice_size,
37-
p_src + index_ * slice_size,
38-
slice_size * sizeof(DataType));
58+
59+
if (param.Index->precision() == PrecisionType::kInt64) {
60+
const int64_t* p_index = param.Index->data<int64_t>();
61+
INDEX_COMPUTE
62+
} else if (param.Index->precision() == PrecisionType::kInt32) {
63+
const int32_t* p_index = param.Index->data<int32_t>();
64+
INDEX_COMPUTE
65+
} else {
66+
LOG(FATAL) << "Unsupported this index precision: "
67+
<< PrecisionToStr(param.Index->precision());
3968
}
4069
}
4170

4271
template <typename IndexType, typename AxisType, typename DataType>
4372
void GatherV2Func(const operators::GatherParam& param) {
44-
auto* index_data = param.Index->data<IndexType>();
4573
auto* input_data = param.X->data<DataType>();
4674
auto* out_data = param.Out->mutable_data<DataType>();
4775

@@ -52,11 +80,7 @@ void GatherV2Func(const operators::GatherParam& param) {
5280
int inner_dim_size = 1;
5381
int outer_dim_size = 1;
5482
int input_index_dim_size = input_dim[axis_index];
55-
for (int i = 0; i < index_size; i++) {
56-
CHECK_LT(index_data[i], input_index_dim_size)
57-
<< "The element of Index must be less than the size of"
58-
<< "dim size of axis dim";
59-
}
83+
6084
for (int i = 0; i < axis_index; i++) {
6185
inner_dim_size *= input_dim[i];
6286
}
@@ -65,15 +89,15 @@ void GatherV2Func(const operators::GatherParam& param) {
6589
}
6690

6791
int out_index = 0;
68-
for (int i = 0; i < inner_dim_size; i++) {
69-
for (int j = 0; j < index_size; j++) {
70-
for (int k = 0; k < outer_dim_size; k++) {
71-
int index = k + index_data[j] * outer_dim_size +
72-
(i * input_size / inner_dim_size);
73-
out_data[out_index] = input_data[index];
74-
out_index++;
75-
}
76-
}
92+
if (param.Index->precision() == PrecisionType::kInt64) {
93+
auto* index_data = param.Index->data<int64_t>();
94+
INDEX_V2_COMPUTE
95+
} else if (param.Index->precision() == PrecisionType::kInt32) {
96+
auto* index_data = param.Index->data<int32_t>();
97+
INDEX_V2_COMPUTE
98+
} else {
99+
LOG(FATAL) << "Unsupported this index precision: "
100+
<< PrecisionToStr(param.Index->precision());
77101
}
78102
}
79103

@@ -136,6 +160,8 @@ void GatherCompute<IndexType, AxisType>::Run() {
136160
return;
137161
}
138162
}
163+
#undef INDEX_COMPUTE
164+
#undef INDEX_V2_COMPUTE
139165

140166
} // namespace host
141167
} // namespace kernels

0 commit comments

Comments
 (0)