Skip to content

Commit 6485cfb

Browse files
fixes
1 parent 984461a commit 6485cfb

File tree

7 files changed

+158
-109
lines changed

7 files changed

+158
-109
lines changed

ChildProject/annotations.py

+84-58
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from . import __version__
1313
from .projects import ChildProject
14-
from .converters import is_thread_safe
14+
from .converters import *
1515
from .tables import IndexTable, IndexColumn
1616
from .utils import Segment, intersect_ranges
1717

@@ -31,7 +31,7 @@ class AnnotationManager:
3131
IndexColumn(name = 'range_onset', description = 'covered range start time in milliseconds, measured since `time_seek`', regex = r"([0-9]+)", required = True),
3232
IndexColumn(name = 'range_offset', description = 'covered range end time in milliseconds, measured since `time_seek`', regex = r"([0-9]+)", required = True),
3333
IndexColumn(name = 'raw_filename', description = 'annotation input filename location, relative to `annotations/<set>/raw`', filename = True, required = True),
34-
IndexColumn(name = 'format', description = 'input annotation format', choices = ['TextGrid', 'eaf', 'vtc_rttm', 'vcm_rttm', 'alice', 'its', 'cha'], required = False),
34+
IndexColumn(name = 'format', description = 'input annotation format', choices = [*converters.keys(), 'NA'], required = False),
3535
IndexColumn(name = 'filter', description = 'source file to filter in (for rttm and alice only)', required = False),
3636
IndexColumn(name = 'annotation_filename', description = 'output formatted annotation location, relative to `annotations/<set>/converted (automatic column, don\'t specify)', filename = True, required = False, generated = True),
3737
IndexColumn(name = 'imported_at', description = 'importation date (automatic column, don\'t specify)', datetime = "%Y-%m-%d %H:%M:%S", required = False, generated = True),
@@ -118,19 +118,32 @@ def read(self) -> Tuple[List[str], List[str]]:
118118

119119
return errors, warnings
120120

121+
def write(self):
122+
"""Update the annotations index,
123+
while enforcing its good shape.
124+
"""
125+
self.annotations[['time_seek', 'range_onset', 'range_offset']].fillna(0, inplace = True)
126+
self.annotations[['time_seek', 'range_onset', 'range_offset']] = self.annotations[['time_seek', 'range_onset', 'range_offset']].astype(int)
127+
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
128+
121129
def validate_annotation(self, annotation: dict) -> Tuple[List[str], List[str]]:
122-
print("validating {}...".format(annotation['annotation_filename']))
130+
print("validating {} from {}...".format(annotation['annotation_filename'], annotation['set']))
123131

124132
segments = IndexTable(
125133
'segments',
126-
path = os.path.join(self.project.path, 'annotations', annotation['set'], 'converted', str(annotation['annotation_filename'])),
134+
path = os.path.join(self.project.path, 'annotations', annotation['set'], 'converted', annotation['annotation_filename']),
127135
columns = self.SEGMENTS_COLUMNS
128136
)
129137

130138
try:
131139
segments.read()
132140
except Exception as e:
133-
return [str(e)], []
141+
error_message = "error while trying to read {} from {}:\n\t{}".format(
142+
annotation['annotation_filename'],
143+
annotation['set'],
144+
str(e)
145+
)
146+
return [error_message], []
134147

135148
return segments.validate()
136149

@@ -159,20 +172,32 @@ def validate(self, annotations: pd.DataFrame = None, threads: int = 0) -> Tuple[
159172

160173
return errors, warnings
161174

162-
def _read_annotation(self, filename, output_type):
163-
if output_type == 'gz':
164-
return pd.read_csv(filename, compression = 'gzip')
175+
def _read_annotation(self, set: str, filename: str, annotation_type: str = None):
176+
path = os.path.join(self.project.path, 'annotations', set, 'converted', filename)
177+
ext = os.path.splitext(filename)[1]
178+
179+
if ext == '.gz':
180+
return pd.read_csv(path, compression = 'gzip')
181+
elif ext == '.csv':
182+
return pd.read_csv(path)
165183
else:
166-
return pd.read_csv(filename)
167-
184+
raise ValueError(f"invalid extension '{ext}' for annotation {set}/{filename}'")
168185

169-
def _write_annotation(self, df, filename, annotation_type = None):
170-
os.makedirs(os.path.dirname(filename), exist_ok = True)
186+
def _write_annotation(self, df: pd.DataFrame, set: str, filename: str):
187+
path = os.path.join(self.project.path, 'annotations', set, 'converted', filename)
188+
ext = os.path.splitext(filename)[1]
171189

172-
if annotation_type == 'gz':
173-
df.to_csv(filename, index = False, compression = 'gzip')
190+
os.makedirs(os.path.dirname(path), exist_ok = True)
191+
192+
if ext == '.gz':
193+
df.to_csv(path, index = False, compression = 'gzip')
194+
elif ext == '.csv':
195+
df.to_csv(path, index = False)
174196
else:
175-
df.to_csv(filename, index = False)
197+
raise ValueError(f"invalid extension '{ext}' for annotation {set}/{filename}'")
198+
199+
200+
return os.path.basename(path)
176201

177202
def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], annotation: dict):
178203
"""import and convert ``annotation``. This function should not be called outside of this class.
@@ -186,9 +211,6 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
186211
"""
187212

188213
source_recording = os.path.splitext(annotation['recording_filename'])[0]
189-
annotation_filename = "{}_{}_{}.csv".format(source_recording, annotation['time_seek'], annotation['range_onset'])
190-
output_filename = os.path.join('annotations', annotation['set'], 'converted', annotation_filename)
191-
192214
path = os.path.join(self.project.path, 'annotations', annotation['set'], 'raw', annotation['raw_filename'])
193215
annotation_format = annotation['format']
194216

@@ -198,27 +220,9 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
198220
try:
199221
if callable(import_function):
200222
df = import_function(path)
201-
elif annotation_format == 'TextGrid':
202-
from .converters import TextGridConverter
203-
df = TextGridConverter.convert(path)
204-
elif annotation_format == 'eaf':
205-
from .converters import EafConverter
206-
df = EafConverter.convert(path)
207-
elif annotation_format == 'vtc_rttm':
208-
from .converters import VtcConverter
209-
df = VtcConverter.convert(path, source_file = filter)
210-
elif annotation_format == 'vcm_rttm':
211-
from .converters import VcmConverter
212-
df = VcmConverter.convert(path, source_file = filter)
213-
elif annotation_format == 'its':
214-
from .converters import ItsConverter
215-
df = ItsConverter.convert(path, recording_num = filter)
216-
elif annotation_format == 'alice':
217-
from .converters import AliceConverter
218-
df = AliceConverter.convert(path, source_file = filter)
219-
elif annotation_format == 'cha':
220-
from .converters import ChatConverter
221-
df = ChatConverter.convert(path)
223+
elif annotation_format in converters:
224+
converter = converters[annotation_format]
225+
df = converter.convert(path, filter)
222226
else:
223227
raise ValueError("file format '{}' unknown for '{}'".format(annotation_format, path))
224228
except:
@@ -248,13 +252,29 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
248252

249253
df.sort_values(sort_columns, inplace = True)
250254

251-
annotation_type = annotation['type'] if 'type' in annotation else 'csv'
252-
self._write_annotation(df, os.path.join(self.project.path, output_filename), annotation_type)
255+
if 'type' not in annotation or pd.isnull(annotation['type']):
256+
annotation['type'] = 'csv'
257+
258+
annotation_filename = "{}_{}_{}.{}".format(
259+
source_recording,
260+
annotation['time_seek'],
261+
annotation['range_onset'],
262+
annotation['type']
263+
)
264+
265+
self._write_annotation(
266+
df,
267+
annotation['set'],
268+
annotation_filename
269+
)
253270

254271
annotation['annotation_filename'] = annotation_filename
255272
annotation['imported_at'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
256273
annotation['package_version'] = __version__
257274

275+
if pd.isnull(annotation['format']):
276+
annotation['format'] = 'NA'
277+
258278
return annotation
259279

260280
def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_function: Callable[[str], pd.DataFrame] = None) -> pd.DataFrame:
@@ -279,8 +299,8 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
279299
input['range_onset'] = input['range_onset'].astype(int)
280300
input['range_offset'] = input['range_offset'].astype(int)
281301

282-
builtin = input[input['format'].isin(is_thread_safe.keys())]
283-
if not builtin['format'].map(is_thread_safe).all():
302+
builtin = input[input['format'].isin(converters.keys())]
303+
if not builtin['format'].map(lambda f: converters[f].THREAD_SAFE).all():
284304
print('warning: some of the converters do not support multithread importation; running on 1 thread')
285305
threads = 1
286306

@@ -298,7 +318,7 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
298318

299319
self.read()
300320
self.annotations = pd.concat([self.annotations, imported], sort = False)
301-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
321+
self.write()
302322

303323
return imported
304324

@@ -357,7 +377,7 @@ def remove_set(self, annotation_set: str, recursive: bool = False):
357377
pass
358378

359379
self.annotations = self.annotations[self.annotations['set'] != annotation_set]
360-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
380+
self.write()
361381

362382
def rename_set(self, annotation_set: str, new_set: str, recursive: bool = False, ignore_errors: bool = False):
363383
"""Rename a set of annotations, moving all related files
@@ -410,20 +430,20 @@ def rename_set(self, annotation_set: str, new_set: str, recursive: bool = False,
410430
move(os.path.join(current_path, 'converted'), os.path.join(new_path, 'converted'))
411431

412432
self.annotations.loc[(self.annotations['set'] == annotation_set), 'set'] = new_set
433+
self.write()
413434

414-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
415-
416-
def merge_annotations(self, left_columns, right_columns, columns, output_set, input):
435+
def merge_annotations(self, left_columns, right_columns, columns, output_set, type, input):
417436
left_annotations = input['left_annotations']
418437
right_annotations = input['right_annotations']
419438

420439
annotations = left_annotations.copy()
421440
annotations['format'] = ''
422441
annotations['annotation_filename'] = annotations.apply(
423-
lambda annotation: "{}_{}_{}.csv".format(
442+
lambda annotation: "{}_{}_{}.{}".format(
424443
os.path.splitext(annotation['recording_filename'])[0],
425444
annotation['time_seek'],
426-
annotation['range_onset']
445+
annotation['range_onset'],
446+
type
427447
)
428448
, axis = 1)
429449

@@ -497,16 +517,19 @@ def merge_annotations(self, left_columns, right_columns, columns, output_set, in
497517

498518
segments = output_segments[output_segments['interval'] == interval]
499519
segments.drop(columns = list(set(segments.columns)-{c.name for c in self.SEGMENTS_COLUMNS}), inplace = True)
500-
segments.to_csv(
501-
os.path.join(self.project.path, 'annotations', annotation_set, 'converted', annotation_filename),
502-
index = False
520+
521+
self._write_annotation(
522+
segments,
523+
annotation_set,
524+
annotation_filename
503525
)
504526

505527
return annotations
506528

507529
def merge_sets(self, left_set: str, right_set: str,
508530
left_columns: List[str], right_columns: List[str],
509531
output_set: str, columns: dict = {},
532+
type = 'csv',
510533
threads = -1
511534
):
512535
"""Merge columns from ``left_set`` and ``right_set`` annotations,
@@ -556,15 +579,19 @@ def merge_sets(self, left_set: str, right_set: str,
556579
for recording in left_annotations['recording_filename'].unique()
557580
]
558581

559-
pool = mp.Pool(processes = threads if threads > 0 else mp.cpu_count())
560-
annotations = pool.map(partial(self.merge_annotations, left_columns, right_columns, columns, output_set), input_annotations)
582+
with mp.Pool(processes = threads if threads > 0 else mp.cpu_count()) as pool:
583+
annotations = pool.map(
584+
partial(self.merge_annotations, left_columns, right_columns, columns, output_set, type),
585+
input_annotations
586+
)
587+
561588
annotations = pd.concat(annotations)
562589
annotations.drop(columns = list(set(annotations.columns)-{c.name for c in self.INDEX_COLUMNS}), inplace = True)
563590
annotations.fillna({'raw_filename': 'NA'}, inplace = True)
564591

565592
self.read()
566593
self.annotations = pd.concat([self.annotations, annotations], sort = False)
567-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
594+
self.write()
568595

569596
def get_segments(self, annotations: pd.DataFrame) -> pd.DataFrame:
570597
"""get all segments associated to the annotations referenced in ``annotations``.
@@ -580,9 +607,8 @@ def get_segments(self, annotations: pd.DataFrame) -> pd.DataFrame:
580607
segments = []
581608
for index, _annotations in annotations.groupby(['set', 'annotation_filename']):
582609
s, annotation_filename = index
583-
path = os.path.join(self.project.path, 'annotations', s, 'converted', annotation_filename)
584610
annotation_type = _annotations['type'].iloc[0] if 'type' in _annotations.columns else 'csv'
585-
df = self._read_annotation(path, annotation_type)
611+
df = self._read_annotation(s, annotation_filename, annotation_type)
586612

587613
for annotation in _annotations.to_dict(orient = 'records'):
588614
segs = df.copy()

ChildProject/converters.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pandas as pd
33
import re
44

5+
converters = {}
6+
57
class AnnotationConverter:
68
SPEAKER_ID_TO_TYPE = defaultdict(lambda: 'NA', {
79
'C1': 'OCH',
@@ -50,7 +52,12 @@ class AnnotationConverter:
5052

5153
THREAD_SAFE = True
5254

55+
def __init_subclass__(cls, **kwargs):
56+
super().__init_subclass__(**kwargs)
57+
converters[cls.FORMAT] = cls
58+
5359
class VtcConverter(AnnotationConverter):
60+
FORMAT = 'vtc_rttm'
5461

5562
SPEAKER_TYPE_TRANSLATION = defaultdict(lambda: 'NA', {
5663
'CHI': 'OCH',
@@ -80,6 +87,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame:
8087
return df
8188

8289
class VcmConverter(AnnotationConverter):
90+
FORMAT = 'vcm_rttm'
8391

8492
SPEAKER_TYPE_TRANSLATION = defaultdict(lambda: 'NA', {
8593
'CHI': 'OCH',
@@ -119,6 +127,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame:
119127
return df
120128

121129
class AliceConverter(AnnotationConverter):
130+
FORMAT = 'alice'
122131

123132
@staticmethod
124133
def convert(filename: str, source_file: str = '') -> pd.DataFrame:
@@ -142,6 +151,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame:
142151
return df
143152

144153
class ItsConverter(AnnotationConverter):
154+
FORMAT = 'its'
145155

146156
SPEAKER_TYPE_TRANSLATION = defaultdict(lambda: 'NA', {
147157
'CHN': 'CHI',
@@ -277,9 +287,10 @@ def extract_from_regex(pattern, subject):
277287
return df
278288

279289
class TextGridConverter(AnnotationConverter):
290+
FORMAT = 'TextGrid'
280291

281292
@staticmethod
282-
def convert(filename: str) -> pd.DataFrame:
293+
def convert(filename: str, filter = None) -> pd.DataFrame:
283294
import pympi
284295
textgrid = pympi.Praat.TextGrid(filename)
285296

@@ -316,9 +327,10 @@ def ling_type(s):
316327
return pd.DataFrame(segments)
317328

318329
class EafConverter(AnnotationConverter):
330+
FORMAT = 'eaf'
319331

320332
@staticmethod
321-
def convert(filename: str) -> pd.DataFrame:
333+
def convert(filename: str, filter = None) -> pd.DataFrame:
322334
import pympi
323335
eaf = pympi.Elan.Eaf(filename)
324336

@@ -387,6 +399,7 @@ def convert(filename: str) -> pd.DataFrame:
387399
return pd.DataFrame(segments.values())
388400

389401
class ChatConverter(AnnotationConverter):
402+
FORMAT = 'cha'
390403
THREAD_SAFE = False
391404

392405
SPEAKER_ROLE_TO_TYPE = defaultdict(lambda: 'NA', {
@@ -450,7 +463,7 @@ def role_to_addressee(role):
450463
return ChatConverter.ADDRESSEE_TABLE[ChatConverter.SPEAKER_ROLE_TO_TYPE[role]]
451464

452465
@staticmethod
453-
def convert(filename: str) -> pd.DataFrame:
466+
def convert(filename: str, filter = None) -> pd.DataFrame:
454467

455468
import pylangacq
456469

@@ -487,14 +500,4 @@ def convert(filename: str) -> pd.DataFrame:
487500
df.drop(columns = ['participant', 'tokens', 'time_marks'], inplace = True)
488501
df.fillna('NA', inplace = True)
489502

490-
return df
491-
492-
is_thread_safe = {
493-
'its': ItsConverter.THREAD_SAFE,
494-
'vtc_rttm': VtcConverter.THREAD_SAFE,
495-
'vcm_rttm': VcmConverter.THREAD_SAFE,
496-
'eaf': EafConverter.THREAD_SAFE,
497-
'TextGrid': TextGridConverter.THREAD_SAFE,
498-
'alice': AliceConverter.THREAD_SAFE,
499-
'cha': ChatConverter.THREAD_SAFE
500-
}
503+
return df

0 commit comments

Comments
 (0)