Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vits] Support WavLM Discriminator #215

Merged
merged 1 commit into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/baker/configs/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 256
"gin_channels": 256,
"use_wd": true,
"slm_model": "exp/slm/wavlm-base-plus",
"slm_sr": 16000,
"slm_hidden": 768,
"slm_nlayers": 13,
"slm_initial_channel": 64
}
}
8 changes: 7 additions & 1 deletion examples/baker/configs/vits2_v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
"n_layers_q": 3,
"use_sdp": true,
"use_spectral_norm": false,
"gin_channels": 256
"gin_channels": 256,
"use_wd": true,
"slm_model": "exp/slm/wavlm-base-plus",
"slm_sr": 16000,
"slm_hidden": 768,
"slm_nlayers": 13,
"slm_initial_channel": 64
}
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ torch
torchvision
tqdm
transformers
huggingface_hub
95 changes: 95 additions & 0 deletions wetts/vits/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import torchaudio
from transformers import AutoModel


def feature_loss(fmap_r, fmap_g):
Expand Down Expand Up @@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l


class WavLMLoss(torch.nn.Module):
def __init__(self, model, wd, model_sr, slm_sr=16000):
super(WavLMLoss, self).__init__()
self.wavlm = AutoModel.from_pretrained(model)
self.wd = wd
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
self.wavlm.eval()
for param in self.wavlm.parameters():
param.requires_grad = False

def forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(
input_values=wav_16, output_hidden_states=True
).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(
input_values=y_rec_16.squeeze(), output_hidden_states=True
).hidden_states

floss = 0
for er, eg in zip(wav_embeddings, y_rec_embeddings):
floss += torch.mean(torch.abs(er - eg))

return floss.mean()

def generator(self, y_rec):
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(
input_values=y_rec_16, output_hidden_states=True
).hidden_states
y_rec_embeddings = (
torch.stack(y_rec_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)
y_df_hat_g = self.wd(y_rec_embeddings)
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)

return loss_gen

def discriminator(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(
input_values=wav_16, output_hidden_states=True
).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(
input_values=y_rec_16, output_hidden_states=True
).hidden_states

y_embeddings = (
torch.stack(wav_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)
y_rec_embeddings = (
torch.stack(y_rec_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)

y_d_rs = self.wd(y_embeddings)
y_d_gs = self.wd(y_rec_embeddings)

y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs

r_loss = torch.mean((1 - y_df_hat_r) ** 2)
g_loss = torch.mean((y_df_hat_g) ** 2)

loss_disc_f = r_loss + g_loss

return loss_disc_f.mean()

def discriminator_forward(self, wav):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(
input_values=wav_16, output_hidden_states=True
).hidden_states
y_embeddings = (
torch.stack(wav_embeddings, dim=1)
.transpose(-1, -2)
.flatten(start_dim=1, end_dim=2)
)

y_d_rs = self.wd(y_embeddings)

return y_d_rs
49 changes: 49 additions & 0 deletions wetts/vits/model/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,52 @@ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
output_probs.append([output_prob])

return output_probs


class WavLMDiscriminator(nn.Module):
"""docstring for Discriminator."""

def __init__(
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
):
super(WavLMDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.pre = norm_f(
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
)

self.convs = nn.ModuleList(
[
norm_f(
nn.Conv1d(
initial_channel, initial_channel * 2, kernel_size=5, padding=2
)
),
norm_f(
nn.Conv1d(
initial_channel * 2,
initial_channel * 4,
kernel_size=5,
padding=2,
)
),
norm_f(
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
),
]
)

self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))

def forward(self, x):
x = self.pre(x)

fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
x = torch.flatten(x, 1, -1)

return x
Loading
Loading