Skip to content

Commit 871a3cf

Browse files
Fix bug in flash_attn_backward_fake
1 parent 0d40f50 commit 871a3cf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

hopper/flash_attn_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ def _flash_attn_backward_fake(
311311
sm_margin: int = 0,
312312
) -> torch.Tensor:
313313

314-
is_varlen_q = bool(cu_seqlens_q)
315-
is_varlen_k = bool(cu_seqlens_k)
316-
is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k)
314+
is_varlen_q = cu_seqlens_q is not None
315+
is_varlen_k = cu_seqlens_q is not None
316+
is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
317317

318318
if not is_varlen_q:
319319
batch_size = q.size(0)

0 commit comments

Comments
 (0)