diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py b/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py new file mode 100644 index 00000000000..14cb19ddc31 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_audio/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm_omni.diffusion.models.hyperclovax_audio.hyperclovax_audio_decoder import ( + HyperCLOVAXAudioDecoderModel, +) +from vllm_omni.diffusion.models.hyperclovax_audio.pipeline_hyperclovax_audio import ( + HyperCLOVAXAudioPipeline, + get_hyperclovax_audio_post_process_func, +) + +__all__ = [ + "HyperCLOVAXAudioPipeline", + "HyperCLOVAXAudioDecoderModel", + "get_hyperclovax_audio_post_process_func", +] diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/activations.py b/vllm_omni/diffusion/models/hyperclovax_audio/activations.py new file mode 100644 index 00000000000..45761c86138 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_audio/activations.py @@ -0,0 +1,272 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# See NOTICE file for license details. + +import math + +import torch +import torch.nn.functional as F +from torch import nn, pow, sin +from torch.nn import Parameter + +if "sinc" in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # See NOTICE file for license details. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# See NOTICE file for license details. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, + # otherwise we will have a small leakage of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = "replicate", + kernel_size: int = 12, + causal: bool = False, + ): + """ + kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. + """ + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.causal = causal + if self.causal: + self.pad_left = kernel_size - 1 + self.pad_right = 0 + else: + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + # Input [B, C, T] + def forward(self, x, hidden_states=None): + _, C, _ = x.shape + hs = x[..., -self.pad_left :] + + if self.padding: + if self.causal: + if hidden_states is not None: + assert hidden_states.shape[-1] >= self.pad_left + hidden_states = hidden_states[..., -self.pad_left :] + x = torch.cat([hidden_states, x], dim=-1) + else: + x = F.pad(x, (self.pad_left, 0), mode="constant", value=0.0) + else: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + return out, hs + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, causal=False): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.causal = causal + + self.half_left = (kernel_size - ratio) // 2 + self.half_right = (kernel_size - ratio + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x, hidden_states=None): + _, C, _ = x.shape + hs = x[..., -self.pad :] + + pad_left = self.pad + pad_right = 0 if self.causal else self.pad + + if hidden_states is not None: + assert hidden_states.shape[-1] >= self.pad + hidden_states = hidden_states[..., -self.pad :] + x = torch.cat([hidden_states, x], dim=-1) + if pad_right > 0: + x = F.pad(x, (0, pad_right), mode="replicate") + else: + x = F.pad(x, (pad_left, pad_right), mode="replicate") + + x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + crop_left = pad_left * self.stride + self.half_left + crop_right = pad_right * self.stride + self.half_right + if crop_right > 0: + x = x[..., crop_left:-crop_right] + else: + x = x[..., crop_left:] + + return x, hs.detach() + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, causal=False): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + causal=causal, + ) + + def forward(self, x, hidden_states=None): + xx, hs = self.lowpass(x, hidden_states) + + return xx, hs.detach() + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper + by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = in_features + + # Initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # Log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # Linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + causal: bool = False, + ): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size, causal) + self.downsample = DownSample1d(down_ratio, down_kernel_size, causal) + + # x: [B,C,T] + def forward(self, x, hidden_states=None): + if hidden_states is None: + hidden_states = [None] * 2 + else: + assert len(hidden_states) == 2 + + x, h_up = self.upsample(x, hidden_states[0]) + x = self.act(x) + x, h_down = self.downsample(x, hidden_states[-1]) + + return x, (h_up.detach(), h_down.detach()) diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/constants.py b/vllm_omni/diffusion/models/hyperclovax_audio/constants.py new file mode 100644 index 00000000000..16033e00955 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_audio/constants.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Speaker IDs supported by the model. +# NOTE: The pretrained model checkpoint supports num_spk=26 speaker embeddings, +# but only the two IDs below are exposed for external use. Unknown speaker IDs +# will raise a KeyError at inference time. +SPEAKERS_LIST = [ + "fkms", + "msij", +] + +FORMAT_MIME_MAP = { + "mp3": "audio/mpeg", + "wav": "audio/wav", + "flac": "audio/flac", + "ogg": "audio/ogg", + "aac": "audio/aac", + "pcm": "audio/pcm", +} + +DEFAULT_FORMAT = "wav" + +AUDIO_FORMAT_MAP = [ + (b"RIFF", "wav"), # WAV (RIFF container) + (b"\x1a\x45\xdf\xa3", "webm"), # WebM / MKV (EBML header) + (b"OggS", "ogg"), # OGG + (b"fLaC", "flac"), # FLAC + (b"ID3", "mp3"), # MP3 with ID3 tag + (b"\xff\xfb", "mp3"), # MP3 without ID3 + (b"\x00\x00\x00\x1c", "mp4"), # MP4 / M4A + (b"\x00\x00\x00\x20", "mp4"), # MP4 / M4A +] + +VOLUME_LEVEL_DB = -26 + +VOLUME_LEVEL = 10 ** (VOLUME_LEVEL_DB / 20) diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py b/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py new file mode 100644 index 00000000000..d9eb67c0ba6 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_audio/ecapa_tdnn.py @@ -0,0 +1,251 @@ +# Omni Chainer - Multimodal LLM Inference System +# Copyright (c) 2025-present NAVER Cloud Corp. +# Apache-2.0 +# +# This is the ECAPA-TDNN model. +# "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" +# https://arxiv.org/pdf/2005.07143 +# +# This model is modified based on the following projects: +# - https://github.com/lawlict/ECAPA-TDNN/blob/master/ecapa_tdnn.py +# - https://github.com/TaoRuijie/ECAPA-TDNN/blob/main/model.py (MIT License) + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Res2Conv1dReluBn(nn.Module): + """ + Res2Conv1d + BatchNorm1d + ReLU + NOTE: in_channels == out_channels == channels + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False, + scale=4, + ): + super().__init__() + assert channels % scale == 0, f"{channels} % {scale} != 0" + self.scale = scale + self.width = channels // scale + self.nums = scale if scale == 1 else scale - 1 + + self.convs = [] + self.bns = [] + for i in range(self.nums): + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) + self.bns.append(nn.BatchNorm1d(self.width)) + self.convs = nn.ModuleList(self.convs) + self.bns = nn.ModuleList(self.bns) + + def forward(self, x): + out = [] + x_splits = torch.split(x, self.width, 1) + for i in range(self.nums): + if i == 0: + split = x_splits[i] + else: + split = split + x_splits[i] + # Order: conv -> relu -> bn + split = self.convs[i](split) + split = self.bns[i](F.relu(split)) + out.append(split) + if self.scale != 1: + out.append(x_splits[self.nums]) + out = torch.cat(out, dim=1) + return out + + +class Conv1dReluBn(nn.Module): + """Conv1d + BatchNorm1d + ReLU""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=False, + ): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, x): + return self.bn(F.relu(self.conv(x))) + + +class SE_Connect(nn.Module): + """The SE connection of 1D case.""" + + def __init__(self, channels, bottleneck_dim): + super().__init__() + self.linear1 = nn.Linear(channels, bottleneck_dim) + self.linear2 = nn.Linear(bottleneck_dim, channels) + + def forward(self, x): + out = x.mean(dim=2) + out = F.relu(self.linear1(out)) + out = torch.sigmoid(self.linear2(out)) + out = x * out.unsqueeze(2) + return out + + +class SE_Res2Block(nn.Module): + """SE-Res2Block of the ECAPA-TDNN architecture.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + scale, + se_bottleneck_dim, + ): + super().__init__() + self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale) + self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0) + self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim) + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + ) + + def forward(self, x): + residual = x + if self.shortcut: + residual = self.shortcut(x) + + x = self.Conv1dReluBn1(x) + x = self.Res2Conv1dReluBn(x) + x = self.Conv1dReluBn2(x) + x = self.SE_Connect(x) + + return x + residual + + +class AttentiveStatsPool(nn.Module): + def __init__(self, in_dim, attention_channels=128, global_context_att=False): + super().__init__() + self.global_context_att = global_context_att + + if global_context_att: + self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper + self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper + + def forward(self, x): + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! In experiments, I find ReLU hard to converge. + alpha = torch.tanh(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + residuals = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(residuals.clamp(min=1e-9)) + return torch.cat([mean, std], dim=1) + + +""" Implementation of + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification". + + Note that we DON'T concatenate the last frame-wise layer with non-weighted mean and standard deviation, + because it brings little improvement but significantly increases model parameters. + As a result, this implementation basically equals the A.2 of Table 2 in the paper. +""" + + +class ECAPA_TDNN(nn.Module): + def __init__(self, in_channel=100, hidden_channel=512, emb_dim=256, global_context_att=False): + super().__init__() + self.instance_norm = nn.InstanceNorm1d(in_channel) + self.channels = [hidden_channel] * 4 + [hidden_channel * 3] + + self.layer1 = Conv1dReluBn(in_channel, self.channels[0], kernel_size=5, padding=2) + self.layer2 = SE_Res2Block( + self.channels[0], + self.channels[1], + kernel_size=3, + stride=1, + padding=2, + dilation=2, + scale=8, + se_bottleneck_dim=128, + ) + self.layer3 = SE_Res2Block( + self.channels[1], + self.channels[2], + kernel_size=3, + stride=1, + padding=3, + dilation=3, + scale=8, + se_bottleneck_dim=128, + ) + self.layer4 = SE_Res2Block( + self.channels[2], + self.channels[3], + kernel_size=3, + stride=1, + padding=4, + dilation=4, + scale=8, + se_bottleneck_dim=128, + ) + + cat_channels = hidden_channel * 3 + self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1) + self.pooling = AttentiveStatsPool( + self.channels[-1], + attention_channels=128, + global_context_att=global_context_att, + ) + self.bn = nn.BatchNorm1d(self.channels[-1] * 2) + self.linear = nn.Linear(self.channels[-1] * 2, emb_dim) + self.bn_out = nn.BatchNorm1d(emb_dim) + + def forward(self, x): + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + + out = torch.cat([out2, out3, out4], dim=1) + out = F.relu(self.conv(out)) + out = self.bn(self.pooling(out)) + out = self.linear(out) + out = self.bn_out(out) + return out diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py b/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py new file mode 100644 index 00000000000..e04706eedfb --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_audio/hyperclovax_audio_decoder.py @@ -0,0 +1,622 @@ +# Copyright (c) 2024 NVIDIA CORPORATION. +# Licensed under the MIT license. +# +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# Portions from https://github.com/NVIDIA/BigVGAN under the MIT license. +# See NOTICE file for license details. + +import json +import math +from pathlib import Path + +import torch +import torch.nn as nn +from torch.nn.utils import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm +from vllm.logger import init_logger + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.models.hyperclovax_audio.activations import Activation1d, SnakeBeta +from vllm_omni.diffusion.models.hyperclovax_audio.ecapa_tdnn import ECAPA_TDNN + + +# Dataclass for model hyper-parameters +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +logger = init_logger(__name__) + + +def load_hparams_from_json(path) -> AttrDict: + with open(path) as f: + data = f.read() + return AttrDict(json.loads(data)) + + +# Functions for model initialization +def init_weights(m, mean=0.0, std=0.01): + if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class CausalConv1d(nn.Module): + """1D causal convloution w/ 1-side padding.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + pad_buffer=None, + ): + super().__init__() + self.conv = weight_norm( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + ) + self.stride = stride + self.pad_length = (kernel_size - 1) * dilation + + # TODO: deprecate pad_buffer and inference. Remove in the future + if pad_buffer is None: + pad_buffer = torch.zeros(1, in_channels, self.pad_length) + self.register_buffer("pad_buffer", pad_buffer) + + def forward(self, x, hidden_states=None): + if hidden_states is None: + x = nn.functional.pad(x, (self.pad_length, 0), "constant", value=0.0) + else: + assert hidden_states.shape[-1] >= self.pad_length + hidden_states = hidden_states[:, :, -self.pad_length :] + x = torch.cat((hidden_states, x), -1) + return self.conv(x), x[:, :, -self.pad_length :].detach() + + +class CausalConvTranspose1d(nn.Module): + """1D causal transpose convloution.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding=0, + output_padding=0, + groups=1, + bias=True, + pad_buffer=None, + ): + super().__init__() + self.deconv = weight_norm( + nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + output_padding=0, + groups=groups, + bias=bias, + ) + ) + self.stride = stride + self.pad_length = math.ceil(kernel_size / stride) - 1 + self.pad = nn.ReplicationPad1d((self.pad_length, 0)) + + # TODO: deprecate pad_buffer and inference. Remove in the future + if pad_buffer is None: + pad_buffer = torch.zeros(1, in_channels, self.pad_length) + self.register_buffer("pad_buffer", pad_buffer) + + def forward(self, x, hidden_states=None): + if hidden_states is None: + x = self.pad(x) + else: + assert hidden_states.shape[-1] >= self.pad_length + hidden_states = hidden_states[:, :, -self.pad_length :] + x = torch.cat((hidden_states, x), -1) + return ( + self.deconv(x)[:, :, self.stride : -self.stride], + x[:, :, -self.pad_length :].detach(), + ) + + +class NonCausalConv1d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + **kwargs, + ): + super().__init__() + self.conv = weight_norm( + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + **kwargs, + ) + ) + self.pad_length = ((kernel_size - 1) * dilation) // 2 + + def forward(self, x, hidden_states=None): + if hidden_states is None: + out = self.conv(x) + else: + assert hidden_states.shape[-1] >= self.pad_length + hidden_states = hidden_states[:, :, -self.pad_length :] + x_ = torch.cat((hidden_states, x), -1) + out = self.conv(x_)[:, :, self.pad_length :] + return out, x[:, :, -self.pad_length :].detach() + + +class NonCausalConvTranspose1d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + output_padding=0, + groups=1, + bias=True, + **kwargs, + ): + super().__init__() + self.deconv = weight_norm( + nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + **kwargs, + ) + ) + self.stride = stride + self.pad_length = (kernel_size - stride) // 2 + + def forward(self, x, hidden_states=None): + if hidden_states is None: + out = self.deconv(x) + else: + assert hidden_states.shape[-1] >= self.pad_length + hidden_states = hidden_states[:, :, -self.pad_length :] + x_ = torch.cat((hidden_states, x), -1) + out = self.deconv(x_)[:, :, self.pad_length * self.stride :] + return out, x[:, :, -self.pad_length :].detach() + + +class AMPBlock1(torch.nn.Module): + """ + AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters + that control periodicity, defined for each layer. + AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 + followed by each layer in self.convs1 + + Args: + h (AttrDict): Hyperparameters. + channels (int): Number of convolution channels. + kernel_size (int): Size of the convolution kernel. Default is 3. + dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. + Default is (1, 3, 5). + activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. + Default is None. + """ + + def __init__( + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, + causal: bool = False, + ): + super().__init__() + conv1d = CausalConv1d if causal else NonCausalConv1d + + self.h = h + + self.convs1 = nn.ModuleList( + [ + conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=d, + padding=get_padding(kernel_size, d), + ) + for d in dilation + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers + + # Activation functions + if activation == "snakebeta": + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale), + causal=False, + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x, hidden_states=None): + if hidden_states is None: + hidden_states = [(None, None, None, None)] * len(self.convs1) + + hidden_states_new = [] + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2, (h_a1, h_c1, h_a2, h_c2) in zip(self.convs1, self.convs2, acts1, acts2, hidden_states): + xt, ht_a1 = a1(x, h_a1) + xt, ht_c1 = c1(xt, h_c1) + xt, ht_a2 = a2(xt, h_a2) + xt, ht_c2 = c2(xt, h_c2) + x = xt + x + hidden_states_new.append((ht_a1, ht_c1, ht_a2, ht_c2)) + + return x, hidden_states_new + + def remove_weight_norm(self): + for layer in self.convs1: + remove_weight_norm(layer) + for layer in self.convs2: + remove_weight_norm(layer) + + +class HyperCLOVAXAudioDecoderModel(nn.Module): + """ + HyperCLOVAXAudioDecoderModel is a neural vocoder model that applies anti-aliased periodic activation + for residual blocks (resblocks). + + Args: + od_config (OmniDiffusionConfig): Configuration object containing model hyperparameters. + + Note: + Ensure that the activation function is correctly specified in the hyperparameters (h.activation). + """ + + def __init__( + self, + od_config: OmniDiffusionConfig, + resblock: str = "1", + causal: bool = False, + finetune: bool = True, + upsample_rates: list[int] = [5, 4, 4, 3, 2, 2], + upsample_kernel_sizes: list[int] = [10, 8, 8, 6, 4, 4], + upsample_initial_channel: int = 1536, + resblock_kernel_sizes: list[int] = [3, 7, 11], + resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_tanh_at_final: bool = False, + use_bias_at_final: bool = False, + activation: str = "snakebeta", + snake_logscale: bool = True, + num_units: int = 6561, + unit_emb_dim: int = 1280, + num_mels: int = 100, + n_fft: int = 1024, + hop_size: int = 256, + win_size: int = 1024, + spk_emb_dim: int = 256, + spk_hidden_dim: int = 512, + global_context_att: bool = False, + sampling_rate: int = 24000, + fmin: int = 0, + fmax: int = 8000, + num_spk: int = 26, + pad_multiple: int = 100, + pad_token_id: int = 3894, + ): + super().__init__() + + self.h = AttrDict( + { + "resblock": resblock, + "causal": causal, + "finetune": finetune, + "upsample_rates": upsample_rates, + "upsample_kernel_sizes": upsample_kernel_sizes, + "upsample_initial_channel": upsample_initial_channel, + "resblock_kernel_sizes": resblock_kernel_sizes, + "resblock_dilation_sizes": resblock_dilation_sizes, + "use_tanh_at_final": use_tanh_at_final, + "use_bias_at_final": use_bias_at_final, + "activation": activation, + "snake_logscale": snake_logscale, + "num_units": num_units, + "unit_emb_dim": unit_emb_dim, + "num_mels": num_mels, + "n_fft": n_fft, + "hop_size": hop_size, + "win_size": win_size, + "spk_emb_dim": spk_emb_dim, + "spk_hidden_dim": spk_hidden_dim, + "global_context_att": global_context_att, + "sampling_rate": sampling_rate, + "fmin": fmin, + "fmax": fmax, + "num_spk": num_spk, + "pad_multiple": pad_multiple, + "pad_token_id": pad_token_id, + } + ) + + self.causal = self.h.get("causal", True) + conv1d = CausalConv1d if self.causal else NonCausalConv1d + convtranspose1d = CausalConvTranspose1d if self.causal else NonCausalConvTranspose1d + + self.num_kernels = len(self.h.resblock_kernel_sizes) + self.num_upsamples = len(self.h.upsample_rates) + + self.finetune = getattr(self.h, "finetune", False) + # Speaker embedding + if not self.finetune: + self.spk_emb = ECAPA_TDNN( + in_channel=self.h.num_mels, + hidden_channel=self.h.spk_hidden_dim, + emb_dim=self.h.spk_emb_dim, + global_context_att=self.h.global_context_att, + ) + else: + self.spk_emb = nn.Embedding(self.h.num_spk, self.h.spk_emb_dim) + + # Unit embedding + self.unit_emb = nn.Embedding(self.h.num_units, self.h.unit_emb_dim) + self.unit_emb_dim = self.h.unit_emb_dim + + # Pre-conv + self.conv_pre = conv1d( + self.h.unit_emb_dim + self.h.spk_emb_dim, self.h.upsample_initial_channel, 7, 1, padding=3 + ) + + # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + if self.h.resblock == "1": + resblock_class = AMPBlock1 + else: + raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {self.h.resblock}") + + # Transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(self.h.upsample_rates, self.h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList( + [ + convtranspose1d( + self.h.upsample_initial_channel // (2**i), + self.h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=math.ceil((k - u) / 2), + output_padding=(k - u) % 2, + ) + ] + ) + ) + + # Residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = self.h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(self.h.resblock_kernel_sizes, self.h.resblock_dilation_sizes)): + self.resblocks.append( + resblock_class(self.h, ch, k, d, activation=self.h.activation, causal=self.causal) + ) + + # Post-conv + activation_post = ( + SnakeBeta(ch, alpha_logscale=self.h.snake_logscale) if self.h.activation == "snakebeta" else None + ) + if activation_post is None: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.activation_post = Activation1d(activation=activation_post, causal=False) + + # Whether to use bias for the final conv_post. Default to True for backward compatibility + self.use_bias_at_final = self.h.get("use_bias_at_final", True) + self.conv_post = conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final) + + # Weight initialization + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + # Final tanh activation. Defaults to True for backward compatibility + self.use_tanh_at_final = self.h.get("use_tanh_at_final", True) + + self.num_layers = self.num_upsamples + self.num_upsamples * self.num_kernels + 3 + + def forward_with_spk_emb(self, x, spk_or_ref, hidden_states=None): + spk_emb = self.spk_emb(spk_or_ref) + return self(x, spk_emb, hidden_states=hidden_states) + + def forward(self, x, spk_emb, hidden_states=None): + if hidden_states is None: + hidden_states = [None] * self.num_layers + else: + assert len(hidden_states) == self.num_layers, ( + f"Expected hidden_states to have {self.num_layers} elements, but got {len(hidden_states)}." + ) + + hidden_state_iter = iter(hidden_states) + hidden_states_new = [] + + # Unit and speaker embedding + x = self.unit_emb(x).transpose(1, 2) * (self.unit_emb_dim**-0.5) + if self.finetune: + spk_emb = spk_emb.transpose(1, 2).expand(-1, -1, x.shape[-1]) + else: + spk_emb = spk_emb.unsqueeze(2).expand(-1, -1, x.shape[-1]) + + x = torch.cat([x, spk_emb], dim=1) + x, h = self.conv_pre(x, next(hidden_state_iter)) + hidden_states_new.append(h) + + for i, up_layers in enumerate(self.ups): + # Upsampling + for up_layer in up_layers: + x, h = up_layer(x, next(hidden_state_iter)) + hidden_states_new.append(h) + # AMP blocks + resblock_outputs = [ + self.resblocks[i * self.num_kernels + j](x, next(hidden_state_iter)) for j in range(self.num_kernels) + ] + x = sum(o for o, _ in resblock_outputs) / self.num_kernels + hidden_states_new.extend([h for _, h in resblock_outputs]) + + # Post-conv + x, h = self.activation_post(x, next(hidden_state_iter)) + hidden_states_new.append(h) + x, h = self.conv_post(x, next(hidden_state_iter)) + hidden_states_new.append(h) + # Final tanh activation + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1] + + return x, hidden_states_new + + def remove_weight_norm(self): + try: + logger.info("Removing weight norm...") + for layer in self.ups: + for l_i in layer: + remove_weight_norm(l_i) + for layer in self.resblocks: + layer.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + except ValueError: + logger.warning("Model already removed weight norm. Skipping!") + pass + + @classmethod + def from_pretrained( + cls, + ckpt_path: str, + config_path: str | None = None, + map_location: str = "cpu", # Additional argument + ): + """Load Pytorch pretrained weights and return the loaded model.""" + + # Load hyperparameters (h) used by BigVGAN + if config_path is None: + logger.info("Loading config.json from local directory") + config_path = Path(ckpt_path).with_name("config.json") + + h = load_hparams_from_json(config_path) + + # Instantiate model using hyperparameters from config.json. + # Constructor kwargs mirror the keys stored in h; unknown keys are ignored. + _INIT_KEYS = { + "resblock", + "causal", + "finetune", + "upsample_rates", + "upsample_kernel_sizes", + "upsample_initial_channel", + "resblock_kernel_sizes", + "resblock_dilation_sizes", + "use_tanh_at_final", + "use_bias_at_final", + "activation", + "snake_logscale", + "num_units", + "unit_emb_dim", + "num_mels", + "n_fft", + "hop_size", + "win_size", + "spk_emb_dim", + "spk_hidden_dim", + "global_context_att", + "sampling_rate", + "fmin", + "fmax", + "num_spk", + "pad_multiple", + "pad_token_id", + } + from vllm_omni.diffusion.data import OmniDiffusionConfig + + od_config = OmniDiffusionConfig(model=str(ckpt_path)) + model = cls(od_config=od_config, **{k: v for k, v in h.items() if k in _INIT_KEYS}) + + # Load pretrained generator weight + logger.info("Loading weights from local directory") + checkpoint_dict = torch.load(ckpt_path, map_location=map_location) + + try: + model.load_state_dict(checkpoint_dict["generator"]) + except RuntimeError: + logger.warning( + "The pretrained checkpoint does not contain weight norm. " + "Loading the checkpoint after removing weight norm!" + ) + model.remove_weight_norm() + model.load_state_dict(checkpoint_dict["generator"]) + + return model diff --git a/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py b/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py new file mode 100644 index 00000000000..81e7584e219 --- /dev/null +++ b/vllm_omni/diffusion/models/hyperclovax_audio/pipeline_hyperclovax_audio.py @@ -0,0 +1,436 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import io +import json +import math +import os +import tempfile +import zipfile +from collections.abc import Iterable +from pathlib import Path +from typing import Any + +import librosa +import numpy as np +import pydub +import scipy.signal +import torch +import torch.nn as nn +from librosa.filters import mel as librosa_mel_fn +from pydub import AudioSegment +from vllm.logger import init_logger +from vllm.model_executor.models.utils import AutoWeightsLoader + +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from .constants import AUDIO_FORMAT_MAP, DEFAULT_FORMAT, FORMAT_MIME_MAP, SPEAKERS_LIST, VOLUME_LEVEL +from .hyperclovax_audio_decoder import HyperCLOVAXAudioDecoderModel + +logger = init_logger(__name__) + +# Global caches for mel filter banks and Hann windows. +mel_basis = {} +hann_window = {} + + +def get_hyperclovax_audio_post_process_func(od_config: OmniDiffusionConfig): + """ + Get post-processing function for HyperCLOVAX Audio pipeline. + + Returns a function that converts model output tensors to audio file. + """ + + def post_process_func(output: list[tuple[torch.Tensor, str]]) -> list[bytes]: + response = [] + for wav_tensor, fmt in output: + wav = wav_tensor.squeeze().cpu().numpy() + pcm = (wav * 32767.0).astype(np.int16) + + if fmt == "pcm": + response.append(pcm.tobytes()) + continue + + segment = AudioSegment(pcm.tobytes(), frame_rate=24000, sample_width=pcm.dtype.itemsize, channels=1) + + buf = io.BytesIO() + export_kwargs = {"format": fmt} + segment.export(buf, **export_kwargs) + response.append(buf.getvalue()) + return response + + return post_process_func + + +class HyperCLOVAXAudioPipeline(nn.Module): + def __init__( + self, + *, + od_config: OmniDiffusionConfig, + prefix: str = "", + ): + super().__init__() + self.od_config = od_config + self.device = get_local_device() + self._dtype = od_config.dtype + + self.model = self.od_config.model + self.weights_sources = [] + + # Note: + # Audio decoder checkpoint of HyperCLOVAX-Omni is currently provided as .mar format + mar_path = self._resolve_mar_path(self.model) + if mar_path is not None: + self._using_mar_checkpoint = True + self.weights_sources = [] + ckpt_path, config_path = self._extract_mar_checkpoint(mar_path) + self.bigvgan = HyperCLOVAXAudioDecoderModel.from_pretrained( + ckpt_path=ckpt_path, + config_path=config_path, + map_location="cpu", + ).to(self.device) + else: + self.bigvgan = HyperCLOVAXAudioDecoderModel(od_config=od_config).to(self.device) + + self.spk_emb = self.bigvgan.spk_emb.to(self.device) + self._vocab = int(getattr(self.bigvgan.h, "num_units", 0)) + + speakers = SPEAKERS_LIST + self.speaker_map = {spk: i for i, spk in enumerate(speakers)} + + def _resolve_mar_path(self, model: str | None) -> Path | None: + if model is None: + return None + + model_path = Path(model) + if model_path.is_file() and model_path.suffix == ".mar": + return model_path + + if not model_path.is_dir(): + return None + + candidates = [ + model_path / "NCCosybigvganDecoder.mar", + model_path / "NCZSCosybigvganDecoder.mar", + model_path / "decoder" / "audio" / "NCCosybigvganDecoder.mar", + model_path / "decoder" / "audio" / "NCZSCosybigvganDecoder.mar", + ] + for candidate in candidates: + if candidate.exists(): + return candidate + return None + + def _extract_mar_checkpoint(self, mar_path: Path) -> tuple[str, str]: + extract_dir = Path(tempfile.mkdtemp(prefix="hcx_audio_decoder_")) + self._mar_extract_dir = str(extract_dir) + + with zipfile.ZipFile(mar_path) as zf: + manifest = json.loads(zf.read("MAR-INF/MANIFEST.json")) + serialized_file = manifest.get("model", {}).get("serializedFile") + if not serialized_file: + raise ValueError(f"serializedFile not found in {mar_path}") + + zf.extract(serialized_file, path=extract_dir) + zf.extract("config.json", path=extract_dir) + + return str(extract_dir / serialized_file), str(extract_dir / "config.json") + + def _prepare_batch( + self, audio_tokens: list[list[int]], speakers: list[str], formats: list[str], ref_audio_tokens: list[str] + ) -> list[tuple[torch.Tensor, torch.Tensor, str]]: + """ + Construct batch to forward through the model. + + Args: + - audio_tokens: List[List[int]]: discrete audio tokens to decode. + - speakers: List[str]: speaker IDs for output audio. + - formats: List[str]: output audio formats. + - ref_audio_tokens: List[str]: + List of base64 encoded reference audio. + If provided, speaker and format will be ignored. + + Returns: + batch: List of tuples of (audio_tokens, speaker_id or ref_mel, format) + """ + batch = [] + for units, speaker, fmt, ref_audio in zip(audio_tokens, speakers, formats, ref_audio_tokens): + units = torch.tensor(units, dtype=torch.long, device=self.device) + + if self._vocab > 0: + mask = (units < 0) | (units >= self._vocab) + if mask.any(): + bad_idxs = units[mask].tolist() + raise ValueError(f"Unit indices out of range [0-{self._vocab - 1}]: {bad_idxs}") + + if ref_audio is not None: + ref_audio_bytes = base64.b64decode(ref_audio.encode("ascii"), validate=True) + ref_mel = ( + self._get_reference_mel_spectrogram(ref_audio_bytes, self.bigvgan.h).to(self.device).to(self._dtype) + ) + batch.append((units, ref_mel, None)) + else: + speaker = "fkms" if speaker is None else speaker + fmt = DEFAULT_FORMAT.lower() if fmt is None else fmt.lower() + if fmt not in FORMAT_MIME_MAP: + raise ValueError(f"Unsupported format '{fmt}'. Choose from {list(FORMAT_MIME_MAP)}") + speaker_id = torch.tensor([self.speaker_map[speaker]], dtype=torch.long) + speaker_id = speaker_id.unsqueeze(0).to(self.device) + + batch.append((units, speaker_id, fmt)) + + return batch + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + """ + Generate audio from audio tokens. + + Args: + req: OmniDiffusionRequest must containing: + - audio_tokens: List[List[int]]: [B, L] or [L, ] audio token ids. + - speakers: List[str]: speaker for each audio sample. + - extra["formats"]: List[str]: output audio format for each audio sample. + - extra["ref_audio_tokens"]: List[str]: base64 encoded reference audio for each audio sample. + + Returns: + OmniDiffusionResponse: The diffusion response. + """ + + # 1. Validate inputs exist in request + request_input = req.prompts[0].get("additional_information", None) + if request_input is None: + return DiffusionOutput(output=None, error="additional_information required in request_input") + + audio_tokens = request_input.get("audio_tokens") + if audio_tokens is None: + return DiffusionOutput(output=None, error="audio_tokens required in request_input") + + speakers = request_input.get("speakers", ["fkms"]) * len(audio_tokens) + + if len(audio_tokens) != len(speakers): + return DiffusionOutput(output=None, error="length of speakers and audio_tokens must be the same") + + if not all(speaker in SPEAKERS_LIST for speaker in speakers): + return DiffusionOutput(output=None, error=f"speakers must be one of {SPEAKERS_LIST}") + + # Optional: audio format. If not provided, use wav format as default. + formats = req.extra.get("formats", [DEFAULT_FORMAT.lower()] * len(audio_tokens)) + if len(audio_tokens) != len(formats): + return DiffusionOutput(output=None, error="length of formats and audio_tokens must be the same") + + ref_audio_tokens = req.extra.get("ref_audio_tokens") + if ref_audio_tokens is None: + ref_audio_tokens = [None] * len(audio_tokens) + if len(audio_tokens) != len(ref_audio_tokens): + return DiffusionOutput(output=None, error="length of ref_audio_tokens and audio_tokens must be the same") + + # 2. Construct batch from given request inputs + batch = self._prepare_batch(audio_tokens, speakers, formats, ref_audio_tokens) + results: list[tuple[torch.Tensor, str]] = [] + + for units, speaker, fmt in batch: + # 3. Convert to tensor if needed + if isinstance(units, list): + units = torch.tensor(units, dtype=torch.long) + elif isinstance(units, np.ndarray): + units = torch.from_numpy(units).long() + + if len(units.size()) == 2 and units.size(0) == 1: + return DiffusionOutput(output=None, error="the underlying decoder does not support batch inference yet") + + units = units.unsqueeze(0) + units = units.to(self.device) + padded_unit, original_portion = self.pad(units) + + # 4. Generate speaker embedding + spk_emb = self.spk_emb(speaker) + + # 5. Decode audio + padded_out, hidden = self.bigvgan(padded_unit, spk_emb=spk_emb) + del hidden + out = self.unpad(padded_out, original_portion) + + # 6. Append decoded audio to result + results.append((out.to(torch.float32), fmt)) + + return DiffusionOutput( + output=results, post_process_func=get_hyperclovax_audio_post_process_func(self.od_config) + ) + + def pad(self, unit: torch.Tensor) -> tuple[torch.Tensor, float]: + """ + Pad the `unit` tensor to AUDIOLLM_PAD_MULTIPLE environment variable. + + Args: + unit: int tensor of shape [1, L] + """ + + pad_multiple = self._get_pad_multiple() + if not pad_multiple: + return unit, 1.0 + + pad_token_id = self._get_pad_token_id() + if pad_token_id is None: + return unit, 1.0 + + overflow = unit.shape[1] % pad_multiple + if overflow == 0: + return unit, 1.0 + pad_amount = pad_multiple - overflow + padded = torch.nn.functional.pad(unit, (0, pad_amount), mode="constant", value=pad_token_id) + return padded, unit.shape[-1] / padded.shape[-1] + + def unpad(self, x: torch.Tensor, original_portion: float) -> torch.Tensor: + """ + Unpad the `x` tensor by retaining only the `original_portion`. + + Args: + x: tensor of shape [..., T] + original_portion: ratio of original unit length over padded unit length + """ + return x[..., : math.ceil(x.shape[-1] * original_portion)] + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load model weights using AutoWeightsLoader. + """ + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def _get_pad_multiple(self) -> int | None: + pad_multiple = int(getattr(self.bigvgan.h, "pad_multiple", 0)) + if pad_multiple <= 0: + return None + return pad_multiple + + def _get_pad_token_id(self) -> int | None: + pad_token_id = int(getattr(self.bigvgan.h, "pad_token_id", -1)) + if pad_token_id < 0: + return None + return pad_token_id + + def _get_down_sample_rate(self) -> float | None: + down_sample_rate_str = os.getenv("AUDIOLLM_DOWN_SAMPLE_RATE") + if not down_sample_rate_str: + return None + + try: + down_sample_rate = float(down_sample_rate_str) + except ValueError: + logger.warning( + "AUDIOLLM_DOWN_SAMPLE_RATE environment variable is not a valid float. Skipping down-sampling..." + ) + return None + + if down_sample_rate <= 0: + logger.warning( + "AUDIOLLM_DOWN_SAMPLE_RATE environment variable is not a positive float. Skipping down-sampling..." + ) + return None + + return down_sample_rate + + def _detect_audio_format(self, header_bytes: bytes) -> str | None: + """ + Detect audio format from header bytes of audio file. + + Args: + header_bytes: first 4 bytes of audio file. + """ + for prefix_bytes, fmt in AUDIO_FORMAT_MAP: + if header_bytes.startswith(prefix_bytes): + return fmt + return None + + def _hpf_normalize(self, pcm: np.ndarray, sr: int | float, volume_level: float) -> np.ndarray: + assert (pcm**2).mean() > 0, "Error in the wav file" + assert np.issubdtype(pcm.dtype, np.floating) + + # highpass filter + filter_ = scipy.signal.butter(2, 70, "highpass", fs=sr, output="sos") + pcm = scipy.signal.sosfilt(filter_, pcm) + pcm = pcm.astype(np.float32) + + # volume normalize + gain = min(volume_level / (pcm**2).mean() ** 0.5, 1 / np.max(np.abs(pcm))) + pcm *= gain + return pcm + + def _load_reference_audio(self, audio: bytes, sample_rate: float) -> np.ndarray: + fmt = self._detect_audio_format(audio[:4]) + audio = io.BytesIO(audio) + + if fmt: + segment = pydub.AudioSegment.from_file(audio, format=fmt) + else: + segment = pydub.AudioSegment.from_file(audio) + + wav_file = io.BytesIO() + segment.export(wav_file, format="wav") + wav_file.seek(0) + + # Down-sample to reduce noise in final result. + load_sr = self._get_down_sample_rate() + if load_sr is None: + load_sr = sample_rate + pcm, sr = librosa.load(wav_file, sr=load_sr, mono=True) + pcm = librosa.resample(pcm, orig_sr=sr, target_sr=sample_rate) + + pcm = self._hpf_normalize(pcm, sample_rate, VOLUME_LEVEL) + return pcm + + def _compute_mel_spectrogram(self, y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + global mel_basis, hann_window + # Create a unique key based on fmax and device + key = f"{fmax}_{y.device}" + if key not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[key] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + # Pad the signal for STFT + pad_amount = int((n_fft - hop_size) / 2) + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_amount, pad_amount), mode="reflect").squeeze(1) + + # Compute the Short-Time Fourier Transform (STFT) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + # Compute the magnitude spectrogram with a small epsilon to avoid log(0) + spec = torch.sqrt(torch.real(spec * spec.conj() + 1e-9)) + + # Map the linear-frequency spectrogram to the mel scale + spec = torch.matmul(mel_basis[key], spec) + + # Apply spectral normalization (dynamic range compression) + spec = torch.log(torch.clamp(spec, min=1e-5)) + + return spec + + def _get_reference_mel_spectrogram(self, ref_audio: bytes, h: dict[str, Any]) -> torch.Tensor: + pcm = self._load_reference_audio(ref_audio, h.sampling_rate) + pcm = torch.from_numpy(pcm).unsqueeze(0) + + mel = self._compute_mel_spectrogram( + pcm, + h.n_fft, + h.num_mels, + h.sampling_rate, + h.hop_size, + h.win_size, + h.fmin, + h.fmax, + ) + return mel diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index 97bc7fa2925..6258a8f3416 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -188,6 +188,11 @@ "pipeline_omnivoice", "OmniVoicePipeline", ), + "HyperCLOVAXAudioPipeline": ( + "hyperclovax_audio", + "pipeline_hyperclovax_audio", + "HyperCLOVAXAudioPipeline", + ), } @@ -375,6 +380,7 @@ def _apply_sequence_parallel_if_enabled(model, od_config: OmniDiffusionConfig) - "HunyuanVideo15ImageToVideoPipeline": "get_hunyuan_video_15_i2v_post_process_func", "MagiHumanPipeline": "get_magi_human_post_process_func", "OmniVoicePipeline": "get_omnivoice_post_process_func", + "HyperCLOVAXAudioPipeline": "get_hyperclovax_audio_post_process_func", } _DIFFUSION_PRE_PROCESS_FUNCS = {