Skip to content

Commit 59b3de6

Browse files
committed
[NPU] test TransformerTTS with NPU
1 parent 3568bb6 commit 59b3de6

File tree

1 file changed

+5
-3
lines changed
  • paddlespeech/t2s/exps/transformer_tts

1 file changed

+5
-3
lines changed

paddlespeech/t2s/exps/transformer_tts/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@
4242
def train_sp(args, config):
4343
# decides device type and whether to run in parallel
4444
# setup running environment correctly
45-
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
46-
paddle.set_device("cpu")
47-
else:
45+
if paddle.is_compiled_with_cuda() and args.ngpu > 0:
4846
paddle.set_device("gpu")
47+
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
48+
paddle.set_device("npu")
49+
else:
50+
paddle.set_device("cpu")
4951
world_size = paddle.distributed.get_world_size()
5052
if world_size > 1:
5153
paddle.distributed.init_parallel_env()

0 commit comments

Comments
 (0)