6969 }
7070end
7171
72- -- Tells us which leading child node type to skip when highlighting a
73- -- multi-line node.
74- local skip_leading_types = {
75- [word_pattern (' class' )] = {
76- php = ' attribute_list' ,
77- },
78- [word_pattern (' method' )] = {
79- php = ' attribute_list' ,
80- },
81- }
82-
83- -- There are language-specific
84- local DEFAULT_TYPE_PATTERNS = {
85- -- These catch most generic groups, eg "function_declaration" or "function_block"
86- default = {
87- ' class' ,
88- ' function' ,
89- ' method' ,
90- ' for' ,
91- ' while' ,
92- ' if' ,
93- ' switch' ,
94- ' case' ,
95- ' interface' ,
96- ' struct' ,
97- ' enum' ,
98- },
99- elixir = {
100- ' anonymous_function' ,
101- ' arguments' ,
102- ' block' ,
103- ' do_block' ,
104- ' list' ,
105- ' map' ,
106- ' tuple' ,
107- ' quoted_content' ,
108- },
109- haskell = {
110- ' adt'
111- },
112- json = {
113- ' pair' ,
114- },
115- markdown = {
116- ' section' ,
117- },
118- rust = {
119- ' impl_item' ,
120- },
121- scala = {
122- ' object_definition' ,
123- },
124- terraform = {
125- ' block' ,
126- ' object_elem' ,
127- ' attribute' ,
128- },
129- tex = {
130- ' chapter' ,
131- ' section' ,
132- ' subsection' ,
133- ' subsubsection' ,
134- },
135- typescript = {
136- ' export_statement' ,
137- },
138- verilog = {
139- ' always_construct' ,
140- ' statement_or_null' ,
141- },
142- vhdl = {
143- ' process_statement' ,
144- ' architecture_body' ,
145- ' entity_declaration' ,
146- },
147- yaml = {
148- ' block_mapping_pair' ,
149- },
150- exact_patterns = {},
151- }
152-
153- local DEFAULT_TYPE_EXCLUDE_PATTERNS = {
154- default = {},
155- teal = {
156- ' function_body' ,
157- },
158- }
159-
16072local INDENT_PATTERN = ' ^%s+'
16173
16274-- Script variables
@@ -173,49 +85,28 @@ local function get_root_node()
17385 return tree :root ()
17486end
17587
176- local function is_excluded (node , filetype )
177- local node_type = node :type ()
178- for _ , rgx in ipairs (config .exclude_patterns .default ) do
179- if node_type :find (rgx ) then
180- return true
88+ --- @return boolean
89+ local function compare_ranges (node1 , node2 )
90+ local range1 = {node1 :range ()}
91+ local range2 = {node2 :range ()}
92+ for i = 1 , 4 do
93+ if range1 [i ] ~= range2 [i ] then
94+ return false
18195 end
18296 end
183- local filetype_patterns = config .exclude_patterns [filetype ]
184- for _ , rgx in ipairs (filetype_patterns or {}) do
185- if node_type :find (rgx ) then
186- return true
187- end
188- end
189- return false
97+ return true
19098end
19199
192- local function is_valid (node , filetype )
193- if is_excluded (node , filetype ) then
194- return false
195- end
196-
197- local node_type = node :type ()
198- for _ , rgx in ipairs (config .patterns .default ) do
199- if node_type :find (rgx ) then
200- return true
201- end
202- end
203- local filetype_patterns = config .patterns [filetype ]
204- for _ , rgx in ipairs (filetype_patterns or {}) do
205- if node_type :find (rgx ) then
206- return true
100+ local function is_valid (node , query )
101+ local bufnr = api .nvim_get_current_buf ()
102+ for id , node0 in query :iter_captures (node , bufnr , 0 , - 1 ) do
103+ local name = query .captures [id ] -- name of the capture in the query
104+ if name == ' context' then
105+ return compare_ranges (node , node0 )
207106 end
208107 end
209- return false
210- end
211108
212- local function get_type_pattern (node , type_patterns )
213- local node_type = node :type ()
214- for _ , rgx in ipairs (type_patterns ) do
215- if node_type :find (rgx ) then
216- return rgx
217- end
218- end
109+ return false
219110end
220111
221112local function find_node (node , query )
@@ -235,23 +126,20 @@ local function find_node(node, query)
235126end
236127
237128local function get_text_for_node (node )
238- local type = get_type_pattern (node , config .patterns .default ) or node :type ()
239- local filetype = vim .bo .filetype
240-
241- local start_row , start_col = node :start ()
242- local end_row , end_col = node :end_ ()
243-
244129 local node_text = vim .treesitter .query .get_node_text (node , 0 )
245- if node_text == nil then return nil , nil end
130+ if node_text == nil then
131+ return nil , nil
132+ end
246133
134+ local start_row , start_col , end_row , end_col = node :range ()
247135 local lines = vim .split (node_text , ' \n ' )
248136
249137 if start_col ~= 0 then
250138 lines [1 ] = api .nvim_buf_get_lines (0 , start_row , start_row + 1 , false )[1 ]
251139 end
252140 start_col = 0
253141
254- local queries = (last_nodes [type ] or {})[filetype ]
142+ local queries = (last_nodes [node : type () ] or {})[vim . bo . filetype ]
255143
256144 local last_position
257145
@@ -403,6 +291,13 @@ local function get_parent_matches(max_lines)
403291 lnum , col = unpack (api .nvim_win_get_cursor (0 ))
404292 end
405293
294+ local lang = parsers .ft_to_lang (vim .bo .filetype )
295+ local query = vim .treesitter .query .get_query (lang , ' context' )
296+
297+ if not query then
298+ return
299+ end
300+
406301 local last_matches
407302 local parent_matches = {}
408303 local line_offset = 0
@@ -431,7 +326,7 @@ local function get_parent_matches(max_lines)
431326 local row = parent :start ()
432327
433328 local height = math.min (max_lines , # parent_matches )
434- if is_valid (parent , vim . bo . filetype )
329+ if is_valid (parent , query )
435330 and row >= 0
436331 and row < (topline + height - 1 ) then
437332
@@ -632,24 +527,6 @@ local function horizontal_scroll_contexts()
632527 end
633528end
634529
635- local function normalize_node (node )
636- local type = get_type_pattern (node , config .patterns .default ) or node :type ()
637- local filetype = vim .bo .filetype
638-
639- local skip_leading_type = (skip_leading_types [type ] or {})[filetype ]
640- if skip_leading_type then
641- local children = ts_utils .get_named_children (node )
642- for _ , child in ipairs (children ) do
643- if child :type () ~= skip_leading_type then
644- node = child
645- break
646- end
647- end
648- end
649-
650- return node
651- end
652-
653530local function open (ctx_nodes )
654531 local bufnr = api .nvim_get_current_buf ()
655532
@@ -676,8 +553,6 @@ local function open(ctx_nodes)
676553 local contexts = {}
677554
678555 for _ , node in ipairs (ctx_nodes ) do
679- node = normalize_node (node )
680-
681556 local lines , range = get_text_for_node (node )
682557 if lines == nil or range == nil or range [1 ] == nil then return end
683558 local text = merge_lines (lines )
@@ -823,17 +698,7 @@ function M.setup(options)
823698
824699 local userOptions = options or {}
825700
826- config = vim .tbl_deep_extend (' force' , {}, defaultConfig , userOptions )
827- config .patterns = vim .tbl_deep_extend (' force' , {}, DEFAULT_TYPE_PATTERNS , userOptions .patterns or {})
828- config .exclude_patterns = vim .tbl_deep_extend (' force' , {}, DEFAULT_TYPE_EXCLUDE_PATTERNS , userOptions .exclude_patterns or {})
829- config .exact_patterns = vim .tbl_deep_extend (' force' , {}, userOptions .exact_patterns or {})
830-
831- for filetype , patterns in pairs (config .patterns ) do
832- -- Map with word_pattern only if users don't need exact pattern matching
833- if not config .exact_patterns [filetype ] then
834- config .patterns [filetype ] = vim .tbl_map (word_pattern , patterns )
835- end
836- end
701+ config = vim .tbl_deep_extend (' force' , {}, defaultConfig , userOptions )
837702
838703 if config .enable then
839704 M .enable ()
0 commit comments