Skip to content

Commit 0465f71

Browse files
authored
[luci/pass] Added pass 'FuseMulWithRmsNorm' (#15932)
This commit adds a pass for fusing Mul with preceding RmsNorm. ONE-DCO-1.0-Signed-off-by: Seockho Kim [email protected]
1 parent 185f998 commit 0465f71

File tree

5 files changed

+164
-0
lines changed

5 files changed

+164
-0
lines changed

compiler/luci/pass/include/luci/CircleOptimizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class CircleOptimizer final
5050
FuseMulWithConv,
5151
FuseMulWithDiv,
5252
FuseMulWithFullyConnected,
53+
FuseMulWithRmsNorm,
5354
FuseTransposeWithMean,
5455
ResolveCustomOpAdd,
5556
ResolveCustomOpBatchMatMul,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#ifndef __LUCI_FUSE_MUL_WITH_RMSNORM_PASS_H__
18+
#define __LUCI_FUSE_MUL_WITH_RMSNORM_PASS_H__
19+
20+
#include <logo/Pass.h>
21+
22+
namespace luci
23+
{
24+
25+
/**
26+
* @brief Class to fuse Mul into CircleRmsNorm
27+
*/
28+
struct FuseMulWithRmsNormPass final : public logo::Pass
29+
{
30+
const char *name(void) const final { return "luci::FuseMulWithRmsNormPass"; }
31+
32+
bool run(loco::Graph *g) final;
33+
};
34+
35+
} // namespace luci
36+
37+
#endif // __LUCI_FUSE_MUL_WITH_RMSNORM_PASS_H__

compiler/luci/pass/src/CircleOptimizer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "luci/Pass/FuseMulWithConvPass.h"
5050
#include "luci/Pass/FuseMulWithDivPass.h"
5151
#include "luci/Pass/FuseMulWithFullyConnectedPass.h"
52+
#include "luci/Pass/FuseMulWithRmsNormPass.h"
5253
#include "luci/Pass/FusePreActivationBatchNormPass.h"
5354
#include "luci/Pass/FusePReluPass.h"
5455
#include "luci/Pass/FuseGeluPass.h"
@@ -333,6 +334,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const
333334
option_to_pass[Options::Algorithm::FuseMulWithConv] = &createPassInstance<luci::FuseMulWithConvPass>;
334335
option_to_pass[Options::Algorithm::FuseMulWithDiv] = &createPassInstance<luci::FuseMulWithDivPass>;
335336
option_to_pass[Options::Algorithm::FuseMulWithFullyConnected] = &createPassInstance<luci::FuseMulWithFullyConnectedPass>;
337+
option_to_pass[Options::Algorithm::FuseMulWithRmsNorm] = &createPassInstance<luci::FuseMulWithRmsNormPass>;
336338
option_to_pass[Options::Algorithm::ResolveCustomOpMaxPoolWithArgmax] = &createPassInstance<luci::ResolveCustomOpMaxPoolWithArgmaxPass>;
337339
option_to_pass[Options::Algorithm::ResolveCustomOpSplitV] = &createPassInstance<luci::ResolveCustomOpSplitVPass>;
338340
option_to_pass[Options::Algorithm::FuseInstanceNorm] = &createPassInstance<luci::FuseInstanceNormPass>;
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "luci/Pass/FuseMulWithRmsNormPass.h"
18+
19+
#include "helpers/NodeFiller.h"
20+
21+
#include <luci/IR/CircleNodes.h>
22+
#include <luci/Service/Nodes/CircleConst.h>
23+
#include <luci/Profile/CircleNodeOrigin.h>
24+
25+
namespace
26+
{
27+
28+
#define RETURN_FALSE_UNLESS(cond) \
29+
if (not(cond)) \
30+
return false;
31+
32+
/**
33+
* Fuse Mul to RmsNorm if the RmsNorm's weight value is 1.0
34+
*
35+
* BEFORE
36+
* |
37+
* [CircleRmsNorm] (gamma=1.0)
38+
* |
39+
* [CircleMul]
40+
* |
41+
*
42+
* AFTER
43+
* |
44+
* [CircleRmsNorm]
45+
* |
46+
*
47+
*/
48+
bool fuse_mul_with_rmsnorm(luci::CircleMul *mul)
49+
{
50+
RETURN_FALSE_UNLESS(mul);
51+
RETURN_FALSE_UNLESS(mul->dtype() == loco::DataType::FLOAT32);
52+
53+
luci::CircleRmsNorm *rmsnorm = nullptr;
54+
luci::CircleConst *weight = nullptr;
55+
RETURN_FALSE_UNLESS(luci::fill(&rmsnorm, &weight).with_commutative_args_of(mul));
56+
57+
RETURN_FALSE_UNLESS(loco::succs(rmsnorm).size() == 1);
58+
RETURN_FALSE_UNLESS(rmsnorm->dtype() == loco::DataType::FLOAT32);
59+
60+
auto norm_gamma = dynamic_cast<luci::CircleConst *>(rmsnorm->gamma());
61+
RETURN_FALSE_UNLESS(norm_gamma);
62+
RETURN_FALSE_UNLESS(norm_gamma->size<loco::DataType::FLOAT32>() == 1);
63+
RETURN_FALSE_UNLESS(norm_gamma->at<loco::DataType::FLOAT32>(0) == 1.0f);
64+
65+
RETURN_FALSE_UNLESS(weight->rank() == 1)
66+
RETURN_FALSE_UNLESS(weight->dim(0) == rmsnorm->dim(rmsnorm->rank() - 1));
67+
68+
auto fused_weight = luci::clone(weight);
69+
rmsnorm->gamma(fused_weight);
70+
71+
luci::add_origin(rmsnorm, luci::get_origin(mul));
72+
73+
replace(mul).with(rmsnorm);
74+
75+
return true;
76+
}
77+
78+
} // namespace
79+
80+
namespace luci
81+
{
82+
83+
bool FuseMulWithRmsNormPass::run(loco::Graph *g)
84+
{
85+
bool changed = false;
86+
for (auto node : loco::active_nodes(loco::output_nodes(g)))
87+
{
88+
if (auto mul = dynamic_cast<luci::CircleMul *>(node))
89+
{
90+
if (fuse_mul_with_rmsnorm(mul))
91+
changed = true;
92+
}
93+
}
94+
95+
return changed;
96+
}
97+
98+
} // namespace luci
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "luci/Pass/FuseMulWithRmsNormPass.h"
18+
19+
#include <gtest/gtest.h>
20+
21+
TEST(FuseMulWithRmsNormPassTest, name)
22+
{
23+
luci::FuseMulWithRmsNormPass pass;
24+
auto const name = pass.name();
25+
ASSERT_NE(nullptr, name);
26+
}

0 commit comments

Comments
 (0)