Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ if(NOT DEFINED XPU_XRE_BASE_VERSION)
set(XPU_XRE_BASE_VERSION "4.32.0.1")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "eb35/20240927")
set(XPU_XHPC_BASE_DATE "eb35/20241015")
endif()
set(XPU_XCCL_BASE_VERSION "1.2.11e")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(group_norm_silu_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(layer_norm_relu_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(xpu_delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"group_norm_silu_xpu_fuse_pass",
"layer_norm_relu_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_adaptive_seqlen_fuse_pass",
Expand Down
214 changes: 214 additions & 0 deletions paddle/fluid/framework/ir/xpu/layer_norm_relu_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>

#include "glog/logging.h"

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

namespace phi {
class DenseTensor;
} // namespace phi

namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

/*
fuse ln + activation block in to xpu_ele_fusion op
For example:
graph:
X
Scale | Bias
\ | /
layer norm
/ | \
/ | \
variance | mean
|
relu
|
output
------------------------------------------------------
After the pass is applied:
X
Scale | Bias
\ | /
ln_relu_fusion
|
Out
*/
struct LayerNormalizeReluXPUPattern : public PatternBase {
LayerNormalizeReluXPUPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(ln);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(ln_x);
PATTERN_DECL_NODE(ln_bias);
PATTERN_DECL_NODE(ln_scale);
PATTERN_DECL_NODE(ln_y);
PATTERN_DECL_NODE(ln_mean);
PATTERN_DECL_NODE(ln_variance);
PATTERN_DECL_NODE(relu_out);
};

LayerNormalizeReluXPUPattern::LayerNormalizeReluXPUPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto ln = pattern->NewNode(ln_repr())->assert_is_op("layer_norm");
auto ln_x = pattern->NewNode(ln_x_repr())
->assert_is_op_input("layer_norm", "X")
->AsInput();
auto ln_bias = pattern->NewNode(ln_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var()
->AsInput();
auto ln_scale = pattern->NewNode(ln_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var()
->AsInput();
auto ln_y = pattern->NewNode(ln_y_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("relu", "X")
->assert_has_n_outputs(1);
auto ln_mean = pattern->NewNode(ln_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_has_n_outputs(0);
auto ln_variance = pattern->NewNode(ln_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_has_n_outputs(0);
ln->LinksFrom({ln_x, ln_bias, ln_scale})
.LinksTo({ln_y, ln_mean, ln_variance});

auto relu = pattern->NewNode(relu_repr())->assert_is_op("relu");
auto relu_out = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu", "Out");
relu->LinksFrom({ln_y}).LinksTo({relu_out});
}

} // namespace patterns

class LayerNormalizeReluXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
void FuseLayerNormalizeRelu(ir::Graph* graph) const;

const std::string name_scope_{"layer_norm_relu_xpu_fuse_pass"};
};

void LayerNormalizeReluXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, common::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
auto* dev_ctx = static_cast<phi::CPUContext*>(
phi::DeviceContextPool::Instance().Get(phi::XPUPlace()));
auto version =
phi::backends::xpu::get_xpu_version(dev_ctx->GetPlace().GetDeviceId());
if (version == phi::backends::xpu::XPUVersion::XPU2) {
FuseLayerNormalizeRelu(graph);
}
}

void LayerNormalizeReluXPUFusePass::FuseLayerNormalizeRelu(
ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::LayerNormalizeReluXPUPattern pattern(gpd.mutable_pattern(),
name_scope_);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle LayerNormalizeReluXPUFusePass fuse";
// declare operator node's name
GET_IR_NODE(ln);
GET_IR_NODE(relu);
// declare variable node's name
GET_IR_NODE(ln_x);
GET_IR_NODE(ln_bias);
GET_IR_NODE(ln_scale);
GET_IR_NODE(ln_y);
GET_IR_NODE(ln_mean);
GET_IR_NODE(ln_variance);
GET_IR_NODE(relu_out);

auto* block = ln->Op()->Block();
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, common::errors::InvalidArgument("Scope cannot be nullptr."));
// delete useless node
std::unordered_set<const Node*> delete_nodes;

float eps = PADDLE_GET_CONST(float, ln->Op()->GetAttr("epsilon"));
int begin_norm_axis =
PADDLE_GET_CONST(int, ln->Op()->GetAttr("begin_norm_axis"));

std::string fused_op_out_name;
fused_op_out_name = relu_out->Name();
// Generate add_layernorm fused op
framework::OpDesc fused_op_desc(block);

fused_op_desc.SetType("layer_norm_relu_xpu");
// set attrs for fused op
fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis);
fused_op_desc.SetInput("x", {ln_x->Name()});
fused_op_desc.SetInput("bias", {ln_bias->Name()});
fused_op_desc.SetInput("scale", {ln_scale->Name()});
fused_op_desc.SetAttr("epsilon", eps);
fused_op_desc.SetOutput("out", {fused_op_out_name});
// relink fused op
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(ln_x, fused_op);
IR_NODE_LINK_TO(ln_bias, fused_op);
IR_NODE_LINK_TO(ln_scale, fused_op);
IR_NODE_LINK_TO(fused_op, relu_out);

delete_nodes.insert({ln, relu, ln_y, ln_mean, ln_variance});
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
AddStatis(found_subgraph_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(layer_norm_relu_xpu_fuse_pass,
paddle::framework::ir::LayerNormalizeReluXPUFusePass);

REGISTER_PASS_CAPABILITY(layer_norm_relu_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"layer_norm_relu_xpu", 0));
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"group_norm_silu_xpu_fuse_pass",
"layer_norm_relu_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"qk_qkv_attention_xpu_fuse_pass",
"block_multihead_attention_xpu_pass",
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ XPUOpMap& get_kl2_ops() {
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
{"group_norm_silu_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"layer_norm_relu_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void LayerNormalizeReluXPUInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
int begin_norm_axis,
float epsilon,
MetaTensor* out) {
out->set_dims(x.dims());
// y->share_lod(x);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void FusedMultiTransformerInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& ln_scales,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ void GroupNormalizeSiluXPUInferMeta(const MetaTensor& x,
float epsilon,
MetaTensor* out);

void LayerNormalizeReluXPUInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
int begin_norm_axis,
float epsilon,
MetaTensor* out);

void BlhaGetMaxLenInferMeta(const MetaTensor& seq_lens_encoder,
const MetaTensor& seq_lens_decoder,
const MetaTensor& batch_size,
Expand Down
99 changes: 99 additions & 0 deletions paddle/phi/kernels/fusion/xpu/layer_norm_relu_xpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"

namespace phi {
namespace fusion {

template <typename T, typename Context>
void LayerNormalizeReluXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
int begin_norm_axis,
float epsilon,
DenseTensor* y) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x.dims();
auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
const auto* x_data = x.data<T>();

xpu::ctx_guard RAII_GUARD(ctx.x_context());

// scale
const float* scale_data_fp32 = nullptr;
const auto* scale_ptr = scale.get_ptr();
if (scale_ptr == nullptr) {
// no scale, do nothing
} else if (scale_ptr->dtype() == phi::DataType::FLOAT16) {
float* scale_data_temp =
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_data_temp,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp;
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
}

// bias
const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
if (bias_ptr == nullptr) {
// no bias, do nothing
} else if (bias_ptr->dtype() == phi::DataType::FLOAT16) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
bias_data_temp,
bias_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data_fp32 = bias_data_temp;
} else {
// no need to cast
bias_data_fp32 = bias_ptr->data<float>();
}

auto* out_data = ctx.template Alloc<T>(y);

int r = xpu::layer_norm_relu_fusion(ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(out_data),
left,
right,
epsilon,
scale_data_fp32,
bias_data_fp32);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_relu_fusion");
}

} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(layer_norm_relu_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::LayerNormalizeReluXPUKernel,
float,
phi::dtype::float16) {}
1 change: 0 additions & 1 deletion paddle/phi/kernels/xpu/c_embedding_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ void CEmbeddingKernel(const Context& dev_ctx,
// xm: table height: number of entries of table.
// n: embedding dim: number of float value within single entry.
// ym: number of elements of input ids.

const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
int r = xpu::embedding(dev_ctx.x_context(),
Expand Down
Loading