Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

整理: コア・エンジンでバージョンを指定しない場合、暗黙的に最新版を取得する処理を削除 #1317

Merged
merged 18 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions test/unit/test_core_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def test_cores_latest_version() -> None:
assert true_latest_version == latest_version


def test_cores_get_core_specified() -> None:
"""CoreManager.get_core() で登録済みコアをバージョン指定して取得できる。"""
def test_cores_get_core_existing() -> None:
"""CoreManager.get_core() で登録済みコアを取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
Expand All @@ -69,23 +69,6 @@ def test_cores_get_core_specified() -> None:
assert true_acquired_core == acquired_core


def test_cores_get_core_latest() -> None:
"""CoreManager.get_core() で最新版コアをバージョン未指定で取得できる。"""
# Inputs
core_manager = CoreManager()
core1 = CoreAdapter(MockCoreWrapper())
core2 = CoreAdapter(MockCoreWrapper())
core_manager.register_core(core1, "0.0.1")
core_manager.register_core(core2, "0.0.2")
# Expects
true_acquired_core = core2
# Outputs
acquired_core = core_manager.get_core()

# Test
assert true_acquired_core == acquired_core


def test_cores_get_core_missing() -> None:
"""CoreManager.get_core() で存在しないコアを取得しようとするとエラーになる。"""
# Inputs
Expand Down
36 changes: 2 additions & 34 deletions test/unit/tts_pipeline/test_tts_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,8 @@ def test_tts_engines_versions() -> None:
assert true_versions == versions


def test_tts_engines_latest_version() -> None:
"""TTSEngineManager.latest_version() で最新バージョンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engines.register_engine(MockTTSEngine(), "0.0.1")
tts_engines.register_engine(MockTTSEngine(), "0.0.2")
# Expects
true_latest_version = "0.0.2"
# Outputs
latest_version = tts_engines.latest_version()

# Test
assert true_latest_version == latest_version


def test_tts_engines_get_engine_specified() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンをバージョン指定して取得できる。"""
def test_tts_engines_get_engine_existing() -> None:
"""TTSEngineManager.get_engine() で登録済み TTS エンジンを取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
Expand All @@ -63,23 +48,6 @@ def test_tts_engines_get_engine_specified() -> None:
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_latest() -> None:
"""TTSEngineManager.get_engine() で最新版 TTS エンジンをバージョン未指定で取得できる。"""
# Inputs
tts_engines = TTSEngineManager()
tts_engine1 = MockTTSEngine()
tts_engine2 = MockTTSEngine()
tts_engines.register_engine(tts_engine1, "0.0.1")
tts_engines.register_engine(tts_engine2, "0.0.2")
# Expects
true_acquired_tts_engine = tts_engine2
# Outputs
acquired_tts_engine = tts_engines.get_engine()

# Test
assert true_acquired_tts_engine == acquired_tts_engine


def test_tts_engines_get_engine_missing() -> None:
"""TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。"""
# Inputs
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
supported_devices = core_manager.get_core(core_version).supported_devices
version = core_version or core_manager.latest_version()
supported_devices = core_manager.get_core(version).supported_devices
if supported_devices is None:
raise HTTPException(status_code=422, detail="非対応の機能です。")
return SupportedDevicesInfo.generate_from(supported_devices)
Expand Down
9 changes: 6 additions & 3 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def morphable_targets(
プロパティが存在しない場合は、モーフィングが許可されているとみなします。
返り値のスタイルIDはstring型なので注意。
"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)

speakers = metas_store.load_combined_metas(core.speakers)
try:
morphable_targets = get_morphable_targets(speakers, base_style_ids)
Expand Down Expand Up @@ -90,8 +92,9 @@ def _synthesis_morphing(
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)

# モーフィングが許可されないキャラクターペアを拒否する
speakers = metas_store.load_combined_metas(core.speakers)
Expand Down
10 changes: 7 additions & 3 deletions voicevox_engine/app/routers/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def generate_speaker_router(
@router.get("/speakers")
def speakers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""話者情報の一覧を取得します。"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
speakers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(speakers, "speaker")

Expand Down Expand Up @@ -74,8 +75,10 @@ def _speaker_info(
# {speaker_uuid_1}/
# ...

version = core_version or core_manager.latest_version()

# 該当話者を検索する
core_speakers = core_manager.get_core(core_version).speakers
core_speakers = core_manager.get_core(version).speakers
speakers = metas_store.load_combined_metas(core_speakers)
speakers = filter_speakers_and_styles(speakers, speaker_or_singer)
speaker = next(
Expand Down Expand Up @@ -138,7 +141,8 @@ def _speaker_info(
@router.get("/singers")
def singers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""歌手情報の一覧を取得します"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
singers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(singers, "singer")

Expand Down
48 changes: 31 additions & 17 deletions voicevox_engine/app/routers/tts_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def audio_query(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
accent_phrases = engine.create_accent_phrases(text, style_id)
return AudioQuery(
accent_phrases=accent_phrases,
Expand Down Expand Up @@ -116,8 +117,9 @@ def audio_query_from_preset(
"""
音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
presets = preset_manager.load_presets()
except PresetInputError as err:
Expand Down Expand Up @@ -175,7 +177,8 @@ def accent_phrases(
* アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。
* アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。
"""
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
if is_kana:
try:
return engine.create_accent_phrases_from_kana(text, style_id)
Expand All @@ -196,7 +199,8 @@ def mora_data(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
return engine.update_length_and_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -209,7 +213,8 @@ def mora_length(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
return engine.update_length(accent_phrases, style_id)

@router.post(
Expand All @@ -222,7 +227,8 @@ def mora_pitch(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[AccentPhrase]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
return engine.update_pitch(accent_phrases, style_id)

@router.post(
Expand All @@ -249,7 +255,8 @@ def synthesis(
] = True,
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
wave = engine.synthesize_wave(
query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak
)
Expand Down Expand Up @@ -289,9 +296,10 @@ def cancellable_synthesis(
status_code=404,
detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。",
)
version = core_version or core_manager.latest_version()
try:
f_name = cancellable_engine._synthesis_impl(
query, style_id, request, core_version=core_version
query, style_id, request, version=version
)
except CancellableEngineInternalError as e:
raise HTTPException(status_code=500, detail=str(e))
Expand Down Expand Up @@ -325,7 +333,8 @@ def multi_synthesis(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
sampling_rate = queries[0].outputSamplingRate

with NamedTemporaryFile(delete=False) as f:
Expand Down Expand Up @@ -367,8 +376,9 @@ def sing_frame_audio_query(
"""
歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。
"""
engine = tts_engines.get_engine(core_version)
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
core = core_manager.get_core(version)
try:
phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume(
score, style_id
Expand Down Expand Up @@ -396,7 +406,8 @@ def sing_frame_volume(
style_id: Annotated[StyleId, Query(alias="speaker")],
core_version: str | SkipJsonSchema[None] = None,
) -> list[float]:
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
try:
return engine.create_sing_volume_from_phoneme_and_f0(
score, frame_audio_query.phonemes, frame_audio_query.f0, style_id
Expand Down Expand Up @@ -424,7 +435,8 @@ def frame_synthesis(
"""
歌唱音声合成を行います。
"""
engine = tts_engines.get_engine(core_version)
version = core_version or core_manager.latest_version()
engine = tts_engines.get_engine(version)
try:
wave = engine.frame_synthsize_wave(query, style_id)
except TalkSingInvalidInputError as e:
Expand Down Expand Up @@ -519,7 +531,8 @@ def initialize_speaker(
指定されたスタイルを初期化します。
実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。
"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit)

@router.get("/is_initialized_speaker", tags=["その他"])
Expand All @@ -530,7 +543,8 @@ def is_initialized_speaker(
"""
指定されたスタイルが初期化されているかどうかを返します。
"""
core = core_manager.get_core(core_version)
version = core_version or core_manager.latest_version()
core = core_manager.get_core(version)
return core.is_initialized_style_id_synthesis(style_id)

return router
14 changes: 6 additions & 8 deletions voicevox_engine/cancellable_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _synthesis_impl(
query: AudioQuery,
style_id: StyleId,
request: Request,
core_version: str | None,
version: str,
) -> str:
"""
音声合成を行う関数
Expand All @@ -163,7 +163,7 @@ def _synthesis_impl(
request: fastapi.Request
接続確立時に受け取ったものをそのまま渡せばよい
https://fastapi.tiangolo.com/advanced/using-request-directly/
core_version: str
version: str

Returns
-------
Expand All @@ -173,7 +173,7 @@ def _synthesis_impl(
proc, sub_proc_con1 = self.procs_and_cons.get()
self.watch_con_list.append((request, proc))
try:
sub_proc_con1.send((query, style_id, core_version))
sub_proc_con1.send((query, style_id, version))
f_name = sub_proc_con1.recv()
if isinstance(f_name, str):
audio_file_name = f_name
Expand Down Expand Up @@ -244,11 +244,9 @@ def start_synthesis_subprocess(
assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。"
while True:
try:
query, style_id, core_version = sub_proc_con.recv()
if core_version is None:
_engine = tts_engines.get_engine()
elif tts_engines.has_engine(core_version):
_engine = tts_engines.get_engine(core_version)
query, style_id, version = sub_proc_con.recv()
if tts_engines.has_engine(version):
_engine = tts_engines.get_engine(version)
else:
# バージョンが見つからないエラー
sub_proc_con.send("")
Expand Down
8 changes: 3 additions & 5 deletions voicevox_engine/core/core_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@ def register_core(self, core: CoreAdapter, version: str) -> None:
"""コアを登録する。"""
self._cores[version] = core

def get_core(self, version: str | None = None) -> CoreAdapter:
"""指定バージョンのコアを取得する。指定が無い場合、最新バージョンを返す。"""
if version is None:
return self._cores[self.latest_version()]
elif version in self._cores:
def get_core(self, version: str) -> CoreAdapter:
"""指定バージョンのコアを取得する。"""
if version in self._cores:
return self._cores[version]
raise CoreNotFound(f"バージョン {version} のコアが見つかりません")

Expand Down
13 changes: 3 additions & 10 deletions voicevox_engine/tts_pipeline/tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ..core.core_wrapper import CoreWrapper
from ..metas.Metas import StyleId
from ..model import AudioQuery
from ..utility.core_version_utility import get_latest_version
from .kana_converter import parse_kana
from .model import AccentPhrase, FrameAudioQuery, FramePhoneme, Mora, Note, Score
from .mora_mapping import mora_kana_to_mora_phonemes, mora_phonemes_to_mora_kana
Expand Down Expand Up @@ -702,19 +701,13 @@ def versions(self) -> list[str]:
"""登録されたエンジンのバージョン一覧を取得する。"""
return list(self._engines.keys())

def latest_version(self) -> str:
"""登録された最新版エンジンのバージョンを取得する。"""
return get_latest_version(self.versions())

def register_engine(self, engine: TTSEngine, version: str) -> None:
"""エンジンを登録する。"""
self._engines[version] = engine

def get_engine(self, version: str | None = None) -> TTSEngine:
"""指定バージョンのエンジンを取得する。指定が無い場合、最新バージョンを返す。"""
if version is None:
return self._engines[self.latest_version()]
elif version in self._engines:
def get_engine(self, version: str) -> TTSEngine:
"""指定バージョンのエンジンを取得する。"""
if version in self._engines:
return self._engines[version]

raise HTTPException(status_code=422, detail="不明なバージョンです")
Expand Down