Skip to content
45 changes: 28 additions & 17 deletions onnxruntime/core/providers/cpu/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ void make_copy<MLFloat16, MLFloat16>(MLFloat16* mask_data, const MLFloat16* mask
template <>
void make_copy<float, bool>(float* mask_data, const bool* mask_index, size_t size) {
for (size_t i = 0; i < size; ++i) {
mask_data[i] = mask_index[i] ? 0.0f : std::numeric_limits<float>::lowest();
mask_data[i] = mask_index[i] ? 0.0f : -std::numeric_limits<float>::infinity();
}
}

template <>
void make_copy<MLFloat16, bool>(MLFloat16* mask_data, const bool* mask_index, size_t size) {
for (size_t i = 0; i < size; ++i) {
mask_data[i] = mask_index[i] ? MLFloat16(0.f) : std::numeric_limits<MLFloat16>::lowest();
mask_data[i] = mask_index[i] ? MLFloat16(0.f) : MLFloat16(-std::numeric_limits<float>::infinity());
}
}

Expand Down Expand Up @@ -188,6 +188,16 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
);
}

template <typename T>
T negative_infinity() {
return -std::numeric_limits<T>::infinity();
}

template <>
MLFloat16 negative_infinity() {
return MLFloat16(-std::numeric_limits<float>::infinity());
}

template <typename T>
void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
Expand Down Expand Up @@ -251,7 +261,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
mask_data = static_cast<T*>(allocated_ptr);
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
mask_data[s_i * parameters.total_sequence_length + m_i] = std::numeric_limits<T>::lowest();
mask_data[s_i * parameters.total_sequence_length + m_i] = negative_infinity<T>();
}
}
delete_mask_data = true;
Expand All @@ -277,7 +287,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
for (int i = 0; i < n_iter; ++i) {
for (int s_i = 0; s_i < parameters.q_sequence_length; s_i++) {
for (int m_i = parameters.past_sequence_length + s_i + 1; m_i < parameters.total_sequence_length; m_i++) {
mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = std::numeric_limits<T>::lowest();
mask_data[s_i * parameters.total_sequence_length + m_i + probs_matrix_size * i] = negative_infinity<T>();
}
}
}
Expand Down Expand Up @@ -332,7 +342,8 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
}

// handling GQA
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_i % parameters.kv_num_heads;
std::ptrdiff_t head_ki = head_i * parameters.kv_num_heads / parameters.q_num_heads;
std::ptrdiff_t ki = batch_i * parameters.kv_num_heads + head_ki;
const T* k = K + k_input_chunk_length * ki;

if (nullptr != present_key) {
Expand Down Expand Up @@ -362,7 +373,7 @@ void AttentionBase<T>::ComputeAttentionProbs(T* attention_probs,
alpha,
Q + q_input_chunk_length * parameters.q_num_heads * batch_i + head_i * parameters.head_size,
parameters.head_size * parameters.q_num_heads, // lda
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + (head_i % parameters.kv_num_heads) * parameters.head_size : k,
transposed_k ? K + k_input_chunk_length * parameters.kv_num_heads * batch_i + head_ki * parameters.head_size : k,
transposed_k ? parameters.head_size * parameters.kv_num_heads : parameters.head_size, // ldb
beta,
output,
Expand Down Expand Up @@ -568,7 +579,8 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
// handling GQA
std::ptrdiff_t batch_i = i / num_heads;
std::ptrdiff_t head_i = i % num_heads;
std::ptrdiff_t vi = batch_i * kv_num_heads + head_i % kv_num_heads;
std::ptrdiff_t head_vi = head_i * kv_num_heads / num_heads;
std::ptrdiff_t vi = batch_i * kv_num_heads + head_vi;
const T* v = V + v_input_chunk_length * vi;

if (nullptr != present_value) {
Expand All @@ -592,16 +604,15 @@ void AttentionBase<T>::ComputeVxAttentionScore(T* output, // bu
// V is transposed but not QK. We use GemmEx with a different value for ldb.
math::GemmEx<T, ThreadPool>(CblasNoTrans,
CblasNoTrans,
sequence_length, // M
v_head_size, // N
total_sequence_length, // K
1.f, // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + (head_i % kv_num_heads) * v_head_size + v_input_chunk_length * kv_num_heads * batch_i
: v,
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
0.f, // beta
sequence_length, // M
v_head_size, // N
total_sequence_length, // K
1.f, // alpha
attention_probs + attention_probs_offset, // QK
total_sequence_length, // lda
transposed_v ? V + head_vi * v_head_size + v_input_chunk_length * kv_num_heads * batch_i : v, // V
transposed_v ? v_head_size * kv_num_heads : v_head_size, // ldb
0.f, // beta
output + ((batch_i * sequence_length * num_heads + head_i) * v_head_size),
v_head_size * num_heads, // ldc
nullptr);
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,24 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
// Please make no more changes to the list
static const ORTCHAR_T* immutable_broken_tests[] =
{
// pending ONNX update
ORT_TSTR("attention_3d_gqa"),
ORT_TSTR("attention_3d_gqa_attn_mask"),
ORT_TSTR("attention_3d_gqa_causal"),
ORT_TSTR("attention_3d_gqa_scaled"),
ORT_TSTR("attention_3d_gqa_softcap"),
ORT_TSTR("attention_3d_gqa_with_past_and_present"),
ORT_TSTR("attention_4d_gqa"),
ORT_TSTR("attention_4d_gqa_attn_mask"),
ORT_TSTR("attention_4d_gqa_causal"),
ORT_TSTR("attention_4d_gqa_scaled"),
ORT_TSTR("attention_4d_gqa_softcap"),
ORT_TSTR("attention_4d_gqa_with_past_and_present"),
ORT_TSTR("attention_4d_diff_heads_mask4d_padded_kv"),
ORT_TSTR("attention_4d_gqa_with_past_and_present_fp16"),
ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal"),
ORT_TSTR("attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal"),
// unsupported case
ORT_TSTR("AvgPool1d"),
ORT_TSTR("AvgPool1d_stride"),
ORT_TSTR("AvgPool2d"),
Expand Down
Loading
Loading