Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BasePhoneme 不使用メソッドの削除 #782

Merged
merged 3 commits into from
Nov 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 0 additions & 72 deletions test/test_acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
from pathlib import Path
from typing import List, Type
from unittest import TestCase

from voicevox_engine.acoustic_feature_extractor import BasePhoneme, OjtPhoneme
@@ -13,32 +10,6 @@ def setUp(self):
self.base_hello_hiho = [
BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]
self.lab_str = """
0.00 1.00 pau
1.00 2.00 k
2.00 3.00 o
3.00 4.00 N
4.00 5.00 n
5.00 6.00 i
6.00 7.00 ch
7.00 8.00 i
8.00 9.00 w
9.00 10.00 a
10.00 11.00 pau
11.00 12.00 h
12.00 13.00 i
13.00 14.00 h
14.00 15.00 o
15.00 16.00 d
16.00 17.00 e
17.00 18.00 s
18.00 19.00 U
19.00 20.00 pau
""".replace(
" ", ""
)[
1:-1
] # ダブルクオーテーションx3で囲われている部分で、空白をすべて置き換え、先頭と最後の"\n"を除外する

def test_repr_(self):
self.assertEqual(
@@ -53,34 +24,6 @@ def test_convert(self):
with self.assertRaises(NotImplementedError):
BasePhoneme.convert(self.base_hello_hiho)

def test_duration(self):
self.assertEqual(self.base_hello_hiho[1].duration, 1)

def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "32.67543 33.48933 e"
parsed_base_1 = BasePhoneme.parse(parse_str_1)
parsed_base_2 = BasePhoneme.parse(parse_str_2)
self.assertEqual(parsed_base_1.phoneme, "pau")
self.assertEqual(parsed_base_1.start, 0.0)
self.assertEqual(parsed_base_1.end, 1.0)
self.assertEqual(parsed_base_2.phoneme, "e")
self.assertEqual(parsed_base_2.start, 32.68)
self.assertEqual(parsed_base_2.end, 33.49)

def lab_test_base(
self,
file_path: str,
phonemes: List["BasePhoneme"],
phoneme_class: Type["BasePhoneme"],
):
phoneme_class.save_lab_list(phonemes, Path(file_path))
with open(file_path, mode="r") as f:
self.assertEqual(f.read(), self.lab_str)
result_phoneme = phoneme_class.load_lab_list(Path(file_path))
self.assertEqual(result_phoneme, phonemes)
os.remove(file_path)


class TestOjtPhoneme(TestBasePhoneme):
def setUp(self):
@@ -118,10 +61,6 @@ def test_equal(self):
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_1)
self.assertFalse(self.ojt_hello_hiho[9] == false_ojt_phoneme_2)

def test_verify(self):
for phoneme in self.ojt_hello_hiho:
phoneme.verify()

def test_phoneme_id(self):
ojt_str_hello_hiho = " ".join([str(p.phoneme_id) for p in self.ojt_hello_hiho])
self.assertEqual(
@@ -157,14 +96,3 @@ def test_onehot(self):
self.assertEqual(phoneme.onehot[j], True)
else:
self.assertEqual(phoneme.onehot[j], False)

def test_parse(self):
parse_str_1 = "0 1 pau"
parse_str_2 = "32.67543 33.48933 e"
parsed_ojt_1 = OjtPhoneme.parse(parse_str_1)
parsed_ojt_2 = OjtPhoneme.parse(parse_str_2)
self.assertEqual(parsed_ojt_1.phoneme_id, 0)
self.assertEqual(parsed_ojt_2.phoneme_id, 14)

def tes_lab_list(self):
self.lab_test_base("./ojt_lab_test", self.ojt_hello_hiho, OjtPhoneme)
86 changes: 0 additions & 86 deletions voicevox_engine/acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from pathlib import Path
from typing import List, Sequence

import numpy
@@ -41,12 +40,6 @@ def __eq__(self, o: object):
self.phoneme == o.phoneme and self.start == o.start and self.end == o.end
)

def verify(self):
"""
音素クラスとして、データが正しいかassertする
"""
assert self.phoneme in self.phoneme_list, f"{self.phoneme} is not defined."

@property
def phoneme_id(self):
"""
@@ -58,17 +51,6 @@ def phoneme_id(self):
"""
return self.phoneme_list.index(self.phoneme)

@property
def duration(self):
"""
音素継続期間を取得する
Returns
-------
duration : int
音素継続期間を返す
"""
return self.end - self.start

@property
def onehot(self):
"""
@@ -82,79 +64,11 @@ def onehot(self):
array[self.phoneme_id] = True
return array

@classmethod
def parse(cls, s: str):
"""
文字列をパースして音素クラスを作る
Parameters
----------
s : str
パースしたい文字列
Returns
-------
phoneme : BasePhoneme
パース結果を用いた音素クラスを返す
Examples
--------
>>> BasePhoneme.parse('1.7425000 1.9125000 o:')
Phoneme(phoneme='o:', start=1.74, end=1.91)
"""
words = s.split()
return cls(
start=float(words[0]),
end=float(words[1]),
phoneme=words[2],
)

@classmethod
@abstractmethod
def convert(cls, phonemes: List["BasePhoneme"]) -> List["BasePhoneme"]:
raise NotImplementedError

@classmethod
def load_lab_list(cls, path: Path):
"""
labファイルを読み込む
Parameters
----------
path : Path
読み込みたいlabファイルのパス
Returns
-------
phonemes : List[BasePhoneme]
パース結果を用いた音素クラスを返す
"""
phonemes = [cls.parse(s) for s in path.read_text().split("\n") if len(s) > 0]
phonemes = cls.convert(phonemes)

for phoneme in phonemes:
phoneme.verify()
return phonemes

@classmethod
def save_lab_list(cls, phonemes: List["BasePhoneme"], path: Path):
"""
音素クラスのリストをlabファイル形式で保存する
Parameters
----------
phonemes : List[BasePhoneme]
保存したい音素クラスのリスト
path : Path
labファイルの保存先パス
"""
text = "\n".join(
[
f"{numpy.round(p.start, decimals=2):.2f}\t"
f"{numpy.round(p.end, decimals=2):.2f}\t"
f"{p.phoneme}"
for p in phonemes
]
)
path.write_text(text)


class OjtPhoneme(BasePhoneme):
"""