From ebd92667ccee74aa2720783f54faa0f629b18b1b Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Wed, 15 May 2024 01:14:03 +0000 Subject: [PATCH 1/2] refactor: filter one dataset at a time --- dfaligner/dataset.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dfaligner/dataset.py b/dfaligner/dataset.py index 4326c66..347650c 100644 --- a/dfaligner/dataset.py +++ b/dfaligner/dataset.py @@ -77,13 +77,16 @@ def predict_dataloader(self): ) def prepare_data(self): - ( - self.train_dataset, - self.val_dataset, - ) = filter_dataset_based_on_target_text_representation_level( + self.train_dataset = filter_dataset_based_on_target_text_representation_level( self.config.model.target_text_representation_level, self.train_dataset, + "training", + self.batch_size, + ) + self.val_dataset = filter_dataset_based_on_target_text_representation_level( + self.config.model.target_text_representation_level, self.val_dataset, + "validation", self.batch_size, ) train_samples = len(self.train_dataset) From a9c86a8d25e998429a197a5c5efade3e9e0b8ad4 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Wed, 22 May 2024 21:48:36 +0000 Subject: [PATCH 2/2] fix: use enum field serializers --- dfaligner/config/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dfaligner/config/__init__.py b/dfaligner/config/__init__.py index 6012a97..cf82f12 100644 --- a/dfaligner/config/__init__.py +++ b/dfaligner/config/__init__.py @@ -15,7 +15,7 @@ from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel from everyvoice.config.utils import load_partials from everyvoice.utils import load_config_from_json_or_yaml_path -from pydantic import Field, FilePath, ValidationInfo, model_validator +from pydantic import Field, FilePath, ValidationInfo, field_serializer, model_validator class DFAlignerExtractionMethod(Enum): @@ -34,6 +34,12 @@ class DFAlignerModelConfig(ConfigModel): 512, description="The number of dimensions in the convolutional layers." ) + @field_serializer("target_text_representation_level") + def convert_training_enum( + self, target_text_representation_level: TargetTrainingTextRepresentationLevel + ): + return target_text_representation_level.value + class DFAlignerTrainingConfig(BaseTrainingConfig): optimizer: AdamOptimizer | AdamWOptimizer = Field( @@ -47,6 +53,12 @@ class DFAlignerTrainingConfig(BaseTrainingConfig): description="The alignment extraction algorithm to use. 'beam' will be quicker but possibly less accurate than 'dijkstra'", ) + @field_serializer("extraction_method") + def convert_extraction_method_enum( + self, extraction_method: DFAlignerExtractionMethod + ): + return extraction_method.value + class DFAlignerConfig(BaseModelWithContact): # TODO FastSpeech2Config and DFAlignerConfig are almost identical.