Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/framework/float16.h"
#include "core/framework/allocator.h"
#include "core/platform/threadpool.h"
#include "core/common/narrow.h"

#include <algorithm>
#include <vector>
Expand Down Expand Up @@ -120,7 +121,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
int num_routing_threads = 1;
if (tp != nullptr && num_tokens >= 1024) {
int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp);
num_routing_threads = std::min(static_cast<int>(num_tokens / 512), max_threads);
num_routing_threads = std::min(narrow<int>(num_tokens / 512), max_threads);
num_routing_threads = std::max(1, num_routing_threads);
}

Expand All @@ -133,7 +134,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
}

concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) {
auto work = concurrency::ThreadPool::PartitionWork(static_cast<int>(thread_id), num_routing_threads, static_cast<std::ptrdiff_t>(num_tokens));
auto work = concurrency::ThreadPool::PartitionWork(narrow<int>(thread_id), num_routing_threads, static_cast<std::ptrdiff_t>(num_tokens));
auto& local_expert_token_map = thread_local_expert_token_maps[thread_id];

std::vector<std::pair<float, int64_t>> sorted_logits(static_cast<size_t>(num_experts));
Expand Down Expand Up @@ -173,7 +174,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
int64_t route_idx = i * k_ + j;
float normalized_weight = sorted_logits[static_cast<size_t>(j)].first * inv_top_k_sum;

route_expert[route_idx] = static_cast<int>(expert_idx);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = normalized_weight;
if (normalized_weight > 0.0f) {
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
Expand All @@ -185,7 +186,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,
int64_t route_idx = i * k_ + j;
float weight = sorted_logits[static_cast<size_t>(j)].first;

route_expert[route_idx] = static_cast<int>(expert_idx);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = weight;
if (weight > 0.0f) {
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
Expand Down Expand Up @@ -319,7 +320,7 @@ Status MoE<T>::ComputeMoE(const OpKernelContext* context,

// Optimized expert processing with thread-local buffer reuse
concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) {
int thread_id = static_cast<int>(thread_id_pd);
int thread_id = narrow<int>(thread_id_pd);
auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast<std::ptrdiff_t>(num_experts));

float* local_output = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size;
Expand Down Expand Up @@ -440,6 +441,11 @@ Status MoE<T>::ProcessExpertBatch(const T* input_tokens,
int64_t inter_size,
T* fc1_output_buffer,
T* activation_output_buffer) const {
ORT_UNUSED_PARAMETER(token_expert_ids);
ORT_UNUSED_PARAMETER(token_weights);
ORT_UNUSED_PARAMETER(expert_id);
ORT_UNUSED_PARAMETER(fc1_output_buffer);
ORT_UNUSED_PARAMETER(activation_output_buffer);
const bool is_swiglu = activation_type_ == ActivationType::SwiGLU;
const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size;

Expand Down
Loading
Loading