Skip to content

Commit

Permalink
整理: テスト入力生成ユーティリティ (#979)
Browse files Browse the repository at this point in the history
* refactor: テスト入力 gen util

* refactor: 入力データ関数切り出し
  • Loading branch information
tarepan authored Jan 6, 2024
1 parent 77a3343 commit c14f131
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions test/test_tts_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
from typing import Union
from unittest import TestCase
from unittest.mock import Mock
Expand Down Expand Up @@ -462,57 +461,60 @@ def test_raw_wave_to_output_wave_without_resample():
assert numpy.allclose(wave, true_wave)


def _gen_hello_hiho_phonemes() -> list[Phoneme]:
hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil"
return [Phoneme(p) for p in hello_hiho.split()]


def _gen_hello_hiho_accent_phrases() -> list[AccentPhrase]:
return [
AccentPhrase(
moras=[
_gen_mora("コ", "k", 0.0, "o", 0.0, 0.0),
_gen_mora("ン", None, None, "N", 0.0, 0.0),
_gen_mora("ニ", "n", 0.0, "i", 0.0, 0.0),
_gen_mora("チ", "ch", 0.0, "i", 0.0, 0.0),
_gen_mora("ワ", "w", 0.0, "a", 0.0, 0.0),
],
accent=5,
pause_mora=_gen_mora("、", None, None, "pau", 0.0, 0.0),
),
AccentPhrase(
moras=[
_gen_mora("ヒ", "h", 0.0, "i", 0.0, 0.0),
_gen_mora("ホ", "h", 0.0, "o", 0.0, 0.0),
_gen_mora("デ", "d", 0.0, "e", 0.0, 0.0),
_gen_mora("ス", "s", 0.0, "U", 0.0, 0.0),
],
accent=1,
pause_mora=None,
),
]


class TestTTSEngine(TestCase):
def setUp(self):
super().setUp()
self.str_list_hello_hiho = (
"sil k o N n i ch i w a pau h i h o d e s U sil".split()
)
self.phoneme_data_list_hello_hiho = [
Phoneme(p) for p in "pau k o N n i ch i w a pau h i h o d e s U pau".split()
]
self.accent_phrases_hello_hiho = [
AccentPhrase(
moras=[
_gen_mora("コ", "k", 0.0, "o", 0.0, 0.0),
_gen_mora("ン", None, None, "N", 0.0, 0.0),
_gen_mora("ニ", "n", 0.0, "i", 0.0, 0.0),
_gen_mora("チ", "ch", 0.0, "i", 0.0, 0.0),
_gen_mora("ワ", "w", 0.0, "a", 0.0, 0.0),
],
accent=5,
pause_mora=_gen_mora("、", None, None, "pau", 0.0, 0.0),
),
AccentPhrase(
moras=[
_gen_mora("ヒ", "h", 0.0, "i", 0.0, 0.0),
_gen_mora("ホ", "h", 0.0, "o", 0.0, 0.0),
_gen_mora("デ", "d", 0.0, "e", 0.0, 0.0),
_gen_mora("ス", "s", 0.0, "U", 0.0, 0.0),
],
accent=1,
pause_mora=None,
),
]
core = MockCore()
self.yukarin_s_mock = core.yukarin_s_forward
self.yukarin_sa_mock = core.yukarin_sa_forward
self.decode_mock = core.decode_forward
self.tts_engine = TTSEngine(core=core) # type: ignore[arg-type]

def test_to_flatten_moras(self):
flatten_moras = to_flatten_moras(self.accent_phrases_hello_hiho)
flatten_moras = to_flatten_moras(_gen_hello_hiho_accent_phrases())
true_accent_phrases_hello_hiho = _gen_hello_hiho_accent_phrases()
self.assertEqual(
flatten_moras,
self.accent_phrases_hello_hiho[0].moras
+ [self.accent_phrases_hello_hiho[0].pause_mora]
+ self.accent_phrases_hello_hiho[1].moras,
true_accent_phrases_hello_hiho[0].moras
+ [true_accent_phrases_hello_hiho[0].pause_mora]
+ true_accent_phrases_hello_hiho[1].moras,
)

def test_split_mora(self):
# Outputs
consonant_phoneme_list, vowel_phoneme_list = split_mora(
self.phoneme_data_list_hello_hiho
_gen_hello_hiho_phonemes()
)

ps = ["pau", "o", "N", "i", "i", "a", "pau", "i", "o", "e", "U", "pau"]
Expand Down Expand Up @@ -541,15 +543,13 @@ def test_split_mora(self):
)

def test_pre_process(self):
flatten_moras, phoneme_data_list = pre_process(
deepcopy(self.accent_phrases_hello_hiho)
)
flatten_moras, phoneme_data_list = pre_process(_gen_hello_hiho_accent_phrases())

mora_index = 0
phoneme_index = 1

self.assertTrue(is_same_phoneme(phoneme_data_list[0], Phoneme("pau")))
for accent_phrase in self.accent_phrases_hello_hiho:
for accent_phrase in _gen_hello_hiho_accent_phrases():
moras = accent_phrase.moras
for mora in moras:
self.assertEqual(flatten_moras[mora_index], mora)
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_pre_process(self):

def test_update_length(self):
# Inputs
hello_hiho = deepcopy(self.accent_phrases_hello_hiho)
hello_hiho = _gen_hello_hiho_accent_phrases()
# Outputs & Indirect Outputs(yukarin_sに渡される値)
result = self.tts_engine.update_length(hello_hiho, StyleId(1))
yukarin_s_args = self.yukarin_s_mock.call_args[1]
Expand All @@ -593,7 +593,7 @@ def test_update_length(self):
true_phoneme_list_1 = [0, 23, 30, 4, 28, 21, 10, 21, 42, 7]
true_phoneme_list_2 = [0, 19, 21, 19, 30, 12, 14, 35, 6, 0]
true_phoneme_list = true_phoneme_list_1 + true_phoneme_list_2
true_result = deepcopy(self.accent_phrases_hello_hiho)
true_result = _gen_hello_hiho_accent_phrases()
index = 1

def result_value(i: int) -> float:
Expand Down Expand Up @@ -632,7 +632,7 @@ def test_update_pitch(self):
self.assertEqual(result, true_result)

# Inputs
hello_hiho = deepcopy(self.accent_phrases_hello_hiho)
hello_hiho = _gen_hello_hiho_accent_phrases()
# Outputs & Indirect Outputs(yukarin_saに渡される値)
result = self.tts_engine.update_pitch(hello_hiho, StyleId(1))
yukarin_sa_args = self.yukarin_sa_mock.call_args[1]
Expand All @@ -651,7 +651,7 @@ def test_update_pitch(self):
true_accent_ends = numpy.array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0])
true_phrase_starts = numpy.array([0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])
true_phrase_ends = numpy.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0])
true_result = deepcopy(self.accent_phrases_hello_hiho)
true_result = _gen_hello_hiho_accent_phrases()
index = 1

def result_value(i: int) -> float:
Expand Down

0 comments on commit c14f131

Please sign in to comment.