Skip to content

Commit

Permalink
Add experimental MPS device support for ASR inference (NVIDIA#6289)
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Bataev <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
artbataev authored and hsiehjackson committed Jun 2, 2023
1 parent 976be2b commit b2bee61
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
batch_size: batch size during inference
cuda: Optional int to enable or disable execution of model on certain CUDA device.
allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available
amp: Bool to decide if Automatic Mixed Precision should be used during inference
audio_type: Str filetype of the audio. Supported = wav, flac, mp3
Expand Down Expand Up @@ -129,6 +130,7 @@ class TranscriptionConfig:
# device anyway, and do inference on CPU only if CUDA device is not found.
# If `cuda` is a negative number, inference will be on CPU only.
cuda: Optional[int] = None
allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU)
amp: bool = False
audio_type: str = "wav"

Expand Down Expand Up @@ -175,14 +177,25 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
if torch.cuda.is_available():
device = [0] # use 0th CUDA device
accelerator = 'gpu'
map_location = torch.device('cuda:0')
elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logging.warning(
"MPS device (Apple Silicon M-series GPU) support is experimental."
" Env variable `PYTORCH_ENABLE_MPS_FALLBACK=1` should be set in most cases to avoid failures."
)
device = [0]
accelerator = 'mps'
map_location = torch.device('mps')
else:
device = 1
accelerator = 'cpu'
map_location = torch.device('cpu')
else:
device = [cfg.cuda]
accelerator = 'gpu'
map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
logging.info(f"Inference will be done on device : {device}")
map_location = torch.device(f'cuda:{cfg.cuda}')

logging.info(f"Inference will be done on device: {map_location}")

asr_model, model_name = setup_model(cfg, map_location)

Expand Down

0 comments on commit b2bee61

Please sign in to comment.