diff --git a/README.md b/README.md index 2aa4e9a..0be0b4b 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,13 @@ # spleeter-pytorch Spleeter implementation in pytorch. +## Requirements + +To install requirements, run `pip install -r requirements.txt` + ## Usage -See [example](./test_estimator.py) for the usage how to use. +See [example](run_estimator.py) for the usage how to use. ## Note diff --git a/output/out_0.wav b/output/out_0.wav deleted file mode 100644 index 490de92..0000000 Binary files a/output/out_0.wav and /dev/null differ diff --git a/output/out_1.wav b/output/out_1.wav deleted file mode 100644 index 7b27ef6..0000000 Binary files a/output/out_1.wav and /dev/null differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6068feb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy==1.18.5 +tensorflow==2.3.1 +torch==1.7.0 +torchaudio==0.7.0 +librosa==0.8.0 \ No newline at end of file diff --git a/run_estimator.py b/run_estimator.py new file mode 100644 index 0000000..cb60f71 --- /dev/null +++ b/run_estimator.py @@ -0,0 +1,29 @@ +import torchaudio +import soundfile as sf + +from spleeter.estimator import Estimator +import os + + +es = Estimator(2, './checkpoints/2stems/model') + + +def main(original_audio='./audio_example.mp3', out_dir='./output'): + # load wav audio + wav, sr = torchaudio.load(original_audio) + + # normalize audio + wav_torch = wav / (wav.max() + 1e-8) + + wavs = es.separate(wav_torch) + for i in range(len(wavs)): + fname = os.path.join(out_dir, f'out_{i}.wav') + print('Writing:', fname) + new_wav = wavs[i].squeeze() + new_wav = new_wav.permute(1, 0) + new_wav = new_wav.numpy() + sf.write(fname, new_wav, sr) + + +if __name__ == '__main__': + main() diff --git a/spleeter/estimator.py b/spleeter/estimator.py index 7efd802..99657ad 100644 --- a/spleeter/estimator.py +++ b/spleeter/estimator.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F from torch import nn -from torchaudio.functional import istft from .unet import UNet from .util import tf2pytorch @@ -88,7 +87,7 @@ def inverse_stft(self, stft): pad = self.win_length // 2 + 1 - stft.size(1) stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) - wav = istft(stft, self.win_length, hop_length=self.hop_length, + wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, window=self.win) return wav.detach() diff --git a/spleeter/util.py b/spleeter/util.py index 36d54bb..8c24f93 100644 --- a/spleeter/util.py +++ b/spleeter/util.py @@ -1,8 +1,6 @@ import numpy as np import tensorflow as tf -from .unet import UNet - def tf2pytorch(checkpoint_path, num_instrumments): tf_vars = {} diff --git a/test_estimator.py b/test_estimator.py deleted file mode 100644 index d28f867..0000000 --- a/test_estimator.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -import torchaudio -from librosa.core import load -from librosa.output import write_wav -import numpy as np - -from spleeter.estimator import Estimator - -es = Estimator(2, './checkpoints/2stems/model') - -# load wav audio -wav, sr = torchaudio.load_wav('./audio_example.mp3') - -# normalize audio -wav_torch = wav / (wav.max() + 1e-8) - -wavs = es.separate(wav_torch) -for i in range(len(wavs)): - fname = 'output/out_{}.wav'.format(i) - print('Writing ',fname) - write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)