Skip to content

Commit

Permalink
Merge pull request #350 from LAAC-LSCP/annotations/import_any_eaf_tier
Browse files Browse the repository at this point in the history
Annotations/import any eaf tier
  • Loading branch information
marianne-m authored Feb 22, 2022
2 parents f3486ce + b4d075b commit 1475102
Show file tree
Hide file tree
Showing 6 changed files with 3,596 additions and 12 deletions.
13 changes: 9 additions & 4 deletions ChildProject/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,14 @@ def write(self):
)

def _import_annotation(
self, import_function: Callable[[str], pd.DataFrame], annotation: dict
self, import_function: Callable[[str], pd.DataFrame], params: dict, annotation: dict
):
"""import and convert ``annotation``. This function should not be called outside of this class.
:param import_function: If callable, ``import_function`` will be called to convert the input annotation into a dataframe. Otherwise, the conversion will be performed by a built-in function.
:type import_function: Callable[[str], pd.DataFrame]
:param params: Optional parameters. With ```new_tiers```, the corresponding EAF tiers will be imported
:type params: dict
:param annotation: input annotation dictionary (attributes defined according to :ref:`ChildProject.annotations.AnnotationManager.SEGMENTS_COLUMNS`)
:type annotation: dict
:return: output annotation dictionary (attributes defined according to :ref:`ChildProject.annotations.AnnotationManager.SEGMENTS_COLUMNS`)
Expand Down Expand Up @@ -427,7 +429,7 @@ def _import_annotation(
df = import_function(path)
elif annotation_format in converters:
converter = converters[annotation_format]
df = converter.convert(path, filter)
df = converter.convert(path, filter, **params)
else:
raise ValueError(
"file format '{}' unknown for '{}'".format(annotation_format, path)
Expand Down Expand Up @@ -488,6 +490,7 @@ def import_annotations(
input: pd.DataFrame,
threads: int = -1,
import_function: Callable[[str], pd.DataFrame] = None,
new_tiers: list = None,
) -> pd.DataFrame:
"""Import and convert annotations.
Expand All @@ -497,6 +500,8 @@ def import_annotations(
:type threads: int, optional
:param import_function: If specified, the custom ``import_function`` function will be used to convert all ``input`` annotations, defaults to None
:type import_function: Callable[[str], pd.DataFrame], optional
:param new_tiers: List of EAF tiers names. If specified, the corresponding EAF tiers will be imported.
:type new_tiers: list[str], optional
:return: dataframe of imported annotations, as in :ref:`format-annotations`.
:rtype: pd.DataFrame
"""
Expand Down Expand Up @@ -536,12 +541,12 @@ def import_annotations(

if threads == 1:
imported = input.apply(
partial(self._import_annotation, import_function), axis=1
partial(self._import_annotation, import_function, {"new_tiers": new_tiers}), axis=1
).to_dict(orient="records")
else:
with mp.Pool(processes=threads if threads > 0 else mp.cpu_count()) as pool:
imported = pool.map(
partial(self._import_annotation, import_function),
partial(self._import_annotation, import_function, {"new_tiers": new_tiers}),
input.to_dict(orient="records"),
)

Expand Down
18 changes: 10 additions & 8 deletions ChildProject/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class CsvConverter(AnnotationConverter):
FORMAT = "csv"

@staticmethod
def convert(filename: str, filter="") -> pd.DataFrame:
def convert(filename: str, filter="", **kwargs) -> pd.DataFrame:
return pd.read_csv(filename)


Expand All @@ -83,7 +83,7 @@ class VtcConverter(AnnotationConverter):
)

@staticmethod
def convert(filename: str, source_file: str = "") -> pd.DataFrame:
def convert(filename: str, source_file: str = "", **kwargs) -> pd.DataFrame:
rttm = pd.read_csv(
filename,
sep=" ",
Expand Down Expand Up @@ -149,7 +149,7 @@ class VcmConverter(AnnotationConverter):
)

@staticmethod
def convert(filename: str, source_file: str = "") -> pd.DataFrame:
def convert(filename: str, source_file: str = "", **kwargs) -> pd.DataFrame:
rttm = pd.read_csv(
filename,
sep=" ",
Expand Down Expand Up @@ -200,7 +200,7 @@ class AliceConverter(AnnotationConverter):
FORMAT = "alice"

@staticmethod
def convert(filename: str, source_file: str = "") -> pd.DataFrame:
def convert(filename: str, source_file: str = "", **kwargs) -> pd.DataFrame:
df = pd.read_csv(
filename,
sep=r"\s",
Expand Down Expand Up @@ -231,7 +231,7 @@ class ItsConverter(AnnotationConverter):
)

@staticmethod
def convert(filename: str, recording_num: int = None) -> pd.DataFrame:
def convert(filename: str, recording_num: int = None, **kwargs) -> pd.DataFrame:
from lxml import etree

xml = etree.parse(filename)
Expand Down Expand Up @@ -398,7 +398,7 @@ class TextGridConverter(AnnotationConverter):
FORMAT = "TextGrid"

@staticmethod
def convert(filename: str, filter=None) -> pd.DataFrame:
def convert(filename: str, filter=None, **kwargs) -> pd.DataFrame:
import pympi

textgrid = pympi.Praat.TextGrid(filename)
Expand Down Expand Up @@ -440,7 +440,7 @@ class EafConverter(AnnotationConverter):
FORMAT = "eaf"

@staticmethod
def convert(filename: str, filter=None) -> pd.DataFrame:
def convert(filename: str, filter=None, **kwargs) -> pd.DataFrame:
import pympi

eaf = pympi.Elan.Eaf(filename)
Expand Down Expand Up @@ -523,6 +523,8 @@ def convert(filename: str, filter=None) -> pd.DataFrame:
segment["vcm_type"] = value
elif label == "msc":
segment["msc_type"] = value
elif label in kwargs["new_tiers"]:
segment[label] = value

return pd.DataFrame(segments.values())

Expand Down Expand Up @@ -592,7 +594,7 @@ def role_to_addressee(role):
return ChatConverter.ADDRESSEE_TABLE[ChatConverter.SPEAKER_ROLE_TO_TYPE[role]]

@staticmethod
def convert(filename: str, filter=None) -> pd.DataFrame:
def convert(filename: str, filter=None, **kwargs) -> pd.DataFrame:

import pylangacq

Expand Down
13 changes: 13 additions & 0 deletions docs/source/api-annotations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ The contents of the output CSV file can be checked:
Users are advised to check the consistency and validity of the annotations and their index
using the validation procedure.

Importing any EAF tier
----------------------

When importing EAF annotation files, some tiers are supported by ChildProject, such as `vcm_type` or
`lex_type`.

If you want to import a tier that is not supported by ChildProject, you can use
:meth:`~ChildProject.annotations.AnnotationManager.import_annotations` as follows :

.. code-block:: python
>>> am.import_annotations(input, new_tier = ['name_of_tier'])
Validating annotations
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
Loading

0 comments on commit 1475102

Please sign in to comment.