Skip to content

Commit

Permalink
implement prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
chottolabs committed Sep 8, 2024
1 parent 23cdc39 commit 116eea6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
48 changes: 33 additions & 15 deletions lua/kznllm/presets.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,41 @@ M.PROMPT_ARGS_STATE = {
user_query = nil,
replace = nil,
context_files = nil,
prefill = nil,
}

M.NS_ID = api.nvim_create_namespace 'kznllm_ns'

local group = api.nvim_create_augroup('LLM_AutoGroup', { clear = true })
---Example implementation of a `make_data_fn` compatible with `kznllm.invoke_llm` for groq spec
---@param prompt_args any
---@param opts { model: string, data_params: table, template_directory: Path, debug: boolean }
---@param opts { model: string, prefill:string, data_params: table, stop_param: table, template_directory: Path, debug: boolean }
---@return table
---
local function make_data_for_openai_chat(prompt_args, opts)
local data = {
messages = {
{
role = 'system',
content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_system_prompt.xml.jinja', prompt_args),
},
{
role = 'user',
content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_user_prompt.xml.jinja', prompt_args),
},
local messages = {
{
role = 'system',
content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_system_prompt.xml.jinja', prompt_args),
},
{
role = 'user',
content = kznllm.make_prompt_from_template(opts.template_directory / 'nous_research/fill_mode_user_prompt.xml.jinja', prompt_args),
},
}

local data = {
messages = messages,
model = opts.model,
stream = true,
}
data = vim.tbl_extend('keep', data, opts.data_params)
if M.PROMPT_ARGS_STATE.replace and opts.prefill and opts.stop_param then
table.insert(messages, {
role = 'assistant',
content = opts.prefill .. prompt_args.current_buffer_filetype .. '\n',
})
end
data = vim.tbl_extend('keep', data, opts.data_params, opts.stop_param or {})

return data
end
Expand Down Expand Up @@ -93,11 +102,14 @@ end

local function openai_debug_fn(data, ns_id, extmark_id, opts)
kznllm.write_content_at_extmark('model: ' .. opts.model, ns_id, extmark_id)
kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id)
for _, message in ipairs(data.messages) do
kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id)
kznllm.write_content_at_extmark(message.role .. ':\n\n', ns_id, extmark_id)
kznllm.write_content_at_extmark(message.content, ns_id, extmark_id)
kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id)

if not (M.BUFFER_STATE.SCRATCH and opts.prefill) then
kznllm.write_content_at_extmark('\n\n---\n\n', ns_id, extmark_id)
end
vim.cmd 'normal! G'
end
end
Expand Down Expand Up @@ -134,7 +146,7 @@ end
---@param make_data_fn fun(prompt_args: table, opts: table)
---@param make_curl_args_fn fun(data: table, opts: table)
---@param make_job_fn fun(data: table, writer_fn: fun(content: string), on_exit_fn: fun())
---@param opts { debug: string?, debug_fn: fun(data: table, ns_id: integer, extmark_id: integer, opts: table)?, stop_dir: Path?, context_dir_id: string?, data_params: table }
---@param opts { debug: string?, debug_fn: fun(data: table, ns_id: integer, extmark_id: integer, opts: table)?, stop_dir: Path?, context_dir_id: string?, data_params: table, prefill: boolean }
function M.invoke_llm(make_data_fn, make_curl_args_fn, make_job_fn, opts)
api.nvim_clear_autocmds { group = group }

Expand Down Expand Up @@ -162,6 +174,7 @@ function M.invoke_llm(make_data_fn, make_curl_args_fn, make_job_fn, opts)
M.PROMPT_ARGS_STATE.current_buffer_path = buf_path
M.PROMPT_ARGS_STATE.current_buffer_context = buf_context
end
M.PROMPT_ARGS_STATE.prefill = opts.prefill

local data = make_data_fn(M.PROMPT_ARGS_STATE, opts)

Expand Down Expand Up @@ -232,6 +245,9 @@ local presets = {
-- max_tokens = 8192,
temperature = 0.7,
},
-- doesn't support prefill
-- stop_param = { stop = { '```' } },
-- prefill = '```',
debug_fn = openai_debug_fn,
base_url = 'https://api.groq.com',
endpoint = '/openai/v1/chat/completions',
Expand All @@ -250,6 +266,8 @@ local presets = {
min_p = 0.05,
logprobs = 1,
},
stop_param = { stop_token_ids = { 74694 } },
prefill = '```',
debug_fn = openai_debug_fn,
base_url = 'https://api.lambdalabs.com',
endpoint = '/v1/chat/completions',
Expand Down
4 changes: 4 additions & 0 deletions templates/nous_research/fill_mode_user_prompt.xml.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ Code:
```

QUERY: {{user_query}}
{% if prefill -%}
INSTRUCTION: Replace the code in the block given above. ONLY return the valid code fragment that is requested in the following snippet surrounded by a code fence with backticks. DO NOT output provide anything except the content inside the code fence.
{%- else -%}
INSTRUCTION: Replace the code in the block given above. ONLY return the code fragment that is requested in the following snippet WITHOUT backticks. DO NOT surround the code fragment in backticks.
{%- endif %}
{%- else -%}
QUERY: {{user_query}}
{%- endif -%}
Expand Down

0 comments on commit 116eea6

Please sign in to comment.