diff --git a/comfy/ldm/mmaudio/vae/__init__.py b/comfy/ldm/mmaudio/vae/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/comfy/ldm/mmaudio/vae/activations.py b/comfy/ldm/mmaudio/vae/activations.py new file mode 100644 index 000000000000..db9192e3e4c4 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter +import comfy.model_management + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = Parameter(torch.empty(in_features)) + else: + self.alpha = Parameter(torch.empty(in_features)) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = Parameter(torch.empty(in_features)) + self.beta = Parameter(torch.empty(in_features)) + else: + self.alpha = Parameter(torch.empty(in_features)) + self.beta = Parameter(torch.empty(in_features)) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x diff --git a/comfy/ldm/mmaudio/vae/alias_free_torch.py b/comfy/ldm/mmaudio/vae/alias_free_torch.py new file mode 100644 index 000000000000..35c70b897120 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/alias_free_torch.py @@ -0,0 +1,157 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import comfy.model_management + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), + stride=self.stride, groups=C) + + return out + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x diff --git a/comfy/ldm/mmaudio/vae/autoencoder.py b/comfy/ldm/mmaudio/vae/autoencoder.py new file mode 100644 index 000000000000..cbb9de302092 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/autoencoder.py @@ -0,0 +1,156 @@ +from typing import Literal + +import torch +import torch.nn as nn + +from .distributions import DiagonalGaussianDistribution +from .vae import VAE_16k +from .bigvgan import BigVGANVocoder +import logging + +try: + import torchaudio +except: + logging.warning("torchaudio missing, MMAudio VAE model will be broken") + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + +class MelConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + fmin: float, + fmax: float, + norm_fn, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + # mel = librosa_mel_fn(sr=self.sampling_rate, + # n_fft=self.n_fft, + # n_mels=self.num_mels, + # fmin=self.fmin, + # fmax=self.fmax) + # mel_basis = torch.from_numpy(mel).float() + mel_basis = torch.empty((num_mels, 1 + n_fft // 2)) + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.mel_basis.device + + def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor: + waveform = waveform.clamp(min=-1., max=1.).to(self.device) + + waveform = torch.nn.functional.pad( + waveform.unsqueeze(1), + [int((self.n_fft - self.hop_size) / 2), + int((self.n_fft - self.hop_size) / 2)], + mode='reflect') + waveform = waveform.squeeze(1) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=center, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.matmul(self.mel_basis, spec) + spec = spectral_normalize_torch(spec, self.norm_fn) + + return spec + +class AudioAutoencoder(nn.Module): + + def __init__( + self, + *, + # ckpt_path: str, + mode=Literal['16k', '44k'], + need_vae_encoder: bool = True, + ): + super().__init__() + + assert mode == "16k", "Only 16k mode is supported currently." + self.mel_converter = MelConverter(sampling_rate=16_000, + n_fft=1024, + num_mels=80, + hop_size=256, + win_size=1024, + fmin=0, + fmax=8_000, + norm_fn=torch.log10) + + self.vae = VAE_16k().eval() + + bigvgan_config = { + "resblock": "1", + "num_mels": 80, + "upsample_rates": [4, 4, 2, 2, 2, 2], + "upsample_kernel_sizes": [8, 8, 4, 4, 4, 4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [ + [1, 3, 5], + [1, 3, 5], + [1, 3, 5], + ], + "activation": "snakebeta", + "snake_logscale": True, + } + + self.vocoder = BigVGANVocoder( + bigvgan_config + ).eval() + + @torch.inference_mode() + def encode_audio(self, x) -> DiagonalGaussianDistribution: + # x: (B * L) + mel = self.mel_converter(x) + dist = self.vae.encode(mel) + + return dist + + @torch.no_grad() + def decode(self, z): + mel_decoded = self.vae.decode(z) + audio = self.vocoder(mel_decoded) + + audio = torchaudio.functional.resample(audio, 16000, 44100) + return audio + + @torch.no_grad() + def encode(self, audio): + audio = audio.mean(dim=1) + audio = torchaudio.functional.resample(audio, 44100, 16000) + dist = self.encode_audio(audio) + return dist.mean diff --git a/comfy/ldm/mmaudio/vae/bigvgan.py b/comfy/ldm/mmaudio/vae/bigvgan.py new file mode 100644 index 000000000000..3a24337f6234 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/bigvgan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2022 NVIDIA CORPORATION. +# Licensed under the MIT license. + +# Adapted from https://github.com/jik876/hifi-gan under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +from types import SimpleNamespace +from . import activations +from .alias_free_torch import Activation1d +import comfy.ops +ops = comfy.ops.disable_weight_init + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + +class AMPBlock1(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): + super(AMPBlock1, self).__init__() + self.h = h + + self.convs1 = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2])) + ]) + + self.convs2 = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1)) + ]) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + +class AMPBlock2(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.h = h + + self.convs = nn.ModuleList([ + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0])), + ops.Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1])) + ]) + + self.num_layers = len(self.convs) # total number of conv layers + + if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList([ + Activation1d( + activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) + for _ in range(self.num_layers) + ]) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + +class BigVGANVocoder(torch.nn.Module): + # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. + def __init__(self, h): + super().__init__() + if isinstance(h, dict): + h = SimpleNamespace(**h) + self.h = h + + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + + # pre conv + self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + + # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default + resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 + + # transposed conv-based upsamplers. does not apply anti-aliasing + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + nn.ModuleList([ + ops.ConvTranspose1d(h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2) + ])) + + # residual blocks using anti-aliased multi-periodicity composition modules (AMP) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) + + # post conv + if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing + activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing + activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3) + + + def forward(self, x): + # pre conv + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + # upsampling + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + # AMP blocks + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + # post conv + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/ldm/mmaudio/vae/distributions.py b/comfy/ldm/mmaudio/vae/distributions.py new file mode 100644 index 000000000000..df987c5ec3fa --- /dev/null +++ b/comfy/ldm/mmaudio/vae/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py new file mode 100644 index 000000000000..62f24606c111 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -0,0 +1,358 @@ +import logging +from typing import Optional + +import torch +import torch.nn as nn + +from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, + Upsample1D, nonlinearity) +from .distributions import DiagonalGaussianDistribution + +import comfy.ops +ops = comfy.ops.disable_weight_init + +log = logging.getLogger() + +DATA_MEAN_80D = [ + -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927, + -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728, + -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131, + -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280, + -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643, + -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436, + -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282, + -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673 +] + +DATA_STD_80D = [ + 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263, + 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194, + 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043, + 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973, + 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939, + 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604, + 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070 +] + +DATA_MEAN_128D = [ + -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597, + -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033, + -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157, + -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782, + -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647, + -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795, + -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121, + -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960, + -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712, + -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120, + -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663, + -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628, + -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861 +] + +DATA_STD_128D = [ + 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659, + 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557, + 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182, + 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991, + 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900, + 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817, + 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609, + 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812, + 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451, + 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877, + 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164 +] + + +class VAE(nn.Module): + + def __init__( + self, + *, + data_dim: int, + embed_dim: int, + hidden_dim: int, + ): + super().__init__() + + if data_dim == 80: + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32)) + elif data_dim == 128: + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32)) + + self.data_mean = self.data_mean.view(1, -1, 1) + self.data_std = self.data_std.view(1, -1, 1) + + self.encoder = Encoder1D( + dim=hidden_dim, + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_layers=[3], + down_layers=[0], + in_dim=data_dim, + embed_dim=embed_dim, + ) + self.decoder = Decoder1D( + dim=hidden_dim, + ch_mult=(1, 2, 4), + num_res_blocks=2, + attn_layers=[3], + down_layers=[0], + in_dim=data_dim, + out_dim=data_dim, + embed_dim=embed_dim, + ) + + self.embed_dim = embed_dim + # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1) + # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1) + + self.initialize_weights() + + def initialize_weights(self): + pass + + def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution: + if normalize: + x = self.normalize(x) + moments = self.encoder(x) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor: + dec = self.decoder(z) + if unnormalize: + dec = self.unnormalize(dec) + return dec + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + + def unnormalize(self, x: torch.Tensor) -> torch.Tensor: + return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device) + + def forward( + self, + x: torch.Tensor, + sample_posterior: bool = True, + rng: Optional[torch.Generator] = None, + normalize: bool = True, + unnormalize: bool = True, + ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: + + posterior = self.encode(x, normalize=normalize) + if sample_posterior: + z = posterior.sample(rng) + else: + z = posterior.mode() + dec = self.decode(z, unnormalize=unnormalize) + return dec, posterior + + def load_weights(self, src_dict) -> None: + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def remove_weight_norm(self): + return self + + +class Encoder1D(nn.Module): + + def __init__(self, + *, + dim: int, + ch_mult: tuple[int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_layers: list[int] = [], + down_layers: list[int] = [], + resamp_with_conv: bool = True, + in_dim: int, + embed_dim: int, + double_z: bool = True, + kernel_size: int = 3, + clip_act: float = 256.0): + super().__init__() + self.dim = dim + self.num_layers = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_dim + self.clip_act = clip_act + self.down_layers = down_layers + self.attn_layers = attn_layers + self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + in_ch_mult = (1, ) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + # downsampling + self.down = nn.ModuleList() + for i_level in range(self.num_layers): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = dim * in_ch_mult[i_level] + block_out = dim * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock1D(in_dim=block_in, + out_dim=block_out, + kernel_size=kernel_size, + use_norm=True)) + block_in = block_out + if i_level in attn_layers: + attn.append(AttnBlock1D(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level in down_layers: + down.downsample = Downsample1D(block_in, resamp_with_conv) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock1D(in_dim=block_in, + out_dim=block_in, + kernel_size=kernel_size, + use_norm=True) + self.mid.attn_1 = AttnBlock1D(block_in) + self.mid.block_2 = ResnetBlock1D(in_dim=block_in, + out_dim=block_in, + kernel_size=kernel_size, + use_norm=True) + + # end + self.conv_out = ops.Conv1d(block_in, + 2 * embed_dim if double_z else embed_dim, + kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + self.learnable_gain = nn.Parameter(torch.zeros([])) + + def forward(self, x): + + # downsampling + h = self.conv_in(x) + for i_level in range(self.num_layers): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + h = h.clamp(-self.clip_act, self.clip_act) + if i_level in self.down_layers: + h = self.down[i_level].downsample(h) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = h.clamp(-self.clip_act, self.clip_act) + + # end + h = nonlinearity(h) + h = self.conv_out(h) * (self.learnable_gain + 1) + return h + + +class Decoder1D(nn.Module): + + def __init__(self, + *, + dim: int, + out_dim: int, + ch_mult: tuple[int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_layers: list[int] = [], + down_layers: list[int] = [], + kernel_size: int = 3, + resamp_with_conv: bool = True, + in_dim: int, + embed_dim: int, + clip_act: float = 256.0): + super().__init__() + self.ch = dim + self.num_layers = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.in_channels = in_dim + self.clip_act = clip_act + self.down_layers = [i + 1 for i in down_layers] # each downlayer add one + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = dim * ch_mult[self.num_layers - 1] + + # z to block_in + self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) + self.mid.attn_1 = AttnBlock1D(block_in) + self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_layers)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = dim * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True)) + block_in = block_out + if i_level in attn_layers: + attn.append(AttnBlock1D(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level in self.down_layers: + up.upsample = Upsample1D(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + self.learnable_gain = nn.Parameter(torch.zeros([])) + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = h.clamp(-self.clip_act, self.clip_act) + + # upsampling + for i_level in reversed(range(self.num_layers)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + h = h.clamp(-self.clip_act, self.clip_act) + if i_level in self.down_layers: + h = self.up[i_level].upsample(h) + + h = nonlinearity(h) + h = self.conv_out(h) * (self.learnable_gain + 1) + return h + + +def VAE_16k(**kwargs) -> VAE: + return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs) + + +def VAE_44k(**kwargs) -> VAE: + return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs) + + +def get_my_vae(name: str, **kwargs) -> VAE: + if name == '16k': + return VAE_16k(**kwargs) + if name == '44k': + return VAE_44k(**kwargs) + raise ValueError(f'Unknown model: {name}') + diff --git a/comfy/ldm/mmaudio/vae/vae_modules.py b/comfy/ldm/mmaudio/vae/vae_modules.py new file mode 100644 index 000000000000..3ad05134bc44 --- /dev/null +++ b/comfy/ldm/mmaudio/vae/vae_modules.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import vae_attention +import math +import comfy.ops +ops = comfy.ops.disable_weight_init + +def nonlinearity(x): + # swish + return torch.nn.functional.silu(x) / 0.596 + +def mp_sum(a, b, t=0.5): + return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2) + +def normalize(x, dim=None, eps=1e-4): + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + +class ResnetBlock1D(nn.Module): + + def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True): + super().__init__() + self.in_dim = in_dim + out_dim = in_dim if out_dim is None else out_dim + self.out_dim = out_dim + self.use_conv_shortcut = conv_shortcut + self.use_norm = use_norm + + self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + if self.in_dim != self.out_dim: + if self.use_conv_shortcut: + self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) + else: + self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + # pixel norm + if self.use_norm: + x = normalize(x, dim=1) + + h = x + h = nonlinearity(h) + h = self.conv1(h) + + h = nonlinearity(h) + h = self.conv2(h) + + if self.in_dim != self.out_dim: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return mp_sum(x, h, t=0.3) + + +class AttnBlock1D(nn.Module): + + def __init__(self, in_channels, num_heads=1): + super().__init__() + self.in_channels = in_channels + + self.num_heads = num_heads + self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False) + self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + self.optimized_attention = vae_attention() + + def forward(self, x): + h = x + y = self.qkv(h) + y = y.reshape(y.shape[0], -1, 3, y.shape[-1]) + q, k, v = normalize(y, dim=1).unbind(2) + + h = self.optimized_attention(q, k, v) + h = self.proj_out(h) + + return mp_sum(x, h, t=0.3) + + +class Upsample1D(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T) + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample1D(nn.Module): + + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + + if self.with_conv: + x = self.conv1(x) + + x = F.avg_pool1d(x, kernel_size=2, stride=2) + + if self.with_conv: + x = self.conv2(x) + + return x diff --git a/comfy/sd.py b/comfy/sd.py index f2d95f85a25a..b9c2e995e759 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -18,6 +18,7 @@ import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae +import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert import yaml import math @@ -291,6 +292,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.downscale_index_formula = None self.upscale_index_formula = None self.extra_1d_channel = None + self.crop_input = True if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -542,6 +544,25 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.latent_channels = 3 self.latent_dim = 2 self.output_channels = 3 + elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE + sample_rate = 16000 + if sample_rate == 16000: + mode = '16k' + else: + mode = '44k' + + self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode) + self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype) + self.latent_channels = 20 + self.output_channels = 2 + self.upscale_ratio = 512 * (44100 / sample_rate) + self.downscale_ratio = 512 * (44100 / sample_rate) + self.latent_dim = 1 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.crop_input = False else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -575,6 +596,9 @@ def throw_exception_if_invalid(self): raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") def vae_encode_crop_pixels(self, pixels): + if not self.crop_input: + return pixels + downscale_ratio = self.spacial_compression_encode() dims = pixels.shape[1:-1]