We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
flash_attn_backward_fake
1 parent 0d40f50 commit 871a3cfCopy full SHA for 871a3cf
hopper/flash_attn_interface.py
@@ -311,9 +311,9 @@ def _flash_attn_backward_fake(
311
sm_margin: int = 0,
312
) -> torch.Tensor:
313
314
- is_varlen_q = bool(cu_seqlens_q)
315
- is_varlen_k = bool(cu_seqlens_k)
316
- is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k)
+ is_varlen_q = cu_seqlens_q is not None
+ is_varlen_k = cu_seqlens_q is not None
+ is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
317
318
if not is_varlen_q:
319
batch_size = q.size(0)
0 commit comments