Skip to content

Commit

Permalink
Add support for BS-RoFormer and Mel-Band RoFormer models (#72)
Browse files Browse the repository at this point in the history
* First pass at adding bs and mel-roformer models to MDXC, not yet tested

* Got some things kinda working

* Added working implementation of Mel-Roformer

* Made mel-roformer the default model
  • Loading branch information
beveradb authored May 22, 2024
1 parent 2d83f46 commit 7e2ba4d
Show file tree
Hide file tree
Showing 11 changed files with 1,638 additions and 224 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ usage: audio-separator [-h] [-v] [-d] [-e] [-l] [--log_level LOG_LEVEL] [-m MODE
[--normalization NORMALIZATION] [--single_stem SINGLE_STEM] [--sample_rate SAMPLE_RATE] [--mdx_segment_size MDX_SEGMENT_SIZE] [--mdx_overlap MDX_OVERLAP] [--mdx_batch_size MDX_BATCH_SIZE]
[--mdx_hop_length MDX_HOP_LENGTH] [--mdx_enable_denoise] [--vr_batch_size VR_BATCH_SIZE] [--vr_window_size VR_WINDOW_SIZE] [--vr_aggression VR_AGGRESSION] [--vr_enable_tta]
[--vr_high_end_process] [--vr_enable_post_process] [--vr_post_process_threshold VR_POST_PROCESS_THRESHOLD] [--demucs_segment_size DEMUCS_SEGMENT_SIZE] [--demucs_shifts DEMUCS_SHIFTS]
[--demucs_overlap DEMUCS_OVERLAP] [--demucs_segments_enabled DEMUCS_SEGMENTS_ENABLED] [--mdxc_segment_size MDXC_SEGMENT_SIZE] [--mdxc_use_model_segment_size] [--mdxc_overlap MDXC_OVERLAP]
[--mdxc_batch_size MDXC_BATCH_SIZE] [--mdxc_pitch_shift MDXC_PITCH_SHIFT]
[--demucs_overlap DEMUCS_OVERLAP] [--demucs_segments_enabled DEMUCS_SEGMENTS_ENABLED] [--mdxc_segment_size MDXC_SEGMENT_SIZE] [--mdxc_override_model_segment_size]
[--mdxc_overlap MDXC_OVERLAP] [--mdxc_batch_size MDXC_BATCH_SIZE] [--mdxc_pitch_shift MDXC_PITCH_SHIFT]
[audio_file]

Separate audio file into different stems.
Expand Down Expand Up @@ -194,7 +194,7 @@ Demucs Architecture Parameters:

MDXC Architecture Parameters:
--mdxc_segment_size MDXC_SEGMENT_SIZE larger consumes more resources, but may give better results (default: 256). Example: --mdxc_segment_size=256
--mdxc_use_model_segment_size use model default segment size instead of the value from the config file. Example: --mdxc_use_model_segment_size
--mdxc_override_model_segment_size override model default segment size instead of using the model default value. Example: --mdxc_override_model_segment_size
--mdxc_overlap MDXC_OVERLAP amount of overlap between prediction windows, 2-50. higher is better but slower (default: 8). Example: --mdxc_overlap=8
--mdxc_batch_size MDXC_BATCH_SIZE larger consumes more RAM but may process slightly faster (default: 1). Example: --mdxc_batch_size=4
--mdxc_pitch_shift MDXC_PITCH_SHIFT shift audio pitch by a number of semitones while processing. may improve output for deep/high vocals. (default: 0). Example: --mdxc_pitch_shift=2
Expand Down
268 changes: 189 additions & 79 deletions audio_separator/separator/architectures/mdxc_separator.py

Large diffs are not rendered by default.

26 changes: 22 additions & 4 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,21 @@ def list_supported_model_files(self):
# "MDX23C Model: MDX23C-InstVoc HQ": {
# "MDX23C-8KFFT-InstVoc_HQ.ckpt": "model_2_stem_full_band_8k.yaml"
# }
# }
# },
# "roformer_download_list": {
# "Roformer Model: BS-Roformer-Viperx-1297": {
# "model_bs_roformer_ep_317_sdr_12.9755.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.yaml"
# },
# "Roformer Model: BS-Roformer-Viperx-1296": {
# "model_bs_roformer_ep_368_sdr_12.9628.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.yaml"
# },
# "Roformer Model: BS-Roformer-Viperx-1053": {
# "model_bs_roformer_ep_937_sdr_10.5309.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.yaml"
# },
# "Roformer Model: Mel-Roformer-Viperx-1143": {
# "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.yaml"
# }
# },
# }

# Only show Demucs v4 models as we've only implemented support for v4
Expand All @@ -354,7 +368,7 @@ def list_supported_model_files(self):
"VR": model_downloads_list["vr_download_list"],
"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]},
"Demucs": filtered_demucs_v4,
"MDXC": {**model_downloads_list["mdx23c_download_list"], **model_downloads_list["mdx23c_download_vip_list"]},
"MDXC": {**model_downloads_list["mdx23c_download_list"], **model_downloads_list["mdx23c_download_vip_list"], **model_downloads_list["roformer_download_list"]},
}
return model_files_grouped_by_type

Expand Down Expand Up @@ -461,11 +475,15 @@ def load_model_data_from_yaml(self, yaml_config_filename):
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
else:
model_data_yaml_filepath = yaml_config_filename

self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")

model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
self.logger.debug(f"Model data loaded from YAML file: {model_data}")

if "roformer" in model_data_yaml_filepath:
model_data["is_roformer"] = True

return model_data

def load_model_data_using_hash(self, model_path):
Expand Down Expand Up @@ -585,7 +603,7 @@ def load_model_data_using_hash(self, model_path):

return model_data

def load_model(self, model_filename="UVR-MDX-NET-Inst_HQ_3.onnx"):
def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
"""
This method instantiates the architecture-specific separation class,
loading the separation model into memory, downloading it first if necessary.
Expand Down
109 changes: 109 additions & 0 deletions audio_separator/separator/uvr_lib_v5/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from functools import wraps
from packaging import version
from collections import namedtuple

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, reduce

# constants

FlashAttentionConfig = namedtuple("FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])

# helpers


def exists(val):
return val is not None


def once(fn):
called = False

@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)

return inner


print_once = once(print)

# main class


class Attend(nn.Module):
def __init__(self, dropout=0.0, flash=False):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse("2.0.0")), "in order to use flash attention, you must be using pytorch 2.0 or above"

# determine efficient attention configs for cuda and cpu

self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not flash:
return

device_properties = torch.cuda.get_device_properties(torch.device("cuda"))

if device_properties.major == 8 and device_properties.minor == 0:
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
self.cuda_config = FlashAttentionConfig(False, True, True)

def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0.0)

return out

def forward(self, q, k, v):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

scale = q.shape[-1] ** -0.5

if self.flash:
return self.flash_attn(q, k, v)

# similarity

sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

# attention

attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)

# aggregate values

out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

return out
Loading

0 comments on commit 7e2ba4d

Please sign in to comment.