Skip to content

Commit 2072742

Browse files
authored
[XPU] Fix argmax, strided_slice when L3 cache and auto-tune is enabled. (#10121) (#10125)
1 parent 1f8b83e commit 2072742

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

lite/kernels/xpu/argmax_compute.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ namespace lite {
2222
namespace kernels {
2323
namespace xpu {
2424

25+
void ArgmaxCompute::PrepareForRun() {
26+
auto& param = this->template Param<param_t>();
27+
if (param.dtype == 2) {
28+
auto out = param.Out;
29+
out_int64_xpu_guard_ =
30+
TargetWrapperXPU::MallocScratchPad(out->numel() * sizeof(int64_t));
31+
}
32+
}
33+
2534
void ArgmaxCompute::Run() {
2635
auto& param = this->template Param<param_t>();
2736
auto& ctx = this->ctx_->template As<XPUContext>();
@@ -45,19 +54,15 @@ void ArgmaxCompute::Run() {
4554
CHECK_EQ(r, 0);
4655
} else if (param.dtype == 2) {
4756
// int32
48-
Tensor out_int64;
49-
out_int64.Resize(out->dims());
57+
int64_t* out_int64_data =
58+
reinterpret_cast<int64_t*>(out_int64_xpu_guard_->addr_);
5059
int r = xdnn::argmax<float, int64_t>(
51-
ctx.GetRawContext(),
52-
x->data<float>(),
53-
out_int64.mutable_data<int64_t>(TARGET(kXPU)),
54-
x_dims,
55-
axis);
60+
ctx.GetRawContext(), x->data<float>(), out_int64_data, x_dims, axis);
5661
CHECK_EQ(r, 0);
5762
r = xdnn::cast_v2<int64_t, int>(ctx.GetRawContext(),
58-
out_int64.data<int64_t>(),
63+
out_int64_data,
5964
out->mutable_data<int>(TARGET(kXPU)),
60-
out_int64.numel());
65+
static_cast<int>(out->numel()));
6166
CHECK_EQ(r, 0);
6267
} else {
6368
LOG(FATAL) << "argmax unsupported param type for xpu: " << param.dtype;

lite/kernels/xpu/argmax_compute.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ class ArgmaxCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
2525
public:
2626
using param_t = operators::ArgmaxParam;
2727

28+
void PrepareForRun() override;
29+
2830
virtual void Run();
2931

3032
virtual ~ArgmaxCompute() = default;
33+
34+
private:
35+
XPUScratchPadGuard out_int64_xpu_guard_;
3136
};
3237

3338
} // namespace xpu

lite/kernels/xpu/strided_slice_compute.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ void StridedSliceCompute<T, PType>::Run() {
296296
auto* out_t = param.Out->template mutable_data<T>(TARGET(kXPU));
297297

298298
if (need_reverse) {
299-
lite::Tensor* tmp = new lite::Tensor();
300-
tmp->Resize(out_dims);
301-
auto* tmp_t = tmp->template mutable_data<T>(TARGET(kXPU));
299+
XPUScratchPadGuard tmp_xpu_guard =
300+
TargetWrapperXPU::MallocScratchPad(param.Out->numel() * sizeof(T));
301+
auto tmp_t = reinterpret_cast<T*>(tmp_xpu_guard->addr_);
302302
int r = xdnn::strided_slice<T>(ctx.GetRawContext(),
303303
in_t,
304304
tmp_t,

0 commit comments

Comments
 (0)