From f777a73e18c218848bca0748581c043987348a5d Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Sun, 23 Feb 2025 13:14:32 +0000 Subject: [PATCH] Some llama-run cleanups (#11973) Use consolidated open function call from File class. Change read_all to to_string(). Remove exclusive locking, the intent for that lock is to avoid multiple processes writing to the same file, it's not an issue for readers, although we may want to consider adding a shared lock. Remove passing nullptr as reference, references are never supposed to be null. clang-format the code for consistent styling. Signed-off-by: Eric Curtin --- examples/run/run.cpp | 91 ++++++++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 4da1e50251600..de736c7d5a3d9 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -323,25 +323,17 @@ class File { return 0; } - std::string read_all(const std::string & filename){ - open(filename, "r"); - lock(); - if (!file) { - printe("Error opening file '%s': %s", filename.c_str(), strerror(errno)); - return ""; - } - + std::string to_string() { fseek(file, 0, SEEK_END); - size_t size = ftell(file); + const size_t size = ftell(file); fseek(file, 0, SEEK_SET); - std::string out; out.resize(size); - size_t read_size = fread(&out[0], 1, size, file); + const size_t read_size = fread(&out[0], 1, size, file); if (read_size != size) { - printe("Error reading file '%s': %s", filename.c_str(), strerror(errno)); - return ""; + printe("Error reading file: %s", strerror(errno)); } + return out; } @@ -1098,59 +1090,66 @@ static int get_user_input(std::string & user_input, const std::string & user) { // Reads a chat template file to be used static std::string read_chat_template_file(const std::string & chat_template_file) { - if(chat_template_file.empty()){ - return ""; - } - File file; - std::string chat_template = ""; - chat_template = file.read_all(chat_template_file); - if(chat_template.empty()){ + if (!file.open(chat_template_file, "r")) { printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); return ""; } - return chat_template; + + return file.to_string(); +} + +static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data, + const common_chat_templates_ptr & chat_templates, int & prev_len, + const bool stdout_a_terminal) { + add_message("user", opt.user.empty() ? user_input : opt.user, llama_data); + int new_len; + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) { + return 1; + } + + std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); + std::string response; + if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { + return 1; + } + + if (!opt.user.empty()) { + return 2; + } + + add_message("assistant", response, llama_data); + if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) { + return 1; + } + + return 0; } // Main chat loop function -static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) { +static int chat_loop(LlamaData & llama_data, const Opt & opt) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - - std::string chat_template = ""; - if(!chat_template_file.empty()){ - chat_template = read_chat_template_file(chat_template_file); + std::string chat_template; + if (!opt.chat_template_file.empty()) { + chat_template = read_chat_template_file(opt.chat_template_file); } - auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template); + common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input std::string user_input; - if (get_user_input(user_input, user) == 1) { + if (get_user_input(user_input, opt.user) == 1) { return 0; } - add_message("user", user.empty() ? user_input : user, llama_data); - int new_len; - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) { - return 1; - } - - std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); - std::string response; - if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { + const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal); + if (ret == 1) { return 1; - } - - if (!user.empty()) { + } else if (ret == 2) { break; } - - add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) { - return 1; - } } return 0; @@ -1208,7 +1207,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) { + if (chat_loop(llama_data, opt)) { return 1; }