From 116eea622a330f19ed3b05ee5044bbcab8d2abf1 Mon Sep 17 00:00:00 2001 From: chottolabs <171991982+chottolabs@users.noreply.github.com> Date: Sat, 7 Sep 2024 22:13:25 -0400 Subject: [PATCH] implement prefill --- lua/kznllm/presets.lua | 48 +++++++++++++------ .../fill_mode_user_prompt.xml.jinja | 4 ++ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/lua/kznllm/presets.lua b/lua/kznllm/presets.lua index 765ca13..b698015 100644 --- a/lua/kznllm/presets.lua +++ b/lua/kznllm/presets.lua @@ -24,6 +24,7 @@ M.PROMPT_ARGS_STATE = { user_query = nil, replace = nil, context_files = nil, + prefill = nil, } M.NS_ID = api.nvim_create_namespace 'kznllm_ns' @@ -31,25 +32,33 @@ M.NS_ID = api.nvim_create_namespace 'kznllm_ns' local group = api.nvim_create_augroup('LLM_AutoGroup', { clear = true }) ---Example implementation of a `make_data_fn` compatible with `kznllm.invoke_llm` for groq spec ---@param prompt_args any ----@param opts { model: string, data_params: table, template_directory: Path, debug: boolean } +---@param opts { model: string, prefill:string, data_params: table, stop_param: table, template_directory: Path, debug: boolean } ---@return table --- local function make_data_for_openai_chat(prompt_args, opts) - local data = { - messages = { - { - role = 'system', - content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_system_prompt.xml.jinja', prompt_args), - }, - { - role = 'user', - content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_user_prompt.xml.jinja', prompt_args), - }, + local messages = { + { + role = 'system', + content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_system_prompt.xml.jinja', prompt_args), + }, + { + role = 'user', + content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_user_prompt.xml.jinja', prompt_args), }, + } + + local data = { + messages = messages, model = opts.model, stream = true, } - data = vim.tbl_extend('keep', data, opts.data_params) + if M.PROMPT_ARGS_STATE.replace and opts.prefill and opts.stop_param then + table.insert(messages, { + role = 'assistant', + content = opts.prefill .. prompt_args.current_buffer_filetype .. '\n', + }) + end + data = vim.tbl_extend('keep', data, opts.data_params, opts.stop_param or {}) return data end @@ -93,11 +102,14 @@ end local function openai_debug_fn(data, ns_id, extmark_id, opts) kznllm.write_content_at_extmark('model: ' .. opts.model, ns_id, extmark_id) - kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id) for _, message in ipairs(data.messages) do + kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id) kznllm.write_content_at_extmark(message.role .. ':\n\n', ns_id, extmark_id) kznllm.write_content_at_extmark(message.content, ns_id, extmark_id) - kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id) + + if not (M.BUFFER_STATE.SCRATCH and opts.prefill) then + kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id) + end vim.cmd 'normal! G' end end @@ -134,7 +146,7 @@ end ---@param make_data_fn fun(prompt_args: table, opts: table) ---@param make_curl_args_fn fun(data: table, opts: table) ---@param make_job_fn fun(data: table, writer_fn: fun(content: string), on_exit_fn: fun()) ----@param opts { debug: string?, debug_fn: fun(data: table, ns_id: integer, extmark_id: integer, opts: table)?, stop_dir: Path?, context_dir_id: string?, data_params: table } +---@param opts { debug: string?, debug_fn: fun(data: table, ns_id: integer, extmark_id: integer, opts: table)?, stop_dir: Path?, context_dir_id: string?, data_params: table, prefill: boolean } function M.invoke_llm(make_data_fn, make_curl_args_fn, make_job_fn, opts) api.nvim_clear_autocmds { group = group } @@ -162,6 +174,7 @@ function M.invoke_llm(make_data_fn, make_curl_args_fn, make_job_fn, opts) M.PROMPT_ARGS_STATE.current_buffer_path = buf_path M.PROMPT_ARGS_STATE.current_buffer_context = buf_context end + M.PROMPT_ARGS_STATE.prefill = opts.prefill local data = make_data_fn(M.PROMPT_ARGS_STATE, opts) @@ -232,6 +245,9 @@ local presets = { -- max_tokens = 8192, temperature = 0.7, }, + -- doesn't support prefill + -- stop_param = { stop = { '```' } }, + -- prefill = '```', debug_fn = openai_debug_fn, base_url = 'https://api.groq.com', endpoint = '/openai/v1/chat/completions', @@ -250,6 +266,8 @@ local presets = { min_p = 0.05, logprobs = 1, }, + stop_param = { stop_token_ids = { 74694 } }, + prefill = '```', debug_fn = openai_debug_fn, base_url = 'https://api.lambdalabs.com', endpoint = '/v1/chat/completions', diff --git a/templates/nous_research/fill_mode_user_prompt.xml.jinja b/templates/nous_research/fill_mode_user_prompt.xml.jinja index 92b0c37..d39f0b6 100644 --- a/templates/nous_research/fill_mode_user_prompt.xml.jinja +++ b/templates/nous_research/fill_mode_user_prompt.xml.jinja @@ -24,7 +24,11 @@ Code: ``` QUERY: {{user_query}} +{% if prefill -%} +INSTRUCTION: Replace the code in the block given above. ONLY return the valid code fragment that is requested in the following snippet surrounded by a code fence with backticks. DO NOT output provide anything except the content inside the code fence. +{%- else -%} INSTRUCTION: Replace the code in the block given above. ONLY return the code fragment that is requested in the following snippet WITHOUT backticks. DO NOT surround the code fragment in backticks. +{%- endif %} {%- else -%} QUERY: {{user_query}} {%- endif -%}