Skip to content

Optimizations and fixes in QMoE CPU kernel #25642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

apsonawane
Copy link
Contributor

This pull request focuses on optimizing the SwiGLU activation implementation in the MoE (Mixture of Experts) module and updating corresponding test cases to reflect these changes. The most important changes include performance improvements in the SwiGLU activation, removal of support for non-interleaved formats, and updates to test cases for quantized weights to ensure consistency.

SwiGLU Activation Optimizations:

  • Optimized the interleaved format processing in ApplySwiGLUActivation by introducing vectorized computations, faster clamping, and efficient memory access. Removed support for non-interleaved formats, replacing it with a ORT_NOT_IMPLEMENTED error. (onnxruntime/contrib_ops/cpu/moe/moe_utils.cc, onnxruntime/contrib_ops/cpu/moe/moe_utils.ccL29-R64)

// Generate test weights for symmetric quantization (zero point is 0)
std::vector<uint8_t> fc1_experts_weights(num_experts * fc1_weight_size_per_expert, 0x12); // 1,2 -> small positive weights
std::vector<uint8_t> fc2_experts_weights(num_experts * fc2_weight_size_per_expert, 0xFF); // -1,0 -> small mixed weights
// Generate test weights for symmetric quantization (zero point is 8 for 4-bit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero point is 0 for symmetric quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is 0.
Original range is -8 to 7, it is zero because original zero maps at 8

ort_dtype_quant_bits_tolerance_map = {
"FP32:0": (5e-3, 1e-3),
"FP16:0": (5e-2, 1e-3),
"FP16:4": (2.0, 8e-3), # Improved tolerance with symmetric quantization
"FP16:8": (1.5, 8e-3), # Improved tolerance with symmetric quantization
"FP16:4": (8.0, 0.15), # 4-bit quantization error tolerance - improved with bug fixes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need lower the thresholds?

onnx_dtype=self.onnx_dtype,
fc1_experts_weights=self.moe_experts_weight1,
fc2_experts_weights=self.moe_experts_weight2,
# Biases are not used in QMoE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not True. Kernel support biases, but this test did not.

@@ -959,7 +1186,7 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, use

# Skip if the session creation failed
if phi3_moe.ort_sess is None:
self.skipTest("Failed to create ONNX Runtime session - CPU MoE operator not available")
self.skipTest("Failed to create ONNX Runtime session")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not skip test for this case. Instead throw an exception.

Comment on lines 229 to 238
// Use arena allocator for better memory management and reduced fragmentation
// This is especially beneficial for repeated kernel invocations
AllocatorPtr arena_allocator;
if (context->GetUseDeterministicCompute()) {
// For deterministic compute, use the standard temp allocator
arena_allocator = allocator;
} else {
// Try to get arena allocator for better performance
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&arena_allocator));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think using arena will make it difference in deterministic.

router_probs_float_ptr = unified_conversion_buffer.get() + input_size;

// Set up smart pointers with custom deleters to avoid double-free
input_float = IAllocatorUniquePtr<float>(input_float_ptr, [](float*) {});
Copy link
Contributor

@tianleiwu tianleiwu Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified_conversion_buffer is already smart pointer. Why to wrap its memory into another smart pointer here?


// Set up smart pointers with custom deleters to avoid double-free
input_float = IAllocatorUniquePtr<float>(input_float_ptr, [](float*) {});
router_probs_float = IAllocatorUniquePtr<float>(router_probs_float_ptr, [](float*) {});
Copy link
Contributor

@tianleiwu tianleiwu Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.
Please make sure this code path be tested.

const size_t total_conversion_size = input_size + router_probs_size;

// Single allocation for input and router_probs conversion
auto unified_conversion_buffer = IAllocator::MakeUniquePtr<float>(arena_allocator, total_conversion_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This smarter pointer will be deleted after if scope. I think it needs to be moved to outer scope since its life cycle shall be kept till computation is done.

if (is_decoding_scenario) {
// Decoding scenario: partition work by tokens * experts for better parallelization
// This allows multiple threads to work on different experts for the same token
double cost_per_token_expert = static_cast<double>(moe_params.hidden_size * moe_params.inter_size * 0.6); // Cost per token-expert pair
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why * 0.6 here?

Copy link
Contributor Author

@apsonawane apsonawane Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be 2 for 2 GEMM operations

expert_idx, routing_weight, moe_params, is_swiglu,
dequant_fc1_weights, dequant_fc2_weights,
fc1_bias_float_ptr, fc2_bias_float_ptr, fc1_output_size); // Atomically accumulate results to avoid race conditions
float* token_result = output_float_ptr + token_idx * moe_params.hidden_size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data race bug.

For instance, Thread 1 might handle (token_idx=0, expert_idx=0) while Thread 2 handles (token_idx=0, expert_idx=1). Both threads will eventually attempt to write to the same memory location (output_float_ptr + token_idx * hidden_size) without any synchronization.

@tianleiwu
Copy link
Contributor

PrepackAndDequantizeWeights shall be called the standard prepack override funciton like

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) override;
Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) override;

It has ability to remove the existing weights (quantized weights) by setting is_packed=true to reduce memory.

@tianleiwu
Copy link
Contributor

Performance Improvement Suggestion: Batch GEMM Operations Instead of Iterative Calls
The most significant performance bottleneck is calling MlasGemm for each token-expert pair individually (with a batch size M=1). This leads to high overhead and underutilizes the CPU.

Problem: In QuantizedMoEImpl, the code loops through tokens and then through experts, performing a small matrix multiplication for each. This is highly inefficient.

Recommendation: Refactor the logic to use a batched GEMM approach. This involves a "token shuffling" strategy:

First, iterate through all tokens to determine which expert each token is routed to.
Group the tokens by their assigned expert.
For each expert, create a single batch of all tokens assigned to it.
Perform one large, efficient MlasGemm operation for each expert on its batch of tokens.
"Un-shuffle" the results back to their original token order.
This approach dramatically improves computational efficiency and is standard practice in high-performance MoE implementations.

@hariharans29
Copy link
Member

Performance Improvement Suggestion: Batch GEMM Operations Instead of Iterative Calls The most significant performance bottleneck is calling MlasGemm for each token-expert pair individually (with a batch size M=1). This leads to high overhead and underutilizes the CPU.

Problem: In QuantizedMoEImpl, the code loops through tokens and then through experts, performing a small matrix multiplication for each. This is highly inefficient.

Recommendation: Refactor the logic to use a batched GEMM approach. This involves a "token shuffling" strategy:

First, iterate through all tokens to determine which expert each token is routed to. Group the tokens by their assigned expert. For each expert, create a single batch of all tokens assigned to it. Perform one large, efficient MlasGemm operation for each expert on its batch of tokens. "Un-shuffle" the results back to their original token order. This approach dramatically improves computational efficiency and is standard practice in high-performance MoE implementations.

I wonder if we even need to "clump" tokens per expert but rather pass through pointers to the token rows per expert and let MLAS use the row pointers to gather data into its compute registers. This way we can avoid the "clump" and "unshuffle" operations altogether ?

@tianleiwu
Copy link
Contributor

I wonder if we even need to "clump" tokens per expert but rather pass through pointers to the token rows per expert and let MLAS use the row pointers to gather data into its compute registers. This way we can avoid the "clump" and "unshuffle" operations altogether ?

"clump" and "unshuffle" is just logically. In implementation, we can pass pointers. That's supported by the MLAS Batch GEMM API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants