From a9c86a8d25e998429a197a5c5efade3e9e0b8ad4 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Wed, 22 May 2024 21:48:36 +0000 Subject: [PATCH] fix: use enum field serializers --- dfaligner/config/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dfaligner/config/__init__.py b/dfaligner/config/__init__.py index 6012a97..cf82f12 100644 --- a/dfaligner/config/__init__.py +++ b/dfaligner/config/__init__.py @@ -15,7 +15,7 @@ from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel from everyvoice.config.utils import load_partials from everyvoice.utils import load_config_from_json_or_yaml_path -from pydantic import Field, FilePath, ValidationInfo, model_validator +from pydantic import Field, FilePath, ValidationInfo, field_serializer, model_validator class DFAlignerExtractionMethod(Enum): @@ -34,6 +34,12 @@ class DFAlignerModelConfig(ConfigModel): 512, description="The number of dimensions in the convolutional layers." ) + @field_serializer("target_text_representation_level") + def convert_training_enum( + self, target_text_representation_level: TargetTrainingTextRepresentationLevel + ): + return target_text_representation_level.value + class DFAlignerTrainingConfig(BaseTrainingConfig): optimizer: AdamOptimizer | AdamWOptimizer = Field( @@ -47,6 +53,12 @@ class DFAlignerTrainingConfig(BaseTrainingConfig): description="The alignment extraction algorithm to use. 'beam' will be quicker but possibly less accurate than 'dijkstra'", ) + @field_serializer("extraction_method") + def convert_extraction_method_enum( + self, extraction_method: DFAlignerExtractionMethod + ): + return extraction_method.value + class DFAlignerConfig(BaseModelWithContact): # TODO FastSpeech2Config and DFAlignerConfig are almost identical.