Skip to content

Commit

Permalink
feat: New assertions when loading configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelLarkin committed Oct 6, 2023
1 parent 7dd78f5 commit 717fed0
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 60 deletions.
Empty file.
Empty file.
Empty file.
158 changes: 98 additions & 60 deletions everyvoice/tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from everyvoice.config.shared_types import BaseTrainingConfig, LoggerConfig
from everyvoice.config.text_config import TextConfig
from everyvoice.model.aligner.config import AlignerConfig
from everyvoice.model.aligner.DeepForcedAligner.dfaligner.config import (
DFAlignerTrainingConfig,
)
from everyvoice.model.e2e.config import E2ETrainingConfig, EveryVoiceConfig
from everyvoice.model.feature_prediction.config import FeaturePredictionConfig
from everyvoice.model.vocoder.config import VocoderConfig
Expand Down Expand Up @@ -278,6 +281,13 @@ class LoadConfigTest(TestCase):
DATA_DIR = Path(__file__).parent / "data" / "relative" / "config"
DATASET_NAME: str = "relative"

def validate_config_path(self, path: Path):
"""
Helper method to validate a path once loaded by a config.
"""
self.assertTrue(path.is_absolute(), msg=path)
self.assertTrue(path.exists(), msg=path)

def test_aligner_config(self):
"""Create a AlignerConfig which pydantic will validate for us."""
config_path = self.DATA_DIR / f"{ALIGNER_CONFIG_FILENAME_PREFIX}.yaml"
Expand All @@ -287,18 +297,19 @@ def test_aligner_config(self):
Path(pre_test["path_to_preprocessing_config_file"]).is_absolute()
)
self.assertFalse(Path(pre_test["path_to_text_config_file"]).is_absolute())
self.assertFalse(
Path(pre_test["training"]["logger"]["save_dir"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["training_filelist"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["validation_filelist"]).is_absolute()
)
training = pre_test["training"]
self.assertFalse(Path(training["logger"]["save_dir"]).is_absolute())
self.assertFalse(Path(training["training_filelist"]).is_absolute())
self.assertFalse(Path(training["validation_filelist"]).is_absolute())
config = AlignerConfig.load_config_from_path(config_path)
# print(config.model_dump_json(indent=2))
self.assertTrue(isinstance(config, AlignerConfig))
self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME)
self.validate_config_path(config.path_to_preprocessing_config_file)
self.validate_config_path(config.path_to_text_config_file)
self.validate_config_path(config.training.logger.save_dir)
self.validate_config_path(config.training.training_filelist)
self.validate_config_path(config.training.validation_filelist)

def test_preprocessing_config(self):
"""Create a PreprocessingConfig which pydantic will validate for us."""
Expand All @@ -307,11 +318,18 @@ def test_preprocessing_config(self):
pre_test = yaml.safe_load(f)
self.assertFalse(Path(pre_test["save_dir"]).is_absolute())
self.assertEqual(len(pre_test["source_data"]), 1)
self.assertFalse(Path(pre_test["source_data"][0]["data_dir"]).is_absolute())
self.assertFalse(Path(pre_test["source_data"][0]["filelist"]).is_absolute())
for data in pre_test["source_data"]:
self.assertFalse(Path(data["data_dir"]).is_absolute())
self.assertFalse(Path(data["filelist"]).is_absolute())
config = PreprocessingConfig.load_config_from_path(config_path)
# print(config.model_dump_json(indent=2))
self.assertTrue(isinstance(config, PreprocessingConfig))
self.assertEqual(config.dataset, self.DATASET_NAME)
self.validate_config_path(config.save_dir)
self.assertEqual(len(config.source_data), 1)
for data in config.source_data:
self.validate_config_path(data.data_dir)
self.validate_config_path(data.filelist)

def test_feature_prediction_config(self):
"""Create a FeaturePredictionConfig which pydantic will validate for us."""
Expand All @@ -322,18 +340,18 @@ def test_feature_prediction_config(self):
Path(pre_test["path_to_preprocessing_config_file"]).is_absolute()
)
self.assertFalse(Path(pre_test["path_to_text_config_file"]).is_absolute())
self.assertFalse(
Path(pre_test["training"]["logger"]["save_dir"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["training_filelist"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["validation_filelist"]).is_absolute()
)
training = pre_test["training"]
self.assertFalse(Path(training["logger"]["save_dir"]).is_absolute())
self.assertFalse(Path(training["training_filelist"]).is_absolute())
self.assertFalse(Path(training["validation_filelist"]).is_absolute())
config = FeaturePredictionConfig.load_config_from_path(config_path)
# Dummy test as the real test is done during load_config_from_path().
# print(config.model_dump_json(indent=2))
self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME)
self.validate_config_path(config.path_to_text_config_file)
self.validate_config_path(config.path_to_text_config_file)
self.validate_config_path(config.training.logger.save_dir)
self.validate_config_path(config.training.training_filelist)
self.validate_config_path(config.training.validation_filelist)

def test_vocoder_config(self):
"""Create a VocoderConfig which pydantic will validate for us."""
Expand All @@ -343,18 +361,18 @@ def test_vocoder_config(self):
self.assertFalse(
Path(pre_test["path_to_preprocessing_config_file"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["logger"]["save_dir"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["training_filelist"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["validation_filelist"]).is_absolute()
)
training = pre_test["training"]
self.assertFalse(Path(training["logger"]["save_dir"]).is_absolute())
self.assertFalse(Path(training["training_filelist"]).is_absolute())
self.assertFalse(Path(training["validation_filelist"]).is_absolute())
config = VocoderConfig.load_config_from_path(config_path)
# print(config.model_dump_json(indent=2))
self.assertTrue(isinstance(config, VocoderConfig))
self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME)
self.validate_config_path(config.path_to_preprocessing_config_file)
self.validate_config_path(config.training.logger.save_dir)
self.validate_config_path(config.training.training_filelist)
self.validate_config_path(config.training.validation_filelist)

def test_everyvoice_config(self):
"""Create a EveryVoiceConfig which pydantic will validate for us."""
Expand All @@ -370,55 +388,75 @@ def test_everyvoice_config(self):
self.assertFalse(
Path(pre_test["path_to_vocoder_config_file"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["logger"]["save_dir"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["training_filelist"]).is_absolute()
)
self.assertFalse(
Path(pre_test["training"]["validation_filelist"]).is_absolute()
)
training = pre_test["training"]
self.assertFalse(Path(training["logger"]["save_dir"]).is_absolute())
self.assertFalse(Path(training["training_filelist"]).is_absolute())
self.assertFalse(Path(training["validation_filelist"]).is_absolute())
config = EveryVoiceConfig.load_config_from_path(config_path)
# print(config.model_dump_json(indent=2))
self.assertTrue(isinstance(config, EveryVoiceConfig))
self.assertEqual(
config.feature_prediction.preprocessing.dataset, self.DATASET_NAME
)
self.validate_config_path(config.path_to_aligner_config_file)
self.validate_config_path(config.path_to_feature_prediction_config_file)
self.validate_config_path(config.path_to_vocoder_config_file)
self.validate_config_path(config.training.logger.save_dir)
self.validate_config_path(config.training.training_filelist)
self.validate_config_path(config.training.validation_filelist)

def test_absolute_path(self):
"""Load a config that has absolute paths."""
with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir)
_writer_helper(AudioConfig(), tempdir / "audio.json")
config = PreprocessingConfig(
path_to_audio_config_file=(tempdir / "audio.json")
)
self.assertTrue(isinstance(config.audio, AudioConfig))
# Write shared:
# Write preprocessing:
preprocessing_config_path = tempdir / "aligner-preprocessing.json"
_writer_helper(
PreprocessingConfig(dataset=self.DATASET_NAME),
tempdir / "preprocessing.json",
preprocessing_config_path,
)
_writer_helper(TextConfig(), tempdir / "text.json")
_writer_helper(BaseTrainingConfig(), tempdir / "training.json")

# Write text:
text_config_path = tempdir / "aligner-text.json"
_writer_helper(TextConfig(), text_config_path)

# Write training:
aligner_training_path = tempdir / "aligner-training.json"
training = DFAlignerTrainingConfig(
training_filelist=tempdir / "training_filelist.psv",
validation_filelist=tempdir / "validation_filelist.psv",
)
(tempdir / training.logger.save_dir).mkdir(parents=True, exist_ok=True)
(tempdir / training.training_filelist).touch(exist_ok=True)
(tempdir / training.validation_filelist).touch(exist_ok=True)
_writer_helper(training, aligner_training_path)

# Write model:
aligner_model_path = tempdir / "aligner-model.json"
_writer_helper(AlignerConfig().model, aligner_model_path)

# Aligner Config
_writer_helper(AlignerConfig().training, tempdir / "aligner-training.json")
_writer_helper(AlignerConfig().model, tempdir / "aligner-model.json")
aligner_config = AlignerConfig(
path_to_model_config_file=tempdir / "aligner-model.json",
path_to_preprocessing_config_file=tempdir / "preprocessing.json",
path_to_text_config_file=tempdir / "text.json",
path_to_training_config_file=tempdir / "aligner-training.json",
path_to_model_config_file=aligner_model_path,
path_to_preprocessing_config_file=preprocessing_config_path,
path_to_text_config_file=text_config_path,
path_to_training_config_file=aligner_training_path,
)
_writer_helper(aligner_config, tempdir / "aligner.json")
self.assertTrue(isinstance(aligner_config, AlignerConfig))
config = AlignerConfig.load_config_from_path(tempdir / "aligner.json")
aligner_config_path = tempdir / "aligner.json"
_writer_helper(aligner_config, aligner_config_path)

# Reload and validate
config = AlignerConfig.load_config_from_path(aligner_config_path)
self.assertTrue(isinstance(config, AlignerConfig))
self.assertEqual(config.preprocessing.dataset, self.DATASET_NAME)
self.assertTrue(config.path_to_model_config_file.is_absolute())
self.assertTrue(config.path_to_preprocessing_config_file.is_absolute())
self.assertTrue(config.path_to_text_config_file.is_absolute())
self.assertTrue(config.path_to_training_config_file.is_absolute())
self.validate_config_path(config.path_to_model_config_file)
self.validate_config_path(config.path_to_preprocessing_config_file)
self.validate_config_path(config.path_to_text_config_file)
self.validate_config_path(config.path_to_training_config_file)
self.validate_config_path(config.training.logger.save_dir)
self.validate_config_path(config.training.training_filelist)
self.validate_config_path(config.training.validation_filelist)

def test_missing_path(self):
"""Load a config that is missing a partial config file."""
Expand Down

0 comments on commit 717fed0

Please sign in to comment.