Skip to content

Commit 75fa8ad

Browse files
authored
Load default robusta model from API (#946)
1 parent 24adc77 commit 75fa8ad

File tree

12 files changed

+270
-43
lines changed

12 files changed

+270
-43
lines changed

conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from tests.llm.conftest import show_llm_summary_report
44
from holmes.core.tracing import readable_timestamp, get_active_branch_name
55
from tests.llm.utils.braintrust import get_braintrust_url
6+
from unittest.mock import MagicMock, patch
7+
import pytest
68

79

810
def pytest_addoption(parser):
@@ -126,3 +128,32 @@ def pytest_report_header(config):
126128
# due to pytest quirks, we need to define this in the main conftest.py - when defined in the llm conftest.py it
127129
# is SOMETIMES picked up and sometimes not, depending on how the test was invokedr
128130
pytest_terminal_summary = show_llm_summary_report
131+
132+
133+
@pytest.fixture(autouse=True)
134+
def patch_supabase(monkeypatch):
135+
monkeypatch.setattr("holmes.core.supabase_dal.ROBUSTA_ACCOUNT_ID", "test-cluster")
136+
monkeypatch.setattr(
137+
"holmes.core.supabase_dal.STORE_URL", "https://fakesupabaseref.supabase.co"
138+
)
139+
monkeypatch.setattr(
140+
"holmes.core.supabase_dal.STORE_API_KEY",
141+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiYW5vbiIsImlhdCI6MTYzNTAwODQ4NywiZXhwIjoxOTUwNTg0NDg3fQ.l8IgkO7TQokGSc9OJoobXIVXsOXkilXl4Ak6SCX5qI8",
142+
)
143+
monkeypatch.setattr("holmes.core.supabase_dal.STORE_EMAIL", "mock_store_user")
144+
monkeypatch.setattr(
145+
"holmes.core.supabase_dal.STORE_PASSWORD", "mock_store_password"
146+
)
147+
148+
149+
@pytest.fixture(autouse=True, scope="session")
150+
def storage_dal_mock():
151+
with patch("holmes.config.SupabaseDal") as MockSupabaseDal:
152+
mock_supabase_dal_instance = MagicMock()
153+
MockSupabaseDal.return_value = mock_supabase_dal_instance
154+
mock_supabase_dal_instance.sign_in.return_value = "mock_supabase_user_id"
155+
mock_supabase_dal_instance.get_ai_credentials.return_value = (
156+
"mock_account_id",
157+
"mock_session_token",
158+
)
159+
yield mock_supabase_dal_instance

holmes/clients/robusta_client.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,16 @@ class HolmesInfo(BaseModel):
1414
latest_version: Optional[str] = None
1515

1616

17+
class RobustaModelsResponse(BaseModel):
18+
model_config = ConfigDict(extra="ignore")
19+
models: List[str]
20+
default_model: Optional[str] = None
21+
22+
1723
@cache
18-
def fetch_robusta_models(account_id, token) -> Optional[List[str]]:
24+
def fetch_robusta_models(
25+
account_id: str, token: str
26+
) -> Optional[RobustaModelsResponse]:
1927
try:
2028
session_request = {"session_token": token, "account_id": account_id}
2129
resp = requests.post(
@@ -25,7 +33,7 @@ def fetch_robusta_models(account_id, token) -> Optional[List[str]]:
2533
)
2634
resp.raise_for_status()
2735
response_json = resp.json()
28-
return response_json.get("models")
36+
return RobustaModelsResponse(**response_json)
2937
except Exception:
3038
logging.exception("Failed to fetch robusta models")
3139
return None

holmes/config.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
# Source plugin imports moved to their respective create methods to speed up startup
3131
if TYPE_CHECKING:
32-
from holmes.core.llm import LLM
3332
from holmes.core.tool_calling_llm import IssueInvestigator, ToolCallingLLM
3433
from holmes.plugins.destinations.slack import SlackDestination
3534
from holmes.plugins.sources.github import GitHubSource
@@ -135,6 +134,7 @@ class Config(RobustaBaseConfig):
135134
_server_tool_executor: Optional[ToolExecutor] = None
136135

137136
_toolset_manager: Optional[ToolsetManager] = None
137+
_default_robusta_model: Optional[str] = None
138138

139139
@property
140140
def toolset_manager(self) -> ToolsetManager:
@@ -170,20 +170,28 @@ def configure_robusta_ai_model(self) -> None:
170170
self._load_default_robusta_config()
171171
return
172172

173-
models = fetch_robusta_models(
173+
robusta_models = fetch_robusta_models(
174174
self.account_id, self.session_token.get_secret_value()
175175
)
176-
if not models:
176+
if not robusta_models or not robusta_models.models:
177177
self._load_default_robusta_config()
178178
return
179179

180-
for model in models:
180+
for model in robusta_models.models:
181181
logging.info(f"Loading Robusta AI model: {model}")
182182
self._model_list[model] = {
183+
"name": model,
183184
"base_url": f"{ROBUSTA_API_ENDPOINT}/llm/{model}",
184185
"is_robusta_model": True,
186+
"model": "gpt-4o", # Robusta AI model is using openai like API.
185187
}
186188

189+
if robusta_models.default_model:
190+
logging.info(
191+
f"Setting default Robusta AI model to: {robusta_models.default_model}"
192+
)
193+
self._default_robusta_model = robusta_models.default_model
194+
187195
except Exception:
188196
logging.exception("Failed to get all robusta models")
189197
# fallback to default behavior
@@ -193,9 +201,12 @@ def _load_default_robusta_config(self):
193201
if self._should_load_robusta_ai() and self.api_key:
194202
logging.info("Loading default Robusta AI model")
195203
self._model_list[ROBUSTA_AI_MODEL_NAME] = {
204+
"name": ROBUSTA_AI_MODEL_NAME,
196205
"base_url": ROBUSTA_API_ENDPOINT,
197206
"is_robusta_model": True,
207+
"model": "gpt-4o",
198208
}
209+
self._default_robusta_model = ROBUSTA_AI_MODEL_NAME
199210

200211
def _should_load_robusta_ai(self) -> bool:
201212
if not self.should_try_robusta_ai:
@@ -525,34 +536,59 @@ def create_slack_destination(self) -> "SlackDestination":
525536
raise ValueError("--slack-channel must be specified")
526537
return SlackDestination(self.slack_token.get_secret_value(), self.slack_channel)
527538

528-
def _get_llm(self, model_key: Optional[str] = None, tracer=None) -> "LLM":
529-
api_key = self.api_key
530-
model = self.model
539+
def _get_model_params(self, model_key: Optional[str] = None) -> dict:
540+
if not self._model_list:
541+
logging.info("No model list setup, using config model")
542+
return {}
543+
544+
if model_key:
545+
model_params = self._model_list.get(model_key)
546+
if model_params is not None:
547+
logging.info(f"Using model: {model_key}")
548+
return model_params.copy()
549+
550+
logging.error(f"Couldn't find model: {model_key} in model list")
551+
552+
if self._default_robusta_model:
553+
model_params = self._model_list.get(self._default_robusta_model)
554+
if model_params is not None:
555+
logging.info(
556+
f"Using default Robusta AI model: {self._default_robusta_model}"
557+
)
558+
return model_params.copy()
559+
560+
logging.error(
561+
f"Couldn't find default Robusta AI model: {self._default_robusta_model} in model list"
562+
)
563+
564+
first_model_params = next(iter(self._model_list.values())).copy()
565+
logging.info("Using first model")
566+
return first_model_params
567+
568+
def _get_llm(self, model_key: Optional[str] = None, tracer=None) -> "DefaultLLM":
569+
model_params = self._get_model_params(model_key)
531570
api_base = self.api_base
532571
api_version = self.api_version
533-
model_params = {}
534-
if self._model_list:
535-
# get requested model or the first credentials if no model requested.
536-
model_params = (
537-
self._model_list.get(model_key, {}).copy()
538-
if model_key
539-
else next(iter(self._model_list.values())).copy()
540-
)
541-
is_robusta_model = model_params.pop("is_robusta_model", False)
542-
if is_robusta_model and self.api_key:
543-
# we set here the api_key since it is being refresh when exprided and not as part of the model loading.
544-
api_key = self.api_key.get_secret_value() # type: ignore
545-
else:
546-
api_key = model_params.pop("api_key", api_key)
547-
model = model_params.pop("model", model)
548-
# It's ok if the model does not have api base and api version, which are defaults to None.
549-
# Handle both api_base and base_url - api_base takes precedence
550-
model_api_base = model_params.pop("api_base", None)
551-
model_base_url = model_params.pop("base_url", None)
552-
api_base = model_api_base or model_base_url or api_base
553-
api_version = model_params.pop("api_version", api_version)
554-
555-
return DefaultLLM(model, api_key, api_base, api_version, model_params, tracer) # type: ignore
572+
573+
is_robusta_model = model_params.pop("is_robusta_model", False)
574+
if is_robusta_model and self.api_key:
575+
# we set here the api_key since it is being refresh when exprided and not as part of the model loading.
576+
api_key = self.api_key.get_secret_value() # type: ignore
577+
else:
578+
api_key = model_params.pop("api_key", None)
579+
580+
model = model_params.pop("model", self.model)
581+
# It's ok if the model does not have api base and api version, which are defaults to None.
582+
# Handle both api_base and base_url - api_base takes precedence
583+
model_api_base = model_params.pop("api_base", None)
584+
model_base_url = model_params.pop("base_url", None)
585+
api_base = model_api_base or model_base_url or api_base
586+
api_version = model_params.pop("api_version", api_version)
587+
model_name = model_params.pop("name", None) or model_key or model
588+
589+
return DefaultLLM(
590+
model, api_key, api_base, api_version, model_params, tracer, model_name
591+
) # type: ignore
556592

557593
def get_models_list(self) -> List[str]:
558594
if self._model_list:

holmes/core/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,16 @@ def __init__(
7272
api_base: Optional[str] = None,
7373
api_version: Optional[str] = None,
7474
args: Optional[Dict] = None,
75-
tracer=None,
75+
tracer: Optional[Any] = None,
76+
name: Optional[str] = None,
7677
):
7778
self.model = model
7879
self.api_key = api_key
7980
self.api_base = api_base
8081
self.api_version = api_version
8182
self.args = args or {}
8283
self.tracer = tracer
84+
self.name = name
8385

8486
self.check_llm(self.model, self.api_key, self.api_base, self.api_version)
8587

poetry.lock

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ pytest-cov = "^6.2.1"
7373
types-python-dateutil = "^2.9.0.20250708"
7474
pytest-dotenv = "^0.5.2"
7575
pytest-shared-session-scope = "^0.4.0"
76+
pytest-responses = "^0.5.1"
7677

7778
[build-system]
7879
requires = ["poetry-core"]

tests/config_class/test_config_api_base_version.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def test_config_get_llm_with_api_base_version():
3535
"""Test that Config._get_llm passes api_base and api_version to DefaultLLM."""
3636
config = Config(
3737
model="test-model",
38-
api_key="test-key",
3938
api_base="https://test.api.base",
4039
api_version="2023-12-01",
4140
)
@@ -49,7 +48,7 @@ def test_config_get_llm_with_api_base_version():
4948
# Check that DefaultLLM was called with the right positional arguments
5049
call_args = mock_default_llm.call_args[0]
5150
assert call_args[0] == "test-model"
52-
assert call_args[1].get_secret_value() == "test-key" # api_key is SecretStr
51+
assert call_args[1] is None
5352
assert call_args[2] == "https://test.api.base"
5453
assert call_args[3] == "2023-12-01"
5554
assert call_args[4] == {}
@@ -85,7 +84,8 @@ def test_config_get_llm_with_model_list_api_base_version(monkeypatch, tmp_path):
8584
"https://model.api.base",
8685
"2024-02-01",
8786
{},
88-
None, # tracer
87+
None, # tracer,
88+
"test-model",
8989
)
9090
assert result == mock_llm_instance
9191

@@ -122,6 +122,7 @@ def test_config_get_llm_model_list_overrides_config_values(monkeypatch, tmp_path
122122
"2024-03-01", # from model list
123123
{},
124124
None, # tracer
125+
"test-model",
125126
)
126127

127128

@@ -156,6 +157,7 @@ def test_config_get_llm_model_list_defaults_to_config_values(monkeypatch, tmp_pa
156157
"2023-01-01", # from config
157158
{},
158159
None, # tracer
160+
"test-model",
159161
)
160162

161163

@@ -202,6 +204,7 @@ def test_config_get_llm_with_non_none_model_list_first_model_fallback(
202204
"2024-01-01", # from first model
203205
{},
204206
None, # tracer
207+
"gpt-4",
205208
)
206209

207210

@@ -255,7 +258,8 @@ def test_config_get_llm_with_specific_model_from_model_list(monkeypatch, tmp_pat
255258
"https://openai.api.base", # from openai-gpt35 model
256259
"2024-04-01", # from openai-gpt35 model
257260
{},
258-
None, # tracer
261+
None, # tracer,
262+
"openai-gpt35",
259263
)
260264

261265

@@ -289,6 +293,7 @@ def test_config_get_llm_with_base_url_only(monkeypatch, tmp_path):
289293
"2024-01-01",
290294
{},
291295
None, # tracer
296+
"test-model",
292297
)
293298

294299

@@ -323,6 +328,7 @@ def test_config_get_llm_api_base_overrides_base_url(monkeypatch, tmp_path):
323328
"2024-01-01",
324329
{},
325330
None, # tracer
331+
"test-model",
326332
)
327333

328334

@@ -358,6 +364,7 @@ def test_config_get_llm_neither_api_base_nor_base_url_uses_config(
358364
"2024-01-01",
359365
{},
360366
None, # tracer
367+
"test-model",
361368
)
362369

363370

0 commit comments

Comments
 (0)