Skip to content

Commit f1ae7bd

Browse files
committed
feat!: use queries for determining context
1 parent 4842abe commit f1ae7bd

File tree

4 files changed

+50
-164
lines changed

4 files changed

+50
-164
lines changed

lua/treesitter-context.lua

Lines changed: 29 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -69,94 +69,6 @@ do
6969
}
7070
end
7171

72-
-- Tells us which leading child node type to skip when highlighting a
73-
-- multi-line node.
74-
local skip_leading_types = {
75-
[word_pattern('class')] = {
76-
php = 'attribute_list',
77-
},
78-
[word_pattern('method')] = {
79-
php = 'attribute_list',
80-
},
81-
}
82-
83-
-- There are language-specific
84-
local DEFAULT_TYPE_PATTERNS = {
85-
-- These catch most generic groups, eg "function_declaration" or "function_block"
86-
default = {
87-
'class',
88-
'function',
89-
'method',
90-
'for',
91-
'while',
92-
'if',
93-
'switch',
94-
'case',
95-
'interface',
96-
'struct',
97-
'enum',
98-
},
99-
elixir = {
100-
'anonymous_function',
101-
'arguments',
102-
'block',
103-
'do_block',
104-
'list',
105-
'map',
106-
'tuple',
107-
'quoted_content',
108-
},
109-
haskell = {
110-
'adt'
111-
},
112-
json = {
113-
'pair',
114-
},
115-
markdown = {
116-
'section',
117-
},
118-
rust = {
119-
'impl_item',
120-
},
121-
scala = {
122-
'object_definition',
123-
},
124-
terraform = {
125-
'block',
126-
'object_elem',
127-
'attribute',
128-
},
129-
tex = {
130-
'chapter',
131-
'section',
132-
'subsection',
133-
'subsubsection',
134-
},
135-
typescript = {
136-
'export_statement',
137-
},
138-
verilog = {
139-
'always_construct',
140-
'statement_or_null',
141-
},
142-
vhdl = {
143-
'process_statement',
144-
'architecture_body',
145-
'entity_declaration',
146-
},
147-
yaml = {
148-
'block_mapping_pair',
149-
},
150-
exact_patterns = {},
151-
}
152-
153-
local DEFAULT_TYPE_EXCLUDE_PATTERNS = {
154-
default = {},
155-
teal = {
156-
'function_body',
157-
},
158-
}
159-
16072
local INDENT_PATTERN = '^%s+'
16173

16274
-- Script variables
@@ -173,49 +85,28 @@ local function get_root_node()
17385
return tree:root()
17486
end
17587

176-
local function is_excluded(node, filetype)
177-
local node_type = node:type()
178-
for _, rgx in ipairs(config.exclude_patterns.default) do
179-
if node_type:find(rgx) then
180-
return true
88+
--- @return boolean
89+
local function compare_ranges(node1, node2)
90+
local range1 = {node1:range()}
91+
local range2 = {node2:range()}
92+
for i = 1, 4 do
93+
if range1[i] ~= range2[i] then
94+
return false
18195
end
18296
end
183-
local filetype_patterns = config.exclude_patterns[filetype]
184-
for _, rgx in ipairs(filetype_patterns or {}) do
185-
if node_type:find(rgx) then
186-
return true
187-
end
188-
end
189-
return false
97+
return true
19098
end
19199

192-
local function is_valid(node, filetype)
193-
if is_excluded(node, filetype) then
194-
return false
195-
end
196-
197-
local node_type = node:type()
198-
for _, rgx in ipairs(config.patterns.default) do
199-
if node_type:find(rgx) then
200-
return true
201-
end
202-
end
203-
local filetype_patterns = config.patterns[filetype]
204-
for _, rgx in ipairs(filetype_patterns or {}) do
205-
if node_type:find(rgx) then
206-
return true
100+
local function is_valid(node, query)
101+
local bufnr = api.nvim_get_current_buf()
102+
for id, node0 in query:iter_captures(node, bufnr, 0, -1) do
103+
local name = query.captures[id] -- name of the capture in the query
104+
if name == 'context' then
105+
return compare_ranges(node, node0)
207106
end
208107
end
209-
return false
210-
end
211108

212-
local function get_type_pattern(node, type_patterns)
213-
local node_type = node:type()
214-
for _, rgx in ipairs(type_patterns) do
215-
if node_type:find(rgx) then
216-
return rgx
217-
end
218-
end
109+
return false
219110
end
220111

221112
local function find_node(node, query)
@@ -235,23 +126,20 @@ local function find_node(node, query)
235126
end
236127

237128
local function get_text_for_node(node)
238-
local type = get_type_pattern(node, config.patterns.default) or node:type()
239-
local filetype = vim.bo.filetype
240-
241-
local start_row, start_col = node:start()
242-
local end_row, end_col = node:end_()
243-
244129
local node_text = vim.treesitter.query.get_node_text(node, 0)
245-
if node_text == nil then return nil, nil end
130+
if node_text == nil then
131+
return nil, nil
132+
end
246133

134+
local start_row, start_col, end_row, end_col = node:range()
247135
local lines = vim.split(node_text, '\n')
248136

249137
if start_col ~= 0 then
250138
lines[1] = api.nvim_buf_get_lines(0, start_row, start_row + 1, false)[1]
251139
end
252140
start_col = 0
253141

254-
local queries = (last_nodes[type] or {})[filetype]
142+
local queries = (last_nodes[node:type()] or {})[vim.bo.filetype]
255143

256144
local last_position
257145

@@ -403,6 +291,13 @@ local function get_parent_matches(max_lines)
403291
lnum, col = unpack(api.nvim_win_get_cursor(0))
404292
end
405293

294+
local lang = parsers.ft_to_lang(vim.bo.filetype)
295+
local query = vim.treesitter.query.get_query(lang, 'context')
296+
297+
if not query then
298+
return
299+
end
300+
406301
local last_matches
407302
local parent_matches = {}
408303
local line_offset = 0
@@ -431,7 +326,7 @@ local function get_parent_matches(max_lines)
431326
local row = parent:start()
432327

433328
local height = math.min(max_lines, #parent_matches)
434-
if is_valid(parent, vim.bo.filetype)
329+
if is_valid(parent, query)
435330
and row >= 0
436331
and row < (topline + height - 1) then
437332

@@ -632,24 +527,6 @@ local function horizontal_scroll_contexts()
632527
end
633528
end
634529

635-
local function normalize_node(node)
636-
local type = get_type_pattern(node, config.patterns.default) or node:type()
637-
local filetype = vim.bo.filetype
638-
639-
local skip_leading_type = (skip_leading_types[type] or {})[filetype]
640-
if skip_leading_type then
641-
local children = ts_utils.get_named_children(node)
642-
for _, child in ipairs(children) do
643-
if child:type() ~= skip_leading_type then
644-
node = child
645-
break
646-
end
647-
end
648-
end
649-
650-
return node
651-
end
652-
653530
local function open(ctx_nodes)
654531
local bufnr = api.nvim_get_current_buf()
655532

@@ -676,8 +553,6 @@ local function open(ctx_nodes)
676553
local contexts = {}
677554

678555
for _, node in ipairs(ctx_nodes) do
679-
node = normalize_node(node)
680-
681556
local lines, range = get_text_for_node(node)
682557
if lines == nil or range == nil or range[1] == nil then return end
683558
local text = merge_lines(lines)
@@ -823,17 +698,7 @@ function M.setup(options)
823698

824699
local userOptions = options or {}
825700

826-
config = vim.tbl_deep_extend('force', {}, defaultConfig, userOptions)
827-
config.patterns = vim.tbl_deep_extend('force', {}, DEFAULT_TYPE_PATTERNS, userOptions.patterns or {})
828-
config.exclude_patterns = vim.tbl_deep_extend('force', {}, DEFAULT_TYPE_EXCLUDE_PATTERNS, userOptions.exclude_patterns or {})
829-
config.exact_patterns = vim.tbl_deep_extend('force', {}, userOptions.exact_patterns or {})
830-
831-
for filetype, patterns in pairs(config.patterns) do
832-
-- Map with word_pattern only if users don't need exact pattern matching
833-
if not config.exact_patterns[filetype] then
834-
config.patterns[filetype] = vim.tbl_map(word_pattern, patterns)
835-
end
836-
end
701+
config = vim.tbl_deep_extend('force', {}, defaultConfig, userOptions)
837702

838703
if config.enable then
839704
M.enable()

queries/c/context.scm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
([
3+
(function_definition)
4+
(if_statement)
5+
] @context)

queries/lua/context.scm

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
([
2+
(function_declaration)
3+
(function_definition)
4+
(if_statement)
5+
(while_statement)
6+
(table_constructor)
7+
(for_statement)
8+
] @context)

queries/teal/context.scm

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
([
3+
(if_statement)
4+
(while_statement)
5+
(anon_function)
6+
(function_statement)
7+
(generic_for_statement)
8+
] @context)

0 commit comments

Comments
 (0)