Skip to content

Commit

Permalink
acoustic_feature_extractor.py 不使用要素の削除 (#781)
Browse files Browse the repository at this point in the history
* Refactor phoneme handler by removing JvsPhoneme

* Refactor phoneme handler by removing unused dict
  • Loading branch information
tarepan authored Nov 26, 2023
1 parent b768811 commit 44dc4b5
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 188 deletions.
98 changes: 1 addition & 97 deletions test/test_acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
from typing import List, Type
from unittest import TestCase

from voicevox_engine.acoustic_feature_extractor import (
BasePhoneme,
JvsPhoneme,
OjtPhoneme,
)
from voicevox_engine.acoustic_feature_extractor import BasePhoneme, OjtPhoneme


class TestBasePhoneme(TestCase):
Expand Down Expand Up @@ -86,95 +82,6 @@ def lab_test_base(
os.remove(file_path)


class TestJvsPhoneme(TestBasePhoneme):
def setUp(self):
super().setUp()
base_hello_hiho = [
JvsPhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]
self.jvs_hello_hiho = JvsPhoneme.convert(base_hello_hiho)

def test_phoneme_list(self):
self.assertEqual(JvsPhoneme.phoneme_list[1], "I")
self.assertEqual(JvsPhoneme.phoneme_list[14], "gy")
self.assertEqual(JvsPhoneme.phoneme_list[26], "p")
self.assertEqual(JvsPhoneme.phoneme_list[38], "z")

def test_const(self):
self.assertEqual(JvsPhoneme.num_phoneme, 39)
self.assertEqual(JvsPhoneme.space_phoneme, "pau")

def test_convert(self):
converted_str_hello_hiho = " ".join([p.phoneme for p in self.jvs_hello_hiho])
self.assertEqual(
converted_str_hello_hiho, "pau k o N n i ch i w a pau h i h o d e s U pau"
)

def test_equal(self):
# jvs_hello_hihoの2番目の"k"と比較
true_jvs_phoneme = JvsPhoneme("k", 1, 2)
# OjtPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue
true_ojt_phoneme = OjtPhoneme("k", 1, 2)

false_jvs_phoneme_1 = JvsPhoneme("a", 1, 2)
false_jvs_phoneme_2 = JvsPhoneme("k", 2, 3)
self.assertTrue(self.jvs_hello_hiho[1] == true_jvs_phoneme)
self.assertTrue(self.jvs_hello_hiho[1] == true_ojt_phoneme)
self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_1)
self.assertFalse(self.jvs_hello_hiho[1] == false_jvs_phoneme_2)

def test_verify(self):
for phoneme in self.jvs_hello_hiho:
phoneme.verify()

def test_phoneme_id(self):
jvs_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.jvs_hello_hiho])
self.assertEqual(
jvs_str_hello_hiho, "0 19 25 2 23 17 7 17 36 4 0 15 17 15 25 9 11 30 3 0"
)

def test_onehot(self):
phoneme_id_list = [
0,
19,
25,
2,
23,
17,
7,
17,
36,
4,
0,
15,
17,
15,
25,
9,
11,
30,
3,
0,
]
for i, phoneme in enumerate(self.jvs_hello_hiho):
for j in range(JvsPhoneme.num_phoneme):
if phoneme_id_list[i] == j:
self.assertEqual(phoneme.onehot[j], True)
else:
self.assertEqual(phoneme.onehot[j], False)

def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "15.32654 16.39454 a"
parsed_jvs_1 = JvsPhoneme.parse(parse_str_1)
parsed_jvs_2 = JvsPhoneme.parse(parse_str_2)
self.assertEqual(parsed_jvs_1.phoneme_id, 0)
self.assertEqual(parsed_jvs_2.phoneme_id, 4)

def test_lab_list(self):
self.lab_test_base("./jvs_lab_test", self.jvs_hello_hiho, JvsPhoneme)


class TestOjtPhoneme(TestBasePhoneme):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -204,13 +111,10 @@ def test_convert(self):
def test_equal(self):
# ojt_hello_hihoの10番目の"a"と比較
true_ojt_phoneme = OjtPhoneme("a", 9, 10)
# JvsPhonemeと比べる、比較はBasePhoneme内で実装されているので、比較結果はTrue
true_jvs_phoneme = JvsPhoneme("a", 9, 10)

false_ojt_phoneme_1 = OjtPhoneme("k", 9, 10)
false_ojt_phoneme_2 = OjtPhoneme("a", 10, 11)
self.assertTrue(self.ojt_hello_hiho[9] == true_ojt_phoneme)
self.assertTrue(self.ojt_hello_hiho[9] == true_jvs_phoneme)
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1)
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2)

Expand Down
91 changes: 0 additions & 91 deletions voicevox_engine/acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from enum import Enum
from pathlib import Path
from typing import List, Sequence

Expand Down Expand Up @@ -157,85 +156,6 @@ def save_lab_list(cls, phonemes: List["BasePhoneme"], path: Path):
path.write_text(text)


class JvsPhoneme(BasePhoneme):
"""
JVS(Japanese versatile speech)コーパスに含まれる音素群クラス
Attributes
----------
phoneme_list : Sequence[str]
音素のリスト
num_phoneme : int
音素リストの要素数
space_phoneme : str
読点に値する音素
"""

phoneme_list = (
"pau",
"I",
"N",
"U",
"a",
"b",
"by",
"ch",
"cl",
"d",
"dy",
"e",
"f",
"g",
"gy",
"h",
"hy",
"i",
"j",
"k",
"ky",
"m",
"my",
"n",
"ny",
"o",
"p",
"py",
"r",
"ry",
"s",
"sh",
"t",
"ts",
"u",
"v",
"w",
"y",
"z",
)
num_phoneme = len(phoneme_list)
space_phoneme = "pau"

@classmethod
def convert(cls, phonemes: List["JvsPhoneme"]) -> List["JvsPhoneme"]:
"""
最初と最後のsil(silent)をspace_phoneme(pau)に置き換え(変換)する
Parameters
----------
phonemes : List[JvsPhoneme]
変換したいphonemeのリスト
Returns
-------
phonemes : List[JvsPhoneme]
変換されたphonemeのリスト
"""
if "sil" in phonemes[0].phoneme:
phonemes[0].phoneme = cls.space_phoneme
if "sil" in phonemes[-1].phoneme:
phonemes[-1].phoneme = cls.space_phoneme
return phonemes


class OjtPhoneme(BasePhoneme):
"""
OpenJTalkに含まれる音素群クラス
Expand Down Expand Up @@ -319,14 +239,3 @@ def convert(cls, phonemes: List["OjtPhoneme"]):
if "sil" in phonemes[-1].phoneme:
phonemes[-1].phoneme = cls.space_phoneme
return phonemes


class PhonemeType(str, Enum):
jvs = "jvs"
openjtalk = "openjtalk"


phoneme_type_to_class = {
PhonemeType.jvs: JvsPhoneme,
PhonemeType.openjtalk: OjtPhoneme,
}

0 comments on commit 44dc4b5

Please sign in to comment.