Skip to content

Commit

Permalink
整理: Coreロードの型付け・変数名 (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan authored Dec 7, 2023
1 parent 92c8211 commit 0677936
Showing 1 changed file with 91 additions and 61 deletions.
152 changes: 91 additions & 61 deletions voicevox_engine/synthesis_engine/core_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import List, Optional
from typing import List, Literal

import numpy as np

Expand Down Expand Up @@ -59,139 +59,139 @@ class GPUType(Enum):


@dataclass(frozen=True)
class CoreInfo:
name: str
platform: str
arch: str
core_type: str
gpu_type: GPUType
class _CoreInfo:
name: str # Coreファイル名
platform: Literal["Windows", "Linux", "Darwin"] # 対応システム/OS
arch: Literal["x64", "x86", "armv7l", "aarch64", "universal"] # 対応アーキテクチャ
core_type: Literal["libtorch", "onnxruntime"] # `model_type`
gpu_type: GPUType # NONE | CUDA | DIRECT_ML


# version 0.12 より前のコアの情報
CORE_INFOS = [
_CORE_INFOS = [
# Windows
CoreInfo(
_CoreInfo(
name="core.dll",
platform="Windows",
arch="x64",
core_type="libtorch",
gpu_type=GPUType.CUDA,
),
CoreInfo(
_CoreInfo(
name="core_cpu.dll",
platform="Windows",
arch="x64",
core_type="libtorch",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="core_gpu_x64_nvidia.dll",
platform="Windows",
arch="x64",
core_type="onnxruntime",
gpu_type=GPUType.CUDA,
),
CoreInfo(
_CoreInfo(
name="core_gpu_x64_directml.dll",
platform="Windows",
arch="x64",
core_type="onnxruntime",
gpu_type=GPUType.DIRECT_ML,
),
CoreInfo(
_CoreInfo(
name="core_cpu_x64.dll",
platform="Windows",
arch="x64",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="core_cpu_x86.dll",
platform="Windows",
arch="x86",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="core_gpu_x86_directml.dll",
platform="Windows",
arch="x86",
core_type="onnxruntime",
gpu_type=GPUType.DIRECT_ML,
),
CoreInfo(
_CoreInfo(
name="core_cpu_arm.dll",
platform="Windows",
arch="armv7l",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="core_gpu_arm_directml.dll",
platform="Windows",
arch="armv7l",
core_type="onnxruntime",
gpu_type=GPUType.DIRECT_ML,
),
CoreInfo(
_CoreInfo(
name="core_cpu_arm64.dll",
platform="Windows",
arch="aarch64",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="core_gpu_arm64_directml.dll",
platform="Windows",
arch="aarch64",
core_type="onnxruntime",
gpu_type=GPUType.DIRECT_ML,
),
# Linux
CoreInfo(
_CoreInfo(
name="libcore.so",
platform="Linux",
arch="x64",
core_type="libtorch",
gpu_type=GPUType.CUDA,
),
CoreInfo(
_CoreInfo(
name="libcore_cpu.so",
platform="Linux",
arch="x64",
core_type="libtorch",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="libcore_gpu_x64_nvidia.so",
platform="Linux",
arch="x64",
core_type="onnxruntime",
gpu_type=GPUType.CUDA,
),
CoreInfo(
_CoreInfo(
name="libcore_cpu_x64.so",
platform="Linux",
arch="x64",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="libcore_cpu_armhf.so",
platform="Linux",
arch="armv7l",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
CoreInfo(
_CoreInfo(
name="libcore_cpu_arm64.so",
platform="Linux",
arch="aarch64",
core_type="onnxruntime",
gpu_type=GPUType.NONE,
),
# macOS
CoreInfo(
_CoreInfo(
name="libcore_cpu_universal2.dylib",
platform="Darwin",
arch="universal",
Expand All @@ -204,17 +204,16 @@ class CoreInfo:
# version 0.12 以降のコアの名前の辞書
# - version 0.12, 0.13 のコアの名前: core
# - version 0.14 からのコアの名前: voicevox_core
CORENAME_DICT = {
_CORENAME_DICT = {
"Windows": ("voicevox_core.dll", "core.dll"),
"Linux": ("libvoicevox_core.so", "libcore.so"),
"Darwin": ("libvoicevox_core.dylib", "libcore.dylib"),
}


def find_version_0_12_core_or_later(core_dir: Path) -> Optional[str]:
def _find_version_0_12_core_or_later(core_dir: Path) -> str | None:
"""
core_dir で指定したディレクトリにあるコアライブラリが Version 0.12 以降である場合、
見つかった共有ライブラリの名前を返す。
`core_dir`直下に存在する コア Version 0.12 以降の共有ライブラリ名(None: 不在)
Version 0.12 以降と判定する条件は、
Expand All @@ -227,20 +226,19 @@ def find_version_0_12_core_or_later(core_dir: Path) -> Optional[str]:
if (core_dir / "metas.json").exists():
return None

for core_name in CORENAME_DICT[platform.system()]:
for core_name in _CORENAME_DICT[platform.system()]:
if (core_dir / core_name).is_file():
return core_name

return None


def get_arch_name() -> Optional[str]:
def _get_arch_name() -> Literal["x64", "x86", "aarch64", "armv7l"] | None:
"""
platform.machine() が特定のアーキテクチャ上で複数パターンの文字列を返し得るので、
一意な文字列に変換する
サポート外のアーキテクチャである場合、None を返す
実行中マシンのアーキテクチャ(None: サポート外アーキテクチャ)
"""
machine = platform.machine()
# 特定のアーキテクチャ上で複数パターンの文字列を返し得るので一意に変換
if machine == "x86_64" or machine == "x64" or machine == "AMD64":
return "x64"
elif machine == "i386" or machine == "x86":
Expand All @@ -253,18 +251,33 @@ def get_arch_name() -> Optional[str]:
return None


def get_core_name(
arch_name: str,
def _get_core_name(
arch_name: Literal["x64", "x86", "aarch64", "armv7l"],
platform_name: str,
model_type: str,
model_type: Literal["libtorch", "onnxruntime"],
gpu_type: GPUType,
) -> Optional[str]:
) -> str | None:
"""
設定値を満たすCoreの名前(None: サポート外)
Parameters
----------
arch_name : Literal["x64", "x86", "aarch64", "armv7l"]
実行中マシンのアーキテクチャ
platform_name : str
実行中マシンのシステム名
model_type: Literal["libtorch", "onnxruntime"]
gpu_type: GPUType
Returns
-------
name : str | None
Core名(None: サポート外)
"""
if platform_name == "Darwin":
if gpu_type == GPUType.NONE and (arch_name == "x64" or arch_name == "aarch64"):
arch_name = "universal"
else:
return None
for core_info in CORE_INFOS:
for core_info in _CORE_INFOS:
if (
core_info.platform == platform_name
and core_info.arch == arch_name
Expand All @@ -275,27 +288,30 @@ def get_core_name(
return None


def get_suitable_core_name(
model_type: str,
def _get_suitable_core_name(
model_type: Literal["libtorch", "onnxruntime"],
gpu_type: GPUType,
) -> Optional[str]:
arch_name = get_arch_name()
) -> str | None:
"""実行中マシン・引数設定値でサポートされるコアのファイル名(None: サポート外)"""
# 実行中マシンのアーキテクチャ・システム名
arch_name = _get_arch_name()
platform_name = platform.system()
if arch_name is None:
return None
platform_name = platform.system()
return get_core_name(arch_name, platform_name, model_type, gpu_type)
return _get_core_name(arch_name, platform_name, model_type, gpu_type)


def check_core_type(core_dir: Path) -> Optional[str]:
# libtorch版はDirectML未対応なので、ここでは`gpu_type=GPUType.DIRECT_ML`は入れない
def _check_core_type(core_dir: Path) -> Literal["libtorch", "onnxruntime"] | None:
"""`core_dir`直下に存在し実行中マシンで利用可能な Core の model_type(None: 利用可能 Core 無し)"""
libtorch_core_names = [
get_suitable_core_name("libtorch", gpu_type=GPUType.CUDA),
get_suitable_core_name("libtorch", gpu_type=GPUType.NONE),
_get_suitable_core_name("libtorch", gpu_type=GPUType.CUDA),
_get_suitable_core_name("libtorch", gpu_type=GPUType.NONE),
# ("libtorch", GPUType.DIRECT_ML): libtorch版はDirectML未対応
]
onnxruntime_core_names = [
get_suitable_core_name("onnxruntime", gpu_type=GPUType.CUDA),
get_suitable_core_name("onnxruntime", gpu_type=GPUType.DIRECT_ML),
get_suitable_core_name("onnxruntime", gpu_type=GPUType.NONE),
_get_suitable_core_name("onnxruntime", gpu_type=GPUType.CUDA),
_get_suitable_core_name("onnxruntime", gpu_type=GPUType.DIRECT_ML),
_get_suitable_core_name("onnxruntime", gpu_type=GPUType.NONE),
]
if any([(core_dir / name).is_file() for name in libtorch_core_names if name]):
return "libtorch"
Expand All @@ -306,7 +322,20 @@ def check_core_type(core_dir: Path) -> Optional[str]:


def load_core(core_dir: Path, use_gpu: bool) -> CDLL:
core_name = find_version_0_12_core_or_later(core_dir)
"""
`core_dir` 直下に存在し実行中マシンでサポートされるコアDLLのロード
Parameters
----------
core_dir : Path
直下にコア(共有ライブラリ)が存在するディレクトリ
use_gpu
Returns
-------
core : CDLL
コアDLL
"""
# Core>=0.12
core_name = _find_version_0_12_core_or_later(core_dir)
if core_name:
try:
# NOTE: CDLL クラスのコンストラクタの引数 name には文字列を渡す必要がある。
Expand All @@ -315,29 +344,30 @@ def load_core(core_dir: Path, use_gpu: bool) -> CDLL:
except OSError as err:
raise RuntimeError(f"コアの読み込みに失敗しました:{err}")

model_type = check_core_type(core_dir)
# Core<0.12
model_type = _check_core_type(core_dir)
if model_type is None:
raise RuntimeError("コアが見つかりません")
if use_gpu or model_type == "onnxruntime":
core_name = get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
if core_name:
try:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
except OSError:
pass
core_name = get_suitable_core_name(model_type, gpu_type=GPUType.DIRECT_ML)
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.DIRECT_ML)
if core_name:
try:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
except OSError:
pass
core_name = get_suitable_core_name(model_type, gpu_type=GPUType.NONE)
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.NONE)
if core_name:
try:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
except OSError as err:
if model_type == "libtorch":
core_name = get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
core_name = _get_suitable_core_name(model_type, gpu_type=GPUType.CUDA)
if core_name:
try:
return CDLL(str((core_dir / core_name).resolve(strict=True)))
Expand Down Expand Up @@ -375,7 +405,7 @@ def __init__(
self.exist_is_model_loaded = False

is_version_0_12_core_or_later = (
find_version_0_12_core_or_later(core_dir) is not None
_find_version_0_12_core_or_later(core_dir) is not None
)
if is_version_0_12_core_or_later:
model_type = "onnxruntime"
Expand All @@ -386,7 +416,7 @@ def __init__(
self.core.is_model_loaded.argtypes = (c_long,)
self.core.is_model_loaded.restype = c_bool
else:
model_type = check_core_type(core_dir)
model_type = _check_core_type(core_dir)
assert model_type is not None

if model_type == "onnxruntime":
Expand Down

0 comments on commit 0677936

Please sign in to comment.