Skip to content
24 changes: 14 additions & 10 deletions src/include/migraphx/raw_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ struct raw_data : raw_data_base
friend Stream& operator<<(Stream& os, const Derived& d)
{
if(not d.empty())
d.visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
d.fallback_visit([&](auto x) { os << x; },
[&](auto&& xs) {
for(auto&& x : xs)
{
os << "{ ";
os << x;
os << " }, ";
}
});
return os;
}

Expand Down Expand Up @@ -125,7 +125,11 @@ struct raw_data : raw_data_base
{
auto&& buffer = static_cast<const Derived&>(*this).data();
shape view_shape = {shape::uint8_type, {s.bytes()}};
v(make_view(view_shape, reinterpret_cast<byte*>(buffer)));
using byte_type = std::conditional_t<
std::is_const_v<std::remove_pointer_t<std::remove_reference_t<decltype(buffer)>>>,
const byte*,
byte*>;
v(make_view(view_shape, reinterpret_cast<byte_type>(buffer)));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/propagate_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static bool skip_propagate(instruction_ref ins)
{
if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name()))
return skip_propagate(ins->inputs().front());
if(ins->name() == "unpack_int4")
if(contains({"unpack_int4", "unpack_fp4"}, ins->name()))
return true;
auto&& s = ins->get_shape();
if(s.broadcasted() and s.element_space() < s.elements())
Expand Down
59 changes: 36 additions & 23 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,36 @@
return s;
}

// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
auto propagate_quantized_ins(module& m,
const instruction_ref dqins,
const instruction_ref qop_arg,
bool is_fp16_model = false)
std::vector<instruction_ref> get_inbetween_ins(const instruction_ref dqins,

Check warning on line 52 in src/simplify_qdq.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] src/simplify_qdq.cpp#L52

"inbetween" is a misspelling of "between"
Raw output
./src/simplify_qdq.cpp:52:33: "inbetween" is a misspelling of "between"
const instruction_ref qop_arg)
{
auto prev_ins = qop_arg;
std::vector<instruction_ref> ins_between;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != dqins)
{
ins_between.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto qinp = dqins->inputs().front();
return ins_between;
}

// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
auto propagate_quantized_ins(module& m,
const instruction_ref dqins,
instruction_ref input_ins,
std::vector<instruction_ref> ins_between,
bool is_fp16_model = false)
{
for(auto ins : reverse_iterator_for(ins_between))
{
if((*ins)->name() == "convert" and is_fp16_model)
{
continue;
}
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
input_ins = m.insert_instruction(dqins, (*ins)->get_operator(), {input_ins});
}
return qinp;
return input_ins;
}

struct match_find_quantizable_ops
Expand Down Expand Up @@ -140,8 +144,13 @@
assert(dq1->get_shape().type() == migraphx::shape::float_type);
is_fp16_model = true;
}
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model);

auto qop_between_arg0 = get_inbetween_ins(dq1, qop_args[0]);

Check warning on line 148 in src/simplify_qdq.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] src/simplify_qdq.cpp#L148

"inbetween" is a misspelling of "between"
Raw output
./src/simplify_qdq.cpp:148:36: "inbetween" is a misspelling of "between"
auto qop_between_arg1 = get_inbetween_ins(dq2, qop_args[1]);

Check warning on line 149 in src/simplify_qdq.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] src/simplify_qdq.cpp#L149

"inbetween" is a misspelling of "between"
Raw output
./src/simplify_qdq.cpp:149:36: "inbetween" is a misspelling of "between"
qop_args.at(0) =
propagate_quantized_ins(m, dq1, qop_args[0], qop_between_arg0, is_fp16_model);
qop_args.at(1) =
propagate_quantized_ins(m, dq2, qop_args[1], qop_between_arg1, is_fp16_model);
auto arg1_lens = qop_args[0]->get_shape().lens();
auto arg2_lens = qop_args[1]->get_shape().lens();

Expand Down Expand Up @@ -280,15 +289,13 @@
}
};

// Note: scales are not constant b/c of dynamic quantization.
// Checks for block quantized scales by checking scales are not scalar or 1D.
inline auto dynamic_block_dq(const std::string& scale)
inline auto block_dq(const std::string& scale)
{
// clang-format off
return match::name("dequantizelinear")(
match::nargs(2),
match::arg(1)(match::skip_broadcasts(match::none_of(
match::is_constant(),
match::scalar_shape,
match::ndim(1)
).bind(scale))));
Expand All @@ -305,9 +312,9 @@
{
auto matcher() const
{
auto dq1 = match::arg(0)(skip_post_dq_ops(dynamic_block_dq("scale1").bind("dq1")));
auto dq2 = match::arg(1)(skip_post_dq_ops(dynamic_block_dq("scale2").bind("dq2")));
return match::name("dot")(dq1, dq2);
auto dq1 = match::arg(0)(skip_post_dq_ops(block_dq("scale1").bind("dq1")));
auto dq2 = match::arg(1)(skip_post_dq_ops(block_dq("scale2").bind("dq2")));
return match::name(get_quantizable_op_names())(dq1, dq2);
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -328,10 +335,16 @@
assert(dq1->get_shape().type() == migraphx::shape::float_type);
is_fp16_model = true;
}
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model);
qop_args.push_back(scale1);
qop_args.push_back(scale2);
auto qop_between_arg0 = get_inbetween_ins(dq1, qop_args[0]);

Check warning on line 338 in src/simplify_qdq.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] src/simplify_qdq.cpp#L338

"inbetween" is a misspelling of "between"
Raw output
./src/simplify_qdq.cpp:338:36: "inbetween" is a misspelling of "between"
qop_args.at(0) =
propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model);
auto qop_between_arg1 = get_inbetween_ins(dq2, qop_args[1]);

Check warning on line 341 in src/simplify_qdq.cpp

View workflow job for this annotation

GitHub Actions / misspell

[misspell] src/simplify_qdq.cpp#L341

"inbetween" is a misspelling of "between"
Raw output
./src/simplify_qdq.cpp:341:36: "inbetween" is a misspelling of "between"
qop_args.at(1) =
propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model);
qop_args.push_back(
propagate_quantized_ins(m, dq1, scale1, qop_between_arg0, is_fp16_model));
qop_args.push_back(
propagate_quantized_ins(m, dq2, scale2, qop_between_arg1, is_fp16_model));

if(qop->name() == "convolution")
{
Expand Down
27 changes: 27 additions & 0 deletions test/propagate_constant_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,4 +535,31 @@ TEST_CASE(block_dequantize_int4)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(pack_unpack_fp4)
{
migraphx::shape s1{migraphx::shape::float_type, {4}};
migraphx::shape s2{migraphx::shape::fp4x2_type, {2}};
migraphx::module m1;
{
const std::vector<float> vec = {1.f, 0.f, 2.f, 0.f};
auto l = m1.add_literal(migraphx::literal(s1, vec));
auto pack = m1.add_instruction(migraphx::make_op("pack_fp4"), l);
auto unpack = m1.add_instruction(migraphx::make_op("unpack_fp4"), pack);
m1.add_return({unpack});
}

run_pass(m1);

migraphx::module m2;
{
using migraphx::shape;
const std::vector<uint8_t> vec = {0x2, 0x4};
auto l = m2.add_literal(migraphx::literal(s2, vec.data()));
auto unpack = m2.add_instruction(migraphx::make_op("unpack_fp4"), l);
m2.add_return({unpack});
}

EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
98 changes: 98 additions & 0 deletions test/simplify_qdq_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,104 @@
EXPECT(m1 == m2);
}

TEST_CASE(fp4x2_quant_dot_transB)

Check warning on line 1818 in test/simplify_qdq_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

invalid case style for function 'fp4x2_quant_dot_transB' [readability-identifier-naming,-warnings-as-errors]
{
migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}};
migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 8, 12}};
migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}};
migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 8, 24}};

migraphx::module m1;
{
auto packed_a = m1.add_parameter("input", shape_packed_a);
auto packed_b = m1.add_parameter("weights", shape_packed_b);
auto scale_a = m1.add_parameter("scale_a", shape_scales_a);
auto scale_b = m1.add_parameter("scale_b", shape_scales_b);

auto unpack_a =
m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a);
auto unpack_b =
m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b);
auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a);
auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b);
auto trans_b = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dq_b);
auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, trans_b);
m1.add_return({dot});
}

migraphx::module m2;
{
auto packed_a = m2.add_parameter("input", shape_packed_a);
auto packed_b = m2.add_parameter("weights", shape_packed_b);
auto scale_a = m2.add_parameter("scale_a", shape_scales_a);
auto scale_b = m2.add_parameter("scale_b", shape_scales_b);

auto unpack_a =
m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a);
auto unpack_b =
m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b);
auto trans_b = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unpack_b);
auto trans_scale_b = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), scale_b);
auto quant_dot = m2.add_instruction(
migraphx::make_op("quant_dot"), unpack_a, trans_b, scale_a, trans_scale_b);
m2.add_return({quant_dot});
}

run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(fp4x2_quant_dot_const_B)

Check warning on line 1868 in test/simplify_qdq_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

invalid case style for function 'fp4x2_quant_dot_const_B' [readability-identifier-naming,-warnings-as-errors]
{
migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}};
migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 24, 4}};
migraphx::shape shape_packed_b_gen{migraphx::shape::uint8_type, {1, 3, 24, 4}};
migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}};
migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 24, 8}};
unsigned long seed = 826;
migraphx::literal b_lit = generate_literal(shape_packed_b_gen, seed);
migraphx::literal scale_b_lit = generate_literal(shape_scales_b, seed);
migraphx::module m1;
{
auto packed_a = m1.add_parameter("input", shape_packed_a);
// avoiding visit fp4x2_type
auto packed_b = m1.add_literal(shape_packed_b, b_lit.data());
auto scale_a = m1.add_parameter("scale_a", shape_scales_a);
auto scale_b = m1.add_literal(scale_b_lit);

auto unpack_a =
m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a);
auto unpack_b =
m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b);
auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a);
auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b);
auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b);
m1.add_return({dot});
}

migraphx::module m2;
{
auto packed_a = m2.add_parameter("input", shape_packed_a);
auto packed_b = m2.add_literal(shape_packed_b, b_lit.data());
auto scale_a = m2.add_parameter("scale_a", shape_scales_a);
auto scale_b = m2.add_literal(scale_b_lit);

auto unpack_a =
m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a);
auto unpack_b =
m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b);
auto quant_dot = m2.add_instruction(
migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b);
m2.add_return({quant_dot});
}

run_pass(m1);
EXPECT(m1 == m2);
}

// Test that unused qdq with pack_fp4, unpack_fp4 are removed
TEST_CASE(fp4x2_even_remove_qdq)
{
Expand Down
Loading