Skip to content

add sliding window support for webgpu gqa #25372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AdditionalImplementation() << "var<workgroup> tileQ: array<q_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "var<workgroup> tileK: array<key_value_t, " << tile_size_ * tile_size_ << ">;\n"
<< "alias f32_val_t = " << (components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32")) << ";\n";

shader.MainFunctionBody() << "// x holds the N and y holds the M\n"
<< "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n"
<< "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n"
Expand Down Expand Up @@ -224,6 +225,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o
}

Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
bool has_sliding_window = local_window_size_ != -1;

if (has_seqlen_k_) {
shader.AddInput("seqlen_k", ShaderUsage::UseUniform);
}
Expand All @@ -241,15 +244,33 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
std::ostringstream oss;
InitVarStub(oss, has_seqlen_k_);
shader.MainFunctionBody() << oss.str()
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"
<< "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n"
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n"
<< "}\n"
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
<< "workgroupBarrier();\n";
<< "let local_offset = local_idx * uniforms.elements_per_thread;\n"
<< "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n";
if (has_sliding_window) {
// Sliding window
shader.MainFunctionBody()
<< "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n"
<< "let start_offset = select(0, seq_causal_length - uniforms.local_window_size, should_apply_local_window);\n"
<< "let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);\n";
} else {
// No sliding window: we keep the code for sliding window in the shader but
// using const for start_offset and should_apply_local_window will make the compiler optimize it out.
shader.MainFunctionBody()
<< "const start_offset = 0;\n"
<< "const should_apply_local_window = false;\n"
<< "let effective_seq_length = seq_causal_length;\n";
}
shader.MainFunctionBody()
<< "var thread_max_vector = f32_val_t(-3.402823e+38f);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n"
<< " let actual_pos = local_offset + i + start_offset;\n"
<< " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n"
<< " thread_max_vector = max(f32_val_t(x[offset + i + start_offset]), thread_max_vector);\n"
<< " }\n"
<< "}\n"
<< "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n"
<< "workgroupBarrier();\n";

if (has_head_sink_) {
// Handle head sink
Expand All @@ -265,8 +286,11 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " max_value = max(thread_max[i], max_value);\n"
<< "}\n"
<< "var sum_vector = f32_val_t(0);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n"
<< "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n"
<< " let actual_pos = local_offset + i + start_offset;\n"
<< " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n"
<< " sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);\n"
<< " }\n"
<< "}\n"
<< "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n"
<< "workgroupBarrier();\n"
Expand All @@ -282,15 +306,33 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

shader.MainFunctionBody() << "if (sum == 0) {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n"
<< " let actual_pos = local_offset + i + start_offset;\n"
<< " if (actual_pos < seq_causal_length) {\n"
<< " x[offset + i + start_offset] = x_value_t(x_element_t(1.0)/x_element_t(effective_seq_length));\n"
<< " }\n"
<< " }\n"
<< "} else {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n"
<< " let actual_pos = local_offset + i + start_offset;\n"
<< " let pos = offset + i + start_offset;\n"
<< " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n"
<< " var f32input = f32_val_t(x[pos]);\n"
<< " x[pos] = x_value_t(exp(f32input - max_value) / sum);\n"
<< " }\n"
<< " }\n"
<< "}\n";

// zero out elements outsize the sliding window
shader.MainFunctionBody() << "if (should_apply_local_window) {\n"
<< " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n"
<< " var f32input = f32_val_t(x[offset + i]);\n"
<< " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n"
<< " let global_pos = i + local_offset;\n"
<< " if (global_pos < start_offset) {\n"
<< " x[offset + i] = x_value_t(x_element_t(0));\n"
<< " }\n"
<< " }\n"
<< "}\n";

if (has_seqlen_k_) {
shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n"
<< " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n"
Expand All @@ -301,7 +343,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length,
const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) {
const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink, int local_window_size) {
const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1));
int work_group_size = 64;
const int total_sequence_length_comp = (total_sequence_length + components - 1) / components;
Expand All @@ -310,15 +352,15 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
}
const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size;

InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr};
InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr, local_window_size};
if (seqlen_k != nullptr) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank});
}
if (head_sink != nullptr) {
program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
}
program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}})
.CacheHint(work_group_size, use_smooth_softmax)
.CacheHint(work_group_size, use_smooth_softmax, local_window_size != -1)
.SetDispatchGroupSize(batch_size * num_heads * sequence_length)
.SetWorkgroupSize(work_group_size)
.AddUniformVariables({{static_cast<uint32_t>(batch_size)},
Expand All @@ -327,7 +369,8 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso
{static_cast<uint32_t>(sequence_length)},
{static_cast<uint32_t>(total_sequence_length_comp)},
{static_cast<uint32_t>(elementsPerThread)},
{static_cast<uint32_t>(is_first_prompt ? 1 : 0)}});
{static_cast<uint32_t>(is_first_prompt ? 1 : 0)},
{static_cast<uint32_t>(local_window_size)}});

return context.RunProgram(program);
}
Expand Down Expand Up @@ -467,7 +510,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int

Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) {
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink,
const Tensor* seqlen_k, int local_window_size) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length =
Expand All @@ -481,7 +525,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
parameters, past_sequence_length, total_sequence_length, seqlen_k));

ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs,
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink));
parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink, local_window_size));

ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value,
parameters, past_sequence_length, total_sequence_length, seqlen_k));
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {

class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
public:
InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink)
: Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) {
InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink, int local_window_size)
: Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink), local_window_size_(local_window_size) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -81,14 +81,16 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
{"sequence_length", ProgramUniformVariableDataType::Uint32},
{"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32},
{"elements_per_thread", ProgramUniformVariableDataType::Uint32},
{"is_first_prompt", ProgramUniformVariableDataType::Uint32});
{"is_first_prompt", ProgramUniformVariableDataType::Uint32},
{"local_window_size", ProgramUniformVariableDataType::Uint32});

private:
int work_group_size_;
int components_;
bool use_smooth_softmax_;
bool has_seqlen_k_;
bool has_head_sink_;
int local_window_size_;
};

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,
const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr);
const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr, int local_window_size = -1);

} // namespace webgpu
} // namespace contrib
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&

if (!do_rotary_ &&
head_sink == nullptr && !use_smooth_softmax_ &&
local_window_size_ == -1 &&
CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
present_value, parameters, context);
Expand Down Expand Up @@ -241,7 +242,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q));
if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format
return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context, head_sink, seqlen_k);
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
}

TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_,
Expand All @@ -258,7 +259,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_,
parameters.v_head_size_, value, nullptr, 0, &V));
return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key,
present_value, parameters, context, head_sink, seqlen_k);
present_value, parameters, context, head_sink, seqlen_k, local_window_size_);
}

} // namespace webgpu
Expand Down
Loading