Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,17 @@ endif()

if(WITH_XPU)
cc_library(
quant_utils
xpu_quant_utils
SRCS xpu/quant_utils.cc
DEPS pass)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS quant_utils)
cc_library(
xpu_pass_utils
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down
79 changes: 32 additions & 47 deletions paddle/fluid/framework/ir/delete_dropout_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,71 +25,52 @@ namespace paddle {
namespace framework {
namespace ir {

#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(dropout_op); \
GET_IR_NODE(dropout_op_out); \
GET_IR_NODE(dropout_op_outmask); \
GET_IR_NODE(any_op2);
#define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern)

void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph);

GraphPatternDetector gpd;

patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern();

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string dropout_op_out_name = dropout_op_out->Var()->Name();

// any_op2
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
if (arg_name.size() == 0) {
LOG(INFO) << "Delete dropout op pass: can not find the input "
<< dropout_op_out_name;
return;
}

// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != dropout_op_out_name) {
new_inputs.push_back(i_n);
}
GET_IR_NODE(dropout_op_x);
GET_IR_NODE(dropout_op);
GET_IR_NODE(dropout_op_out);
GET_IR_NODE(dropout_op_mask);

// link dropout_op_out to pre_op
auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto dropout_op_out_name = dropout_op_out->Var()->Name();
auto pre_ops = dropout_op_x->inputs;
if (pre_ops.empty()) return;
auto pre_op_desc = pre_ops[0]->Op();
auto pre_op_outs = pre_op_desc->Outputs();
for (auto& out_var : pre_op_outs) {
auto names = out_var.second;
for (size_t i = 0; i < names.size(); i++) {
if (names[i] == dropout_op_x_name) {
names[i] = dropout_op_out_name;
pre_op_desc->SetOutput(out_var.first, names);
break;
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
any_op2_desc->Flush();
IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);

// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{dropout_op, dropout_op_out, dropout_op_outmask});
// delete useless node
std::unordered_set<const Node*> delete_nodes{
dropout_op_x, dropout_op, dropout_op_mask};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};

gpd(graph, handler);
AddStatis(found_subgraph_count);
}

DeleteDropoutOpXPass::DeleteDropoutOpXPass() {
Expand Down Expand Up @@ -279,6 +260,10 @@ void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op,

REGISTER_PASS(delete_dropout_op_pass,
paddle::framework::ir::DeleteDropoutOpPass);
REGISTER_PASS_CAPABILITY(delete_dropout_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"dropout", 0));

REGISTER_PASS(delete_dropout_op_x_pass,
paddle::framework::ir::DeleteDropoutOpXPass);
Expand Down
31 changes: 12 additions & 19 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3034,26 +3034,19 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
}

void patterns::DeleteDropoutOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("dropout", "X")
->AsInput();

auto dropout_op =
pattern->NewNode(dropout_op_repr())->assert_is_op("dropout");

auto dropout_op_x = pattern->NewNode(dropout_op_x_repr())
->assert_is_op_input("dropout", "X")
->AsInput();
auto dropout_op = pattern->NewNode(dropout_op_repr())
->assert_is_op("dropout")
->assert_op_attr("dropout_implementation",
std::string("upscale_in_train"));
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out")
->AsIntermediate();

auto dropout_op_outmask = pattern->NewNode(dropout_op_outmask_repr())
->assert_is_op_output("dropout", "Mask")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();

dropout_op->LinksFrom({any_op_out});
dropout_op_out->LinksFrom({dropout_op});
dropout_op_outmask->LinksFrom({dropout_op});
any_op2->LinksFrom({dropout_op_out});
->assert_is_op_output("dropout", "Out");
auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
dropout_op->LinksFrom({dropout_op_x})
.LinksTo({dropout_op_out, dropout_op_mask});
}

void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
Expand Down
5 changes: 2 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1763,11 +1763,10 @@ struct DeleteDropoutOpPattern : public PatternBase {

void operator()();

PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(dropout_op_x);
PATTERN_DECL_NODE(dropout_op);
PATTERN_DECL_NODE(dropout_op_out);
PATTERN_DECL_NODE(dropout_op_outmask);
PATTERN_DECL_NODE(any_op2);
PATTERN_DECL_NODE(dropout_op_mask);
};

struct DeleteQuantDequantOpPattern : public PatternBase {
Expand Down
17 changes: 2 additions & 15 deletions paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,6 @@ class FcXPUFusePass : public FusePassBase {
const std::string& act_type) const;

const std::string name_scope_{"fc_xpu_fuse_pass"};
const std::map<std::string, int> act_map_{{"", 0},
{"relu", 1},
{"sigmoid", 2},
{"tanh", 3},
{"gelu", 4},
{"leaky_relu", 5},
{"hard_swish", 14},
{"hard_sigmoid", 15},
{"relu6", 17}};
};

void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
Expand Down Expand Up @@ -246,17 +237,13 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
mul_w_max_var->SetPersistable(true);
auto mul_w_max_tensor =
scope->Var(mul_w_max_name)->GetMutable<phi::DenseTensor>();
auto* xpu_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::XPUPlace()));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
bool transpose_w = false;
if (mul_type == "matmul") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
} else if (mul_type == "matmul_v2") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
}
QuantWeight<int16_t>(
mul_w_tensor, mul_w_max_tensor, !transpose_w, max_ptr_size);
QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w);
}

// Generate fc_xpu op
Expand Down Expand Up @@ -288,7 +275,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc.SetAttr("act_type", 0);
fc_xpu_op_desc.SetAttr("act_alpha", 0.f);
if (act) {
fc_xpu_op_desc.SetAttr("act_type", act_map_.at(act_type));
fc_xpu_op_desc.SetAttr("act_type", ConvertActivationType(act_type));
if (act_type == "leaky_relu") {
fc_xpu_op_desc.SetAttr(
"act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("alpha")));
Expand Down
Loading