Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ if(WITH_TENSORRT)
pass_library(remove_padding_recover_padding_pass inference)
pass_library(delete_remove_padding_recover_padding_pass inference)
pass_library(layernorm_shift_partition_fuse_pass inference)
pass_library(preln_layernorm_x_fuse_pass inference)
endif()

if(WITH_TENSORRT AND NOT WIN32)
Expand Down
189 changes: 189 additions & 0 deletions paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h"

#include <string>

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct PrelnLayerNormX : public PatternBase {
PrelnLayerNormX(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_layernorm_x") {}

void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(elementwise_bias);
PATTERN_DECL_NODE(elementwise0);
PATTERN_DECL_NODE(elementwise1);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(elementwise1_out);

PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
};

void PrelnLayerNormX::operator()(PDNode *x, PDNode *y) {
auto *elementwise1 =
pattern->NewNode(elementwise1_repr())->assert_is_op("elementwise_add");
auto *elementwise1_out_var =
pattern->NewNode(elementwise1_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("layernorm_shift_partition", "X");

elementwise1->LinksFrom({x, y}).LinksTo({elementwise1_out_var});
// Create nodes for layer_norm op.
auto *layer_norm = pattern->NewNode(layer_norm_repr())
->assert_is_op("layernorm_shift_partition");
auto *layer_norm_bias_var =
pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layernorm_shift_partition", "Bias");

auto *layer_norm_scale_var =
pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layernorm_shift_partition", "Scale");

auto *layer_norm_out_var =
pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layernorm_shift_partition", "Y");

// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise1_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out_var});
}

} // namespace patterns

int PrelnLayerNormXFusePass::ApplyPattern(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_layernorm_x_fuse", graph);

int found_subgraph_count = 0;

GraphPatternDetector gpd;
PDNode *x = nullptr;
PDNode *y = nullptr;

x = gpd.mutable_pattern()
->NewNode("preln_layernorm_x_fuse/x")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "X");

y = gpd.mutable_pattern()
->NewNode("preln_layernorm_x_fuse/y")
->AsInput()
->assert_var_not_persistable()
->assert_is_op_input("elementwise_add", "Y");
patterns::PrelnLayerNormX fused_pattern(gpd.mutable_pattern(),
"preln_layernorm_x_fuse");
fused_pattern(x, y);

auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}

VLOG(4) << "handle preln layernorm x fuse";

GET_IR_NODE_FROM_SUBGRAPH(elementwise1, elementwise1, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
elementwise1_out, elementwise1_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);

if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln_layernorm_x_fuse pass in op compat failed.";
return;
}
static int cnt = 0;
if (cnt++ > 0) {
// return;
}
std::unordered_set<const Node *> del_node_set;
// Create an PrelnLayerNormX op node
OpDesc new_desc(*layer_norm->Op());
new_desc.SetType("preln_layernorm_shift_partition");
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetOutput("Out_0", {elementwise1_out->Name()});
new_desc.SetOutput("Out_1", {layer_norm_out->Name()});
new_desc.RemoveOutput("Y");
new_desc.Flush();

auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.

del_node_set.insert(elementwise1);
del_node_set.insert(layer_norm);
GraphSafeRemoveNodes(graph, del_node_set);

IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise1_out);
found_subgraph_count++;
};

gpd(graph, handler);
return found_subgraph_count;
}

void PrelnLayerNormXFusePass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("preln_layernorm_x_fuse", graph);
int found_subgraph_count = ApplyPattern(graph);
AddStatis(found_subgraph_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(preln_layernorm_x_fuse_pass,
paddle::framework::ir::PrelnLayerNormXFusePass);
REGISTER_PASS_CAPABILITY(preln_layernorm_x_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"elementwise_add", 1));
60 changes: 60 additions & 0 deletions paddle/fluid/framework/ir/preln_layernorm_x_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {
namespace framework {
namespace ir {
//
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> preln_layernorm_shift_partition
// | | | |
// other_op4 layernorm_shift_partition other_op4 other_op3
// |
// other_op3
class Graph;

class PrelnLayerNormXFusePass : public FusePassBase {
public:
PrelnLayerNormXFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1, 2})
.End();
}

virtual ~PrelnLayerNormXFusePass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
int ApplyPattern(ir::Graph* graph) const;
};

} // namespace ir
} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,7 @@ USE_TRT_CONVERTER(shape)
USE_TRT_CONVERTER(fill_constant)
USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(preln_layernorm_shift_partition)
USE_TRT_CONVERTER(merge_layernorm)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", //
"preln_residual_bias_fuse_pass", //
"preln_layernorm_x_fuse_pass", //
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个pass是不是也能放在原生gpu里面,如果把新增的plugin放到phi算子里,通过通用plugin应该也能接进来

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面 @weishengying 完善后,可以按照新方案接入plugin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为目前只有trt算子 所以没有放到原生pass里

// "set_transformer_input_convert_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ list(
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc
preln_layernorm_shift_partition_op.cc
merge_layernorm_op.cc
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/prelnlayernorm_shift_partition_op.h"

namespace paddle {
namespace inference {
namespace tensorrt {

class PrelnLayerNormShiftPartitionOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a fluid preln_layernorm_shift_partition op to tensorrt "
"preln_layernorm_shift_partition plugin";
framework::OpDesc op_desc(op, nullptr);

auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());

std::vector<nvinfer1::ITensor*> inputs{X, Y};

auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front());
auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front());

const float eps = op_desc.HasAttr("epsilon")
? PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"))
: 1e-5f;
const int window_size =
PADDLE_GET_CONST(int, op_desc.GetAttr("window_size"));
const int input_resolution =
PADDLE_GET_CONST(int, op_desc.GetAttr("input_resolution"));

const int shift_size =
op_desc.HasAttr("shift_size")
? PADDLE_GET_CONST(int, op_desc.GetAttr("shift_size"))
: 0;

auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
auto* Scale_t = Scale_v->GetMutable<framework::LoDTensor>();

auto bias_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t);
auto scale_weight =
engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t);
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::PrelnLnormShiftPartitionPluginDynamic* plugin =
new plugin::PrelnLnormShiftPartitionPluginDynamic(
static_cast<const float*>(scale_weight.get().values),
static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count,
shift_size,
window_size,
input_resolution,
eps,
with_fp16);
layernorm_layer =
engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin);
}

std::vector<std::string> output_names;
output_names.emplace_back(op_desc.Output("Out_0").front());
output_names.emplace_back(op_desc.Output("Out_1").front());
RreplenishLayerAndOutput(layernorm_layer,
"preln_layernorm_shift_partition",
output_names,
test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(preln_layernorm_shift_partition,
PrelnLayerNormShiftPartitionOpConverter);
Loading