Skip to content

Conversation

LoserCheems
Copy link
Collaborator

Summary

This PR fixes potential gradient flow issues and memory inefficiencies in the top-k attention mask construction logic within _flash_dynamic_mask_attention_forward. The changes address:

  1. Gradient Safety: Prevents gradients from flowing back through positions filled with -inf (masked positions) during top-k selection
  2. Memory Efficiency: Eliminates unnecessary tensor allocations when no pre-existing mask is provided
  3. Dimension Handling: Ensures attention_bias maintains its original 3D shape when paired with 4D masks, avoiding unintended dimension expansion

These improvements contribute to resolving issues like #180 by ensuring safer numerical operations during backward passes.

Root Cause

Problem 1: Unsafe Gradient Flow Through Masked Positions

Location: flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py:L91-L107 (old version)

When selecting top-k attention positions from attention_bias, the code used masked_fill(~attention_mask, min_dtype) to exclude masked positions. However, without proper gradient detachment:

  • Gradients could flow back through the masked_fill operation to masked positions
  • These positions contain -inf values, which can cause numerical instability (INF/NaN) during backward passes
  • The issue is particularly problematic with bf16/fp16 precision where the representable range is limited

Problem 2: Unnecessary Memory Allocation

Location: Same file, when attention_mask is None

The original code created a full ones_like(attention_bias) mask even when no masking was needed:

# Old code
if attention_mask is None:
    attention_mask = torch.ones_like(attention_bias, dtype=torch.bool)  # Wasteful!

This resulted in:

  • Unnecessary memory allocation (e.g., 0.125 MB for a 3D bias with shape (2, 32, 2048))
  • Redundant masked_fill operations on an all-True mask (no effect, pure overhead)
  • The final effect was equivalent to not having a mask at all

Problem 3: Dimension Expansion Side Effects

Location: 4D mask + 3D bias scenario

When handling 4D attention_mask with 3D attention_bias, the code would expand the bias and reassign it:

# Old problematic pattern
attention_bias = attention_bias.unsqueeze(-2).expand_as(attention_mask)
# Now attention_bias is 4D, but should remain 3D for kernel!

This caused the kernel to receive 4D bias instead of the intended 3D bias, potentially affecting downstream computations.

Changes

Code-Level Changes

File: flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

1. Added Gradient Detachment

Applied .detach() to the masked bias tensor before top-k selection:

topk_indices = torch.topk(
    attention_bias_for_topk.masked_fill(~attention_mask, min_dtype).detach(),  # ✅ detach here
    keep_window_size, dim=-1, largest=True, sorted=False,
).indices

Effect: Prevents gradients from flowing back through masked positions filled with -inf, eliminating a source of numerical instability.

2. Eliminated Unnecessary Allocation for None Mask

Split the logic into two branches:

if attention_mask is not None:
    # Has existing mask: select top-k from allowed positions with masked_fill
    # ... (standard path with masked_fill + detach)
else:
    # No existing mask: directly select top-k without masked_fill
    topk_indices = torch.topk(
        attention_bias.detach(),  # ✅ No masked_fill needed!
        keep_window_size, dim=-1, largest=True, sorted=False,
    ).indices
    attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool).scatter_(
        -1, topk_indices, True
    )

Effect: Saves ~36% peak memory in the None mask scenario by skipping ones_like and masked_fill.

3. Preserved Bias Dimensionality with Temporary Variable

Introduced attention_bias_for_topk to handle dimension expansion without mutating the original:

if attention_mask.dim() == 4 and attention_bias.dim() == 3:
    attention_bias_for_topk = attention_bias.unsqueeze(-2).expand_as(attention_mask)
else:
    attention_bias_for_topk = attention_bias
# Use attention_bias_for_topk only for top-k selection
# attention_bias remains 3D for kernel!

Effect: Ensures the kernel receives 3D bias even when 4D mask is present, maintaining API contract.

API / Behavioral Changes

  • No breaking changes: All function signatures remain identical
  • Improved safety: Gradients are now properly isolated from masked positions
  • Memory optimization: Reduced peak memory usage in common scenarios
  • Dimension preservation: attention_bias dimensionality is now correctly maintained

Reproduction

Minimal Example Demonstrating the Fix

import torch
from flash_dmattn.integrations.modeling_flash_dynamic_mask_attention_utils import (
    _flash_dynamic_mask_attention_forward
)

# Setup: simulate training scenario with keep_window_size
batch_size, num_heads, seq_len, head_dim = 2, 32, 4096, 128
keep_window_size = 2048

query = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16, device='cuda')
key = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16, device='cuda')
value = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=torch.bfloat16, device='cuda')

# Create learnable bias (simulating dynamic mask attention)
attention_bias = torch.randn(batch_size, num_heads, seq_len, dtype=torch.bfloat16, device='cuda', requires_grad=True)

# Forward pass
output = _flash_dynamic_mask_attention_forward(
    query, key, value,
    attention_mask=None,
    attention_bias=attention_bias,
    query_length=seq_len,
    key_length=seq_len,
    is_causal=True,
    keep_window_size=keep_window_size
)

# Backward pass - should not produce INF/NaN
loss = output.sum()
loss.backward()

# Verify gradient is finite
assert torch.isfinite(attention_bias.grad).all(), "Gradient contains INF/NaN!"
print("✅ Gradient is finite and safe")

Before vs After

Scenario Before After
Gradient Safety ❌ Gradients flow through -inf masked positions .detach() isolates gradients
None Mask Memory 1.375 MB peak (with wasteful ones_like) 0.875 MB peak (36% reduction)
4D+3D Bias Shape ❌ Bias becomes 4D after expansion ✅ Bias stays 3D (only expanded in local scope)

Tests

Validation Performed

  1. Gradient Flow Test (test_gradient_flow.py):

    • Verified that gradients correctly flow to learnable parameters (A, dt_proj)
    • Confirmed .detach() only blocks gradients from masked positions, not the main computation path
    • Result: ✅ Gradients flow normally with norm values ~24.8 for A, ~7.0 for dt_proj.weight
  2. Memory Profiling (verify_memory_three_scenarios.py):

    • Measured peak memory for three scenarios:
      • 3D mask + 3D bias: 1.375 MB
      • 4D mask + 3D bias: 2304.25 MB (bias stays 3D ✅)
      • None mask + 3D bias: 0.875 MB (36% reduction ✅)
    • Confirmed no memory leaks or unexpected allocations
  3. Dimension Integrity Test (test_topk_fix.py):

    • Verified all dimension combinations:
      • 3D + 3D: ✅ No expansion
      • 4D + 4D: ✅ No expansion
      • 4D + 3D: ✅ Bias remains 3D, only mask is 4D
      • None + 3D: ✅ New mask matches bias shape
  4. Numerical Stability Test:

Test Coverage

  • ✅ All dimension combinations (3D/4D mask × 3D/4D bias)
  • ✅ None mask optimization path
  • ✅ Gradient flow correctness
  • ✅ Memory footprint validation
  • ✅ Numerical stability under bf16/fp16

Compatibility

Backward Compatibility

Fully backward compatible - no API changes, all existing code continues to work without modification.

Performance Impact

  • Memory: Up to 36% reduction in peak memory for None mask scenarios
  • Compute: Minor reduction in redundant operations (skipped ones_like, masked_fill for None mask)
  • Gradient Safety: Improved numerical stability, especially with low-precision dtypes (bf16/fp16)

Migration Notes

No migration required. This is a drop-in replacement with identical external behavior but improved internals.

Related Issues & PRs

  • Fixes: #180 - INF in backward pass (contributes to the solution alongside kernel fixes)
  • Related: #181 - Kernel-level clamping fix
  • Related: #182 - Mask/bias memory access fixes

Checklist

Additional Notes

Design Rationale

The three-pronged approach (detach, None-mask optimization, dimension preservation) addresses distinct issues:

  1. Safety first: .detach() is a zero-cost operation that eliminates gradient hazards
  2. Efficiency: Avoiding unnecessary allocations is a clear win
  3. Correctness: Maintaining intended dimensions prevents subtle API contract violations

Future Considerations


Reviewers: Please pay special attention to the gradient flow test results and memory profiling, as these validate the core claims of the PR.

Respects existing masks when applying keep‑window top‑k selection and aligns bias/mask shapes (3D↔4D) before selection. Builds boolean masks from indices and avoids bias overwrites.

Prevents errors by casting inputs only when present, improving compatibility with PEFT/LoRA setups.

Cleans up unused imports and variables for clarity.
@Copilot Copilot AI review requested due to automatic review settings October 4, 2025 14:28
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the top-k mask construction in flash dynamic mask attention by addressing gradient flow safety, memory efficiency, and dimension handling issues. The changes prevent unsafe gradient propagation through masked positions, eliminate unnecessary memory allocations, and preserve tensor dimensionality.

  • Adds gradient detachment using .detach() to prevent gradients flowing through masked positions filled with -inf
  • Optimizes memory usage by avoiding unnecessary tensor allocations when no pre-existing mask is provided
  • Preserves original tensor dimensions by using temporary variables for dimension expansion

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

from transformers.utils import logging
from transformers.integrations import flash_attention


Copy link

Copilot AI Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Removed import statement leaves an empty line. Consider removing the blank line to maintain consistent spacing.

Suggested change

Copilot uses AI. Check for mistakes.

Comment on lines 26 to +31
if target_dtype and q.dtype == torch.float32:
logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-dmattn compatibility.")
q, k, v, bias = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype), bias.to(target_dtype)
q = q.to(target_dtype) if q is not None else None
k = k.to(target_dtype) if k is not None else None
v = v.to(target_dtype) if v is not None else None
bias = bias.to(target_dtype) if bias is not None else None
Copy link

Copilot AI Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The individual None checks can be simplified. Consider using a helper function or list comprehension to reduce code duplication and improve readability.

Copilot uses AI. Check for mistakes.

@LoserCheems
Copy link
Collaborator Author

hi @ftgreat, please try the changes in this PR bench 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG REPORT] INF occurs in backward phrase of the first training step

8 participants