diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 1858e4aaf..becfd56e3 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -255,7 +255,6 @@ def __init__(self, self.waveform = None self.last_drop_frames = 0 - @tables.register("model_classes", "FsmnVADStreaming") class FsmnVADStreaming(nn.Module): """ @@ -500,8 +499,9 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {}, # # reset class variables and clear the dict for the next query # self.AllResetDetection() return segments - - def init_cache(self, cache: dict = {}): + + def init_cache(self, cache: dict = {}, **kwargs): + cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) cache["encoder"] = {} @@ -528,9 +528,9 @@ def inference(self, cache: dict = {}, **kwargs, ): - # cache = kwargs.get("cache", {}) + if len(cache) == 0: - self.init_cache(cache) + self.init_cache(cache, **kwargs) meta_data = {} chunk_size = kwargs.get("chunk_size", 60000) # 50ms diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index 6cb65dda9..62d6be80b 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -147,6 +147,7 @@ def run(self): for epoch in range(self.start_epoch, self.max_epoch + 1): self._train_epoch(epoch) + if self.use_ddp or self.use_fsdp: dist.barrier() @@ -156,6 +157,7 @@ def run(self): if self.use_ddp or self.use_fsdp: dist.barrier() + if self.rank == 0: self._save_checkpoint(epoch) @@ -170,7 +172,7 @@ def run(self): if self.use_ddp or self.use_fsdp: dist.barrier() - + if self.writer: self.writer.close()