Skip to content

Commit d155697

Browse files
authored
[XPU] hybrid_parallel_optimizer xpu support bf16 (#60213)
1 parent d6f6402 commit d155697

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,8 @@ def _comm_and_clip(
225225
)
226226
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
227227

228-
# bf16 is not supported on XPU now
229-
if not (
230-
paddle.is_compiled_with_xpu()
231-
or isinstance(
232-
paddle.framework._current_expected_place(), paddle.CustomPlace
233-
)
228+
if not isinstance(
229+
paddle.framework._current_expected_place(), paddle.CustomPlace
234230
):
235231
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
236232
for p, g in params_grads:
@@ -241,10 +237,6 @@ def _comm_and_clip(
241237
if g.dtype == paddle.float16:
242238
g.multiply_(clip_var_fp16)
243239
elif g.dtype == paddle.bfloat16:
244-
if paddle.is_compiled_with_xpu():
245-
raise NotImplementedError(
246-
"BF16 is not supported on XPU now"
247-
)
248240
g.multiply_(clip_var_bf16)
249241
else:
250242
g.multiply_(clip_var)

0 commit comments

Comments
 (0)