From f2902bf6861f16e14513be9b159cff4dd0bb730a Mon Sep 17 00:00:00 2001 From: Andrew Beveridge Date: Thu, 28 Dec 2023 00:50:28 -0600 Subject: [PATCH] Added STFT class to move slightly closer to UVR implementation parity --- audio_separator/separator/separator.py | 38 ++++------------- audio_separator/separator/stft.py | 58 ++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 31 deletions(-) create mode 100644 audio_separator/separator/stft.py diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py index 5ba45cc..bb693e9 100644 --- a/audio_separator/separator/separator.py +++ b/audio_separator/separator/separator.py @@ -11,6 +11,7 @@ import onnxruntime as ort from pydub import AudioSegment from audio_separator.separator import spec_utils +from audio_separator.separator.stft import STFT class Separator: @@ -327,6 +328,7 @@ def initialize_model_settings(self): self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device) self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins - self.dim_f, self.dim_t]).to(self.device) self.gen_size = self.chunk_size - 2 * self.trim + self.stft = STFT(self.logger, self.n_fft, self.hop, self.dim_f, self.device) self.logger.debug( f"Model settings initialized: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}" ) @@ -407,9 +409,9 @@ def demix_base(self, mix, is_ckpt=False, is_match_mix=False): return sources - # This function is called by demix_base for each audio chunk. - # It applies a Short-Time Fourier Transform (STFT) to the chunk, processes it through the neural network model, - # and then applies an inverse STFT to convert it back to the time domain. + # This function is called by demix_base for each audio chunk. + # It applies a Short-Time Fourier Transform (STFT) to the chunk, processes it through the neural network model, + # and then applies an inverse STFT to convert it back to the time domain. # This function is where the model infers the separation of vocals and instrumentals from the mixed audio. def run_model(self, mix, is_ckpt=False, is_match_mix=False): self.logger.debug(f"Running model on mix_wave with is_ckpt={is_ckpt}, is_match_mix={is_match_mix}") @@ -426,42 +428,16 @@ def run_model(self, mix, is_ckpt=False, is_match_mix=False): spec_pred = -self.model_run(-spek) * 0.5 + self.model_run(spek) * 0.5 if self.denoise_enabled else self.model_run(spek) if is_ckpt: - return self.istft(spec_pred).cpu().detach().numpy() + return self.stft.inverse(spec_pred).cpu().detach().numpy() else: return ( - self.istft(torch.tensor(spec_pred).to(self.device)) + self.stft.inverse(torch.tensor(spec_pred).to(self.device)) .to(self.cpu)[:, :, self.trim : -self.trim] .transpose(0, 1) .reshape(2, -1) .numpy() ) - # These functions perform the Short-Time Fourier Transform (stft) and its inverse (istft). - # They are essential for converting the audio between the time domain and the frequency domain, - # which is a crucial aspect of audio processing in neural networks. - def stft(self, x): - initial_shape = x.shape - x = x.reshape([-1, self.chunk_size]) - x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True) - x = torch.view_as_real(x) - x = x.permute([0, 3, 1, 2]) - x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t]) - self.logger.debug(f"STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}") - return x[:, :, : self.dim_f] - - def istft(self, x, freq_pad=None): - initial_shape = x.shape - freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad - x = torch.cat([x, freq_pad], -2) - x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t]) - x = x.permute([0, 2, 3, 1]) - x = x.contiguous() - x = torch.view_as_complex(x) - x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) - x = x.reshape([-1, 2, self.chunk_size]) - self.logger.debug(f"ISTFT applied. Initial shape: {initial_shape} Returning shape: {x.shape}") - return x - # This function handles the initial processing of the audio file. It involves loading the audio file (or array), # ensuring it's in stereo format, and then segmenting it into manageable chunks based on the specified chunk size and margin. # This segmentation is crucial for efficient processing of the audio, especially for longer tracks. diff --git a/audio_separator/separator/stft.py b/audio_separator/separator/stft.py new file mode 100644 index 0000000..8b03110 --- /dev/null +++ b/audio_separator/separator/stft.py @@ -0,0 +1,58 @@ +import torch + + +# These functions perform the Short-Time Fourier Transform (stft) and its inverse (istft). +# They are essential for converting the audio between the time domain and the frequency domain, +# which is a crucial aspect of audio processing in neural networks. +class STFT: + def __init__(self, logger, n_fft, hop_length, dim_f, device): + self.logger = logger + self.n_fft = n_fft + self.hop_length = hop_length + self.window = torch.hann_window(window_length=self.n_fft, periodic=True) + self.dim_f = dim_f + self.device = device + + def __call__(self, x): + x_is_mps = not x.device.type in ["cuda", "cpu"] + if x_is_mps: + x = x.cpu() + + initial_shape = x.shape + window = self.window.to(x.device) + batch_dims = x.shape[:-2] + c, t = x.shape[-2:] + x = x.reshape([-1, t]) + x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=False) + x = x.permute([0, 3, 1, 2]) + x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) + + if x_is_mps: + x = x.to(self.device) + + self.logger.debug(f"STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}") + return x[..., : self.dim_f, :] + + def inverse(self, x): + x_is_mps = not x.device.type in ["cuda", "cpu"] + if x_is_mps: + x = x.cpu() + + initial_shape = x.shape + window = self.window.to(x.device) + batch_dims = x.shape[:-3] + c, f, t = x.shape[-3:] + n = self.n_fft // 2 + 1 + f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) + x = torch.cat([x, f_pad], -2) + x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) + x = x.permute([0, 2, 3, 1]) + x = x[..., 0] + x[..., 1] * 1.0j + x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) + x = x.reshape([*batch_dims, 2, -1]) + + if x_is_mps: + x = x.to(self.device) + + self.logger.debug(f"Inverse STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}") + return x