Skip to content

Commit eb57641

Browse files
authored
Add retrieving types in codequery (#260)
* Add retrieving types in codequery * Add fuzzy search option (#271)
1 parent 2a6a9af commit eb57641

File tree

5 files changed

+207
-37
lines changed

5 files changed

+207
-37
lines changed

program-model/src/buttercup/program_model/api/tree_sitter.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from dataclasses import dataclass
55
from pathlib import Path
66
from functools import lru_cache
7-
from enum import Enum
8-
97
from buttercup.common.challenge_task import ChallengeTask
10-
from buttercup.program_model.utils.common import Function, FunctionBody
8+
from buttercup.program_model.utils.common import (
9+
Function,
10+
FunctionBody,
11+
TypeDefinition,
12+
TypeDefinitionType,
13+
)
1114
from tree_sitter_language_pack import get_language, get_parser
1215
from buttercup.common.project_yaml import ProjectYaml
1316

@@ -81,26 +84,6 @@
8184
"""
8285

8386

84-
@dataclass
85-
class TypeDefinitionType(str, Enum):
86-
"""Enum to store type definition type."""
87-
88-
STRUCT = "struct"
89-
UNION = "union"
90-
ENUM = "enum"
91-
TYPEDEF = "typedef"
92-
PREPROC_TYPE = "preproc_type"
93-
94-
95-
@dataclass
96-
class TypeDefinition:
97-
"""Class to store type definition information."""
98-
99-
name: str
100-
type: TypeDefinitionType
101-
definition: str
102-
103-
10487
@dataclass
10588
class CodeTS:
10689
"""Class to extract information about functions in a challenge project using TreeSitter."""
@@ -199,7 +182,9 @@ def get_function(self, function_name: str, file_path: Path) -> Function | None:
199182
functions = self.get_functions(file_path)
200183
return functions.get(function_name)
201184

202-
def parse_types_in_code(self, file_path: Path) -> dict[str, TypeDefinition]:
185+
def parse_types_in_code(
186+
self, file_path: Path, typename: str | None = None, fuzzy: bool | None = False
187+
) -> dict[str, TypeDefinition]:
203188
"""Parse the definition of a type in a piece of code."""
204189
logger.debug("Parsing types in code")
205190
code = self.challenge_task.task_dir.joinpath(file_path).read_bytes()
@@ -233,6 +218,10 @@ def parse_types_in_code(self, file_path: Path) -> dict[str, TypeDefinition]:
233218

234219
type_definition = code[start_byte : definition_node.end_byte].decode()
235220
name = name_node.text.decode()
221+
if typename and not fuzzy and name != typename:
222+
continue
223+
if typename and fuzzy and typename not in name:
224+
continue
236225
logger.debug("Type name: %s", name)
237226
logger.debug("Type definition: %s", type_definition)
238227

@@ -255,6 +244,7 @@ def parse_types_in_code(self, file_path: Path) -> dict[str, TypeDefinition]:
255244
name=name,
256245
type=type_def_type,
257246
definition=type_definition,
247+
definition_line=definition_node.start_point[0],
258248
)
259249

260250
return res

program-model/src/buttercup/program_model/codequery.py

Lines changed: 105 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
from dataclasses import dataclass, field
99
from pathlib import Path
1010
from itertools import groupby
11-
from typing import ClassVar
11+
from typing import ClassVar, Union
12+
1213

1314
from buttercup.common.challenge_task import ChallengeTask
1415
from buttercup.program_model.api.tree_sitter import CodeTS
15-
from buttercup.program_model.utils.common import Function
16+
from buttercup.program_model.utils.common import (
17+
Function,
18+
TypeDefinition,
19+
)
1620
from buttercup.common.project_yaml import ProjectYaml
1721

1822
logger = logging.getLogger(__name__)
@@ -215,7 +219,10 @@ def _run_cqsearch(self, *args: str) -> list[CQSearchResult]:
215219
return [result for result in results if result is not None]
216220

217221
def get_functions(
218-
self, function_name: str, file_path: Path | None = None
222+
self,
223+
function_name: str,
224+
file_path: Path | None = None,
225+
fuzzy: bool | None = False,
219226
) -> list[Function]:
220227
"""Get the definition(s) of a function in the codebase or in a specific file."""
221228
cqsearch_args = [
@@ -225,31 +232,117 @@ def get_functions(
225232
"2",
226233
"-t",
227234
function_name,
228-
"-e",
235+
"-f" if fuzzy else "-e",
229236
"-u",
230237
]
231238
if file_path:
232239
cqsearch_args += ["-b", file_path.as_posix()]
233240

234241
results = self._run_cqsearch(*cqsearch_args)
242+
logger.debug("cqsearch output: %s", results)
235243

236-
res = []
244+
res: list[Function] = []
237245
results_by_file = groupby(results, key=lambda x: x.file)
238246
for file, results in results_by_file:
239-
if not all(result.value == function_name for result in results):
247+
functions_found = [result.value for result in results]
248+
249+
if not fuzzy and not all(function_name == f for f in functions_found):
240250
logger.warning(
241-
"Function name mismatch, this should not happen: %s != %s",
242-
results[0].value,
251+
"Function name mismatch, this should not happen: %s",
243252
function_name,
244253
)
245254
continue
255+
if fuzzy and not all(function_name in f for f in functions_found):
256+
logger.warning(
257+
"Function name mismatch, this should not happen: %s",
258+
function_name,
259+
)
260+
continue
261+
262+
for function in functions_found:
263+
f = self.ts.get_function(function, file)
264+
if f is None:
265+
logger.warning("Function not found in tree-sitter: %s", function)
266+
continue
267+
res.append(f)
246268

247-
function = self.ts.get_function(function_name, file)
248-
if function is None:
249-
logger.warning("Function not found in tree-sitter: %s", function_name)
269+
return res
270+
271+
def get_types(
272+
self,
273+
type_name: Union[bytes, str],
274+
file_path: Path | None = None,
275+
function_name: str | None = None,
276+
fuzzy: bool | None = False,
277+
) -> list[TypeDefinition]:
278+
"""Finds and return the definition of type named `typename`."""
279+
# Build the cqsearch command to find occurences of the typename in the code
280+
cqsearch_args = [
281+
"-s",
282+
self.CODEQUERY_DB, # Specify the database file path
283+
"-p",
284+
"1", # '1' for symbol
285+
"-t",
286+
type_name, # The name of the type
287+
"-f" if fuzzy else "-e",
288+
"-u", # use full paths
289+
]
290+
if file_path:
291+
cqsearch_args += ["-b", file_path.as_posix()]
292+
293+
results = self._run_cqsearch(*cqsearch_args)
294+
logger.debug("cqsearch output: %s", results)
295+
296+
res: list[TypeDefinition] = []
297+
results_by_file = groupby(results, key=lambda x: x.file)
298+
for file, results in results_by_file:
299+
types_found = [result.value for result in results]
300+
301+
if not fuzzy and not all(type_name == t for t in types_found):
302+
logger.warning(
303+
"Type name mismatch, this should not happen: %s",
304+
type_name,
305+
)
306+
continue
307+
if fuzzy and not all(type_name in t for t in types_found):
308+
logger.warning(
309+
"Type name mismatch, this should not happen: %s",
310+
type_name,
311+
)
250312
continue
251313

252-
res.append(function)
314+
typedefs: dict[str, TypeDefinition] = {}
315+
316+
for typename in types_found:
317+
t = self.ts.parse_types_in_code(file, typename, fuzzy)
318+
if not t:
319+
logger.warning(
320+
"Type definition not found in tree-sitter: %s", typename
321+
)
322+
continue
323+
typedefs.update(t)
324+
325+
if function_name:
326+
# Get the function definition to find its scope
327+
function = self.ts.get_function(function_name, file)
328+
if function:
329+
# Filter type definitions to only include those within the function's scope
330+
filtered_typedefs = {}
331+
for name, typedef in typedefs.items():
332+
# Check if the type definition is within the function's scope
333+
for body in function.bodies:
334+
if (
335+
body.start_line
336+
<= typedef.definition_line
337+
<= body.end_line
338+
):
339+
filtered_typedefs[name] = typedef
340+
break
341+
typedefs = filtered_typedefs
342+
else:
343+
typedefs = {}
344+
345+
res.extend(typedefs.values())
253346

254347
return res
255348

program-model/src/buttercup/program_model/utils/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import dataclass, field
3+
from enum import Enum
34
from pathlib import Path
45

56

@@ -35,3 +36,31 @@ class Function:
3536

3637
bodies: list[FunctionBody] = field(default_factory=list)
3738
"""List of function bodies."""
39+
40+
41+
@dataclass
42+
class TypeDefinitionType(str, Enum):
43+
"""Enum to store type definition type."""
44+
45+
STRUCT = "struct"
46+
UNION = "union"
47+
ENUM = "enum"
48+
TYPEDEF = "typedef"
49+
PREPROC_TYPE = "preproc_type"
50+
51+
52+
@dataclass
53+
class TypeDefinition:
54+
"""Class to store type definition information."""
55+
56+
name: str
57+
"""Name of the type."""
58+
59+
type: TypeDefinitionType
60+
"""Type of the type."""
61+
62+
definition: str
63+
"""Definition of the type."""
64+
65+
definition_line: int
66+
"""Line number of the definition of the type."""

program-model/tests/test_codequery.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from buttercup.common.challenge_task import ChallengeTask
88
from buttercup.program_model.codequery import CodeQuery, CodeQueryPersistent
99
from buttercup.common.task_meta import TaskMeta
10+
from buttercup.program_model.utils.common import TypeDefinitionType
1011

1112

1213
def setup_dirs(tmp_path: Path) -> Path:
@@ -49,6 +50,13 @@ def setup_dirs(tmp_path: Path) -> Path:
4950
int function4(char *s) {
5051
return strlen(s);
5152
}
53+
""")
54+
(source / "test4.c").write_text("""typedef int myInt;
55+
myInt function5(myInt a, myInt b) {
56+
typedef int myOtherInt;
57+
myOtherInt c = a + b;
58+
return a + b + c;
59+
}
5260
""")
5361

5462
# Create task metadata
@@ -123,6 +131,17 @@ def test_get_functions_multiple(mock_challenge_task: ChallengeTask):
123131
)
124132

125133

134+
def test_get_functions_fuzzy(mock_challenge_task: ChallengeTask):
135+
"""Test that we can get functions (fuzzy search) in codebase"""
136+
codequery = CodeQuery(mock_challenge_task)
137+
functions = codequery.get_functions("function", fuzzy=True)
138+
assert len(functions) == 4
139+
functions = codequery.get_functions("function", Path("test3.c"), fuzzy=True)
140+
assert len(functions) == 2
141+
functions = codequery.get_functions("function3", Path("test3.c"), fuzzy=True)
142+
assert len(functions) == 1
143+
144+
126145
def test_keep_status(
127146
mock_challenge_task: ChallengeTask,
128147
mock_challenge_task_ro: ChallengeTask,
@@ -161,6 +180,46 @@ def test_keep_status(
161180
assert mock_challenge_task_ro.task_dir.exists()
162181

163182

183+
def test_get_types(mock_challenge_task: ChallengeTask):
184+
"""Test that we can get types in codebase"""
185+
codequery = CodeQuery(mock_challenge_task)
186+
types = codequery.get_types("myInt", Path("test3.c"))
187+
assert len(types) == 0
188+
types = codequery.get_types("myInt")
189+
assert len(types) == 1
190+
types = codequery.get_types("myInt", Path("test4.c"))
191+
assert len(types) == 1
192+
assert types[0].name == "myInt"
193+
assert types[0].type == TypeDefinitionType.TYPEDEF
194+
assert types[0].definition == "typedef int myInt;"
195+
assert types[0].definition_line == 0
196+
types = codequery.get_types("myInt", Path("test4.c"), function_name="function5")
197+
assert len(types) == 0
198+
types = codequery.get_types(
199+
"myOtherInt", Path("test4.c"), function_name="function5"
200+
)
201+
assert len(types) == 1
202+
assert types[0].name == "myOtherInt"
203+
assert types[0].type == TypeDefinitionType.TYPEDEF
204+
assert types[0].definition == " typedef int myOtherInt;"
205+
assert types[0].definition_line == 2
206+
207+
208+
def test_get_types_fuzzy(mock_challenge_task: ChallengeTask):
209+
"""Test that we can get types (fuzzy search) in codebase"""
210+
codequery = CodeQuery(mock_challenge_task)
211+
types = codequery.get_types("my", Path("test4.c"), fuzzy=True)
212+
assert len(types) == 2
213+
types = codequery.get_types("myInt", Path("test4.c"), fuzzy=True)
214+
assert len(types) == 1
215+
types = codequery.get_types("myOtherInt", Path("test4.c"), fuzzy=True)
216+
assert len(types) == 1
217+
types = codequery.get_types("my", fuzzy=True)
218+
assert len(types) == 2
219+
types = codequery.get_types("myOtherInt", Path("test4.c"), "function5", fuzzy=True)
220+
assert len(types) == 1
221+
222+
164223
@pytest.fixture
165224
def libjpeg_oss_fuzz_task(tmp_path: Path) -> ChallengeTask:
166225
"""Create a challenge task using a real OSS-Fuzz repository."""

program-model/uv.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)