Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the code in train.py #212

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 60 additions & 43 deletions wetts/vits/train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os

import torch
import torch.distributed as dist

from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

Expand All @@ -30,30 +31,17 @@

def main():
hps = task.get_hparams()
# Set random seed
torch.manual_seed(hps.train.seed)
global global_step
# Initialize distributed
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
torch.torch.cuda.set_device(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
if rank == 0:
logger = task.get_logger(hps.model_dir)
logger.info(hps)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(
log_dir=os.path.join(hps.model_dir, "eval"))

if ("use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder):
print("Using mel posterior encoder for VITS2")
posterior_channels = hps.data.n_mel_channels # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False

# Get the dataset and data loader
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
train_sampler = DistributedBucketSampler(
train_dataset,
Expand Down Expand Up @@ -85,6 +73,17 @@ def main():
collate_fn=collate_fn,
)

# Get the tts model
if ("use_mel_posterior_encoder" in hps.model.keys()
and hps.model.use_mel_posterior_encoder):
print("Using mel posterior encoder for VITS2")
posterior_channels = hps.data.n_mel_channels # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False

# some of these flags are not being used in the code and directly set in hps
# json file. they are kept here for reference and prototyping.
if ("use_transformer_flows" in hps.model.keys()
Expand Down Expand Up @@ -144,7 +143,7 @@ def main():
0.1,
gin_channels=hps.model.gin_channels
if hps.data.n_speakers != 0 else 0,
).cuda(rank)
).cuda(local_rank)
elif duration_discriminator_type == "dur_disc_2":
net_dur_disc = DurationDiscriminatorV2(
hps.model.hidden_channels,
Expand All @@ -153,7 +152,7 @@ def main():
0.1,
gin_channels=hps.model.gin_channels
if hps.data.n_speakers != 0 else 0,
).cuda(rank)
).cuda(local_rank)
else:
print("NOT using any duration discriminator like VITS1")
net_dur_disc = None
Expand All @@ -164,15 +163,29 @@ def main():
n_speakers=hps.data.n_speakers,
mas_noise_scale_initial=mas_noise_scale_initial,
noise_scale_delta=noise_scale_delta,
**hps.model).cuda(rank)
**hps.model).cuda(local_rank)
if ("use_mrd_disc" in hps.model.keys()
and hps.model.use_mrd_disc):
print("Using MultiPeriodMultiResolutionDiscriminator")
net_d = MultiPeriodMultiResolutionDiscriminator(
hps.model.use_spectral_norm).cuda(rank)
hps.model.use_spectral_norm).cuda(local_rank)
else:
print("Using MPD")
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)

# Dispatch the model from cpu to gpu
# comment - choihkk
# if we comment out unused parameter like DurationDiscriminator's
# self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
# self.post_transformer we don't have to set find_unused_parameters=True
# but I will not proceed with commenting out for compatibility with the
# latest work for others
net_g = DDP(net_g, device_ids=[local_rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[local_rank], find_unused_parameters=True)
if net_dur_disc:
net_dur_disc = DDP(net_dur_disc, device_ids=[local_rank], find_unused_parameters=True)

# Get the optimizer
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
Expand All @@ -195,17 +208,7 @@ def main():
else:
optim_dur_disc = None

# comment - choihkk
# if we comment out unused parameter like DurationDiscriminator's
# self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
# self.post_transformer we don't have to set find_unused_parameters=True
# but I will not proceed with commenting out for compatibility with the
# latest work for others
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
if net_dur_disc:
net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)

# Load the checkpoint
try:
_, _, _, epoch_str = task.load_checkpoint(
task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
Expand All @@ -219,11 +222,14 @@ def main():
net_dur_disc,
optim_dur_disc,
)
global_step = (epoch_str - 1) * len(train_loader)
global_step = int(
task.get_steps(task.latest_checkpoint_path(hps.model_dir, "G_*.pth"))
)
except Exception as e:
epoch_str = 1
global_step = 0

# Get the scheduler
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
Expand All @@ -234,12 +240,22 @@ def main():
else:
scheduler_dur_disc = None

# Get the tensorboard summary
writer = None
if rank == 0:
logger = task.get_logger(hps.model_dir)
logger.info(hps)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(
log_dir=os.path.join(hps.model_dir, "eval"))

scaler = GradScaler(enabled=hps.train.fp16_run)

for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
local_rank,
epoch,
hps,
[net_g, net_d, net_dur_disc],
Expand All @@ -253,6 +269,7 @@ def main():
else:
train_and_evaluate(
rank,
local_rank,
epoch,
hps,
[net_g, net_d, net_dur_disc],
Expand All @@ -269,7 +286,7 @@ def main():
scheduler_dur_disc.step()


def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, scaler,
loaders, logger, writers):
net_g, net_d, net_dur_disc = nets
optim_g, optim_d, optim_dur_disc = optims
Expand Down Expand Up @@ -301,14 +318,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
net_g.module.noise_scale_delta * global_step)
net_g.module.current_mas_noise_scale = max(current_mas_noise_scale,
0.0)
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
rank, non_blocking=True)
x, x_lengths = x.cuda(local_rank, non_blocking=True), x_lengths.cuda(
local_rank, non_blocking=True)
spec, spec_lengths = spec.cuda(
rank, non_blocking=True), spec_lengths.cuda(rank,
non_blocking=True)
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
rank, non_blocking=True)
speakers = speakers.cuda(rank, non_blocking=True)
local_rank, non_blocking=True), spec_lengths.cuda(local_rank,
non_blocking=True)
y, y_lengths = y.cuda(local_rank, non_blocking=True), y_lengths.cuda(
local_rank, non_blocking=True)
speakers = speakers.cuda(local_rank, non_blocking=True)

with autocast(enabled=hps.train.fp16_run):
(
Expand Down
6 changes: 6 additions & 0 deletions wetts/vits/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import re
from pathlib import Path

import torch
Expand Down Expand Up @@ -259,6 +260,11 @@ def get_logger(model_dir, filename="train.log"):
return logger


def get_steps(model_path):
matches = re.findall(r"\d+", model_path)
return matches[-1] if matches else None


class HParams:

def __init__(self, **kwargs):
Expand Down
Loading