Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: コア・エンジンでバージョンを指定しない場合、暗黙的に最新版を取得する処理を削除 #1317

Merged
merged 18 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions test/test_core_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_cores_latest_version() -> None:
assert true_latest_version == latest_version


def test_cores_get_core_specified() -> None:
"""CoreManager.get_core() で登録済みコアをバージョン指定して取得できる。"""
def test_cores_get_core_existing() -> None:
"""CoreManager.get_core() で登録済みコアを取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
Expand All @@ -69,23 +69,6 @@ def test_cores_get_core_specified() -> None:
assert true_acquired_core == acquired_core


def test_cores_get_core_latest() -> None:
"""CoreManager.get_core() で最新版コアをバージョン未指定で取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
core2 = CoreAdapter(MockCoreWrapper())
core_manager.register_core(core1, "0.0.1")
core_manager.register_core(core2, "0.0.2")
# Expects
true_acquired_core = core2
# Outputs
acquired_core = core_manager.get_core()

# Test
assert true_acquired_core == acquired_core


def test_cores_get_core_missing() -> None:
"""CoreManager.get_core() で存在しないコアを取得しようとするとエラーになる。"""
# Inputs
Expand Down
36 changes: 36 additions & 0 deletions test/test_router_commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
""" ルーター共通処理のテスト"""

from voicevox_engine.app.routers.commons import convert_version_format
from voicevox_engine.core.core_adapter import CoreAdapter
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.dev.core.mock import MockCoreWrapper


def test_convert_version_format_non_latest() -> None:
"""convert_version_format() で明示的バージョンが維持される。"""
# Inputs
core_manager = CoreManager()
api_format_version = "0.0.2"
# Expects
true_version = "0.0.2"
# Outputs
version = convert_version_format(api_format_version, core_manager)

# Test
assert true_version == version


def test_cores_convert_version_format_latest() -> None:
"""convert_version_format() で latest 表現が変換される。"""
# Inputs
core_manager = CoreManager()
core_manager.register_core(CoreAdapter(MockCoreWrapper()), "0.0.1")
core_manager.register_core(CoreAdapter(MockCoreWrapper()), "0.0.2")
api_format_version = None
# Expects
true_version = "0.0.2"
# Outputs
version = convert_version_format(api_format_version, core_manager)

# Test
assert true_version == version
36 changes: 2 additions & 34 deletions test/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,8 @@ def test_tts_engines_versions() -> None:
assert true_versions == versions


def test_tts_engines_latest_version() -> None:
"""TTSEngineManager.latest_version() で最新バージョンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_latest_version = "0.0.2"
# Outputs
latest_version = tts_engines.latest_version()

# Test
assert true_latest_version == latest_version


def test_tts_engines_get_engine_specified() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンをバージョン指定して取得できる。"""
def test_tts_engines_get_engine_existing() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
Expand All @@ -63,23 +48,6 @@ def test_tts_engines_get_engine_specified() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine() で最新版 TTS エンジンをバージョン未指定で取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
# Expects
true_acquired_tts_engine = tts_engine2
# Outputs
acquired_tts_engine = tts_engines.get_engine()

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
Expand Down
20 changes: 20 additions & 0 deletions voicevox_engine/app/routers/commons.py
tarepan marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""ルーター間で共通する処理"""

from typing import TypeAlias

from voicevox_engine.core.core_initializer import CoreManager

APICoreVersion: TypeAlias = str | None # `None` は latest を意味する
EngineCoreVersion: TypeAlias = str


def convert_version_format(
core_version: APICoreVersion, core_manager: CoreManager
tarepan marked this conversation as resolved.
Show resolved Hide resolved
) -> EngineCoreVersion:
"""
バージョンの形式を API 形式から ENGINE 形式へ変換する。

API 形式は latest を指定でき、それは `None` で表現される。
ENGINE 形式は latest を指定できず、ゆえに `None` を持たない。
tarepan marked this conversation as resolved.
Show resolved Hide resolved
"""
return core_manager.latest_version() if core_version is None else core_version
4 changes: 3 additions & 1 deletion voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel, Field

from voicevox_engine import __version__
from voicevox_engine.app.routers.commons import convert_version_format
from voicevox_engine.core.core_adapter import DeviceSupport
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.engine_manifest import EngineManifest
Expand Down Expand Up @@ -49,7 +50,8 @@ async def core_versions() -> list[str]:
@router.get("/supported_devices")
def supported_devices(core_version: str | None = None) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
supported_devices = core_manager.get_core(core_version).supported_devices
version = convert_version_format(core_version, core_manager)
supported_devices = core_manager.get_core(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)
Expand Down
9 changes: 6 additions & 3 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.app.routers.commons import convert_version_format
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import StyleId
from voicevox_engine.metas.MetasStore import MetasStore, construct_lookup
Expand Down Expand Up @@ -51,7 +52,8 @@ def morphable_targets(
プロパティが存在しない場合は、モーフィングが許可されているとみなします。
返り値のスタイルIDはstring型なので注意。
"""
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
core = core_manager.get_core(version)

try:
speakers = metas_store.load_combined_metas(core.speakers)
Expand Down Expand Up @@ -92,8 +94,9 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)

try:
speakers = metas_store.load_combined_metas(core.speakers)
Expand Down
19 changes: 12 additions & 7 deletions voicevox_engine/app/routers/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import APIRouter, HTTPException, Query
from pydantic import parse_obj_as

from voicevox_engine.app.routers.commons import convert_version_format
from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import Speaker, SpeakerInfo, StyleId
from voicevox_engine.metas.MetasStore import MetasStore, filter_speakers_and_styles
Expand All @@ -27,7 +28,8 @@ def generate_speaker_router(
@router.get("/speakers")
def speakers(core_version: str | None = None) -> list[Speaker]:
"""話者情報の一覧を取得します。"""
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
core = core_manager.get_core(version)
speakers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(speakers, "speaker")

Expand Down Expand Up @@ -72,10 +74,10 @@ def _speaker_info(
# {speaker_uuid_1}/
# ...

version = convert_version_format(core_version, core_manager)

# 該当話者を検索する
speakers = parse_obj_as(
list[Speaker], core_manager.get_core(core_version).speakers
)
speakers = parse_obj_as(list[Speaker], core_manager.get_core(version).speakers)
speakers = filter_speakers_and_styles(speakers, speaker_or_singer)
speaker = next(
filter(lambda spk: spk.speaker_uuid == speaker_uuid, speakers), None
Expand Down Expand Up @@ -137,7 +139,8 @@ def _speaker_info(
@router.get("/singers")
def singers(core_version: str | None = None) -> list[Speaker]:
"""歌手情報の一覧を取得します"""
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
core = core_manager.get_core(version)
singers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(singers, "singer")

Expand Down Expand Up @@ -168,7 +171,8 @@ def initialize_speaker(
指定されたスタイルを初期化します。
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
core = core_manager.get_core(version)
core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)

@router.get("/is_initialized_speaker")
Expand All @@ -179,7 +183,8 @@ def is_initialized_speaker(
"""
指定されたスタイルが初期化されているかどうかを返します。
"""
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
core = core_manager.get_core(version)
return core.is_initialized_style_id_synthesis(style_id)

return router
43 changes: 28 additions & 15 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

from voicevox_engine.app.routers.commons import convert_version_format
from voicevox_engine.cancellable_engine import (
CancellableEngine,
CancellableEngineInternalError,
Expand Down Expand Up @@ -84,8 +85,9 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
Expand Down Expand Up @@ -113,8 +115,9 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
presets = preset_manager.load_presets()
except PresetInputError as err:
Expand Down Expand Up @@ -170,7 +173,8 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
if is_kana:
try:
return engine.create_accent_phrases_from_kana(text, style_id)
Expand All @@ -191,7 +195,8 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
return engine.update_length_and_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -204,7 +209,8 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
return engine.update_length(accent_phrases, style_id)

@router.post(
Expand All @@ -217,7 +223,8 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
return engine.update_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -244,7 +251,8 @@ def synthesis(
] = True,
core_version: str | None = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)
Expand Down Expand Up @@ -284,9 +292,10 @@ def cancellable_synthesis(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
version = convert_version_format(core_version, core_manager)
try:
f_name = cancellable_engine._synthesis_impl(
query, style_id, request, core_version=core_version
query, style_id, request, version=version
)
except CancellableEngineInternalError as e:
raise HTTPException(status_code=500, detail=str(e))
Expand Down Expand Up @@ -320,7 +329,8 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -362,8 +372,9 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
Expand Down Expand Up @@ -391,7 +402,8 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | None = None,
) -> list[float]:
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
try:
return engine.create_sing_volume_from_phoneme_and_f0(
score, frame_audio_query.phonemes, frame_audio_query.f0, style_id
Expand Down Expand Up @@ -419,7 +431,8 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
engine = tts_engines.get_engine(core_version)
version = convert_version_format(core_version, core_manager)
engine = tts_engines.get_engine(version)
try:
wave = engine.frame_synthsize_wave(query, style_id)
except TalkSingInvalidInputError as e:
Expand Down
Loading