1+ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ http://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License. */
14+
15+ #include " paddle/fluid/platform/mkldnn_reuse.h"
16+
17+ namespace paddle {
18+ namespace operators {
19+
20+ using paddle::framework::Tensor;
21+
22+ template <typename T>
23+ class SliceMKLDNNKernel : public framework ::OpKernel<T> {
24+ public:
25+ void Compute (const framework::ExecutionContext& ctx) const override {
26+ this ->RunKernel (ctx);
27+ }
28+
29+ void RunKernel (const framework::ExecutionContext& ctx) const {
30+ const auto & dev_ctx =
31+ ctx.template device_context <platform::MKLDNNDeviceContext>();
32+ const auto & onednn_engine = dev_ctx.GetEngine ();
33+
34+ auto * x = ctx.Input <Tensor>(" Input" );
35+ auto * out = ctx.Output <Tensor>(" Out" );
36+
37+ auto x_vec_dims = framework::vectorize (x->dims ());
38+ auto out_vec_dims = framework::vectorize (out->dims ());
39+
40+ auto axes_int = ctx.Attr <std::vector<int >>(" axes" );
41+ auto starts_int = ctx.Attr <std::vector<int >>(" starts" );
42+ auto ends_int = ctx.Attr <std::vector<int >>(" ends" );
43+
44+ std::vector<int64_t > axes (ctx.Attr <std::vector<int >>(" axes" ).begin (),
45+ ctx.Attr <std::vector<int >>(" axes" ).end ());
46+ std::vector<int64_t > starts (ctx.Attr <std::vector<int >>(" starts" ).begin (),
47+ ctx.Attr <std::vector<int >>(" starts" ).end ());
48+ std::vector<int64_t > ends (ctx.Attr <std::vector<int >>(" ends" ).begin (),
49+ ctx.Attr <std::vector<int >>(" ends" ).end ());
50+
51+ auto decrease_axis = ctx.Attr <std::vector<int >>(" decrease_axis" );
52+
53+ std::vector<int64_t > offsets (x_vec_dims.size (), 0 );
54+ std::vector<int64_t > slice_dims (x_vec_dims);
55+
56+ for (size_t i = 0 ; i < axes.size (); ++i) {
57+ starts[i] = starts[i] < 0 ? x_vec_dims[axes[i]] + starts[i] : starts[i];
58+ ends[i] = ends[i] < 0 ? x_vec_dims[axes[i]] + ends[i]
59+ : std::min (ends[i], x_vec_dims[axes[i]]);
60+ offsets[axes[i]] = starts[i];
61+ slice_dims[axes[i]] = ends[i] - starts[i];
62+ }
63+
64+ mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType (x->type ());
65+ auto key = platform::CreateKey (dev_ctx, x_vec_dims, axes, starts, ends,
66+ x->format (), x_type);
67+
68+ platform::ReorderMKLDNNHandler reorder_handler (
69+ x_vec_dims, x->type (), x_type, dev_ctx, onednn_engine, key);
70+
71+ auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory (
72+ x->format (), platform::to_void_cast (x->data <T>()));
73+ auto slice_mem_p = reorder_handler.AcquireSubmemory (slice_dims, offsets,
74+ reorder_src_memory_p);
75+ auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory (
76+ out, slice_dims, 0 , x->format (), ctx.GetPlace ());
77+
78+ auto reorder_p =
79+ reorder_handler.AcquireReorder (reorder_dst_memory_p, slice_mem_p);
80+ auto & astream = platform::MKLDNNDeviceContext::tls ().get_stream ();
81+ reorder_p->execute (astream, *slice_mem_p, *reorder_dst_memory_p);
82+ astream.wait ();
83+
84+ out->set_layout (framework::DataLayout::kMKLDNN );
85+ out->set_format (platform::GetMKLDNNFormat (
86+ reorder_dst_memory_p->get_desc ().reshape (out_vec_dims)));
87+ }
88+ };
89+
90+ template <typename T>
91+ class SliceGradMKLDNNKernel : public framework ::OpKernel<T> {
92+ public:
93+ void Compute (const framework::ExecutionContext& ctx) const override {
94+ this ->RunKernel (ctx);
95+ }
96+
97+ void RunKernel (const framework::ExecutionContext& ctx) const {
98+ const auto & dev_ctx =
99+ ctx.template device_context <platform::MKLDNNDeviceContext>();
100+ const auto & onednn_engine = dev_ctx.GetEngine ();
101+
102+ auto * dout = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
103+ auto * dx = ctx.Output <Tensor>(framework::GradVarName (" Input" ));
104+
105+ auto dx_vec_dims = framework::vectorize (dx->dims ());
106+ auto dout_vec_dims = framework::vectorize (dout->dims ());
107+
108+ auto axes_int = ctx.Attr <std::vector<int >>(" axes" );
109+ auto starts_int = ctx.Attr <std::vector<int >>(" starts" );
110+ auto ends_int = ctx.Attr <std::vector<int >>(" ends" );
111+
112+ std::vector<int64_t > axes (ctx.Attr <std::vector<int >>(" axes" ).begin (),
113+ ctx.Attr <std::vector<int >>(" axes" ).end ());
114+ std::vector<int64_t > starts (ctx.Attr <std::vector<int >>(" starts" ).begin (),
115+ ctx.Attr <std::vector<int >>(" starts" ).end ());
116+ std::vector<int64_t > ends (ctx.Attr <std::vector<int >>(" ends" ).begin (),
117+ ctx.Attr <std::vector<int >>(" ends" ).end ());
118+
119+ auto decrease_axis = ctx.Attr <std::vector<int >>(" decrease_axis" );
120+
121+ std::vector<int64_t > offsets (dx_vec_dims.size (), 0 );
122+ std::vector<int64_t > slice_dims (dx_vec_dims);
123+
124+ for (size_t i = 0 ; i < axes.size (); ++i) {
125+ starts[i] = starts[i] < 0 ? dx_vec_dims[axes[i]] + starts[i] : starts[i];
126+ ends[i] = ends[i] < 0 ? dx_vec_dims[axes[i]] + ends[i]
127+ : std::min (ends[i], dx_vec_dims[axes[i]]);
128+ offsets[axes[i]] = starts[i];
129+ slice_dims[axes[i]] = ends[i] - starts[i];
130+ }
131+
132+ mkldnn::memory::data_type dout_type =
133+ framework::ToMKLDNNDataType (dout->type ());
134+ mkldnn::memory::desc md (dout_vec_dims, platform::MKLDNNGetDataType<T>(),
135+ dout->format ());
136+ mkldnn::memory::format_tag reorder_format_tag =
137+ platform::GetMKLDNNFormat (md.reshape (slice_dims));
138+
139+ auto key = platform::CreateKey (dev_ctx, dout_vec_dims, axes, starts, ends,
140+ reorder_format_tag, dout_type);
141+
142+ platform::ReorderMKLDNNHandler reorder_handler (
143+ slice_dims, dout->type (), dout_type, dev_ctx, onednn_engine, key);
144+
145+ auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory (
146+ reorder_format_tag, platform::to_void_cast (dout->data <T>()));
147+ auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory (
148+ dx, dx_vec_dims, 0 , reorder_format_tag, ctx.GetPlace ());
149+ memset (dx->data <T>(), 0 , reorder_dst_memory_p->get_desc ().get_size ());
150+
151+ auto slice_mem_p = reorder_handler.AcquireSubmemory (slice_dims, offsets,
152+ reorder_dst_memory_p);
153+
154+ auto reorder_p =
155+ reorder_handler.AcquireReorder (slice_mem_p, reorder_src_memory_p);
156+ auto & astream = platform::MKLDNNDeviceContext::tls ().get_stream ();
157+ reorder_p->execute (astream, *reorder_src_memory_p, *slice_mem_p);
158+ astream.wait ();
159+
160+ dx->set_layout (framework::DataLayout::kMKLDNN );
161+ dx->set_format (reorder_format_tag);
162+ }
163+ };
164+ } // namespace operators
165+ } // namespace paddle
166+
167+ namespace ops = paddle::operators;
168+ REGISTER_OP_KERNEL (slice, MKLDNN, paddle::platform::CPUPlace,
169+ ops::SliceMKLDNNKernel<float >,
170+ ops::SliceMKLDNNKernel<paddle::platform::bfloat16>);
171+
172+ namespace ops = paddle::operators;
173+ REGISTER_OP_KERNEL (slice_grad, MKLDNN, paddle::platform::CPUPlace,
174+ ops::SliceGradMKLDNNKernel<float >,
175+ ops::SliceGradMKLDNNKernel<paddle::platform::bfloat16>);
0 commit comments