Skip to content

Commit 94fe929

Browse files
authored
[XPU] add set_value and set_value_grad (#48845)
1 parent 95332be commit 94fe929

File tree

4 files changed

+392
-0
lines changed

4 files changed

+392
-0
lines changed

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,13 @@ XPUOpMap& get_kl2_ops() {
417417
phi::DataType::FLOAT32})},
418418
{"sampling_id",
419419
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT64})},
420+
{"set_value",
421+
XPUKernelSet({phi::DataType::INT32,
422+
phi::DataType::INT64,
423+
phi::DataType::FLOAT16,
424+
phi::DataType::FLOAT32})},
425+
{"set_value_grad",
426+
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
420427
{"sgd", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
421428
{"sgd_dense_param_sparse_grad",
422429
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/set_value_grad_kernel.h"
16+
17+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18+
#include "paddle/phi/backends/xpu/xpu_context.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
21+
#include "paddle/phi/common/int_array.h"
22+
#include "paddle/phi/common/scalar.h"
23+
#include "paddle/phi/core/dense_tensor.h"
24+
#include "paddle/phi/core/tensor_utils.h"
25+
#include "paddle/phi/kernels/empty_kernel.h"
26+
#include "paddle/phi/kernels/funcs/broadcast_function.h"
27+
#include "paddle/phi/kernels/funcs/eigen/common.h"
28+
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
29+
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
30+
#include "paddle/phi/kernels/funcs/slice_utils.h"
31+
32+
namespace phi {
33+
34+
template <typename T, typename Context>
35+
void SetValueGradKernel(const Context& dev_ctx,
36+
const DenseTensor& out_grad,
37+
const IntArray& starts,
38+
const IntArray& ends,
39+
const IntArray& steps,
40+
const std::vector<int64_t>& axes,
41+
const std::vector<int64_t>& decrease_axes,
42+
const std::vector<int64_t>& none_axes,
43+
DenseTensor* x_grad,
44+
DenseTensor* value_grad) {
45+
using XPUType = typename XPUTypeTrait<T>::Type;
46+
x_grad->Resize(out_grad.dims());
47+
dev_ctx.template Alloc<T>(x_grad);
48+
dev_ctx.template Alloc<T>(value_grad);
49+
50+
const XPUType* dy_data = reinterpret_cast<const XPUType*>(out_grad.data<T>());
51+
XPUType* dx_data = reinterpret_cast<XPUType*>(x_grad->data<T>());
52+
XPUType* dv_data = reinterpret_cast<XPUType*>(value_grad->data<T>());
53+
54+
std::vector<int64_t> starts_vec = starts.GetData();
55+
std::vector<int64_t> ends_vec = ends.GetData();
56+
std::vector<int64_t> steps_vec = steps.GetData();
57+
58+
auto dy_dims = out_grad.dims();
59+
std::vector<int> dy_shape;
60+
for (int i = 0; i < dy_dims.size(); ++i) {
61+
dy_shape.push_back(dy_dims[i]);
62+
}
63+
64+
auto dv_dims = value_grad->dims();
65+
std::vector<int> dv_shape;
66+
for (int i = 0; i < dv_dims.size(); ++i) {
67+
dv_shape.push_back(dv_dims[i]);
68+
}
69+
70+
auto dx_dims = x_grad->dims();
71+
std::vector<int> dx_shape;
72+
for (int i = 0; i < dx_dims.size(); ++i) {
73+
dx_shape.push_back(dx_dims[i]);
74+
}
75+
76+
std::vector<int> starts_vec_int32;
77+
for (size_t i = 0; i < starts_vec.size(); ++i) {
78+
starts_vec_int32.push_back(starts_vec[i]);
79+
}
80+
81+
std::vector<int> ends_vec_int32;
82+
for (size_t i = 0; i < ends_vec.size(); ++i) {
83+
ends_vec_int32.push_back(ends_vec[i]);
84+
}
85+
86+
std::vector<int> steps_vec_int32;
87+
for (size_t i = 0; i < steps_vec.size(); ++i) {
88+
steps_vec_int32.push_back(steps_vec[i]);
89+
}
90+
91+
std::vector<int> axes_int32;
92+
for (size_t i = 0; i < axes.size(); ++i) {
93+
axes_int32.push_back(axes[i]);
94+
}
95+
96+
std::vector<int> decrease_axes_int32;
97+
for (size_t i = 0; i < decrease_axes.size(); ++i) {
98+
decrease_axes_int32.push_back(decrease_axes[i]);
99+
}
100+
101+
std::vector<int> none_axes_int32;
102+
for (size_t i = 0; i < none_axes.size(); ++i) {
103+
none_axes_int32.push_back(none_axes[i]);
104+
}
105+
106+
int r = xpu::set_value_grad(dev_ctx.x_context(),
107+
dy_data,
108+
dx_data,
109+
dv_data,
110+
dy_shape,
111+
dv_shape,
112+
starts_vec_int32,
113+
ends_vec_int32,
114+
steps_vec_int32,
115+
axes_int32,
116+
decrease_axes_int32,
117+
none_axes_int32);
118+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "set_value_grad");
119+
}
120+
121+
} // namespace phi
122+
123+
PD_REGISTER_KERNEL(set_value_grad,
124+
XPU,
125+
ALL_LAYOUT,
126+
phi::SetValueGradKernel,
127+
float,
128+
phi::dtype::float16) {}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/set_value_kernel.h"
16+
17+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18+
#include "paddle/phi/backends/xpu/xpu_context.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
21+
#include "paddle/phi/common/int_array.h"
22+
#include "paddle/phi/common/scalar.h"
23+
#include "paddle/phi/core/dense_tensor.h"
24+
#include "paddle/phi/core/tensor_utils.h"
25+
#include "paddle/phi/kernels/empty_kernel.h"
26+
#include "paddle/phi/kernels/funcs/broadcast_function.h"
27+
#include "paddle/phi/kernels/funcs/eigen/common.h"
28+
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
29+
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
30+
#include "paddle/phi/kernels/funcs/slice_utils.h"
31+
32+
namespace phi {
33+
34+
template <typename T, typename Context>
35+
void SetTensorValueKernel(const Context& dev_ctx,
36+
const DenseTensor& x,
37+
const DenseTensor& value,
38+
const IntArray& starts,
39+
const IntArray& ends,
40+
const IntArray& steps,
41+
const std::vector<int64_t>& axes,
42+
const std::vector<int64_t>& decrease_axes,
43+
const std::vector<int64_t>& none_axes,
44+
DenseTensor* out) {
45+
using XPUType = typename XPUTypeTrait<T>::Type;
46+
out->Resize(x.dims());
47+
dev_ctx.template Alloc<T>(out);
48+
49+
const XPUType* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
50+
const XPUType* v_data = reinterpret_cast<const XPUType*>(value.data<T>());
51+
XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>());
52+
53+
std::vector<int64_t> starts_vec = starts.GetData();
54+
std::vector<int64_t> ends_vec = ends.GetData();
55+
std::vector<int64_t> steps_vec = steps.GetData();
56+
57+
std::vector<int> starts_vec_int32;
58+
for (size_t i = 0; i < starts_vec.size(); ++i) {
59+
starts_vec_int32.push_back(starts_vec[i]);
60+
}
61+
62+
std::vector<int> ends_vec_int32;
63+
for (size_t i = 0; i < ends_vec.size(); ++i) {
64+
ends_vec_int32.push_back(ends_vec[i]);
65+
}
66+
67+
std::vector<int> steps_vec_int32;
68+
for (size_t i = 0; i < steps_vec.size(); ++i) {
69+
steps_vec_int32.push_back(steps_vec[i]);
70+
}
71+
72+
std::vector<int> axes_int32;
73+
for (size_t i = 0; i < axes.size(); ++i) {
74+
axes_int32.push_back(axes[i]);
75+
}
76+
77+
std::vector<int> decrease_axes_int32;
78+
for (size_t i = 0; i < decrease_axes.size(); ++i) {
79+
decrease_axes_int32.push_back(decrease_axes[i]);
80+
}
81+
82+
std::vector<int> none_axes_int32;
83+
for (size_t i = 0; i < none_axes.size(); ++i) {
84+
none_axes_int32.push_back(none_axes[i]);
85+
}
86+
87+
auto x_dims = x.dims();
88+
std::vector<int> x_shape;
89+
for (int i = 0; i < x_dims.size(); ++i) {
90+
x_shape.push_back(x_dims[i]);
91+
}
92+
93+
auto v_dims = value.dims();
94+
std::vector<int> v_shape;
95+
for (int i = 0; i < v_dims.size(); ++i) {
96+
v_shape.push_back(v_dims[i]);
97+
}
98+
99+
int r = xpu::set_value(dev_ctx.x_context(),
100+
x_data,
101+
v_data,
102+
y_data,
103+
x_shape,
104+
v_shape,
105+
starts_vec_int32,
106+
ends_vec_int32,
107+
steps_vec_int32,
108+
axes_int32,
109+
decrease_axes_int32,
110+
none_axes_int32);
111+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "set_value");
112+
}
113+
114+
template <typename T, typename Context>
115+
void SetValueKernel(const Context& dev_ctx,
116+
const DenseTensor& x,
117+
const IntArray& starts,
118+
const IntArray& ends,
119+
const IntArray& steps,
120+
const std::vector<int64_t>& axes,
121+
const std::vector<int64_t>& decrease_axes,
122+
const std::vector<int64_t>& none_axes,
123+
const std::vector<int64_t>& shape,
124+
const std::vector<Scalar>& values,
125+
DenseTensor* out) {
126+
std::vector<T> assgin_values;
127+
assgin_values.reserve(values.size());
128+
for (const auto& val : values) {
129+
assgin_values.push_back(val.to<T>());
130+
}
131+
DenseTensor value_tensor = Empty<T>(dev_ctx, shape);
132+
paddle::framework::TensorFromVector(assgin_values, dev_ctx, &value_tensor);
133+
value_tensor.Resize(phi::make_ddim(shape));
134+
135+
SetTensorValueKernel<T, Context>(dev_ctx,
136+
x,
137+
value_tensor,
138+
starts,
139+
ends,
140+
steps,
141+
axes,
142+
decrease_axes,
143+
none_axes,
144+
out);
145+
}
146+
147+
} // namespace phi
148+
149+
PD_REGISTER_KERNEL(set_value,
150+
XPU,
151+
ALL_LAYOUT,
152+
phi::SetValueKernel,
153+
float,
154+
phi::dtype::float16,
155+
int,
156+
int64_t) {}
157+
158+
PD_REGISTER_KERNEL(set_value_with_tensor,
159+
XPU,
160+
ALL_LAYOUT,
161+
phi::SetTensorValueKernel,
162+
float,
163+
phi::dtype::float16,
164+
int,
165+
int64_t) {}

0 commit comments

Comments
 (0)