Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherrypick: Support dataloader as input to audio for transcription (#9201) #9235

Merged
merged 1 commit into from
May 17, 2024

Conversation

titu1994
Copy link
Collaborator

@titu1994 titu1994 commented May 17, 2024

  • Support dataloader as input to audio for transcription

Signed-off-by: smajumdar [email protected]

  • Apply isort and black reformatting

Signed-off-by: titu1994 [email protected]

  • Support dataloader as input to audio for transcription

Signed-off-by: smajumdar [email protected]

  • Update transcribe signatures

Signed-off-by: smajumdar [email protected]

  • Apply isort and black reformatting

Signed-off-by: titu1994 [email protected]


Signed-off-by: smajumdar [email protected]
Signed-off-by: titu1994 [email protected]
(cherry picked from commit 67401ed)

What does this PR do ?

Enables the use of a pre-constructed data loader as input to the model.transcribe() function.
This allows for a fastpath to ignore all manifest and tensor handling to the user, only executing the model forward and later steps.

Collection: [ASR]

Changelog

  • Allows the user to provide a DataLoader object, which overrides internal computation of manifest processing or dataset construction
  • Assumes implicit faith in user provided input - the user is now responsible for formatting and providing all arguments to match up with the ASR model's forward arguments if user chooses to provide a dataloader.

Usage

from nemo.collections.asr.data.audio_to_text import _speech_collate_fn

model = ASRModel.from_pretrained("stt_en_conformer_ctc_small")

# Load audio file
import soundfile as SF

audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav")
audio, sr = sf.read(audio_file, dtype='float32')

audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an152-mwhw-b.wav")
audio2, sr = sf.read(audio_file2, dtype='float32')

# Create a dummy dataset to hold the tensor values
class DummyDataset(Dataset):
    def __init__(self, audio_tensors: List[str], config: Dict = None):
        self.audio_tensors = audio_tensors
        self.config = config

    def __getitem__(self, index):
        data = self.audio_tensors[index]
        samples = torch.tensor(data)
        # Calculate seq length
        seq_len = torch.tensor(samples.shape[0], dtype=torch.long)

        # Dummy text tokens
        text_tokens = torch.tensor([0], dtype=torch.long)
        text_tokens_len = torch.tensor(1, dtype=torch.long)

        # Ensure to provide output tokens that can be consumed by an ASR's forward function
        return (samples, seq_len, text_tokens, text_tokens_len)

    def __len__(self):
        return len(self.audio_tensors)

# Wrap the dataset into a data loader with proper collate function
dataset = DummyDataset([audio, audio2])
collate_fn = lambda x: _speech_collate_fn(x, pad_id=0)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn)

# DataLoader as input to audio
outputs = model.transcribe(dataloader, batch_size=1)

assert len(outputs) == 2
assert isinstance(outputs[0], str)
assert isinstance(outputs[1], str)

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

* Support dataloader as input to `audio` for transcription

Signed-off-by: smajumdar <[email protected]>

* Update transcribe signatures

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
(cherry picked from commit 67401ed)
@github-actions github-actions bot added the ASR label May 17, 2024
audio2, sr = sf.read(audio_file2, dtype='float32')

dataset = DummyDataset([audio, audio2])
collate_fn = lambda x: _speech_collate_fn(x, pad_id=0)

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

TestTranscriptionMixin.test_transcribe_dataloader.lambda returns
tuple of size 4
and
tuple of size 5
.
@titu1994 titu1994 merged commit 6a5187b into r2.0.0rc0 May 17, 2024
13 checks passed
@titu1994 titu1994 deleted the cherrypick_transcribe_dataloader branch May 17, 2024 16:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant