Skip to content

Commit

Permalink
Merge pull request #1593 from windstamp/npu_dev_20220322
Browse files Browse the repository at this point in the history
[NPU] Add NPU support for TransformerTTS
  • Loading branch information
yt605155624 authored Mar 23, 2022
2 parents 6813002 + 59b3de6 commit 26ef478
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions paddlespeech/t2s/exps/transformer_tts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
if paddle.is_compiled_with_cuda() and args.ngpu > 0:
paddle.set_device("gpu")
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
paddle.set_device("npu")
else:
paddle.set_device("cpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()
Expand Down

0 comments on commit 26ef478

Please sign in to comment.