Skip to content

Commit

Permalink
Fixed typos, optimized plots loading for gpu mode
Browse files Browse the repository at this point in the history
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
ssh-meister authored Dec 1, 2023
1 parent fcb7aa9 commit 98cf987
Showing 1 changed file with 59 additions and 56 deletions.
115 changes: 59 additions & 56 deletions tools/speech_data_explorer/data_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

from sde.dataloader.dataset import Dataset
from sde.dataloader.engines.cudf_engine import cuDF
from sde.pages.statistics.plot import gpu_plot_histogram, gpu_plot_word_accuracy

# number of items in a table per page
DATA_PAGE_SIZE = 10
Expand Down Expand Up @@ -121,7 +122,10 @@ def parse_args():
help='field name for which you want to see statistics (optional). Example: pred_text_contextnet.',
)
parser.add_argument(
'--gpu', '-gpu', action='store_true', help='use GPU-acceleration',
'--gpu',
'-gpu',
action='store_true',
help='use GPU-acceleration',
)
args = parser.parse_args()

Expand Down Expand Up @@ -482,12 +486,9 @@ def append_data(


# plot histogram of specified field in data list
def plot_histogram(data, key, label, gpu_acceleration=False):
if gpu_acceleration:
data_frame = data[key].to_list()
else:
data_frame = [item[key] for item in data]

def plot_histogram(data, key, label):
data_frame = [item[key] for item in data]

fig = px.histogram(
data_frame=data_frame,
nbins=50,
Expand All @@ -501,10 +502,10 @@ def plot_histogram(data, key, label, gpu_acceleration=False):
return fig


def plot_word_accuracy(vocabulary_data):
def plot_word_accuracy(vocabulary_data):
labels = ['Unrecognized', 'Sometimes recognized', 'Always recognized']
counts = [0, 0, 0]

if args.gpu:
counts[0] = (vocabulary_data['Accuracy'] == 0).sum()
counts[1] = (vocabulary_data['Accuracy'] < 100).sum()
Expand Down Expand Up @@ -572,28 +573,25 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
print('Loading data...')
if args.gpu:
if args.names_compared is not None:
raise Exception(f"Currently comparision mode is not available with gpu acceleation.")

raise Exception(f"Currently, comparison mode is not available with GPU acceleration.")
hypothesis_fields = ["pred_text"]
if args.show_statistics is not None:
hypothesis_fields = [args.show_statistics]

enable_plk = True
enable_pkl = True
if args.disable_caching_metrics:
enable_plk = False

enable_pkl = False
cu_df = cuDF()

dataset = Dataset(
manifest_filepath=args.manifest,
data_engine=cu_df,
hypothesis_fields=hypothesis_fields,
estimate_audio_metrics=args.estimate_audio_metrics,
enable_plk=enable_plk,
)

dataset = Dataset(manifest_filepath = args.manifest, data_engine = cu_df,
hypothesis_fields = hypothesis_fields,
estimate_audio_metrics = args.estimate_audio_metrics,
enable_pkl = enable_pkl)

dataset = dataset.process()

data = dataset.samples_data
num_hours = dataset.duration
vocabulary = dataset.vocabulary_data
Expand All @@ -602,8 +600,8 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
metrics_available = len(dataset.hypotheses) != 0
if metrics_available:
wer = dataset.hypotheses[hypothesis_fields[0]].wer
cer = dataset.hypotheses[hypothesis_fields[0]].cer
wmr = dataset.hypotheses[hypothesis_fields[0]].wmr
cer = dataset.hypotheses[hypothesis_fields[0]].cer
wmr = dataset.hypotheses[hypothesis_fields[0]].wmr
mwa = dataset.hypotheses[hypothesis_fields[0]].mwa
else:
if not comparison_mode:
Expand Down Expand Up @@ -682,31 +680,36 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):

if args.gpu:
fields = data.columns
for field in fields:
val = data[field][0]
if isinstance(val, (int, float, np.int64, np.float64)) and not isinstance(val, bool):
if field in figures_labels:
title = figures_labels[field][0] + ' (per utterance)'
else:
title = field.replace('_', ' ')
title = title[0].upper() + title[1:].lower()
figures_hist[field] = [title, gpu_plot_histogram(data, field)]
if metrics_available:
figure_word_acc = gpu_plot_word_accuracy(vocabulary_data, "Accuracy")

else:
fields = data[0].keys()

for k in fields:
if args.gpu:
val = data[k][0]
else:
for k in fields:
val = data[0][k]
if isinstance(val, (int, float)) and not isinstance(val, bool):
if k in figures_labels:
ylabel = figures_labels[k][0]
xlabel = figures_labels[k][1]
else:
title = k.replace('_', ' ')
title = title[0].upper() + title[1:].lower()
ylabel = title
xlabel = title
figures_hist[k] = [ylabel + ' (per utterance)', plot_histogram(data, k, xlabel, args.gpu)]

if metrics_available:
if args.gpu:
figure_word_acc = plot_word_accuracy(vocabulary_data)
else:
figure_word_acc = plot_word_accuracy(vocabulary)
if isinstance(val, (int, float)) and not isinstance(val, bool):
if k in figures_labels:
ylabel = figures_labels[k][0]
xlabel = figures_labels[k][1]
else:
title = k.replace('_', ' ')
title = title[0].upper() + title[1:].lower()
ylabel = title
xlabel = title
figures_hist[k] = [ylabel + ' (per utterance)', plot_histogram(data, k, xlabel)]

if metrics_available:
figure_word_acc = plot_word_accuracy(vocabulary)

stats_layout = [
dbc.Row(dbc.Col(html.H5(children='Global Statistics'), class_name='text-secondary'), class_name='mt-3'),
dbc.Row(
Expand Down Expand Up @@ -813,6 +816,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
dbc.Col(html.Div('{}'.format(sorted(alphabet))),), class_name='mt-2 bg-light font-monospace rounded border'
),
]

for k in figures_hist:
stats_layout += [
dbc.Row(dbc.Col(html.H5(figures_hist[k][0]), class_name='text-secondary'), class_name='mt-3'),
Expand All @@ -827,7 +831,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):

wordstable_columns = [{'name': 'Word', 'id': 'Word'}, {'name': 'Count', 'id': 'Amount'}]

if args.gpu:
if args.gpu:
vocabulary_columns = vocabulary.columns
else:
vocabulary_columns = vocabulary[0].keys()
Expand Down Expand Up @@ -910,22 +914,22 @@ def update_wordstable(page_current, sort_by, filter_query):
if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
if args.gpu:
vocabulary_view = vocabulary_view.loc[getattr(operator, op)(vocabulary_view[col_name], filter_value)]
else:
else:
vocabulary_view = [x for x in vocabulary_view if getattr(operator, op)(x[col_name], filter_value)]
elif op == 'contains':
vocabulary_view = [x for x in vocabulary_view if filter_value in str(x[col_name])]

if len(sort_by):
col = sort_by[0]['column_id']
ascending = sort_by[0]['direction'] != 'desc'

if args.gpu:
vocabulary_view = vocabulary_view.sort_values(col, ascending=ascending)
else:
vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=descending)
vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=ascending)
if page_current * DATA_PAGE_SIZE >= len(vocabulary_view):
page_current = len(vocabulary_view) // DATA_PAGE_SIZE

if args.gpu:
return [
vocabulary_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'),
Expand All @@ -937,7 +941,6 @@ def update_wordstable(page_current, sort_by, filter_query):
math.ceil(len(vocabulary_view) / DATA_PAGE_SIZE),
]


if args.gpu:
col_names = data.columns
else:
Expand Down Expand Up @@ -1565,22 +1568,22 @@ def update_datatable(page_current, sort_by, filter_query):
if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
if args.gpu:
data_view = data_view.loc[getattr(operator, op)(data_view[col_name], filter_value)]
else:
else:
data_view = [x for x in data_view if getattr(operator, op)(x[col_name], filter_value)]
elif op == 'contains':
data_view = [x for x in data_view if filter_value in str(x[col_name])]

if len(sort_by):
col = sort_by[0]['column_id']
ascending = sort_by[0]['direction'] != 'desc'

if args.gpu:
data_view = data_view.sort_values(col, ascending=ascending)
else:
data_view = sorted(data_view, key=lambda x: x[col], reverse=descending)
if page_current * DATA_PAGE_SIZE >= len(data_view):
page_current = len(data_view) // DATA_PAGE_SIZE

if args.gpu:
return [
data_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'),
Expand Down

0 comments on commit 98cf987

Please sign in to comment.