|
| 1 | +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include <string> |
| 16 | + |
| 17 | +#include "glog/logging.h" |
| 18 | + |
| 19 | +#include "paddle/fluid/framework/ir/fuse_pass_base.h" |
| 20 | +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" |
| 21 | +#include "paddle/fluid/framework/ir/pass.h" |
| 22 | +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" |
| 23 | +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" |
| 24 | +#include "paddle/fluid/framework/op_version_registry.h" |
| 25 | +#include "paddle/fluid/platform/enforce.h" |
| 26 | + |
| 27 | +namespace phi { |
| 28 | +class DenseTensor; |
| 29 | +} // namespace phi |
| 30 | + |
| 31 | +namespace paddle { |
| 32 | +namespace framework { |
| 33 | +class Scope; |
| 34 | +} // namespace framework |
| 35 | +} // namespace paddle |
| 36 | + |
| 37 | +namespace paddle { |
| 38 | +namespace framework { |
| 39 | +namespace ir { |
| 40 | +namespace patterns { |
| 41 | + |
| 42 | +/* |
| 43 | +fuse ln + activation block in to xpu_ele_fusion op |
| 44 | +For example: |
| 45 | +graph: |
| 46 | + X |
| 47 | + Scale | Bias |
| 48 | + \ | / |
| 49 | + layer norm |
| 50 | + / | \ |
| 51 | + / | \ |
| 52 | + variance | mean |
| 53 | + | |
| 54 | + relu |
| 55 | + | |
| 56 | + output |
| 57 | +------------------------------------------------------ |
| 58 | +After the pass is applied: |
| 59 | + X |
| 60 | + Scale | Bias |
| 61 | + \ | / |
| 62 | + ln_relu_fusion |
| 63 | + | |
| 64 | + Out |
| 65 | +*/ |
| 66 | +struct LayerNormalizeReluXPUPattern : public PatternBase { |
| 67 | + LayerNormalizeReluXPUPattern(PDPattern* pattern, |
| 68 | + const std::string& name_scope); |
| 69 | + // declare operator node's name |
| 70 | + PATTERN_DECL_NODE(ln); |
| 71 | + PATTERN_DECL_NODE(relu); |
| 72 | + // declare variable node's name |
| 73 | + PATTERN_DECL_NODE(ln_x); |
| 74 | + PATTERN_DECL_NODE(ln_bias); |
| 75 | + PATTERN_DECL_NODE(ln_scale); |
| 76 | + PATTERN_DECL_NODE(ln_y); |
| 77 | + PATTERN_DECL_NODE(ln_mean); |
| 78 | + PATTERN_DECL_NODE(ln_variance); |
| 79 | + PATTERN_DECL_NODE(relu_out); |
| 80 | +}; |
| 81 | + |
| 82 | +LayerNormalizeReluXPUPattern::LayerNormalizeReluXPUPattern( |
| 83 | + PDPattern* pattern, const std::string& name_scope) |
| 84 | + : PatternBase(pattern, name_scope, name_scope) { |
| 85 | + auto ln = pattern->NewNode(ln_repr())->assert_is_op("layer_norm"); |
| 86 | + auto ln_x = pattern->NewNode(ln_x_repr()) |
| 87 | + ->assert_is_op_input("layer_norm", "X") |
| 88 | + ->AsInput(); |
| 89 | + auto ln_bias = pattern->NewNode(ln_bias_repr()) |
| 90 | + ->assert_is_op_input("layer_norm", "Bias") |
| 91 | + ->assert_is_persistable_var() |
| 92 | + ->AsInput(); |
| 93 | + auto ln_scale = pattern->NewNode(ln_scale_repr()) |
| 94 | + ->assert_is_op_input("layer_norm", "Scale") |
| 95 | + ->assert_is_persistable_var() |
| 96 | + ->AsInput(); |
| 97 | + auto ln_y = pattern->NewNode(ln_y_repr()) |
| 98 | + ->assert_is_op_output("layer_norm", "Y") |
| 99 | + ->assert_is_op_input("relu", "X") |
| 100 | + ->assert_has_n_outputs(1); |
| 101 | + auto ln_mean = pattern->NewNode(ln_mean_repr()) |
| 102 | + ->assert_is_op_output("layer_norm", "Mean") |
| 103 | + ->assert_has_n_outputs(0); |
| 104 | + auto ln_variance = pattern->NewNode(ln_variance_repr()) |
| 105 | + ->assert_is_op_output("layer_norm", "Variance") |
| 106 | + ->assert_has_n_outputs(0); |
| 107 | + ln->LinksFrom({ln_x, ln_bias, ln_scale}) |
| 108 | + .LinksTo({ln_y, ln_mean, ln_variance}); |
| 109 | + |
| 110 | + auto relu = pattern->NewNode(relu_repr())->assert_is_op("relu"); |
| 111 | + auto relu_out = pattern->NewNode(relu_out_repr()) |
| 112 | + ->AsOutput() |
| 113 | + ->assert_is_op_output("relu", "Out"); |
| 114 | + relu->LinksFrom({ln_y}).LinksTo({relu_out}); |
| 115 | +} |
| 116 | + |
| 117 | +} // namespace patterns |
| 118 | + |
| 119 | +class LayerNormalizeReluXPUFusePass : public FusePassBase { |
| 120 | + protected: |
| 121 | + void ApplyImpl(ir::Graph* graph) const override; |
| 122 | + |
| 123 | + private: |
| 124 | + void FuseLayerNormalizeRelu(ir::Graph* graph) const; |
| 125 | + |
| 126 | + const std::string name_scope_{"layer_norm_relu_xpu_fuse_pass"}; |
| 127 | +}; |
| 128 | + |
| 129 | +void LayerNormalizeReluXPUFusePass::ApplyImpl(ir::Graph* graph) const { |
| 130 | + PADDLE_ENFORCE_NOT_NULL( |
| 131 | + graph, common::errors::PreconditionNotMet("graph should not be null.")); |
| 132 | + Init(name_scope_, graph); |
| 133 | + auto* dev_ctx = static_cast<phi::CPUContext*>( |
| 134 | + phi::DeviceContextPool::Instance().Get(phi::XPUPlace())); |
| 135 | + auto version = |
| 136 | + phi::backends::xpu::get_xpu_version(dev_ctx->GetPlace().GetDeviceId()); |
| 137 | + if (version == phi::backends::xpu::XPUVersion::XPU2) { |
| 138 | + FuseLayerNormalizeRelu(graph); |
| 139 | + } |
| 140 | +} |
| 141 | + |
| 142 | +void LayerNormalizeReluXPUFusePass::FuseLayerNormalizeRelu( |
| 143 | + ir::Graph* graph) const { |
| 144 | + GraphPatternDetector gpd; |
| 145 | + patterns::LayerNormalizeReluXPUPattern pattern(gpd.mutable_pattern(), |
| 146 | + name_scope_); |
| 147 | + |
| 148 | + int found_subgraph_count = 0; |
| 149 | + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, |
| 150 | + Graph* graph) { |
| 151 | + VLOG(4) << "handle LayerNormalizeReluXPUFusePass fuse"; |
| 152 | + // declare operator node's name |
| 153 | + GET_IR_NODE(ln); |
| 154 | + GET_IR_NODE(relu); |
| 155 | + // declare variable node's name |
| 156 | + GET_IR_NODE(ln_x); |
| 157 | + GET_IR_NODE(ln_bias); |
| 158 | + GET_IR_NODE(ln_scale); |
| 159 | + GET_IR_NODE(ln_y); |
| 160 | + GET_IR_NODE(ln_mean); |
| 161 | + GET_IR_NODE(ln_variance); |
| 162 | + GET_IR_NODE(relu_out); |
| 163 | + |
| 164 | + auto* block = ln->Op()->Block(); |
| 165 | + auto* scope = param_scope(); |
| 166 | + PADDLE_ENFORCE_NOT_NULL( |
| 167 | + scope, common::errors::InvalidArgument("Scope cannot be nullptr.")); |
| 168 | + // delete useless node |
| 169 | + std::unordered_set<const Node*> delete_nodes; |
| 170 | + |
| 171 | + float eps = PADDLE_GET_CONST(float, ln->Op()->GetAttr("epsilon")); |
| 172 | + int begin_norm_axis = |
| 173 | + PADDLE_GET_CONST(int, ln->Op()->GetAttr("begin_norm_axis")); |
| 174 | + |
| 175 | + std::string fused_op_out_name; |
| 176 | + fused_op_out_name = relu_out->Name(); |
| 177 | + // Generate add_layernorm fused op |
| 178 | + framework::OpDesc fused_op_desc(block); |
| 179 | + |
| 180 | + fused_op_desc.SetType("layer_norm_relu_xpu"); |
| 181 | + // set attrs for fused op |
| 182 | + fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis); |
| 183 | + fused_op_desc.SetInput("x", {ln_x->Name()}); |
| 184 | + fused_op_desc.SetInput("bias", {ln_bias->Name()}); |
| 185 | + fused_op_desc.SetInput("scale", {ln_scale->Name()}); |
| 186 | + fused_op_desc.SetAttr("epsilon", eps); |
| 187 | + fused_op_desc.SetOutput("out", {fused_op_out_name}); |
| 188 | + // relink fused op |
| 189 | + auto* fused_op = graph->CreateOpNode(&fused_op_desc); |
| 190 | + IR_NODE_LINK_TO(ln_x, fused_op); |
| 191 | + IR_NODE_LINK_TO(ln_bias, fused_op); |
| 192 | + IR_NODE_LINK_TO(ln_scale, fused_op); |
| 193 | + IR_NODE_LINK_TO(fused_op, relu_out); |
| 194 | + |
| 195 | + delete_nodes.insert({ln, relu, ln_y, ln_mean, ln_variance}); |
| 196 | + GraphSafeRemoveNodes(graph, delete_nodes); |
| 197 | + found_subgraph_count++; |
| 198 | + }; |
| 199 | + |
| 200 | + gpd(graph, handler); |
| 201 | + AddStatis(found_subgraph_count); |
| 202 | +} |
| 203 | + |
| 204 | +} // namespace ir |
| 205 | +} // namespace framework |
| 206 | +} // namespace paddle |
| 207 | + |
| 208 | +REGISTER_PASS(layer_norm_relu_xpu_fuse_pass, |
| 209 | + paddle::framework::ir::LayerNormalizeReluXPUFusePass); |
| 210 | + |
| 211 | +REGISTER_PASS_CAPABILITY(layer_norm_relu_xpu_fuse_pass) |
| 212 | + .AddCombination( |
| 213 | + paddle::framework::compatible::OpVersionComparatorCombination().EQ( |
| 214 | + "layer_norm_relu_xpu", 0)); |
0 commit comments