Skip to content

Commit

Permalink
fix device setting to allow using accelerator cpu (NVIDIA#8084)
Browse files Browse the repository at this point in the history
* fix device setting to allow using accelerator cpu

Signed-off-by: Oren Amsalem <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Oren Amsalem <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
2 people authored and sashameister committed Feb 15, 2024
1 parent 825a357 commit a2a0857
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions nemo/collections/asr/modules/wav2vec_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TransposeLast(torch.nn.Module):
"""
Expand Down Expand Up @@ -341,7 +339,7 @@ def apply_transformer(self, x, padding_mask=None):
def create_padding_mask(self, length):
# Broadcast to vectorize creating the padding mask
max_len = max(length)
padding_mask = torch.arange(max_len, device=DEVICE)
padding_mask = torch.arange(max_len, device=length.device)

# Switch to binary for transformer, 1 for valid tokens, 0 for padding
padding_mask = (padding_mask.expand(len(length), max_len) < length.unsqueeze(1)).type(torch.uint8)
Expand Down

0 comments on commit a2a0857

Please sign in to comment.