We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3568bb6 commit 59b3de6Copy full SHA for 59b3de6
paddlespeech/t2s/exps/transformer_tts/train.py
@@ -42,10 +42,12 @@
42
def train_sp(args, config):
43
# decides device type and whether to run in parallel
44
# setup running environment correctly
45
- if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
46
- paddle.set_device("cpu")
47
- else:
+ if paddle.is_compiled_with_cuda() and args.ngpu > 0:
48
paddle.set_device("gpu")
+ elif paddle.is_compiled_with_npu() and args.ngpu > 0:
+ paddle.set_device("npu")
49
+ else:
50
+ paddle.set_device("cpu")
51
world_size = paddle.distributed.get_world_size()
52
if world_size > 1:
53
paddle.distributed.init_parallel_env()
0 commit comments