Skip to content

Commit

Permalink
put whisper_encode back in it's original form, disable main from whis…
Browse files Browse the repository at this point in the history
…per.cpp
  • Loading branch information
jwijffels committed Oct 6, 2024
1 parent fa2ae97 commit 25473f2
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 2 deletions.
31 changes: 31 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,36 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// whisper_encode
Rcpp::List whisper_encode(SEXP model, std::string path, std::string language, bool token_timestamps, bool translate, Rcpp::IntegerVector duration, Rcpp::IntegerVector offset, int trace, int n_threads, int n_processors, float entropy_thold, float logprob_thold, int beam_size, int best_of, bool split_on_word, int max_context, std::string prompt, bool print_special, bool diarize, float diarize_percent);
RcppExport SEXP _audio_whisper_whisper_encode(SEXP modelSEXP, SEXP pathSEXP, SEXP languageSEXP, SEXP token_timestampsSEXP, SEXP translateSEXP, SEXP durationSEXP, SEXP offsetSEXP, SEXP traceSEXP, SEXP n_threadsSEXP, SEXP n_processorsSEXP, SEXP entropy_tholdSEXP, SEXP logprob_tholdSEXP, SEXP beam_sizeSEXP, SEXP best_ofSEXP, SEXP split_on_wordSEXP, SEXP max_contextSEXP, SEXP promptSEXP, SEXP print_specialSEXP, SEXP diarizeSEXP, SEXP diarize_percentSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< SEXP >::type model(modelSEXP);
Rcpp::traits::input_parameter< std::string >::type path(pathSEXP);
Rcpp::traits::input_parameter< std::string >::type language(languageSEXP);
Rcpp::traits::input_parameter< bool >::type token_timestamps(token_timestampsSEXP);
Rcpp::traits::input_parameter< bool >::type translate(translateSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type duration(durationSEXP);
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type offset(offsetSEXP);
Rcpp::traits::input_parameter< int >::type trace(traceSEXP);
Rcpp::traits::input_parameter< int >::type n_threads(n_threadsSEXP);
Rcpp::traits::input_parameter< int >::type n_processors(n_processorsSEXP);
Rcpp::traits::input_parameter< float >::type entropy_thold(entropy_tholdSEXP);
Rcpp::traits::input_parameter< float >::type logprob_thold(logprob_tholdSEXP);
Rcpp::traits::input_parameter< int >::type beam_size(beam_sizeSEXP);
Rcpp::traits::input_parameter< int >::type best_of(best_ofSEXP);
Rcpp::traits::input_parameter< bool >::type split_on_word(split_on_wordSEXP);
Rcpp::traits::input_parameter< int >::type max_context(max_contextSEXP);
Rcpp::traits::input_parameter< std::string >::type prompt(promptSEXP);
Rcpp::traits::input_parameter< bool >::type print_special(print_specialSEXP);
Rcpp::traits::input_parameter< bool >::type diarize(diarizeSEXP);
Rcpp::traits::input_parameter< float >::type diarize_percent(diarize_percentSEXP);
rcpp_result_gen = Rcpp::wrap(whisper_encode(model, path, language, token_timestamps, translate, duration, offset, trace, n_threads, n_processors, entropy_thold, logprob_thold, beam_size, best_of, split_on_word, max_context, prompt, print_special, diarize, diarize_percent));
return rcpp_result_gen;
END_RCPP
}
// whisper_load_model
SEXP whisper_load_model(std::string model, bool use_gpu);
RcppExport SEXP _audio_whisper_whisper_load_model(SEXP modelSEXP, SEXP use_gpuSEXP) {
Expand Down Expand Up @@ -45,6 +75,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_audio_whisper_whisper_encode", (DL_FUNC) &_audio_whisper_whisper_encode, 20},
{"_audio_whisper_whisper_load_model", (DL_FUNC) &_audio_whisper_whisper_load_model, 2},
{"_audio_whisper_whisper_print_benchmark", (DL_FUNC) &_audio_whisper_whisper_print_benchmark, 2},
{"_audio_whisper_whisper_language_info", (DL_FUNC) &_audio_whisper_whisper_language_info, 0},
Expand Down
253 changes: 251 additions & 2 deletions src/rcpp_whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper

static void cb_log_disable(enum ggml_log_level , const char * , void * ) { }

/*
int main(int argc, char ** argv) {
whisper_params params;
Expand Down Expand Up @@ -373,7 +374,6 @@ int main(int argc, char ** argv) {
whisper_print_usage(argc, argv, params);
return 1;
}

// remove non-existent files
for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();) {
const auto fname_inp = it->c_str();
Expand Down Expand Up @@ -447,7 +447,6 @@ int main(int argc, char ** argv) {
// initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured
whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr);

if (!params.grammar.empty()) {
auto & grammar = params.grammar_parsed;
if (is_file_exist(params.grammar.c_str())) {
Expand Down Expand Up @@ -623,6 +622,256 @@ int main(int argc, char ** argv) {

return 0;
}
*/

// [[Rcpp::export]]
Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
bool token_timestamps = false, bool translate = false, Rcpp::IntegerVector duration = 0, Rcpp::IntegerVector offset = 0, int trace = 1,
int n_threads = 1, int n_processors = 1,
float entropy_thold = 2.40,
float logprob_thold = -1.00,
int beam_size = -1,
int best_of = 5,
bool split_on_word = false,
int max_context = -1,
std::string prompt = "",
bool print_special = false,
bool diarize = false,
float diarize_percent = 1.1) {
float audio_duration=0;

whisper_params params;
params.language = language;
params.translate = translate;
params.print_special = print_special;
params.duration_ms = duration[0];
params.offset_t_ms = offset[0];
params.fname_inp.push_back(path);
params.n_threads = n_threads;
params.n_processors = n_processors;

params.entropy_thold = entropy_thold;
params.logprob_thold = logprob_thold;
params.beam_size = beam_size;
params.best_of = best_of;
params.split_on_word = split_on_word;
params.max_context = max_context;
params.prompt = prompt;
params.diarize = diarize;
if (params.fname_inp.empty()) {
Rcpp::stop("error: no input files specified");
}

if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
Rcpp::stop("Unknown language");
}

// whisper init
Rcpp::XPtr<WhisperModel> whispermodel(model);
struct whisper_context * ctx = whispermodel->ctx;
//Rcpp::XPtr<whisper_context> ctx(model);
//struct whisper_context * ctx = whisper_init(params.model.c_str());

const auto fname_inp = params.fname_inp[0];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM

if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
Rprintf("error: failed to read WAV file '%s'\n", fname_inp.c_str());
Rcpp::stop("The input audio needs to be a 16-bit .wav file.");
}

if(trace > 0){
Rprintf("system_info: n_threads = %d / %d | %s\n", params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}

{
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
Rcpp::warning("WARNING: model is not multilingual, ignoring language and translation options");
}
}
if(trace > 0){
Rcpp::Rcout << "Processing " << fname_inp << " (" << int(pcmf32.size()) << " samples, " << float(pcmf32.size())/WHISPER_SAMPLE_RATE << " sec)" << ", lang = " << params.language << ", translate = " << params.translate << ", timestamps = " << token_timestamps << ", beam_size = " << params.beam_size << ", best_of = " << params.best_of << "\n";
}
}
audio_duration = float(pcmf32.size())/WHISPER_SAMPLE_RATE;

// Structures to get the data back in R
std::vector<int> segment_nr;
std::vector<int> segment_offset;
Rcpp::StringVector transcriptions(0);
Rcpp::StringVector transcriptions_from(0);
Rcpp::StringVector transcriptions_to(0);
Rcpp::StringVector transcriptions_speaker(0);
std::vector<int> token_segment_nr;
std::vector<int> token_segment_id;
std::vector<std::string> token_segment_text;
std::vector<float> token_segment_probability;
std::vector<std::string> token_segment_from;
std::vector<std::string> token_segment_to;
//Rcpp::StringVector token_speaker(0);
int n_segments;

for (int f = 0; f < (int) offset.size(); ++f) {
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);

wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;

wparams.print_realtime = false;
wparams.print_progress = false;
if(trace > 0){
wparams.print_progress = true;
wparams.print_realtime = true;
}
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.detect_language = params.detect_language;
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = (int) offset[f];
wparams.duration_ms = (int) duration[f];

wparams.token_timestamps = token_timestamps;
wparams.thold_pt = params.word_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;

wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode;

wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

wparams.initial_prompt = params.prompt.c_str();



wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;

wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;

whisper_print_user_data user_data = { &params, &pcmf32s, 0 };

// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
if(trace > 0 && offset.size() > 1){
Rcpp::Rcout << "Processing audio offset section " << f+1 << " (" << wparams.offset_ms << " ms - " << wparams.offset_ms+wparams.duration_ms << " ms)\n";
}

if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
Rcpp::stop("failed to process audio");
}
}
n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
segment_nr.push_back(segment_nr.size() + 1);
segment_offset.push_back(offset[f]);
const char * text = whisper_full_get_segment_text(ctx, i);
transcriptions.push_back(Rcpp::String(text));
int64_t t0 = whisper_full_get_segment_t0(ctx, i);
int64_t t1 = whisper_full_get_segment_t1(ctx, i);
transcriptions_from.push_back(Rcpp::String(to_timestamp(t0).c_str()));
transcriptions_to.push_back(Rcpp::String(to_timestamp(t1).c_str()));
Rcpp::String channel_speaker;
if (params.diarize && pcmf32s.size() == 2) {
channel_speaker = Rcpp::String(estimate_diarization_speaker(pcmf32s, t0, t1, true, diarize_percent));
}else{
channel_speaker = NA_STRING;
}
transcriptions_speaker.push_back(channel_speaker);

for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int tokenid = whisper_full_get_token_id (ctx, i, j);
token_segment_nr.push_back(i + 1);
token_segment_id.push_back(tokenid);
std::string str(text);
token_segment_text.push_back(str);
token_segment_probability.push_back(p);
if(token_timestamps){
whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
t0 = token.t0;
t1 = token.t1;
token_segment_from.push_back(Rcpp::String(to_timestamp(t0).c_str()));
token_segment_to.push_back(to_timestamp(token.t1));
}
//token_speaker.push_back(channel_speaker);
}
}
}
Rcpp::DataFrame tokens;
if(token_timestamps){
tokens = Rcpp::DataFrame::create(
Rcpp::Named("segment") = token_segment_nr,
Rcpp::Named("token_id") = token_segment_id,
Rcpp::Named("token") = token_segment_text,
Rcpp::Named("token_prob") = token_segment_probability,
Rcpp::Named("token_from") = token_segment_from,
Rcpp::Named("token_to") = token_segment_to,
//Rcpp::Named("token_speaker") = token_speaker,
Rcpp::Named("stringsAsFactors") = false);
}else{
tokens = Rcpp::DataFrame::create(
Rcpp::Named("segment") = token_segment_nr,
Rcpp::Named("token_id") = token_segment_id,
Rcpp::Named("token") = token_segment_text,
Rcpp::Named("token_prob") = token_segment_probability,
//Rcpp::Named("token_speaker") = token_speaker,
Rcpp::Named("stringsAsFactors") = false);
}

//whisper_free(ctx);
Rcpp::List output = Rcpp::List::create(Rcpp::Named("n_segments") = segment_nr.size(),
Rcpp::Named("data") = Rcpp::DataFrame::create(
Rcpp::Named("segment") = segment_nr,
Rcpp::Named("segment_offset") = segment_offset,
Rcpp::Named("from") = transcriptions_from,
Rcpp::Named("to") = transcriptions_to,
Rcpp::Named("text") = transcriptions,
Rcpp::Named("speaker") = transcriptions_speaker,
Rcpp::Named("stringsAsFactors") = false),
Rcpp::Named("tokens") = tokens,
Rcpp::Named("params") = Rcpp::List::create(
Rcpp::Named("audio") = path,
Rcpp::Named("audio_duration_seconds") = audio_duration,
Rcpp::Named("language") = params.language,
Rcpp::Named("offset") = offset,
Rcpp::Named("duration") = duration,
Rcpp::Named("translate") = params.translate,
Rcpp::Named("token_timestamps") = token_timestamps,
Rcpp::Named("word_threshold") = params.word_thold,
Rcpp::Named("entropy_thold") = params.entropy_thold,
Rcpp::Named("logprob_thold") = params.logprob_thold,
Rcpp::Named("beam_size") = params.beam_size,
Rcpp::Named("best_of") = params.best_of,
Rcpp::Named("split_on_word") = params.split_on_word,
Rcpp::Named("diarize") = params.diarize,
Rcpp::Named("system_info") = Rcpp::List::create(
Rcpp::Named("n_threads") = params.n_threads,
Rcpp::Named("n_processors") = params.n_processors,
Rcpp::Named("available_concurrency") = std::thread::hardware_concurrency(),
Rcpp::Named("optimisations") = whisper_print_system_info())));
return output;
}



Expand Down

0 comments on commit 25473f2

Please sign in to comment.