11
11
12
12
from . import __version__
13
13
from .projects import ChildProject
14
- from .converters import is_thread_safe
14
+ from .converters import *
15
15
from .tables import IndexTable , IndexColumn
16
16
from .utils import Segment , intersect_ranges
17
17
@@ -31,10 +31,11 @@ class AnnotationManager:
31
31
IndexColumn (name = 'range_onset' , description = 'covered range start time in milliseconds, measured since `time_seek`' , regex = r"([0-9]+)" , required = True ),
32
32
IndexColumn (name = 'range_offset' , description = 'covered range end time in milliseconds, measured since `time_seek`' , regex = r"([0-9]+)" , required = True ),
33
33
IndexColumn (name = 'raw_filename' , description = 'annotation input filename location, relative to `annotations/<set>/raw`' , filename = True , required = True ),
34
- IndexColumn (name = 'format' , description = 'input annotation format' , choices = ['TextGrid' , 'eaf' , 'vtc_rttm' , 'vcm_rttm' , 'alice' , 'its' , 'cha ' ], required = False ),
34
+ IndexColumn (name = 'format' , description = 'input annotation format' , choices = [* converters . keys () , 'NA ' ], required = False ),
35
35
IndexColumn (name = 'filter' , description = 'source file to filter in (for rttm and alice only)' , required = False ),
36
36
IndexColumn (name = 'annotation_filename' , description = 'output formatted annotation location, relative to `annotations/<set>/converted (automatic column, don\' t specify)' , filename = True , required = False , generated = True ),
37
37
IndexColumn (name = 'imported_at' , description = 'importation date (automatic column, don\' t specify)' , datetime = "%Y-%m-%d %H:%M:%S" , required = False , generated = True ),
38
+ IndexColumn (name = 'type' , description = 'annotation storage format' , choices = ['csv' , 'gz' ], required = False ),
38
39
IndexColumn (name = 'package_version' , description = 'version of the package used when the importation was performed' , regex = r"[0-9]+\.[0-9]+\.[0-9]+" , required = False , generated = True ),
39
40
IndexColumn (name = 'error' , description = 'error message in case the annotation could not be imported' , required = False , generated = True )
40
41
]
@@ -117,19 +118,32 @@ def read(self) -> Tuple[List[str], List[str]]:
117
118
118
119
return errors , warnings
119
120
121
+ def write (self ):
122
+ """Update the annotations index,
123
+ while enforcing its good shape.
124
+ """
125
+ self .annotations [['time_seek' , 'range_onset' , 'range_offset' ]].fillna (0 , inplace = True )
126
+ self .annotations [['time_seek' , 'range_onset' , 'range_offset' ]] = self .annotations [['time_seek' , 'range_onset' , 'range_offset' ]].astype (int )
127
+ self .annotations .to_csv (os .path .join (self .project .path , 'metadata/annotations.csv' ), index = False )
128
+
120
129
def validate_annotation (self , annotation : dict ) -> Tuple [List [str ], List [str ]]:
121
- print ("validating {}..." .format (annotation ['annotation_filename' ]))
130
+ print ("validating {} from {} ..." .format (annotation ['annotation_filename' ], annotation [ 'set ' ]))
122
131
123
132
segments = IndexTable (
124
133
'segments' ,
125
- path = os .path .join (self .project .path , 'annotations' , annotation ['set' ], 'converted' , str ( annotation ['annotation_filename' ]) ),
134
+ path = os .path .join (self .project .path , 'annotations' , annotation ['set' ], 'converted' , annotation ['annotation_filename' ]),
126
135
columns = self .SEGMENTS_COLUMNS
127
136
)
128
137
129
138
try :
130
139
segments .read ()
131
140
except Exception as e :
132
- return [str (e )], []
141
+ error_message = "error while trying to read {} from {}:\n \t {}" .format (
142
+ annotation ['annotation_filename' ],
143
+ annotation ['set' ],
144
+ str (e )
145
+ )
146
+ return [error_message ], []
133
147
134
148
return segments .validate ()
135
149
@@ -158,6 +172,32 @@ def validate(self, annotations: pd.DataFrame = None, threads: int = 0) -> Tuple[
158
172
159
173
return errors , warnings
160
174
175
+ def _read_annotation (self , set : str , filename : str , annotation_type : str = None ):
176
+ path = os .path .join (self .project .path , 'annotations' , set , 'converted' , filename )
177
+ ext = os .path .splitext (filename )[1 ]
178
+
179
+ if ext == '.gz' :
180
+ return pd .read_csv (path , compression = 'gzip' )
181
+ elif ext == '.csv' :
182
+ return pd .read_csv (path )
183
+ else :
184
+ raise ValueError (f"invalid extension '{ ext } ' for annotation { set } /{ filename } '" )
185
+
186
+ def _write_annotation (self , df : pd .DataFrame , set : str , filename : str ):
187
+ path = os .path .join (self .project .path , 'annotations' , set , 'converted' , filename )
188
+ ext = os .path .splitext (filename )[1 ]
189
+
190
+ os .makedirs (os .path .dirname (path ), exist_ok = True )
191
+
192
+ if ext == '.gz' :
193
+ df .to_csv (path , index = False , compression = 'gzip' )
194
+ elif ext == '.csv' :
195
+ df .to_csv (path , index = False )
196
+ else :
197
+ raise ValueError (f"invalid extension '{ ext } ' for annotation { set } /{ filename } '" )
198
+
199
+
200
+ return os .path .basename (path )
161
201
162
202
def _import_annotation (self , import_function : Callable [[str ], pd .DataFrame ], annotation : dict ):
163
203
"""import and convert ``annotation``. This function should not be called outside of this class.
@@ -171,9 +211,6 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
171
211
"""
172
212
173
213
source_recording = os .path .splitext (annotation ['recording_filename' ])[0 ]
174
- annotation_filename = "{}_{}_{}.csv" .format (source_recording , annotation ['time_seek' ], annotation ['range_onset' ])
175
- output_filename = os .path .join ('annotations' , annotation ['set' ], 'converted' , annotation_filename )
176
-
177
214
path = os .path .join (self .project .path , 'annotations' , annotation ['set' ], 'raw' , annotation ['raw_filename' ])
178
215
annotation_format = annotation ['format' ]
179
216
@@ -183,27 +220,9 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
183
220
try :
184
221
if callable (import_function ):
185
222
df = import_function (path )
186
- elif annotation_format == 'TextGrid' :
187
- from .converters import TextGridConverter
188
- df = TextGridConverter .convert (path )
189
- elif annotation_format == 'eaf' :
190
- from .converters import EafConverter
191
- df = EafConverter .convert (path )
192
- elif annotation_format == 'vtc_rttm' :
193
- from .converters import VtcConverter
194
- df = VtcConverter .convert (path , source_file = filter )
195
- elif annotation_format == 'vcm_rttm' :
196
- from .converters import VcmConverter
197
- df = VcmConverter .convert (path , source_file = filter )
198
- elif annotation_format == 'its' :
199
- from .converters import ItsConverter
200
- df = ItsConverter .convert (path , recording_num = filter )
201
- elif annotation_format == 'alice' :
202
- from .converters import AliceConverter
203
- df = AliceConverter .convert (path , source_file = filter )
204
- elif annotation_format == 'cha' :
205
- from .converters import ChatConverter
206
- df = ChatConverter .convert (path )
223
+ elif annotation_format in converters :
224
+ converter = converters [annotation_format ]
225
+ df = converter .convert (path , filter )
207
226
else :
208
227
raise ValueError ("file format '{}' unknown for '{}'" .format (annotation_format , path ))
209
228
except :
@@ -233,13 +252,29 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
233
252
234
253
df .sort_values (sort_columns , inplace = True )
235
254
236
- os .makedirs (os .path .dirname (os .path .join (self .project .path , output_filename )), exist_ok = True )
237
- df .to_csv (os .path .join (self .project .path , output_filename ), index = False )
255
+ if 'type' not in annotation or pd .isnull (annotation ['type' ]):
256
+ annotation ['type' ] = 'csv'
257
+
258
+ annotation_filename = "{}_{}_{}.{}" .format (
259
+ source_recording ,
260
+ annotation ['time_seek' ],
261
+ annotation ['range_onset' ],
262
+ annotation ['type' ]
263
+ )
264
+
265
+ self ._write_annotation (
266
+ df ,
267
+ annotation ['set' ],
268
+ annotation_filename
269
+ )
238
270
239
271
annotation ['annotation_filename' ] = annotation_filename
240
272
annotation ['imported_at' ] = datetime .datetime .now ().strftime ("%Y-%m-%d %H:%M:%S" )
241
273
annotation ['package_version' ] = __version__
242
274
275
+ if pd .isnull (annotation ['format' ]):
276
+ annotation ['format' ] = 'NA'
277
+
243
278
return annotation
244
279
245
280
def import_annotations (self , input : pd .DataFrame , threads : int = - 1 , import_function : Callable [[str ], pd .DataFrame ] = None ) -> pd .DataFrame :
@@ -264,8 +299,8 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
264
299
input ['range_onset' ] = input ['range_onset' ].astype (int )
265
300
input ['range_offset' ] = input ['range_offset' ].astype (int )
266
301
267
- builtin = input [input ['format' ].isin (is_thread_safe .keys ())]
268
- if not builtin ['format' ].map (is_thread_safe ).all ():
302
+ builtin = input [input ['format' ].isin (converters .keys ())]
303
+ if not builtin ['format' ].map (lambda f : converters [ f ]. THREAD_SAFE ).all ():
269
304
print ('warning: some of the converters do not support multithread importation; running on 1 thread' )
270
305
threads = 1
271
306
@@ -283,7 +318,7 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
283
318
284
319
self .read ()
285
320
self .annotations = pd .concat ([self .annotations , imported ], sort = False )
286
- self .annotations . to_csv ( os . path . join ( self . project . path , 'metadata/annotations.csv' ), index = False )
321
+ self .write ( )
287
322
288
323
return imported
289
324
@@ -342,7 +377,7 @@ def remove_set(self, annotation_set: str, recursive: bool = False):
342
377
pass
343
378
344
379
self .annotations = self .annotations [self .annotations ['set' ] != annotation_set ]
345
- self .annotations . to_csv ( os . path . join ( self . project . path , 'metadata/annotations.csv' ), index = False )
380
+ self .write ( )
346
381
347
382
def rename_set (self , annotation_set : str , new_set : str , recursive : bool = False , ignore_errors : bool = False ):
348
383
"""Rename a set of annotations, moving all related files
@@ -395,20 +430,20 @@ def rename_set(self, annotation_set: str, new_set: str, recursive: bool = False,
395
430
move (os .path .join (current_path , 'converted' ), os .path .join (new_path , 'converted' ))
396
431
397
432
self .annotations .loc [(self .annotations ['set' ] == annotation_set ), 'set' ] = new_set
433
+ self .write ()
398
434
399
- self .annotations .to_csv (os .path .join (self .project .path , 'metadata/annotations.csv' ), index = False )
400
-
401
- def merge_annotations (self , left_columns , right_columns , columns , output_set , input ):
435
+ def merge_annotations (self , left_columns , right_columns , columns , output_set , type , input ):
402
436
left_annotations = input ['left_annotations' ]
403
437
right_annotations = input ['right_annotations' ]
404
438
405
439
annotations = left_annotations .copy ()
406
440
annotations ['format' ] = ''
407
441
annotations ['annotation_filename' ] = annotations .apply (
408
- lambda annotation : "{}_{}_{}.csv " .format (
442
+ lambda annotation : "{}_{}_{}.{} " .format (
409
443
os .path .splitext (annotation ['recording_filename' ])[0 ],
410
444
annotation ['time_seek' ],
411
- annotation ['range_onset' ]
445
+ annotation ['range_onset' ],
446
+ type
412
447
)
413
448
, axis = 1 )
414
449
@@ -482,16 +517,19 @@ def merge_annotations(self, left_columns, right_columns, columns, output_set, in
482
517
483
518
segments = output_segments [output_segments ['interval' ] == interval ]
484
519
segments .drop (columns = list (set (segments .columns )- {c .name for c in self .SEGMENTS_COLUMNS }), inplace = True )
485
- segments .to_csv (
486
- os .path .join (self .project .path , 'annotations' , annotation_set , 'converted' , annotation_filename ),
487
- index = False
520
+
521
+ self ._write_annotation (
522
+ segments ,
523
+ annotation_set ,
524
+ annotation_filename
488
525
)
489
526
490
527
return annotations
491
528
492
529
def merge_sets (self , left_set : str , right_set : str ,
493
530
left_columns : List [str ], right_columns : List [str ],
494
531
output_set : str , columns : dict = {},
532
+ type = 'csv' ,
495
533
threads = - 1
496
534
):
497
535
"""Merge columns from ``left_set`` and ``right_set`` annotations,
@@ -541,15 +579,19 @@ def merge_sets(self, left_set: str, right_set: str,
541
579
for recording in left_annotations ['recording_filename' ].unique ()
542
580
]
543
581
544
- pool = mp .Pool (processes = threads if threads > 0 else mp .cpu_count ())
545
- annotations = pool .map (partial (self .merge_annotations , left_columns , right_columns , columns , output_set ), input_annotations )
582
+ with mp .Pool (processes = threads if threads > 0 else mp .cpu_count ()) as pool :
583
+ annotations = pool .map (
584
+ partial (self .merge_annotations , left_columns , right_columns , columns , output_set , type ),
585
+ input_annotations
586
+ )
587
+
546
588
annotations = pd .concat (annotations )
547
589
annotations .drop (columns = list (set (annotations .columns )- {c .name for c in self .INDEX_COLUMNS }), inplace = True )
548
590
annotations .fillna ({'raw_filename' : 'NA' }, inplace = True )
549
591
550
592
self .read ()
551
593
self .annotations = pd .concat ([self .annotations , annotations ], sort = False )
552
- self .annotations . to_csv ( os . path . join ( self . project . path , 'metadata/annotations.csv' ), index = False )
594
+ self .write ( )
553
595
554
596
def get_segments (self , annotations : pd .DataFrame ) -> pd .DataFrame :
555
597
"""get all segments associated to the annotations referenced in ``annotations``.
@@ -565,7 +607,8 @@ def get_segments(self, annotations: pd.DataFrame) -> pd.DataFrame:
565
607
segments = []
566
608
for index , _annotations in annotations .groupby (['set' , 'annotation_filename' ]):
567
609
s , annotation_filename = index
568
- df = pd .read_csv (os .path .join (self .project .path , 'annotations' , s , 'converted' , annotation_filename ))
610
+ annotation_type = _annotations ['type' ].iloc [0 ] if 'type' in _annotations .columns else 'csv'
611
+ df = self ._read_annotation (s , annotation_filename , annotation_type )
569
612
570
613
for annotation in _annotations .to_dict (orient = 'records' ):
571
614
segs = df .copy ()
0 commit comments