2929
3030# Source plugin imports moved to their respective create methods to speed up startup
3131if TYPE_CHECKING :
32- from holmes .core .llm import LLM
3332 from holmes .core .tool_calling_llm import IssueInvestigator , ToolCallingLLM
3433 from holmes .plugins .destinations .slack import SlackDestination
3534 from holmes .plugins .sources .github import GitHubSource
@@ -135,6 +134,7 @@ class Config(RobustaBaseConfig):
135134 _server_tool_executor : Optional [ToolExecutor ] = None
136135
137136 _toolset_manager : Optional [ToolsetManager ] = None
137+ _default_robusta_model : Optional [str ] = None
138138
139139 @property
140140 def toolset_manager (self ) -> ToolsetManager :
@@ -170,20 +170,28 @@ def configure_robusta_ai_model(self) -> None:
170170 self ._load_default_robusta_config ()
171171 return
172172
173- models = fetch_robusta_models (
173+ robusta_models = fetch_robusta_models (
174174 self .account_id , self .session_token .get_secret_value ()
175175 )
176- if not models :
176+ if not robusta_models or not robusta_models . models :
177177 self ._load_default_robusta_config ()
178178 return
179179
180- for model in models :
180+ for model in robusta_models . models :
181181 logging .info (f"Loading Robusta AI model: { model } " )
182182 self ._model_list [model ] = {
183+ "name" : model ,
183184 "base_url" : f"{ ROBUSTA_API_ENDPOINT } /llm/{ model } " ,
184185 "is_robusta_model" : True ,
186+ "model" : "gpt-4o" , # Robusta AI model is using openai like API.
185187 }
186188
189+ if robusta_models .default_model :
190+ logging .info (
191+ f"Setting default Robusta AI model to: { robusta_models .default_model } "
192+ )
193+ self ._default_robusta_model = robusta_models .default_model
194+
187195 except Exception :
188196 logging .exception ("Failed to get all robusta models" )
189197 # fallback to default behavior
@@ -193,9 +201,12 @@ def _load_default_robusta_config(self):
193201 if self ._should_load_robusta_ai () and self .api_key :
194202 logging .info ("Loading default Robusta AI model" )
195203 self ._model_list [ROBUSTA_AI_MODEL_NAME ] = {
204+ "name" : ROBUSTA_AI_MODEL_NAME ,
196205 "base_url" : ROBUSTA_API_ENDPOINT ,
197206 "is_robusta_model" : True ,
207+ "model" : "gpt-4o" ,
198208 }
209+ self ._default_robusta_model = ROBUSTA_AI_MODEL_NAME
199210
200211 def _should_load_robusta_ai (self ) -> bool :
201212 if not self .should_try_robusta_ai :
@@ -525,34 +536,59 @@ def create_slack_destination(self) -> "SlackDestination":
525536 raise ValueError ("--slack-channel must be specified" )
526537 return SlackDestination (self .slack_token .get_secret_value (), self .slack_channel )
527538
528- def _get_llm (self , model_key : Optional [str ] = None , tracer = None ) -> "LLM" :
529- api_key = self .api_key
530- model = self .model
539+ def _get_model_params (self , model_key : Optional [str ] = None ) -> dict :
540+ if not self ._model_list :
541+ logging .info ("No model list setup, using config model" )
542+ return {}
543+
544+ if model_key :
545+ model_params = self ._model_list .get (model_key )
546+ if model_params is not None :
547+ logging .info (f"Using model: { model_key } " )
548+ return model_params .copy ()
549+
550+ logging .error (f"Couldn't find model: { model_key } in model list" )
551+
552+ if self ._default_robusta_model :
553+ model_params = self ._model_list .get (self ._default_robusta_model )
554+ if model_params is not None :
555+ logging .info (
556+ f"Using default Robusta AI model: { self ._default_robusta_model } "
557+ )
558+ return model_params .copy ()
559+
560+ logging .error (
561+ f"Couldn't find default Robusta AI model: { self ._default_robusta_model } in model list"
562+ )
563+
564+ first_model_params = next (iter (self ._model_list .values ())).copy ()
565+ logging .info ("Using first model" )
566+ return first_model_params
567+
568+ def _get_llm (self , model_key : Optional [str ] = None , tracer = None ) -> "DefaultLLM" :
569+ model_params = self ._get_model_params (model_key )
531570 api_base = self .api_base
532571 api_version = self .api_version
533- model_params = {}
534- if self ._model_list :
535- # get requested model or the first credentials if no model requested.
536- model_params = (
537- self ._model_list .get (model_key , {}).copy ()
538- if model_key
539- else next (iter (self ._model_list .values ())).copy ()
540- )
541- is_robusta_model = model_params .pop ("is_robusta_model" , False )
542- if is_robusta_model and self .api_key :
543- # we set here the api_key since it is being refresh when exprided and not as part of the model loading.
544- api_key = self .api_key .get_secret_value () # type: ignore
545- else :
546- api_key = model_params .pop ("api_key" , api_key )
547- model = model_params .pop ("model" , model )
548- # It's ok if the model does not have api base and api version, which are defaults to None.
549- # Handle both api_base and base_url - api_base takes precedence
550- model_api_base = model_params .pop ("api_base" , None )
551- model_base_url = model_params .pop ("base_url" , None )
552- api_base = model_api_base or model_base_url or api_base
553- api_version = model_params .pop ("api_version" , api_version )
554-
555- return DefaultLLM (model , api_key , api_base , api_version , model_params , tracer ) # type: ignore
572+
573+ is_robusta_model = model_params .pop ("is_robusta_model" , False )
574+ if is_robusta_model and self .api_key :
575+ # we set here the api_key since it is being refresh when exprided and not as part of the model loading.
576+ api_key = self .api_key .get_secret_value () # type: ignore
577+ else :
578+ api_key = model_params .pop ("api_key" , None )
579+
580+ model = model_params .pop ("model" , self .model )
581+ # It's ok if the model does not have api base and api version, which are defaults to None.
582+ # Handle both api_base and base_url - api_base takes precedence
583+ model_api_base = model_params .pop ("api_base" , None )
584+ model_base_url = model_params .pop ("base_url" , None )
585+ api_base = model_api_base or model_base_url or api_base
586+ api_version = model_params .pop ("api_version" , api_version )
587+ model_name = model_params .pop ("name" , None ) or model_key or model
588+
589+ return DefaultLLM (
590+ model , api_key , api_base , api_version , model_params , tracer , model_name
591+ ) # type: ignore
556592
557593 def get_models_list (self ) -> List [str ]:
558594 if self ._model_list :
0 commit comments