Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for huggingface ASR models in hg recipes #751

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
feff2cd
refactor the hg interface to support multiple models through presets
Ahmedsaed Aug 15, 2024
d9f753e
Refactor ASR evaluation code for improved extensibility
Ahmedsaed Aug 15, 2024
a3329b7
correctly manage and free cuda memory
Ahmedsaed Aug 15, 2024
da5d8d3
fix type hints and lint
Ahmedsaed Aug 15, 2024
107aa5e
Use custom `EvalSeqbatch` instead of Seq2Seq2Batch
Ahmedsaed Aug 15, 2024
721baa4
Introduce AsrDatasetConfig to handle different ASR datasets
Ahmedsaed Aug 15, 2024
1f6772c
lint and fix type hints
Ahmedsaed Aug 15, 2024
45e6571
move split to AsrDatasetConfig and move tokenizer back to AsrEvalConfig
Ahmedsaed Aug 15, 2024
7fb74ed
update default split
Ahmedsaed Aug 15, 2024
657f40f
Merge branch 'main' into hg/AsrDatasetConfig
Ahmedsaed Aug 15, 2024
d80d154
Add whisper integration
Ahmedsaed Aug 16, 2024
9f035ce
Intorduce ModelConfig for dynamically loading huggingface models
Ahmedsaed Aug 16, 2024
58dd05c
Lint and update type hints
Ahmedsaed Aug 16, 2024
f92c551
lint and fix type hints
Ahmedsaed Aug 16, 2024
dc684f0
Refactor AsrDatasetConfig
Ahmedsaed Aug 20, 2024
e081e77
Merge branch 'main' into hg/AsrDatasetConfig
Ahmedsaed Aug 20, 2024
775f191
Merge branch 'main' into hg/whisper
Ahmedsaed Aug 20, 2024
433f463
refactor dynamically imported libraries
Ahmedsaed Aug 22, 2024
0f08e00
Merge branch 'hg/whisper' of https://github.com/Ahmedsaed/fairseq2 in…
Ahmedsaed Aug 22, 2024
460c777
Merge branch 'main' into hg/AsrDatasetConfig
Ahmedsaed Aug 26, 2024
58d4e03
Merge branch 'main' into hg/whisper
Ahmedsaed Aug 26, 2024
a0a266b
refactor code and lint
Ahmedsaed Sep 4, 2024
058bdf0
Merge branch 'hg/AsrDatasetConfig' of https://github.com/Ahmedsaed/fa…
Ahmedsaed Sep 4, 2024
201127c
lint
Ahmedsaed Sep 4, 2024
667194f
refactor code and lint
Ahmedsaed Sep 4, 2024
5fc5fed
lint
Ahmedsaed Sep 4, 2024
c645950
Merge branch 'hg/AsrDatasetConfig' of https://github.com/Ahmedsaed/fa…
Ahmedsaed Sep 4, 2024
0aaa6bb
Merge branch 'hg/AsrDatasetConfig' of https://github.com/Ahmedsaed/fa…
Ahmedsaed Sep 4, 2024
9d860f5
Merge branch 'main' into hg/whisper
Ahmedsaed Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor AsrDatasetConfig
Ahmedsaed committed Aug 20, 2024
commit dc684f03a6b91f858d8a5f106e30013cdcee7d79
25 changes: 7 additions & 18 deletions src/fairseq2/recipes/hg/asr_eval.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, List, Optional, Union
from typing import Any, List, Optional

import torch
from datasets import ( # type: ignore[attr-defined,import-untyped,import-not-found]
@@ -39,30 +39,19 @@ class AsrDatasetConfig:
"""Configuration for an automatic speech recognition dataset."""

dataset_path: str
"""The name of the dataset."""
"""The path to the dataset."""

dataset_name: Optional[str] = None
"""The name of the dataset split."""
"""The name of the dataset configuration."""

split: str = "test"
"""The name of the dataset split to evaluate with."""
"""Which split of the data to load."""

source_column: List[str] = field(default_factory=list)
"""The path of the column containing the source audio."""
"""The path to the column containing the source audio."""

target_column: List[str] = field(default_factory=list)
"""The path of the column containing the target text."""

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AsrDatasetConfig":
"""Create an AsrDatasetConfig instance from a configuration dictionary."""
return cls(
dataset_path=config_dict.get("dataset_path", ""),
dataset_name=config_dict.get("dataset_name"),
source_column=config_dict.get("source_column", []),
target_column=config_dict.get("target_column", []),
split=config_dict.get("split", "test"),
)
"""The path to the column containing the target text."""

def get_source_data(self, ds: Example) -> Any:
"""Retrieve the source (audio) data from the dataset."""
@@ -75,7 +64,7 @@ def get_target_data(self, ds: Example) -> Any:
return results

@staticmethod
def _get_data(ds: Example, path: List[str]) -> Union[Example, List[int], str]:
def _get_data(ds: Example, path: List[str]) -> Example | List[int] | str:
"""Retrieve data from the dataset using the specified path."""
current = ds
for key in path: