Skip to content

Commit

Permalink
refactor: dynamic draw columns
Browse files Browse the repository at this point in the history
  • Loading branch information
Saghen committed Jan 22, 2025
1 parent 1facba6 commit 55db7e1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 33 deletions.
2 changes: 1 addition & 1 deletion lua/blink/cmp/completion/windows/menu.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function menu.open_with_items(context, items)
menu.items = items
menu.selected_item_idx = menu.selected_item_idx ~= nil and math.min(menu.selected_item_idx, #items) or nil

if not menu.renderer then menu.renderer = require('blink.cmp.completion.windows.render').new(context, config.draw) end
if not menu.renderer then menu.renderer = require('blink.cmp.completion.windows.render').new(config.draw) end
menu.renderer:draw(context, menu.win:get_buf(), items, config.draw)

local auto_show = menu.auto_show
Expand Down
6 changes: 4 additions & 2 deletions lua/blink/cmp/completion/windows/render/column.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
--- @class blink.cmp.DrawColumn
--- @field component_names string[]
--- @field components blink.cmp.DrawComponent[]
--- @field gap number
--- @field lines string[][]
--- @field width number
--- @field ctxs blink.cmp.DrawItemContext[]
---
--- @field new fun(components: blink.cmp.DrawComponent[], gap: number): blink.cmp.DrawColumn
--- @field new fun(component_names: string[], components: blink.cmp.DrawComponent[], gap: number): blink.cmp.DrawColumn
--- @field render fun(self: blink.cmp.DrawColumn,context: blink.cmp.Context, ctxs: blink.cmp.DrawItemContext[])
--- @field get_line_text fun(self: blink.cmp.DrawColumn, line_idx: number): string
--- @field get_line_highlights fun(self: blink.cmp.DrawColumn, line_idx: number): blink.cmp.DrawHighlight[]
Expand All @@ -16,8 +17,9 @@ local text_lib = require('blink.cmp.completion.windows.render.text')
--- @diagnostic disable-next-line: missing-fields
local column = {}

function column.new(components, gap)
function column.new(component_names, components, gap)
local self = setmetatable({}, { __index = column })
self.component_names = component_names
self.components = components
self.gap = gap
self.lines = {}
Expand Down
59 changes: 30 additions & 29 deletions lua/blink/cmp/completion/windows/render/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,34 @@
--- @field gap number
--- @field columns blink.cmp.DrawColumn[]
---
--- @field new fun(context: blink.cmp.Context, draw: blink.cmp.Draw): blink.cmp.Renderer
--- @field draw fun(self: blink.cmp.Renderer, context: blink.cmp.Context, bufnr: number, items: blink.cmp.CompletionItem[], draw: blink.cmp.Draw)
--- @field pre_column fun(self: blink.cmp.Renderer, context: blink.cmp.Context, draw: blink.cmp.Draw)
--- @field get_component_column_location fun(self: blink.cmp.Renderer, component_name: string): { column_idx: number, component_idx: number }
--- @field get_component_start_col fun(self: blink.cmp.Renderer, component_name: string): number
--- @field new fun(draw: blink.cmp.Draw): blink.cmp.Renderer
--- @field draw fun(self: blink.cmp.Renderer, context: blink.cmp.Context, bufnr: number, items: blink.cmp.CompletionItem[], draw: blink.cmp.Draw): blink.cmp.DrawColumn[]
--- @field get_columns fun(self: blink.cmp.Renderer, context: blink.cmp.Context, draw: blink.cmp.Draw): blink.cmp.DrawColumn[]
--- @field get_component_column_location fun(self: blink.cmp.Renderer, columns: blink.cmp.DrawColumn[], component_name: string): { column_idx: number, component_idx: number }
--- @field get_component_start_col fun(self: blink.cmp.Renderer, columns: blink.cmp.DrawColumn[], component_name: string): number
--- @field get_alignment_start_col fun(self: blink.cmp.Renderer): number

local ns = vim.api.nvim_create_namespace('blink_cmp_renderer')

--- @type blink.cmp.Renderer
--- @diagnostic disable-next-line: missing-fields
local renderer = {}
--- @type blink.cmp.DrawColumn[]
local columns = {}

function renderer.new(context, draw)
--- Convert the component names in the columns to the component definitions
columns = draw.columns
if type(columns) == 'function' then draw.columns = columns(context) end

function renderer.new(draw)
local padding = type(draw.padding) == 'number' and { draw.padding, draw.padding } or draw.padding
--- @cast padding number[]

local self = setmetatable({}, { __index = renderer })
self.padding = padding
self.gap = draw.gap
self.def = draw
renderer:pre_column(context, draw)
return self
end

function renderer:pre_column(context, draw)
if type(columns) == 'function' then draw.columns = columns(context) end
function renderer:get_columns(context, draw)
local columns = draw.columns
if type(columns) == 'function' then columns = columns(context) end
--- @cast columns blink.cmp.DrawColumnDefinition[]

--- @type blink.cmp.DrawComponent[][]
local columns_definitions = vim.tbl_map(function(column)
Expand All @@ -47,14 +42,16 @@ function renderer:pre_column(context, draw)
table.insert(components, draw.components[component_name])
end
return {
component_names = column,
components = components,
gap = column.gap or 0,
}
end, draw.columns)
end, columns)

self.columns = vim.tbl_map(
return vim.tbl_map(
function(column_definition)
return require('blink.cmp.completion.windows.render.column').new(
column_definition.component_names,
column_definition.components,
column_definition.gap
)
Expand All @@ -64,11 +61,11 @@ function renderer:pre_column(context, draw)
end

function renderer:draw(context, bufnr, items)
renderer:pre_column(context, self.def)
local columns = self:get_columns(context, self.def)
local draw_contexts = require('blink.cmp.completion.windows.render.context').get_from_items(context, self.def, items)

-- render the columns
for _, column in ipairs(self.columns) do
for _, column in ipairs(columns) do
column:render(context, draw_contexts)
end

Expand All @@ -78,7 +75,7 @@ function renderer:draw(context, bufnr, items)
local line = ''
if self.padding[1] > 0 then line = string.rep(' ', self.padding[1]) end

for _, column in ipairs(self.columns) do
for _, column in ipairs(columns) do
local text = column:get_line_text(idx)
if #text > 0 then line = line .. text .. string.rep(' ', self.gap) end
end
Expand All @@ -98,7 +95,7 @@ function renderer:draw(context, bufnr, items)
on_win = function(_, _, win_bufnr) return bufnr == win_bufnr end,
on_line = function(_, _, _, line)
local offset = self.padding[1]
for _, column in ipairs(self.columns) do
for _, column in ipairs(columns) do
local text = column:get_line_text(line + 1)
if #text > 0 then
local highlights = column:get_line_highlights(line + 1)
Expand All @@ -118,28 +115,30 @@ function renderer:draw(context, bufnr, items)
end
end,
})

self.columns = columns
end

function renderer:get_component_column_location(component_name)
for column_idx, column in ipairs(self.def.columns) do
for component_idx, other_component_name in ipairs(column) do
function renderer:get_component_column_location(columns, component_name)
for column_idx, column in ipairs(columns) do
for component_idx, other_component_name in ipairs(column.component_names) do
if other_component_name == component_name then return { column_idx, component_idx } end
end
end
error('No component found with name: ' .. component_name)
end

function renderer:get_component_start_col(component_name)
local column_idx, component_idx = unpack(self:get_component_column_location(component_name))
function renderer:get_component_start_col(columns, component_name)
local column_idx, component_idx = unpack(self:get_component_column_location(columns, component_name))

-- add previous columns
local start_col = self.padding[1]
for i = 1, column_idx - 1 do
start_col = start_col + self.columns[i].width + self.gap
start_col = start_col + columns[i].width + self.gap
end

-- add previous components
local line = self.columns[column_idx].lines[1]
local line = columns[column_idx].lines[1]
if not line then return start_col end
for i = 1, component_idx - 1 do
start_col = start_col + #line[i]
Expand All @@ -151,7 +150,9 @@ end
function renderer:get_alignment_start_col()
local component_name = self.def.align_to
if component_name == nil or component_name == 'none' or component_name == 'cursor' then return 0 end
return self:get_component_start_col(component_name)

assert(self.columns ~= nil, 'Attempted to get alignment start col before drawing')
return self:get_component_start_col(self.columns, component_name)
end

return renderer
4 changes: 3 additions & 1 deletion lua/blink/cmp/completion/windows/render/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
--- @field align_to? string | 'none' | 'cursor' Align the window to the component with the given name, or to the cursor
--- @field padding? number | number[] Padding on the left and right of the grid
--- @field gap? number Gap between columns
--- @field columns? { [number]: string, gap?: number }[] Components to render, grouped by column
--- @field columns? blink.cmp.DrawColumnDefinition[] | fun(context: blink.cmp.Context): blink.cmp.DrawColumnDefinition[] Components to render, grouped by column
--- @field components? table<string, blink.cmp.DrawComponent> Component definitions
--- @field treesitter? string[] Use treesitter to highlight the label text of completions from these sources
---
Expand All @@ -22,3 +22,5 @@
--- @field ellipsis? boolean Whether to add an ellipsis when truncating the text
--- @field text? fun(ctx: blink.cmp.DrawItemContext): string? Renders the text of the component
--- @field highlight? string | fun(ctx: blink.cmp.DrawItemContext, text: string): string | blink.cmp.DrawHighlight[] Renders the highlights of the component
---
--- @alias blink.cmp.DrawColumnDefinition { [number]: string, gap?: number }

0 comments on commit 55db7e1

Please sign in to comment.