Skip to content

Commit

Permalink
Main script refactoring: dataloading
Browse files Browse the repository at this point in the history
Added multiprocessing and engine choice (currently pd.cudf supported)

Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
ssh-meister authored Nov 27, 2023
1 parent 278fd66 commit df20f08
Showing 1 changed file with 238 additions and 0 deletions.
238 changes: 238 additions & 0 deletions tools/speech_data_explorer/sde/dataloader/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import json
from tqdm import tqdm
import multiprocessing as mp
import pickle
import os
from datetime import datetime

from nemo.utils import logging
from .sample import Sample
from .engines.cudf_engine import cuDF

class Dataset:
def __init__(self, manifest_filepath: str, chunksize: int = 10000, data_engine: object = None, n_jobs: int = -1,
reference_field = "text", hypothesis_fields: list[str] = ["pred_text"], hypothesis_labels: list[str] = None,
estimate_audio_metrics: bool = False,
enable_plk: bool = True, plk_filepath: str = None):
self.manifest_filepath = manifest_filepath
self.chunksize = chunksize
self.data_engine = data_engine
self.n_jobs = n_jobs

max_jobs = mp.cpu_count()
if self.n_jobs == -1 or n_jobs > max_jobs:
self.n_jobs = max_jobs

self.reference_field = reference_field
self.hypothesis_fields = hypothesis_fields
self.hypothesis_labels = hypothesis_labels
self.hypotheses = dict()
self.estimate_audio_metrics = estimate_audio_metrics
self.enable_plk = enable_plk
self.plk_filepath = plk_filepath
self.chunks = []

self.num_words = 0
self.num_chars = 0
self.duration = 0
self.charset = set()
self.words_frequencies = dict()

self.samples_data = []
self.vocabulary_data = []

def _check_hypotheses(self, manifest_line: str):
if self.hypothesis_fields is not None:
if self.hypothesis_labels is None:
if len(self.hypothesis_fields) == 1:
self.hypothesis_labels = [""]
else:
self.hypothesis_labels = list(range(1, len(self.hypothesis_fields) + 1))

if len(self.hypothesis_labels) != len(self.hypothesis_fields):
logging.error(f"Amount of hypothesis_labels ({len(self.hypothesis_labels)}) is not equal to amount of hypothesis_fields ({len(self.hypothesis_fields)}).")
raise
else:
sample_to_check = json.loads(manifest_line)

i = 0
while i < len(self.hypothesis_fields):
hypothesis_field = self.hypothesis_fields[i]
if hypothesis_field not in sample_to_check:
logging.warning(f"Field '{hypothesis_field}' not found in sample.")
self.hypothesis_fields.pop(i)
self.hypothesis_labels.pop(i)
else:
logging.info(f"Field '{hypothesis_field}' was found (labeled as '{self.hypothesis_labels[i]}').")
self.hypotheses[hypothesis_field] = HypothesisMetrics(hypothesis_label = self.hypothesis_labels[i])
i += 1

def _read_manifest(self):
logging.info("Reading manifest..")
with open(self.manifest_filepath, 'r', encoding = "utf8") as manifest:
lines = manifest.readlines()

self._check_hypotheses(lines[0])

lines_amount = len(lines)
logging.info(f"Lines amount: {lines_amount}. Splitting to chunks ({self.chunksize} lines per chunk)..")

start_chunk_indicies = list(range(0, lines_amount, self.chunksize))
end_chunk_indicies = list(range(self.chunksize, lines_amount, self.chunksize)) + [lines_amount]

for start_idx, end_idx in tqdm(zip(start_chunk_indicies, end_chunk_indicies), total = len(start_chunk_indicies)):
chunk = DataChunk(manifest_lines = lines[start_idx : end_idx], data_engine = self.data_engine, reference_field = self.reference_field,
hypothesis_fields = self.hypothesis_fields, hypothesis_labels = self.hypothesis_labels,
estimate_audio_metrics = self.estimate_audio_metrics)
self.chunks.append(chunk)

def _get_plk_filepath(self):
timestamp = datetime.fromtimestamp(os.path.getmtime(self.manifest_filepath)).strftime('%Y-%m-%d_%H-%M-%S')
return f"{self.manifest_filepath.replace('.json', '')}_{timestamp}.pkl"

def _read_pickle(self):
with open(self.plk_filepath, 'rb') as pkl:
return pickle.load(pkl)

def _write_pickle(self):
logging.info(f'Saving .plk file..')
with open(self.plk_filepath, 'wb') as pkl:
pickle.dump(self, pkl, pickle.HIGHEST_PROTOCOL)

logging.info(f'{self.plk_filepath} saved.')

def process(self):
if self.enable_plk:
logging.info(f'Looking for .plk file ({self.plk_filepath})')
if self.plk_filepath is None:
self.plk_filepath = self._get_plk_filepath()

if os.path.exists(self.plk_filepath):
logging.info(f'{self.plk_filepath} found.')
return self._read_pickle()
else:
logging.info(f'{self.plk_filepath} not found. Loading from data from manifest..')

self._read_manifest()

processed_chunks = []
logging.info(f'Samples processing ({self.n_jobs} processes)..')
with mp.Pool(self.n_jobs) as pool:
for processed_chunk in tqdm(pool.imap(DataChunk.process, self.chunks), total=len(self.chunks)):
processed_chunks.append(processed_chunk)

self.chunks = processed_chunks

logging.info(f'Global metrics computing..')
for chunk in tqdm(self.chunks):
self.num_words += chunk.num_words
self.num_chars += chunk.num_chars
self.duration += chunk.duration
self.charset.update(chunk.charset)

for hypothesis_field in chunk.hypotheses:
self.hypotheses[hypothesis_field].update(chunk.hypotheses[hypothesis_field])

for word in chunk.words_frequencies:
self.words_frequencies[word] = self.words_frequencies.get(word, 0) + chunk.words_frequencies[word]

if self.data_engine is not None:
self.samples_data.append(chunk.samples_data)

for hypothesis_field in self.hypotheses:
self.hypotheses[hypothesis_field].compute(dataset_num_words = self.num_words, dataset_num_chars = self.num_chars)

self.duration = round(self.duration / 3600 , 2)

if self.data_engine is not None:
logging.info(f'Samples datatable loading..')
self.samples_data = self.data_engine.concat_samples_chunks(self.samples_data)
self.vocabulary_data = self.data_engine.process_vocabulary(words_frequencies = self.words_frequencies,
hypotheses_metrics = self.hypotheses.values())

if self.enable_plk:
self._write_pickle()

return self


class DataChunk:
def __init__(self, manifest_lines: list[str], data_engine: object = None,
reference_field: str = "text", hypothesis_fields: list[str] = ["pred_text"], hypothesis_labels: list[str] = None,
estimate_audio_metrics: bool = False):
self.manifest_lines = manifest_lines
self.reference_field = reference_field
self.estimate_audio_metrics = estimate_audio_metrics
self.samples_dicts = []
self.num_words = 0
self.num_chars = 0
self.duration = 0
self.charset = set()
self.words_frequencies = dict()

self.hypothesis_fields = hypothesis_fields
self.hypothesis_labels = hypothesis_labels
self.hypotheses = dict()

for field, label in zip(hypothesis_fields, hypothesis_labels):
self.hypotheses[field] = HypothesisMetrics(hypothesis_label = label)

self.data_engine = data_engine
self.samples_data = None

def process(self):
sample = Sample()
for manifest_line in self.manifest_lines:
sample.parse_line(manifest_line, reference_field = self.reference_field,
hypothesis_fields = self.hypothesis_fields,
hypothesis_labels = self.hypothesis_labels)
sample.compute(estimate_audio_metrics = self.estimate_audio_metrics)

self.samples_dicts.append(sample.sample_dict)
self.num_words += sample.num_words
self.num_chars += sample.num_chars
self.duration += sample.duration
self.charset.update(sample.charset)

for word in sample.words_frequencies:
self.words_frequencies[word] = self.words_frequencies.get(word, 0) + sample.words_frequencies[word]

for hypothesis_field in sample.hypotheses:
self.hypotheses[hypothesis_field].update(sample.hypotheses[hypothesis_field])

sample.reset()

if self.data_engine is not None:
self.samples_data = self.data_engine.load_samples_chunk(self.samples_dicts)
self.samples_dicts = {}

return self

class HypothesisMetrics:
def __init__(self, hypothesis_label: str = None):
self.hypothesis_label = hypothesis_label
self.word_distance = 0
self.word_match = 0
self.char_distance = 0

self.wer = None
self.wmr = None
self.cer = None
self.mwa = None

self.match_words_frequencies = dict()

def update(self, hypothesis: object):
assert self.hypothesis_label == hypothesis.hypothesis_label, "Hypothesis label mismatch!"

self.word_distance += hypothesis.word_distance
self.word_match += hypothesis.word_match
self.char_distance += hypothesis.char_distance

for word in hypothesis.match_words_frequencies:
self.match_words_frequencies[word] = self.match_words_frequencies.get(word, 0) + hypothesis.match_words_frequencies[word]

def compute(self, dataset_num_words: int, dataset_num_chars: int):
self.wer = round(self.word_distance / dataset_num_words * 100.0, 2)
self.wmr = round(self.word_match / dataset_num_words * 100.0, 2)
self.cer = round(self.char_distance / dataset_num_chars * 100.0, 2)

0 comments on commit df20f08

Please sign in to comment.