-
Notifications
You must be signed in to change notification settings - Fork 36
Closed
Labels
featureNew feature requestNew feature request
Description
Is your feature request related to a problem?
Current kernels always assume both attn_mask and attn_bias are active. This causes:
- Unnecessary global memory loads when only one is needed.
- needless dbias computation when no bias is conceptually required.
- For bias-only usage we still fake a full mask or disable skipping logic inefficiently.
Describe the solution you'd like
Make both attn_mask and attn_bias optional with 4 explicit modes:
| Case | attn_mask | attn_bias | Behavior |
|---|---|---|---|
| A | None | None | Dense path, no block skip, no bias load/add, fastest |
| B | Tensor | None | Block skip using mask, no bias add/dbias |
| C | None | Tensor | No block skip (all blocks active), add bias + compute dbias |
| D | Tensor | Tensor | Current behavior (mask skip + bias add + dbias) |
Rules:
- Only load tensors that are not
None. - Only compute / return
dbiasif bias provided. - Block skip only when mask is present (never infer from bias alone).
Describe alternatives you've considered
- Forcing dummy all-ones mask (bandwidth waste).
- Sentinel zero-sized tensors (confusing, still conditional branches).
- Mode enum instead of optional tensors (less ergonomic for users).
Implementation details
Kernel / Launch:
- Add flags:
use_mask,use_bias. - If
use_mask: load mask → build skip map (OR reduction per block). - If
!use_mask: all blocks active (skip map implicit true). - If
use_bias: load bias (subject to skip only if mask present).
Forward: - Guard bias addition with
if (use_bias).
Backward: - If
use_bias: compute/storedbias; else skip. - Reuse existing skip logic only when
use_mask.
Performance: - Case A removes both memory streams + skip logic overhead.
- Case B removes bias path (saves reads/writes + math).
- Case C removes mask loads / OR reductions (simpler control flow).
Autograd API: - Gradient for absent tensors must be
None(not zeros).
Python wrapper: - Accept
attn_mask: Optional[Tensor] = None,attn_bias: Optional[Tensor] = None.
Testing:
- Case A ≈ Case D with all-ones mask & zero bias.
- Case B ≈ Case D with zero bias (and no dbias returned).
- Case C ≈ Case D with all-ones mask (no skip) same bias.
- Case D unchanged baseline regression tests.
- Gradient checks:
dbiasonly when bias passed.
Docs:
- Add mode table + guidance on when to omit tensors.
- Note: omitting mask disables block skipping even if sparsity exists.
Describe alternatives you've considered (API level)
- Separate functions (explodes surface area).
- Single enum “mode” (harder to mix future extensions).
Use case
- Dense long-context causal inference (no mask/bias).
- Relative position / ALiBi style bias-only (no sparsity).
- Structured sparsity pruning (mask-only).
- Combined sparse + bias training (both).
Related work
- FlashAttention dense fast path patterns.
- Sparse block skipping in dynamic mask attention.
- Relative position bias (bias-only adopters).
Why it’s valuable
Eliminates wasted DRAM traffic and unnecessary gradient work; improves throughput and reduces latency for large sequence lengths and specialized training/inference regimes.
Copilot
Metadata
Metadata
Labels
featureNew feature requestNew feature request