diff --git a/lua/cmp/config/compare.lua b/lua/cmp/config/compare.lua index ec79b7336..ef975f391 100644 --- a/lua/cmp/config/compare.lua +++ b/lua/cmp/config/compare.lua @@ -198,7 +198,8 @@ compare.locality = setmetatable({ ---scopes: Entries defined in a closer scope will be ranked higher (e.g., prefer local variables to globals). ---@type cmp.ComparatorFunctor compare.scopes = setmetatable({ - scopes_map = {}, + definition_depths = {}, + has_nvim_0_9_features = vim.fn.has('nvim-0.9') == 1, update = function(self) local config = require('cmp').get_config() if not vim.tbl_contains(config.sorting.comparators, compare.scopes) then @@ -207,64 +208,51 @@ compare.scopes = setmetatable({ local ok, locals = pcall(require, 'nvim-treesitter.locals') if ok then - local win, buf = vim.api.nvim_get_current_win(), vim.api.nvim_get_current_buf() - local cursor_row = vim.api.nvim_win_get_cursor(win)[1] - 1 - - -- Cursor scope. - local cursor_scope = nil - -- Prioritize the older get_scopes method from nvim-treesitter `master` over get from `main` - local scopes = locals.get_scopes and locals.get_scopes(buf) or select(3, locals.get(buf)) - for _, scope in ipairs(scopes) do - if scope:start() <= cursor_row and cursor_row <= scope:end_() then - if not cursor_scope then - cursor_scope = scope - else - if cursor_scope:start() <= scope:start() and scope:end_() <= cursor_scope:end_() then - cursor_scope = scope - end - end - elseif cursor_scope and cursor_scope:end_() <= scope:start() then - break - end + self.definition_depths = {} + local buf = vim.api.nvim_get_current_buf() + if self.has_nvim_0_9_features and not vim.b[buf].cmp_buf_has_ts_parser then + return end - -- Definitions. - local definitions = locals.get_definitions_lookup_table(buf) - - -- Narrow definitions. + local get_cursor_node = vim.treesitter.get_node or require('nvim-treesitter.ts_utils').get_node_at_cursor + local cursor_node = get_cursor_node() + local scope_depths = {} local depth = 0 - for scope in locals.iter_scope_tree(cursor_scope, buf) do - local s, e = scope:start(), scope:end_() + -- If there's no cursor node, no iterations are made. + ---@diagnostic disable-next-line: param-type-mismatch + for scope in locals.iter_scope_tree(cursor_node, buf) do + scope_depths[scope:id()] = depth + depth = depth + 1 + end - -- Check scope's direct child. - for _, definition in pairs(definitions) do - if s <= definition.node:start() and definition.node:end_() <= e then - if scope:id() == locals.containing_scope(definition.node, buf):id() then - local get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text - local text = get_node_text(definition.node, buf) or '' - if not self.scopes_map[text] then - self.scopes_map[text] = depth - end - end + -- Map definitions based on their scope relative to the cursor. + local definitions = locals.get_definitions_lookup_table(buf) + local get_node_text = vim.treesitter.get_node_text or vim.treesitter.query.get_node_text + for _, definition in pairs(definitions) do + local definition_depth = scope_depths[locals.containing_scope(definition.node, buf):id()] + local def_text = get_node_text(definition.node, buf) or '' + if definition_depth then + -- Prefer the closest scoped definitions. + if not self.definition_depths[def_text] or self.definition_depths[def_text] > definition_depth then + self.definition_depths[def_text] = definition_depth end end - depth = depth + 1 end end end, }, { ---@type fun(self: table, entry1: cmp.Entry, entry2: cmp.Entry): boolean|nil __call = function(self, entry1, entry2) - local local1 = self.scopes_map[entry1.word] - local local2 = self.scopes_map[entry2.word] - if local1 ~= local2 then - if local1 == nil then + local def_depth1 = self.definition_depths[entry1.word] + local def_depth2 = self.definition_depths[entry2.word] + if def_depth1 ~= def_depth2 then + if def_depth1 == nil then return false end - if local2 == nil then + if def_depth2 == nil then return true end - return local1 < local2 + return def_depth1 < def_depth2 end end, }) diff --git a/lua/cmp/utils/autocmd.lua b/lua/cmp/utils/autocmd.lua index 438e23190..0d980a358 100644 --- a/lua/cmp/utils/autocmd.lua +++ b/lua/cmp/utils/autocmd.lua @@ -19,8 +19,8 @@ autocmd.subscribe = function(events, callback) vim.api.nvim_create_autocmd(event, { desc = ('nvim-cmp: autocmd: %s'):format(event), group = autocmd.group, - callback = function() - autocmd.emit(event) + callback = function(details) + autocmd.emit(event, details) end, }) end @@ -41,12 +41,13 @@ end ---Emit autocmd ---@param event string -autocmd.emit = function(event) +---@param details table|nil +autocmd.emit = function(event, details) debug.log(' ') debug.log(string.format('>>> %s', event)) autocmd.events[event] = autocmd.events[event] or {} for _, callback in ipairs(autocmd.events[event]) do - callback() + callback(details) end end diff --git a/plugin/cmp.lua b/plugin/cmp.lua index 611b5c96a..110fe01a8 100644 --- a/plugin/cmp.lua +++ b/plugin/cmp.lua @@ -53,6 +53,30 @@ if vim.on_key then end, vim.api.nvim_create_namespace('cmp.plugin')) end +-- see compare.scopes +if vim.fn.has('nvim-0.9') == 1 then + local ts = vim.treesitter + local has_ts_parser = ts.language.get_lang + -- vim.treesitter.language.add is recommended for checking treesitter in 0.11 nightly + if vim.fn.has('nvim-0.11') then + has_ts_parser = function(filetype) + local lang = ts.language.get_lang(filetype) + return lang and ts.language.add(lang) + end + end + autocmd.subscribe({ 'FileType' }, function(details) + if has_ts_parser(details.match) then + vim.b[details.buf].cmp_buf_has_ts_parser = true + else + vim.b[details.buf].cmp_buf_has_ts_parser = false + end + end) + autocmd.subscribe({ 'BufUnload' }, function(details) + if vim.treesitter.language.get_lang(details.match) then + vim.b[details.buf].cmp_buf_has_ts_parser = false + end + end) +end vim.api.nvim_create_user_command('CmpStatus', function() require('cmp').status()