1212from pydantic .networks import HttpUrl
1313from scipy .spatial .distance import cdist
1414from tenacity import retry , stop_after_attempt
15+ from sentence_transformers import CrossEncoder
1516
1617from akd .errors import SchemaValidationError
1718from akd .structures import SearchResultItem
@@ -31,7 +32,10 @@ class CodeSearchToolInputSchema(SearchToolInputSchema):
3132 Input schema for the code search tool.
3233 """
3334
34- pass
35+ @computed_field
36+ def top_k (self ) -> int :
37+ """Returns the number of top results to return."""
38+ return self .max_results
3539
3640
3741class CodeSearchToolOutputSchema (SearchToolOutputSchema ):
@@ -120,11 +124,11 @@ def __get_sort_key(result):
120124
121125 # Then check if 'extra' field exists and contains the sort_by key
122126 if (
123- "extra" in result
124- and isinstance (result [ " extra" ] , dict )
125- and sort_by in result [ " extra" ]
127+ result . extra
128+ and isinstance (result . extra , dict )
129+ and sort_by in result . extra
126130 ):
127- return result [ " extra" ] [sort_by ]
131+ return result . extra [sort_by ]
128132
129133 # If key not found anywhere, return a default value that will sort last
130134 # Using float('inf') for numerical sorting or empty string for string sorting
@@ -139,15 +143,85 @@ def __get_sort_key(result):
139143 return results
140144
141145
142- class LocalRepoCodeSearchToolInputSchema ( CodeSearchToolInputSchema ):
146+ class CombinedCodeSearchToolConfig ( CodeSearchToolConfig ):
143147 """
144- Input schema for the local repository code search tool.
148+ Configuration for the combined code search tool.
145149 """
146150
147- @computed_field
148- def top_k (self ) -> int :
149- """Returns the number of top results to return."""
150- return self .max_results
151+ reranker_model_name : str = Field (
152+ "cross-encoder/ms-marco-MiniLM-L6-v2" ,
153+ description = "The model to use for reranking the combined results." ,
154+ )
155+
156+
157+ class CombinedCodeSearchTool (CodeSearchTool ):
158+ """
159+ Tool for performing combined code search using multiple sub-tools and CrossEncoder reranking.
160+ """
161+
162+ input_schema = CodeSearchToolInputSchema
163+ output_schema = CodeSearchToolOutputSchema
164+ config_schema = CombinedCodeSearchToolConfig
165+
166+ def __init__ (
167+ self ,
168+ config : CombinedCodeSearchToolConfig | None = None ,
169+ tools : Optional [list [CodeSearchTool ]] = None ,
170+ debug : bool = False ,
171+ ):
172+ super ().__init__ (config , debug )
173+ self .tools = tools or [
174+ LocalRepoCodeSearchTool (debug = debug ),
175+ GitHubCodeSearchTool (debug = debug ),
176+ SDECodeSearchTool (debug = debug ),
177+ ]
178+ self .reranker_model = CrossEncoder (config .reranker_model_name )
179+
180+ async def _arun (
181+ self ,
182+ params : CodeSearchToolInputSchema ,
183+ ** kwargs ,
184+ ) -> CodeSearchToolOutputSchema :
185+ """Run the combined code search tool and aggregate results from all tools."""
186+ all_results : list [SearchResultItem ] = []
187+
188+ for tool in self .tools :
189+ try :
190+ if self .debug :
191+ logger .debug (f"Running tool: { tool .__class__ .__name__ } " )
192+ result = await tool ._arun (params )
193+ for res in result .results :
194+ res .extra ["tool" ] = tool .__class__ .__name__
195+ all_results .extend (result .results )
196+ except Exception as e :
197+ logger .error (f"Error running tool { tool .__class__ .__name__ } : { e } " )
198+
199+ reranked = self ._rerank_results (all_results , params .queries [0 ])[: params .top_k ]
200+ return self .output_schema (results = reranked , category = "technology" )
201+
202+ def _rerank_results (
203+ self , results : list [SearchResultItem ], query : str
204+ ) -> list [SearchResultItem ]:
205+ """
206+ Rerank results using a CrossEncoder model based on the query and result content.
207+ """
208+
209+ try :
210+ # Generate (query, content) pairs
211+ pairs = [(query , result .content ) for result in results ]
212+
213+ # Get similarity scores from CrossEncoder
214+ scores = self .reranker_model .predict (pairs )
215+
216+ # Attach scores
217+ for score , result in zip (scores , results ):
218+ result .extra ["score" ] = score
219+
220+ # Sort results by score
221+ return self ._sort_results (results , sort_by = "score" )
222+ except Exception as e :
223+ logger .error (f"Reranking failed: { e } " )
224+ return results
151225
152226
153227class LocalRepoCodeSearchToolConfig (CodeSearchToolConfig ):
@@ -173,7 +247,7 @@ class LocalRepoCodeSearchTool(CodeSearchTool):
173247 It automatically downloads the necessary data file if it's not found locally.
174248 """
175249
176- input_schema = LocalRepoCodeSearchToolInputSchema
250+ input_schema = CodeSearchToolInputSchema
177251 output_schema = CodeSearchToolOutputSchema
178252 config_schema = LocalRepoCodeSearchToolConfig
179253
0 commit comments