Skip to content

Commit

Permalink
Merge pull request #9 from skirdey/1.1
Browse files Browse the repository at this point in the history
1.1
  • Loading branch information
skirdey authored Jan 17, 2025
2 parents 6fccf6e + 11191fe commit 834a86c
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 63 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.pt
*.pth
test_output.wav

# Byte-compiled / optimized / DLL files
Expand Down
28 changes: 16 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ Demo of audio restorations: [VoiceRestore](https://sparkling-rabanadas-3082be.ne

Credits: This repository is based on the [E2-TTS implementation by Lucidrains](https://github.com/lucidrains/e2-tts-pytorch)


#### Super easy usage - using Transformers 🤗 by [@jadechoghari](https://github.com/jadechoghari) - Hugging Face
<a href="https://huggingface.co/jadechoghari/VoiceRestore">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20VoiceRestore-blue" alt="VoiceRestore" height="25">
</a>

#### Build it locally on gradio in this [repo.](https://github.com/jadechoghari/VoiceRestore-demo)

#### Try the Model here:
<a href="https://huggingface.co/spaces/jadechoghari/VoiceRestore">
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20VoiceRestore-orange" alt="VoiceRestore" height="25">
</a>
## Latest Releases

* **01/16/2025** - [Version 1.1 of the checkpoint that improves restoration.](https://drive.google.com/drive/folders/1uBJNp4mrPJQY9WEaiTI9u09IsRg1lAPR?usp=sharing)
* **09/07/2024** - Version 0.1 of the model inference and checkpoint.

## Example
### Degraded Input:
Expand Down Expand Up @@ -73,15 +74,15 @@ https://github.com/user-attachments/assets/fdbbb988-9bd2-4750-bddd-32bd5153d254

4. Run a test restoration:
```bash
python inference_short.py --checkpoint ./checkpoints/voice-restore-20d-16h-optim.pt --input test_input.wav --output test_output.wav --steps 32 --cfg_strength 0.5
python inference_short.py --checkpoint ./checkpoints/voicerestore-1.1.pth --input test_input.wav --output test_output.wav --steps 32 --cfg_strength 0.5
```
This will process `test_input.wav` and save the result as `test_output.wav`.

5. Run a long form restoration, it uses window chunking:
```bash
python inference_long.py --checkpoint ./checkpoints/voice-restore-20d-16h-optim.pt --input test_input_long.wav --output test_output_long.wav --steps 32 --cfg_strength 0.5 --window_size_sec 10.0 --overlap 0.25
python inference_long.py --checkpoint ./checkpoints/voicerestore-1.1.pth --input long_audio_file.mp3 --output test_output_long.wav --steps 8 --cfg_strength 0.5 --window_size_sec 10.0 --overlap 0.3
```
This will process `test_input_long.wav` (you need to provide it) and save the result as `test_output_long.wav`.
This will save the result as `test_output_long.wav`.

## Usage

Expand Down Expand Up @@ -130,11 +131,14 @@ model("test_input.wav", "test_output.wav")
If you use VoiceRestore in your research, please cite our paper:

```
@misc{kirdey2024voicerestore,
title={VoiceRestore: Flow-Matching Transformers for Speech Recording Quality Restoration},
author={Kirdey, Stanislav},
howpublished={\url{https://github.com/skirdey/voicerestore}},
year={2024}
@misc{kirdey2025voicerestoreflowmatchingtransformersspeech,
title={VoiceRestore: Flow-Matching Transformers for Speech Recording Quality Restoration},
author={Stanislav Kirdey},
year={2025},
eprint={2501.00794},
archivePrefix={arXiv},
primaryClass={eess.AS},
url={https://arxiv.org/abs/2501.00794},
}
```

Expand Down
Empty file.
2 changes: 1 addition & 1 deletion inference_long.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def restore_audio(model, input_path, output_path, steps=16, cfg_strength=0.5, wi
start_time = time.time()

initial_gpu_memory = measure_gpu_memory(device)
wav, sr = librosa.load(input_path, sr=model.bigvgan_model.h.sampling_rate, mono=True)
wav, sr = librosa.load(input_path, mono=True)
wav = torch.FloatTensor(wav).unsqueeze(0) # Shape: [1, num_samples]

window_size_samples = int(window_size_sec * sr)
Expand Down
5 changes: 1 addition & 4 deletions inference_short.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ def load_model(save_path):


def restore_audio(model, input_path, output_path, steps=16, cfg_strength=0.5):
audio, sr = torchaudio.load(input_path)

if sr != model.target_sample_rate:
audio = torchaudio.functional.resample(audio, sr, model.target_sample_rate)
audio = torchaudio.load(input_path)

audio = audio.mean(dim=0, keepdim=True) if audio.dim() > 1 else audio # Convert to mono if stereo

Expand Down
Binary file removed long_form_sample.ogg
Binary file not shown.
16 changes: 8 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ def __init__(self, target_sample_rate=24000, device=None, bigvgan_model=None):
# Initialize VoiceRestore
self.voice_restore = VoiceRestore(
sigma=0.0,
transformer={
'dim': 768,
'depth': 20,
'heads': 16,
'dim_head': 64,
'skip_connect_type': 'concat',
'max_seq_len': 2000,
},
transformer=dict(
dim=768,
depth=20,
heads=16,
dim_head=64,
skip_connect_type="concat",
max_seq_len=1024,
),
num_channels=100
)

Expand Down
Binary file removed test_input.wav
Binary file not shown.
157 changes: 119 additions & 38 deletions voice_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
d - dimension
"""


from __future__ import annotations
from typing import Dict, Any, Optional
from functools import partial
Expand Down Expand Up @@ -43,6 +42,7 @@ def __init__(self, dim: int, dim_condition: Optional[int] = None, init_bias_valu
nn.init.constant_(self.to_gamma.bias, init_bias_value)

def forward(self, x: torch.Tensor, *, condition: torch.Tensor) -> torch.Tensor:
# condition shape: (b, d) or (b, 1, d)
if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')
gamma = self.to_gamma(condition).sigmoid()
Expand Down Expand Up @@ -113,8 +113,14 @@ def __init__(
skip_proj = Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None

self.layers.append(ModuleList([
gateloop, skip_proj, attn_norm, attn, attn_adaln_zero,
ff_norm, ff, ff_adaln_zero
gateloop,
skip_proj,
attn_norm,
attn,
attn_adaln_zero,
ff_norm,
ff,
ff_adaln_zero
]))

self.final_norm = RMSNorm(dim)
Expand All @@ -123,56 +129,94 @@ def forward(
self,
x: Float['b n d'],
times: Optional[Float['b'] | Float['']] = None,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch, seq_len, device = *x.shape[:2], x.device

assert not (exists(times) ^ self.cond_on_time), '`times` must be passed in if `cond_on_time` is set to `True` and vice versa'
"""
Args:
x: (b, n, d)
times: (b,) or scalar if cond_on_time is True
mask: (b, n) boolean or 0/1 mask for attention
"""
b, n, device = x.shape[0], x.shape[1], x.device
assert not (exists(times) ^ self.cond_on_time), (
"`times` must be passed in if `cond_on_time` is set to `True`, and vice versa."
)

norm_kwargs = {}

# Absolute positional embedding
if exists(self.abs_pos_emb):
# assert seq_len <= self.max_seq_len, f'{seq_len} exceeds the set `max_seq_len` ({self.max_seq_len}) on Transformer'
seq = torch.arange(seq_len, device=device)
x = x + self.abs_pos_emb(seq)
# you may want to guard for n <= self.max_seq_len
pos_indices = torch.arange(n, device=device)
x = x + self.abs_pos_emb(pos_indices)

# Time conditioning
if exists(times):
if times.ndim == 0:
times = repeat(times, ' -> b', b=batch)
times = self.time_cond_mlp(times)
times = repeat(times, ' -> b', b=b)
times = self.time_cond_mlp(times) # (b, d) or (b, 1, d)
norm_kwargs['condition'] = times

registers = repeat(self.registers, 'r d -> b r d', b=batch)
# Concat registers to the sequence
registers = repeat(self.registers, 'r d -> b r d', b=b)
x, registers_packed_shape = pack((registers, x), 'b * d')

# Build the rotary embeddings for this sequence length
rotary_pos_emb = self.rotary_emb.forward_from_seq_len(x.shape[-2])

# Similarly extend the mask to registers + real tokens if given
if mask is not None:
# mask: (b, n), we have total length = r + n after packing
# The first `r` (num_registers) are never "masked" out
# so we build a new mask of shape (b, r + n)
reg_mask = x.new_ones(b, self.num_registers, dtype=mask.dtype)
mask = torch.cat([reg_mask, mask], dim=1) # (b, r + n)

# We'll keep track of skip connections
skips = []

for ind, (
gateloop, maybe_skip_proj, attn_norm, attn, maybe_attn_adaln_zero,
ff_norm, ff, maybe_ff_adaln_zero
gateloop,
maybe_skip_proj,
attn_norm,
attn,
maybe_attn_adaln_zero,
ff_norm,
ff,
maybe_ff_adaln_zero
) in enumerate(self.layers):
layer = ind + 1
is_first_half = layer <= (self.depth // 2)

layer_idx = ind + 1
is_first_half = (layer_idx <= (self.depth // 2))

# If in the first half, push x onto skip stack
if is_first_half:
skips.append(x)
else:
# Retrieve matching skip
skip = skips.pop()
if self.skip_connect_type == 'concat':
x = torch.cat((x, skip), dim=-1)
x = maybe_skip_proj(x)

# GateLoop
x = gateloop(x) + x

attn_out = attn(attn_norm(x, **norm_kwargs), rotary_pos_emb=rotary_pos_emb)
# Attention
attn_out = attn(
attn_norm(x, **norm_kwargs),
rotary_pos_emb=rotary_pos_emb,
mask=mask # pass mask here
)
x = x + maybe_attn_adaln_zero(attn_out, **norm_kwargs)

# Feed-forward
ff_out = ff(ff_norm(x, **norm_kwargs))
x = x + maybe_ff_adaln_zero(ff_out, **norm_kwargs)

assert len(skips) == 0
assert len(skips) == 0, "Skip-connection stack not empty at the end!"

# Unpack back
_, x = unpack(x, registers_packed_shape, 'b * d')

return self.final_norm(x)
Expand All @@ -189,49 +233,86 @@ def __init__(
self.sigma = sigma
self.num_channels = num_channels

# For simplicity, always cond_on_time = True in transformer config
self.transformer = Transformer(**transformer, cond_on_time=True)

# Default ODE integration settings
self.odeint_kwargs = odeint_kwargs or {'atol': 1e-5, 'rtol': 1e-5, 'method': 'midpoint'}

self.proj_in = nn.Linear(num_channels, self.transformer.dim)
self.cond_proj = nn.Linear(num_channels, self.transformer.dim)
self.to_pred = nn.Linear(self.transformer.dim, num_channels)

def transformer_with_pred_head(self, x: torch.Tensor, times: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.proj_in(x)
def transformer_with_pred_head(
self,
x: torch.Tensor,
times: torch.Tensor,
cond: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Projects input x, optionally adds condition, feeds through Transformer,
and returns final prediction in dimension of num_channels.
"""
x = self.proj_in(x) # (b, n, dim)
if cond is not None:
cond_proj = self.cond_proj(cond)
x = x + cond_proj
attended = self.transformer(x, times=times)
return self.to_pred(attended)
x = x + cond_proj # broadcast if shapes match suitably

attended = self.transformer(x, times=times, mask=mask)
return self.to_pred(attended) # (b, n, num_channels)

def cfg_transformer_with_pred_head(
self,
*args,
cond=None,
mask=None,
cfg_strength: float = 0.5,
**kwargs,
x: torch.Tensor,
times: torch.Tensor,
cond: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
cfg_strength: float = 0.5
):
pred = self.transformer_with_pred_head(*args, **kwargs, cond=cond)
"""
Classifier-free guidance variant of the transformer forward.
If cfg_strength <= 0, effectively the normal forward pass.
"""
pred = self.transformer_with_pred_head(x, times=times, cond=cond, mask=mask)

if cfg_strength < 1e-5:
return pred * mask.unsqueeze(-1) if mask is not None else pred

null_pred = self.transformer_with_pred_head(*args, **kwargs, cond=None)

result = pred + (pred - null_pred) * cfg_strength
return result * mask.unsqueeze(-1) if mask is not None else result
# no guidance
return pred if mask is None else pred * mask.unsqueeze(-1)

# null (no condition) pass
null_pred = self.transformer_with_pred_head(x, times=times, cond=None, mask=mask)

guided = pred + (pred - null_pred) * cfg_strength
return guided if mask is None else guided * mask.unsqueeze(-1)

@torch.no_grad()
def sample(self, processed: torch.Tensor, steps: int = 32, cfg_strength: float = 0.5) -> torch.Tensor:
def sample(
self,
processed: torch.Tensor,
steps: int = 32,
cfg_strength: float = 0.5,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Example sampling routine using an ODE solver. For many text/audio models,
you'll use something else, but this shows how you might incorporate the
same forward pass + mask into an ODE integration.
"""
self.eval()
# times from 0 -> 1
times = torch.linspace(0, 1, steps, device=processed.device)

def ode_fn(t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return self.cfg_transformer_with_pred_head(x, times=t, cond=processed, cfg_strength=cfg_strength)
def ode_fn(t: torch.Tensor, x: torch.Tensor):
return self.cfg_transformer_with_pred_head(
x,
times=t,
cond=processed,
mask=mask,
cfg_strength=cfg_strength
)

# Starting from noise
y0 = torch.randn_like(processed)
trajectory = odeint(ode_fn, y0, times, **self.odeint_kwargs)
restored = trajectory[-1]
Expand Down

0 comments on commit 834a86c

Please sign in to comment.