diff --git a/wetts/vits/model/modules.py b/wetts/vits/model/modules.py index 985a68c..c33318d 100644 --- a/wetts/vits/model/modules.py +++ b/wetts/vits/model/modules.py @@ -1,7 +1,6 @@ import torch from torch import nn -from torch.nn.utils import weight_norm -from torch.nn.utils.parametrize import remove_parametrizations +from torch.nn.utils import weight_norm, remove_weight_norm from utils import commons @@ -89,11 +88,11 @@ def forward(self, x, x_mask, g=None, **kwargs): def remove_weight_norm(self): if self.gin_channels != 0: - remove_parametrizations(self.cond_layer, "weight") + remove_weight_norm(self.cond_layer, "weight") for l in self.in_layers: - remove_parametrizations(l, "weight") + remove_weight_norm(l, "weight") for l in self.res_skip_layers: - remove_parametrizations(l, "weight") + remove_weight_norm(l, "weight") class Flip(nn.Module):