[FEATURE] Reduce compute bubbles in backward skip path #133
+967
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements a comprehensive set of optimizations to reduce compute bubbles in the backward kernel's skip branch when handling fully masked blocks. The optimizations target the key inefficiencies identified in sparse attention scenarios where many tiles are inactive.
Problem
The original backward kernel skip branch showed substantial compute bubbles even when successfully skipping mathematically null GEMMs. Key inefficiencies included:
Solution
Implemented a 4-phase optimization strategy:
Phase 1: Early Mask Prefetch
Moved mask loading and activity checking before heavy K/V/dO async loads, enabling skip decisions before expensive memory operations complete.
Phase 2: Conditional Synchronization
Bypass unnecessary
__syncthreads()
barriers for fully masked tiles when safe for pipeline correctness.Phase 3: Next-Tile Look-Ahead
Added infrastructure for prefetching subsequent mask/bias tiles during skip cycles to hide future operation latency.
Phase 4: Adaptive Density Mode
Dynamically disable skip optimization when active tile density exceeds 85%, preventing overhead in high-density scenarios.
Performance Impact
Expected speedups based on sparsity level:
Key Features
Testing
Added comprehensive validation:
The optimizations maintain backward compatibility while delivering significant performance improvements for sparse attention workloads common in long-context models, mixture-of-experts architectures, and structured document attention.
Fixes #132.
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.