11
11
12
12
from . import __version__
13
13
from .projects import ChildProject
14
- from .converters import is_thread_safe
14
+ from .converters import *
15
15
from .tables import IndexTable , IndexColumn
16
16
from .utils import Segment , intersect_ranges
17
17
@@ -31,7 +31,7 @@ class AnnotationManager:
31
31
IndexColumn (name = 'range_onset' , description = 'covered range start time in milliseconds, measured since `time_seek`' , regex = r"([0-9]+)" , required = True ),
32
32
IndexColumn (name = 'range_offset' , description = 'covered range end time in milliseconds, measured since `time_seek`' , regex = r"([0-9]+)" , required = True ),
33
33
IndexColumn (name = 'raw_filename' , description = 'annotation input filename location, relative to `annotations/<set>/raw`' , filename = True , required = True ),
34
- IndexColumn (name = 'format' , description = 'input annotation format' , choices = ['TextGrid' , 'eaf' , 'vtc_rttm' , 'vcm_rttm' , 'alice' , 'its' , 'cha ' ], required = False ),
34
+ IndexColumn (name = 'format' , description = 'input annotation format' , choices = [* converters . keys () , 'NA ' ], required = False ),
35
35
IndexColumn (name = 'filter' , description = 'source file to filter in (for rttm and alice only)' , required = False ),
36
36
IndexColumn (name = 'annotation_filename' , description = 'output formatted annotation location, relative to `annotations/<set>/converted (automatic column, don\' t specify)' , filename = True , required = False , generated = True ),
37
37
IndexColumn (name = 'imported_at' , description = 'importation date (automatic column, don\' t specify)' , datetime = "%Y-%m-%d %H:%M:%S" , required = False , generated = True ),
@@ -118,19 +118,32 @@ def read(self) -> Tuple[List[str], List[str]]:
118
118
119
119
return errors , warnings
120
120
121
+ def write (self ):
122
+ """Update the annotations index,
123
+ while enforcing its good shape.
124
+ """
125
+ self .annotations [['time_seek' , 'range_onset' , 'range_offset' ]].fillna (0 , inplace = True )
126
+ self .annotations [['time_seek' , 'range_onset' , 'range_offset' ]] = self .annotations [['time_seek' , 'range_onset' , 'range_offset' ]].astype (int )
127
+ self .annotations .to_csv (os .path .join (self .project .path , 'metadata/annotations.csv' ), index = False )
128
+
121
129
def validate_annotation (self , annotation : dict ) -> Tuple [List [str ], List [str ]]:
122
- print ("validating {}..." .format (annotation ['annotation_filename' ]))
130
+ print ("validating {} from {} ..." .format (annotation ['annotation_filename' ], annotation [ 'set ' ]))
123
131
124
132
segments = IndexTable (
125
133
'segments' ,
126
- path = os .path .join (self .project .path , 'annotations' , annotation ['set' ], 'converted' , str ( annotation ['annotation_filename' ]) ),
134
+ path = os .path .join (self .project .path , 'annotations' , annotation ['set' ], 'converted' , annotation ['annotation_filename' ]),
127
135
columns = self .SEGMENTS_COLUMNS
128
136
)
129
137
130
138
try :
131
139
segments .read ()
132
140
except Exception as e :
133
- return [str (e )], []
141
+ error_message = "error while trying to read {} from {}:\n \t {}" .format (
142
+ annotation ['annotation_filename' ],
143
+ annotation ['set' ],
144
+ str (e )
145
+ )
146
+ return [error_message ], []
134
147
135
148
return segments .validate ()
136
149
@@ -159,20 +172,32 @@ def validate(self, annotations: pd.DataFrame = None, threads: int = 0) -> Tuple[
159
172
160
173
return errors , warnings
161
174
162
- def _read_annotation (self , filename , output_type ):
163
- if output_type == 'gz' :
164
- return pd .read_csv (filename , compression = 'gzip' )
175
+ def _read_annotation (self , set : str , filename : str , annotation_type : str = None ):
176
+ path = os .path .join (self .project .path , 'annotations' , set , 'converted' , filename )
177
+ ext = os .path .splitext (filename )[1 ]
178
+
179
+ if ext == '.gz' :
180
+ return pd .read_csv (path , compression = 'gzip' )
181
+ elif ext == '.csv' :
182
+ return pd .read_csv (path )
165
183
else :
166
- return pd .read_csv (filename )
167
-
184
+ raise ValueError (f"invalid extension '{ ext } ' for annotation { set } /{ filename } '" )
168
185
169
- def _write_annotation (self , df , filename , annotation_type = None ):
170
- os .makedirs (os .path .dirname (filename ), exist_ok = True )
186
+ def _write_annotation (self , df : pd .DataFrame , set : str , filename : str ):
187
+ path = os .path .join (self .project .path , 'annotations' , set , 'converted' , filename )
188
+ ext = os .path .splitext (filename )[1 ]
171
189
172
- if annotation_type == 'gz' :
173
- df .to_csv (filename , index = False , compression = 'gzip' )
190
+ os .makedirs (os .path .dirname (path ), exist_ok = True )
191
+
192
+ if ext == '.gz' :
193
+ df .to_csv (path , index = False , compression = 'gzip' )
194
+ elif ext == '.csv' :
195
+ df .to_csv (path , index = False )
174
196
else :
175
- df .to_csv (filename , index = False )
197
+ raise ValueError (f"invalid extension '{ ext } ' for annotation { set } /{ filename } '" )
198
+
199
+
200
+ return os .path .basename (path )
176
201
177
202
def _import_annotation (self , import_function : Callable [[str ], pd .DataFrame ], annotation : dict ):
178
203
"""import and convert ``annotation``. This function should not be called outside of this class.
@@ -186,9 +211,6 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
186
211
"""
187
212
188
213
source_recording = os .path .splitext (annotation ['recording_filename' ])[0 ]
189
- annotation_filename = "{}_{}_{}.csv" .format (source_recording , annotation ['time_seek' ], annotation ['range_onset' ])
190
- output_filename = os .path .join ('annotations' , annotation ['set' ], 'converted' , annotation_filename )
191
-
192
214
path = os .path .join (self .project .path , 'annotations' , annotation ['set' ], 'raw' , annotation ['raw_filename' ])
193
215
annotation_format = annotation ['format' ]
194
216
@@ -198,27 +220,9 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
198
220
try :
199
221
if callable (import_function ):
200
222
df = import_function (path )
201
- elif annotation_format == 'TextGrid' :
202
- from .converters import TextGridConverter
203
- df = TextGridConverter .convert (path )
204
- elif annotation_format == 'eaf' :
205
- from .converters import EafConverter
206
- df = EafConverter .convert (path )
207
- elif annotation_format == 'vtc_rttm' :
208
- from .converters import VtcConverter
209
- df = VtcConverter .convert (path , source_file = filter )
210
- elif annotation_format == 'vcm_rttm' :
211
- from .converters import VcmConverter
212
- df = VcmConverter .convert (path , source_file = filter )
213
- elif annotation_format == 'its' :
214
- from .converters import ItsConverter
215
- df = ItsConverter .convert (path , recording_num = filter )
216
- elif annotation_format == 'alice' :
217
- from .converters import AliceConverter
218
- df = AliceConverter .convert (path , source_file = filter )
219
- elif annotation_format == 'cha' :
220
- from .converters import ChatConverter
221
- df = ChatConverter .convert (path )
223
+ elif annotation_format in converters :
224
+ converter = converters [annotation_format ]
225
+ df = converter .convert (path , filter )
222
226
else :
223
227
raise ValueError ("file format '{}' unknown for '{}'" .format (annotation_format , path ))
224
228
except :
@@ -248,13 +252,29 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
248
252
249
253
df .sort_values (sort_columns , inplace = True )
250
254
251
- annotation_type = annotation ['type' ] if 'type' in annotation else 'csv'
252
- self ._write_annotation (df , os .path .join (self .project .path , output_filename ), annotation_type )
255
+ if 'type' not in annotation or pd .isnull (annotation ['type' ]):
256
+ annotation ['type' ] = 'csv'
257
+
258
+ annotation_filename = "{}_{}_{}.{}" .format (
259
+ source_recording ,
260
+ annotation ['time_seek' ],
261
+ annotation ['range_onset' ],
262
+ annotation ['type' ]
263
+ )
264
+
265
+ self ._write_annotation (
266
+ df ,
267
+ annotation ['set' ],
268
+ annotation_filename
269
+ )
253
270
254
271
annotation ['annotation_filename' ] = annotation_filename
255
272
annotation ['imported_at' ] = datetime .datetime .now ().strftime ("%Y-%m-%d %H:%M:%S" )
256
273
annotation ['package_version' ] = __version__
257
274
275
+ if pd .isnull (annotation ['format' ]):
276
+ annotation ['format' ] = 'NA'
277
+
258
278
return annotation
259
279
260
280
def import_annotations (self , input : pd .DataFrame , threads : int = - 1 , import_function : Callable [[str ], pd .DataFrame ] = None ) -> pd .DataFrame :
@@ -279,8 +299,8 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
279
299
input ['range_onset' ] = input ['range_onset' ].astype (int )
280
300
input ['range_offset' ] = input ['range_offset' ].astype (int )
281
301
282
- builtin = input [input ['format' ].isin (is_thread_safe .keys ())]
283
- if not builtin ['format' ].map (is_thread_safe ).all ():
302
+ builtin = input [input ['format' ].isin (converters .keys ())]
303
+ if not builtin ['format' ].map (lambda f : converters [ f ]. THREAD_SAFE ).all ():
284
304
print ('warning: some of the converters do not support multithread importation; running on 1 thread' )
285
305
threads = 1
286
306
@@ -298,7 +318,7 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
298
318
299
319
self .read ()
300
320
self .annotations = pd .concat ([self .annotations , imported ], sort = False )
301
- self .annotations . to_csv ( os . path . join ( self . project . path , 'metadata/annotations.csv' ), index = False )
321
+ self .write ( )
302
322
303
323
return imported
304
324
@@ -357,7 +377,7 @@ def remove_set(self, annotation_set: str, recursive: bool = False):
357
377
pass
358
378
359
379
self .annotations = self .annotations [self .annotations ['set' ] != annotation_set ]
360
- self .annotations . to_csv ( os . path . join ( self . project . path , 'metadata/annotations.csv' ), index = False )
380
+ self .write ( )
361
381
362
382
def rename_set (self , annotation_set : str , new_set : str , recursive : bool = False , ignore_errors : bool = False ):
363
383
"""Rename a set of annotations, moving all related files
@@ -410,20 +430,20 @@ def rename_set(self, annotation_set: str, new_set: str, recursive: bool = False,
410
430
move (os .path .join (current_path , 'converted' ), os .path .join (new_path , 'converted' ))
411
431
412
432
self .annotations .loc [(self .annotations ['set' ] == annotation_set ), 'set' ] = new_set
433
+ self .write ()
413
434
414
- self .annotations .to_csv (os .path .join (self .project .path , 'metadata/annotations.csv' ), index = False )
415
-
416
- def merge_annotations (self , left_columns , right_columns , columns , output_set , input ):
435
+ def merge_annotations (self , left_columns , right_columns , columns , output_set , type , input ):
417
436
left_annotations = input ['left_annotations' ]
418
437
right_annotations = input ['right_annotations' ]
419
438
420
439
annotations = left_annotations .copy ()
421
440
annotations ['format' ] = ''
422
441
annotations ['annotation_filename' ] = annotations .apply (
423
- lambda annotation : "{}_{}_{}.csv " .format (
442
+ lambda annotation : "{}_{}_{}.{} " .format (
424
443
os .path .splitext (annotation ['recording_filename' ])[0 ],
425
444
annotation ['time_seek' ],
426
- annotation ['range_onset' ]
445
+ annotation ['range_onset' ],
446
+ type
427
447
)
428
448
, axis = 1 )
429
449
@@ -497,16 +517,19 @@ def merge_annotations(self, left_columns, right_columns, columns, output_set, in
497
517
498
518
segments = output_segments [output_segments ['interval' ] == interval ]
499
519
segments .drop (columns = list (set (segments .columns )- {c .name for c in self .SEGMENTS_COLUMNS }), inplace = True )
500
- segments .to_csv (
501
- os .path .join (self .project .path , 'annotations' , annotation_set , 'converted' , annotation_filename ),
502
- index = False
520
+
521
+ self ._write_annotation (
522
+ segments ,
523
+ annotation_set ,
524
+ annotation_filename
503
525
)
504
526
505
527
return annotations
506
528
507
529
def merge_sets (self , left_set : str , right_set : str ,
508
530
left_columns : List [str ], right_columns : List [str ],
509
531
output_set : str , columns : dict = {},
532
+ type = 'csv' ,
510
533
threads = - 1
511
534
):
512
535
"""Merge columns from ``left_set`` and ``right_set`` annotations,
@@ -556,15 +579,19 @@ def merge_sets(self, left_set: str, right_set: str,
556
579
for recording in left_annotations ['recording_filename' ].unique ()
557
580
]
558
581
559
- pool = mp .Pool (processes = threads if threads > 0 else mp .cpu_count ())
560
- annotations = pool .map (partial (self .merge_annotations , left_columns , right_columns , columns , output_set ), input_annotations )
582
+ with mp .Pool (processes = threads if threads > 0 else mp .cpu_count ()) as pool :
583
+ annotations = pool .map (
584
+ partial (self .merge_annotations , left_columns , right_columns , columns , output_set , type ),
585
+ input_annotations
586
+ )
587
+
561
588
annotations = pd .concat (annotations )
562
589
annotations .drop (columns = list (set (annotations .columns )- {c .name for c in self .INDEX_COLUMNS }), inplace = True )
563
590
annotations .fillna ({'raw_filename' : 'NA' }, inplace = True )
564
591
565
592
self .read ()
566
593
self .annotations = pd .concat ([self .annotations , annotations ], sort = False )
567
- self .annotations . to_csv ( os . path . join ( self . project . path , 'metadata/annotations.csv' ), index = False )
594
+ self .write ( )
568
595
569
596
def get_segments (self , annotations : pd .DataFrame ) -> pd .DataFrame :
570
597
"""get all segments associated to the annotations referenced in ``annotations``.
@@ -580,9 +607,8 @@ def get_segments(self, annotations: pd.DataFrame) -> pd.DataFrame:
580
607
segments = []
581
608
for index , _annotations in annotations .groupby (['set' , 'annotation_filename' ]):
582
609
s , annotation_filename = index
583
- path = os .path .join (self .project .path , 'annotations' , s , 'converted' , annotation_filename )
584
610
annotation_type = _annotations ['type' ].iloc [0 ] if 'type' in _annotations .columns else 'csv'
585
- df = self ._read_annotation (path , annotation_type )
611
+ df = self ._read_annotation (s , annotation_filename , annotation_type )
586
612
587
613
for annotation in _annotations .to_dict (orient = 'records' ):
588
614
segs = df .copy ()
0 commit comments