diff --git a/dfaligner/config/__init__.py b/dfaligner/config/__init__.py index 6012a97..cf82f12 100644 --- a/dfaligner/config/__init__.py +++ b/dfaligner/config/__init__.py @@ -15,7 +15,7 @@ from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel from everyvoice.config.utils import load_partials from everyvoice.utils import load_config_from_json_or_yaml_path -from pydantic import Field, FilePath, ValidationInfo, model_validator +from pydantic import Field, FilePath, ValidationInfo, field_serializer, model_validator class DFAlignerExtractionMethod(Enum): @@ -34,6 +34,12 @@ class DFAlignerModelConfig(ConfigModel): 512, description="The number of dimensions in the convolutional layers." ) + @field_serializer("target_text_representation_level") + def convert_training_enum( + self, target_text_representation_level: TargetTrainingTextRepresentationLevel + ): + return target_text_representation_level.value + class DFAlignerTrainingConfig(BaseTrainingConfig): optimizer: AdamOptimizer | AdamWOptimizer = Field( @@ -47,6 +53,12 @@ class DFAlignerTrainingConfig(BaseTrainingConfig): description="The alignment extraction algorithm to use. 'beam' will be quicker but possibly less accurate than 'dijkstra'", ) + @field_serializer("extraction_method") + def convert_extraction_method_enum( + self, extraction_method: DFAlignerExtractionMethod + ): + return extraction_method.value + class DFAlignerConfig(BaseModelWithContact): # TODO FastSpeech2Config and DFAlignerConfig are almost identical. diff --git a/dfaligner/dataset.py b/dfaligner/dataset.py index 4326c66..347650c 100644 --- a/dfaligner/dataset.py +++ b/dfaligner/dataset.py @@ -77,13 +77,16 @@ def predict_dataloader(self): ) def prepare_data(self): - ( - self.train_dataset, - self.val_dataset, - ) = filter_dataset_based_on_target_text_representation_level( + self.train_dataset = filter_dataset_based_on_target_text_representation_level( self.config.model.target_text_representation_level, self.train_dataset, + "training", + self.batch_size, + ) + self.val_dataset = filter_dataset_based_on_target_text_representation_level( + self.config.model.target_text_representation_level, self.val_dataset, + "validation", self.batch_size, ) train_samples = len(self.train_dataset)