Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Funasr1.0 #1343

Merged
merged 6 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# example2
import torchaudio
import os
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
input_tensor, sample_rate = torchaudio.load(wav_file)
input_tensor = input_tensor.mean(0)
Expand All @@ -33,7 +34,7 @@

# example3
import soundfile
import os

wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
Expand Down
2 changes: 1 addition & 1 deletion funasr/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def main(**kwargs):
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
Expand Down
10 changes: 7 additions & 3 deletions funasr/datasets/audio_datasets/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, dataset,
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle and is_training
self.length_scale_source = kwargs.get("length_scale_source", 1.0)


def __len__(self):
return (self.total_samples-1) // self.batch_size + 1
Expand Down Expand Up @@ -53,8 +55,10 @@ def __iter__(self):

idx_map = self.shuffle_idx[idx]
# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
sample_len_cur = self.dataset.get_source_len(idx_map) + \
self.dataset.get_target_len(idx_map)
target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0
source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source
sample_len_cur = source_len + target_len


datalen_with_index.append([idx, sample_len_cur])

Expand All @@ -66,7 +70,7 @@ def __iter__(self):

max_token_cur = max(max_token, sample_len_cur_raw)
max_token_padding = 1 + num_sample
if self.batch_type == 'length':
if self.batch_type != 'example':
max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
Expand Down
8 changes: 6 additions & 2 deletions funasr/models/whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@


from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function
from funasr.register import tables


@dataclass
class ModelDimensions:
Expand Down Expand Up @@ -128,6 +130,8 @@ def forward(
return x



@tables.register("encoder_classes", "WhisperEncoder")
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
Expand Down Expand Up @@ -158,7 +162,7 @@ def forward(self, x: Tensor):
x = self.ln_post(x)
return x


@tables.register("decoder_classes", "WhisperDecoder")
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
Expand Down Expand Up @@ -193,7 +197,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):

return logits


@tables.register("model_classes", "Whisper")
class Whisper(nn.Module):
def __init__(self, dims: dict):
super().__init__()
Expand Down
Loading