Skip to content

Commit f9b606f

Browse files
authored
Merge pull request #77 from NASA-IMPACT/feature/combined-code-agent
Add CombinedCodeSearchTool for unified code search
2 parents 73133e2 + 6d60c42 commit f9b606f

File tree

3 files changed

+117
-19
lines changed

3 files changed

+117
-19
lines changed

akd/tools/code_search.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic.networks import HttpUrl
1313
from scipy.spatial.distance import cdist
1414
from tenacity import retry, stop_after_attempt
15+
from sentence_transformers import CrossEncoder
1516

1617
from akd.errors import SchemaValidationError
1718
from akd.structures import SearchResultItem
@@ -31,7 +32,10 @@ class CodeSearchToolInputSchema(SearchToolInputSchema):
3132
Input schema for the code search tool.
3233
"""
3334

34-
pass
35+
@computed_field
36+
def top_k(self) -> int:
37+
"""Returns the number of top results to return."""
38+
return self.max_results
3539

3640

3741
class CodeSearchToolOutputSchema(SearchToolOutputSchema):
@@ -120,11 +124,11 @@ def __get_sort_key(result):
120124

121125
# Then check if 'extra' field exists and contains the sort_by key
122126
if (
123-
"extra" in result
124-
and isinstance(result["extra"], dict)
125-
and sort_by in result["extra"]
127+
result.extra
128+
and isinstance(result.extra, dict)
129+
and sort_by in result.extra
126130
):
127-
return result["extra"][sort_by]
131+
return result.extra[sort_by]
128132

129133
# If key not found anywhere, return a default value that will sort last
130134
# Using float('inf') for numerical sorting or empty string for string sorting
@@ -139,15 +143,85 @@ def __get_sort_key(result):
139143
return results
140144

141145

142-
class LocalRepoCodeSearchToolInputSchema(CodeSearchToolInputSchema):
146+
class CombinedCodeSearchToolConfig(CodeSearchToolConfig):
143147
"""
144-
Input schema for the local repository code search tool.
148+
Configuration for the combined code search tool.
145149
"""
146150

147-
@computed_field
148-
def top_k(self) -> int:
149-
"""Returns the number of top results to return."""
150-
return self.max_results
151+
reranker_model_name: str = Field(
152+
"cross-encoder/ms-marco-MiniLM-L6-v2",
153+
description="The model to use for reranking the combined results.",
154+
)
155+
156+
157+
class CombinedCodeSearchTool(CodeSearchTool):
158+
"""
159+
Tool for performing combined code search using multiple sub-tools and CrossEncoder reranking.
160+
"""
161+
162+
input_schema = CodeSearchToolInputSchema
163+
output_schema = CodeSearchToolOutputSchema
164+
config_schema = CombinedCodeSearchToolConfig
165+
166+
def __init__(
167+
self,
168+
config: CombinedCodeSearchToolConfig | None = None,
169+
tools: Optional[list[CodeSearchTool]] = None,
170+
debug: bool = False,
171+
):
172+
super().__init__(config, debug)
173+
self.tools = tools or [
174+
LocalRepoCodeSearchTool(debug=debug),
175+
GitHubCodeSearchTool(debug=debug),
176+
SDECodeSearchTool(debug=debug),
177+
]
178+
self.reranker_model = CrossEncoder(config.reranker_model_name)
179+
180+
async def _arun(
181+
self,
182+
params: CodeSearchToolInputSchema,
183+
**kwargs,
184+
) -> CodeSearchToolOutputSchema:
185+
"""Run the combined code search tool and aggregate results from all tools."""
186+
all_results: list[SearchResultItem] = []
187+
188+
for tool in self.tools:
189+
try:
190+
if self.debug:
191+
logger.debug(f"Running tool: {tool.__class__.__name__}")
192+
result = await tool._arun(params)
193+
for res in result.results:
194+
res.extra["tool"] = tool.__class__.__name__
195+
all_results.extend(result.results)
196+
except Exception as e:
197+
logger.error(f"Error running tool {tool.__class__.__name__}: {e}")
198+
199+
reranked = self._rerank_results(all_results, params.queries[0])[: params.top_k]
200+
return self.output_schema(results=reranked, category="technology")
201+
202+
def _rerank_results(
203+
self, results: list[SearchResultItem], query: str
204+
) -> list[SearchResultItem]:
205+
"""
206+
Rerank results using a CrossEncoder model based on the query and result content.
207+
"""
208+
209+
try:
210+
# Generate (query, content) pairs
211+
pairs = [(query, result.content) for result in results]
212+
213+
# Get similarity scores from CrossEncoder
214+
scores = self.reranker_model.predict(pairs)
215+
216+
# Attach scores
217+
for score, result in zip(scores, results):
218+
result.extra["score"] = score
219+
220+
# Sort results by score
221+
return self._sort_results(results, sort_by="score")
222+
except Exception as e:
223+
logger.error(f"Reranking failed: {e}")
224+
return results
151225

152226

153227
class LocalRepoCodeSearchToolConfig(CodeSearchToolConfig):
@@ -173,7 +247,7 @@ class LocalRepoCodeSearchTool(CodeSearchTool):
173247
It automatically downloads the necessary data file if it's not found locally.
174248
"""
175249

176-
input_schema = LocalRepoCodeSearchToolInputSchema
250+
input_schema = CodeSearchToolInputSchema
177251
output_schema = CodeSearchToolOutputSchema
178252
config_schema = LocalRepoCodeSearchToolConfig
179253

examples/code_search_test.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
CodeSearchToolInputSchema,
1111
LocalRepoCodeSearchTool,
1212
LocalRepoCodeSearchToolConfig,
13-
LocalRepoCodeSearchToolInputSchema,
1413
GitHubCodeSearchTool,
1514
SDECodeSearchTool,
1615
SDECodeSearchToolConfig,
16+
CombinedCodeSearchTool,
17+
CombinedCodeSearchToolConfig,
1718
)
1819

1920

@@ -25,9 +26,7 @@ async def local_repo_search_test():
2526
cfg = LocalRepoCodeSearchToolConfig()
2627
tool = LocalRepoCodeSearchTool(config=cfg)
2728

28-
search_input = LocalRepoCodeSearchToolInputSchema(
29-
queries=["landslide nepal"], max_results=5
30-
)
29+
search_input = CodeSearchToolInputSchema(queries=["landslide nepal"], max_results=5)
3130

3231
print("Running the search...")
3332
output = await tool._arun(search_input)
@@ -82,10 +81,36 @@ async def sde_search_test():
8281
print("-" * 100)
8382

8483

84+
# Combined Code Search Tool
85+
async def combined_code_search_test():
86+
"""An async function to run the tool."""
87+
88+
print("Initializing the tool...")
89+
cfg = CombinedCodeSearchToolConfig()
90+
tool = CombinedCodeSearchTool(config=cfg)
91+
92+
search_input = CodeSearchToolInputSchema(
93+
queries=["landslide nepal"], max_results=10
94+
)
95+
96+
print("Running the search...")
97+
output = await tool._arun(search_input)
98+
99+
print("\n--- Search Results ---")
100+
for result in output.results:
101+
print(result.url)
102+
print(result.content)
103+
print(result.extra["tool"])
104+
print(result.extra["score"])
105+
print("-" * 100)
106+
107+
85108
if __name__ == "__main__":
86109
print("Running local repo search test...")
87110
asyncio.run(local_repo_search_test())
88111
print("Running GitHub search test...")
89112
asyncio.run(github_search_test())
90113
print("Running SDE search test...")
91114
asyncio.run(sde_search_test())
115+
print("Running combined code search test...")
116+
asyncio.run(combined_code_search_test())

tests/code_search_tool_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111

1212
import asyncio
1313
from akd.tools.misc import Embedder
14-
from akd.tools.search import SearxNGSearchToolConfig, SearxNGSearchToolInputSchema
14+
from akd.tools.search import SearxNGSearchToolConfig
1515
from akd.tools.code_search import (
1616
CodeSearchToolInputSchema,
1717
LocalRepoCodeSearchTool,
1818
LocalRepoCodeSearchToolConfig,
19-
LocalRepoCodeSearchToolInputSchema,
2019
GitHubCodeSearchTool,
2120
SDECodeSearchTool,
2221
SDECodeSearchToolConfig,
@@ -99,7 +98,7 @@ def test_vector_embedding(embedder=Embedder(model_name="all-MiniLM-L6-v2")):
9998

10099
@pytest.mark.asyncio
101100
async def test_local_repo_search(local_tool):
102-
input_params = LocalRepoCodeSearchToolInputSchema(
101+
input_params = CodeSearchToolInputSchema(
103102
queries=["landslide nepal"],
104103
max_results=3,
105104
)

0 commit comments

Comments
 (0)