Skip to content

Commit d9a3ad6

Browse files
committed
Merge remote-tracking branch 'PaddleNLP_feiyue/intel_hpu' into support_intel_hpu_backend
2 parents 6569e3f + 53f4fdc commit d9a3ad6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddlenlp/transformers/llama/fusion_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False):
147147
elif get_env_device() == "gcu":
148148
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
149149
elif get_env_device() == "intel_hpu":
150-
return paddle.incubate.nn.functional.fused_rms_norm(hidden_states, weight, None, variance_epsilon, 2)[0]
150+
return paddle.incubate.nn.functional.fused_rms_norm(
151+
hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
152+
)[0]
151153
elif get_env_device() == "xpu":
152154
try:
153155
import paddle_xpu_nn # noqa: F821

0 commit comments

Comments
 (0)