Skip to content

Commit

Permalink
Merge branch 'main' into funasr1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
LauraGPT authored Jan 19, 2024
2 parents 391d762 + 12496e5 commit 82c7277
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
10 changes: 5 additions & 5 deletions funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def __init__(self,
self.waveform = None
self.last_drop_frames = 0


@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
"""
Expand Down Expand Up @@ -500,8 +499,9 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
# # reset class variables and clear the dict for the next query
# self.AllResetDetection()
return segments

def init_cache(self, cache: dict = {}):

def init_cache(self, cache: dict = {}, **kwargs):

cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
Expand All @@ -528,9 +528,9 @@ def inference(self,
cache: dict = {},
**kwargs,
):
# cache = kwargs.get("cache", {})

if len(cache) == 0:
self.init_cache(cache)
self.init_cache(cache, **kwargs)

meta_data = {}
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
Expand Down
4 changes: 3 additions & 1 deletion funasr/train_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def run(self):
for epoch in range(self.start_epoch, self.max_epoch + 1):

self._train_epoch(epoch)


if self.use_ddp or self.use_fsdp:
dist.barrier()
Expand All @@ -156,6 +157,7 @@ def run(self):
if self.use_ddp or self.use_fsdp:
dist.barrier()


if self.rank == 0:
self._save_checkpoint(epoch)

Expand All @@ -170,7 +172,7 @@ def run(self):

if self.use_ddp or self.use_fsdp:
dist.barrier()

if self.writer:
self.writer.close()

Expand Down

0 comments on commit 82c7277

Please sign in to comment.