Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
Remove track_running_stats = False for now. Needs input from other developers.
  • Loading branch information
galv committed May 17, 2024
1 parent 9a92820 commit 3ffbcf9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 58 deletions.
45 changes: 21 additions & 24 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import os
from dataclasses import dataclass, is_dataclass
from tempfile import NamedTemporaryFile
import time
from typing import List, Optional, Union

import pytorch_lightning as pl
import soundfile as sf
import torch
from omegaconf import OmegaConf, open_dict

Expand Down Expand Up @@ -83,6 +83,8 @@
langid: Str used for convert_num_to_words during groundtruth cleaning
use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER)
calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset.
# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
Expand Down Expand Up @@ -199,6 +201,8 @@ class TranscriptionConfig:
allow_partial_transcribe: bool = False
extract_nbest: bool = False # Extract n-best hypotheses from the model

calculate_rtfx: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]:
Expand Down Expand Up @@ -393,11 +397,15 @@ def autocast(dtype=None, enabled=True):

# transcribe audio

total_duration = 0.0
if cfg.calculate_rtfx:
total_duration = 0.0

with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
total_duration += json.loads(line)["duration"]
with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
item = json.loads(line)
if "duration" not in item:
raise ValueError(f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.")
total_duration += item["duration"]

with autocast(dtype=amp_dtype, enabled=cfg.amp):
with torch.no_grad():
Expand All @@ -422,33 +430,19 @@ def autocast(dtype=None, enabled=True):
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name

for i in range(2):
# if i == 1:
# import nvtx
# pr = nvtx.Profile()
# pr.enable() # begin annotating function calls
# ctx = torch.autograd.profiler.emit_nvtx()
# ctx.__enter__()
# torch.cuda.cudart().cudaProfilerStart()
import time

# import ipdb; ipdb.set_trace()

if cfg.calculate_rtfx:
start_time = time.time()
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)
end_time = time.time()
print("RTFx=", total_duration / (end_time - start_time))
# if i == 1:
# pr.disable()
# ctx.__exit__(None, None, None)
# torch.cuda.cudart().cudaProfilerStop()
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)
if cfg.calculate_rtfx:
transcribe_time = time.time() - start_time

if cfg.dataset_manifest is not None:
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")
if cfg.presort_manifest:
transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions)
else:
logging.info(f"Finished transcribing {len(filepaths)} files !")

logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

# if transcriptions form a tuple of (best_hypotheses, all_hypotheses)
Expand Down Expand Up @@ -493,6 +487,9 @@ def autocast(dtype=None, enabled=True):
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

if cfg.calculate_rtfx:
logging.info("Dataset RTFx {(transcribe_time/total_duration):.2}")

return cfg


Expand Down
3 changes: 0 additions & 3 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,10 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
q = q.transpose(1, 2) # (batch, time1, head, d_k)

n_batch_pos = pos_emb.size(0)

p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
# We are stuck casting this up to float32... ugh.
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
Expand Down Expand Up @@ -952,7 +950,6 @@ def extend_pe(self, length, device, dtype):
return
# positions would be from negative numbers to positive
# positive positions would be used for left positions and negative for right positions
# fix this
positions = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device).unsqueeze(1)
self.create_pe(positions=positions, dtype=dtype)

Expand Down
27 changes: 11 additions & 16 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,24 +286,19 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:

audio_key = cfg.get('audio_key', 'audio_filepath')

with NamedTemporaryFile("w", suffix=".json", delete=False) as durations_f, open(cfg.dataset_manifest, "rt") as f:
for item in f:
item = json.loads(item)
with open(cfg.dataset_manifest, "rt") as fh:
for line in fh:
item = json.loads(line)
item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest)
if item.get("duration") is None:
sound = sf.SoundFile(item["audio_filepath"])
duration = sound.frames / sound.samplerate
item["duration"] = duration
print(json.dumps(item), file=durations_f, flush=True)

all_entries_have_offset_and_duration = True
for item in read_and_maybe_sort_manifest(durations_f.name, try_sort=cfg.presort_manifest):
if not ("offset" in item and "duration" in item):
all_entries_have_offset_and_duration = False
audio_file = item[audio_key]
filepaths.append(audio_file)
if item.get("duration") is None and cfg.presort_manifest:
raise ValueError(f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.")
all_entries_have_offset_and_duration = True
for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=cfg.presort_manifest):
if not ("offset" in item and "duration" in item):
all_entries_have_offset_and_duration = False
audio_file = get_full_path(audio_file=item[audio_key], manifest_file=cfg.dataset_manifest)
filepaths.append(audio_file)
partial_audio = all_entries_have_offset_and_duration
cfg.dataset_manifest = durations_f.name
logging.info(f"\nTranscribing {len(filepaths)} files...\n")

return filepaths, partial_audio
Expand Down
15 changes: 0 additions & 15 deletions nemo/core/classes/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,6 @@ def freeze(self) -> None:
for param in self.parameters():
param.requires_grad = False

# By default, if you run a batch norm module with
# training=False (calling self.eval() below makes
# training=False for a module and all its descendents),
# pytorch will still update batch norm statistics. This can
# cause a denial of service attack, because an adversary can
# provide an input that causes the batch norm statistics to
# get set to NaN or infinity. Setting track_running_stats to
# False prevents updates of the batch norm statistics.
for sub_module in self.modules():
if isinstance(sub_module,
(torch.nn.BatchNorm1d, torch.nn.LazyBatchNorm1d, torch.nn.BatchNorm2d,
torch.nn.LazyBatchNorm2d, torch.nn.BatchNorm3d, torch.nn.LazyBatchNorm3d,
torch.nn.SyncBatchNorm)):
sub_module.track_running_stats = False

self.eval()

def unfreeze(self) -> None:
Expand Down

0 comments on commit 3ffbcf9

Please sign in to comment.