Skip to content

Commit

Permalink
refactor metrics and add a useful eval script
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed Jul 6, 2024
1 parent 4a2e8cd commit afa0099
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 9 deletions.
30 changes: 21 additions & 9 deletions DiarizationLM/diarizationlm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,9 @@ class UtteranceMetrics:
wder_total: int = 0


def compute_utterance_metrics(
hyp_text: str,
ref_text: str,
hyp_spk: Optional[str] = None,
ref_spk: Optional[str] = None,
) -> UtteranceMetrics:
def compute_wer(
hyp_text: str, ref_text: str
) -> tuple[UtteranceMetrics, list[tuple[int, int]]]:
"""Compute the word error rate of an utterance."""
result = UtteranceMetrics()
hyp_normalized = utils.normalize_text(hyp_text)
Expand All @@ -72,10 +69,25 @@ def compute_utterance_metrics(

result.wer_total = result.wer_correct + result.wer_sub + result.wer_delete
assert result.wer_total == len(ref_words)
return result, align


def compute_utterance_metrics(
hyp_text: str,
ref_text: str,
hyp_spk: Optional[str] = None,
ref_spk: Optional[str] = None,
) -> UtteranceMetrics:
"""Compute all metrics of an utterance."""
hyp_normalized = utils.normalize_text(hyp_text)
ref_normalized = utils.normalize_text(ref_text)
hyp_words = hyp_normalized.split()
ref_words = ref_normalized.split()

result, align = compute_wer(hyp_text, ref_text)

# Compute WDER if needed.
compute_wder = hyp_spk or ref_spk
if not compute_wder:
compute_diarization_metrics = hyp_spk or ref_spk
if not compute_diarization_metrics:
return result

if not (hyp_spk and ref_spk):
Expand Down
64 changes: 64 additions & 0 deletions DiarizationLM/unsloth/4_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Post-process completions and compute metrics."""
import os
import json
import tqdm

import config

from diarizationlm import utils
from diarizationlm import metrics


def postprocess(input_file: str, output_file: str) -> None:
# Load data.
with open(input_file, "rt") as f:
data_dict = json.load(f)

po = utils.PromptOptions(completion_suffix=config.COMPLETION_SUFFIX)

# Process utterances.
for utt in tqdm.tqdm(data_dict["utterances"]):
utils.postprocess_completions_for_utt(
utt,
llm_text_field="llm_text",
llm_speaker_field="llm_spk",
transfered_llm_speaker_field="hyp_spk_llm",
hyp_text_field="hyp_text",
hyp_spk_field="hyp_spk",
po=po,
)

# Write output.
with open(output_file, "wt") as f:
json.dump(data_dict, f, indent=2)
print("Output JSON processed file written to:", output_file)


def evaluate(input_file: str, output_file: str) -> None:
with open(input_file) as f:
json_dict = json.load(f)
result_dict = metrics.compute_metrics_on_json_dict(
json_dict,
ref_text_field="ref_text",
hyp_text_field="hyp_text",
ref_spk_field="ref_spk",
hyp_spk_field="hyp_spk_llm",
)
with open(output_file, "wt") as f:
json.dump(result_dict, f, indent=2)
print("Output JSON metrics file written to:", output_file)


if __name__ == "__main__":
for eval_dataset in config.EVAL_INPUTS:
print("Evaluating:", eval_dataset)
output_dir = os.path.join(config.MODEL_ID,
"decoded",
f"checkpoint-{config.CHECKPOINT}",
eval_dataset)
postprocess(
input_file=os.path.join(output_dir, "final.json"),
output_file=os.path.join(output_dir, "postprocessed.json"))
evaluate(
input_file=os.path.join(output_dir, "postprocessed.json"),
output_file=os.path.join(output_dir, "metrics.json"))

0 comments on commit afa0099

Please sign in to comment.