Skip to content

Commit fbf6fc1

Browse files
committed
[NPU] test TransformerTTS with NPU
1 parent 3568bb6 commit fbf6fc1

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

examples/ljspeech/tts1/local/train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ python3 ${BIN_DIR}/train.py \
88
--dev-metadata=dump/dev/norm/metadata.jsonl \
99
--config=${config_path} \
1010
--output-dir=${train_output_path} \
11-
--ngpu=2 \
11+
--ngpu=1 \
1212
--phones-dict=dump/phone_id_map.txt

examples/ljspeech/tts1/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
set -e
44
source path.sh
55

6-
gpus=0,1
6+
gpus=0
77
stage=0
88
stop_stage=100
99

paddlespeech/t2s/exps/transformer_tts/train.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@
4242
def train_sp(args, config):
4343
# decides device type and whether to run in parallel
4444
# setup running environment correctly
45-
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
46-
paddle.set_device("cpu")
47-
else:
45+
if paddle.is_compiled_with_cuda() and args.ngpu > 0:
4846
paddle.set_device("gpu")
47+
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
48+
paddle.set_device("npu")
49+
else:
50+
paddle.set_device("cpu")
4951
world_size = paddle.distributed.get_world_size()
5052
if world_size > 1:
5153
paddle.distributed.init_parallel_env()

0 commit comments

Comments
 (0)