Skip to content

Commit 8d1354a

Browse files
committed
[XPU] Add layernorm_relu pass and kernel (PaddlePaddle#68451)
1 parent 8ce0de5 commit 8d1354a

File tree

11 files changed

+442
-1
lines changed

11 files changed

+442
-1
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ if(WITH_XPU)
297297
${XPU_PASS_DEPS})
298298
pass_library(add_layernorm_xpu_fuse_pass inference DIR xpu DEPS
299299
${XPU_PASS_DEPS})
300+
pass_library(layer_norm_relu_xpu_fuse_pass inference DIR xpu DEPS
301+
${XPU_PASS_DEPS})
300302
pass_library(xpu_delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
301303
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
302304
${XPU_PASS_DEPS})

paddle/fluid/framework/ir/pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
6060
"constant_folding_pass",
6161
"delete_elementwise_mul_op_pass",
6262
"generate_sequence_xpu_fuse_pass",
63+
"layer_norm_relu_xpu_fuse_pass",
6364
"embedding_with_eltwise_add_xpu_fuse_pass",
6465
"multi_encoder_xpu_fuse_pass",
6566
"multi_encoder_xpu_adaptive_seqlen_fuse_pass",
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
17+
#include "glog/logging.h"
18+
19+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
#include "paddle/fluid/framework/ir/pass.h"
22+
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
23+
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
24+
#include "paddle/fluid/framework/op_version_registry.h"
25+
#include "paddle/fluid/platform/enforce.h"
26+
27+
namespace phi {
28+
class DenseTensor;
29+
} // namespace phi
30+
31+
namespace paddle {
32+
namespace framework {
33+
class Scope;
34+
} // namespace framework
35+
} // namespace paddle
36+
37+
namespace paddle {
38+
namespace framework {
39+
namespace ir {
40+
namespace patterns {
41+
42+
/*
43+
fuse ln + activation block in to xpu_ele_fusion op
44+
For example:
45+
graph:
46+
X
47+
Scale | Bias
48+
\ | /
49+
layer norm
50+
/ | \
51+
/ | \
52+
variance | mean
53+
|
54+
relu
55+
|
56+
output
57+
------------------------------------------------------
58+
After the pass is applied:
59+
X
60+
Scale | Bias
61+
\ | /
62+
ln_relu_fusion
63+
|
64+
Out
65+
*/
66+
struct LayerNormalizeReluXPUPattern : public PatternBase {
67+
LayerNormalizeReluXPUPattern(PDPattern* pattern,
68+
const std::string& name_scope);
69+
// declare operator node's name
70+
PATTERN_DECL_NODE(ln);
71+
PATTERN_DECL_NODE(relu);
72+
// declare variable node's name
73+
PATTERN_DECL_NODE(ln_x);
74+
PATTERN_DECL_NODE(ln_bias);
75+
PATTERN_DECL_NODE(ln_scale);
76+
PATTERN_DECL_NODE(ln_y);
77+
PATTERN_DECL_NODE(ln_mean);
78+
PATTERN_DECL_NODE(ln_variance);
79+
PATTERN_DECL_NODE(relu_out);
80+
};
81+
82+
LayerNormalizeReluXPUPattern::LayerNormalizeReluXPUPattern(
83+
PDPattern* pattern, const std::string& name_scope)
84+
: PatternBase(pattern, name_scope, name_scope) {
85+
auto ln = pattern->NewNode(ln_repr())->assert_is_op("layer_norm");
86+
auto ln_x = pattern->NewNode(ln_x_repr())
87+
->assert_is_op_input("layer_norm", "X")
88+
->AsInput();
89+
auto ln_bias = pattern->NewNode(ln_bias_repr())
90+
->assert_is_op_input("layer_norm", "Bias")
91+
->assert_is_persistable_var()
92+
->AsInput();
93+
auto ln_scale = pattern->NewNode(ln_scale_repr())
94+
->assert_is_op_input("layer_norm", "Scale")
95+
->assert_is_persistable_var()
96+
->AsInput();
97+
auto ln_y = pattern->NewNode(ln_y_repr())
98+
->assert_is_op_output("layer_norm", "Y")
99+
->assert_is_op_input("relu", "X")
100+
->assert_has_n_outputs(1);
101+
auto ln_mean = pattern->NewNode(ln_mean_repr())
102+
->assert_is_op_output("layer_norm", "Mean")
103+
->assert_has_n_outputs(0);
104+
auto ln_variance = pattern->NewNode(ln_variance_repr())
105+
->assert_is_op_output("layer_norm", "Variance")
106+
->assert_has_n_outputs(0);
107+
ln->LinksFrom({ln_x, ln_bias, ln_scale})
108+
.LinksTo({ln_y, ln_mean, ln_variance});
109+
110+
auto relu = pattern->NewNode(relu_repr())->assert_is_op("relu");
111+
auto relu_out = pattern->NewNode(relu_out_repr())
112+
->AsOutput()
113+
->assert_is_op_output("relu", "Out");
114+
relu->LinksFrom({ln_y}).LinksTo({relu_out});
115+
}
116+
117+
} // namespace patterns
118+
119+
class LayerNormalizeReluXPUFusePass : public FusePassBase {
120+
protected:
121+
void ApplyImpl(ir::Graph* graph) const override;
122+
123+
private:
124+
void FuseLayerNormalizeRelu(ir::Graph* graph) const;
125+
126+
const std::string name_scope_{"layer_norm_relu_xpu_fuse_pass"};
127+
};
128+
129+
void LayerNormalizeReluXPUFusePass::ApplyImpl(ir::Graph* graph) const {
130+
PADDLE_ENFORCE_NOT_NULL(
131+
graph, common::errors::PreconditionNotMet("graph should not be null."));
132+
Init(name_scope_, graph);
133+
auto* dev_ctx = static_cast<phi::CPUContext*>(
134+
phi::DeviceContextPool::Instance().Get(phi::XPUPlace()));
135+
auto version =
136+
phi::backends::xpu::get_xpu_version(dev_ctx->GetPlace().GetDeviceId());
137+
if (version == phi::backends::xpu::XPUVersion::XPU2) {
138+
FuseLayerNormalizeRelu(graph);
139+
}
140+
}
141+
142+
void LayerNormalizeReluXPUFusePass::FuseLayerNormalizeRelu(
143+
ir::Graph* graph) const {
144+
GraphPatternDetector gpd;
145+
patterns::LayerNormalizeReluXPUPattern pattern(gpd.mutable_pattern(),
146+
name_scope_);
147+
148+
int found_subgraph_count = 0;
149+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
150+
Graph* graph) {
151+
VLOG(4) << "handle LayerNormalizeReluXPUFusePass fuse";
152+
// declare operator node's name
153+
GET_IR_NODE(ln);
154+
GET_IR_NODE(relu);
155+
// declare variable node's name
156+
GET_IR_NODE(ln_x);
157+
GET_IR_NODE(ln_bias);
158+
GET_IR_NODE(ln_scale);
159+
GET_IR_NODE(ln_y);
160+
GET_IR_NODE(ln_mean);
161+
GET_IR_NODE(ln_variance);
162+
GET_IR_NODE(relu_out);
163+
164+
auto* block = ln->Op()->Block();
165+
auto* scope = param_scope();
166+
PADDLE_ENFORCE_NOT_NULL(
167+
scope, common::errors::InvalidArgument("Scope cannot be nullptr."));
168+
// delete useless node
169+
std::unordered_set<const Node*> delete_nodes;
170+
171+
float eps = PADDLE_GET_CONST(float, ln->Op()->GetAttr("epsilon"));
172+
int begin_norm_axis =
173+
PADDLE_GET_CONST(int, ln->Op()->GetAttr("begin_norm_axis"));
174+
175+
std::string fused_op_out_name;
176+
fused_op_out_name = relu_out->Name();
177+
// Generate add_layernorm fused op
178+
framework::OpDesc fused_op_desc(block);
179+
180+
fused_op_desc.SetType("layer_norm_relu_xpu");
181+
// set attrs for fused op
182+
fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis);
183+
fused_op_desc.SetInput("x", {ln_x->Name()});
184+
fused_op_desc.SetInput("bias", {ln_bias->Name()});
185+
fused_op_desc.SetInput("scale", {ln_scale->Name()});
186+
fused_op_desc.SetAttr("epsilon", eps);
187+
fused_op_desc.SetOutput("out", {fused_op_out_name});
188+
// relink fused op
189+
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
190+
IR_NODE_LINK_TO(ln_x, fused_op);
191+
IR_NODE_LINK_TO(ln_bias, fused_op);
192+
IR_NODE_LINK_TO(ln_scale, fused_op);
193+
IR_NODE_LINK_TO(fused_op, relu_out);
194+
195+
delete_nodes.insert({ln, relu, ln_y, ln_mean, ln_variance});
196+
GraphSafeRemoveNodes(graph, delete_nodes);
197+
found_subgraph_count++;
198+
};
199+
200+
gpd(graph, handler);
201+
AddStatis(found_subgraph_count);
202+
}
203+
204+
} // namespace ir
205+
} // namespace framework
206+
} // namespace paddle
207+
208+
REGISTER_PASS(layer_norm_relu_xpu_fuse_pass,
209+
paddle::framework::ir::LayerNormalizeReluXPUFusePass);
210+
211+
REGISTER_PASS_CAPABILITY(layer_norm_relu_xpu_fuse_pass)
212+
.AddCombination(
213+
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
214+
"layer_norm_relu_xpu", 0));

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
535535
"cast_embedding_trans_ids_to_int32_pass",
536536
"delete_elementwise_mul_op_pass",
537537
"generate_sequence_xpu_fuse_pass",
538+
"layer_norm_relu_xpu_fuse_pass",
538539
"embedding_with_eltwise_add_xpu_fuse_pass",
539540
"multi_encoder_xpu_fuse_pass",
540541
"multi_encoder_xpu_adaptive_seqlen_fuse_pass",

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,16 @@
371371
func : layer_norm_act_xpu
372372
data_type : x
373373

374+
- op : layer_norm_relu_xpu
375+
args : (Tensor x, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon = 1e-5)
376+
output : Tensor(out)
377+
infer_meta :
378+
func : LayerNormalizeReluXPUInferMeta
379+
kernel :
380+
func : layer_norm_relu_xpu
381+
data_type : x
382+
optional : scale, bias
383+
374384
- op : multi_encoder_xpu
375385
args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx)
376386
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,8 @@ XPUOpMap& get_kl2_ops() {
501501
phi::DataType::FLOAT32})},
502502
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
503503
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
504+
{"layer_norm_relu_xpu",
505+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
504506
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
505507
{"hard_sigmoid",
506508
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},

paddle/phi/infermeta/fusion.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
116116
out->share_lod(x);
117117
}
118118

119+
void LayerNormalizeReluXPUInferMeta(const MetaTensor& x,
120+
const MetaTensor& scale,
121+
const MetaTensor& bias,
122+
int begin_norm_axis,
123+
float epsilon,
124+
MetaTensor* out) {
125+
out->set_dims(x.dims());
126+
// y->share_lod(x);
127+
out->set_dtype(x.dtype());
128+
out->set_layout(x.layout());
129+
}
130+
119131
void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
120132
const MetaTensor& key_cache,
121133
const MetaTensor& value_cache,

paddle/phi/infermeta/fusion.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
3838
float epsilon,
3939
MetaTensor* out);
4040

41+
void LayerNormalizeReluXPUInferMeta(const MetaTensor& x,
42+
const MetaTensor& scale,
43+
const MetaTensor& bias,
44+
int begin_norm_axis,
45+
float epsilon,
46+
MetaTensor* out);
47+
4148
void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
4249
const MetaTensor& key_cache,
4350
const MetaTensor& value_cache,
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/kernels/funcs/norm_utils.h"
18+
19+
namespace phi {
20+
namespace fusion {
21+
22+
template <typename T, typename Context>
23+
void LayerNormalizeReluXPUKernel(const Context& ctx,
24+
const DenseTensor& x,
25+
const paddle::optional<DenseTensor>& scale,
26+
const paddle::optional<DenseTensor>& bias,
27+
int begin_norm_axis,
28+
float epsilon,
29+
DenseTensor* y) {
30+
using XPUType = typename XPUTypeTrait<T>::Type;
31+
const auto& x_dims = x.dims();
32+
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
33+
int left = static_cast<int>(matrix_dim[0]);
34+
int right = static_cast<int>(matrix_dim[1]);
35+
const auto* x_data = x.data<T>();
36+
37+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
38+
39+
// scale
40+
const float* scale_data_fp32 = nullptr;
41+
const auto* scale_ptr = scale.get_ptr();
42+
if (scale_ptr == nullptr) {
43+
// no scale, do nothing
44+
} else if (scale_ptr->dtype() == phi::DataType::FLOAT16) {
45+
float* scale_data_temp =
46+
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
47+
int r = xpu::cast<XPUType, float>(
48+
ctx.x_context(),
49+
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
50+
scale_data_temp,
51+
scale_ptr->numel());
52+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
53+
scale_data_fp32 = scale_data_temp;
54+
} else {
55+
// no need to cast
56+
scale_data_fp32 = scale_ptr->data<float>();
57+
}
58+
59+
// bias
60+
const float* bias_data_fp32 = nullptr;
61+
const auto* bias_ptr = bias.get_ptr();
62+
if (bias_ptr == nullptr) {
63+
// no bias, do nothing
64+
} else if (bias_ptr->dtype() == phi::DataType::FLOAT16) {
65+
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
66+
int r = xpu::cast<XPUType, float>(
67+
ctx.x_context(),
68+
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
69+
bias_data_temp,
70+
bias_ptr->numel());
71+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
72+
bias_data_fp32 = bias_data_temp;
73+
} else {
74+
// no need to cast
75+
bias_data_fp32 = bias_ptr->data<float>();
76+
}
77+
78+
auto* out_data = ctx.template Alloc<T>(y);
79+
80+
int r = xpu::layer_norm_relu_fusion(ctx.x_context(),
81+
reinterpret_cast<const XPUType*>(x_data),
82+
reinterpret_cast<XPUType*>(out_data),
83+
left,
84+
right,
85+
epsilon,
86+
scale_data_fp32,
87+
bias_data_fp32);
88+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_relu_fusion");
89+
}
90+
91+
} // namespace fusion
92+
} // namespace phi
93+
94+
PD_REGISTER_KERNEL(layer_norm_relu_xpu,
95+
XPU,
96+
ALL_LAYOUT,
97+
phi::fusion::LayerNormalizeReluXPUKernel,
98+
float,
99+
phi::dtype::float16) {}

0 commit comments

Comments
 (0)