From 97b83e8b014f143aa00f5e5b3b8c003cc54cc93b Mon Sep 17 00:00:00 2001 From: whoami <48873278+lsrami@users.noreply.github.com> Date: Wed, 22 May 2024 20:07:21 +0800 Subject: [PATCH] [dataloader] Fix text filtering bug and speed up spectrum length calc (#216) * [dataloader] Fix text filtering bug and speed up spectrum length calculation * [fix] Fix code style check --------- Co-authored-by: lsrami --- examples/aishell-3/run.sh | 7 ++++ requirements.txt | 1 + tools/compute_spec_length.py | 72 ++++++++++++++++++++++++++++++++++++ wetts/vits/data_utils.py | 36 +++++++++++------- 4 files changed, 102 insertions(+), 14 deletions(-) create mode 100755 tools/compute_spec_length.py diff --git a/examples/aishell-3/run.sh b/examples/aishell-3/run.sh index c884bce..6a092b1 100755 --- a/examples/aishell-3/run.sh +++ b/examples/aishell-3/run.sh @@ -40,6 +40,13 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then $dataset_dir/data_aishell3 \ $data/all.txt + # Compute spec length (optional, but recommended) + python tools/compute_spec_length.py \ + $data/all.txt \ + $config \ + $data/all_spec_length.txt + mv $data/all_spec_length.txt $data/all.txt + cat $data/all.txt | awk -F '|' '{print $2}' | \ sort | uniq | awk '{print $0, NR-1}' > $data/speaker.txt echo 'sil 0' > $data/phones.txt diff --git a/requirements.txt b/requirements.txt index c41c9c6..0606d6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ torchvision tqdm transformers huggingface_hub +soundfile diff --git a/tools/compute_spec_length.py b/tools/compute_spec_length.py new file mode 100755 index 0000000..539f1c9 --- /dev/null +++ b/tools/compute_spec_length.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# author: @lsrami + +import os +import sys +import json +from tqdm import tqdm +import soundfile as sf +from concurrent.futures import ThreadPoolExecutor + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def process_item(item): + audiopath = item[0] + src_sampling_rate = sf.info(audiopath).samplerate + text = item[2] + text = text.strip().split() + if min_text_len <= len(text) and len(text) <= max_text_len: + length = int(os.path.getsize(audiopath) * sampling_rate / + src_sampling_rate) // (2 * hop_length) + item.append(length) + return item + else: + return None + + +def main(in_file, out_file): + """ + Filter text & store spec lengths + """ + + audiopaths_sid_text = load_filepaths_and_text(in_file, split="|") + + with ThreadPoolExecutor(max_workers=32) as executor: + results = list( + tqdm( + executor.map(process_item, audiopaths_sid_text), + total=len(audiopaths_sid_text), + ) + ) + + # Filter out None results + results = [result for result in results if result is not None] + + with open(out_file, "w", encoding="utf-8") as f: + for item in results: + f.write("|".join([str(i) for i in item]) + "\n") + + +if __name__ == "__main__": + if len(sys.argv) != 4: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + in_file, config_file, out_file = sys.argv[1:4] + + with open(config_file, "r", encoding="utf8") as f: + data = f.read() + config = json.loads(data) + hparams = config["data"] + + min_text_len = hparams.get("min_text_len", 1) + max_text_len = hparams.get("max_text_len", 190) + sampling_rate = hparams.get("sampling_rate", 22050) + hop_length = hparams.get("hop_length", 256) + print(min_text_len, max_text_len, sampling_rate, hop_length) + + main(in_file, out_file) diff --git a/wetts/vits/data_utils.py b/wetts/vits/data_utils.py index 4886a82..4ccdcaf 100644 --- a/wetts/vits/data_utils.py +++ b/wetts/vits/data_utils.py @@ -4,6 +4,8 @@ import torch import torchaudio import torch.utils.data +from tqdm import tqdm +import soundfile as sf from utils.mel_processing import spectrogram_torch, mel_spectrogram_torch from utils.task import load_filepaths_and_text @@ -60,20 +62,26 @@ def _filter(self): """ Filter text & store spec lengths """ - audiopaths_sid_text_new = [] - lengths = [] - for item in self.audiopaths_sid_text: - audiopath = item[0] - src_sampling_rate = torchaudio.info(audiopath).sample_rate - # filename|speaker|text - text = item[2] - if self.min_text_len <= len(text) and len( - text) <= self.max_text_len: - audiopaths_sid_text_new.append(item) - lengths.append( - int( - os.path.getsize(audiopath) * self.sampling_rate / - src_sampling_rate) // (2 * self.hop_length)) + if len(self.audiopaths_sid_text[0]) > 3: + # spec length is provided + audiopaths_sid_text_new = [item[:3] for item in self.audiopaths_sid_text] + lengths = [int(item[3]) for item in self.audiopaths_sid_text] + else: + audiopaths_sid_text_new = [] + lengths = [] + for item in tqdm(self.audiopaths_sid_text, desc="Filtering data"): + audiopath = item[0] + src_sampling_rate = sf.info(audiopath).samplerate + # filename|speaker|text + text = item[2] + text = text.strip().split() + if self.min_text_len <= len(text) and len( + text) <= self.max_text_len: + audiopaths_sid_text_new.append(item) + lengths.append( + int( + os.path.getsize(audiopath) * self.sampling_rate / + src_sampling_rate) // (2 * self.hop_length)) self.audiopaths_sid_text = audiopaths_sid_text_new self.lengths = lengths