Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
165869d
Adding lower_lrn_to_pooling function
aarushjain29 Sep 10, 2025
65d5b70
Adding lower_lrn_to_pooling function
aarushjain29 Sep 10, 2025
83651ec
Adding invert permutation header
aarushjain29 Sep 10, 2025
75d6370
changes in apply
aarushjain29 Sep 19, 2025
6d94d79
Update src/rewrite_pooling.cpp
aarushjain29 Sep 19, 2025
acafcc0
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Sep 19, 2025
0b717ab
test case added
aarushjain29 Sep 22, 2025
bec9dac
test case added
aarushjain29 Sep 23, 2025
ea39019
more test case
aarushjain29 Sep 23, 2025
8b36261
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Sep 23, 2025
69efb32
remove comment
aarushjain29 Sep 23, 2025
fb8b708
remove spaces
aarushjain29 Sep 23, 2025
6dae737
Update test/rewrite_pooling_test.cpp
aarushjain29 Sep 23, 2025
09ef691
Update test/rewrite_pooling_test.cpp
aarushjain29 Sep 23, 2025
6ac18ee
tidy errors
aarushjain29 Sep 24, 2025
4780e5e
cpp errors
aarushjain29 Sep 24, 2025
3858f86
tidy errors
aarushjain29 Sep 24, 2025
2c1b76f
tidy errors
aarushjain29 Sep 24, 2025
07a2730
test case for comparing two models
aarushjain29 Sep 24, 2025
c2dc92b
test case for comparing two models
aarushjain29 Sep 24, 2025
ea7159d
accepting both even and odd sizes
aarushjain29 Sep 24, 2025
cc383cd
remove tidy errors
aarushjain29 Sep 24, 2025
6d17986
remove tidy errors
aarushjain29 Sep 24, 2025
506e2e6
tidy error
aarushjain29 Sep 24, 2025
6e5597d
formatting
aarushjain29 Sep 24, 2025
393e613
formatting
aarushjain29 Sep 24, 2025
38a39c2
license
aarushjain29 Sep 24, 2025
2b6cc85
formatting
aarushjain29 Sep 24, 2025
d4d5452
reverting back to even size
aarushjain29 Sep 24, 2025
5a783c9
logic for both evenn and odd sizes
aarushjain29 Sep 26, 2025
02a7ce7
calculate padding
aarushjain29 Sep 30, 2025
84d2105
formatting
aarushjain29 Sep 30, 2025
f4df624
formatting
aarushjain29 Sep 30, 2025
89fc424
formatting
aarushjain29 Sep 30, 2025
e489de0
test case added
aarushjain29 Sep 30, 2025
b0335f2
verify test case
aarushjain29 Sep 30, 2025
77ca91d
formatting
aarushjain29 Sep 30, 2025
2043239
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Sep 30, 2025
e8c2547
Update test/rewrite_pooling_test.cpp
aarushjain29 Sep 30, 2025
32821aa
licensing
aarushjain29 Sep 30, 2025
50b429f
combine line 89 and 90
aarushjain29 Oct 2, 2025
5e51bc1
compiler warning unused param
aarushjain29 Oct 2, 2025
e0355c5
remove transposed lens
aarushjain29 Oct 2, 2025
b8fc4ee
Adding the check for size and combining all the checks in if
aarushjain29 Oct 2, 2025
1b073b9
changing the test to simplify_algebra like test
aarushjain29 Oct 2, 2025
72934da
tidy error
aarushjain29 Oct 3, 2025
ea31d5e
tidy error
aarushjain29 Oct 3, 2025
3d7b480
tidy error
aarushjain29 Oct 3, 2025
080ac81
tidy error
aarushjain29 Oct 3, 2025
85c15cd
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Oct 3, 2025
91eb3c7
remove try catch and add all conditions
aarushjain29 Oct 5, 2025
a351f31
formatting
aarushjain29 Oct 5, 2025
310c8e7
new tests added
aarushjain29 Oct 5, 2025
8745a74
removing test case from test_relu_lrn
aarushjain29 Oct 5, 2025
808a56e
MIGRAPHX_REWRITE_LRN flag
aarushjain29 Oct 5, 2025
355bec1
license
aarushjain29 Oct 5, 2025
6d8929f
formatting
aarushjain29 Oct 5, 2025
c69157a
formatting
aarushjain29 Oct 5, 2025
1dadd05
formatting
aarushjain29 Oct 5, 2025
f5318f2
license
aarushjain29 Oct 5, 2025
a546fd4
enable flag in test case
aarushjain29 Oct 5, 2025
afabc81
test case accepting flag
aarushjain29 Oct 5, 2025
8b17b1d
formatting
aarushjain29 Oct 6, 2025
2e244c5
formatting
aarushjain29 Oct 6, 2025
878b135
simplify code
aarushjain29 Oct 6, 2025
04a7705
simplify code
aarushjain29 Oct 6, 2025
7ef1891
Merge branch 'develop' into lower-lrn-to-pooling
aarushjain29 Oct 6, 2025
526e25e
updated the doc
aarushjain29 Oct 6, 2025
c7cf3ee
remove flag in test and formatting
aarushjain29 Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/rewrite_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include <migraphx/op/reduce_max.hpp>
#include <migraphx/make_op.hpp>

#include <migraphx/op/lrn.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/program.hpp>

namespace migraphx {
Expand All @@ -55,6 +57,75 @@ static void replace_with_reduce(module& m, instruction_ref ins)
}
}

static void lower_lrn_to_pooling(module& m, instruction_ref ins)
{
auto v = ins->get_operator().to_value();

float alpha = v.at("alpha").to<float>();
float beta = v.at("beta").to<float>();
float k = v.at("bias").to<float>();
int size = v.at("size").to<int>();
const unsigned int axis = 1;

auto x = ins->inputs().at(0);
const auto& xshape = x->get_shape();
auto lens = xshape.lens();
const int64_t rank = static_cast<int64_t>(lens.size());
if(rank < 2)
return;
if(size <= 0)
return;

auto x2 = m.insert_instruction(ins, make_op("mul"), x, x);

std::vector<int64_t> perm(rank);
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[static_cast<std::size_t>(axis)], perm.back());
auto moved = m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), x2);
auto moved_lens = moved->get_shape().lens();
auto b =
std::accumulate(moved_lens.begin(), moved_lens.end() - 1, 1, std::multiplies<size_t>());
const int64_t c = static_cast<int64_t>(moved_lens.back());
auto pooled_in = m.insert_instruction(
ins,
make_op("reshape", {{"dims", std::vector<int64_t>{static_cast<int64_t>(b), 1, 1, c}}}),
moved);

auto avg = m.insert_instruction(ins,
make_op("pooling",
{{"mode", op::pooling_mode::average},
{"lengths", std::vector<int64_t>{1, size}},
{"stride", std::vector<int64_t>{1, 1}},
{"padding", std::vector<int64_t>{0, size / 2}},
{"count_include_pad", true}}),
pooled_in);

auto moved_shape_back = std::vector<int64_t>(moved_lens.begin(), moved_lens.end());
auto avg_moved =
m.insert_instruction(ins, make_op("reshape", {{"dims", moved_shape_back}}), avg);


auto invp = invert_permutation(perm);
auto avg_ch =
m.insert_instruction(ins, make_op("transpose", {{"permutation", invp}}), avg_moved);

auto k_lit = m.add_literal(k);
auto a_lit = m.add_literal(alpha);
auto k_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), k_lit);
auto a_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), a_lit);
auto alpha_avg = m.insert_instruction(ins, make_op("mul"), a_mb, avg_ch);
auto den = m.insert_instruction(ins, make_op("add"), k_mb, alpha_avg);

auto b_lit = m.add_literal(beta);
auto b_mb = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", lens}}), b_lit);
auto denpow = m.insert_instruction(ins, make_op("pow"), den, b_mb);
auto y = m.insert_instruction(ins, make_op("div"), ins->inputs().front(), denpow);

m.replace_instruction(ins, y);
}



static void replace_dilations_with_gather_pooling(module& m, instruction_ref ins)
{
// TODO remove this when MIOpen supports dilated pooling
Expand Down Expand Up @@ -143,10 +214,16 @@ void rewrite_pooling::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "pooling")
continue;
if(ins->inputs().empty())
continue;
if(ins->name() == "lrn")
{
lower_lrn_to_pooling(m, ins);
continue;
}
if(ins->name() != "pooling")
continue;

auto&& s = ins->inputs().front()->get_shape();
auto&& op = any_cast<op::pooling>(ins->get_operator());
bool same_kernel_as_shape = std::equal(
Expand Down
60 changes: 60 additions & 0 deletions test/rewrite_pooling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

#include <migraphx/verify.hpp>

#include <migraphx/iterator.hpp>

static void opt_pooling(migraphx::module& m)
{
migraphx::rewrite_pooling rp;
Expand Down Expand Up @@ -309,6 +311,64 @@ TEST_CASE(rewrite_pooling_dialtions_test5)
test_rewrite(migraphx::op::pooling_mode::max);
}

TEST_CASE(lower_lrn_to_pooling)
{
migraphx::shape s{migraphx::shape::float_type, {1, 64, 55, 55}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto lrn = m1.add_instruction(
migraphx::make_op("lrn",
{{"alpha", 0.0001}, {"beta", 0.75}, {"bias", 1.0}, {"size", 5}}),
x);
m1.add_return({lrn});
}
opt_pooling(m1);

migraphx::module m2;
{
auto x = m2.add_parameter("x", s);

auto x_squared = m2.add_instruction(migraphx::make_op("mul"), x, x);
auto transpose1 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), x_squared);
auto reshape1 = m2.add_instruction(
migraphx::make_op("reshape", {{"dims", {3025, 1, 1, 64}}}), transpose1);
auto pooling =
m2.add_instruction(migraphx::make_op("pooling",
{{"mode", migraphx::op::pooling_mode::average},
{"lengths", std::vector<int64_t>{1, 5}},
{"stride", std::vector<int64_t>{1, 1}},
{"padding", std::vector<int64_t>{0, 2}},
{"count_include_pad", true}}),
reshape1);
auto reshape2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 55, 55, 64}}}), pooling);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 2, 1}}}), reshape2);

auto beta_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.75}});
auto alpha_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {0.0001}});
auto bias_lit = m2.add_literal(migraphx::literal{migraphx::shape::float_type, {1.0}});

auto bias_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), bias_lit);
auto alpha_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), alpha_lit);
auto beta_mb = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 55, 55}}}), beta_lit);

auto alpha_avg = m2.add_instruction(migraphx::make_op("mul"), alpha_mb, transpose2);
auto denominator = m2.add_instruction(migraphx::make_op("add"), bias_mb, alpha_avg);
auto powered = m2.add_instruction(migraphx::make_op("pow"), denominator, beta_mb);
auto result = m2.add_instruction(migraphx::make_op("div"), x, powered);

m2.add_return({result});
}

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(rewrite_avgpool_rank3_dil_test)
{
// 1D case 1, input is 3D
Expand Down
Loading