Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JSON output from diarization now includes sentences. Optimized senten… #3897

Merged
merged 7 commits into from
Apr 1, 2022
133 changes: 86 additions & 47 deletions nemo/collections/asr/parts/utils/diarization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,14 +486,23 @@ def make_json_output(self, uniq_id, diar_hyp, word_dict_seq_list, total_riva_dic
labels = diar_hyp[uniq_id]
n_spk = self.get_num_of_spk_from_labels(labels)
riva_dict = od(
{'status': 'Success', 'session_id': uniq_id, 'transcription': '', 'speaker_count': n_spk, 'words': [],}
{
'status': 'Success',
'session_id': uniq_id,
'transcription': '',
'speaker_count': n_spk,
'words': [],
'sentences': [],
}
)
gecko_dict = od({'schemaVersion': 2.0, 'monologues': []})
start_point, end_point, speaker = labels[0].split()
string_out = self.print_time(speaker, start_point, end_point, self.params, previous_string='')
prev_speaker = speaker
terms_list = []

sentences = []
sentence = {'speaker': speaker, 'start_point': float(start_point), 'end_point': float(end_point), 'text': ''}

logging.info(f"Creating results for Session: {uniq_id} n_spk: {n_spk} ")
for k, line_dict in enumerate(word_dict_seq_list):
word, speaker = line_dict['word'], line_dict['speaker_label']
Expand All @@ -505,24 +514,33 @@ def make_json_output(self, uniq_id, diar_hyp, word_dict_seq_list, total_riva_dic
{'speaker': {'name': None, 'id': prev_speaker}, 'terms': terms_list}
)
terms_list = []
string_out = self.print_time(speaker, start_point, end_point, self.params, previous_string=string_out)

# remove trailing space in text
sentence['text'] = sentence['text'].strip()

# store last sentence
sentences.append(sentence)

# start construction of a new sentence
sentence = {'speaker': speaker, 'start_point': start_point, 'end_point': end_point, 'text': ''}
else:
string_out = self.print_time(
speaker, start_point, end_point, self.params, previous_string=string_out, replace_time=True
)
stt_sec, end_sec = round(start_point, 2), round(end_point, 2)
# correct the ending time
sentence['end_point'] = end_point

stt_sec, end_sec = start_point, end_point
terms_list.append({'start': stt_sec, 'end': end_sec, 'text': word, 'type': 'WORD'})
string_out = self.print_word(string_out, word, self.params)

# add current word to sentence
sentence['text'] += word.strip() + ' '

self.add_json_to_dict(riva_dict, word, stt_sec, end_sec, speaker)
audacity_label_words.append(self.get_audacity_label(word, stt_sec, end_sec, speaker))
total_riva_dict[uniq_id] = riva_dict
prev_speaker = speaker

if self.params['break_lines']:
string_out = self.break_lines(string_out)
gecko_dict['monologues'].append({'speaker': {'name': None, 'id': speaker}, 'terms': terms_list})
riva_dict['transcription'] = ' '.join(word_seq_list)
self.write_and_log(uniq_id, riva_dict, string_out, audacity_label_words, gecko_dict)
self.write_and_log(uniq_id, riva_dict, audacity_label_words, gecko_dict, sentences)
return total_riva_dict

def get_realignment_ranges(self, k, word_seq_len):
Expand Down Expand Up @@ -946,10 +964,18 @@ def break_lines(self, string_out, max_chars_in_line=90):
return_string_out.append(org_chunk)
return '\n'.join(return_string_out)

def write_and_log(self, uniq_id, riva_dict, string_out, audacity_label_words, gecko_dict):
def write_and_log(self, uniq_id, riva_dict, audacity_label_words, gecko_dict, sentences):
"""
Write output files and display logging messages.
"""
# print the sentences in the .txt output
string_out = self.print_sentences(sentences, self.params)
if self.params['break_lines']:
string_out = self.break_lines(string_out)

# add sentences to the json array
self.add_sentences_to_dict(riva_dict, sentences)

ROOT = self.root_path
dump_json_to_file(f'{ROOT}/pred_rttms/{uniq_id}.json', riva_dict)
dump_json_to_file(f'{ROOT}/pred_rttms/{uniq_id}_gecko.json', gecko_dict)
Expand Down Expand Up @@ -986,41 +1012,44 @@ def print_errors(self, DER_result_dict, WDER_dict):
\nSpk. counting acc.: {DER_result_dict['total']['spk_counting_acc']:.4f}"
)

def print_time(self, speaker, start_point, end_point, params, previous_string=None, replace_time=False):
def print_sentences(self, sentences, params):
"""
Print a transcript with speaker labels and timestamps.
"""
if not previous_string:
string_out = ''
else:
string_out = previous_string
if params['colored_text']:
color = self.color_palette.get(speaker, '\033[0;37m')
else:
color = ''
# init output
string_out = ''

for sentence in sentences:
# extract info
speaker = sentence['speaker']
start_point = sentence['start_point']
end_point = sentence['end_point']
text = sentence['text']

if params['colored_text']:
color = self.color_palette.get(speaker, '\033[0;37m')
else:
color = ''

datetime_offset = 16 * 3600
if float(start_point) > 3600:
time_str = "%H:%M:%S.%f"
else:
time_str = "%M:%S.%f"
start_point, end_point = max(float(start_point), 0), max(float(end_point), 0)
start_point_str = datetime.fromtimestamp(start_point - datetime_offset).strftime(time_str)[:-4]
end_point_str = datetime.fromtimestamp(end_point - datetime_offset).strftime(time_str)[:-4]

if replace_time:
old_start_point_str = string_out.split('\n')[-1].split(' - ')[0].split('[')[-1]
word_sequence = string_out.split('\n')[-1].split(' - ')[-1].split(':')[-1].strip() + ' '
string_out = '\n'.join(string_out.split('\n')[:-1])
time_str = "[{} - {}] ".format(old_start_point_str, end_point_str)
else:
time_str = "[{} - {}] ".format(start_point_str, end_point_str)
word_sequence = ''
# cast timestamp to the correct format
datetime_offset = 16 * 3600
if float(start_point) > 3600:
time_str = '%H:%M:%S.%f'
else:
time_str = '%M:%S.%f'
start_point, end_point = max(float(start_point), 0), max(float(end_point), 0)
start_point_str = datetime.fromtimestamp(start_point - datetime_offset).strftime(time_str)[:-4]
end_point_str = datetime.fromtimestamp(end_point - datetime_offset).strftime(time_str)[:-4]

if params['print_time']:
time_str = f'[{start_point_str} - {end_point_str}] '
else:
time_str = ''

# string out concatenation
string_out += f'{color}{time_str}{speaker}: {text}\n'

if not params['print_time']:
time_str = ''
strd = "\n{}{}{}: {}".format(color, time_str, speaker, word_sequence.lstrip())
return string_out + strd
return string_out

@staticmethod
def threshold_non_speech(source_list, params):
Expand Down Expand Up @@ -1049,11 +1078,6 @@ def get_audacity_label(word, stt_sec, end_sec, speaker):
spk = speaker.split('_')[-1]
return f'{stt_sec}\t{end_sec}\t[{spk}] {word}'

@staticmethod
def print_word(string_out, word, params):
word = word.strip()
return string_out + word + " "

@staticmethod
def softmax(logits):
e = np.exp(logits - np.max(logits))
Expand All @@ -1067,3 +1091,18 @@ def get_num_of_spk_from_labels(labels):
@staticmethod
def add_json_to_dict(riva_dict, word, stt, end, speaker):
riva_dict['words'].append({'word': word, 'start_time': stt, 'end_time': end, 'speaker_label': speaker})

@staticmethod
def add_sentences_to_dict(riva_dict, sentences):
# iterate over sentences
for sentence in sentences:
# extract info
speaker = sentence['speaker']
start_point = sentence['start_point']
end_point = sentence['end_point']
text = sentence['text']

# save to riva_dict
riva_dict['sentences'].append(
{'sentence': text, 'start_time': start_point, 'end_time': end_point, 'speaker_label': speaker}
)