diff --git a/lua/blink/cmp/completion/windows/menu.lua b/lua/blink/cmp/completion/windows/menu.lua index b4c26b36..c82c4a96 100644 --- a/lua/blink/cmp/completion/windows/menu.lua +++ b/lua/blink/cmp/completion/windows/menu.lua @@ -51,7 +51,7 @@ function menu.open_with_items(context, items) menu.items = items menu.selected_item_idx = menu.selected_item_idx ~= nil and math.min(menu.selected_item_idx, #items) or nil - if not menu.renderer then menu.renderer = require('blink.cmp.completion.windows.render').new(context, config.draw) end + if not menu.renderer then menu.renderer = require('blink.cmp.completion.windows.render').new(config.draw) end menu.renderer:draw(context, menu.win:get_buf(), items, config.draw) local auto_show = menu.auto_show diff --git a/lua/blink/cmp/completion/windows/render/column.lua b/lua/blink/cmp/completion/windows/render/column.lua index 34374710..8f5699c7 100644 --- a/lua/blink/cmp/completion/windows/render/column.lua +++ b/lua/blink/cmp/completion/windows/render/column.lua @@ -1,11 +1,12 @@ --- @class blink.cmp.DrawColumn +--- @field component_names string[] --- @field components blink.cmp.DrawComponent[] --- @field gap number --- @field lines string[][] --- @field width number --- @field ctxs blink.cmp.DrawItemContext[] --- ---- @field new fun(components: blink.cmp.DrawComponent[], gap: number): blink.cmp.DrawColumn +--- @field new fun(component_names: string[], components: blink.cmp.DrawComponent[], gap: number): blink.cmp.DrawColumn --- @field render fun(self: blink.cmp.DrawColumn,context: blink.cmp.Context, ctxs: blink.cmp.DrawItemContext[]) --- @field get_line_text fun(self: blink.cmp.DrawColumn, line_idx: number): string --- @field get_line_highlights fun(self: blink.cmp.DrawColumn, line_idx: number): blink.cmp.DrawHighlight[] @@ -16,8 +17,9 @@ local text_lib = require('blink.cmp.completion.windows.render.text') --- @diagnostic disable-next-line: missing-fields local column = {} -function column.new(components, gap) +function column.new(component_names, components, gap) local self = setmetatable({}, { __index = column }) + self.component_names = component_names self.components = components self.gap = gap self.lines = {} diff --git a/lua/blink/cmp/completion/windows/render/init.lua b/lua/blink/cmp/completion/windows/render/init.lua index 700ac29d..c7dc2961 100644 --- a/lua/blink/cmp/completion/windows/render/init.lua +++ b/lua/blink/cmp/completion/windows/render/init.lua @@ -4,11 +4,11 @@ --- @field gap number --- @field columns blink.cmp.DrawColumn[] --- ---- @field new fun(context: blink.cmp.Context, draw: blink.cmp.Draw): blink.cmp.Renderer ---- @field draw fun(self: blink.cmp.Renderer, context: blink.cmp.Context, bufnr: number, items: blink.cmp.CompletionItem[], draw: blink.cmp.Draw) ---- @field pre_column fun(self: blink.cmp.Renderer, context: blink.cmp.Context, draw: blink.cmp.Draw) ---- @field get_component_column_location fun(self: blink.cmp.Renderer, component_name: string): { column_idx: number, component_idx: number } ---- @field get_component_start_col fun(self: blink.cmp.Renderer, component_name: string): number +--- @field new fun(draw: blink.cmp.Draw): blink.cmp.Renderer +--- @field draw fun(self: blink.cmp.Renderer, context: blink.cmp.Context, bufnr: number, items: blink.cmp.CompletionItem[], draw: blink.cmp.Draw): blink.cmp.DrawColumn[] +--- @field get_columns fun(self: blink.cmp.Renderer, context: blink.cmp.Context, draw: blink.cmp.Draw): blink.cmp.DrawColumn[] +--- @field get_component_column_location fun(self: blink.cmp.Renderer, columns: blink.cmp.DrawColumn[], component_name: string): { column_idx: number, component_idx: number } +--- @field get_component_start_col fun(self: blink.cmp.Renderer, columns: blink.cmp.DrawColumn[], component_name: string): number --- @field get_alignment_start_col fun(self: blink.cmp.Renderer): number local ns = vim.api.nvim_create_namespace('blink_cmp_renderer') @@ -16,14 +16,8 @@ local ns = vim.api.nvim_create_namespace('blink_cmp_renderer') --- @type blink.cmp.Renderer --- @diagnostic disable-next-line: missing-fields local renderer = {} ---- @type blink.cmp.DrawColumn[] -local columns = {} - -function renderer.new(context, draw) - --- Convert the component names in the columns to the component definitions - columns = draw.columns - if type(columns) == 'function' then draw.columns = columns(context) end +function renderer.new(draw) local padding = type(draw.padding) == 'number' and { draw.padding, draw.padding } or draw.padding --- @cast padding number[] @@ -31,12 +25,13 @@ function renderer.new(context, draw) self.padding = padding self.gap = draw.gap self.def = draw - renderer:pre_column(context, draw) return self end -function renderer:pre_column(context, draw) - if type(columns) == 'function' then draw.columns = columns(context) end +function renderer:get_columns(context, draw) + local columns = draw.columns + if type(columns) == 'function' then columns = columns(context) end + --- @cast columns blink.cmp.DrawColumnDefinition[] --- @type blink.cmp.DrawComponent[][] local columns_definitions = vim.tbl_map(function(column) @@ -47,14 +42,16 @@ function renderer:pre_column(context, draw) table.insert(components, draw.components[component_name]) end return { + component_names = column, components = components, gap = column.gap or 0, } - end, draw.columns) + end, columns) - self.columns = vim.tbl_map( + return vim.tbl_map( function(column_definition) return require('blink.cmp.completion.windows.render.column').new( + column_definition.component_names, column_definition.components, column_definition.gap ) @@ -64,11 +61,11 @@ function renderer:pre_column(context, draw) end function renderer:draw(context, bufnr, items) - renderer:pre_column(context, self.def) + local columns = self:get_columns(context, self.def) local draw_contexts = require('blink.cmp.completion.windows.render.context').get_from_items(context, self.def, items) -- render the columns - for _, column in ipairs(self.columns) do + for _, column in ipairs(columns) do column:render(context, draw_contexts) end @@ -78,7 +75,7 @@ function renderer:draw(context, bufnr, items) local line = '' if self.padding[1] > 0 then line = string.rep(' ', self.padding[1]) end - for _, column in ipairs(self.columns) do + for _, column in ipairs(columns) do local text = column:get_line_text(idx) if #text > 0 then line = line .. text .. string.rep(' ', self.gap) end end @@ -98,7 +95,7 @@ function renderer:draw(context, bufnr, items) on_win = function(_, _, win_bufnr) return bufnr == win_bufnr end, on_line = function(_, _, _, line) local offset = self.padding[1] - for _, column in ipairs(self.columns) do + for _, column in ipairs(columns) do local text = column:get_line_text(line + 1) if #text > 0 then local highlights = column:get_line_highlights(line + 1) @@ -118,28 +115,30 @@ function renderer:draw(context, bufnr, items) end end, }) + + self.columns = columns end -function renderer:get_component_column_location(component_name) - for column_idx, column in ipairs(self.def.columns) do - for component_idx, other_component_name in ipairs(column) do +function renderer:get_component_column_location(columns, component_name) + for column_idx, column in ipairs(columns) do + for component_idx, other_component_name in ipairs(column.component_names) do if other_component_name == component_name then return { column_idx, component_idx } end end end error('No component found with name: ' .. component_name) end -function renderer:get_component_start_col(component_name) - local column_idx, component_idx = unpack(self:get_component_column_location(component_name)) +function renderer:get_component_start_col(columns, component_name) + local column_idx, component_idx = unpack(self:get_component_column_location(columns, component_name)) -- add previous columns local start_col = self.padding[1] for i = 1, column_idx - 1 do - start_col = start_col + self.columns[i].width + self.gap + start_col = start_col + columns[i].width + self.gap end -- add previous components - local line = self.columns[column_idx].lines[1] + local line = columns[column_idx].lines[1] if not line then return start_col end for i = 1, component_idx - 1 do start_col = start_col + #line[i] @@ -151,7 +150,9 @@ end function renderer:get_alignment_start_col() local component_name = self.def.align_to if component_name == nil or component_name == 'none' or component_name == 'cursor' then return 0 end - return self:get_component_start_col(component_name) + + assert(self.columns ~= nil, 'Attempted to get alignment start col before drawing') + return self:get_component_start_col(self.columns, component_name) end return renderer diff --git a/lua/blink/cmp/completion/windows/render/types.lua b/lua/blink/cmp/completion/windows/render/types.lua index 186b3dc3..63310bdf 100644 --- a/lua/blink/cmp/completion/windows/render/types.lua +++ b/lua/blink/cmp/completion/windows/render/types.lua @@ -2,7 +2,7 @@ --- @field align_to? string | 'none' | 'cursor' Align the window to the component with the given name, or to the cursor --- @field padding? number | number[] Padding on the left and right of the grid --- @field gap? number Gap between columns ---- @field columns? { [number]: string, gap?: number }[] Components to render, grouped by column +--- @field columns? blink.cmp.DrawColumnDefinition[] | fun(context: blink.cmp.Context): blink.cmp.DrawColumnDefinition[] Components to render, grouped by column --- @field components? table Component definitions --- @field treesitter? string[] Use treesitter to highlight the label text of completions from these sources --- @@ -22,3 +22,5 @@ --- @field ellipsis? boolean Whether to add an ellipsis when truncating the text --- @field text? fun(ctx: blink.cmp.DrawItemContext): string? Renders the text of the component --- @field highlight? string | fun(ctx: blink.cmp.DrawItemContext, text: string): string | blink.cmp.DrawHighlight[] Renders the highlights of the component +--- +--- @alias blink.cmp.DrawColumnDefinition { [number]: string, gap?: number }