Skip to content

Commit

Permalink
add flash_attn,
Browse files Browse the repository at this point in the history
add audio_ctx, suppress_regex, temperature, no_timestamps - todo: DTW/grammar
remove unused functions (related to parsing command line arguments)
  • Loading branch information
jwijffels committed Oct 7, 2024
1 parent 2c24bbc commit 208f5d0
Showing 1 changed file with 14 additions and 150 deletions.
164 changes: 14 additions & 150 deletions src/rcpp_whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,152 +85,6 @@ struct whisper_params {
grammar_parser::parse_state grammar_parsed;
};

static void whisper_print_usage(int argc, char ** argv, const whisper_params & params);

static char * whisper_param_turn_lowercase(char * in){
int string_len = strlen(in);
for (int i = 0; i < string_len; i++){
*(in+i) = tolower((unsigned char)*(in+i));
}
return in;
}

static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

if (arg == "-"){
params.fname_inp.push_back(arg);
continue;
}

if (arg[0] != '-') {
params.fname_inp.push_back(arg);
continue;
}

if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}

return true;
}

static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false");
fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false");
fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false");
fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str());
fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false");
fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false");
fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false");
fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", "");
fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}

struct whisper_print_user_data {
const whisper_params * params;
Expand Down Expand Up @@ -322,9 +176,10 @@ static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
class WhisperModel {
public:
struct whisper_context * ctx;
WhisperModel(std::string model, bool use_gpu = false){
WhisperModel(std::string model, bool use_gpu = false, bool flash_attn = false){
struct whisper_context_params cparams;
cparams.use_gpu = use_gpu;
cparams.flash_attn = flash_attn;
ctx = whisper_init_from_file_with_params(model.c_str(), cparams);
}
~WhisperModel(){
Expand Down Expand Up @@ -384,6 +239,10 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
Rcpp::stop("Unknown language");
}

if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}

// whisper init
Rcpp::XPtr<WhisperModel> whispermodel(model);
struct whisper_context * ctx = whispermodel->ctx;
Expand Down Expand Up @@ -460,25 +319,30 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;

//wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

wparams.initial_prompt = params.prompt.c_str();
wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str();

wparams.initial_prompt = params.prompt.c_str();


wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc;
wparams.temperature = params.temperature;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;

whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
wparams.no_timestamps = params.no_timestamps;

whisper_print_user_data user_data = { &params, &pcmf32s, 0 };

// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
Expand Down

0 comments on commit 208f5d0

Please sign in to comment.