Skip to content
10 changes: 6 additions & 4 deletions onnxruntime/core/providers/cpu/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
}

// handling GQA
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads;
std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads;
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
const T* k = K + k_input_chunk_length * ki;

if (nullptr != present_key) {
Expand Down Expand Up @@ -347,7 +348,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
alpha,
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
parameters.head_size * parameters.q_num_heads, // lda
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_i * parameters.head_size : k,
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
beta,
output,
Expand Down Expand Up @@ -555,7 +556,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
// handling GQA
std::ptrdiff_t batch_i = i / num_heads;
std::ptrdiff_t head_i = i % num_heads;
std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
const T* v = V + v_input_chunk_length * vi;

if (nullptr != present_value) {
Expand Down Expand Up @@ -585,7 +587,7 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
1.f, // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + head_i * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
0.f, // beta
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
Expand Down
Loading
Loading