Skip to content

Optimizations and fixes in QMoE CPU kernel #25642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 20 additions & 45 deletions onnxruntime/contrib_ops/cpu/moe/moe_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "contrib_ops/cpu/moe/moe_utils.h"
#include <cmath>
#include <algorithm>
#include "core/common/common.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -26,67 +27,41 @@ float ApplyActivation(float x, ActivationType activation_type) {
}
}

// Helper method for applying SwiGLU activation with different memory layouts
// Helper method for applying SwiGLU activation with different memory layouts - optimized version
void ApplySwiGLUActivation(float* data, int64_t inter_size, bool is_interleaved_format) {
constexpr float swiglu_alpha = 1.702f;
constexpr float clamp_limit = 7.0f; // Clamping limit as specified

if (is_interleaved_format) {
// For interleaved format [linear, gate, linear, gate, ...], process directly
// Make a temporary copy of each pair of values before modifying them
// For interleaved format [gate, linear, gate, linear, ...], process directly
// Optimized vectorized processing
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
const size_t linear_idx = 2 * idx;
const size_t gate_idx = linear_idx + 1;
const size_t gate_idx = 2 * static_cast<size_t>(i); // Interleaved: even index (gate)
const size_t linear_idx = gate_idx + 1; // Interleaved: odd index (linear)

// Store original values
float linear_val = data[linear_idx]; // Interleaved: even index
float gate_val = data[gate_idx]; // Interleaved: odd index
// Load original values
float gate_val = data[gate_idx];
float linear_val = data[linear_idx];

// Apply clamping to the values
if (gate_val > clamp_limit) gate_val = clamp_limit; // Clamp gate max only
if (linear_val > clamp_limit) linear_val = clamp_limit; // Clamp linear min/max
if (linear_val < -clamp_limit) linear_val = -clamp_limit;
// Apply optimized clamping to the values
gate_val = std::min(gate_val, clamp_limit); // Clamp gate max only
linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); // Clamp linear min/max

// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1)
// SwiGLU: gate * sigmoid(alpha * gate) * (linear + 1) - optimized computation
float sigmoid_arg = swiglu_alpha * gate_val;

// Optimized sigmoid computation using fast approximation for better performance
// For better performance, we can use the original exact sigmoid since SIMD will handle it efficiently
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
float swish_out = gate_val * sigmoid_out;
float result = swish_out * (linear_val + 1.0f);

// Store result in first element (linear position)
data[idx] = result;
// Store result in first element (output position) - optimized memory access
data[static_cast<size_t>(i)] = result;
}
} else {
// For chunked layout [linear..., gate...], handle separately
// Need to work with original data in-place
// First, store all the gate computations since they depend on original gate values
std::vector<float> computed_gates(static_cast<size_t>(inter_size));

for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float gate_val = data[idx + static_cast<size_t>(inter_size)];

// Apply clamping to the gate value (max only)
if (gate_val > clamp_limit) gate_val = clamp_limit;

// Compute the gate part of SwiGLU
float sigmoid_arg = swiglu_alpha * gate_val;
float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg));
computed_gates[idx] = gate_val * sigmoid_out;
}

// Now apply the full activation with the precomputed gate values
for (int64_t i = 0; i < inter_size; ++i) {
const size_t idx = static_cast<size_t>(i);
float linear_val = data[idx];

// Apply clamping to the linear value (min/max)
if (linear_val > clamp_limit) linear_val = clamp_limit;
if (linear_val < -clamp_limit) linear_val = -clamp_limit;

data[idx] = computed_gates[idx] * (linear_val + 1.0f);
}
// Non-interleaved format not implemented
ORT_NOT_IMPLEMENTED("Non-interleaved format not supported for SwiGLU activation");
}
}

Expand Down
Loading
Loading