- 
                Notifications
    You must be signed in to change notification settings 
- Fork 36
Closed
Labels
featureNew feature requestNew feature request
Description
Is your feature request related to a problem?
Current kernels always assume both attn_mask and attn_bias are active. This causes:
- Unnecessary global memory loads when only one is needed.
- needless dbias computation when no bias is conceptually required.
- For bias-only usage we still fake a full mask or disable skipping logic inefficiently.
Describe the solution you'd like
Make both attn_mask and attn_bias optional with 4 explicit modes:
| Case | attn_mask | attn_bias | Behavior | 
|---|---|---|---|
| A | None | None | Dense path, no block skip, no bias load/add, fastest | 
| B | Tensor | None | Block skip using mask, no bias add/dbias | 
| C | None | Tensor | No block skip (all blocks active), add bias + compute dbias | 
| D | Tensor | Tensor | Current behavior (mask skip + bias add + dbias) | 
Rules:
- Only load tensors that are not None.
- Only compute / return dbiasif bias provided.
- Block skip only when mask is present (never infer from bias alone).
Describe alternatives you've considered
- Forcing dummy all-ones mask (bandwidth waste).
- Sentinel zero-sized tensors (confusing, still conditional branches).
- Mode enum instead of optional tensors (less ergonomic for users).
Implementation details
Kernel / Launch:
- Add flags: use_mask,use_bias.
- If use_mask: load mask → build skip map (OR reduction per block).
- If !use_mask: all blocks active (skip map implicit true).
- If use_bias: load bias (subject to skip only if mask present).
 Forward:
- Guard bias addition with if (use_bias).
 Backward:
- If use_bias: compute/storedbias; else skip.
- Reuse existing skip logic only when use_mask.
 Performance:
- Case A removes both memory streams + skip logic overhead.
- Case B removes bias path (saves reads/writes + math).
- Case C removes mask loads / OR reductions (simpler control flow).
 Autograd API:
- Gradient for absent tensors must be None(not zeros).
 Python wrapper:
- Accept attn_mask: Optional[Tensor] = None,attn_bias: Optional[Tensor] = None.
 Testing:
- Case A ≈ Case D with all-ones mask & zero bias.
- Case B ≈ Case D with zero bias (and no dbias returned).
- Case C ≈ Case D with all-ones mask (no skip) same bias.
- Case D unchanged baseline regression tests.
- Gradient checks: dbiasonly when bias passed.
 Docs:
- Add mode table + guidance on when to omit tensors.
- Note: omitting mask disables block skipping even if sparsity exists.
Describe alternatives you've considered (API level)
- Separate functions (explodes surface area).
- Single enum “mode” (harder to mix future extensions).
Use case
- Dense long-context causal inference (no mask/bias).
- Relative position / ALiBi style bias-only (no sparsity).
- Structured sparsity pruning (mask-only).
- Combined sparse + bias training (both).
Related work
- FlashAttention dense fast path patterns.
- Sparse block skipping in dynamic mask attention.
- Relative position bias (bias-only adopters).
Why it’s valuable
Eliminates wasted DRAM traffic and unnecessary gradient work; improves throughput and reduces latency for large sequence lengths and specialized training/inference regimes.
Copilot
Metadata
Metadata
Labels
featureNew feature requestNew feature request