Skip to content

Commit 6d4d355

Browse files
committed
add fsdpa
1 parent d2da44f commit 6d4d355

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,24 @@ def fusion_flash_attention(
221221
attention_mask is None,
222222
True,
223223
)[0]
224+
elif get_env_device() == "intel_hpu":
225+
if config.context_parallel_degree > 1:
226+
raise ValueError("Context parallel is not implemented for intel_hpu")
227+
scaling_factor = query_states.shape[3] ** -0.5
228+
attention_mask = attention_mask.astype("bfloat16")
229+
attn_output = paddle.incubate.nn.functional.fused_dot_product_attention(
230+
query_states,
231+
key_states,
232+
value_states,
233+
attention_mask,
234+
scaling_factor,
235+
0.0,
236+
False,
237+
attention_mask is None,
238+
None,
239+
False,
240+
)
241+
attn_output = paddle.transpose(attn_output, [0, 2, 1, 3])
224242
else:
225243
if config.context_parallel_degree > 1:
226244
attn_output = RingFlashAttention.apply(

paddlenlp/transformers/llama/modeling.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,11 @@ def forward(
16971697

16981698
is_casual = False
16991699

1700-
if attn_mask_startend_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu":
1700+
if (
1701+
attn_mask_startend_row_indices is None
1702+
and self.config.use_flash_attention
1703+
and get_env_device() not in ["gcu", "intel_hpu"]
1704+
):
17011705
if self.config.use_flash_attention_for_generation or use_casual_mask:
17021706
is_casual = True
17031707
else:

0 commit comments

Comments
 (0)