Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama.vim: speculative fim #31

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions autoload/llama.vim
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ function! llama#fim(is_auto, cache) abort
let s:content = []
let s:can_accept = v:false

let s:speculating = v:false

let s:pos_x = col('.') - 1
let s:pos_y = line('.')
let l:max_y = line('$')
Expand Down Expand Up @@ -593,6 +595,143 @@ function! s:on_move()
call llama#fim_cancel()
endfunction

function! llama#speculate()
if s:speculating == v:true || ( len(s:content) == 1 && s:content[0] == "" )
return
endif

if len(s:content) > 0
" gather prompt
if len(s:content) == 1
" take the first line of the suggestion if content is only one line
let l:cur_line = s:line_cur[:(s:pos_x - 1)] . s:content[0]
let l:pos_x_spec = s:pos_x + len(s:content[0] -1)
else
let l:cur_line = s:content[-1]
let l:pos_x_spec = len(s:content[-1])
endif

let l:line_cur_prefix_spec = strpart(l:cur_line, 0, l:pos_x_spec)

" gather prefix
let l:pos_y_spec = s:pos_y + len(s:content) - 1
let l:lines_prefix = getline(max([1, l:pos_y_spec - g:llama_config.n_prefix]), l:pos_y_spec - 2)
let l:content_cpy = []

call add(l:lines_prefix , s:line_cur[:(s:pos_x - 1)] . s:content[0])
call extend(l:lines_prefix , s:content[1:])

let l:lines_prefix = l:lines_prefix[-(min([len(l:lines_prefix), (g:llama_config.n_prefix + 1)])):-2]

" gather suffix
let l:max_y = line('$')
let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix]))

let l:prompt = ""
\ . l:line_cur_prefix_spec

let l:prefix = ""
\ . join(l:lines_prefix, "\n")
\ . "\n"

let l:suffix = ""
\ . "\n"
\ . join(l:lines_suffix, "\n")
\ . "\n"

" prepare the extra context data
let l:extra_context = []
for l:chunk in s:ring_chunks
call add(l:extra_context, {
\ 'text': l:chunk.str,
\ 'time': l:chunk.time,
\ 'filename': l:chunk.filename
\ })
endfor

" the indentation of the current line
let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*'))

let l:request = json_encode({
\ 'input_prefix': l:prefix,
\ 'input_suffix': l:suffix,
\ 'input_extra': l:extra_context,
\ 'prompt': l:prompt,
\ 'n_predict': g:llama_config.n_predict,
\ 'n_indent': l:indent,
\ 'top_k': 40,
\ 'top_p': 0.99,
\ 'stream': v:false,
\ 'samplers': ["top_k", "top_p", "infill"],
\ 'cache_prompt': v:true,
\ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms,
\ 't_max_predict_ms': g:llama_config.t_max_predict_ms,
\ 'response_fields': [
\ "content",
\ "timings/prompt_n",
\ "timings/prompt_ms",
\ "timings/prompt_per_token_ms",
\ "timings/prompt_per_second",
\ "timings/predicted_n",
\ "timings/predicted_ms",
\ "timings/predicted_per_token_ms",
\ "timings/predicted_per_second",
\ "truncated",
\ "tokens_cached",
\ ],
\ })

let l:curl_command = [
\ "curl",
\ "--silent",
\ "--no-buffer",
\ "--request", "POST",
\ "--url", g:llama_config.endpoint,
\ "--header", "Content-Type: application/json",
\ "--data", l:request
\ ]

if exists ("g:llama_config.api_key") && len("g:llama_config.api_key") > 0
call extend(l:curl_command, ['--header', 'Authorization: Bearer ' .. g:llama_config.api_key])
endif

if s:current_job != v:null
if s:ghost_text_nvim
call jobstop(s:current_job)
elseif s:ghost_text_vim
call job_stop(s:current_job)
endif
endif

let s:speculating = v:true

" Construct hash from prefix, prompt, and suffix with separators
let l:request_context = l:prefix . 'Î' . l:prompt . 'Î' . l:suffix
let l:hash = sha256(l:request_context)

let l:cached_completion = get(g:result_cache, l:hash, v:null)

if l:cached_completion != v:null
call s:fim_on_stdout(l:hash, v:true, l:pos_x_spec, l:pos_y_spec, v:true, 0, l:cached_completion)
else
" send the request asynchronously
if s:ghost_text_nvim
let s:current_job = jobstart(l:curl_command, {
\ 'on_stdout': function('s:fim_on_stdout', [l:hash, v:true, l:pos_x_spec, l:pos_y_spec, v:true]),
\ 'on_exit': function('s:fim_on_exit'),
\ 'stdout_buffered': v:true
\ })
elseif s:ghost_text_vim
let s:current_job = job_start(l:curl_command, {
\ 'out_cb': function('s:fim_on_stdout', [l:hash, v:true, l:pos_x_spec, l:pos_y_spec, v:true]),
\ 'exit_cb': function('s:fim_on_exit')
\ })
endif
endif

endif
endfunction

" TODO: Currently the cache uses a random eviction policy. A more clever policy could be implemented (eg. LRU).
function! s:insert_cache(key, value)
if len(keys(g:result_cache)) > (g:llama_config.max_cache_keys - 1)
Expand Down Expand Up @@ -852,6 +991,10 @@ function! s:fim_on_stdout(hash, cache, pos_x, pos_y, is_auto, job_id, data, even
inoremap <buffer> <C-B> <C-O>:call llama#fim_accept('word')<CR>

let s:hint_shown = v:true

" speculate the next completion
call llama#speculate()
let s:speculating = v:false
endfunction

function! s:fim_on_exit(job_id, exit_code, event = v:null)
Expand Down