Skip to content

Commit

Permalink
Added STFT class to move slightly closer to UVR implementation parity
Browse files Browse the repository at this point in the history
  • Loading branch information
beveradb committed Dec 28, 2023
1 parent f9fc5c0 commit f2902bf
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 31 deletions.
38 changes: 7 additions & 31 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import onnxruntime as ort
from pydub import AudioSegment
from audio_separator.separator import spec_utils
from audio_separator.separator.stft import STFT


class Separator:
Expand Down Expand Up @@ -327,6 +328,7 @@ def initialize_model_settings(self):
self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device)
self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins - self.dim_f, self.dim_t]).to(self.device)
self.gen_size = self.chunk_size - 2 * self.trim
self.stft = STFT(self.logger, self.n_fft, self.hop, self.dim_f, self.device)
self.logger.debug(
f"Model settings initialized: n_bins={self.n_bins}, trim={self.trim}, chunk_size={self.chunk_size}, gen_size={self.gen_size}"
)
Expand Down Expand Up @@ -407,9 +409,9 @@ def demix_base(self, mix, is_ckpt=False, is_match_mix=False):

return sources

# This function is called by demix_base for each audio chunk.
# It applies a Short-Time Fourier Transform (STFT) to the chunk, processes it through the neural network model,
# and then applies an inverse STFT to convert it back to the time domain.
# This function is called by demix_base for each audio chunk.
# It applies a Short-Time Fourier Transform (STFT) to the chunk, processes it through the neural network model,
# and then applies an inverse STFT to convert it back to the time domain.
# This function is where the model infers the separation of vocals and instrumentals from the mixed audio.
def run_model(self, mix, is_ckpt=False, is_match_mix=False):
self.logger.debug(f"Running model on mix_wave with is_ckpt={is_ckpt}, is_match_mix={is_match_mix}")
Expand All @@ -426,42 +428,16 @@ def run_model(self, mix, is_ckpt=False, is_match_mix=False):
spec_pred = -self.model_run(-spek) * 0.5 + self.model_run(spek) * 0.5 if self.denoise_enabled else self.model_run(spek)

if is_ckpt:
return self.istft(spec_pred).cpu().detach().numpy()
return self.stft.inverse(spec_pred).cpu().detach().numpy()
else:
return (
self.istft(torch.tensor(spec_pred).to(self.device))
self.stft.inverse(torch.tensor(spec_pred).to(self.device))
.to(self.cpu)[:, :, self.trim : -self.trim]
.transpose(0, 1)
.reshape(2, -1)
.numpy()
)

# These functions perform the Short-Time Fourier Transform (stft) and its inverse (istft).
# They are essential for converting the audio between the time domain and the frequency domain,
# which is a crucial aspect of audio processing in neural networks.
def stft(self, x):
initial_shape = x.shape
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True, return_complex=True)
x = torch.view_as_real(x)
x = x.permute([0, 3, 1, 2])
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t])
self.logger.debug(f"STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}")
return x[:, :, : self.dim_f]

def istft(self, x, freq_pad=None):
initial_shape = x.shape
freq_pad = self.freq_pad.repeat([x.shape[0], 1, 1, 1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t])
x = x.permute([0, 2, 3, 1])
x = x.contiguous()
x = torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
x = x.reshape([-1, 2, self.chunk_size])
self.logger.debug(f"ISTFT applied. Initial shape: {initial_shape} Returning shape: {x.shape}")
return x

# This function handles the initial processing of the audio file. It involves loading the audio file (or array),
# ensuring it's in stereo format, and then segmenting it into manageable chunks based on the specified chunk size and margin.
# This segmentation is crucial for efficient processing of the audio, especially for longer tracks.
Expand Down
58 changes: 58 additions & 0 deletions audio_separator/separator/stft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch


# These functions perform the Short-Time Fourier Transform (stft) and its inverse (istft).
# They are essential for converting the audio between the time domain and the frequency domain,
# which is a crucial aspect of audio processing in neural networks.
class STFT:
def __init__(self, logger, n_fft, hop_length, dim_f, device):
self.logger = logger
self.n_fft = n_fft
self.hop_length = hop_length
self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
self.dim_f = dim_f
self.device = device

def __call__(self, x):
x_is_mps = not x.device.type in ["cuda", "cpu"]
if x_is_mps:
x = x.cpu()

initial_shape = x.shape
window = self.window.to(x.device)
batch_dims = x.shape[:-2]
c, t = x.shape[-2:]
x = x.reshape([-1, t])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True, return_complex=False)
x = x.permute([0, 3, 1, 2])
x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])

if x_is_mps:
x = x.to(self.device)

self.logger.debug(f"STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}")
return x[..., : self.dim_f, :]

def inverse(self, x):
x_is_mps = not x.device.type in ["cuda", "cpu"]
if x_is_mps:
x = x.cpu()

initial_shape = x.shape
window = self.window.to(x.device)
batch_dims = x.shape[:-3]
c, f, t = x.shape[-3:]
n = self.n_fft // 2 + 1
f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
x = torch.cat([x, f_pad], -2)
x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
x = x.permute([0, 2, 3, 1])
x = x[..., 0] + x[..., 1] * 1.0j
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
x = x.reshape([*batch_dims, 2, -1])

if x_is_mps:
x = x.to(self.device)

self.logger.debug(f"Inverse STFT applied. Initial shape: {initial_shape} Resulting shape: {x.shape}")
return x

0 comments on commit f2902bf

Please sign in to comment.