From 95f218a220e1979a19308f238f2f0453b735b0db Mon Sep 17 00:00:00 2001 From: tarepan Date: Tue, 18 Jun 2024 21:56:55 +0900 Subject: [PATCH] =?UTF-8?q?=E6=95=B4=E7=90=86:=20=E3=82=B3=E3=82=A2?= =?UTF-8?q?=E3=83=BB=E3=82=A8=E3=83=B3=E3=82=B8=E3=83=B3=E3=81=A7=E3=83=90?= =?UTF-8?q?=E3=83=BC=E3=82=B8=E3=83=A7=E3=83=B3=E3=82=92=E6=8C=87=E5=AE=9A?= =?UTF-8?q?=E3=81=97=E3=81=AA=E3=81=84=E5=A0=B4=E5=90=88=E3=80=81=E6=9A=97?= =?UTF-8?q?=E9=BB=99=E7=9A=84=E3=81=AB=E6=9C=80=E6=96=B0=E7=89=88=E3=82=92?= =?UTF-8?q?=E5=8F=96=E5=BE=97=E3=81=99=E3=82=8B=E5=87=A6=E7=90=86=E3=82=92?= =?UTF-8?q?=E5=89=8A=E9=99=A4=20(#1317)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: コアバージョン変換メソッドを追加 * refactor: latest 自動取得を削除 * refactor: ルーターとドメインを分離 * refactor: 使われなくなった型を削除 * refactor: lint * fix: lint * refactor: `convert_version_format()` を削除 * fix conflict --------- Co-authored-by: Hiroshiba Kazuyuki --- test/unit/test_core_initializer.py | 21 +-------- test/unit/tts_pipeline/test_tts_engines.py | 36 +--------------- voicevox_engine/app/routers/engine_info.py | 3 +- voicevox_engine/app/routers/morphing.py | 9 ++-- voicevox_engine/app/routers/speaker.py | 10 +++-- voicevox_engine/app/routers/tts_pipeline.py | 48 +++++++++++++-------- voicevox_engine/cancellable_engine.py | 14 +++--- voicevox_engine/core/core_initializer.py | 8 ++-- voicevox_engine/tts_pipeline/tts_engine.py | 13 ++---- 9 files changed, 62 insertions(+), 100 deletions(-) diff --git a/test/unit/test_core_initializer.py b/test/unit/test_core_initializer.py index 954c5ac1e..8046bf7f7 100644 --- a/test/unit/test_core_initializer.py +++ b/test/unit/test_core_initializer.py @@ -52,8 +52,8 @@ def test_cores_latest_version() -> None: assert true_latest_version == latest_version -def test_cores_get_core_specified() -> None: - """CoreManager.get_core() で登録済みコアをバージョン指定して取得できる。""" +def test_cores_get_core_existing() -> None: + """CoreManager.get_core() で登録済みコアを取得できる。""" # Inputs core_manager = CoreManager() core1 = CoreAdapter(MockCoreWrapper()) @@ -69,23 +69,6 @@ def test_cores_get_core_specified() -> None: assert true_acquired_core == acquired_core -def test_cores_get_core_latest() -> None: - """CoreManager.get_core() で最新版コアをバージョン未指定で取得できる。""" - # Inputs - core_manager = CoreManager() - core1 = CoreAdapter(MockCoreWrapper()) - core2 = CoreAdapter(MockCoreWrapper()) - core_manager.register_core(core1, "0.0.1") - core_manager.register_core(core2, "0.0.2") - # Expects - true_acquired_core = core2 - # Outputs - acquired_core = core_manager.get_core() - - # Test - assert true_acquired_core == acquired_core - - def test_cores_get_core_missing() -> None: """CoreManager.get_core() で存在しないコアを取得しようとするとエラーになる。""" # Inputs diff --git a/test/unit/tts_pipeline/test_tts_engines.py b/test/unit/tts_pipeline/test_tts_engines.py index f488fcb0b..1fa32c904 100644 --- a/test/unit/tts_pipeline/test_tts_engines.py +++ b/test/unit/tts_pipeline/test_tts_engines.py @@ -31,23 +31,8 @@ def test_tts_engines_versions() -> None: assert true_versions == versions -def test_tts_engines_latest_version() -> None: - """TTSEngineManager.latest_version() で最新バージョンを取得できる。""" - # Inputs - tts_engines = TTSEngineManager() - tts_engines.register_engine(MockTTSEngine(), "0.0.1") - tts_engines.register_engine(MockTTSEngine(), "0.0.2") - # Expects - true_latest_version = "0.0.2" - # Outputs - latest_version = tts_engines.latest_version() - - # Test - assert true_latest_version == latest_version - - -def test_tts_engines_get_engine_specified() -> None: - """TTSEngineManager.get_engine() で登録済み TTS エンジンをバージョン指定して取得できる。""" +def test_tts_engines_get_engine_existing() -> None: + """TTSEngineManager.get_engine() で登録済み TTS エンジンを取得できる。""" # Inputs tts_engines = TTSEngineManager() tts_engine1 = MockTTSEngine() @@ -63,23 +48,6 @@ def test_tts_engines_get_engine_specified() -> None: assert true_acquired_tts_engine == acquired_tts_engine -def test_tts_engines_get_engine_latest() -> None: - """TTSEngineManager.get_engine() で最新版 TTS エンジンをバージョン未指定で取得できる。""" - # Inputs - tts_engines = TTSEngineManager() - tts_engine1 = MockTTSEngine() - tts_engine2 = MockTTSEngine() - tts_engines.register_engine(tts_engine1, "0.0.1") - tts_engines.register_engine(tts_engine2, "0.0.2") - # Expects - true_acquired_tts_engine = tts_engine2 - # Outputs - acquired_tts_engine = tts_engines.get_engine() - - # Test - assert true_acquired_tts_engine == acquired_tts_engine - - def test_tts_engines_get_engine_missing() -> None: """TTSEngineManager.get_engine() で存在しない TTS エンジンを取得しようとするとエラーになる。""" # Inputs diff --git a/voicevox_engine/app/routers/engine_info.py b/voicevox_engine/app/routers/engine_info.py index b50546c1a..96f22b4d0 100644 --- a/voicevox_engine/app/routers/engine_info.py +++ b/voicevox_engine/app/routers/engine_info.py @@ -52,7 +52,8 @@ def supported_devices( core_version: str | SkipJsonSchema[None] = None, ) -> SupportedDevicesInfo: """対応デバイスの一覧を取得します。""" - supported_devices = core_manager.get_core(core_version).supported_devices + version = core_version or core_manager.latest_version() + supported_devices = core_manager.get_core(version).supported_devices if supported_devices is None: raise HTTPException(status_code=422, detail="非対応の機能です。") return SupportedDevicesInfo.generate_from(supported_devices) diff --git a/voicevox_engine/app/routers/morphing.py b/voicevox_engine/app/routers/morphing.py index 45a572899..364a3ebbb 100644 --- a/voicevox_engine/app/routers/morphing.py +++ b/voicevox_engine/app/routers/morphing.py @@ -54,7 +54,9 @@ def morphable_targets( プロパティが存在しない場合は、モーフィングが許可されているとみなします。 返り値のスタイルIDはstring型なので注意。 """ - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + core = core_manager.get_core(version) + speakers = metas_store.load_combined_metas(core.speakers) try: morphable_targets = get_morphable_targets(speakers, base_style_ids) @@ -90,8 +92,9 @@ def _synthesis_morphing( 指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。 モーフィングの割合は`morph_rate`で指定でき、0.0でベースのスタイル、1.0でターゲットのスタイルに近づきます。 """ - engine = tts_engines.get_engine(core_version) - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) + core = core_manager.get_core(version) # モーフィングが許可されないキャラクターペアを拒否する speakers = metas_store.load_combined_metas(core.speakers) diff --git a/voicevox_engine/app/routers/speaker.py b/voicevox_engine/app/routers/speaker.py index 46fa312f7..d32afe8c1 100644 --- a/voicevox_engine/app/routers/speaker.py +++ b/voicevox_engine/app/routers/speaker.py @@ -27,7 +27,8 @@ def generate_speaker_router( @router.get("/speakers") def speakers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]: """話者情報の一覧を取得します。""" - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + core = core_manager.get_core(version) speakers = metas_store.load_combined_metas(core.speakers) return filter_speakers_and_styles(speakers, "speaker") @@ -74,8 +75,10 @@ def _speaker_info( # {speaker_uuid_1}/ # ... + version = core_version or core_manager.latest_version() + # 該当話者を検索する - core_speakers = core_manager.get_core(core_version).speakers + core_speakers = core_manager.get_core(version).speakers speakers = metas_store.load_combined_metas(core_speakers) speakers = filter_speakers_and_styles(speakers, speaker_or_singer) speaker = next( @@ -138,7 +141,8 @@ def _speaker_info( @router.get("/singers") def singers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]: """歌手情報の一覧を取得します""" - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + core = core_manager.get_core(version) singers = metas_store.load_combined_metas(core.speakers) return filter_speakers_and_styles(singers, "singer") diff --git a/voicevox_engine/app/routers/tts_pipeline.py b/voicevox_engine/app/routers/tts_pipeline.py index c22755642..4a2159a09 100644 --- a/voicevox_engine/app/routers/tts_pipeline.py +++ b/voicevox_engine/app/routers/tts_pipeline.py @@ -85,8 +85,9 @@ def audio_query( """ 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - engine = tts_engines.get_engine(core_version) - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) + core = core_manager.get_core(version) accent_phrases = engine.create_accent_phrases(text, style_id) return AudioQuery( accent_phrases=accent_phrases, @@ -116,8 +117,9 @@ def audio_query_from_preset( """ 音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - engine = tts_engines.get_engine(core_version) - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) + core = core_manager.get_core(version) try: presets = preset_manager.load_presets() except PresetInputError as err: @@ -175,7 +177,8 @@ def accent_phrases( * アクセント位置を`'`で指定する。全てのアクセント句にはアクセント位置を1つ指定する必要がある。 * アクセント句末に`?`(全角)を入れることにより疑問文の発音ができる。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) if is_kana: try: return engine.create_accent_phrases_from_kana(text, style_id) @@ -196,7 +199,8 @@ def mora_data( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) return engine.update_length_and_pitch(accent_phrases, style_id) @router.post( @@ -209,7 +213,8 @@ def mora_length( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) return engine.update_length(accent_phrases, style_id) @router.post( @@ -222,7 +227,8 @@ def mora_pitch( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[AccentPhrase]: - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) return engine.update_pitch(accent_phrases, style_id) @router.post( @@ -249,7 +255,8 @@ def synthesis( ] = True, core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) wave = engine.synthesize_wave( query, style_id, enable_interrogative_upspeak=enable_interrogative_upspeak ) @@ -289,9 +296,10 @@ def cancellable_synthesis( status_code=404, detail="実験的機能はデフォルトで無効になっています。使用するには引数を指定してください。", ) + version = core_version or core_manager.latest_version() try: f_name = cancellable_engine._synthesis_impl( - query, style_id, request, core_version=core_version + query, style_id, request, version=version ) except CancellableEngineInternalError as e: raise HTTPException(status_code=500, detail=str(e)) @@ -325,7 +333,8 @@ def multi_synthesis( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> FileResponse: - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) sampling_rate = queries[0].outputSamplingRate with NamedTemporaryFile(delete=False) as f: @@ -367,8 +376,9 @@ def sing_frame_audio_query( """ 歌唱音声合成用のクエリの初期値を得ます。ここで得られたクエリはそのまま歌唱音声合成に利用できます。各値の意味は`Schemas`を参照してください。 """ - engine = tts_engines.get_engine(core_version) - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) + core = core_manager.get_core(version) try: phonemes, f0, volume = engine.create_sing_phoneme_and_f0_and_volume( score, style_id @@ -396,7 +406,8 @@ def sing_frame_volume( style_id: Annotated[StyleId, Query(alias="speaker")], core_version: str | SkipJsonSchema[None] = None, ) -> list[float]: - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) try: return engine.create_sing_volume_from_phoneme_and_f0( score, frame_audio_query.phonemes, frame_audio_query.f0, style_id @@ -424,7 +435,8 @@ def frame_synthesis( """ 歌唱音声合成を行います。 """ - engine = tts_engines.get_engine(core_version) + version = core_version or core_manager.latest_version() + engine = tts_engines.get_engine(version) try: wave = engine.frame_synthsize_wave(query, style_id) except TalkSingInvalidInputError as e: @@ -519,7 +531,8 @@ def initialize_speaker( 指定されたスタイルを初期化します。 実行しなくても他のAPIは使用できますが、初回実行時に時間がかかることがあります。 """ - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + core = core_manager.get_core(version) core.initialize_style_id_synthesis(style_id, skip_reinit=skip_reinit) @router.get("/is_initialized_speaker", tags=["その他"]) @@ -530,7 +543,8 @@ def is_initialized_speaker( """ 指定されたスタイルが初期化されているかどうかを返します。 """ - core = core_manager.get_core(core_version) + version = core_version or core_manager.latest_version() + core = core_manager.get_core(version) return core.is_initialized_style_id_synthesis(style_id) return router diff --git a/voicevox_engine/cancellable_engine.py b/voicevox_engine/cancellable_engine.py index 70da62c5f..2899fb5e4 100644 --- a/voicevox_engine/cancellable_engine.py +++ b/voicevox_engine/cancellable_engine.py @@ -149,7 +149,7 @@ def _synthesis_impl( query: AudioQuery, style_id: StyleId, request: Request, - core_version: str | None, + version: str, ) -> str: """ 音声合成を行う関数 @@ -163,7 +163,7 @@ def _synthesis_impl( request: fastapi.Request 接続確立時に受け取ったものをそのまま渡せばよい https://fastapi.tiangolo.com/advanced/using-request-directly/ - core_version: str + version: str Returns ------- @@ -173,7 +173,7 @@ def _synthesis_impl( proc, sub_proc_con1 = self.procs_and_cons.get() self.watch_con_list.append((request, proc)) try: - sub_proc_con1.send((query, style_id, core_version)) + sub_proc_con1.send((query, style_id, version)) f_name = sub_proc_con1.recv() if isinstance(f_name, str): audio_file_name = f_name @@ -244,11 +244,9 @@ def start_synthesis_subprocess( assert len(tts_engines.versions()) != 0, "音声合成エンジンがありません。" while True: try: - query, style_id, core_version = sub_proc_con.recv() - if core_version is None: - _engine = tts_engines.get_engine() - elif tts_engines.has_engine(core_version): - _engine = tts_engines.get_engine(core_version) + query, style_id, version = sub_proc_con.recv() + if tts_engines.has_engine(version): + _engine = tts_engines.get_engine(version) else: # バージョンが見つからないエラー sub_proc_con.send("") diff --git a/voicevox_engine/core/core_initializer.py b/voicevox_engine/core/core_initializer.py index 5f7a9dc2a..5e645dfb1 100644 --- a/voicevox_engine/core/core_initializer.py +++ b/voicevox_engine/core/core_initializer.py @@ -44,11 +44,9 @@ def register_core(self, core: CoreAdapter, version: str) -> None: """コアを登録する。""" self._cores[version] = core - def get_core(self, version: str | None = None) -> CoreAdapter: - """指定バージョンのコアを取得する。指定が無い場合、最新バージョンを返す。""" - if version is None: - return self._cores[self.latest_version()] - elif version in self._cores: + def get_core(self, version: str) -> CoreAdapter: + """指定バージョンのコアを取得する。""" + if version in self._cores: return self._cores[version] raise CoreNotFound(f"バージョン {version} のコアが見つかりません") diff --git a/voicevox_engine/tts_pipeline/tts_engine.py b/voicevox_engine/tts_pipeline/tts_engine.py index e5b09d903..9d3c57248 100644 --- a/voicevox_engine/tts_pipeline/tts_engine.py +++ b/voicevox_engine/tts_pipeline/tts_engine.py @@ -13,7 +13,6 @@ from ..core.core_wrapper import CoreWrapper from ..metas.Metas import StyleId from ..model import AudioQuery -from ..utility.core_version_utility import get_latest_version from .kana_converter import parse_kana from .model import AccentPhrase, FrameAudioQuery, FramePhoneme, Mora, Note, Score from .mora_mapping import mora_kana_to_mora_phonemes, mora_phonemes_to_mora_kana @@ -702,19 +701,13 @@ def versions(self) -> list[str]: """登録されたエンジンのバージョン一覧を取得する。""" return list(self._engines.keys()) - def latest_version(self) -> str: - """登録された最新版エンジンのバージョンを取得する。""" - return get_latest_version(self.versions()) - def register_engine(self, engine: TTSEngine, version: str) -> None: """エンジンを登録する。""" self._engines[version] = engine - def get_engine(self, version: str | None = None) -> TTSEngine: - """指定バージョンのエンジンを取得する。指定が無い場合、最新バージョンを返す。""" - if version is None: - return self._engines[self.latest_version()] - elif version in self._engines: + def get_engine(self, version: str) -> TTSEngine: + """指定バージョンのエンジンを取得する。""" + if version in self._engines: return self._engines[version] raise HTTPException(status_code=422, detail="不明なバージョンです")