Skip to content

[webgpu] Apply Flash Attention if sliding window exceeds KV cache length #25594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,9 @@ Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext&

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, int local_window_size) {
ORT_ENFORCE(local_window_size == -1, "Sliding window is not supported yet in FlashAttention.");

ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value));
const int present_sequence_length = parameters.is_gqa_ ? parameters.seqlen_present_kv_cache_ : parameters.total_sequence_length_;
if (parameters.sequence_length_ > 1) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class FlashAttentionDecodeVxReduceProgram final : public Program<FlashAttentionD

Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
Tensor* output, const Tensor* past_key, Tensor* present_key, const Tensor* past_value, Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, int local_window_size = -1);

bool CanApplyFlashAttention(const Tensor* bias, const Tensor* present_key, const Tensor* present_value,
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
Tensor* present_value = context.Output(2, present_kv_shape);
parameters.past_present_share_buffer_ = present_key != nullptr && present_value != nullptr && past_key != nullptr && past_value != nullptr && past_key->DataRaw() == present_key->DataRaw() && past_value->DataRaw() == present_value->DataRaw();

bool use_sliding_window = (local_window_size_ != -1 && local_window_size_ < parameters.seqlen_present_kv_cache_ && local_window_size_ < parameters.total_sequence_length_);
if (!do_rotary_ &&
head_sink == nullptr && !use_smooth_softmax_ &&
local_window_size_ == -1 &&
!use_sliding_window &&
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
Expand Down
Loading