Skip to content

Commit

Permalink
整理: コア・エンジンでバージョンを指定しない場合、暗黙的に最新版を取得する処理を削除 (#1317)
Browse files Browse the repository at this point in the history
* refactor: コアバージョン変換メソッドを追加

* refactor: latest 自動取得を削除

* refactor: ルーターとドメインを分離

* refactor: 使われなくなった型を削除

* refactor: lint

* fix: lint

* refactor: `convert_version_format()` を削除

* fix conflict

---------

Co-authored-by: Hiroshiba Kazuyuki <[email protected]>
  • Loading branch information
tarepan and Hiroshiba authored Jun 18, 2024
1 parent c72b889 commit 95f218a
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 100 deletions.
21 changes: 2 additions & 19 deletions test/unit/test_core_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_cores_latest_version() -> None:
assert true_latest_version == latest_version


def test_cores_get_core_specified() -> None:
"""CoreManager.get_core() で登録済みコアをバージョン指定して取得できる。"""
def test_cores_get_core_existing() -> None:
"""CoreManager.get_core() で登録済みコアを取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
Expand All @@ -69,23 +69,6 @@ def test_cores_get_core_specified() -> None:
assert true_acquired_core == acquired_core


def test_cores_get_core_latest() -> None:
"""CoreManager.get_core() で最新版コアをバージョン未指定で取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
core2 = CoreAdapter(MockCoreWrapper())
core_manager.register_core(core1, "0.0.1")
core_manager.register_core(core2, "0.0.2")
# Expects
true_acquired_core = core2
# Outputs
acquired_core = core_manager.get_core()

# Test
assert true_acquired_core == acquired_core


def test_cores_get_core_missing() -> None:
"""CoreManager.get_core() で存在しないコアを取得しようとするとエラーになる。"""
# Inputs
Expand Down
36 changes: 2 additions & 34 deletions test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,8 @@ def test_tts_engines_versions() -> None:
assert true_versions == versions


def test_tts_engines_latest_version() -> None:
"""TTSEngineManager.latest_version() で最新バージョンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_latest_version = "0.0.2"
# Outputs
latest_version = tts_engines.latest_version()

# Test
assert true_latest_version == latest_version


def test_tts_engines_get_engine_specified() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンをバージョン指定して取得できる。"""
def test_tts_engines_get_engine_existing() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
Expand All @@ -63,23 +48,6 @@ def test_tts_engines_get_engine_specified() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine() で最新版 TTS エンジンをバージョン未指定で取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
# Expects
true_acquired_tts_engine = tts_engine2
# Outputs
acquired_tts_engine = tts_engines.get_engine()

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
supported_devices = core_manager.get_core(core_version).supported_devices
version = core_version or core_manager.latest_version()
supported_devices = core_manager.get_core(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)
Expand Down
9 changes: 6 additions & 3 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def morphable_targets(
プロパティが存在しない場合は、モーフィングが許可されているとみなします。
返り値のスタイルIDはstring型なので注意。
"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)

speakers = metas_store.load_combined_metas(core.speakers)
try:
morphable_targets = get_morphable_targets(speakers, base_style_ids)
Expand Down Expand Up @@ -90,8 +92,9 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)

# モーフィングが許可されないキャラクターペアを拒否する
speakers = metas_store.load_combined_metas(core.speakers)
Expand Down
10 changes: 7 additions & 3 deletions voicevox_engine/app/routers/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def generate_speaker_router(
@router.get("/speakers")
def speakers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""話者情報の一覧を取得します。"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
speakers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(speakers, "speaker")

Expand Down Expand Up @@ -74,8 +75,10 @@ def _speaker_info(
# {speaker_uuid_1}/
# ...

version = core_version or core_manager.latest_version()

# 該当話者を検索する
core_speakers = core_manager.get_core(core_version).speakers
core_speakers = core_manager.get_core(version).speakers
speakers = metas_store.load_combined_metas(core_speakers)
speakers = filter_speakers_and_styles(speakers, speaker_or_singer)
speaker = next(
Expand Down Expand Up @@ -138,7 +141,8 @@ def _speaker_info(
@router.get("/singers")
def singers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""歌手情報の一覧を取得します"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
singers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(singers, "singer")

Expand Down
48 changes: 31 additions & 17 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
Expand Down Expand Up @@ -116,8 +117,9 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
presets = preset_manager.load_presets()
except PresetInputError as err:
Expand Down Expand Up @@ -175,7 +177,8 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
if is_kana:
try:
return engine.create_accent_phrases_from_kana(text, style_id)
Expand All @@ -196,7 +199,8 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
return engine.update_length_and_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -209,7 +213,8 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
return engine.update_length(accent_phrases, style_id)

@router.post(
Expand All @@ -222,7 +227,8 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
return engine.update_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -249,7 +255,8 @@ def synthesis(
] = True,
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)
Expand Down Expand Up @@ -289,9 +296,10 @@ def cancellable_synthesis(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
version = core_version or core_manager.latest_version()
try:
f_name = cancellable_engine._synthesis_impl(
query, style_id, request, core_version=core_version
query, style_id, request, version=version
)
except CancellableEngineInternalError as e:
raise HTTPException(status_code=500, detail=str(e))
Expand Down Expand Up @@ -325,7 +333,8 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -367,8 +376,9 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
Expand Down Expand Up @@ -396,7 +406,8 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[float]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
try:
return engine.create_sing_volume_from_phoneme_and_f0(
score, frame_audio_query.phonemes, frame_audio_query.f0, style_id
Expand Down Expand Up @@ -424,7 +435,8 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
try:
wave = engine.frame_synthsize_wave(query, style_id)
except TalkSingInvalidInputError as e:
Expand Down Expand Up @@ -519,7 +531,8 @@ def initialize_speaker(
指定されたスタイルを初期化します。
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)

@router.get("/is_initialized_speaker", tags=["その他"])
Expand All @@ -530,7 +543,8 @@ def is_initialized_speaker(
"""
指定されたスタイルが初期化されているかどうかを返します。
"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
return core.is_initialized_style_id_synthesis(style_id)

return router
14 changes: 6 additions & 8 deletions voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _synthesis_impl(
query: AudioQuery,
style_id: StyleId,
request: Request,
core_version: str | None,
version: str,
) -> str:
"""
音声合成を行う関数
Expand All @@ -163,7 +163,7 @@ def _synthesis_impl(
request: fastapi.Request
接続確立時に受け取ったものをそのまま渡せばよい
https://fastapi.tiangolo.com/advanced/using-request-directly/
core_version: str
version: str
Returns
-------
Expand All @@ -173,7 +173,7 @@ def _synthesis_impl(
proc, sub_proc_con1 = self.procs_and_cons.get()
self.watch_con_list.append((request, proc))
try:
sub_proc_con1.send((query, style_id, core_version))
sub_proc_con1.send((query, style_id, version))
f_name = sub_proc_con1.recv()
if isinstance(f_name, str):
audio_file_name = f_name
Expand Down Expand Up @@ -244,11 +244,9 @@ def start_synthesis_subprocess(
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"
while True:
try:
query, style_id, core_version = sub_proc_con.recv()
if core_version is None:
_engine = tts_engines.get_engine()
elif tts_engines.has_engine(core_version):
_engine = tts_engines.get_engine(core_version)
query, style_id, version = sub_proc_con.recv()
if tts_engines.has_engine(version):
_engine = tts_engines.get_engine(version)
else:
# バージョンが見つからないエラー
sub_proc_con.send("")
Expand Down
8 changes: 3 additions & 5 deletions voicevox_engine/core/core_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ def register_core(self, core: CoreAdapter, version: str) -> None:
"""コアを登録する。"""
self._cores[version] = core

def get_core(self, version: str | None = None) -> CoreAdapter:
"""指定バージョンのコアを取得する。指定が無い場合、最新バージョンを返す。"""
if version is None:
return self._cores[self.latest_version()]
elif version in self._cores:
def get_core(self, version: str) -> CoreAdapter:
"""指定バージョンのコアを取得する。"""
if version in self._cores:
return self._cores[version]
raise CoreNotFound(f"バージョン {version} のコアが見つかりません")

Expand Down
13 changes: 3 additions & 10 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ..core.core_wrapper import CoreWrapper
from ..metas.Metas import StyleId
from ..model import AudioQuery
from ..utility.core_version_utility import get_latest_version
from .kana_converter import parse_kana
from .model import AccentPhrase, FrameAudioQuery, FramePhoneme, Mora, Note, Score
from .mora_mapping import mora_kana_to_mora_phonemes, mora_phonemes_to_mora_kana
Expand Down Expand Up @@ -702,19 +701,13 @@ def versions(self) -> list[str]:
"""登録されたエンジンのバージョン一覧を取得する。"""
return list(self._engines.keys())

def latest_version(self) -> str:
"""登録された最新版エンジンのバージョンを取得する。"""
return get_latest_version(self.versions())

def register_engine(self, engine: TTSEngine, version: str) -> None:
"""エンジンを登録する。"""
self._engines[version] = engine

def get_engine(self, version: str | None = None) -> TTSEngine:
"""指定バージョンのエンジンを取得する。指定が無い場合、最新バージョンを返す。"""
if version is None:
return self._engines[self.latest_version()]
elif version in self._engines:
def get_engine(self, version: str) -> TTSEngine:
"""指定バージョンのエンジンを取得する。"""
if version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down

0 comments on commit 95f218a

Please sign in to comment.