Skip to content

Commit 86ea8dc

Browse files
authored
Added scale op FP32/BF16 FWD/BWD kernels (#32975)
1 parent 88b43b5 commit 86ea8dc

File tree

9 files changed

+345
-3
lines changed

9 files changed

+345
-3
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
143143

144144
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
145145
const Tensor& in, Tensor* out,
146-
platform::Place place) {
146+
platform::Place place, bool always_copy) {
147147
PADDLE_ENFORCE_NE(in.format(), MKLDNNMemoryFormat::undef,
148148
platform::errors::InvalidArgument(
149149
"Input tensor format is invalid. Input tensor should "
@@ -177,7 +177,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
177177
// output tensor has the same dims as input. Reorder don't change dims
178178
out->Resize(in.dims());
179179

180-
if (in_format != out_format) {
180+
if ((in_format != out_format) || always_copy) {
181181
void* in_data = GetDataFromTensor(in, in_type);
182182
std::string key =
183183
platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type);

paddle/fluid/framework/data_layout_transform.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
7878

7979
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
8080
const Tensor& in, Tensor* out,
81-
platform::Place place);
81+
platform::Place place,
82+
bool always_copy = false);
8283

8384
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
8485
const OpKernelType& expected_kernel_type,

paddle/fluid/inference/api/details/zero_copy_tensor.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/data_layout_transform.h"
1516
#include "paddle/fluid/framework/lod_tensor.h"
1617
#include "paddle/fluid/framework/scope.h"
1718
#include "paddle/fluid/inference/api/paddle_inference_api.h"
@@ -161,8 +162,24 @@ void Tensor::CopyToCpu(T *data) {
161162
auto *t_data = tensor->data<T>();
162163
auto t_place = tensor->place();
163164

165+
paddle::framework::Tensor out;
166+
auto mem_allocation = std::make_shared<paddle::memory::Allocation>(
167+
static_cast<void *>(data), ele_num * sizeof(T),
168+
paddle::platform::CPUPlace());
169+
out.ResetHolder(mem_allocation);
170+
164171
if (paddle::platform::is_cpu_place(t_place)) {
172+
#ifdef PADDLE_WITH_MKLDNN
173+
if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN)
174+
paddle::framework::innerTransDataLayoutFromMKLDNN(
175+
tensor->layout(), paddle::platform::MKLDNNDeviceContext::tls()
176+
.get_cur_paddle_data_layout(),
177+
*tensor, &out, paddle::platform::CPUPlace(), true);
178+
else
179+
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
180+
#else
165181
std::memcpy(static_cast<void *>(data), t_data, ele_num * sizeof(T));
182+
#endif
166183
} else if (place_ == PlaceType::kGPU) {
167184
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
168185
paddle::platform::DeviceContextPool &pool =
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/mkldnn_reuse.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using paddle::framework::Tensor;
21+
22+
template <typename T>
23+
class ScaleMKLDNNKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
this->RunKernel(ctx);
27+
}
28+
29+
void RunKernel(const framework::ExecutionContext& ctx) const {
30+
const auto& dev_ctx =
31+
ctx.template device_context<platform::MKLDNNDeviceContext>();
32+
33+
bool bias_after_scale = ctx.Attr<bool>("bias_after_scale");
34+
auto* x = ctx.Input<Tensor>("X");
35+
auto* out = ctx.Output<Tensor>("Out");
36+
auto* scale_tensor = ctx.Input<Tensor>("ScaleTensor");
37+
38+
float scale = (scale_tensor == nullptr) ? ctx.Attr<float>("scale")
39+
: (float)*(scale_tensor->data<T>());
40+
float bias = ctx.Attr<float>("bias");
41+
42+
// if bias_after_scale == true
43+
// out = scale*X + bias
44+
// else
45+
// out = scale*(X + bias) = scale*X + scale*bias
46+
47+
if (!bias_after_scale) bias *= scale;
48+
49+
auto x_tz = framework::vectorize<int64_t>(x->dims());
50+
bool is_inplaced = x->IsSharedBufferWith(*out);
51+
52+
platform::ActivationMKLDNNHandler<T> handler(
53+
x_tz, mkldnn::algorithm::eltwise_linear, scale, bias, x->format(),
54+
dev_ctx, ctx.GetPlace(), ctx.InputName("X"), is_inplaced);
55+
56+
auto src_memory_p = handler.AcquireSrcMemory(x);
57+
auto dst_memory_p = handler.AcquireDstMemory(out);
58+
auto activation_p = handler.AcquireForwardPrimitive();
59+
60+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
61+
activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
62+
{MKLDNN_ARG_TO, *dst_memory_p}});
63+
astream.wait();
64+
65+
out->set_layout(framework::DataLayout::kMKLDNN);
66+
out->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
67+
}
68+
};
69+
} // namespace operators
70+
} // namespace paddle
71+
72+
namespace ops = paddle::operators;
73+
REGISTER_OP_KERNEL(scale, MKLDNN, paddle::platform::CPUPlace,
74+
ops::ScaleMKLDNNKernel<float>,
75+
ops::ScaleMKLDNNKernel<paddle::platform::bfloat16>);

paddle/fluid/operators/scale_op.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,21 @@ class ScaleOp : public framework::OperatorWithKernel {
5454
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
5555
ctx->ShareLoD("X", /*->*/ "Out");
5656
}
57+
58+
framework::OpKernelType GetExpectedKernelType(
59+
const framework::ExecutionContext &ctx) const override {
60+
auto input_data_type =
61+
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
62+
63+
#ifdef PADDLE_WITH_MKLDNN
64+
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
65+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
66+
framework::DataLayout::kMKLDNN,
67+
framework::LibraryType::kMKLDNN);
68+
}
69+
#endif
70+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
71+
}
5772
};
5873

5974
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -87,6 +102,9 @@ if bias_after_scale=True:
87102
"Apply bias addition after or before scaling. It is useful for "
88103
"numeric stability in some circumstances.")
89104
.SetDefault(true);
105+
AddAttr<bool>("use_mkldnn",
106+
"(bool, default false) Only used in mkldnn kernel")
107+
.SetDefault(false);
90108
}
91109
};
92110

@@ -112,6 +130,8 @@ class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
112130
grad_op->SetAttr("scale", this->GetAttr("scale"));
113131
grad_op->SetAttr("bias", 0.0f);
114132
grad_op->SetAttr("bias_after_scale", true);
133+
if (grad_op->HasAttr("use_mkldnn"))
134+
grad_op->SetAttr("use_mkldnn", this->GetAttr("use_mkldnn"));
115135
}
116136
};
117137

paddle/fluid/operators/unity_build_rule.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ register_unity_group(cc
234234
save_combine_op.cc
235235
save_op.cc
236236
scale_op.cc
237+
mkldnn/scale_mkldnn_op.cc
237238
scatter_nd_add_op.cc
238239
scatter_op.cc
239240
seed_op.cc
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
20+
import paddle
21+
import paddle.fluid as fluid
22+
import paddle.fluid.core as core
23+
24+
25+
@unittest.skipIf(not core.supports_bfloat16(),
26+
"place does not support BF16 evaluation")
27+
@unittest.skipIf(core.is_compiled_with_cuda(),
28+
"core is compiled with CUDA which has no BF implementation")
29+
class TestScaleOpBF16(OpTest):
30+
def setUp(self):
31+
self.op_type = "scale"
32+
self.x_fp32 = np.random.random((10, 10)).astype(np.float32)
33+
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
34+
self.scale = -2.3
35+
self.inputs = {'X': self.x_bf16}
36+
self.attrs = {'scale': self.scale, 'use_mkldnn': True, 'bias': 0.4}
37+
self.use_mkldnn = True
38+
self.outputs = {
39+
'Out': (self.x_fp32 * self.attrs['scale']) + self.attrs['bias']
40+
}
41+
42+
def calculate_grads(self):
43+
bias = 0
44+
if 'bias' in self.attrs:
45+
bias = self.attrs['bias']
46+
47+
scale = self.scale
48+
if 'ScaleTensor' in self.attrs:
49+
scale = self.attrs['ScaleTensor']
50+
51+
self.out = (self.x_fp32 * scale) + bias
52+
self.dx = (self.out * scale)
53+
54+
def test_check_output(self):
55+
self.check_output(check_dygraph=False)
56+
57+
def test_check_grad(self):
58+
self.calculate_grads()
59+
self.check_grad_with_place(
60+
core.CPUPlace(), ["X"],
61+
"Out",
62+
check_dygraph=False,
63+
user_defined_grads=[self.dx],
64+
user_defined_grad_outputs=[convert_float_to_uint16(self.out)])
65+
66+
67+
class TestScaleOpBF16BiasNotAfterScale(TestScaleOpBF16):
68+
def setUp(self):
69+
self.op_type = "scale"
70+
self.x_fp32 = np.random.random((10, 10)).astype(np.float32)
71+
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
72+
self.scale = 1.5
73+
self.inputs = {'X': self.x_bf16}
74+
self.attrs = {
75+
'scale': self.scale,
76+
'use_mkldnn': True,
77+
'bias': 0.0,
78+
'bias_after_scale': False
79+
}
80+
self.use_mkldnn = True
81+
self.outputs = {
82+
'Out': (self.x_fp32 + self.attrs['bias']) * self.attrs['scale']
83+
}
84+
85+
86+
class TestScaleOpBF16ScaleTensor(TestScaleOpBF16):
87+
def setUp(self):
88+
self.op_type = "scale"
89+
self.scale = -2.3
90+
self.x_fp32 = np.random.random((10, 10)).astype(np.float32)
91+
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
92+
self.scale_tensor = np.array([self.scale]).astype(np.float32)
93+
self.inputs = {
94+
'X': self.x_bf16,
95+
'ScaleTensor': convert_float_to_uint16(self.scale_tensor)
96+
}
97+
self.attrs = {'use_mkldnn': True}
98+
self.outputs = {'Out': self.x_fp32 * self.scale}
99+
100+
101+
class TestScaleOpBF16ScaleTensorNotBiasAfterScale(TestScaleOpBF16):
102+
def setUp(self):
103+
self.op_type = "scale"
104+
self.scale = 1.2
105+
self.x_fp32 = np.random.random((9, 13)).astype(np.float32)
106+
self.x_bf16 = convert_float_to_uint16(self.x_fp32)
107+
self.scale_tensor = np.array([self.scale]).astype(np.float32)
108+
self.inputs = {
109+
'X': self.x_bf16,
110+
'ScaleTensor': convert_float_to_uint16(self.scale_tensor)
111+
}
112+
self.attrs = {
113+
'bias': -1.1,
114+
'bias_after_scale': False,
115+
'use_mkldnn': True
116+
}
117+
self.outputs = {'Out': (self.x_fp32 + self.attrs['bias']) * self.scale}
118+
119+
120+
if __name__ == "__main__":
121+
paddle.enable_static()
122+
unittest.main()

0 commit comments

Comments
 (0)