Skip to content

Conversation

guangzlu
Copy link

Motivation

Found bug when using fp8 fla + torch compile

Technical Details

When datatype for QKV is FP8, datatype for out should be bf16

@Copilot Copilot AI review requested due to automatic review settings October 13, 2025 06:02
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a bug in torch compile when using FP8 data types with Flash Attention (FLA). The fix ensures that when QKV tensors are in FP8 format, the output tensor is correctly created with BF16 data type instead of inheriting the FP8 type.

  • Added conditional logic to handle FP8 input tensors by creating BF16 output tensors
  • Maintains existing behavior for non-FP8 data types

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines 113 to +123
else:
out = torch.empty(
(batch_size, seqlen_q, num_heads, head_size_v),
dtype=q.dtype,
device=q.device,
requires_grad=q.requires_grad,
)
if q.dtype == dtypes.fp8:
out = torch.empty(
(batch_size, seqlen_q, num_heads, head_size_v),
dtype=dtypes.bf16,
device=q.device,
requires_grad=q.requires_grad,
)
else:
out = torch.empty(
(batch_size, seqlen_q, num_heads, head_size_v),
Copy link

Copilot AI Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested if-else structure creates duplicated tensor creation logic. Consider restructuring to determine the output dtype first, then create the tensor once to reduce code duplication.

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant