Skip to content

Commit 30c0291

Browse files
authored
[xpu] fixbug for use of xpu::sigmoid and change xpu::clip_v2 to xpu::clamp (#69374)
1 parent a747694 commit 30c0291

16 files changed

+50
-55
lines changed

paddle/phi/kernels/fusion/xpu/fused_multi_transformer_int8_xpu_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ void FusedMultiTransformerInt8XpuKernel(
354354
: gather_index_t->dims()[0],
355355
gather_axis);
356356
}
357-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::gather");
357+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::paddle_gather");
358358
r = xpu::copy<XPUTypeT>(
359359
ctx.x_context(),
360360
reinterpret_cast<XPUTypeT*>(cache_kv_gather_tensor.data<T>()),

paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ void FusedMultiTransformerXpuKernel(
277277
: gather_index_t->dims()[0],
278278
gather_axis);
279279
}
280-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::gather");
280+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::paddle_gather");
281281
cache_kv_out[i]->ResizeAndAllocate(cache_kv_gather_dims);
282282
r = xpu::copy<XPUTypeT>(
283283
ctx.x_context(),
@@ -307,7 +307,7 @@ void FusedMultiTransformerXpuKernel(
307307
: gather_index_t->dims()[0],
308308
gather_axis);
309309
}
310-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::gather_inplace");
310+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu::paddle_gather_inplace");
311311
}
312312
}
313313

paddle/phi/kernels/fusion/xpu/fused_rope_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void GetSinCosByPassValue(const Context& dev_ctx,
8383
{seq_len, head_dim},
8484
batch_size * seq_len,
8585
0);
86-
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather");
86+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "paddle_gather");
8787
ret = xpu::paddle_gather<XPUSCType, int64_t>(
8888
dev_ctx.x_context(),
8989
reinterpret_cast<const XPUSCType*>(cos->data()),
@@ -92,7 +92,7 @@ void GetSinCosByPassValue(const Context& dev_ctx,
9292
{seq_len, head_dim},
9393
batch_size * seq_len,
9494
0);
95-
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather");
95+
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "paddle_gather");
9696
} else {
9797
int sin_cos_batch_size = (dims_size) == 4 ? sin_cos_dims[0] : 1;
9898
ret = xpu::broadcast<XPUSCType>(

paddle/phi/kernels/xpu/activation_grad_kernel.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,12 @@ struct XPUSigmoidGradFunctor : public funcs::BaseActivationFunctor<T> {
408408
const DenseTensor* out,
409409
const DenseTensor* dout,
410410
DenseTensor* dx) const {
411-
int r = xpu_activation_backward<Context, T, XPUType>(
412-
dev_ctx, x, out, dout, dx, xpu::sigmoid_grad<XPUType>);
411+
dev_ctx.template Alloc<T>(dx);
412+
const XPUType* y_data = reinterpret_cast<const XPUType*>(out->data<T>());
413+
const XPUType* y_grad = reinterpret_cast<const XPUType*>(dout->data<T>());
414+
XPUType* x_grad = reinterpret_cast<XPUType*>(dx->data<T>());
415+
int r = xpu::sigmoid_grad(
416+
dev_ctx.x_context(), y_data, y_grad, x_grad, dx->numel());
413417
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_grad");
414418
}
415419
};

paddle/phi/kernels/xpu/c_embedding_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void CEmbeddingKernel(const Context& dev_ctx,
5050
ids.numel(),
5151
-1,
5252
static_cast<int32_t>(start_index));
53-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
53+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_embedding");
5454
} else if (index_type == phi::DataType::INT64) {
5555
int r = xpu::paddle_embedding(dev_ctx.x_context(),
5656
reinterpret_cast<const XPUType*>(table_data),
@@ -61,7 +61,7 @@ void CEmbeddingKernel(const Context& dev_ctx,
6161
ids.numel(),
6262
-1,
6363
static_cast<int64_t>(start_index));
64-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
64+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_embedding");
6565
} else {
6666
PADDLE_THROW(common::errors::Unavailable(
6767
"XPU c_embedding ids only support int32 or int64."));

paddle/phi/kernels/xpu/clip_kernel.cc

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "glog/logging.h"
1818

19+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1920
#include "paddle/phi/backends/xpu/xpu_context.h"
2021
#include "paddle/phi/backends/xpu/xpu_header.h"
2122
#include "paddle/phi/core/kernel_registry.h"
@@ -32,19 +33,13 @@ void ClipKernel(const Context& dev_ctx,
3233
using XPUDataType = typename XPUTypeTrait<T>::Type;
3334
auto x_data = reinterpret_cast<const XPUDataType*>(x.data<T>());
3435
auto out_data = reinterpret_cast<XPUDataType*>(out->data<T>());
35-
int r = xpu::clip_v2(dev_ctx.x_context(),
36-
x_data,
37-
out_data,
38-
x.numel(),
39-
static_cast<XPUDataType>(min.to<T>()),
40-
static_cast<XPUDataType>(max.to<T>()));
41-
42-
PADDLE_ENFORCE_EQ(r,
43-
XPU_SUCCESS,
44-
common::errors::External("XPU API(clip_v2) return wrong "
45-
"value[%d %s]",
46-
r,
47-
XPUAPIErrorMsg[r]));
36+
int r = xpu::clamp(dev_ctx.x_context(),
37+
x_data,
38+
out_data,
39+
x.numel(),
40+
static_cast<XPUDataType>(min.to<T>()),
41+
static_cast<XPUDataType>(max.to<T>()));
42+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clamp");
4843
}
4944

5045
} // namespace phi

paddle/phi/kernels/xpu/distribute_fpn_proposals_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ void DistributeFpnProposalsKernel(
158158
fpn_rois_shape,
159159
sub_idx.numel(),
160160
0);
161-
PADDLE_ENFORCE_XDNN_SUCCESS(r1, "gather");
161+
PADDLE_ENFORCE_XDNN_SUCCESS(r1, "paddle_gather");
162162
} else {
163163
multi_fpn_rois[i]->Resize({sub_rois_num, funcs::kBoxDim});
164164
dev_ctx.template Alloc<T>(multi_fpn_rois[i]);

paddle/phi/kernels/xpu/embedding_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void EmbeddingKernel(const Context &ctx,
107107
padding_idx);
108108
#endif
109109
}
110-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
110+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_embedding");
111111
}
112112

113113
} // namespace phi

paddle/phi/kernels/xpu/gather_kernel.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,7 @@ void GatherKernel(const Context& dev_ctx,
7777
index.dims().size() == 0 ? 1 : index.dims()[0],
7878
axis_v);
7979
}
80-
PADDLE_ENFORCE_EQ(
81-
r,
82-
xpu::Error_t::SUCCESS,
83-
common::errors::External(
84-
"XPU gather kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
80+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
8581
}
8682

8783
} // namespace phi

paddle/phi/kernels/xpu/generate_proposals_kernel.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
107107
{static_cast<int>(scores_slice.numel()), 1},
108108
index_sort.numel(),
109109
0);
110-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
110+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
111111

112112
r = xpu::paddle_gather<T>(
113113
dev_ctx.x_context(),
@@ -117,7 +117,7 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
117117
{static_cast<int>(bbox_deltas_slice.numel()) / 4, 4},
118118
index_sort.numel(),
119119
0);
120-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
120+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
121121

122122
r = xpu::paddle_gather<T>(dev_ctx.x_context(),
123123
anchors.data<T>(),
@@ -126,7 +126,7 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
126126
{static_cast<int>(anchors.numel()) / 4, 4},
127127
index_sort.numel(),
128128
0);
129-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
129+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
130130

131131
r = xpu::paddle_gather<T>(dev_ctx.x_context(),
132132
variances.data<T>(),
@@ -135,7 +135,7 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
135135
{static_cast<int>(variances.numel()) / 4, 4},
136136
index_sort.numel(),
137137
0);
138-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
138+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
139139

140140
int num = scores_slice.numel();
141141
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num)
@@ -211,7 +211,7 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
211211
{pre_nms_num, 4},
212212
keep_num,
213213
0);
214-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
214+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
215215

216216
r = xpu::paddle_gather<T>(dev_ctx.x_context(),
217217
scores_sel.data<T>(),
@@ -220,7 +220,7 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
220220
{pre_nms_num, 1},
221221
keep_num,
222222
0);
223-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
223+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
224224

225225
if (nms_thresh <= 0) {
226226
if (dev_ctx.x_context()->xpu_stream) {
@@ -257,15 +257,15 @@ std::pair<DenseTensor, DenseTensor> ProposalForOneImage(
257257
{keep_num, 4},
258258
keep_index.numel(),
259259
0);
260-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
260+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
261261
r = xpu::paddle_gather<T>(dev_ctx.x_context(),
262262
scores_filter.data<T>(),
263263
keep_index.data<int>(),
264264
scores_nms.data<T>(),
265265
{keep_num, 1},
266266
keep_index.numel(),
267267
0);
268-
PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather");
268+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_gather");
269269
if (dev_ctx.x_context()->xpu_stream) {
270270
dev_ctx.Wait();
271271
}

0 commit comments

Comments
 (0)