@@ -24,7 +24,25 @@ static bool IsSupportedDataType(const Node& node) {
24
24
}
25
25
return true ;
26
26
}
27
-
27
+ /*
28
+ This function fuses subgraph like the following into one Gelu node.
29
+ Subgraph pattern 1:
30
+ +-------Mul(0.5)---------------------+
31
+ | |
32
+ | v
33
+ [root] --> Div -----> Erf --> Add --> Mul ==>
34
+ (B=1.4142...) (1)
35
+
36
+ Subgraph pattern 2:
37
+ +------------------------------------+
38
+ | |
39
+ | v
40
+ [root] --> Div -----> Erf --> Add --> Mul -->Mul ==>
41
+ (B=1.4142...) (1) (0.5)
42
+
43
+ After Fusion:
44
+ [root]--> Gelu ==>
45
+ */
28
46
Status GeluFusion::ApplyImpl (Graph& graph, bool & modified, int graph_level, const logging::Logger& logger) const {
29
47
GraphViewer graph_viewer (graph);
30
48
const auto & node_topology_list = graph_viewer.GetNodesInTopologicalOrder ();
@@ -68,13 +86,9 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
68
86
continue ;
69
87
}
70
88
71
- // Check the other input node(e.g. not of type Erf) is 1.0f.
72
- const Node& add_first_input_node = *(add_node.InputNodesBegin ());
73
- int add_const_input_index = 0 ;
74
- if (add_first_input_node.OpType ().compare (" Erf" ) == 0 ) {
75
- add_const_input_index = 1 ;
76
- }
77
- const auto & add_const_input_arg = add_node.InputDefs ()[add_const_input_index];
89
+ // Check the other input node (e.g. not the Erf) is 1.0f.
90
+ bool is_erf_first_input = (add_node.InputDefs ()[0 ]->Name () == erf_node.MutableOutputDefs ()[0 ]->Name ());
91
+ const auto & add_const_input_arg = add_node.InputDefs ()[is_erf_first_input ? 1 : 0 ];
78
92
if (!optimizer_utils::IsInitializerWithExpectedValue (graph, *add_const_input_arg, 1 .0f , true )) {
79
93
continue ;
80
94
}
@@ -87,35 +101,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
87
101
continue ;
88
102
}
89
103
90
- const Node* p_mul2_node = nullptr ;
91
- for (auto iter = mul_node.InputNodesBegin (); iter != mul_node.InputNodesEnd (); ++iter) {
92
- if ((*iter).OpType ().compare (" Mul" ) == 0 ) {
93
- // find the other input node of Mul
94
- p_mul2_node = &(*iter);
95
- break ;
104
+ bool is_pattern_1 = true ;
105
+ const Node* p_mul2_node = graph_utils::FirstParentByType (mul_node, " Mul" );
106
+ if (p_mul2_node != nullptr ) {
107
+ // Match subgraph pattern 1
108
+ Node& mul2_node = *graph.GetNode (p_mul2_node->Index ());
109
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (mul2_node, " Mul" , {7 }) ||
110
+ mul2_node.GetExecutionProviderType () != div.GetExecutionProviderType () ||
111
+ mul2_node.GetOutputEdgesCount () != 1 ||
112
+ !IsSupportedDataType (mul2_node)) {
113
+ continue ;
96
114
}
97
- }
98
- if (p_mul2_node == nullptr ) {
99
- continue ;
100
- }
101
115
102
- Node& mul2_node = *graph.GetNode (p_mul2_node->Index ());
103
- if (!graph_utils::IsSupportedOptypeVersionAndDomain (mul2_node, " Mul" , {7 }) ||
104
- mul2_node.GetExecutionProviderType () != div.GetExecutionProviderType () ||
105
- mul2_node.GetOutputEdgesCount () != 1 ||
106
- !IsSupportedDataType (mul2_node)) {
107
- continue ;
108
- }
116
+ // One input of mul2_node shall be the subgraph input
117
+ auto root_index = optimizer_utils::IndexOfNodeInput (*p_mul2_node, *div.InputDefs ()[0 ]);
118
+ if (root_index < 0 )
119
+ continue ;
109
120
110
- // Check the other input node(e.g. not of type Add) is 0.5f.
111
- int mul_const_input_index = 0 ;
112
- if (mul2_node.InputDefs ()[0 ]->Name () == div.MutableInputDefs ()[0 ]->Name ()) {
113
- mul_const_input_index = 1 ;
114
- }
121
+ // Check the other input node is 0.5f.
122
+ int mul_const_input_index = (root_index == 0 ? 1 : 0 );
115
123
116
- const auto & mul_const_input_arg = mul2_node.InputDefs ()[mul_const_input_index];
117
- if (!optimizer_utils::IsInitializerWithExpectedValue (graph, *mul_const_input_arg, 0 .5f , true )) {
118
- continue ;
124
+ const auto & mul_const_input_arg = mul2_node.InputDefs ()[mul_const_input_index];
125
+ if (!optimizer_utils::IsInitializerWithExpectedValue (graph, *mul_const_input_arg, 0 .5f , true )) {
126
+ continue ;
127
+ }
128
+ } else {
129
+ is_pattern_1 = false ;
130
+
131
+ // Match subgraph pattern 2
132
+ if (mul_node.GetOutputEdgesCount () != 1 ) {
133
+ continue ;
134
+ }
135
+
136
+ // Another input of Mul node shall be the subgraph input.
137
+ auto root_index = optimizer_utils::IndexOfNodeInput (mul_node, *div.InputDefs ()[0 ]);
138
+ if (root_index < 0 )
139
+ continue ;
140
+
141
+ Node& mul2_node = *graph.GetNode (mul_node.OutputNodesBegin ()->Index ());
142
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain (mul2_node, " Mul" , {7 }) ||
143
+ mul_node.GetExecutionProviderType () != div.GetExecutionProviderType () ||
144
+ !IsSupportedDataType (mul_node)) {
145
+ continue ;
146
+ }
147
+
148
+ int mul_const_input_index = 0 ;
149
+ if (mul2_node.InputDefs ()[0 ]->Name () == mul_node.MutableOutputDefs ()[0 ]->Name ()) {
150
+ mul_const_input_index = 1 ;
151
+ }
152
+ const auto & mul_const_input_arg = mul2_node.InputDefs ()[mul_const_input_index];
153
+ if (!optimizer_utils::IsInitializerWithExpectedValue (graph, *mul_const_input_arg, 0 .5f , true )) {
154
+ continue ;
155
+ }
156
+
157
+ p_mul2_node = &mul2_node;
119
158
}
120
159
121
160
const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs ()[0 ]};
@@ -131,7 +170,12 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
131
170
// move input edges to div (first in list) across to the gelu_node.
132
171
// move output definitions and output edges from mul_node (last in list) to gelu_node.
133
172
// remove all the other nodes.
134
- graph_utils::FinalizeNodeFusion (graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node);
173
+ Node& mul2 = *graph.GetNode (p_mul2_node->Index ());
174
+ if (is_pattern_1) {
175
+ graph_utils::FinalizeNodeFusion (graph, {div, erf_node, add_node, mul2, mul_node}, gelu_node);
176
+ } else {
177
+ graph_utils::FinalizeNodeFusion (graph, {div, erf_node, add_node, mul_node, mul2}, gelu_node);
178
+ }
135
179
136
180
modified = true ;
137
181
}
0 commit comments