diff --git a/train.py b/train.py index ac56216..d9251fa 100644 --- a/train.py +++ b/train.py @@ -383,12 +383,15 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade utils.save_checkpoint(net_dur_disc, optim_dur_disc, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step))) - prev_g = os.path.join(hps.model_dir, "G_{}.pth".format(global_step - 2 * hps.train.eval_interval)) - prev_d = os.path.join(hps.model_dir, "D_{}.pth".format(global_step - 2 * hps.train.eval_interval)) + prev_g = os.path.join(hps.model_dir, "G_{}.pth".format(global_step - 3 * hps.train.eval_interval)) + prev_d = os.path.join(hps.model_dir, "D_{}.pth".format(global_step - 3 * hps.train.eval_interval)) + prev_dur = os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step - 3 * hps.train.eval_interval)) if os.path.exists(prev_g): os.remove(prev_g) if os.path.exists(prev_d): os.remove(prev_d) + if os.path.exists(prev_dur): + os.remove(prev_dur) global_step += 1