Skip to content

Commit fad5adc

Browse files
multiple storage support for annotations: add gzip
(proof of concept) fixes
1 parent cde134e commit fad5adc

File tree

8 files changed

+165
-98
lines changed

8 files changed

+165
-98
lines changed

ChildProject/annotations.py

+90-47
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from . import __version__
1313
from .projects import ChildProject
14-
from .converters import is_thread_safe
14+
from .converters import *
1515
from .tables import IndexTable, IndexColumn
1616
from .utils import Segment, intersect_ranges
1717

@@ -31,10 +31,11 @@ class AnnotationManager:
3131
IndexColumn(name = 'range_onset', description = 'covered range start time in milliseconds, measured since `time_seek`', regex = r"([0-9]+)", required = True),
3232
IndexColumn(name = 'range_offset', description = 'covered range end time in milliseconds, measured since `time_seek`', regex = r"([0-9]+)", required = True),
3333
IndexColumn(name = 'raw_filename', description = 'annotation input filename location, relative to `annotations/<set>/raw`', filename = True, required = True),
34-
IndexColumn(name = 'format', description = 'input annotation format', choices = ['TextGrid', 'eaf', 'vtc_rttm', 'vcm_rttm', 'alice', 'its', 'cha'], required = False),
34+
IndexColumn(name = 'format', description = 'input annotation format', choices = [*converters.keys(), 'NA'], required = False),
3535
IndexColumn(name = 'filter', description = 'source file to filter in (for rttm and alice only)', required = False),
3636
IndexColumn(name = 'annotation_filename', description = 'output formatted annotation location, relative to `annotations/<set>/converted (automatic column, don\'t specify)', filename = True, required = False, generated = True),
3737
IndexColumn(name = 'imported_at', description = 'importation date (automatic column, don\'t specify)', datetime = "%Y-%m-%d %H:%M:%S", required = False, generated = True),
38+
IndexColumn(name = 'type', description = 'annotation storage format', choices = ['csv', 'gz'], required = False),
3839
IndexColumn(name = 'package_version', description = 'version of the package used when the importation was performed', regex = r"[0-9]+\.[0-9]+\.[0-9]+", required = False, generated = True),
3940
IndexColumn(name = 'error', description = 'error message in case the annotation could not be imported', required = False, generated = True)
4041
]
@@ -117,19 +118,32 @@ def read(self) -> Tuple[List[str], List[str]]:
117118

118119
return errors, warnings
119120

121+
def write(self):
122+
"""Update the annotations index,
123+
while enforcing its good shape.
124+
"""
125+
self.annotations[['time_seek', 'range_onset', 'range_offset']].fillna(0, inplace = True)
126+
self.annotations[['time_seek', 'range_onset', 'range_offset']] = self.annotations[['time_seek', 'range_onset', 'range_offset']].astype(int)
127+
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
128+
120129
def validate_annotation(self, annotation: dict) -> Tuple[List[str], List[str]]:
121-
print("validating {}...".format(annotation['annotation_filename']))
130+
print("validating {} from {}...".format(annotation['annotation_filename'], annotation['set']))
122131

123132
segments = IndexTable(
124133
'segments',
125-
path = os.path.join(self.project.path, 'annotations', annotation['set'], 'converted', str(annotation['annotation_filename'])),
134+
path = os.path.join(self.project.path, 'annotations', annotation['set'], 'converted', annotation['annotation_filename']),
126135
columns = self.SEGMENTS_COLUMNS
127136
)
128137

129138
try:
130139
segments.read()
131140
except Exception as e:
132-
return [str(e)], []
141+
error_message = "error while trying to read {} from {}:\n\t{}".format(
142+
annotation['annotation_filename'],
143+
annotation['set'],
144+
str(e)
145+
)
146+
return [error_message], []
133147

134148
return segments.validate()
135149

@@ -158,6 +172,32 @@ def validate(self, annotations: pd.DataFrame = None, threads: int = 0) -> Tuple[
158172

159173
return errors, warnings
160174

175+
def _read_annotation(self, set: str, filename: str, annotation_type: str = None):
176+
path = os.path.join(self.project.path, 'annotations', set, 'converted', filename)
177+
ext = os.path.splitext(filename)[1]
178+
179+
if ext == '.gz':
180+
return pd.read_csv(path, compression = 'gzip')
181+
elif ext == '.csv':
182+
return pd.read_csv(path)
183+
else:
184+
raise ValueError(f"invalid extension '{ext}' for annotation {set}/{filename}'")
185+
186+
def _write_annotation(self, df: pd.DataFrame, set: str, filename: str):
187+
path = os.path.join(self.project.path, 'annotations', set, 'converted', filename)
188+
ext = os.path.splitext(filename)[1]
189+
190+
os.makedirs(os.path.dirname(path), exist_ok = True)
191+
192+
if ext == '.gz':
193+
df.to_csv(path, index = False, compression = 'gzip')
194+
elif ext == '.csv':
195+
df.to_csv(path, index = False)
196+
else:
197+
raise ValueError(f"invalid extension '{ext}' for annotation {set}/{filename}'")
198+
199+
200+
return os.path.basename(path)
161201

162202
def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], annotation: dict):
163203
"""import and convert ``annotation``. This function should not be called outside of this class.
@@ -171,9 +211,6 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
171211
"""
172212

173213
source_recording = os.path.splitext(annotation['recording_filename'])[0]
174-
annotation_filename = "{}_{}_{}.csv".format(source_recording, annotation['time_seek'], annotation['range_onset'])
175-
output_filename = os.path.join('annotations', annotation['set'], 'converted', annotation_filename)
176-
177214
path = os.path.join(self.project.path, 'annotations', annotation['set'], 'raw', annotation['raw_filename'])
178215
annotation_format = annotation['format']
179216

@@ -183,27 +220,9 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
183220
try:
184221
if callable(import_function):
185222
df = import_function(path)
186-
elif annotation_format == 'TextGrid':
187-
from .converters import TextGridConverter
188-
df = TextGridConverter.convert(path)
189-
elif annotation_format == 'eaf':
190-
from .converters import EafConverter
191-
df = EafConverter.convert(path)
192-
elif annotation_format == 'vtc_rttm':
193-
from .converters import VtcConverter
194-
df = VtcConverter.convert(path, source_file = filter)
195-
elif annotation_format == 'vcm_rttm':
196-
from .converters import VcmConverter
197-
df = VcmConverter.convert(path, source_file = filter)
198-
elif annotation_format == 'its':
199-
from .converters import ItsConverter
200-
df = ItsConverter.convert(path, recording_num = filter)
201-
elif annotation_format == 'alice':
202-
from .converters import AliceConverter
203-
df = AliceConverter.convert(path, source_file = filter)
204-
elif annotation_format == 'cha':
205-
from .converters import ChatConverter
206-
df = ChatConverter.convert(path)
223+
elif annotation_format in converters:
224+
converter = converters[annotation_format]
225+
df = converter.convert(path, filter)
207226
else:
208227
raise ValueError("file format '{}' unknown for '{}'".format(annotation_format, path))
209228
except:
@@ -233,13 +252,29 @@ def _import_annotation(self, import_function: Callable[[str], pd.DataFrame], ann
233252

234253
df.sort_values(sort_columns, inplace = True)
235254

236-
os.makedirs(os.path.dirname(os.path.join(self.project.path, output_filename)), exist_ok = True)
237-
df.to_csv(os.path.join(self.project.path, output_filename), index = False)
255+
if 'type' not in annotation or pd.isnull(annotation['type']):
256+
annotation['type'] = 'csv'
257+
258+
annotation_filename = "{}_{}_{}.{}".format(
259+
source_recording,
260+
annotation['time_seek'],
261+
annotation['range_onset'],
262+
annotation['type']
263+
)
264+
265+
self._write_annotation(
266+
df,
267+
annotation['set'],
268+
annotation_filename
269+
)
238270

239271
annotation['annotation_filename'] = annotation_filename
240272
annotation['imported_at'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
241273
annotation['package_version'] = __version__
242274

275+
if pd.isnull(annotation['format']):
276+
annotation['format'] = 'NA'
277+
243278
return annotation
244279

245280
def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_function: Callable[[str], pd.DataFrame] = None) -> pd.DataFrame:
@@ -264,8 +299,8 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
264299
input['range_onset'] = input['range_onset'].astype(int)
265300
input['range_offset'] = input['range_offset'].astype(int)
266301

267-
builtin = input[input['format'].isin(is_thread_safe.keys())]
268-
if not builtin['format'].map(is_thread_safe).all():
302+
builtin = input[input['format'].isin(converters.keys())]
303+
if not builtin['format'].map(lambda f: converters[f].THREAD_SAFE).all():
269304
print('warning: some of the converters do not support multithread importation; running on 1 thread')
270305
threads = 1
271306

@@ -283,7 +318,7 @@ def import_annotations(self, input: pd.DataFrame, threads: int = -1, import_func
283318

284319
self.read()
285320
self.annotations = pd.concat([self.annotations, imported], sort = False)
286-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
321+
self.write()
287322

288323
return imported
289324

@@ -342,7 +377,7 @@ def remove_set(self, annotation_set: str, recursive: bool = False):
342377
pass
343378

344379
self.annotations = self.annotations[self.annotations['set'] != annotation_set]
345-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
380+
self.write()
346381

347382
def rename_set(self, annotation_set: str, new_set: str, recursive: bool = False, ignore_errors: bool = False):
348383
"""Rename a set of annotations, moving all related files
@@ -395,20 +430,20 @@ def rename_set(self, annotation_set: str, new_set: str, recursive: bool = False,
395430
move(os.path.join(current_path, 'converted'), os.path.join(new_path, 'converted'))
396431

397432
self.annotations.loc[(self.annotations['set'] == annotation_set), 'set'] = new_set
433+
self.write()
398434

399-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
400-
401-
def merge_annotations(self, left_columns, right_columns, columns, output_set, input):
435+
def merge_annotations(self, left_columns, right_columns, columns, output_set, type, input):
402436
left_annotations = input['left_annotations']
403437
right_annotations = input['right_annotations']
404438

405439
annotations = left_annotations.copy()
406440
annotations['format'] = ''
407441
annotations['annotation_filename'] = annotations.apply(
408-
lambda annotation: "{}_{}_{}.csv".format(
442+
lambda annotation: "{}_{}_{}.{}".format(
409443
os.path.splitext(annotation['recording_filename'])[0],
410444
annotation['time_seek'],
411-
annotation['range_onset']
445+
annotation['range_onset'],
446+
type
412447
)
413448
, axis = 1)
414449

@@ -482,16 +517,19 @@ def merge_annotations(self, left_columns, right_columns, columns, output_set, in
482517

483518
segments = output_segments[output_segments['interval'] == interval]
484519
segments.drop(columns = list(set(segments.columns)-{c.name for c in self.SEGMENTS_COLUMNS}), inplace = True)
485-
segments.to_csv(
486-
os.path.join(self.project.path, 'annotations', annotation_set, 'converted', annotation_filename),
487-
index = False
520+
521+
self._write_annotation(
522+
segments,
523+
annotation_set,
524+
annotation_filename
488525
)
489526

490527
return annotations
491528

492529
def merge_sets(self, left_set: str, right_set: str,
493530
left_columns: List[str], right_columns: List[str],
494531
output_set: str, columns: dict = {},
532+
type = 'csv',
495533
threads = -1
496534
):
497535
"""Merge columns from ``left_set`` and ``right_set`` annotations,
@@ -541,15 +579,19 @@ def merge_sets(self, left_set: str, right_set: str,
541579
for recording in left_annotations['recording_filename'].unique()
542580
]
543581

544-
pool = mp.Pool(processes = threads if threads > 0 else mp.cpu_count())
545-
annotations = pool.map(partial(self.merge_annotations, left_columns, right_columns, columns, output_set), input_annotations)
582+
with mp.Pool(processes = threads if threads > 0 else mp.cpu_count()) as pool:
583+
annotations = pool.map(
584+
partial(self.merge_annotations, left_columns, right_columns, columns, output_set, type),
585+
input_annotations
586+
)
587+
546588
annotations = pd.concat(annotations)
547589
annotations.drop(columns = list(set(annotations.columns)-{c.name for c in self.INDEX_COLUMNS}), inplace = True)
548590
annotations.fillna({'raw_filename': 'NA'}, inplace = True)
549591

550592
self.read()
551593
self.annotations = pd.concat([self.annotations, annotations], sort = False)
552-
self.annotations.to_csv(os.path.join(self.project.path, 'metadata/annotations.csv'), index = False)
594+
self.write()
553595

554596
def get_segments(self, annotations: pd.DataFrame) -> pd.DataFrame:
555597
"""get all segments associated to the annotations referenced in ``annotations``.
@@ -565,7 +607,8 @@ def get_segments(self, annotations: pd.DataFrame) -> pd.DataFrame:
565607
segments = []
566608
for index, _annotations in annotations.groupby(['set', 'annotation_filename']):
567609
s, annotation_filename = index
568-
df = pd.read_csv(os.path.join(self.project.path, 'annotations', s, 'converted', annotation_filename))
610+
annotation_type = _annotations['type'].iloc[0] if 'type' in _annotations.columns else 'csv'
611+
df = self._read_annotation(s, annotation_filename, annotation_type)
569612

570613
for annotation in _annotations.to_dict(orient = 'records'):
571614
segs = df.copy()

ChildProject/converters.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pandas as pd
33
import re
44

5+
converters = {}
6+
57
class AnnotationConverter:
68
SPEAKER_ID_TO_TYPE = defaultdict(lambda: 'NA', {
79
'C1': 'OCH',
@@ -50,7 +52,12 @@ class AnnotationConverter:
5052

5153
THREAD_SAFE = True
5254

55+
def __init_subclass__(cls, **kwargs):
56+
super().__init_subclass__(**kwargs)
57+
converters[cls.FORMAT] = cls
58+
5359
class VtcConverter(AnnotationConverter):
60+
FORMAT = 'vtc_rttm'
5461

5562
SPEAKER_TYPE_TRANSLATION = defaultdict(lambda: 'NA', {
5663
'CHI': 'OCH',
@@ -80,6 +87,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame:
8087
return df
8188

8289
class VcmConverter(AnnotationConverter):
90+
FORMAT = 'vcm_rttm'
8391

8492
SPEAKER_TYPE_TRANSLATION = defaultdict(lambda: 'NA', {
8593
'CHI': 'OCH',
@@ -119,6 +127,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame:
119127
return df
120128

121129
class AliceConverter(AnnotationConverter):
130+
FORMAT = 'alice'
122131

123132
@staticmethod
124133
def convert(filename: str, source_file: str = '') -> pd.DataFrame:
@@ -142,6 +151,7 @@ def convert(filename: str, source_file: str = '') -> pd.DataFrame:
142151
return df
143152

144153
class ItsConverter(AnnotationConverter):
154+
FORMAT = 'its'
145155

146156
SPEAKER_TYPE_TRANSLATION = defaultdict(lambda: 'NA', {
147157
'CHN': 'CHI',
@@ -277,9 +287,10 @@ def extract_from_regex(pattern, subject):
277287
return df
278288

279289
class TextGridConverter(AnnotationConverter):
290+
FORMAT = 'TextGrid'
280291

281292
@staticmethod
282-
def convert(filename: str) -> pd.DataFrame:
293+
def convert(filename: str, filter = None) -> pd.DataFrame:
283294
import pympi
284295
textgrid = pympi.Praat.TextGrid(filename)
285296

@@ -316,9 +327,10 @@ def ling_type(s):
316327
return pd.DataFrame(segments)
317328

318329
class EafConverter(AnnotationConverter):
330+
FORMAT = 'eaf'
319331

320332
@staticmethod
321-
def convert(filename: str) -> pd.DataFrame:
333+
def convert(filename: str, filter = None) -> pd.DataFrame:
322334
import pympi
323335
eaf = pympi.Elan.Eaf(filename)
324336

@@ -387,6 +399,7 @@ def convert(filename: str) -> pd.DataFrame:
387399
return pd.DataFrame(segments.values())
388400

389401
class ChatConverter(AnnotationConverter):
402+
FORMAT = 'cha'
390403
THREAD_SAFE = False
391404

392405
SPEAKER_ROLE_TO_TYPE = defaultdict(lambda: 'NA', {
@@ -450,7 +463,7 @@ def role_to_addressee(role):
450463
return ChatConverter.ADDRESSEE_TABLE[ChatConverter.SPEAKER_ROLE_TO_TYPE[role]]
451464

452465
@staticmethod
453-
def convert(filename: str) -> pd.DataFrame:
466+
def convert(filename: str, filter = None) -> pd.DataFrame:
454467

455468
import pylangacq
456469

@@ -487,14 +500,4 @@ def convert(filename: str) -> pd.DataFrame:
487500
df.drop(columns = ['participant', 'tokens', 'time_marks'], inplace = True)
488501
df.fillna('NA', inplace = True)
489502

490-
return df
491-
492-
is_thread_safe = {
493-
'its': ItsConverter.THREAD_SAFE,
494-
'vtc_rttm': VtcConverter.THREAD_SAFE,
495-
'vcm_rttm': VcmConverter.THREAD_SAFE,
496-
'eaf': EafConverter.THREAD_SAFE,
497-
'TextGrid': TextGridConverter.THREAD_SAFE,
498-
'alice': AliceConverter.THREAD_SAFE,
499-
'cha': ChatConverter.THREAD_SAFE
500-
}
503+
return df

0 commit comments

Comments
 (0)