From 9a9c3b75b5b3359701844a91a9fae6d2979866cd Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Wed, 17 Jan 2024 18:28:28 +0800 Subject: [PATCH 1/4] Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --- .../paraformer/finetune.sh | 4 +- .../seaco_paraformer/demo.py | 2 +- funasr/auto/auto_model.py | 34 ++- funasr/bin/train.py | 17 +- funasr/datasets/audio_datasets/index_ds.py | 6 +- funasr/datasets/audio_datasets/samplers.py | 3 +- funasr/models/paraformer/model.py | 1 + funasr/models/paraformer/template.yaml | 1 + funasr/train_utils/average_nbest_models.py | 266 +++++++++++------- funasr/train_utils/trainer.py | 107 +++++-- 10 files changed, 296 insertions(+), 145 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh index 93cce73f3..7d8987602 100644 --- a/examples/industrial_data_pretraining/paraformer/finetune.sh +++ b/examples/industrial_data_pretraining/paraformer/finetune.sh @@ -9,9 +9,11 @@ python funasr/bin/train.py \ +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +model_revision="v2.0.2" \ -+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ ++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ ++valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ ++dataset_conf.batch_size=2 \ ++dataset_conf.batch_type="example" \ +++train_conf.max_epoch=2 \ +output_dir="outputs/debug/ckpt/funasr2/exp2" \ +device="cpu" \ +debug="true" \ No newline at end of file diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 5f17252f9..19ad1c9c5 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -15,6 +15,6 @@ spk_model_revision="v2.0.2", ) -res = model.generate(input=f"{model.model_path}/example/asr_example.wav", +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", hotword='达摩院 魔搭') print(res) \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 740614c74..bedc17d16 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -221,7 +221,8 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** speed_stats = {} asr_result_list = [] num_samples = len(data_list) - pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) + disable_pbar = kwargs.get("disable_pbar", False) + pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None time_speech_total = 0.0 time_escape_total = 0.0 for beg_idx in range(0, num_samples, batch_size): @@ -239,8 +240,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** time2 = time.perf_counter() asr_result_list.extend(results) - pbar.update(1) - + # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() batch_data_time = meta_data.get("batch_data_time", -1) time_escape = time2 - time1 @@ -252,12 +252,15 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** description = ( f"{speed_stats}, " ) - pbar.set_description(description) + if pbar: + pbar.update(1) + pbar.set_description(description) time_speech_total += batch_data_time time_escape_total += time_escape - - pbar.update(1) - pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") + + if pbar: + pbar.update(1) + pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") torch.cuda.empty_cache() return asr_result_list @@ -309,8 +312,11 @@ def inference_with_vad(self, input, input_len=None, **cfg): time_speech_total_per_sample = speech_lengths/16000 time_speech_total_all_samples += time_speech_total_per_sample + pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True) + all_segments = [] for j, _ in enumerate(range(0, n)): + pbar_sample.update(1) batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) if j < n - 1 and ( batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and ( @@ -319,13 +325,14 @@ def inference_with_vad(self, input, input_len=None, **cfg): batch_size_ms_cum = 0 end_idx = j + 1 speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) - results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) + results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg) if self.spk_model is not None: - + + # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] for _b in range(len(speech_j)): - vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \ - sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \ + vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, + sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, speech_j[_b]]] segments = sv_chunk(vad_segments) all_segments.extend(segments) @@ -338,12 +345,13 @@ def inference_with_vad(self, input, input_len=None, **cfg): results_sorted.extend(results) - pbar_total.update(1) + end_asr_total = time.time() time_escape_total_per_sample = end_asr_total - beg_asr_total - pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") + restored_data = [0] * n for j in range(n): diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 7ae687ef9..0334006c5 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -141,30 +141,37 @@ def main(**kwargs): scheduler_class = scheduler_classes.get(scheduler) scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) - # import pdb; - # pdb.set_trace() + # dataset dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) + dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, + **kwargs.get("dataset_conf")) # dataloader batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") - batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) + batch_sampler_val = None if batch_sampler is not None: + batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) + batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, batch_sampler=batch_sampler, num_workers=kwargs.get("dataset_conf").get("num_workers", 4), pin_memory=True) - + dataloader_val = torch.utils.data.DataLoader(dataset_val, + collate_fn=dataset_val.collator, + batch_sampler=batch_sampler_val, + num_workers=kwargs.get("dataset_conf").get("num_workers", 4), + pin_memory=True) trainer = Trainer( model=model, optim=optim, scheduler=scheduler, dataloader_train=dataloader_tr, - dataloader_val=None, + dataloader_val=dataloader_val, local_rank=local_rank, use_ddp=use_ddp, use_fsdp=use_fsdp, diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index 8e5b05cf3..c94d20961 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -54,7 +54,11 @@ def __len__(self): return len(self.contents) def __getitem__(self, index): - return self.contents[index] + try: + data = self.contents[index] + except: + print(index) + return data def get_source_len(self, data_dict): return data_dict["source_len"] diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 4af35e9cd..e170c681b 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -13,6 +13,7 @@ def __init__(self, dataset, buffer_size: int = 30, drop_last: bool = False, shuffle: bool = True, + is_training: bool = True, **kwargs): self.drop_last = drop_last @@ -24,7 +25,7 @@ def __init__(self, dataset, self.buffer_size = buffer_size self.max_token_length = kwargs.get("max_token_length", 5000) self.shuffle_idx = np.arange(self.total_samples) - self.shuffle = shuffle + self.shuffle = shuffle and is_training def __len__(self): return self.total_samples diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index f92441d39..9f3c3f3b6 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -164,6 +164,7 @@ def __init__( self.use_1st_decoder_loss = use_1st_decoder_loss self.length_normalized_loss = length_normalized_loss self.beam_search = None + self.error_calculator = None def forward( self, diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml index 94eebf7bd..3972caaa1 100644 --- a/funasr/models/paraformer/template.yaml +++ b/funasr/models/paraformer/template.yaml @@ -95,6 +95,7 @@ train_conf: - acc - max keep_nbest_models: 10 + avg_nbest_model: 5 log_interval: 50 optim: adam diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py index 96e138428..f117804f3 100644 --- a/funasr/train_utils/average_nbest_models.py +++ b/funasr/train_utils/average_nbest_models.py @@ -9,117 +9,173 @@ import torch from typing import Collection +import os +import torch +import re +from collections import OrderedDict +from functools import cmp_to_key -from funasr.train.reporter import Reporter +# @torch.no_grad() +# def average_nbest_models( +# output_dir: Path, +# best_model_criterion: Sequence[Sequence[str]], +# nbest: Union[Collection[int], int], +# suffix: Optional[str] = None, +# oss_bucket=None, +# pai_output_dir=None, +# ) -> None: +# """Generate averaged model from n-best models +# +# Args: +# output_dir: The directory contains the model file for each epoch +# reporter: Reporter instance +# best_model_criterion: Give criterions to decide the best model. +# e.g. [("valid", "loss", "min"), ("train", "acc", "max")] +# nbest: Number of best model files to be averaged +# suffix: A suffix added to the averaged model file name +# """ +# if isinstance(nbest, int): +# nbests = [nbest] +# else: +# nbests = list(nbest) +# if len(nbests) == 0: +# warnings.warn("At least 1 nbest values are required") +# nbests = [1] +# if suffix is not None: +# suffix = suffix + "." +# else: +# suffix = "" +# +# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] +# nbest_epochs = [ +# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) +# for ph, k, m in best_model_criterion +# if reporter.has(ph, k) +# ] +# +# _loaded = {} +# for ph, cr, epoch_and_values in nbest_epochs: +# _nbests = [i for i in nbests if i <= len(epoch_and_values)] +# if len(_nbests) == 0: +# _nbests = [1] +# +# for n in _nbests: +# if n == 0: +# continue +# elif n == 1: +# # The averaged model is same as the best model +# e, _ = epoch_and_values[0] +# op = output_dir / f"{e}epoch.pb" +# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" +# if sym_op.is_symlink() or sym_op.exists(): +# sym_op.unlink() +# sym_op.symlink_to(op.name) +# else: +# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" +# logging.info( +# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' +# ) +# +# avg = None +# # 2.a. Averaging model +# for e, _ in epoch_and_values[:n]: +# if e not in _loaded: +# if oss_bucket is None: +# _loaded[e] = torch.load( +# output_dir / f"{e}epoch.pb", +# map_location="cpu", +# ) +# else: +# buffer = BytesIO( +# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) +# _loaded[e] = torch.load(buffer) +# states = _loaded[e] +# +# if avg is None: +# avg = states +# else: +# # Accumulated +# for k in avg: +# avg[k] = avg[k] + states[k] +# for k in avg: +# if str(avg[k].dtype).startswith("torch.int"): +# # For int type, not averaged, but only accumulated. +# # e.g. BatchNorm.num_batches_tracked +# # (If there are any cases that requires averaging +# # or the other reducing method, e.g. max/min, for integer type, +# # please report.) +# pass +# else: +# avg[k] = avg[k] / n +# +# # 2.b. Save the ave model and create a symlink +# if oss_bucket is None: +# torch.save(avg, op) +# else: +# buffer = BytesIO() +# torch.save(avg, buffer) +# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), +# buffer.getvalue()) +# +# # 3. *.*.ave.pb is a symlink to the max ave model +# if oss_bucket is None: +# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" +# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" +# if sym_op.is_symlink() or sym_op.exists(): +# sym_op.unlink() +# sym_op.symlink_to(op.name) -@torch.no_grad() -def average_nbest_models( - output_dir: Path, - reporter: Reporter, - best_model_criterion: Sequence[Sequence[str]], - nbest: Union[Collection[int], int], - suffix: Optional[str] = None, - oss_bucket=None, - pai_output_dir=None, -) -> None: - """Generate averaged model from n-best models - Args: - output_dir: The directory contains the model file for each epoch - reporter: Reporter instance - best_model_criterion: Give criterions to decide the best model. - e.g. [("valid", "loss", "min"), ("train", "acc", "max")] - nbest: Number of best model files to be averaged - suffix: A suffix added to the averaged model file name +def _get_checkpoint_paths(output_dir: str, last_n: int=5): """ - if isinstance(nbest, int): - nbests = [nbest] - else: - nbests = list(nbest) - if len(nbests) == 0: - warnings.warn("At least 1 nbest values are required") - nbests = [1] - if suffix is not None: - suffix = suffix + "." - else: - suffix = "" - - # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] - nbest_epochs = [ - (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) - for ph, k, m in best_model_criterion - if reporter.has(ph, k) - ] - - _loaded = {} - for ph, cr, epoch_and_values in nbest_epochs: - _nbests = [i for i in nbests if i <= len(epoch_and_values)] - if len(_nbests) == 0: - _nbests = [1] - - for n in _nbests: - if n == 0: - continue - elif n == 1: - # The averaged model is same as the best model - e, _ = epoch_and_values[0] - op = output_dir / f"{e}epoch.pb" - sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" - if sym_op.is_symlink() or sym_op.exists(): - sym_op.unlink() - sym_op.symlink_to(op.name) - else: - op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" - logging.info( - f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' - ) + Get the paths of the last 'last_n' checkpoints by parsing filenames + in the output directory. + """ + # List all files in the output directory + files = os.listdir(output_dir) + # Filter out checkpoint files and extract epoch numbers + checkpoint_files = [f for f in files if f.startswith("model.pt.e")] + # Sort files by epoch number in descending order + checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True) + # Get the last 'last_n' checkpoint paths + checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]] + return checkpoint_paths - avg = None - # 2.a. Averaging model - for e, _ in epoch_and_values[:n]: - if e not in _loaded: - if oss_bucket is None: - _loaded[e] = torch.load( - output_dir / f"{e}epoch.pb", - map_location="cpu", - ) - else: - buffer = BytesIO( - oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) - _loaded[e] = torch.load(buffer) - states = _loaded[e] +@torch.no_grad() +def average_checkpoints(output_dir: str, last_n: int=5): + """ + Average the last 'last_n' checkpoints' model state_dicts. + If a tensor is of type torch.int, perform sum instead of average. + """ + checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) + state_dicts = [] - if avg is None: - avg = states - else: - # Accumulated - for k in avg: - avg[k] = avg[k] + states[k] - for k in avg: - if str(avg[k].dtype).startswith("torch.int"): - # For int type, not averaged, but only accumulated. - # e.g. BatchNorm.num_batches_tracked - # (If there are any cases that requires averaging - # or the other reducing method, e.g. max/min, for integer type, - # please report.) - pass - else: - avg[k] = avg[k] / n + # Load state_dicts from checkpoints + for path in checkpoint_paths: + if os.path.isfile(path): + state_dicts.append(torch.load(path, map_location='cpu')['state_dict']) + else: + print(f"Checkpoint file {path} not found.") + continue - # 2.b. Save the ave model and create a symlink - if oss_bucket is None: - torch.save(avg, op) - else: - buffer = BytesIO() - torch.save(avg, buffer) - oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), - buffer.getvalue()) + # Check if we have any state_dicts to average + if not state_dicts: + raise RuntimeError("No checkpoints found for averaging.") - # 3. *.*.ave.pb is a symlink to the max ave model - if oss_bucket is None: - op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" - sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" - if sym_op.is_symlink() or sym_op.exists(): - sym_op.unlink() - sym_op.symlink_to(op.name) + # Average or sum weights + avg_state_dict = OrderedDict() + for key in state_dicts[0].keys(): + tensors = [state_dict[key].cpu() for state_dict in state_dicts] + # Check the type of the tensor + if str(tensors[0].dtype).startswith("torch.int"): + # Perform sum for integer tensors + summed_tensor = sum(tensors) + avg_state_dict[key] = summed_tensor + else: + # Perform average for other types of tensors + stacked_tensors = torch.stack(tensors) + avg_state_dict[key] = torch.mean(stacked_tensors, dim=0) + + torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}")) + return avg_state_dict \ No newline at end of file diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index da346c39c..91b30b0a8 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -7,10 +7,11 @@ from contextlib import nullcontext # from torch.utils.tensorboard import SummaryWriter from tensorboardX import SummaryWriter +from pathlib import Path from funasr.train_utils.device_funcs import to_device from funasr.train_utils.recursive_op import recursive_average - +from funasr.train_utils.average_nbest_models import average_checkpoints class Trainer: """ @@ -66,10 +67,9 @@ def __init__(self, model, self.use_ddp = use_ddp self.use_fsdp = use_fsdp self.device = next(model.parameters()).device + self.avg_nbest_model = kwargs.get("avg_nbest_model", 5) self.kwargs = kwargs - if self.resume: - self._resume_checkpoint(self.resume) try: rank = dist.get_rank() @@ -102,9 +102,17 @@ def _save_checkpoint(self, epoch): } # Create output directory if it does not exist os.makedirs(self.output_dir, exist_ok=True) - filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') + filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') torch.save(state, filename) + print(f'Checkpoint saved to {filename}') + latest = Path(os.path.join(self.output_dir, f'model.pt')) + try: + latest.unlink() + except: + pass + + latest.symlink_to(filename) def _resume_checkpoint(self, resume_path): """ @@ -114,29 +122,50 @@ def _resume_checkpoint(self, resume_path): Args: resume_path (str): The file path to the checkpoint to resume from. """ - if os.path.isfile(resume_path): - checkpoint = torch.load(resume_path) + ckpt = os.path.join(resume_path, "model.pt") + if os.path.isfile(ckpt): + checkpoint = torch.load(ckpt) self.start_epoch = checkpoint['epoch'] + 1 self.model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) - print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})") + print(f"Checkpoint loaded successfully from '{ckpt}'") else: - print(f"No checkpoint found at '{resume_path}', starting from scratch") + print(f"No checkpoint found at '{ckpt}', starting from scratch") + + if self.use_ddp or self.use_fsdp: + dist.barrier() def run(self): """ Starts the training process, iterating over epochs, training the model, and saving checkpoints at the end of each epoch. """ + if self.resume: + self._resume_checkpoint(self.output_dir) + for epoch in range(self.start_epoch, self.max_epoch + 1): + self._train_epoch(epoch) - # self._validate_epoch(epoch) + + self._validate_epoch(epoch) + if self.rank == 0: self._save_checkpoint(epoch) + + if self.use_ddp or self.use_fsdp: + dist.barrier() + self.scheduler.step() + + + if self.rank == 0: + average_checkpoints(self.output_dir, self.avg_nbest_model) + if self.use_ddp or self.use_fsdp: + dist.barrier() self.writer.close() + def _train_epoch(self, epoch): """ @@ -157,8 +186,7 @@ def _train_epoch(self, epoch): for batch_idx, batch in enumerate(self.dataloader_train): time1 = time.perf_counter() speed_stats["data_load"] = f"{time1-time5:0.3f}" - # import pdb; - # pdb.set_trace() + batch = to_device(batch, self.device) my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext @@ -211,13 +239,12 @@ def _train_epoch(self, epoch): speed_stats["optim_time"] = f"{time5 - time4:0.3f}" speed_stats["total_time"] = total_time - - # import pdb; - # pdb.set_trace() + + pbar.update(1) if self.local_rank == 0: description = ( - f"Epoch: {epoch + 1}/{self.max_epoch}, " + f"Epoch: {epoch}/{self.max_epoch}, " f"step {batch_idx}/{len(self.dataloader_train)}, " f"{speed_stats}, " f"(loss: {loss.detach().cpu().item():.3f}), " @@ -248,6 +275,50 @@ def _validate_epoch(self, epoch): """ self.model.eval() with torch.no_grad(): - for data, target in self.dataloader_val: - # Implement the model validation steps here - pass + pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val), + dynamic_ncols=True) + speed_stats = {} + time5 = time.perf_counter() + for batch_idx, batch in enumerate(self.dataloader_val): + time1 = time.perf_counter() + speed_stats["data_load"] = f"{time1 - time5:0.3f}" + batch = to_device(batch, self.device) + time2 = time.perf_counter() + retval = self.model(**batch) + time3 = time.perf_counter() + speed_stats["forward_time"] = f"{time3 - time2:0.3f}" + loss, stats, weight = retval + stats = {k: v for k, v in stats.items() if v is not None} + if self.use_ddp or self.use_fsdp: + # Apply weighted averaging for loss and stats + loss = (loss * weight.type(loss.dtype)).sum() + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed=True) + # Now weight is summation over all workers + loss /= weight + # Multiply world_size because DistributedDataParallel + # automatically normalizes the gradient by world_size. + loss *= self.world_size + # Scale the loss since we're not updating for every mini-batch + loss = loss + time4 = time.perf_counter() + + pbar.update(1) + if self.local_rank == 0: + description = ( + f"validation: \nEpoch: {epoch}/{self.max_epoch}, " + f"step {batch_idx}/{len(self.dataloader_train)}, " + f"{speed_stats}, " + f"(loss: {loss.detach().cpu().item():.3f}), " + f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" + ) + pbar.set_description(description) + if self.writer: + self.writer.add_scalar('Loss/val', loss.item(), + epoch*len(self.dataloader_train) + batch_idx) + for key, var in stats.items(): + self.writer.add_scalar(f'{key}/val', var.item(), + epoch * len(self.dataloader_train) + batch_idx) + for key, var in speed_stats.items(): + self.writer.add_scalar(f'{key}/val', eval(var), + epoch * len(self.dataloader_train) + batch_idx) \ No newline at end of file From 7458e39ff0756d0bae38b139e0e534e61e1fa0cf Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 17 Jan 2024 19:21:08 +0800 Subject: [PATCH 2/4] bug fix --- .../paraformer/demo.py | 4 ++- funasr/models/bicif_paraformer/model.py | 34 +++++++++---------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py index ef33bf40d..78af3aa1d 100644 --- a/examples/industrial_data_pretraining/paraformer/demo.py +++ b/examples/industrial_data_pretraining/paraformer/demo.py @@ -11,6 +11,7 @@ print(res) +''' can not use currently from funasr import AutoFrontend frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2") @@ -19,4 +20,5 @@ for batch_idx, fbank_dict in enumerate(fbanks): res = model.generate(**fbank_dict) - print(res) \ No newline at end of file + print(res) +''' \ No newline at end of file diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py index 01f19c697..0069b8c98 100644 --- a/funasr/models/bicif_paraformer/model.py +++ b/funasr/models/bicif_paraformer/model.py @@ -235,23 +235,23 @@ def inference(self, self.nbest = kwargs.get("nbest", 1) meta_data = {} - if isinstance(data_in, torch.Tensor): # fbank - speech, speech_lengths = data_in, data_lengths - if len(speech.shape) < 3: - speech = speech[None, :, :] - if speech_lengths is None: - speech_lengths = speech.shape[1] - else: - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + # if isinstance(data_in, torch.Tensor): # fbank + # speech, speech_lengths = data_in, data_lengths + # if len(speech.shape) < 3: + # speech = speech[None, :, :] + # if speech_lengths is None: + # speech_lengths = speech.shape[1] + # else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) From 704db424a1aa1e92272f5aa4457efbcfefefc205 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Thu, 18 Jan 2024 11:10:25 +0800 Subject: [PATCH 3/4] Funasr1.0 (#1265) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * update with main (#1264) * Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * bug fix --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi From b28f3c9da94ae72a3a0b7bb5982b587be7cf4cd6 Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Thu, 18 Jan 2024 22:00:58 +0800 Subject: [PATCH 4/4] fsmn-vad bugfix (#1270) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * update with main (#1264) * Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * bug fix --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi * funasr1.0 sanm scama * funasr1.0 infer_after_finetune * funasr1.0 fsmn-vad bug fix * funasr1.0 fsmn-vad bug fix --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --- README.md | 11 +- README_zh.md | 11 +- .../paraformer/infer_after_finetune.sh | 12 + .../industrial_data_pretraining/scama/demo.py | 42 + .../scama/infer.sh | 11 + funasr/models/fsmn_vad_streaming/model.py | 1376 +++++++++-------- funasr/models/paraformer/model.py | 1 - funasr/models/paraformer/template.yaml | 8 - funasr/models/sanm/decoder.py | 10 +- funasr/models/sanm/encoder.py | 8 +- funasr/models/sanm/model.py | 11 +- funasr/models/sanm/template.yaml | 121 ++ .../scama/{sanm_decoder.py => decoder.py} | 11 +- .../scama/{sanm_encoder.py => encoder.py} | 10 +- funasr/models/scama/model.py | 669 ++++++++ funasr/models/scama/template.yaml | 127 ++ .../uniasr/{e2e_uni_asr.py => model.py} | 118 +- funasr/models/uniasr/template.yaml | 178 +++ 18 files changed, 1952 insertions(+), 783 deletions(-) create mode 100644 examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh create mode 100644 examples/industrial_data_pretraining/scama/demo.py create mode 100644 examples/industrial_data_pretraining/scama/infer.sh create mode 100644 funasr/models/sanm/template.yaml rename funasr/models/scama/{sanm_decoder.py => decoder.py} (99%) rename funasr/models/scama/{sanm_encoder.py => encoder.py} (98%) create mode 100644 funasr/models/scama/model.py create mode 100644 funasr/models/scama/template.yaml rename funasr/models/uniasr/{e2e_uni_asr.py => model.py} (95%) create mode 100644 funasr/models/uniasr/template.yaml diff --git a/README.md b/README.md index 0094dc4d4..c9b9e8916 100644 --- a/README.md +++ b/README.md @@ -91,12 +91,13 @@ Notes: Support recognition of single audio file, as well as file list in Kaldi-s from funasr import AutoModel # paraformer-zh is a multi-functional asr model # use vad, punc, spk or not as you need -model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \ - vad_model="fsmn-vad", vad_model_revision="v2.0.2", \ - punc_model="ct-punc-c", punc_model_revision="v2.0.2", \ - spk_model="cam++", spk_model_revision="v2.0.2") +model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", + vad_model="fsmn-vad", vad_model_revision="v2.0.2", + punc_model="ct-punc-c", punc_model_revision="v2.0.2", + # spk_model="cam++", spk_model_revision="v2.0.2", + ) res = model.generate(input=f"{model.model_path}/example/asr_example.wav", - batch_size=64, + batch_size_s=300, hotword='魔搭') print(res) ``` diff --git a/README_zh.md b/README_zh.md index 57a6bbb21..9cd18977d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -87,12 +87,13 @@ funasr +model=paraformer-zh +vad_model="fsmn-vad" +punc_model="ct-punc" +input=a from funasr import AutoModel # paraformer-zh is a multi-functional asr model # use vad, punc, spk or not as you need -model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \ - vad_model="fsmn-vad", vad_model_revision="v2.0.2", \ - punc_model="ct-punc-c", punc_model_revision="v2.0.2", \ - spk_model="cam++", spk_model_revision="v2.0.2") +model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", + vad_model="fsmn-vad", vad_model_revision="v2.0.2", + punc_model="ct-punc-c", punc_model_revision="v2.0.2", + # spk_model="cam++", spk_model_revision="v2.0.2", + ) res = model.generate(input=f"{model.model_path}/example/asr_example.wav", - batch_size=64, + batch_size_s=300, hotword='魔搭') print(res) ``` diff --git a/examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh b/examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh new file mode 100644 index 000000000..df1e54a4e --- /dev/null +++ b/examples/industrial_data_pretraining/paraformer/infer_after_finetune.sh @@ -0,0 +1,12 @@ + + +python funasr/bin/inference.py \ +--config-path="/Users/zhifu/funasr_github/test_local/funasr_cli_egs" \ +--config-name="config.yaml" \ +++init_param="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/model.pt" \ ++tokenizer_conf.token_list="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/tokens.txt" \ ++frontend_conf.cmvn_file="/Users/zhifu/funasr_github/test_local/funasr_cli_egs/am.mvn" \ ++input="data/wav.scp" \ ++output_dir="./outputs/debug" \ ++device="cuda" \ + diff --git a/examples/industrial_data_pretraining/scama/demo.py b/examples/industrial_data_pretraining/scama/demo.py new file mode 100644 index 000000000..c80599368 --- /dev/null +++ b/examples/industrial_data_pretraining/scama/demo.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +chunk_size = [5, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms +encoder_chunk_look_back = 0 #number of chunks to lookback for encoder self-attention +decoder_chunk_look_back = 0 #number of encoder chunks to lookback for decoder cross-attention + +model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_SCAMA_asr-zh-cn-16k-common-vocab8358-streaming", model_revision="v2.0.2") +cache = {} +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", + chunk_size=chunk_size, + encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, + ) +print(res) + + +import soundfile +import os + +wav_file = os.path.join(model.model_path, "example/asr_example.wav") +speech, sample_rate = soundfile.read(wav_file) + +chunk_stride = chunk_size[1] * 960 # 600ms、480ms + +cache = {} +total_chunk_num = int(len((speech)-1)/chunk_stride+1) +for i in range(total_chunk_num): + speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] + is_final = i == total_chunk_num - 1 + res = model.generate(input=speech_chunk, + cache=cache, + is_final=is_final, + chunk_size=chunk_size, + encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, + ) + print(res) diff --git a/examples/industrial_data_pretraining/scama/infer.sh b/examples/industrial_data_pretraining/scama/infer.sh new file mode 100644 index 000000000..225f2a953 --- /dev/null +++ b/examples/industrial_data_pretraining/scama/infer.sh @@ -0,0 +1,11 @@ + +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online" +model_revision="v2.0.2" + +python funasr/bin/inference.py \ ++model=${model} \ ++model_revision=${model_revision} \ ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ ++output_dir="./outputs/debug" \ ++device="cpu" \ + diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 193feb08d..943cb476a 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -19,714 +19,718 @@ class VadStateMachine(Enum): - kVadInStateStartPointNotDetected = 1 - kVadInStateInSpeechSegment = 2 - kVadInStateEndPointDetected = 3 + kVadInStateStartPointNotDetected = 1 + kVadInStateInSpeechSegment = 2 + kVadInStateEndPointDetected = 3 class FrameState(Enum): - kFrameStateInvalid = -1 - kFrameStateSpeech = 1 - kFrameStateSil = 0 + kFrameStateInvalid = -1 + kFrameStateSpeech = 1 + kFrameStateSil = 0 # final voice/unvoice state per frame class AudioChangeState(Enum): - kChangeStateSpeech2Speech = 0 - kChangeStateSpeech2Sil = 1 - kChangeStateSil2Sil = 2 - kChangeStateSil2Speech = 3 - kChangeStateNoBegin = 4 - kChangeStateInvalid = 5 + kChangeStateSpeech2Speech = 0 + kChangeStateSpeech2Sil = 1 + kChangeStateSil2Sil = 2 + kChangeStateSil2Speech = 3 + kChangeStateNoBegin = 4 + kChangeStateInvalid = 5 class VadDetectMode(Enum): - kVadSingleUtteranceDetectMode = 0 - kVadMutipleUtteranceDetectMode = 1 + kVadSingleUtteranceDetectMode = 0 + kVadMutipleUtteranceDetectMode = 1 class VADXOptions: - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__( - self, - sample_rate: int = 16000, - detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, - snr_mode: int = 0, - max_end_silence_time: int = 800, - max_start_silence_time: int = 3000, - do_start_point_detection: bool = True, - do_end_point_detection: bool = True, - window_size_ms: int = 200, - sil_to_speech_time_thres: int = 150, - speech_to_sil_time_thres: int = 150, - speech_2_noise_ratio: float = 1.0, - do_extend: int = 1, - lookback_time_start_point: int = 200, - lookahead_time_end_point: int = 100, - max_single_segment_time: int = 60000, - nn_eval_block_size: int = 8, - dcd_block_size: int = 4, - snr_thres: int = -100.0, - noise_frame_num_used_for_snr: int = 100, - decibel_thres: int = -100.0, - speech_noise_thres: float = 0.6, - fe_prior_thres: float = 1e-4, - silence_pdf_num: int = 1, - sil_pdf_ids: List[int] = [0], - speech_noise_thresh_low: float = -0.1, - speech_noise_thresh_high: float = 0.3, - output_frame_probs: bool = False, - frame_in_ms: int = 10, - frame_length_ms: int = 25, - **kwargs, - ): - self.sample_rate = sample_rate - self.detect_mode = detect_mode - self.snr_mode = snr_mode - self.max_end_silence_time = max_end_silence_time - self.max_start_silence_time = max_start_silence_time - self.do_start_point_detection = do_start_point_detection - self.do_end_point_detection = do_end_point_detection - self.window_size_ms = window_size_ms - self.sil_to_speech_time_thres = sil_to_speech_time_thres - self.speech_to_sil_time_thres = speech_to_sil_time_thres - self.speech_2_noise_ratio = speech_2_noise_ratio - self.do_extend = do_extend - self.lookback_time_start_point = lookback_time_start_point - self.lookahead_time_end_point = lookahead_time_end_point - self.max_single_segment_time = max_single_segment_time - self.nn_eval_block_size = nn_eval_block_size - self.dcd_block_size = dcd_block_size - self.snr_thres = snr_thres - self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr - self.decibel_thres = decibel_thres - self.speech_noise_thres = speech_noise_thres - self.fe_prior_thres = fe_prior_thres - self.silence_pdf_num = silence_pdf_num - self.sil_pdf_ids = sil_pdf_ids - self.speech_noise_thresh_low = speech_noise_thresh_low - self.speech_noise_thresh_high = speech_noise_thresh_high - self.output_frame_probs = output_frame_probs - self.frame_in_ms = frame_in_ms - self.frame_length_ms = frame_length_ms + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__( + self, + sample_rate: int = 16000, + detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, + snr_mode: int = 0, + max_end_silence_time: int = 800, + max_start_silence_time: int = 3000, + do_start_point_detection: bool = True, + do_end_point_detection: bool = True, + window_size_ms: int = 200, + sil_to_speech_time_thres: int = 150, + speech_to_sil_time_thres: int = 150, + speech_2_noise_ratio: float = 1.0, + do_extend: int = 1, + lookback_time_start_point: int = 200, + lookahead_time_end_point: int = 100, + max_single_segment_time: int = 60000, + nn_eval_block_size: int = 8, + dcd_block_size: int = 4, + snr_thres: int = -100.0, + noise_frame_num_used_for_snr: int = 100, + decibel_thres: int = -100.0, + speech_noise_thres: float = 0.6, + fe_prior_thres: float = 1e-4, + silence_pdf_num: int = 1, + sil_pdf_ids: List[int] = [0], + speech_noise_thresh_low: float = -0.1, + speech_noise_thresh_high: float = 0.3, + output_frame_probs: bool = False, + frame_in_ms: int = 10, + frame_length_ms: int = 25, + **kwargs, + ): + self.sample_rate = sample_rate + self.detect_mode = detect_mode + self.snr_mode = snr_mode + self.max_end_silence_time = max_end_silence_time + self.max_start_silence_time = max_start_silence_time + self.do_start_point_detection = do_start_point_detection + self.do_end_point_detection = do_end_point_detection + self.window_size_ms = window_size_ms + self.sil_to_speech_time_thres = sil_to_speech_time_thres + self.speech_to_sil_time_thres = speech_to_sil_time_thres + self.speech_2_noise_ratio = speech_2_noise_ratio + self.do_extend = do_extend + self.lookback_time_start_point = lookback_time_start_point + self.lookahead_time_end_point = lookahead_time_end_point + self.max_single_segment_time = max_single_segment_time + self.nn_eval_block_size = nn_eval_block_size + self.dcd_block_size = dcd_block_size + self.snr_thres = snr_thres + self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr + self.decibel_thres = decibel_thres + self.speech_noise_thres = speech_noise_thres + self.fe_prior_thres = fe_prior_thres + self.silence_pdf_num = silence_pdf_num + self.sil_pdf_ids = sil_pdf_ids + self.speech_noise_thresh_low = speech_noise_thresh_low + self.speech_noise_thresh_high = speech_noise_thresh_high + self.output_frame_probs = output_frame_probs + self.frame_in_ms = frame_in_ms + self.frame_length_ms = frame_length_ms class E2EVadSpeechBufWithDoa(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self): - self.start_ms = 0 - self.end_ms = 0 - self.buffer = [] - self.contain_seg_start_point = False - self.contain_seg_end_point = False - self.doa = 0 - - def Reset(self): - self.start_ms = 0 - self.end_ms = 0 - self.buffer = [] - self.contain_seg_start_point = False - self.contain_seg_end_point = False - self.doa = 0 + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): + self.start_ms = 0 + self.end_ms = 0 + self.buffer = [] + self.contain_seg_start_point = False + self.contain_seg_end_point = False + self.doa = 0 + + def Reset(self): + self.start_ms = 0 + self.end_ms = 0 + self.buffer = [] + self.contain_seg_start_point = False + self.contain_seg_end_point = False + self.doa = 0 class E2EVadFrameProb(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self): - self.noise_prob = 0.0 - self.speech_prob = 0.0 - self.score = 0.0 - self.frame_id = 0 - self.frm_state = 0 + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): + self.noise_prob = 0.0 + self.speech_prob = 0.0 + self.score = 0.0 + self.frame_id = 0 + self.frm_state = 0 class WindowDetector(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self, window_size_ms: int, - sil_to_speech_time: int, - speech_to_sil_time: int, - frame_size_ms: int): - self.window_size_ms = window_size_ms - self.sil_to_speech_time = sil_to_speech_time - self.speech_to_sil_time = speech_to_sil_time - self.frame_size_ms = frame_size_ms - - self.win_size_frame = int(window_size_ms / frame_size_ms) - self.win_sum = 0 - self.win_state = [0] * self.win_size_frame # 初始化窗 - - self.cur_win_pos = 0 - self.pre_frame_state = FrameState.kFrameStateSil - self.cur_frame_state = FrameState.kFrameStateSil - self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) - self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) - - self.voice_last_frame_count = 0 - self.noise_last_frame_count = 0 - self.hydre_frame_count = 0 - - def Reset(self) -> None: - self.cur_win_pos = 0 - self.win_sum = 0 - self.win_state = [0] * self.win_size_frame - self.pre_frame_state = FrameState.kFrameStateSil - self.cur_frame_state = FrameState.kFrameStateSil - self.voice_last_frame_count = 0 - self.noise_last_frame_count = 0 - self.hydre_frame_count = 0 - - def GetWinSize(self) -> int: - return int(self.win_size_frame) - - def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState: - cur_frame_state = FrameState.kFrameStateSil - if frameState == FrameState.kFrameStateSpeech: - cur_frame_state = 1 - elif frameState == FrameState.kFrameStateSil: - cur_frame_state = 0 - else: - return AudioChangeState.kChangeStateInvalid - self.win_sum -= self.win_state[self.cur_win_pos] - self.win_sum += cur_frame_state - self.win_state[self.cur_win_pos] = cur_frame_state - self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, window_size_ms: int, + sil_to_speech_time: int, + speech_to_sil_time: int, + frame_size_ms: int): + self.window_size_ms = window_size_ms + self.sil_to_speech_time = sil_to_speech_time + self.speech_to_sil_time = speech_to_sil_time + self.frame_size_ms = frame_size_ms + + self.win_size_frame = int(window_size_ms / frame_size_ms) + self.win_sum = 0 + self.win_state = [0] * self.win_size_frame # 初始化窗 + + self.cur_win_pos = 0 + self.pre_frame_state = FrameState.kFrameStateSil + self.cur_frame_state = FrameState.kFrameStateSil + self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) + self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) + + self.voice_last_frame_count = 0 + self.noise_last_frame_count = 0 + self.hydre_frame_count = 0 + + def Reset(self) -> None: + self.cur_win_pos = 0 + self.win_sum = 0 + self.win_state = [0] * self.win_size_frame + self.pre_frame_state = FrameState.kFrameStateSil + self.cur_frame_state = FrameState.kFrameStateSil + self.voice_last_frame_count = 0 + self.noise_last_frame_count = 0 + self.hydre_frame_count = 0 + + def GetWinSize(self) -> int: + return int(self.win_size_frame) + + def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState: + cur_frame_state = FrameState.kFrameStateSil + if frameState == FrameState.kFrameStateSpeech: + cur_frame_state = 1 + elif frameState == FrameState.kFrameStateSil: + cur_frame_state = 0 + else: + return AudioChangeState.kChangeStateInvalid + self.win_sum -= self.win_state[self.cur_win_pos] + self.win_sum += cur_frame_state + self.win_state[self.cur_win_pos] = cur_frame_state + self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame + + if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: + self.pre_frame_state = FrameState.kFrameStateSpeech + return AudioChangeState.kChangeStateSil2Speech + + if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: + self.pre_frame_state = FrameState.kFrameStateSil + return AudioChangeState.kChangeStateSpeech2Sil + + if self.pre_frame_state == FrameState.kFrameStateSil: + return AudioChangeState.kChangeStateSil2Sil + if self.pre_frame_state == FrameState.kFrameStateSpeech: + return AudioChangeState.kChangeStateSpeech2Speech + return AudioChangeState.kChangeStateInvalid + + def FrameSizeMs(self) -> int: + return int(self.frame_size_ms) + +class Stats(object): + def __init__(self, + sil_pdf_ids, + max_end_sil_frame_cnt_thresh, + speech_noise_thres, + ): + + self.data_buf_start_frame = 0 + self.frm_cnt = 0 + self.latest_confirmed_speech_frame = 0 + self.lastest_confirmed_silence_frame = -1 + self.continous_silence_frame_count = 0 + self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + self.confirmed_start_frame = -1 + self.confirmed_end_frame = -1 + self.number_end_time_detected = 0 + self.sil_frame = 0 + self.sil_pdf_ids = sil_pdf_ids + self.noise_average_decibel = -100.0 + self.pre_end_silence_detected = False + self.next_seg = True + + self.output_data_buf = [] + self.output_data_buf_offset = 0 + self.frame_probs = [] + self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh + self.speech_noise_thres = speech_noise_thres + self.scores = None + self.max_time_out = False + self.decibel = [] + self.data_buf = None + self.data_buf_all = None + self.waveform = None + self.last_drop_frames = 0 - if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: - self.pre_frame_state = FrameState.kFrameStateSpeech - return AudioChangeState.kChangeStateSil2Speech - if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: - self.pre_frame_state = FrameState.kFrameStateSil - return AudioChangeState.kChangeStateSpeech2Sil - - if self.pre_frame_state == FrameState.kFrameStateSil: - return AudioChangeState.kChangeStateSil2Sil - if self.pre_frame_state == FrameState.kFrameStateSpeech: - return AudioChangeState.kChangeStateSpeech2Speech - return AudioChangeState.kChangeStateInvalid - - def FrameSizeMs(self) -> int: - return int(self.frame_size_ms) - - -@dataclass -class StatsItem: - - # init variables - data_buf_start_frame = 0 - frm_cnt = 0 - latest_confirmed_speech_frame = 0 - lastest_confirmed_silence_frame = -1 - continous_silence_frame_count = 0 - vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - confirmed_start_frame = -1 - confirmed_end_frame = -1 - number_end_time_detected = 0 - sil_frame = 0 - sil_pdf_ids: list - noise_average_decibel = -100.0 - pre_end_silence_detected = False - next_seg = True # unused - - output_data_buf = [] - output_data_buf_offset = 0 - frame_probs = [] # unused - max_end_sil_frame_cnt_thresh: int - speech_noise_thres: float - scores = None - max_time_out = False #unused - decibel = [] - data_buf = None - data_buf_all = None - waveform = None - last_drop_frames = 0 - @tables.register("model_classes", "FsmnVADStreaming") class FsmnVADStreaming(nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self, - encoder: str = None, - encoder_conf: Optional[Dict] = None, - vad_post_args: Dict[str, Any] = None, - **kwargs, - ): - super().__init__() - self.vad_opts = VADXOptions(**kwargs) - - encoder_class = tables.encoder_classes.get(encoder) - encoder = encoder_class(**encoder_conf) - self.encoder = encoder - - - def ResetDetection(self, cache: dict = {}): - cache["stats"].continous_silence_frame_count = 0 - cache["stats"].latest_confirmed_speech_frame = 0 - cache["stats"].lastest_confirmed_silence_frame = -1 - cache["stats"].confirmed_start_frame = -1 - cache["stats"].confirmed_end_frame = -1 - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - cache["windows_detector"].Reset() - cache["stats"].sil_frame = 0 - cache["stats"].frame_probs = [] - - if cache["stats"].output_data_buf: - assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True - drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) - real_drop_frames = drop_frames - cache["stats"].last_drop_frames - cache["stats"].last_drop_frames = drop_frames - cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] - cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] - cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] - - def ComputeDecibel(self, cache: dict = {}) -> None: - frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) - frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - if cache["stats"].data_buf_all is None: - cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] - cache["stats"].data_buf = cache["stats"].data_buf_all - else: - cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) - for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length): - cache["stats"].decibel.append( - 10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \ - 0.000001)) - - def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None: - scores = self.encoder(feats, cache=cache["encoder"]).to('cpu') # return B * T * D - assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" - self.vad_opts.nn_eval_block_size = scores.shape[1] - cache["stats"].frm_cnt += scores.shape[1] # count total frames - if cache["stats"].scores is None: - cache["stats"].scores = scores # the first calculation - else: - cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1) - - def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again - while cache["stats"].data_buf_start_frame < frame_idx: - if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): - cache["stats"].data_buf_start_frame += 1 - cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( - self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] - - def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, - last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None: - self.PopDataBufTillFrame(start_frm, cache=cache) - expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) - if last_frm_is_end_point: - extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) - expected_sample_number += int(extra_sample) - if end_point_is_sent_end: - expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf)) - if len(cache["stats"].data_buf) < expected_sample_number: - print('error in calling pop data_buf\n') - - if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point: - cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa()) - cache["stats"].output_data_buf[-1].Reset() - cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms - cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms - cache["stats"].output_data_buf[-1].doa = 0 - cur_seg = cache["stats"].output_data_buf[-1] - if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: - print('warning\n') - out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 - data_to_pop = 0 - if end_point_is_sent_end: - data_to_pop = expected_sample_number - else: - data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - if data_to_pop > len(cache["stats"].data_buf): - print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n') - data_to_pop = len(cache["stats"].data_buf) - expected_sample_number = len(cache["stats"].data_buf) - - cur_seg.doa = 0 - for sample_cpy_out in range(0, data_to_pop): - # cur_seg.buffer[out_pos ++] = data_buf_.back(); - out_pos += 1 - for sample_cpy_out in range(data_to_pop, expected_sample_number): - # cur_seg.buffer[out_pos++] = data_buf_.back() - out_pos += 1 - if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: - print('Something wrong with the VAD algorithm\n') - cache["stats"].data_buf_start_frame += frm_cnt - cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms - if first_frm_is_start_point: - cur_seg.contain_seg_start_point = True - if last_frm_is_end_point: - cur_seg.contain_seg_end_point = True - - def OnSilenceDetected(self, valid_frame: int, cache: dict = {}): - cache["stats"].lastest_confirmed_silence_frame = valid_frame - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - self.PopDataBufTillFrame(valid_frame, cache=cache) - # silence_detected_callback_ - # pass - - def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None: - cache["stats"].latest_confirmed_speech_frame = valid_frame - self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache) - - def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None: - if self.vad_opts.do_start_point_detection: - pass - if cache["stats"].confirmed_start_frame != -1: - print('not reset vad properly\n') - else: - cache["stats"].confirmed_start_frame = start_frame - - if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache) - - def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None: - for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame): - self.OnVoiceDetected(t, cache=cache) - if self.vad_opts.do_end_point_detection: - pass - if cache["stats"].confirmed_end_frame != -1: - print('not reset vad properly\n') - else: - cache["stats"].confirmed_end_frame = end_frame - if not fake_result: - cache["stats"].sil_frame = 0 - self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache) - cache["stats"].number_end_time_detected += 1 - - def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None: - if is_final_frame: - self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - - def GetLatency(self, cache: dict = {}) -> int: - return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms) - - def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int: - vad_latency = cache["windows_detector"].GetWinSize() - if self.vad_opts.do_extend: - vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) - return vad_latency - - def GetFrameState(self, t: int, cache: dict = {}): - frame_state = FrameState.kFrameStateInvalid - cur_decibel = cache["stats"].decibel[t] - cur_snr = cur_decibel - cache["stats"].noise_average_decibel - # for each frame, calc log posterior probability of each state - if cur_decibel < self.vad_opts.decibel_thres: - frame_state = FrameState.kFrameStateSil - self.DetectOneFrame(frame_state, t, False, cache=cache) - return frame_state - - sum_score = 0.0 - noise_prob = 0.0 - assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num - if len(cache["stats"].sil_pdf_ids) > 0: - assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试 - sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids] - sum_score = sum(sil_pdf_scores) - noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio - total_score = 1.0 - sum_score = total_score - sum_score - speech_prob = math.log(sum_score) - if self.vad_opts.output_frame_probs: - frame_prob = E2EVadFrameProb() - frame_prob.noise_prob = noise_prob - frame_prob.speech_prob = speech_prob - frame_prob.score = sum_score - frame_prob.frame_id = t - cache["stats"].frame_probs.append(frame_prob) - if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres: - if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: - frame_state = FrameState.kFrameStateSpeech - else: - frame_state = FrameState.kFrameStateSil - else: - frame_state = FrameState.kFrameStateSil - if cache["stats"].noise_average_decibel < -99.9: - cache["stats"].noise_average_decibel = cur_decibel - else: - cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * ( - self.vad_opts.noise_frame_num_used_for_snr - - 1)) / self.vad_opts.noise_frame_num_used_for_snr - - return frame_state - - def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {}, - is_final: bool = False - ): - # if len(cache) == 0: - # self.AllResetDetection() - # self.waveform = waveform # compute decibel for each frame - cache["stats"].waveform = waveform - self.ComputeDecibel(cache=cache) - self.ComputeScores(feats, cache=cache) - if not is_final: - self.DetectCommonFrames(cache=cache) - else: - self.DetectLastFrames(cache=cache) - segments = [] - for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now - segment_batch = [] - if len(cache["stats"].output_data_buf) > 0: - for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): - if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ - i].contain_seg_end_point): - continue - segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] - segment_batch.append(segment) - cache["stats"].output_data_buf_offset += 1 # need update this parameter - if segment_batch: - segments.append(segment_batch) - # if is_final: - # # reset class variables and clear the dict for the next query - # self.AllResetDetection() - return segments - - def init_cache(self, cache: dict = {}, **kwargs): - cache["frontend"] = {} - cache["prev_samples"] = torch.empty(0) - cache["encoder"] = {} - windows_detector = WindowDetector(self.vad_opts.window_size_ms, - self.vad_opts.sil_to_speech_time_thres, - self.vad_opts.speech_to_sil_time_thres, - self.vad_opts.frame_in_ms) - - stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids, - max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres, - speech_noise_thres=self.vad_opts.speech_noise_thres, - ) - cache["windows_detector"] = windows_detector - cache["stats"] = stats - return cache - - def inference(self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - cache: dict = {}, - **kwargs, - ): - - if len(cache) == 0: - self.init_cache(cache, **kwargs) - - meta_data = {} - chunk_size = kwargs.get("chunk_size", 60000) # 50ms - chunk_stride_samples = int(chunk_size * frontend.fs / 1000) - - time1 = time.perf_counter() - cfg = {"is_final": kwargs.get("is_final", False)} - audio_sample_list = load_audio_text_image_video(data_in, - fs=frontend.fs, - audio_fs=kwargs.get("fs", 16000), - data_type=kwargs.get("data_type", "sound"), - tokenizer=tokenizer, - cache=cfg, - ) - _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True - - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - assert len(audio_sample_list) == 1, "batch_size must be set 1" - - audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) - - n = int(len(audio_sample) // chunk_stride_samples + int(_is_final)) - m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))) - segments = [] - for i in range(n): - kwargs["is_final"] = _is_final and i == n - 1 - audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples] - - # extract fbank feats - speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), - frontend=frontend, cache=cache["frontend"], - is_final=kwargs["is_final"]) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - speech = speech.to(device=kwargs["device"]) - speech_lengths = speech_lengths.to(device=kwargs["device"]) - - batch = { - "feats": speech, - "waveform": cache["frontend"]["waveforms"], - "is_final": kwargs["is_final"], - "cache": cache - } - segments_i = self.forward(**batch) - if len(segments_i) > 0: - segments.extend(*segments_i) - - - cache["prev_samples"] = audio_sample[:-m] - if _is_final: - self.init_cache(cache, **kwargs) - - ibest_writer = None - if ibest_writer is None and kwargs.get("output_dir") is not None: - writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = writer[f"{1}best_recog"] - - results = [] - result_i = {"key": key[0], "value": segments} - if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": - result_i = json.dumps(result_i) - - results.append(result_i) - - if ibest_writer is not None: - ibest_writer["text"][key[0]] = segments - - - return results, meta_data - - - def DetectCommonFrames(self, cache: dict = {}) -> int: - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: - return 0 - for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): - frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) - self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) - - return 0 - - def DetectLastFrames(self, cache: dict = {}) -> int: - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: - return 0 - for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): - frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) - if i != 0: - self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) - else: - self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache) - - return 0 - - def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: - tmp_cur_frm_state = FrameState.kFrameStateInvalid - if cur_frm_state == FrameState.kFrameStateSpeech: - if math.fabs(1.0) > self.vad_opts.fe_prior_thres: - tmp_cur_frm_state = FrameState.kFrameStateSpeech - else: - tmp_cur_frm_state = FrameState.kFrameStateSil - elif cur_frm_state == FrameState.kFrameStateSil: - tmp_cur_frm_state = FrameState.kFrameStateSil - state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache) - frm_shift_in_ms = self.vad_opts.frame_in_ms - if AudioChangeState.kChangeStateSil2Speech == state_change: - silence_frame_count = cache["stats"].continous_silence_frame_count - cache["stats"].continous_silence_frame_count = 0 - cache["stats"].pre_end_silence_detected = False - start_frame = 0 - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) - self.OnVoiceStart(start_frame, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment - for t in range(start_frame + 1, cur_frm_idx + 1): - self.OnVoiceDetected(t, cache=cache) - elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx): - self.OnVoiceDetected(t, cache=cache) - if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx, cache=cache) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) - else: - pass - elif AudioChangeState.kChangeStateSpeech2Sil == state_change: - cache["stats"].continous_silence_frame_count = 0 - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - pass - elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx, cache=cache) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) - else: - pass - elif AudioChangeState.kChangeStateSpeech2Speech == state_change: - cache["stats"].continous_silence_frame_count = 0 - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - cache["stats"].max_time_out = True - self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx, cache=cache) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) - else: - pass - elif AudioChangeState.kChangeStateSil2Sil == state_change: - cache["stats"].continous_silence_frame_count += 1 - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - # silence timeout, return zero length decision - if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( - cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ - or (is_final_frame and cache["stats"].number_end_time_detected == 0): - for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): - self.OnSilenceDetected(t, cache=cache) - self.OnVoiceStart(0, True, cache=cache) - self.OnVoiceEnd(0, True, False, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - else: - if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): - self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) - elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: - lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) - if self.vad_opts.do_extend: - lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) - lookback_frame -= 1 - lookback_frame = max(0, lookback_frame) - self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) - cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif self.vad_opts.do_extend and not is_final_frame: - if cache["stats"].continous_silence_frame_count <= int( - self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): - self.OnVoiceDetected(cur_frm_idx, cache=cache) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) - else: - pass - - if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ - self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: - self.ResetDetection(cache=cache) + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, + encoder: str = None, + encoder_conf: Optional[Dict] = None, + vad_post_args: Dict[str, Any] = None, + **kwargs, + ): + super().__init__() + self.vad_opts = VADXOptions(**kwargs) + + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(**encoder_conf) + self.encoder = encoder + + + def ResetDetection(self, cache: dict = {}): + cache["stats"].continous_silence_frame_count = 0 + cache["stats"].latest_confirmed_speech_frame = 0 + cache["stats"].lastest_confirmed_silence_frame = -1 + cache["stats"].confirmed_start_frame = -1 + cache["stats"].confirmed_end_frame = -1 + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + cache["windows_detector"].Reset() + cache["stats"].sil_frame = 0 + cache["stats"].frame_probs = [] + + if cache["stats"].output_data_buf: + assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True + drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) + real_drop_frames = drop_frames - cache["stats"].last_drop_frames + cache["stats"].last_drop_frames = drop_frames + cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] + cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] + + def ComputeDecibel(self, cache: dict = {}) -> None: + frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) + frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) + if cache["stats"].data_buf_all is None: + cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] + cache["stats"].data_buf = cache["stats"].data_buf_all + else: + cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) + for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length): + cache["stats"].decibel.append( + 10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \ + 0.000001)) + + def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None: + scores = self.encoder(feats, cache=cache["encoder"]).to('cpu') # return B * T * D + assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" + self.vad_opts.nn_eval_block_size = scores.shape[1] + cache["stats"].frm_cnt += scores.shape[1] # count total frames + if cache["stats"].scores is None: + cache["stats"].scores = scores # the first calculation + else: + cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1) + + def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again + while cache["stats"].data_buf_start_frame < frame_idx: + if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): + cache["stats"].data_buf_start_frame += 1 + cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( + self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + + def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, + last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None: + self.PopDataBufTillFrame(start_frm, cache=cache) + expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) + if last_frm_is_end_point: + extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ + self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) + expected_sample_number += int(extra_sample) + if end_point_is_sent_end: + expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf)) + if len(cache["stats"].data_buf) < expected_sample_number: + print('error in calling pop data_buf\n') + + if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point: + cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa()) + cache["stats"].output_data_buf[-1].Reset() + cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms + cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms + cache["stats"].output_data_buf[-1].doa = 0 + cur_seg = cache["stats"].output_data_buf[-1] + if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: + print('warning\n') + out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 + data_to_pop = 0 + if end_point_is_sent_end: + data_to_pop = expected_sample_number + else: + data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) + if data_to_pop > len(cache["stats"].data_buf): + print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n') + data_to_pop = len(cache["stats"].data_buf) + expected_sample_number = len(cache["stats"].data_buf) + + cur_seg.doa = 0 + for sample_cpy_out in range(0, data_to_pop): + # cur_seg.buffer[out_pos ++] = data_buf_.back(); + out_pos += 1 + for sample_cpy_out in range(data_to_pop, expected_sample_number): + # cur_seg.buffer[out_pos++] = data_buf_.back() + out_pos += 1 + if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: + print('Something wrong with the VAD algorithm\n') + cache["stats"].data_buf_start_frame += frm_cnt + cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms + if first_frm_is_start_point: + cur_seg.contain_seg_start_point = True + if last_frm_is_end_point: + cur_seg.contain_seg_end_point = True + + def OnSilenceDetected(self, valid_frame: int, cache: dict = {}): + cache["stats"].lastest_confirmed_silence_frame = valid_frame + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + self.PopDataBufTillFrame(valid_frame, cache=cache) + # silence_detected_callback_ + # pass + + def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None: + cache["stats"].latest_confirmed_speech_frame = valid_frame + self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache) + + def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None: + if self.vad_opts.do_start_point_detection: + pass + if cache["stats"].confirmed_start_frame != -1: + print('not reset vad properly\n') + else: + cache["stats"].confirmed_start_frame = start_frame + + if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache) + + def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None: + for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame): + self.OnVoiceDetected(t, cache=cache) + if self.vad_opts.do_end_point_detection: + pass + if cache["stats"].confirmed_end_frame != -1: + print('not reset vad properly\n') + else: + cache["stats"].confirmed_end_frame = end_frame + if not fake_result: + cache["stats"].sil_frame = 0 + self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache) + cache["stats"].number_end_time_detected += 1 + + def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None: + if is_final_frame: + self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + + def GetLatency(self, cache: dict = {}) -> int: + return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms) + + def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int: + vad_latency = cache["windows_detector"].GetWinSize() + if self.vad_opts.do_extend: + vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) + return vad_latency + + def GetFrameState(self, t: int, cache: dict = {}): + frame_state = FrameState.kFrameStateInvalid + cur_decibel = cache["stats"].decibel[t] + cur_snr = cur_decibel - cache["stats"].noise_average_decibel + # for each frame, calc log posterior probability of each state + if cur_decibel < self.vad_opts.decibel_thres: + frame_state = FrameState.kFrameStateSil + self.DetectOneFrame(frame_state, t, False, cache=cache) + return frame_state + + sum_score = 0.0 + noise_prob = 0.0 + assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num + if len(cache["stats"].sil_pdf_ids) > 0: + assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试 + sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids] + sum_score = sum(sil_pdf_scores) + noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio + total_score = 1.0 + sum_score = total_score - sum_score + speech_prob = math.log(sum_score) + if self.vad_opts.output_frame_probs: + frame_prob = E2EVadFrameProb() + frame_prob.noise_prob = noise_prob + frame_prob.speech_prob = speech_prob + frame_prob.score = sum_score + frame_prob.frame_id = t + cache["stats"].frame_probs.append(frame_prob) + if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres: + if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: + frame_state = FrameState.kFrameStateSpeech + else: + frame_state = FrameState.kFrameStateSil + else: + frame_state = FrameState.kFrameStateSil + if cache["stats"].noise_average_decibel < -99.9: + cache["stats"].noise_average_decibel = cur_decibel + else: + cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * ( + self.vad_opts.noise_frame_num_used_for_snr + - 1)) / self.vad_opts.noise_frame_num_used_for_snr + + return frame_state + + def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {}, + is_final: bool = False + ): + # if len(cache) == 0: + # self.AllResetDetection() + # self.waveform = waveform # compute decibel for each frame + cache["stats"].waveform = waveform + self.ComputeDecibel(cache=cache) + self.ComputeScores(feats, cache=cache) + if not is_final: + self.DetectCommonFrames(cache=cache) + else: + self.DetectLastFrames(cache=cache) + segments = [] + for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now + segment_batch = [] + if len(cache["stats"].output_data_buf) > 0: + for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): + if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ + i].contain_seg_end_point): + continue + segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] + segment_batch.append(segment) + cache["stats"].output_data_buf_offset += 1 # need update this parameter + if segment_batch: + segments.append(segment_batch) + # if is_final: + # # reset class variables and clear the dict for the next query + # self.AllResetDetection() + return segments + + def init_cache(self, cache: dict = {}, **kwargs): + cache["frontend"] = {} + cache["prev_samples"] = torch.empty(0) + cache["encoder"] = {} + windows_detector = WindowDetector(self.vad_opts.window_size_ms, + self.vad_opts.sil_to_speech_time_thres, + self.vad_opts.speech_to_sil_time_thres, + self.vad_opts.frame_in_ms) + windows_detector.Reset() + + stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids, + max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres, + speech_noise_thres=self.vad_opts.speech_noise_thres + ) + cache["windows_detector"] = windows_detector + cache["stats"] = stats + return cache + + def inference(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + cache: dict = {}, + **kwargs, + ): + + if len(cache) == 0: + self.init_cache(cache, **kwargs) + + meta_data = {} + chunk_size = kwargs.get("chunk_size", 60000) # 50ms + chunk_stride_samples = int(chunk_size * frontend.fs / 1000) + + time1 = time.perf_counter() + cfg = {"is_final": kwargs.get("is_final", False)} + audio_sample_list = load_audio_text_image_video(data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + cache=cfg, + ) + _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True + + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + assert len(audio_sample_list) == 1, "batch_size must be set 1" + + audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) + + n = int(len(audio_sample) // chunk_stride_samples + int(_is_final)) + m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))) + segments = [] + for i in range(n): + kwargs["is_final"] = _is_final and i == n - 1 + audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples] + + # extract fbank feats + speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), + frontend=frontend, cache=cache["frontend"], + is_final=kwargs["is_final"]) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + batch = { + "feats": speech, + "waveform": cache["frontend"]["waveforms"], + "is_final": kwargs["is_final"], + "cache": cache + } + segments_i = self.forward(**batch) + if len(segments_i) > 0: + segments.extend(*segments_i) + + + cache["prev_samples"] = audio_sample[:-m] + if _is_final: + cache = {} + + ibest_writer = None + if ibest_writer is None and kwargs.get("output_dir") is not None: + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{1}best_recog"] + + results = [] + result_i = {"key": key[0], "value": segments} + if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": + result_i = json.dumps(result_i) + + results.append(result_i) + + if ibest_writer is not None: + ibest_writer["text"][key[0]] = segments + + + return results, meta_data + + + def DetectCommonFrames(self, cache: dict = {}) -> int: + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + return 0 + for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): + frame_state = FrameState.kFrameStateInvalid + frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) + self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) + + return 0 + + def DetectLastFrames(self, cache: dict = {}) -> int: + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + return 0 + for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): + frame_state = FrameState.kFrameStateInvalid + frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) + if i != 0: + self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) + else: + self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache) + + return 0 + + def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: + tmp_cur_frm_state = FrameState.kFrameStateInvalid + if cur_frm_state == FrameState.kFrameStateSpeech: + if math.fabs(1.0) > self.vad_opts.fe_prior_thres: + tmp_cur_frm_state = FrameState.kFrameStateSpeech + else: + tmp_cur_frm_state = FrameState.kFrameStateSil + elif cur_frm_state == FrameState.kFrameStateSil: + tmp_cur_frm_state = FrameState.kFrameStateSil + state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache) + frm_shift_in_ms = self.vad_opts.frame_in_ms + if AudioChangeState.kChangeStateSil2Speech == state_change: + silence_frame_count = cache["stats"].continous_silence_frame_count + cache["stats"].continous_silence_frame_count = 0 + cache["stats"].pre_end_silence_detected = False + start_frame = 0 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) + self.OnVoiceStart(start_frame, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment + for t in range(start_frame + 1, cur_frm_idx + 1): + self.OnVoiceDetected(t, cache=cache) + elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx): + self.OnVoiceDetected(t, cache=cache) + if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif not is_final_frame: + self.OnVoiceDetected(cur_frm_idx, cache=cache) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) + else: + pass + elif AudioChangeState.kChangeStateSpeech2Sil == state_change: + cache["stats"].continous_silence_frame_count = 0 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + pass + elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif not is_final_frame: + self.OnVoiceDetected(cur_frm_idx, cache=cache) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) + else: + pass + elif AudioChangeState.kChangeStateSpeech2Speech == state_change: + cache["stats"].continous_silence_frame_count = 0 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + cache["stats"].max_time_out = True + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif not is_final_frame: + self.OnVoiceDetected(cur_frm_idx, cache=cache) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) + else: + pass + elif AudioChangeState.kChangeStateSil2Sil == state_change: + cache["stats"].continous_silence_frame_count += 1 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + # silence timeout, return zero length decision + if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( + cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ + or (is_final_frame and cache["stats"].number_end_time_detected == 0): + for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): + self.OnSilenceDetected(t, cache=cache) + self.OnVoiceStart(0, True, cache=cache) + self.OnVoiceEnd(0, True, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + else: + if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): + self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) + elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: + lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) + if self.vad_opts.do_extend: + lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) + lookback_frame -= 1 + lookback_frame = max(0, lookback_frame) + self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif self.vad_opts.do_extend and not is_final_frame: + if cache["stats"].continous_silence_frame_count <= int( + self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): + self.OnVoiceDetected(cur_frm_idx, cache=cache) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) + else: + pass + + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ + self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: + self.ResetDetection(cache=cache) diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 9f3c3f3b6..468d23f39 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -33,7 +33,6 @@ class Paraformer(torch.nn.Module): def __init__( self, - # token_list: Union[Tuple[str, ...], List[str]], specaug: Optional[str] = None, specaug_conf: Optional[Dict] = None, normalize: str = None, diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml index 3972caaa1..bccf63871 100644 --- a/funasr/models/paraformer/template.yaml +++ b/funasr/models/paraformer/template.yaml @@ -6,7 +6,6 @@ # tables.print() # network architecture -#model: funasr.models.paraformer.model:Paraformer model: Paraformer model_conf: ctc_weight: 0.0 @@ -87,13 +86,6 @@ train_conf: accum_grad: 1 grad_clip: 5 max_epoch: 150 - val_scheduler_criterion: - - valid - - acc - best_model_criterion: - - - valid - - acc - - max keep_nbest_models: 10 avg_nbest_model: 5 log_interval: 50 diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py index 190ada0f3..35752829f 100644 --- a/funasr/models/sanm/decoder.py +++ b/funasr/models/sanm/decoder.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + from typing import List from typing import Tuple import logging @@ -193,10 +198,9 @@ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size @tables.register("decoder_classes", "FsmnDecoder") class FsmnDecoder(BaseTransformerDecoder): """ - Author: Speech Lab of DAMO Academy, Alibaba Group - SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition + Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin + San-m: Memory equipped self-attention for end-to-end speech recognition https://arxiv.org/abs/2006.01713 - """ def __init__( diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py index cb4e21af4..069c527a2 100644 --- a/funasr/models/sanm/encoder.py +++ b/funasr/models/sanm/encoder.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + from typing import List from typing import Optional from typing import Sequence @@ -156,10 +161,9 @@ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): @tables.register("encoder_classes", "SANMEncoder") class SANMEncoder(nn.Module): """ - Author: Speech Lab of DAMO Academy, Alibaba Group + Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin San-m: Memory equipped self-attention for end-to-end speech recognition https://arxiv.org/abs/2006.01713 - """ def __init__( diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py index 4dc882541..0cef54061 100644 --- a/funasr/models/sanm/model.py +++ b/funasr/models/sanm/model.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + import logging import torch @@ -7,7 +12,11 @@ @tables.register("model_classes", "SANM") class SANM(Transformer): - """CTC-attention hybrid Encoder-Decoder model""" + """ + Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin + San-m: Memory equipped self-attention for end-to-end speech recognition + https://arxiv.org/abs/2006.01713 + """ def __init__( self, diff --git a/funasr/models/sanm/template.yaml b/funasr/models/sanm/template.yaml new file mode 100644 index 000000000..156926f2c --- /dev/null +++ b/funasr/models/sanm/template.yaml @@ -0,0 +1,121 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: SANM +model_conf: + ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + +# encoder +encoder: SANMEncoder +encoder_conf: + output_size: 512 + attention_heads: 4 + linear_units: 2048 + num_blocks: 50 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + +# decoder +decoder: FsmnDecoder +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 16 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 16 + kernel_size: 11 + sanm_shfit: 0 + + + +# frontend related +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + val_scheduler_criterion: + - valid + - acc + best_model_criterion: + - - valid + - acc + - max + keep_nbest_models: 10 + avg_nbest_model: 5 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: DynamicBatchLocalShuffleSampler + batch_type: example # example or length + batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true + +normalize: null diff --git a/funasr/models/scama/sanm_decoder.py b/funasr/models/scama/decoder.py similarity index 99% rename from funasr/models/scama/sanm_decoder.py rename to funasr/models/scama/decoder.py index 4222e5f85..9dcb9da72 100644 --- a/funasr/models/scama/sanm_decoder.py +++ b/funasr/models/scama/decoder.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + from typing import List from typing import Tuple import logging @@ -192,11 +197,11 @@ def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size @tables.register("decoder_classes", "FsmnDecoderSCAMAOpt") class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): """ - Author: Speech Lab of DAMO Academy, Alibaba Group + Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition - https://arxiv.org/abs/2006.01713 - + https://arxiv.org/abs/2006.01712 """ + def __init__( self, vocab_size: int, diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/encoder.py similarity index 98% rename from funasr/models/scama/sanm_encoder.py rename to funasr/models/scama/encoder.py index 5e28db7df..3651e6128 100644 --- a/funasr/models/scama/sanm_encoder.py +++ b/funasr/models/scama/encoder.py @@ -1,3 +1,8 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + from typing import List from typing import Optional from typing import Sequence @@ -157,10 +162,9 @@ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): @tables.register("encoder_classes", "SANMEncoderChunkOpt") class SANMEncoderChunkOpt(nn.Module): """ - Author: Speech Lab of DAMO Academy, Alibaba Group + Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition - https://arxiv.org/abs/2006.01713 - + https://arxiv.org/abs/2006.01712 """ def __init__( diff --git a/funasr/models/scama/model.py b/funasr/models/scama/model.py new file mode 100644 index 000000000..aec6fe329 --- /dev/null +++ b/funasr/models/scama/model.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import time +import torch +import torch.nn as nn +import torch.functional as F +import logging +from typing import Dict, Tuple +from contextlib import contextmanager +from distutils.version import LooseVersion + +from funasr.register import tables +from funasr.models.ctc.ctc import CTC +from funasr.utils import postprocess_utils +from funasr.metrics.compute_acc import th_accuracy +from funasr.utils.datadir_writer import DatadirWriter +from funasr.models.paraformer.model import Paraformer +from funasr.models.paraformer.search import Hypothesis +from funasr.models.paraformer.cif_predictor import mae_loss +from funasr.train_utils.device_funcs import force_gatherable +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.models.scama.utils import sequence_mask + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + +@tables.register("model_classes", "SCAMA") +class SCAMA(nn.Module): + """ + Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie + SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition + https://arxiv.org/abs/2006.01712 + """ + + def __init__( + self, + specaug: str = None, + specaug_conf: dict = None, + normalize: str = None, + normalize_conf: dict = None, + encoder: str = None, + encoder_conf: dict = None, + decoder: str = None, + decoder_conf: dict = None, + ctc: str = None, + ctc_conf: dict = None, + ctc_weight: float = 0.5, + predictor: str = None, + predictor_conf: dict = None, + predictor_bias: int = 0, + predictor_weight: float = 0.0, + input_size: int = 80, + vocab_size: int = -1, + ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + share_embedding: bool = False, + **kwargs, + ): + + super().__init__() + + if specaug is not None: + specaug_class = tables.specaug_classes.get(specaug) + specaug = specaug_class(**specaug_conf) + + if normalize is not None: + normalize_class = tables.normalize_classes.get(normalize) + normalize = normalize_class(**normalize_conf) + + encoder_class = tables.encoder_classes.get(encoder) + encoder = encoder_class(input_size=input_size, **encoder_conf) + encoder_output_size = encoder.output_size() + + decoder_class = tables.decoder_classes.get(decoder) + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + **decoder_conf, + ) + if ctc_weight > 0.0: + + if ctc_conf is None: + ctc_conf = {} + + ctc = CTC( + odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf + ) + + predictor_class = tables.predictor_classes.get(predictor) + predictor = predictor_class(**predictor_conf) + + # note that eos is the same as sos (equivalent ID) + self.blank_id = blank_id + self.sos = sos if sos is not None else vocab_size - 1 + self.eos = eos if eos is not None else vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + + self.specaug = specaug + self.normalize = normalize + + self.encoder = encoder + + + if ctc_weight == 1.0: + self.decoder = None + else: + self.decoder = decoder + + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + + self.predictor = predictor + self.predictor_weight = predictor_weight + self.predictor_bias = predictor_bias + + self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) + + self.share_embedding = share_embedding + if self.share_embedding: + self.decoder.embed = None + + self.length_normalized_loss = length_normalized_loss + self.beam_search = None + self.error_calculator = None + + if self.encoder.overlap_chunk_cls is not None: + from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder + self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder + self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk") + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + + decoding_ind = kwargs.get("decoding_ind") + if len(text_lengths.size()) > 1: + text_lengths = text_lengths[:, 0] + if len(speech_lengths.size()) > 1: + speech_lengths = speech_lengths[:, 0] + + batch_size = speech.shape[0] + + # Encoder + ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind) + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind) + + + loss_ctc, cer_ctc = None, None + loss_pre = None + stats = dict() + + # decoder: CTC branch + + if self.ctc_weight > 0.0: + + encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, + encoder_out_lens, + chunk_outs=None) + + + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths + ) + # Collect CTC branch stats + stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None + stats["cer_ctc"] = cer_ctc + + # decoder: Attention decoder branch + loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 3. CTC-Att loss definition + if self.ctc_weight == 0.0: + loss = loss_att + loss_pre * self.predictor_weight + else: + loss = self.ctc_weight * loss_ctc + ( + 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + + # Collect Attn branch stats + stats["loss_att"] = loss_att.detach() if loss_att is not None else None + stats["acc"] = acc_att + stats["cer"] = cer_att + stats["wer"] = wer_att + stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None + + stats["loss"] = torch.clone(loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + if self.length_normalized_loss: + batch_size = (text_lengths + self.predictor_bias).sum() + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + return encoder_out, encoder_out_lens + + def encode_chunk( + self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + ind: int + """ + with autocast(False): + + # Data augmentation + if self.specaug is not None and self.training: + speech, speech_lengths = self.specaug(speech, speech_lengths) + + # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + speech, speech_lengths = self.normalize(speech, speech_lengths) + + # Forward encoder + encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"]) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + return encoder_out, torch.tensor([encoder_out.size(1)]) + + def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs): + is_final = kwargs.get("is_final", False) + + return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final) + + def _calc_att_predictor_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype, + device=encoder_out.device)[:, None, :] + mask_chunk_predictor = None + if self.encoder.overlap_chunk_cls is not None: + mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None, + device=encoder_out.device, + batch_size=encoder_out.size( + 0)) + mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, + batch_size=encoder_out.size(0)) + encoder_out = encoder_out * mask_shfit_chunk + pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out, + ys_out_pad, + encoder_out_mask, + ignore_id=self.ignore_id, + mask_chunk_predictor=mask_chunk_predictor, + target_label_length=ys_in_lens, + ) + predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, + encoder_out_lens) + + + encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur + attention_chunk_center_bias = 0 + attention_chunk_size = encoder_chunk_size + decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur + mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None, + device=encoder_out.device, + batch_size=encoder_out.size( + 0)) + scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn( + predictor_alignments=predictor_alignments, + encoder_sequence_length=encoder_out_lens, + chunk_size=1, + encoder_chunk_size=encoder_chunk_size, + attention_chunk_center_bias=attention_chunk_center_bias, + attention_chunk_size=attention_chunk_size, + attention_chunk_type=self.decoder_attention_chunk_type, + step=None, + predictor_mask_chunk_hopping=mask_chunk_predictor, + decoder_att_look_back_factor=decoder_att_look_back_factor, + mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, + target_length=ys_in_lens, + is_training=self.training, + ) + + + # try: + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, + encoder_out_lens, + ys_in_pad, + ys_in_lens, + chunk_mask=scama_mask, + pre_acoustic_embeds=pre_acoustic_embeds, + + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + # predictor loss + loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length) + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att, loss_pre + + def calc_predictor_mask( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor = None, + ys_pad_lens: torch.Tensor = None, + ): + # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + # ys_in_lens = ys_pad_lens + 1 + ys_out_pad, ys_in_lens = None, None + + encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype, + device=encoder_out.device)[:, None, :] + mask_chunk_predictor = None + + mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None, + device=encoder_out.device, + batch_size=encoder_out.size( + 0)) + mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device, + batch_size=encoder_out.size(0)) + encoder_out = encoder_out * mask_shfit_chunk + pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out, + ys_out_pad, + encoder_out_mask, + ignore_id=self.ignore_id, + mask_chunk_predictor=mask_chunk_predictor, + target_label_length=ys_in_lens, + ) + predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas, + encoder_out_lens) + + + encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur + attention_chunk_center_bias = 0 + attention_chunk_size = encoder_chunk_size + decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur + mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None, + device=encoder_out.device, + batch_size=encoder_out.size( + 0)) + scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn( + predictor_alignments=predictor_alignments, + encoder_sequence_length=encoder_out_lens, + chunk_size=1, + encoder_chunk_size=encoder_chunk_size, + attention_chunk_center_bias=attention_chunk_center_bias, + attention_chunk_size=attention_chunk_size, + attention_chunk_type=self.decoder_attention_chunk_type, + step=None, + predictor_mask_chunk_hopping=mask_chunk_predictor, + decoder_att_look_back_factor=decoder_att_look_back_factor, + mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder, + target_length=ys_in_lens, + is_training=self.training, + ) + + return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask + + def init_beam_search(self, + **kwargs, + ): + from funasr.models.scama.beam_search import BeamSearchScama + from funasr.models.transformer.scorers.ctc import CTCPrefixScorer + from funasr.models.transformer.scorers.length_bonus import LengthBonus + + # 1. Build ASR model + scorers = {} + + if self.ctc != None: + ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) + scorers.update( + ctc=ctc + ) + token_list = kwargs.get("token_list") + scorers.update( + decoder=self.decoder, + length_bonus=LengthBonus(len(token_list)), + ) + + # 3. Build ngram model + # ngram is not supported now + ngram = None + scorers["ngram"] = ngram + + weights = dict( + decoder=1.0 - kwargs.get("decoding_ctc_weight"), + ctc=kwargs.get("decoding_ctc_weight", 0.0), + lm=kwargs.get("lm_weight", 0.0), + ngram=kwargs.get("ngram_weight", 0.0), + length_bonus=kwargs.get("penalty", 0.0), + ) + beam_search = BeamSearchScama( + beam_size=kwargs.get("beam_size", 2), + weights=weights, + scorers=scorers, + sos=self.sos, + eos=self.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", + ) + # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() + # for scorer in scorers.values(): + # if isinstance(scorer, torch.nn.Module): + # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() + self.beam_search = beam_search + + def generate_chunk(self, + speech, + speech_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + **kwargs, + ): + cache = kwargs.get("cache", {}) + speech = speech.to(device=kwargs["device"]) + speech_lengths = speech_lengths.to(device=kwargs["device"]) + + # Encoder + encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, + is_final=kwargs.get("is_final", False)) + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # predictor + predictor_outs = self.calc_predictor_chunk(encoder_out, + encoder_out_lens, + cache=cache, + is_final=kwargs.get("is_final", False), + ) + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ + predictor_outs[2], predictor_outs[3] + pre_token_length = pre_token_length.round().long() + + + if torch.max(pre_token_length) < 1: + return [] + decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out, + encoder_out_lens, + pre_acoustic_embeds, + pre_token_length, + cache=cache + ) + decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] + + results = [] + b, n, d = decoder_out.size() + if isinstance(key[0], (list, tuple)): + key = key[0] + for i in range(b): + x = encoder_out[i, :encoder_out_lens[i], :] + am_scores = decoder_out[i, :pre_token_length[i], :] + if self.beam_search is not None: + nbest_hyps = self.beam_search( + x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), + minlenratio=kwargs.get("minlenratio", 0.0) + ) + + nbest_hyps = nbest_hyps[: self.nbest] + else: + + yseq = am_scores.argmax(dim=-1) + score = am_scores.max(dim=-1)[0] + score = torch.sum(score, dim=-1) + # pad with mask tokens to ensure compatibility with sos/eos tokens + yseq = torch.tensor( + [self.sos] + yseq.tolist() + [self.eos], device=yseq.device + ) + nbest_hyps = [Hypothesis(yseq=yseq, score=score)] + for nbest_idx, hyp in enumerate(nbest_hyps): + + # remove sos/eos and get results + last_pos = -1 + if isinstance(hyp.yseq, list): + token_int = hyp.yseq[1:last_pos] + else: + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) + + # Change integer-ids to tokens + token = tokenizer.ids2tokens(token_int) + # text = tokenizer.tokens2text(token) + + result_i = token + + results.extend(result_i) + + return results + + def init_cache(self, cache: dict = {}, **kwargs): + chunk_size = kwargs.get("chunk_size", [0, 10, 5]) + encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0) + decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0) + batch_size = 1 + + enc_output_size = kwargs["encoder_conf"]["output_size"] + feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"] + cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, + "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), + "tail_chunk": False} + cache["encoder"] = cache_encoder + + cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None, + "chunk_size": chunk_size} + cache["decoder"] = cache_decoder + cache["frontend"] = {} + cache["prev_samples"] = torch.empty(0) + + return cache + + def inference(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + cache: dict = {}, + **kwargs, + ): + + # init beamsearch + is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None + is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None + if self.beam_search is None and (is_use_lm or is_use_ctc): + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) + + if len(cache) == 0: + self.init_cache(cache, **kwargs) + + meta_data = {} + chunk_size = kwargs.get("chunk_size", [0, 10, 5]) + chunk_stride_samples = int(chunk_size[1] * 960) # 600ms + + time1 = time.perf_counter() + cfg = {"is_final": kwargs.get("is_final", False)} + audio_sample_list = load_audio_text_image_video(data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + cache=cfg, + ) + _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True + + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + assert len(audio_sample_list) == 1, "batch_size must be set 1" + + audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) + + n = int(len(audio_sample) // chunk_stride_samples + int(_is_final)) + m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))) + tokens = [] + for i in range(n): + kwargs["is_final"] = _is_final and i == n - 1 + audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples] + + # extract fbank feats + speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), + frontend=frontend, cache=cache["frontend"], + is_final=kwargs["is_final"]) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, + frontend=frontend, **kwargs) + tokens.extend(tokens_i) + + text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens) + + result_i = {"key": key[0], "text": text_postprocessed} + result = [result_i] + + cache["prev_samples"] = audio_sample[:-m] + if _is_final: + self.init_cache(cache, **kwargs) + + if kwargs.get("output_dir"): + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{1}best_recog"] + ibest_writer["token"][key[0]] = " ".join(tokens) + ibest_writer["text"][key[0]] = text_postprocessed + + return result, meta_data diff --git a/funasr/models/scama/template.yaml b/funasr/models/scama/template.yaml new file mode 100644 index 000000000..f647a9222 --- /dev/null +++ b/funasr/models/scama/template.yaml @@ -0,0 +1,127 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: SCAMA +model_conf: + ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + +# encoder +encoder: SANMEncoderChunkOpt +encoder_conf: + output_size: 512 + attention_heads: 4 + linear_units: 2048 + num_blocks: 50 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + +# decoder +decoder: FsmnDecoderSCAMAOpt +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 16 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 16 + kernel_size: 11 + sanm_shfit: 0 + +predictor: CifPredictorV2 +predictor_conf: + idim: 512 + threshold: 1.0 + l_order: 1 + r_order: 1 + tail_threshold: 0.45 + +# frontend related +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + val_scheduler_criterion: + - valid + - acc + best_model_criterion: + - - valid + - acc + - max + keep_nbest_models: 10 + avg_nbest_model: 5 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: DynamicBatchLocalShuffleSampler + batch_type: example # example or length + batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true + +normalize: null diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/model.py similarity index 95% rename from funasr/models/uniasr/e2e_uni_asr.py rename to funasr/models/uniasr/model.py index 390d27418..de80d4ac7 100644 --- a/funasr/models/uniasr/e2e_uni_asr.py +++ b/funasr/models/uniasr/model.py @@ -1,85 +1,73 @@ -import logging -from contextlib import contextmanager -from distutils.version import LooseVersion -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) +import time import torch +import logging +from torch.cuda.amp import autocast +from typing import Union, Dict, List, Tuple, Optional -from funasr.models.e2e_asr_common import ErrorCalculator +from funasr.register import tables +from funasr.models.ctc.ctc import CTC +from funasr.utils import postprocess_utils from funasr.metrics.compute_acc import th_accuracy -from funasr.models.transformer.utils.add_sos_eos import add_sos_eos -from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -from funasr.models.ctc import CTC -from funasr.models.decoder.abs_decoder import AbsDecoder -from funasr.models.encoder.abs_encoder import AbsEncoder -from funasr.frontends.abs_frontend import AbsFrontend -from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.layers.abs_normalize import AbsNormalize -from funasr.train_utils.device_funcs import force_gatherable -from funasr.models.base_model import FunASRModel -from funasr.models.scama.chunk_utilis import sequence_mask +from funasr.utils.datadir_writer import DatadirWriter +from funasr.models.paraformer.search import Hypothesis from funasr.models.paraformer.cif_predictor import mae_loss - -if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - from torch.cuda.amp import autocast -else: - # Nothing to do if torch<1.6.0 - @contextmanager - def autocast(enabled=True): - yield +from funasr.train_utils.device_funcs import force_gatherable +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.transformer.utils.add_sos_eos import add_sos_eos +from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -class UniASR(FunASRModel): +@tables.register("model_classes", "UniASR") +class UniASR(torch.nn.Module): """ Author: Speech Lab of DAMO Academy, Alibaba Group """ def __init__( self, - vocab_size: int, - token_list: Union[Tuple[str, ...], List[str]], - frontend: Optional[AbsFrontend], - specaug: Optional[AbsSpecAug], - normalize: Optional[AbsNormalize], - encoder: AbsEncoder, - decoder: AbsDecoder, - ctc: CTC, + specaug: Optional[str] = None, + specaug_conf: Optional[Dict] = None, + normalize: str = None, + normalize_conf: Optional[Dict] = None, + encoder: str = None, + encoder_conf: Optional[Dict] = None, + decoder: str = None, + decoder_conf: Optional[Dict] = None, + ctc: str = None, + ctc_conf: Optional[Dict] = None, + predictor: str = None, + predictor_conf: Optional[Dict] = None, ctc_weight: float = 0.5, - interctc_weight: float = 0.0, + input_size: int = 80, + vocab_size: int = -1, ignore_id: int = -1, + blank_id: int = 0, + sos: int = 1, + eos: int = 2, lsm_weight: float = 0.0, length_normalized_loss: bool = False, - report_cer: bool = True, - report_wer: bool = True, - sym_space: str = "", - sym_blank: str = "", - extract_feats_in_collect_stats: bool = True, - predictor=None, + # report_cer: bool = True, + # report_wer: bool = True, + # sym_space: str = "", + # sym_blank: str = "", + # extract_feats_in_collect_stats: bool = True, + # predictor=None, predictor_weight: float = 0.0, - decoder_attention_chunk_type: str = 'chunk', - encoder2: AbsEncoder = None, - decoder2: AbsDecoder = None, - ctc2: CTC = None, - ctc_weight2: float = 0.5, - interctc_weight2: float = 0.0, - predictor2=None, - predictor_weight2: float = 0.0, - decoder_attention_chunk_type2: str = 'chunk', - stride_conv=None, - loss_weight_model1: float = 0.5, - enable_maas_finetune: bool = False, - freeze_encoder2: bool = False, - preencoder: Optional[AbsPreEncoder] = None, - postencoder: Optional[AbsPostEncoder] = None, + predictor_bias: int = 0, + sampling_ratio: float = 0.2, + share_embedding: bool = False, + # preencoder: Optional[AbsPreEncoder] = None, + # postencoder: Optional[AbsPostEncoder] = None, + use_1st_decoder_loss: bool = False, encoder1_encoder2_joint_training: bool = True, + **kwargs, + ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= interctc_weight < 1.0, interctc_weight @@ -443,10 +431,8 @@ def forward( # force_gatherable: to-device and to-tensor if scalar for DataParallel if self.length_normalized_loss: batch_size = int((text_lengths + 1).sum()) -<<<<<<< HEAD:funasr/models/uniasr/e2e_uni_asr.py -======= ->>>>>>> main:funasr/models/e2e_uni_asr.py + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight diff --git a/funasr/models/uniasr/template.yaml b/funasr/models/uniasr/template.yaml new file mode 100644 index 000000000..f4815c131 --- /dev/null +++ b/funasr/models/uniasr/template.yaml @@ -0,0 +1,178 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: UniASR +model_conf: + ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + predictor_weight: 1.0 + decoder_attention_chunk_type: chunk + ctc_weight2: 0.0 + predictor_weight2: 1.0 + decoder_attention_chunk_type2: chunk + loss_weight_model1: 0.5 + +# encoder +encoder: SANMEncoderChunkOpt +encoder_conf: + output_size: 320 + attention_heads: 4 + linear_units: 1280 + num_blocks: 35 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + chunk_size: [20, 60] + stride: [10, 40] + pad_left: [5, 10] + encoder_att_look_back_factor: [0, 0] + decoder_att_look_back_factor: [0, 0] + +# decoder +decoder: FsmnDecoderSCAMAOpt +decoder_conf: + attention_dim: 256 + attention_heads: 4 + linear_units: 1024 + num_blocks: 12 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 6 + kernel_size: 11 + concat_embeds: true + +predictor: CifPredictorV2 +predictor_conf: + idim: 320 + threshold: 1.0 + l_order: 1 + r_order: 1 + +encoder2: SANMEncoderChunkOpt +encoder2_conf: + output_size: 320 + attention_heads: 4 + linear_units: 1280 + num_blocks: 20 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 21 + sanm_shfit: 0 + selfattention_layer_type: sanm + chunk_size: [45, 70] + stride: [35, 50] + pad_left: [5, 10] + encoder_att_look_back_factor: [0, 0] + decoder_att_look_back_factor: [0, 0] + +decoder2: FsmnDecoderSCAMAOpt +decoder2_conf: + attention_dim: 320 + attention_heads: 4 + linear_units: 1280 + num_blocks: 12 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 6 + kernel_size: 11 + concat_embeds: true + +predictor2: CifPredictorV2 +predictor2_conf: + idim: 320 + threshold: 1.0 + l_order: 1 + r_order: 1 + +stride_conv: stride_conv1d +stride_conv_conf: + kernel_size: 2 + stride: 2 + pad: [0, 1] + +# frontend related +frontend: WavFrontendOnline +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + keep_nbest_models: 10 + avg_nbest_model: 5 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.0001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: DynamicBatchLocalShuffleSampler + batch_type: example # example or length + batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +normalize: null