Skip to content

Commit

Permalink
try a few
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 5, 2024
1 parent e9d901b commit ed9f2bb
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/text/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _metric_max_over_ground_truths(

def _squad_input_check(
preds: PREDS_TYPE, targets: TARGETS_TYPE
) -> Tuple[Dict[str, str], List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]]]:
) -> Tuple[Dict[str, str], list[Dict[str, List[Dict[str, list[Dict[str, Any]]]]]]]:
"""Check for types and convert the input to necessary format to compute the input."""
if isinstance(preds, Dict):
preds = [preds]
Expand All @@ -118,7 +118,7 @@ def _squad_input_check(
f"{SQuAD_FORMAT}"
)

answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore[assignment]
answers: dict[str, Union[List[str], list[int]]] = target["answers"] # type: ignore[assignment]
if "text" not in answers:
raise KeyError(
"Expected keys in a 'answers' are 'text'."
Expand All @@ -135,7 +135,7 @@ def _squad_input_check(

def _squad_update(
preds: Dict[str, str],
target: List[Dict[str, List[Dict[str, List[Dict[str, Any]]]]]],
target: list[dict[str, list[dict[str, list[dict[str, Any]]]]]],
) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute F1 Score and Exact Match for a collection of predictions and references.
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/text/test_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2


def _reference_jiwer_wer(preds: Union[str, List[str]], target: Union[str, List[str]]):
def _reference_jiwer_wer(preds: Union[str, list[str]], target: Union[str, list[str]]):
try:
from jiwer import compute_measures
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/text/test_wip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from unittests.text._inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2


def _reference_jiwer_wip(preds: Union[str, List[str]], target: Union[str, List[str]]):
def _reference_jiwer_wip(preds: Union[str, list[str]], target: Union[str, list[str]]):
try:
from jiwer import wip
except ImportError:
Expand Down

0 comments on commit ed9f2bb

Please sign in to comment.