Skip to content

Commit 070cab1

Browse files
authored
Added slice BF16/FP32 FWD/BWD kernels (#34332)
* aded slice FWD FP32 * added tests for slice FWD FP32 * added slice bwd * added bf16 tests * CI fix * CI fix * added reason to skip_if * minor change * temporary fix for failing test * temporary fix * changes after review * CI rerun
1 parent a647b80 commit 070cab1

File tree

7 files changed

+436
-12
lines changed

7 files changed

+436
-12
lines changed
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/platform/mkldnn_reuse.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using paddle::framework::Tensor;
21+
22+
template <typename T>
23+
class SliceMKLDNNKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
this->RunKernel(ctx);
27+
}
28+
29+
void RunKernel(const framework::ExecutionContext& ctx) const {
30+
const auto& dev_ctx =
31+
ctx.template device_context<platform::MKLDNNDeviceContext>();
32+
const auto& onednn_engine = dev_ctx.GetEngine();
33+
34+
auto* x = ctx.Input<Tensor>("Input");
35+
auto* out = ctx.Output<Tensor>("Out");
36+
37+
auto x_vec_dims = framework::vectorize(x->dims());
38+
auto out_vec_dims = framework::vectorize(out->dims());
39+
40+
auto axes_int = ctx.Attr<std::vector<int>>("axes");
41+
auto starts_int = ctx.Attr<std::vector<int>>("starts");
42+
auto ends_int = ctx.Attr<std::vector<int>>("ends");
43+
44+
std::vector<int64_t> axes(ctx.Attr<std::vector<int>>("axes").begin(),
45+
ctx.Attr<std::vector<int>>("axes").end());
46+
std::vector<int64_t> starts(ctx.Attr<std::vector<int>>("starts").begin(),
47+
ctx.Attr<std::vector<int>>("starts").end());
48+
std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(),
49+
ctx.Attr<std::vector<int>>("ends").end());
50+
51+
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
52+
53+
std::vector<int64_t> offsets(x_vec_dims.size(), 0);
54+
std::vector<int64_t> slice_dims(x_vec_dims);
55+
56+
for (size_t i = 0; i < axes.size(); ++i) {
57+
starts[i] = starts[i] < 0 ? x_vec_dims[axes[i]] + starts[i] : starts[i];
58+
ends[i] = ends[i] < 0 ? x_vec_dims[axes[i]] + ends[i]
59+
: std::min(ends[i], x_vec_dims[axes[i]]);
60+
offsets[axes[i]] = starts[i];
61+
slice_dims[axes[i]] = ends[i] - starts[i];
62+
}
63+
64+
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
65+
auto key = platform::CreateKey(dev_ctx, x_vec_dims, axes, starts, ends,
66+
x->format(), x_type);
67+
68+
platform::ReorderMKLDNNHandler reorder_handler(
69+
x_vec_dims, x->type(), x_type, dev_ctx, onednn_engine, key);
70+
71+
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
72+
x->format(), platform::to_void_cast(x->data<T>()));
73+
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
74+
reorder_src_memory_p);
75+
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
76+
out, slice_dims, 0, x->format(), ctx.GetPlace());
77+
78+
auto reorder_p =
79+
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);
80+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
81+
reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);
82+
astream.wait();
83+
84+
out->set_layout(framework::DataLayout::kMKLDNN);
85+
out->set_format(platform::GetMKLDNNFormat(
86+
reorder_dst_memory_p->get_desc().reshape(out_vec_dims)));
87+
}
88+
};
89+
90+
template <typename T>
91+
class SliceGradMKLDNNKernel : public framework::OpKernel<T> {
92+
public:
93+
void Compute(const framework::ExecutionContext& ctx) const override {
94+
this->RunKernel(ctx);
95+
}
96+
97+
void RunKernel(const framework::ExecutionContext& ctx) const {
98+
const auto& dev_ctx =
99+
ctx.template device_context<platform::MKLDNNDeviceContext>();
100+
const auto& onednn_engine = dev_ctx.GetEngine();
101+
102+
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
103+
auto* dx = ctx.Output<Tensor>(framework::GradVarName("Input"));
104+
105+
auto dx_vec_dims = framework::vectorize(dx->dims());
106+
auto dout_vec_dims = framework::vectorize(dout->dims());
107+
108+
auto axes_int = ctx.Attr<std::vector<int>>("axes");
109+
auto starts_int = ctx.Attr<std::vector<int>>("starts");
110+
auto ends_int = ctx.Attr<std::vector<int>>("ends");
111+
112+
std::vector<int64_t> axes(ctx.Attr<std::vector<int>>("axes").begin(),
113+
ctx.Attr<std::vector<int>>("axes").end());
114+
std::vector<int64_t> starts(ctx.Attr<std::vector<int>>("starts").begin(),
115+
ctx.Attr<std::vector<int>>("starts").end());
116+
std::vector<int64_t> ends(ctx.Attr<std::vector<int>>("ends").begin(),
117+
ctx.Attr<std::vector<int>>("ends").end());
118+
119+
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
120+
121+
std::vector<int64_t> offsets(dx_vec_dims.size(), 0);
122+
std::vector<int64_t> slice_dims(dx_vec_dims);
123+
124+
for (size_t i = 0; i < axes.size(); ++i) {
125+
starts[i] = starts[i] < 0 ? dx_vec_dims[axes[i]] + starts[i] : starts[i];
126+
ends[i] = ends[i] < 0 ? dx_vec_dims[axes[i]] + ends[i]
127+
: std::min(ends[i], dx_vec_dims[axes[i]]);
128+
offsets[axes[i]] = starts[i];
129+
slice_dims[axes[i]] = ends[i] - starts[i];
130+
}
131+
132+
mkldnn::memory::data_type dout_type =
133+
framework::ToMKLDNNDataType(dout->type());
134+
mkldnn::memory::desc md(dout_vec_dims, platform::MKLDNNGetDataType<T>(),
135+
dout->format());
136+
mkldnn::memory::format_tag reorder_format_tag =
137+
platform::GetMKLDNNFormat(md.reshape(slice_dims));
138+
139+
auto key = platform::CreateKey(dev_ctx, dout_vec_dims, axes, starts, ends,
140+
reorder_format_tag, dout_type);
141+
142+
platform::ReorderMKLDNNHandler reorder_handler(
143+
slice_dims, dout->type(), dout_type, dev_ctx, onednn_engine, key);
144+
145+
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
146+
reorder_format_tag, platform::to_void_cast(dout->data<T>()));
147+
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
148+
dx, dx_vec_dims, 0, reorder_format_tag, ctx.GetPlace());
149+
memset(dx->data<T>(), 0, reorder_dst_memory_p->get_desc().get_size());
150+
151+
auto slice_mem_p = reorder_handler.AcquireSubmemory(slice_dims, offsets,
152+
reorder_dst_memory_p);
153+
154+
auto reorder_p =
155+
reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p);
156+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
157+
reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p);
158+
astream.wait();
159+
160+
dx->set_layout(framework::DataLayout::kMKLDNN);
161+
dx->set_format(reorder_format_tag);
162+
}
163+
};
164+
} // namespace operators
165+
} // namespace paddle
166+
167+
namespace ops = paddle::operators;
168+
REGISTER_OP_KERNEL(slice, MKLDNN, paddle::platform::CPUPlace,
169+
ops::SliceMKLDNNKernel<float>,
170+
ops::SliceMKLDNNKernel<paddle::platform::bfloat16>);
171+
172+
namespace ops = paddle::operators;
173+
REGISTER_OP_KERNEL(slice_grad, MKLDNN, paddle::platform::CPUPlace,
174+
ops::SliceGradMKLDNNKernel<float>,
175+
ops::SliceGradMKLDNNKernel<paddle::platform::bfloat16>);

paddle/fluid/operators/mkldnn/split_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class SplitMKLDNNKernel : public framework::OpKernel<T> {
105105

106106
for (size_t i = 0; i < outs_number; ++i) {
107107
auto out_vec_dims = framework::vectorize(outs[i]->dims());
108-
auto slice_mem_p = reorder_handler.AcquireSrcSubmemory(
108+
auto slice_mem_p = reorder_handler.AcquireSubmemory(
109109
out_vec_dims, offset, reorder_src_memory_p, i);
110110

111111
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(

paddle/fluid/operators/slice_op.cc

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,26 @@ class SliceOp : public framework::OperatorWithKernel {
132132
if (platform::is_cuda_pinned_place(in_tensor.place())) {
133133
return framework::OpKernelType(in_tensor.type(), ctx.device_context());
134134
}
135+
136+
#ifdef PADDLE_WITH_MKLDNN
137+
auto input_data_type =
138+
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
139+
140+
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
141+
// OneDNN uses blocking format, which cannot be always supported with
142+
// reorders, because if blocked dimension is not divisible by 8 or
143+
// 16(depending on which blocking format is used) submemory cannot be
144+
// created, so in that scenario a fallback is needed
145+
auto tmp_md = dnnl::memory::desc(
146+
framework::vectorize(ctx.Input<Tensor>("Input")->dims()),
147+
dnnl::memory::data_type::f32, ctx.Input<Tensor>("Input")->format());
148+
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
149+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
150+
framework::DataLayout::kMKLDNN,
151+
framework::LibraryType::kMKLDNN);
152+
}
153+
#endif
154+
135155
return framework::OpKernelType(in_tensor.type(), in_tensor.place());
136156
}
137157
return framework::OpKernelType(
@@ -216,6 +236,14 @@ class SliceOpMaker : public framework::OpProtoAndCheckerMaker {
216236
.SetDefault({});
217237
AddAttr<std::vector<int>>("decrease_axis", "(list<int>) decrease_axis")
218238
.SetDefault({});
239+
AddAttr<bool>("use_mkldnn",
240+
"(bool, default false) Only used in mkldnn kernel")
241+
.SetDefault(false);
242+
AddAttr<std::string>(
243+
"mkldnn_data_type",
244+
"(string, default \"float32\"). Data type of mkldnn kernel")
245+
.SetDefault("float32")
246+
.InEnum({"float32", "bfloat16"});
219247
AddComment(R"DOC(
220248
Slice Operator.
221249
@@ -278,12 +306,32 @@ class SliceOpGrad : public framework::OperatorWithKernel {
278306
ctx->SetOutputDim(x_grad_name, x_dims);
279307
}
280308
}
309+
281310
framework::OpKernelType GetExpectedKernelType(
282311
const framework::ExecutionContext &ctx) const override {
283-
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
284-
ctx, framework::GradVarName("Out")),
285-
ctx.device_context());
312+
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
313+
ctx, framework::GradVarName("Out"));
314+
315+
#ifdef PADDLE_WITH_MKLDNN
316+
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
317+
// OneDNN uses blocking format, which cannot be always supported with
318+
// reorders, because if blocked dimension is not divisible by 8 or
319+
// 16(depending on which blocking format is used) submemory cannot be
320+
// created, so in that scenario a fallback is needed
321+
auto tmp_md = dnnl::memory::desc(
322+
framework::vectorize(
323+
ctx.Input<Tensor>(framework::GradVarName("Out"))->dims()),
324+
dnnl::memory::data_type::f32,
325+
ctx.Input<Tensor>(framework::GradVarName("Out"))->format());
326+
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
327+
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
328+
framework::DataLayout::kMKLDNN,
329+
framework::LibraryType::kMKLDNN);
330+
}
331+
#endif
332+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
286333
}
334+
287335
framework::OpKernelType GetKernelTypeForVar(
288336
const std::string &var_name, const Tensor &tensor,
289337
const framework::OpKernelType &expected_kernel_type) const override {

paddle/fluid/operators/split_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,9 @@ class SplitOp : public framework::OperatorWithKernel {
7878

7979
#ifdef PADDLE_WITH_MKLDNN
8080
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
81-
// OneDNN uses blocking format, which cannot be always
82-
// supported with reorders, because if blocked dimension is not divisible
83-
// by
84-
// 8 or 16(depending on which blocking format is used) submemory cannot be
81+
// OneDNN uses blocking format, which cannot be always supported with
82+
// reorders, because if blocked dimension is not divisible by 8 or
83+
// 16(depending on which blocking format is used) submemory cannot be
8584
// created, so in that scenario a fallback is needed
8685
auto tmp_md = dnnl::memory::desc(
8786
framework::vectorize(ctx.Input<Tensor>("X")->dims()),

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,9 +1090,9 @@ class ReorderMKLDNNHandler : public MKLDNNHandler {
10901090
return this->AcquireMemory(dims_, dtype_, fmt, ptr, "@user_src_mem_p");
10911091
}
10921092

1093-
std::shared_ptr<mkldnn::memory> AcquireSrcSubmemory(
1093+
std::shared_ptr<mkldnn::memory> AcquireSubmemory(
10941094
const std::vector<int64_t>& dims, const std::vector<int64_t>& offset,
1095-
const std::shared_ptr<mkldnn::memory>& mem_p, int submemory_number) {
1095+
const std::shared_ptr<mkldnn::memory>& mem_p, int submemory_number = 0) {
10961096
std::string local_key = key_;
10971097
local_key.append("@submem")
10981098
.append(std::to_string(submemory_number))

0 commit comments

Comments
 (0)