Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions onnxruntime/contrib_ops/cpu/moe/moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters,
const Tensor* fc3_experts_bias, // optional
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
const bool is_fused_swiglu) {
const bool is_fused_swiglu,
const int64_t block_size = 0) { // block size for block-wise quantization
// Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
ASSERT_TENSOR_2D_OR_3D(input);
ASSERT_TENSOR_3D(fc1_experts_weights);
Expand Down Expand Up @@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters,
CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size);
CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size);

CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size);
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size);
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size);
// Validate scale tensors: Handle both row-wise and block-wise quantization flexibly
// First, detect the actual quantization method from the tensor shapes
bool is_row_wise_quantization = true;
if (fc1_experts_scales != nullptr) {
const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims();
if (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1) {
is_row_wise_quantization = false;
}
}

if (block_size > 0 && !is_row_wise_quantization) {
// Block-wise quantization: 3D scale tensors
// For block-wise quantization, we calculate the number of blocks using ceiling division
// to handle cases where the dimension is not perfectly divisible by block_size
const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1) / block_size;
const int64_t fc2_blocks_per_row = (inter_size + block_size - 1) / block_size;
const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1) / block_size;

CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row);
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row);
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row);
} else {
// Row-wise quantization: 2D scale tensors or 3D with last dimension = 1
// Handle both {num_experts, features} and {num_experts, features, 1} shapes
if (fc1_experts_scales != nullptr) {
const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims();
if (fc1_scales_dims.size() == 2) {
CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size);
} else if (fc1_scales_dims.size() == 3) {
CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, 1);
} else {
ORT_THROW("fc1_experts_scales must be 2D or 3D tensor");
}
}

if (fc2_experts_scales != nullptr) {
const auto& fc2_scales_dims = fc2_experts_scales->Shape().GetDims();
if (fc2_scales_dims.size() == 2) {
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size);
} else if (fc2_scales_dims.size() == 3) {
CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, 1);
} else {
ORT_THROW("fc2_experts_scales must be 2D or 3D tensor");
}
}

if (fc3_experts_scales != nullptr) {
const auto& fc3_scales_dims = fc3_experts_scales->Shape().GetDims();
if (fc3_scales_dims.size() == 2) {
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size);
} else if (fc3_scales_dims.size() == 3) {
CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, 1);
} else {
ORT_THROW("fc3_experts_scales must be 2D or 3D tensor");
}
}
}

if (fc3_experts_weights == nullptr) {
ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr);
Expand Down
Loading
Loading