Skip to content

Commit d35f43c

Browse files
committed
feat: 106: Extract Function
Just the bare bones of the project getting started here.
1 parent 3dc51eb commit d35f43c

File tree

10 files changed

+408
-0
lines changed

10 files changed

+408
-0
lines changed

lua/refactoring/106.lua

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
local dev = require("refactoring.dev")
2+
dev.reload()
3+
4+
local ts_utils = require("nvim-treesitter.ts_utils")
5+
local utils = require("refactoring.utils")
6+
local vim_helpers = require("refactoring.vim-helpers")
7+
8+
local REFACTORING = {}
9+
local REFACTORING_OPTIONS = {
10+
code_generation = {
11+
lua = {
12+
extract_function = function(opts)
13+
return {
14+
create = table.concat(vim.tbl_flatten({
15+
string.format("local function %s(%s)", opts.name, table.concat(opts.args, ", ")),
16+
opts.body,
17+
"end",
18+
""
19+
}), "\n"),
20+
21+
call = string.format("%s(%s)", opts.name, table.concat(opts.args, ", ")),
22+
}
23+
end,
24+
}
25+
}
26+
}
27+
28+
local function get_selection_range()
29+
local _, start_row, start_col, _ = unpack(vim.fn.getpos("'<"))
30+
local _, end_row, _, _ = unpack(vim.fn.getpos("'>"))
31+
local end_col = vim.fn.col("'>")
32+
33+
-- end_col :: TS is 0 based, and '> on line selections is char_count + 1
34+
-- I think - 2 is correct on
35+
--
36+
-- end_row : end_row is exclusive in TS, so we don't minus
37+
return start_row, start_col, end_row, end_col
38+
end
39+
40+
local function vim_range_to_ts_range(start_row, start_col, end_row, end_col)
41+
return start_row - 1, start_col, end_row - 1, end_col
42+
end
43+
44+
-- 106
45+
local function get_text_edits(selected_local_references, end_row, lang, start_col, start_row, end_col, scope_range, function_name)
46+
-- local declaration within the selection range.
47+
local lsp_text_edits = {}
48+
local extract_function = REFACTORING_OPTIONS.code_generation[lang].extract_function({
49+
args = vim.tbl_keys(selected_local_references),
50+
body = vim.api.nvim_buf_get_lines(0, start_row, end_row, false),
51+
name = function_name,
52+
})
53+
table.insert(lsp_text_edits, {
54+
range = scope_range,
55+
newText = string.format("\n%s", extract_function.create)
56+
})
57+
table.insert(lsp_text_edits, {
58+
range = {
59+
start = {line = start_row, character = start_col},
60+
["end"] = {line = end_row, character = end_col}
61+
},
62+
newText = string.format("\n%s", extract_function.call)
63+
})
64+
return lsp_text_edits
65+
end
66+
67+
local function get_scope_range(scope)
68+
-- vim_helpers.move_text(0, start_row, end_row, scope:range())
69+
local scope_range = ts_utils.node_to_lsp_range(scope)
70+
scope_range.start.line = scope_range.start.line - 1
71+
scope_range["end"] = scope_range.start
72+
73+
return scope_range
74+
end
75+
76+
local function get_local_definitions(local_defs, function_args)
77+
local local_def_map = {}
78+
79+
for _, def in pairs(local_defs) do
80+
local_def_map[ts_utils.get_node_text(def)[1]] = true
81+
end
82+
for _, def in pairs(function_args) do
83+
local_def_map[ts_utils.get_node_text(def)[1]] = true
84+
end
85+
86+
return local_def_map
87+
end
88+
89+
REFACTORING.extract = function(bufnr)
90+
-- lua 1 based index
91+
-- vim apis are 1 based
92+
-- treesitter is 0 based
93+
-- first entry (1), line 1, row 0
94+
bufnr = bufnr or 0
95+
96+
local lang = vim.bo.filetype
97+
local start_row, start_col, end_row, end_col = get_selection_range()
98+
local ts_start_row, ts_start_col, ts_end_row, ts_end_col =
99+
vim_range_to_ts_range(start_row, start_col, end_row, end_col)
100+
local root = utils.get_root(lang)
101+
local scope = utils.get_scope_over_selection(root, start_row, start_col, end_row + 1, end_col, lang)
102+
103+
if scope == nil then
104+
error("Scope is nil")
105+
end
106+
107+
local local_defs = vim.tbl_filter(function(node)
108+
return not utils.range_contains_node(node, ts_start_row, ts_start_col, ts_end_row, ts_end_col)
109+
end, utils.get_locals_defs(scope, lang))
110+
111+
local function_args = utils.get_function_args(scope, lang)
112+
local local_def_map = get_local_definitions(local_defs, function_args)
113+
local local_references = utils.get_all_identifiers(scope, lang)
114+
local selected_local_references = {}
115+
116+
for _, local_ref in pairs(local_references) do
117+
local local_name = ts_utils.get_node_text(local_ref)[1]
118+
if utils.range_contains_node(local_ref, ts_start_row, ts_start_col, ts_end_row, ts_end_col) and
119+
local_def_map[local_name] then
120+
selected_local_references[local_name] = true
121+
end
122+
end
123+
124+
-- TODO: Probably use text edit
125+
local scope_range = get_scope_range(scope)
126+
127+
-- TODO: Polar, nvim_buf_get_lines doesn't actually get the highlighted
128+
-- region, instead the highlighted rows
129+
local function_name = vim.fn.input("106: Extract Function Name > ")
130+
131+
-- TODO: Polor, could you also make the variable that is returned the first
132+
local text_edits = get_text_edits(selected_local_references, end_row, lang, start_col, start_row, end_col, scope_range, function_name)
133+
vim.lsp.util.apply_text_edits(text_edits, 0)
134+
end
135+
136+
REFACTORING.extract()
137+
138+
return REFACTORING

lua/refactoring/106_spec.lua

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
vim.cmd("set rtp+=" .. vim.loop.cwd())
2+
3+
local eq = assert.are.same
4+
local refactoring = require("refactoring.106")
5+
6+
describe("Refactoring", function()
7+
it("should test selection grabbing", function(done)
8+
vim.cmd(":new")
9+
vim.bo.filetype = "lua"
10+
vim.api.nvim_buf_set_lines(0, 0, 3, false, {
11+
"local foo = 5",
12+
"",
13+
"foo = foo + 5 + foo",
14+
})
15+
16+
eq(3, #vim.api.nvim_buf_get_lines(0, 0, 3, false))
17+
vim.cmd(":norm! ggVjj")
18+
19+
eq(0, vim.fn.col("'<"), "Selections start at 0")
20+
eq(#"foo = foo + 5 + foo" + 1, vim.fn.col("'>"), "Selection stops + 1 after last line.")
21+
22+
-- print("COL", vim.fn.col("'<"), vim.fn.col("'>"))
23+
end)
24+
end)
25+
26+
27+

lua/refactoring/dev.lua

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
local M = {}
2+
M.reload = function()
3+
require("plenary.reload").reload_module("refactoring")
4+
end
5+
6+
return M
7+

lua/refactoring/init.lua

Whitespace-only changes.

lua/refactoring/intersect.lua

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
-- Array<node_wrapper>
3+
local intersect_nodes = function(nodes, row, col)
4+
local found = {}
5+
for idx = 1, #nodes do
6+
local node = nodes[idx]
7+
local sRow = node.dim.s.r
8+
local sCol = node.dim.s.c
9+
local eRow = node.dim.e.r
10+
local eCol = node.dim.e.c
11+
if utils.intersects(row, col, sRow, sCol, eRow, eCol) then
12+
table.insert(found, node)
13+
end
14+
end
15+
return found
16+
end
17+

lua/refactoring/query.lua

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
function find_identifiers()
2+
end
3+
4+

lua/refactoring/refactoring.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+

lua/refactoring/utils.lua

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
local ts_utils = require("nvim-treesitter.ts_utils")
2+
local ts_query = require("nvim-treesitter.query")
3+
local parsers = require("nvim-treesitter.parsers")
4+
local locals = require("nvim-treesitter.locals")
5+
6+
local M = {}
7+
8+
M.get_root = function(lang)
9+
local parser = parsers.get_parser(0, lang)
10+
return parser:parse()[1]:root()
11+
end
12+
13+
M.get_bounded_query = function(query, lang, startR, stopR)
14+
local success, parsed_query = pcall(function()
15+
return vim.treesitter.parse_query(lang, query)
16+
end)
17+
18+
if not success then
19+
error("Unsuccessful successful first try")
20+
end
21+
22+
local root = M.get_root(lang)
23+
24+
local out = {}
25+
for match in ts_query.iter_prepared_matches(parsed_query, root, 0, startR - 1, stopR) do
26+
locals.recurse_local_nodes(match, function(_, node, path)
27+
table.insert(out, node)
28+
end)
29+
end
30+
return out;
31+
end
32+
33+
local refactor_constants = {
34+
lua = {
35+
scope = {
36+
["function"] = true,
37+
["function_definition"] = true
38+
}
39+
}
40+
}
41+
42+
-- determines if a contains node b.
43+
-- @param a the containing node
44+
-- @param b the node to be contained
45+
function M.node_contains(a, b)
46+
if a == nil or b == nil then
47+
return false
48+
end
49+
50+
local start_row, start_col, end_row, end_col = b:range()
51+
return ts_utils.is_in_node_range(a, start_row, start_col) and
52+
ts_utils.is_in_node_range(a, end_row, end_col)
53+
end
54+
55+
-- determines if a node exists within a range. Imagine a range selection
56+
-- across '<,'> and an identifier. Does the identifier exist within the
57+
-- selection?
58+
--
59+
-- @param a the containing node
60+
-- @param b the node to be contained
61+
M.range_contains_node = function(node, start_row, start_col, end_row, end_col)
62+
local node_start_row, node_start_col, node_end_row, node_end_col = node:range()
63+
64+
-- There are five possible conditions
65+
-- 1. node start/end row are contained exclusively within the range.
66+
-- 2. The range is a single line range
67+
-- - the node start/end row must equal start_row and cols have to exist
68+
-- within range, inclusive
69+
-- 3. The node exists solely within the first line
70+
-- - node_start_col has to be inclusive with start_col, end col doesn't
71+
-- matter.
72+
-- 4. The node exists solely within the last line
73+
-- - node_start_col doesn't matter whereas node_end_col has to be
74+
-- inclusive with end_col
75+
-- 5. The node starts / ends on the same rows and has to have each column
76+
-- considered
77+
if start_row < node_start_row and end_row > node_end_row then
78+
return true
79+
elseif start_row == end_row then
80+
return start_row == node_start_row and
81+
end_row == node_end_row and
82+
start_col <= node_start_col and
83+
end_col >= node_end_col
84+
85+
elseif start_row == node_start_row and start_row == node_end_row then
86+
return start_col <= node_start_col
87+
elseif end_row == node_start_row and end_row == node_end_row then
88+
return end_col >= node_end_col
89+
elseif start_row <= node_start_row and end_row >= node_end_row then
90+
return start_col <= node_start_col and end_col >= node_end_col
91+
end
92+
93+
return false
94+
end
95+
96+
M.get_scope_over_selection = function(root, start_line, start_col, end_line, end_col, lang)
97+
local start_scope = M.get_scope(root, start_line, start_col, lang)
98+
local end_scope = M.get_scope(root, end_line, end_col, lang)
99+
100+
if start_scope ~= end_scope then
101+
error("Selection spans over two scopes, cannot determine scope")
102+
end
103+
104+
return start_scope
105+
end
106+
107+
M.get_scope = function(root, line, col, lang)
108+
local function_scopes = {}
109+
local query = vim.treesitter.get_query(lang, "locals")
110+
111+
for id, n, _ in query:iter_captures(root, 0, 0, -1) do
112+
if query.captures[id] == "scope" and refactor_constants[lang].scope[n:type()] then
113+
table.insert(function_scopes, n)
114+
end
115+
end
116+
117+
local out = nil
118+
for _, scope in pairs(function_scopes) do
119+
-- TODO: This is a confusing issue
120+
-- should a scope that contains another scope but terminates at the
121+
-- same point be the outer or inner? Should potentially be considered
122+
-- a list of scopes...
123+
if ts_utils.is_in_node_range(scope, line, col) and
124+
(out == nil or M.node_contains(out, scope)) then
125+
126+
out = scope
127+
end
128+
end
129+
130+
return out
131+
end
132+
133+
local function get_refactoring_query(lang)
134+
local query = vim.treesitter.get_query(lang, "refactoring")
135+
if not query then
136+
error("refactoring not supported in this language. Please provide a queries/<lang>/refactoring.scm")
137+
end
138+
return query
139+
end
140+
141+
local function pluck_by_capture(scope, lang, query, capture_name)
142+
local local_defs = {}
143+
local root = M.get_root(lang)
144+
for id, node, _ in query:iter_captures(root, 0, 0, -1) do
145+
if query.captures[id] == capture_name and M.node_contains(scope, node) then
146+
table.insert(local_defs, node)
147+
end
148+
end
149+
150+
return local_defs
151+
end
152+
153+
M.get_function_args = function(scope, lang)
154+
return pluck_by_capture(scope, lang, get_refactoring_query(lang), "definition.function_argument")
155+
end
156+
157+
M.get_locals_defs = function(scope, lang)
158+
return pluck_by_capture(scope, lang, get_refactoring_query(lang), "definition.local_var")
159+
end
160+
161+
M.get_all_identifiers = function(scope, lang)
162+
return pluck_by_capture(scope, lang, vim.treesitter.get_query(lang, "locals"), "reference")
163+
end
164+
165+
-- is there a better way?
166+
M.range_to_table = function(node)
167+
if node == nil then
168+
return "range nil"
169+
end
170+
local a, b, c, d = node:range()
171+
return {a, b, c, d}
172+
end
173+
174+
return M

0 commit comments

Comments
 (0)