Skip to content

Commit dff6f2d

Browse files
committed
[xpu]: add layernorm_relu pass and kernel ;test=develop
1 parent fec5afd commit dff6f2d

File tree

12 files changed

+443
-2
lines changed

12 files changed

+443
-2
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ if(NOT DEFINED XPU_XRE_BASE_VERSION)
3030
set(XPU_XRE_BASE_VERSION "4.32.0.1")
3131
endif()
3232
if(NOT DEFINED XPU_XHPC_BASE_DATE)
33-
set(XPU_XHPC_BASE_DATE "eb35/20240923")
33+
set(XPU_XHPC_BASE_DATE "eb35/20241015")
3434
endif()
3535
set(XPU_XCCL_BASE_VERSION "1.2.11d")
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ if(WITH_XPU)
306306
${XPU_PASS_DEPS})
307307
pass_library(group_norm_silu_xpu_fuse_pass inference DIR xpu DEPS
308308
${XPU_PASS_DEPS})
309+
pass_library(layer_norm_relu_xpu_fuse_pass inference DIR xpu DEPS
310+
${XPU_PASS_DEPS})
309311
pass_library(xpu_delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
310312
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
311313
${XPU_PASS_DEPS})

paddle/fluid/framework/ir/pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
6969
"delete_elementwise_mul_op_pass",
7070
"generate_sequence_xpu_fuse_pass",
7171
"group_norm_silu_xpu_fuse_pass",
72+
"layer_norm_relu_xpu_fuse_pass",
7273
"embedding_with_eltwise_add_xpu_fuse_pass",
7374
"multi_encoder_xpu_fuse_pass",
7475
"multi_encoder_xpu_adaptive_seqlen_fuse_pass",
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
17+
#include "glog/logging.h"
18+
19+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
#include "paddle/fluid/framework/ir/pass.h"
22+
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
23+
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
24+
#include "paddle/fluid/framework/op_version_registry.h"
25+
#include "paddle/fluid/platform/enforce.h"
26+
27+
namespace phi {
28+
class DenseTensor;
29+
} // namespace phi
30+
31+
namespace paddle {
32+
namespace framework {
33+
class Scope;
34+
} // namespace framework
35+
} // namespace paddle
36+
37+
namespace paddle {
38+
namespace framework {
39+
namespace ir {
40+
namespace patterns {
41+
42+
/*
43+
fuse ln + activation block in to xpu_ele_fusion op
44+
For example:
45+
graph:
46+
X
47+
Scale | Bias
48+
\ | /
49+
layer norm
50+
/ | \
51+
/ | \
52+
variance | mean
53+
|
54+
relu
55+
|
56+
output
57+
------------------------------------------------------
58+
After the pass is applied:
59+
X
60+
Scale | Bias
61+
\ | /
62+
ln_relu_fusion
63+
|
64+
Out
65+
*/
66+
struct LayerNormalizeReluXPUPattern : public PatternBase {
67+
LayerNormalizeReluXPUPattern(PDPattern* pattern,
68+
const std::string& name_scope);
69+
// declare operator node's name
70+
PATTERN_DECL_NODE(ln);
71+
PATTERN_DECL_NODE(relu);
72+
// declare variable node's name
73+
PATTERN_DECL_NODE(ln_x);
74+
PATTERN_DECL_NODE(ln_bias);
75+
PATTERN_DECL_NODE(ln_scale);
76+
PATTERN_DECL_NODE(ln_y);
77+
PATTERN_DECL_NODE(ln_mean);
78+
PATTERN_DECL_NODE(ln_variance);
79+
PATTERN_DECL_NODE(relu_out);
80+
};
81+
82+
LayerNormalizeReluXPUPattern::LayerNormalizeReluXPUPattern(
83+
PDPattern* pattern, const std::string& name_scope)
84+
: PatternBase(pattern, name_scope, name_scope) {
85+
auto ln = pattern->NewNode(ln_repr())->assert_is_op("layer_norm");
86+
auto ln_x = pattern->NewNode(ln_x_repr())
87+
->assert_is_op_input("layer_norm", "X")
88+
->AsInput();
89+
auto ln_bias = pattern->NewNode(ln_bias_repr())
90+
->assert_is_op_input("layer_norm", "Bias")
91+
->assert_is_persistable_var()
92+
->AsInput();
93+
auto ln_scale = pattern->NewNode(ln_scale_repr())
94+
->assert_is_op_input("layer_norm", "Scale")
95+
->assert_is_persistable_var()
96+
->AsInput();
97+
auto ln_y = pattern->NewNode(ln_y_repr())
98+
->assert_is_op_output("layer_norm", "Y")
99+
->assert_is_op_input("relu", "X")
100+
->assert_has_n_outputs(1);
101+
auto ln_mean = pattern->NewNode(ln_mean_repr())
102+
->assert_is_op_output("layer_norm", "Mean")
103+
->assert_has_n_outputs(0);
104+
auto ln_variance = pattern->NewNode(ln_variance_repr())
105+
->assert_is_op_output("layer_norm", "Variance")
106+
->assert_has_n_outputs(0);
107+
ln->LinksFrom({ln_x, ln_bias, ln_scale})
108+
.LinksTo({ln_y, ln_mean, ln_variance});
109+
110+
auto relu = pattern->NewNode(relu_repr())->assert_is_op("relu");
111+
auto relu_out = pattern->NewNode(relu_out_repr())
112+
->AsOutput()
113+
->assert_is_op_output("relu", "Out");
114+
relu->LinksFrom({ln_y}).LinksTo({relu_out});
115+
}
116+
117+
} // namespace patterns
118+
119+
class LayerNormalizeReluXPUFusePass : public FusePassBase {
120+
protected:
121+
void ApplyImpl(ir::Graph* graph) const override;
122+
123+
private:
124+
void FuseLayerNormalizeRelu(ir::Graph* graph) const;
125+
126+
const std::string name_scope_{"layer_norm_relu_xpu_fuse_pass"};
127+
};
128+
129+
void LayerNormalizeReluXPUFusePass::ApplyImpl(ir::Graph* graph) const {
130+
PADDLE_ENFORCE_NOT_NULL(
131+
graph, common::errors::PreconditionNotMet("graph should not be null."));
132+
Init(name_scope_, graph);
133+
auto* dev_ctx = static_cast<phi::CPUContext*>(
134+
phi::DeviceContextPool::Instance().Get(phi::XPUPlace()));
135+
auto version =
136+
phi::backends::xpu::get_xpu_version(dev_ctx->GetPlace().GetDeviceId());
137+
if (version == phi::backends::xpu::XPUVersion::XPU2) {
138+
FuseLayerNormalizeRelu(graph);
139+
}
140+
}
141+
142+
void LayerNormalizeReluXPUFusePass::FuseLayerNormalizeRelu(
143+
ir::Graph* graph) const {
144+
GraphPatternDetector gpd;
145+
patterns::LayerNormalizeReluXPUPattern pattern(gpd.mutable_pattern(),
146+
name_scope_);
147+
148+
int found_subgraph_count = 0;
149+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
150+
Graph* graph) {
151+
VLOG(4) << "handle LayerNormalizeReluXPUFusePass fuse";
152+
// declare operator node's name
153+
GET_IR_NODE(ln);
154+
GET_IR_NODE(relu);
155+
// declare variable node's name
156+
GET_IR_NODE(ln_x);
157+
GET_IR_NODE(ln_bias);
158+
GET_IR_NODE(ln_scale);
159+
GET_IR_NODE(ln_y);
160+
GET_IR_NODE(ln_mean);
161+
GET_IR_NODE(ln_variance);
162+
GET_IR_NODE(relu_out);
163+
164+
auto* block = ln->Op()->Block();
165+
auto* scope = param_scope();
166+
PADDLE_ENFORCE_NOT_NULL(
167+
scope, common::errors::InvalidArgument("Scope cannot be nullptr."));
168+
// delete useless node
169+
std::unordered_set<const Node*> delete_nodes;
170+
171+
float eps = PADDLE_GET_CONST(float, ln->Op()->GetAttr("epsilon"));
172+
int begin_norm_axis =
173+
PADDLE_GET_CONST(int, ln->Op()->GetAttr("begin_norm_axis"));
174+
175+
std::string fused_op_out_name;
176+
fused_op_out_name = relu_out->Name();
177+
// Generate add_layernorm fused op
178+
framework::OpDesc fused_op_desc(block);
179+
180+
fused_op_desc.SetType("layer_norm_relu_xpu");
181+
// set attrs for fused op
182+
fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis);
183+
fused_op_desc.SetInput("x", {ln_x->Name()});
184+
fused_op_desc.SetInput("bias", {ln_bias->Name()});
185+
fused_op_desc.SetInput("scale", {ln_scale->Name()});
186+
fused_op_desc.SetAttr("epsilon", eps);
187+
fused_op_desc.SetOutput("out", {fused_op_out_name});
188+
// relink fused op
189+
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
190+
IR_NODE_LINK_TO(ln_x, fused_op);
191+
IR_NODE_LINK_TO(ln_bias, fused_op);
192+
IR_NODE_LINK_TO(ln_scale, fused_op);
193+
IR_NODE_LINK_TO(fused_op, relu_out);
194+
195+
delete_nodes.insert({ln, relu, ln_y, ln_mean, ln_variance});
196+
GraphSafeRemoveNodes(graph, delete_nodes);
197+
found_subgraph_count++;
198+
};
199+
200+
gpd(graph, handler);
201+
AddStatis(found_subgraph_count);
202+
}
203+
204+
} // namespace ir
205+
} // namespace framework
206+
} // namespace paddle
207+
208+
REGISTER_PASS(layer_norm_relu_xpu_fuse_pass,
209+
paddle::framework::ir::LayerNormalizeReluXPUFusePass);
210+
211+
REGISTER_PASS_CAPABILITY(layer_norm_relu_xpu_fuse_pass)
212+
.AddCombination(
213+
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
214+
"layer_norm_relu_xpu", 0));

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
516516
"delete_elementwise_mul_op_pass",
517517
"generate_sequence_xpu_fuse_pass",
518518
"group_norm_silu_xpu_fuse_pass",
519+
"layer_norm_relu_xpu_fuse_pass",
519520
"embedding_with_eltwise_add_xpu_fuse_pass",
520521
"qk_qkv_attention_xpu_fuse_pass",
521522
"block_multihead_attention_xpu_pass",

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ XPUOpMap& get_kl2_ops() {
566566
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
567567
{"group_norm_silu_xpu",
568568
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
569+
{"layer_norm_relu_xpu",
570+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
569571
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
570572
{"hard_sigmoid",
571573
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},

paddle/phi/infermeta/fusion.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x,
131131
out->share_lod(x);
132132
}
133133

134+
void LayerNormalizeReluXPUInferMeta(const MetaTensor& x,
135+
const MetaTensor& scale,
136+
const MetaTensor& bias,
137+
int begin_norm_axis,
138+
float epsilon,
139+
MetaTensor* out) {
140+
out->set_dims(x.dims());
141+
// y->share_lod(x);
142+
out->set_dtype(x.dtype());
143+
out->set_layout(x.layout());
144+
}
145+
134146
void FusedMultiTransformerInferMeta(
135147
const MetaTensor& x,
136148
const std::vector<const MetaTensor*>& ln_scales,

paddle/phi/infermeta/fusion.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x,
8282
float epsilon,
8383
MetaTensor* out);
8484

85+
void LayerNormalizeReluXPUInferMeta(const MetaTensor& x,
86+
const MetaTensor& scale,
87+
const MetaTensor& bias,
88+
int begin_norm_axis,
89+
float epsilon,
90+
MetaTensor* out);
91+
8592
void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder,
8693
const MetaTensor& seq_lens_decoder,
8794
const MetaTensor& batch_size,
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/kernels/funcs/norm_utils.h"
18+
19+
namespace phi {
20+
namespace fusion {
21+
22+
template <typename T, typename Context>
23+
void LayerNormalizeReluXPUKernel(const Context& ctx,
24+
const DenseTensor& x,
25+
const paddle::optional<DenseTensor>& scale,
26+
const paddle::optional<DenseTensor>& bias,
27+
int begin_norm_axis,
28+
float epsilon,
29+
DenseTensor* y) {
30+
using XPUType = typename XPUTypeTrait<T>::Type;
31+
const auto& x_dims = x.dims();
32+
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
33+
int left = static_cast<int>(matrix_dim[0]);
34+
int right = static_cast<int>(matrix_dim[1]);
35+
const auto* x_data = x.data<T>();
36+
37+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
38+
39+
// scale
40+
const float* scale_data_fp32 = nullptr;
41+
const auto* scale_ptr = scale.get_ptr();
42+
if (scale_ptr == nullptr) {
43+
// no scale, do nothing
44+
} else if (scale_ptr->dtype() == phi::DataType::FLOAT16) {
45+
float* scale_data_temp =
46+
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
47+
int r = xpu::cast<XPUType, float>(
48+
ctx.x_context(),
49+
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
50+
scale_data_temp,
51+
scale_ptr->numel());
52+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
53+
scale_data_fp32 = scale_data_temp;
54+
} else {
55+
// no need to cast
56+
scale_data_fp32 = scale_ptr->data<float>();
57+
}
58+
59+
// bias
60+
const float* bias_data_fp32 = nullptr;
61+
const auto* bias_ptr = bias.get_ptr();
62+
if (bias_ptr == nullptr) {
63+
// no bias, do nothing
64+
} else if (bias_ptr->dtype() == phi::DataType::FLOAT16) {
65+
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
66+
int r = xpu::cast<XPUType, float>(
67+
ctx.x_context(),
68+
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
69+
bias_data_temp,
70+
bias_ptr->numel());
71+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
72+
bias_data_fp32 = bias_data_temp;
73+
} else {
74+
// no need to cast
75+
bias_data_fp32 = bias_ptr->data<float>();
76+
}
77+
78+
auto* out_data = ctx.template Alloc<T>(y);
79+
80+
int r = xpu::layer_norm_relu_fusion(ctx.x_context(),
81+
reinterpret_cast<const XPUType*>(x_data),
82+
reinterpret_cast<XPUType*>(out_data),
83+
left,
84+
right,
85+
epsilon,
86+
scale_data_fp32,
87+
bias_data_fp32);
88+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_relu_fusion");
89+
}
90+
91+
} // namespace fusion
92+
} // namespace phi
93+
94+
PD_REGISTER_KERNEL(layer_norm_relu_xpu,
95+
XPU,
96+
ALL_LAYOUT,
97+
phi::fusion::LayerNormalizeReluXPUKernel,
98+
float,
99+
phi::dtype::float16) {}

paddle/phi/kernels/xpu/c_embedding_kernel.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ void CEmbeddingKernel(const Context& dev_ctx,
3939
// xm: table height: number of entries of table.
4040
// n: embedding dim: number of float value within single entry.
4141
// ym: number of elements of input ids.
42-
4342
const auto& index_type = ids.dtype();
4443
if (index_type == phi::DataType::INT32) {
4544
int r = xpu::embedding(dev_ctx.x_context(),

0 commit comments

Comments
 (0)