Skip to content

Commit f3d54e2

Browse files
authored
Move sgd to phi (#40045)
* move sgd to phi; test=develop * update * add sgd kernel; test=develop
1 parent 3fc698f commit f3d54e2

File tree

10 files changed

+592
-18
lines changed

10 files changed

+592
-18
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2051,7 +2051,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
20512051
// deal with optional here
20522052
if ((it == ctx.inputs.end() || it->second.size() == 0) &&
20532053
(input_defs[i].type_index ==
2054-
std::type_index(typeid(paddle::optional<const phi::DenseTensor&>)))) {
2054+
std::type_index(
2055+
typeid(paddle::optional<const phi::DenseTensor&>)) ||
2056+
input_defs[i].type_index ==
2057+
std::type_index(
2058+
typeid(paddle::optional<const phi::SelectedRows&>)))) {
20552059
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
20562060
auto end_idx = start_idx + 1;
20572061
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),

paddle/fluid/operators/optimizers/dgc_momentum_op.h

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <memory>
1818

1919
#include "paddle/fluid/operators/optimizers/momentum_op.h"
20-
#include "paddle/fluid/operators/optimizers/sgd_op.h"
20+
#include "paddle/phi/kernels/sgd_kernel.h"
2121

2222
namespace paddle {
2323
namespace operators {
@@ -26,8 +26,7 @@ template <typename DeviceContext, typename T>
2626
class DGCMomentumKernel : public framework::OpKernel<T> {
2727
public:
2828
DGCMomentumKernel()
29-
: _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()),
30-
_sgd_op_kernel(new SGDOpKernel<DeviceContext, T>()) {}
29+
: _momentum_op_kernel(new MomentumOpKernel<DeviceContext, T>()) {}
3130

3231
void Compute(const framework::ExecutionContext& context) const override {
3332
auto rampup_begin_step = context.Attr<float>("rampup_begin_step");
@@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel<T> {
6766
}
6867

6968
VLOG(10) << " so use sgd optimizer";
70-
return _sgd_op_kernel->Compute(context);
69+
70+
const auto* param_var = context.InputVar("Param");
71+
const auto* grad_var = context.InputVar("Grad");
72+
auto* learning_rate = context.Input<framework::Tensor>("LearningRate");
73+
bool multi_precision = context.Attr<bool>("multi_precision");
74+
if (param_var->IsType<framework::LoDTensor>()) {
75+
auto* param = context.Input<framework::Tensor>("Param");
76+
auto* param_out = context.Output<framework::Tensor>("ParamOut");
77+
auto* master_param_out =
78+
context.Output<framework::Tensor>("MasterParamOut");
79+
paddle::optional<const framework::Tensor&> master_param_opt =
80+
paddle::none;
81+
if (multi_precision) {
82+
auto* master_param = context.Input<framework::Tensor>("MasterParam");
83+
master_param_opt = *master_param;
84+
}
85+
86+
if (grad_var->IsType<framework::Tensor>()) {
87+
// sgd_dense
88+
auto* grad = context.Input<framework::Tensor>("Grad");
89+
phi::SGDDenseKernel<T>(
90+
static_cast<const typename framework::ConvertToPhiContext<
91+
DeviceContext>::TYPE&>(dev_ctx),
92+
*param, *learning_rate, *grad, master_param_opt, multi_precision,
93+
param_out, master_param_out);
94+
} else {
95+
// sgd dense param sparse grad
96+
auto* grad = context.Input<phi::SelectedRows>("Grad");
97+
phi::SGDDenseParamSparseGradKernel<T>(
98+
static_cast<const typename framework::ConvertToPhiContext<
99+
DeviceContext>::TYPE&>(dev_ctx),
100+
*param, *learning_rate, *grad, master_param_opt, multi_precision,
101+
param_out, master_param_out);
102+
}
103+
} else if (param_var->IsType<phi::SelectedRows>() &&
104+
grad_var->IsType<phi::SelectedRows>() &&
105+
platform::is_cpu_place(context.GetPlace())) {
106+
// sgd sparse param sparse grad
107+
auto* param = context.Input<phi::SelectedRows>("Param");
108+
auto* param_out = context.Output<phi::SelectedRows>("ParamOut");
109+
auto* master_param_out =
110+
context.Output<phi::SelectedRows>("MasterParamOut");
111+
paddle::optional<const phi::SelectedRows&> master_param_opt =
112+
paddle::none;
113+
if (multi_precision) {
114+
auto* master_param = context.Input<phi::SelectedRows>("MasterParam");
115+
master_param_opt = *master_param;
116+
}
117+
auto* grad = context.Input<phi::SelectedRows>("Grad");
118+
phi::SGDSparseParamSparseGradKernel<T>(
119+
static_cast<const typename framework::ConvertToPhiContext<
120+
DeviceContext>::TYPE&>(dev_ctx),
121+
*param, *learning_rate, *grad, master_param_opt, multi_precision,
122+
param_out, master_param_out);
123+
124+
} else {
125+
PADDLE_THROW("gdc not support yet");
126+
}
71127
}
72128

73129
private:
74130
std::unique_ptr<MomentumOpKernel<DeviceContext, T>> _momentum_op_kernel;
75-
std::unique_ptr<SGDOpKernel<DeviceContext, T>> _sgd_op_kernel;
76131
};
77132

78133
} // namespace operators

paddle/fluid/operators/optimizers/sgd_op.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,3 @@ REGISTER_OPERATOR(
166166
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
167167
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
168168
ops::SGDOpInferVarType);
169-
REGISTER_OP_CPU_KERNEL(
170-
sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
171-
ops::SGDOpKernel<paddle::platform::CPUDeviceContext,
172-
paddle::platform::bfloat16>,
173-
ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/optimizers/sgd_op.cu

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,3 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
166166
};
167167
} // namespace operators
168168
} // namespace paddle
169-
170-
namespace ops = paddle::operators;
171-
namespace plat = paddle::platform;
172-
REGISTER_OP_CUDA_KERNEL(
173-
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
174-
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>,
175-
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);

paddle/phi/core/kernel_registry.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
8181
default_tensor_layout,
8282
default_key.dtype(),
8383
arg_type);
84+
} else if (arg_type == std::type_index(typeid(
85+
paddle::optional<const SelectedRows&>))) {
86+
args_def->AppendInput(default_key.backend(),
87+
default_tensor_layout,
88+
default_key.dtype(),
89+
arg_type);
8490
} else if (arg_type ==
8591
std::type_index(typeid(const std::vector<DenseTensor>&))) {
8692
args_def->AppendInput(default_key.backend(),

paddle/phi/core/kernel_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
219219

220220
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
221221
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
222+
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows);
222223
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
223224
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
224225

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/sgd_kernel.h"
16+
#include "paddle/fluid/operators/jit/kernels.h"
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/funcs/eigen/common.h"
20+
21+
namespace phi {
22+
23+
template <typename T>
24+
void sgd_dense_param_dense_grad_impl(const DenseTensor& param,
25+
const DenseTensor& learning_rate,
26+
const DenseTensor& grad,
27+
DenseTensor* param_out) {
28+
const auto sz = param_out->numel();
29+
paddle::operators::jit::sgd_attr_t attr(1, sz, 1, sz, 1);
30+
const T* lr = learning_rate.data<T>();
31+
const T* param_data = param.data<T>();
32+
const T* grad_data = grad.data<T>();
33+
int64_t rows_idx = 0;
34+
T* out_data = param_out->data<T>();
35+
36+
auto sgd =
37+
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
38+
phi::CPUPlace>::Cache()
39+
.At(attr);
40+
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
41+
}
42+
43+
template <>
44+
void sgd_dense_param_dense_grad_impl<phi::dtype::bfloat16>(
45+
const DenseTensor& param,
46+
const DenseTensor& learning_rate,
47+
const DenseTensor& grad,
48+
DenseTensor* param_out) {
49+
auto p = EigenVector<phi::dtype::bfloat16>::Flatten(param);
50+
auto g = EigenVector<phi::dtype::bfloat16>::Flatten(grad);
51+
auto o = EigenVector<phi::dtype::bfloat16>::Flatten(*param_out);
52+
const auto* lr = learning_rate.data<phi::dtype::bfloat16>();
53+
54+
o = p - lr[0] * g;
55+
}
56+
57+
template <typename T>
58+
void sgd_dense_param_sparse_grad_impl(const DenseTensor& param,
59+
const DenseTensor& learning_rate,
60+
const SelectedRows& grad,
61+
DenseTensor* param_out) {
62+
const auto& grad_value = grad.value();
63+
const auto& grad_rows = grad.rows();
64+
const T* param_data = param.data<T>();
65+
const T* grad_data = grad_value.data<T>();
66+
const T* lr = learning_rate.data<T>();
67+
const int64_t* rows_data = grad_rows.data();
68+
T* out_data = param_out->data<T>();
69+
70+
paddle::operators::jit::sgd_attr_t attr;
71+
attr.param_height = param_out->dims()[0];
72+
attr.param_width = param_out->numel() / attr.param_height;
73+
attr.grad_height = grad_rows.size(); // note: it is not grad->height()
74+
attr.grad_width = grad_value.numel() / attr.grad_height;
75+
attr.selected_rows_size = grad_rows.size();
76+
77+
auto sgd =
78+
paddle::operators::jit::KernelFuncs<paddle::operators::jit::SgdTuple<T>,
79+
phi::CPUPlace>::Cache()
80+
.At(attr);
81+
sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
82+
}
83+
84+
template <>
85+
void sgd_dense_param_sparse_grad_impl<phi::dtype::bfloat16>(
86+
const DenseTensor& param,
87+
const DenseTensor& learning_rate,
88+
const SelectedRows& grad,
89+
DenseTensor* param_out) {
90+
const auto& grad_value = grad.value();
91+
const auto& grad_rows = grad.rows();
92+
const auto grad_height = grad.height();
93+
const int64_t grad_val_height = static_cast<int64_t>(grad_rows.size());
94+
const auto grad_width = grad_value.numel() / grad_val_height;
95+
96+
const auto* grad_data = grad_value.data<phi::dtype::bfloat16>();
97+
auto* out_data = param_out->data<phi::dtype::bfloat16>();
98+
const auto* lr = learning_rate.data<phi::dtype::bfloat16>();
99+
100+
for (size_t i = 0; i < grad_rows.size(); ++i) {
101+
PADDLE_ENFORCE_LT(
102+
grad_rows[i],
103+
grad_height,
104+
phi::errors::OutOfRange(
105+
"Grad rows index value should be less than grad height."
106+
"Got [%s], but expected less than [%s]",
107+
grad_rows[i],
108+
grad_height));
109+
const int64_t row = grad_rows[i];
110+
for (int64_t j = 0; j < grad_width; ++j) {
111+
out_data[row * grad_width + j] -= lr[0] * grad_data[i * grad_width + j];
112+
}
113+
}
114+
}
115+
116+
template <typename T, typename Context>
117+
void SGDDenseKernel(const Context& dev_ctx,
118+
const DenseTensor& param,
119+
const DenseTensor& learning_rate,
120+
const DenseTensor& grad,
121+
paddle::optional<const DenseTensor&> master_param,
122+
bool multi_precision,
123+
DenseTensor* param_out,
124+
DenseTensor* master_param_out) {
125+
dev_ctx.template Alloc<T>(param_out);
126+
sgd_dense_param_dense_grad_impl<T>(param, learning_rate, grad, param_out);
127+
}
128+
129+
template <typename T, typename Context>
130+
void SGDDenseParamSparseGradKernel(
131+
const Context& dev_ctx,
132+
const DenseTensor& param,
133+
const DenseTensor& learning_rate,
134+
const SelectedRows& grad,
135+
paddle::optional<const DenseTensor&> master_param,
136+
bool multi_precision,
137+
DenseTensor* param_out,
138+
DenseTensor* master_param_out) {
139+
dev_ctx.template Alloc<T>(param_out);
140+
sgd_dense_param_sparse_grad_impl<T>(param, learning_rate, grad, param_out);
141+
}
142+
143+
template <typename T, typename Context>
144+
void SGDSparseParamSparseGradKernel(
145+
const Context& dev_ctx,
146+
const SelectedRows& param,
147+
const DenseTensor& learning_rate,
148+
const SelectedRows& grad,
149+
paddle::optional<const SelectedRows&> master_param,
150+
bool multi_precision,
151+
SelectedRows* param_out,
152+
SelectedRows* master_param_out) {
153+
// for distributed training, a sparse var may be empty,
154+
// just skip updating.
155+
if (grad.rows().size() == 0) {
156+
return;
157+
}
158+
159+
auto param_row_width = param.value().dims()[1];
160+
auto grad_row_width = grad.value().dims()[1];
161+
PADDLE_ENFORCE_EQ(
162+
param_row_width,
163+
grad_row_width,
164+
phi::errors::InvalidArgument(
165+
"The param_row in SgdOP should have the same size with grad_row. "
166+
"But received param_row's width is [%s], and grad_row's width is "
167+
"[%s]",
168+
param_row_width,
169+
grad_row_width));
170+
171+
const auto* lr = learning_rate.data<T>();
172+
const auto* grad_data = grad.value().data<T>();
173+
auto* out_data = param_out->mutable_value()->data<T>();
174+
for (size_t i = 0; i < grad.rows().size(); i++) {
175+
int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false);
176+
PADDLE_ENFORCE_GE(
177+
id_index,
178+
static_cast<int64_t>(0),
179+
phi::errors::InvalidArgument(
180+
"The id in SgdOp should be >= 0. But recevied id_index is [%s]",
181+
id_index));
182+
for (int64_t j = 0; j < grad_row_width; j++) {
183+
out_data[id_index * grad_row_width + j] -=
184+
lr[0] * grad_data[i * grad_row_width + j];
185+
}
186+
}
187+
}
188+
189+
} // namespace phi
190+
191+
PD_REGISTER_KERNEL(sgd,
192+
CPU,
193+
ALL_LAYOUT,
194+
phi::SGDDenseKernel,
195+
phi::dtype::bfloat16,
196+
float,
197+
double) {}
198+
199+
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
200+
CPU,
201+
ALL_LAYOUT,
202+
phi::SGDDenseParamSparseGradKernel,
203+
phi::dtype::bfloat16,
204+
float,
205+
double) {}
206+
207+
PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad,
208+
CPU,
209+
ALL_LAYOUT,
210+
phi::SGDSparseParamSparseGradKernel,
211+
phi::dtype::bfloat16,
212+
float,
213+
double) {}

0 commit comments

Comments
 (0)