-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranscribe.py
78 lines (66 loc) · 2.08 KB
/
transcribe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
from torch import nn
import torch.utils.data as data
import torch.nn.functional as F
import torchaudio
from deep_speech import (
TextTransform,
data_processing,
SpeechRecognitionModel,
GreedyDecoder,
)
def asr(data, sample_test_example=True):
h_params = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
"n_class": 29,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
}
use_cuda = torch.cuda.is_available()
torch.manual_seed(7)
device = torch.device("cuda" if use_cuda else "cpu")
audio_transforms = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_mels=128
)
text_transform = TextTransform()
spectrograms, labels, input_lengths, label_lengths = data_processing(
data, audio_transforms, text_transform
)
if sample_test_example:
spectrograms, labels, input_lengths, label_lengths = (
spectrograms[:2],
labels[:2],
input_lengths[:2],
label_lengths[:2],
)
model = SpeechRecognitionModel(
h_params["n_cnn_layers"],
h_params["n_rnn_layers"],
h_params["rnn_dim"],
h_params["n_class"],
h_params["n_feats"],
h_params["stride"],
h_params["dropout"],
).to(device)
model.load_state_dict(torch.load("./model_checkpoint/deep_speech.pth"))
model.eval()
output = model(spectrograms.to(device))
output = F.log_softmax(output, dim=2)
decoded_preds, decoded_targets = GreedyDecoder(
output, labels, label_lengths, text_transform
)
print("The original labels is", decoded_targets)
print("**********")
print("The transcription by the model is", decoded_preds)
if __name__ == "__main__":
if os.listdir("./data"):
data = torchaudio.datasets.LIBRISPEECH(
"./data", url="test-clean", download=False
)
asr(data)
else:
raise ValueError("upload audio files for transcribing")