Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 98 additions & 4 deletions akd/tools/code_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pydantic.networks import HttpUrl
from scipy.spatial.distance import cdist
from tenacity import retry, stop_after_attempt
from sentence_transformers import CrossEncoder

from akd.errors import SchemaValidationError
from akd.structures import SearchResultItem
Expand Down Expand Up @@ -120,11 +121,11 @@ def __get_sort_key(result):

# Then check if 'extra' field exists and contains the sort_by key
if (
"extra" in result
and isinstance(result["extra"], dict)
and sort_by in result["extra"]
result.extra
and isinstance(result.extra, dict)
and sort_by in result.extra
):
return result["extra"][sort_by]
return result.extra[sort_by]

# If key not found anywhere, return a default value that will sort last
# Using float('inf') for numerical sorting or empty string for string sorting
Expand All @@ -139,6 +140,99 @@ def __get_sort_key(result):
return results


class CombinedCodeSearchToolInputSchema(CodeSearchToolInputSchema):
"""
Input schema for the combined code search tool.
"""

@computed_field
def top_k(self) -> int:
"""Returns the number of top results to return."""
return self.max_results


class CombinedCodeSearchToolConfig(CodeSearchToolConfig):
"""
Configuration for the combined code search tool.
"""

reranker_model_name: str = Field(
"cross-encoder/ms-marco-MiniLM-L6-v2",
description="The model to use for reranking the combined results.",
)


class CombinedCodeSearchTool(CodeSearchTool):
"""
Tool for performing combined code search using multiple sub-tools and CrossEncoder reranking.
"""

input_schema = CombinedCodeSearchToolInputSchema
output_schema = CodeSearchToolOutputSchema
config_schema = CombinedCodeSearchToolConfig

def __init__(
self,
config: CombinedCodeSearchToolConfig | None = None,
tools: Optional[list[CodeSearchTool]] = None,
debug: bool = False,
):
config = config or self.config_schema()
self.tools = tools or [
LocalRepoCodeSearchTool(debug=debug),
GitHubCodeSearchTool(debug=debug),
SDECodeSearchTool(debug=debug),
]
self.reranker_model = CrossEncoder(config.reranker_model_name)
super().__init__(config, debug)

async def _arun(
self,
params: CodeSearchToolInputSchema,
**kwargs,
) -> CodeSearchToolOutputSchema:
"""Run the combined code search tool and aggregate results from all tools."""
all_results: list[SearchResultItem] = []

for tool in self.tools:
try:
if self.debug:
logger.debug(f"Running tool: {tool.__class__.__name__}")
result = await tool._arun(params)
for res in result.results:
res.extra["tool"] = tool.__class__.__name__
all_results.extend(result.results)
except Exception as e:
logger.error(f"Error running tool {tool.__class__.__name__}: {e}")

reranked = self._rerank_results(all_results, params.queries[0])[: params.top_k]
return self.output_schema(results=reranked, category="technology")

def _rerank_results(
self, results: list[SearchResultItem], query: str
) -> list[SearchResultItem]:
"""
Rerank results using a CrossEncoder model based on the query and result content.
"""

try:
# Generate (query, content) pairs
pairs = [(query, result.content) for result in results]

# Get similarity scores from CrossEncoder
scores = self.reranker_model.predict(pairs)

# Attach scores
for score, result in zip(scores, results):
result.extra["score"] = score

# Sort results by score
return self._sort_results(results, sort_by="score")
except Exception as e:
logger.error(f"Reranking failed: {e}")
return results


class LocalRepoCodeSearchToolInputSchema(CodeSearchToolInputSchema):
"""
Input schema for the local repository code search tool.
Expand Down
31 changes: 31 additions & 0 deletions examples/code_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
GitHubCodeSearchTool,
SDECodeSearchTool,
SDECodeSearchToolConfig,
CombinedCodeSearchTool,
CombinedCodeSearchToolConfig,
CombinedCodeSearchToolInputSchema,
)


Expand Down Expand Up @@ -82,10 +85,38 @@ async def sde_search_test():
print("-" * 100)


# Combined Code Search Tool
async def combined_code_search_test():
"""An async function to run the tool."""

print("Initializing the tool...")
cfg = CombinedCodeSearchToolConfig()
tool = CombinedCodeSearchTool(config=cfg)

search_input = CombinedCodeSearchToolInputSchema(
queries=["landslide nepal"], max_results=10
)

print("Running the search...")
output = await tool._arun(search_input)

print("\n--- Search Results ---")
for result in output.results:
print(result.url)
print(result.content)
print(result.extra["tool"])
print(result.extra["score"])
print("-" * 100)


if __name__ == "__main__":
"""
print("Running local repo search test...")
asyncio.run(local_repo_search_test())
print("Running GitHub search test...")
asyncio.run(github_search_test())
print("Running SDE search test...")
asyncio.run(sde_search_test())
"""
print("Running combined code search test...")
asyncio.run(combined_code_search_test())