Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/onnx/onnx_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ shape::type_t get_type(int dtype)
case 18: return shape::fp8e4m3fnuz_type;
case 21: return shape::uint8_type;
case 22: return shape::int8_type;
case 23: return shape::fp4x2_type;
case 14:
case 15:
case 16: return shape::bf16_type;
Expand Down
139 changes: 139 additions & 0 deletions src/onnx/parse_dynamicscale.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

/**
* Operator from Brevitas to calculate dynamic quantization scales.
*/
struct parse_dynamicscale : op_parser<parse_dynamicscale>
{

std::vector<op_desc> operators() const { return {{"DynamicScale"}}; };

instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
const instruction_ref input = args.front();
instruction_ref tmp_in = input;
const auto input_lens = input->get_shape().lens();
if(args.size() != 1)
{
MIGRAPHX_THROW("DynamicScale: must have only 1 input");
}
int block_axis = info.attributes.at("group_dim").i();
block_axis = tune_axis(input->get_shape().ndim(), block_axis, "DynamicScale");
int block_size = info.attributes.at("group_size").i();
if(block_size != 32)
{
MIGRAPHX_THROW("DynamicScale: only group_size of 32 is supported");
}
migraphx::shape::type_t output_type = get_type(info.attributes.at("output_dtype").i());

// TODO expand this to handle other MX types
if(output_type != migraphx::shape::fp4x2_type)
{
MIGRAPHX_THROW("DynamicScale: only support MXFP4 type");
}

std::string scale_selection_method = info.attributes.at("scale_selection_method").s();
if(scale_selection_method != "floor")
{
MIGRAPHX_THROW("DynamicScale: only support floor scale selection");
}

std::string zero_point_selection_method = "None";
if(contains(info.attributes, "zero_point_selection_method"))
zero_point_selection_method = info.attributes.at("zero_point_selection_method").s();

if(zero_point_selection_method != "None")
{
MIGRAPHX_THROW("DynamicScale: zero_point not supported");
}

// make reduction axes for calculating block scales
// tmp_lens != input_lens if runt block is padded
auto tmp_lens = input_lens;
auto block_dim = tmp_lens.at(block_axis);
std::size_t block_padding =
std::ceil(double(block_dim) / double(block_size)) * block_size - block_dim;
// handle runt block by padding
if(block_padding != 0)
{
std::vector<std::size_t> pads_vec(2 * tmp_lens.size(), 0);
pads_vec.at(block_axis + tmp_lens.size()) = block_padding;
tmp_in = info.add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in);
tmp_lens = tmp_in->get_shape().lens();
}
// reshape block dimension to {num_blocks, block_size}
std::size_t num_blocks = tmp_lens.at(block_axis) / std::size_t(block_size);
std::vector<std::size_t> reduct_dims = tmp_lens;
reduct_dims.at(block_axis) = block_size;
reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks);
instruction_ref reshape_ins =
info.add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in);

// dynamic quantization for MX types:
// V_k = fp32 vector input of block size k
// B_k = pow(2, floor(log2(reduce_max(abs(V_k))))) # largest power of 2 less than V
// X_k = block scale k = B_k / (largest power of 2 in fp4e2m1) = B_k / 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate comment here

auto abs_ins = info.add_instruction(make_op("abs"), reshape_ins);
auto reduce_max_ins =
info.add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins);
auto log2_ins = info.add_instruction(make_op("log2"), reduce_max_ins);
auto floor_ins = info.add_instruction(make_op("floor"), log2_ins);
auto lit_2_ins = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}});
auto broadcast_lit_2_ins = info.add_instruction(
make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}),
lit_2_ins);
auto pow_ins = info.add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins);
auto lit_4_ins = info.add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}});
auto broadcast_lit_4_ins = info.add_instruction(
make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}),
lit_4_ins);
auto block_scales_ins = info.add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins);

// squeeze reduction axis for use in block quantized quantizelinear
block_scales_ins = info.add_instruction(make_op("squeeze", {{"axes", {block_axis + 1}}}),
block_scales_ins);

return block_scales_ins;
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
42 changes: 41 additions & 1 deletion src/onnx/parse_quantizelinear.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -91,6 +91,46 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
args = transform_quantize_dequantize_linear_inputs(
info, opd.onnx_name, block_size, axis, args);

if(output_type == migraphx::shape::fp4x2_type)
{
// Parsing in pack_fp4 and unpack_fp4 for the FP4 case
auto q_ins = info.add_instruction(
make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), args);

// packing axis set to fastest dimension
auto quantized_shape = q_ins->get_shape();
const auto& qs_strides = quantized_shape.strides();
if(qs_strides.empty())
{
MIGRAPHX_THROW("QuantizeLinear: MX type quantized_shape has no strides");
}
int fast_axis =
std::min_element(qs_strides.cbegin(), qs_strides.cend()) - qs_strides.cbegin();
bool odd_fast_axis = (quantized_shape.lens().at(fast_axis) % 2 == 1);
if(odd_fast_axis)
{
// pad fastest dimension by 1 if it is odd
std::vector<int64_t> padding(2 * quantized_shape.ndim(), 0);
padding.at(fast_axis * 2 + 1) = 1;
q_ins = info.add_instruction(make_op("pad", {{"pads", padding}}), q_ins);
}
auto pack_ins = info.add_instruction(make_op("pack_fp4", {{"axis", fast_axis}}),
q_ins); // output is fp4x2_type
auto unpack_ins = info.add_instruction(make_op("unpack_fp4", {{"axis", fast_axis}}),
pack_ins); // output is fp8e4m3fn_type
if(odd_fast_axis)
{
// slice off padded values
unpack_ins = info.add_instruction(
make_op("slice",
{{"axes", {fast_axis}},
{"starts", {0}},
{"ends", {quantized_shape.lens().at(fast_axis)}}}),
unpack_ins);
}
return unpack_ins;
}

if(parser.opset_version < 19)
{
auto common_type = common_shape({args[0]->get_shape(), args[1]->get_shape()}).type();
Expand Down
21 changes: 21 additions & 0 deletions test/onnx/dynamicscale_even_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
 dynamicscale_even_test:
£
inputoutput" DynamicScale*
group_dim *

group_size  *
output_dtype *"
scale_selection_method"floor *&
zero_point_selection_method"None dynamicscale_even_testZ
input


@

b
output


@

B
Binary file added test/onnx/dynamicscale_odd_test.onnx
Binary file not shown.
17 changes: 17 additions & 0 deletions test/onnx/dynamicscale_small_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
 dynamicscale_small_test:û
¬
inputoutput" DynamicScale*
group_dimÿÿÿÿÿÿÿÿÿ *

group_size  *
output_dtype *"
scale_selection_method"floor *&
zero_point_selection_method"None dynamicscale_small_testZ
input


b
output


B
81 changes: 81 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9735,6 +9735,51 @@ def mxfixneuron_small_test():
return ([node], [in_tv], [out_tv])


@onnx_test()
def dynamicscale_even_test():
in_tv = helper.make_tensor_value_info('input', TensorProto.FLOAT, [3, 64, 4, 4])
out_tv = helper.make_tensor_value_info('output', TensorProto.FLOAT, [3, 64, 4, 4])
node = onnx.helper.make_node('DynamicScale',
inputs=['input'],
group_dim=1,
group_size=32,
output_dtype=23,
scale_selection_method='floor',
zero_point_selection_method='None',
outputs=['output'])
return ([node], [in_tv], [out_tv])


@onnx_test()
def dynamicscale_odd_test():
in_tv = helper.make_tensor_value_info('input', TensorProto.FLOAT, [71, 5, 5])
out_tv = helper.make_tensor_value_info('output', TensorProto.FLOAT, [71, 5, 5])
node = onnx.helper.make_node('DynamicScale',
inputs=['input'],
group_dim=0,
group_size=32,
output_dtype=23,
scale_selection_method='floor',
zero_point_selection_method='None',
outputs=['output'])
return ([node], [in_tv], [out_tv])


@onnx_test()
def dynamicscale_small_test():
in_tv = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 4])
out_tv = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 4])
node = onnx.helper.make_node('DynamicScale',
inputs=['input'],
group_dim=-1,
group_size=32,
output_dtype=23,
scale_selection_method='floor',
zero_point_selection_method='None',
outputs=['output'])
return ([node], [in_tv], [out_tv])


@onnx_test()
def neg_test():
x = helper.make_tensor_value_info('0', TensorProto.INT64, [2, 3])
Expand Down Expand Up @@ -11418,6 +11463,42 @@ def quantizelinear_neg_axis_test():
return make_quantizelinear_axis_graph(-2)


@onnx_test()
def quantizelinear_mxfp4_even_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 64, 4, 4])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 4, 4])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT4E2M1, [3, 64, 4, 4])

node = onnx.helper.make_node(
'QuantizeLinear',
inputs = ['0', '1'],
axis = 1,
block_size = 32,
output_dtype = 23,
outputs = ['out'],
)

return ([node], [arg0, arg1], [arg_out])

@onnx_test()
def quantizelinear_mxfp4_odd_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 64, 4, 7])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 2, 4, 7])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT4E2M1, [3, 64, 4, 7])

node = onnx.helper.make_node(
'QuantizeLinear',
inputs = ['0', '1'],
axis = 1,
block_size = 32,
output_dtype = 23,
outputs = ['out'],
)

return ([node], [arg0, arg1], [arg_out])



@onnx_test()
def randomnormal_test():
dtype = 11
Expand Down
Loading
Loading