Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -724,12 +724,13 @@ if (onnxruntime_USE_CUDA)
if (WIN32)
link_directories(${onnxruntime_CUDNN_HOME}/lib/x64)

file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*")
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll")
foreach(cuda_dll_path ${cuda_dll_paths})
get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME)
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:${cuda_dll_file_name}")
endforeach(cuda_dll_path)
# delayload causes crash on exit, so disable for now
#file(GLOB cuda_dll_paths "${onnxruntime_CUDA_HOME}/bin/cublas64_*" "${onnxruntime_CUDA_HOME}/bin/cudart64_*" "${onnxruntime_CUDA_HOME}/bin/curand64_*")
#set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:cudnn64_7.dll")
#foreach(cuda_dll_path ${cuda_dll_paths})
# get_filename_component(cuda_dll_file_name ${cuda_dll_path} NAME)
# set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:${cuda_dll_file_name}")
#endforeach(cuda_dll_path)

else()
link_directories(${onnxruntime_CUDNN_HOME}/lib64)
Expand Down
112 changes: 78 additions & 34 deletions onnxruntime/core/optimizer/gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,25 @@ static bool IsSupportedDataType(const Node& node) {
}
return true;
}

/*
This function fuses subgraph like the following into one Gelu node.
Subgraph pattern 1:
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul ==>
(B=1.4142...) (1)

Subgraph pattern 2:
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul ==>
(B=1.4142...) (1) (0.5)

After Fusion:
[root]--> Gelu ==>
*/
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
Expand Down Expand Up @@ -68,13 +86,9 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
continue;
}

// Check the other input node(e.g. not of type Erf) is 1.0f.
const Node& add_first_input_node = *(add_node.InputNodesBegin());
int add_const_input_index = 0;
if (add_first_input_node.OpType().compare("Erf") == 0) {
add_const_input_index = 1;
}
const auto& add_const_input_arg = add_node.InputDefs()[add_const_input_index];
// Check the other input node (e.g. not the Erf) is 1.0f.
bool is_erf_first_input = (add_node.InputDefs()[0]->Name() == erf_node.MutableOutputDefs()[0]->Name());
const auto& add_const_input_arg = add_node.InputDefs()[is_erf_first_input ? 1 : 0];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *add_const_input_arg, 1.0f, true)) {
continue;
}
Expand All @@ -87,35 +101,60 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
continue;
}

const Node* p_mul2_node = nullptr;
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
if ((*iter).OpType().compare("Mul") == 0) {
// find the other input node of Mul
p_mul2_node = &(*iter);
break;
bool is_pattern_1 = true;
const Node* p_mul2_node = graph_utils::FirstParentByType(mul_node, "Mul");
if (p_mul2_node != nullptr) {
// Match subgraph pattern 1
Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul2_node)) {
continue;
}
}
if (p_mul2_node == nullptr) {
continue;
}

Node& mul2_node = *graph.GetNode(p_mul2_node->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul2_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(mul2_node)) {
continue;
}
// One input of mul2_node shall be the subgraph input
auto root_index = optimizer_utils::IndexOfNodeInput(*p_mul2_node, *div.InputDefs()[0]);
if (root_index < 0)
continue;

// Check the other input node(e.g. not of type Add) is 0.5f.
int mul_const_input_index = 0;
if (mul2_node.InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
mul_const_input_index = 1;
}
// Check the other input node is 0.5f.
int mul_const_input_index = (root_index == 0 ? 1 : 0);

const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
continue;
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
continue;
}
} else {
is_pattern_1 = false;

// Match subgraph pattern 2
if (mul_node.GetOutputEdgesCount() != 1) {
continue;
}

// Another input of Mul node shall be the subgraph input.
auto root_index = optimizer_utils::IndexOfNodeInput(mul_node, *div.InputDefs()[0]);
if (root_index < 0)
continue;

Node& mul2_node = *graph.GetNode(mul_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul2_node, "Mul", {7}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
}

int mul_const_input_index = 0;
if (mul2_node.InputDefs()[0]->Name() == mul_node.MutableOutputDefs()[0]->Name()) {
mul_const_input_index = 1;
}
const auto& mul_const_input_arg = mul2_node.InputDefs()[mul_const_input_index];
if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *mul_const_input_arg, 0.5f, true)) {
continue;
}

p_mul2_node = &mul2_node;
}

const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
Expand All @@ -131,7 +170,12 @@ Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, cons
// move input edges to div (first in list) across to the gelu_node.
// move output definitions and output edges from mul_node (last in list) to gelu_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2_node, mul_node}, gelu_node);
Node& mul2 = *graph.GetNode(p_mul2_node->Index());
if (is_pattern_1) {
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul2, mul_node}, gelu_node);
} else {
graph_utils::FinalizeNodeFusion(graph, {div, erf_node, add_node, mul_node, mul2}, gelu_node);
}

modified = true;
}
Expand Down
48 changes: 35 additions & 13 deletions onnxruntime/python/tools/bert/BertOnnxModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,24 @@ def fuse_gelu(self, gelu_op_name):

"""
Fuse Gelu with Erf into one node:
+-------Mul(B=0.5)-------------------+
Pattern 1:
+-------Mul(0.5)---------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->
(B=1.4142...) (B=1)
(B=1.4142...) (1)

Pattern 2:
+------------------------------------+
| |
| v
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
(B=1.4142...) (1) (0.5)

Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
"""
def fuse_gelu_with_elf(self, gelu_op_name):
logger.debug(f"start fuse_gelu_with_elf({gelu_op_name})")
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()

Expand Down Expand Up @@ -276,25 +285,38 @@ def fuse_gelu_with_elf(self, gelu_op_name):
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
continue

root_node = self.get_parent(div, 0, output_name_to_node)
if root_node is None:
continue
subgraph_input = div.input[0]

mul_half = self.match_parent(mul_after_erf, 'Mul', None, output_name_to_node)
if mul_half is None:
continue
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
if subgraph_input == mul_after_erf.input[another]: # pattern 2
children = input_name_to_nodes[mul_after_erf.output[0]]
if len(children) != 1 or children[0].op_type != 'Mul':
continue
mul_half = children[0]
if not self.has_constant_input(mul_half, 0.5):
continue
subgraph_output = mul_half.output[0]
else: # pattern 1
mul_half = self.match_parent(mul_after_erf, 'Mul', another, output_name_to_node)
if mul_half is None:
continue

if not self.has_constant_input(mul_half, 0.5):
continue
if not self.has_constant_input(mul_half, 0.5):
continue

if subgraph_input not in mul_half.input:
continue

subgraph_output = mul_after_erf.output[0]

subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul_after_erf.output[0]], input_name_to_nodes, output_name_to_node):
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
continue

nodes_to_remove.extend(subgraph_nodes)
gelu_node = onnx.helper.make_node(gelu_op_name,
inputs=[root_node.output[0]],
outputs=[mul_after_erf.output[0]])
inputs=[subgraph_input],
outputs=[subgraph_output])
gelu_node.domain = "com.microsoft"
nodes_to_add.append(gelu_node)

Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/python/tools/bert/test_bert_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BERT_TEST_MODELS = {
"bert_pytorch_0": 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_0.onnx',
"bert_pytorch_1": 'test_data\\bert_squad_pytorch1.4_opset11\\BertForQuestionAnswering_1.onnx',
"bert_squad_pytorch1.4_opset10_fp32": 'test_data\\bert_squad_pytorch1.4_opset10_fp32\\BertForQuestionAnswering.onnx',
"bert_keras_0": 'test_data\\bert_mrpc_tensorflow2.1_opset10\\TFBertForSequenceClassification_1.onnx'
}

Expand Down Expand Up @@ -155,6 +156,13 @@ def test_pytorch_model_0_gpu(self):
}
self.verify_node_count(bert_model, expected_node_count)

def test_pytorch_model_2_cpu(self):
input = BERT_TEST_MODELS['bert_squad_pytorch1.4_opset10_fp32']
bert_model = optimize_model(input, 'bert', gpu_only=False,
num_heads=2, hidden_size=8, sequence_length=10,
input_int32=False, float16=False)
self.assertTrue(bert_model.is_fully_optimized())

def test_keras_model_1_cpu(self):
input = BERT_TEST_MODELS['bert_keras_0']

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

BstartJ(���A2I<W�1<nR�<G�;�^<�q?;���<
r�;��<
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

BendJ(���=Fڞ=L��=QR�=�w�=6\�=��=���=��=Jg�=
57 changes: 57 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,63 @@ TEST(GraphTransformationTests, GeluFusionTest) {
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, GeluFusionTestSwitchOrderFormat2) {
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_0.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, GeluFusionTestFormat2) {
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, GeluFusionTestFormat2GraphInput) {
auto model_uri = MODEL_FOLDER "fusion/gelu_format2_1_use_graph_input.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()).IsOK());
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger());
ASSERT_TRUE(ret.IsOK());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}

TEST(GraphTransformationTests, BiasGeluTest) {
auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/shared_lib/test_model_loading.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

#include "core/session/onnxruntime_cxx_api.h"
#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
#endif
#include <fstream>
#include "test_fixture.h"
#include "file_util.h"
Expand All @@ -25,6 +28,12 @@ TEST(CApiTest, model_from_array) {

Ort::SessionOptions so;
Ort::Session session(*ort_env.get(), buffer.data(), buffer.size(), so);

#ifdef USE_CUDA
// test with CUDA provider when using onnxruntime as dll
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(so, 0));
Ort::Session session_cuda(*ort_env.get(), buffer.data(), buffer.size(), so);
#endif
}
} // namespace test
} // namespace onnxruntime
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading