Skip to content

Commit 21b9eb1

Browse files
committed
Mock reranking
1 parent 148dfeb commit 21b9eb1

File tree

2 files changed

+13
-15
lines changed

2 files changed

+13
-15
lines changed

llm-service/app/services/models/reranking.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,21 @@ def get_noop() -> BaseNodePostprocessor:
6767
def list_available() -> list[ModelResponse]:
6868
return get_provider_class().list_reranking_models()
6969

70+
_TEST_NODES = [
71+
NodeWithScore(node=TextNode(text="test node"), score=0.5),
72+
NodeWithScore(node=TextNode(text="another test node"), score=0.4),
73+
]
74+
7075
@classmethod
7176
def test(cls, model_name: str) -> str:
7277
models = cls.list_available()
7378
for model in models:
7479
if model.model_id == model_name:
75-
node = NodeWithScore(node=TextNode(text="test"), score=0.5)
76-
another_test_node = NodeWithScore(
77-
node=TextNode(text="another test node"), score=0.4
78-
)
7980
reranking_model: BaseNodePostprocessor | None = cls.get(
8081
model_name=model_name
8182
)
8283
if reranking_model:
83-
reranking_model.postprocess_nodes(
84-
[node, another_test_node], None, "test"
85-
)
84+
reranking_model.postprocess_nodes(cls._TEST_NODES, None, "test")
8685
return "ok"
8786
raise HTTPException(status_code=404, detail="Model not found")
8887

llm-service/app/tests/model_provider_mocks/bedrock.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from app.config import settings
5353
from app.services.caii.types import ModelResponse
5454
from app.services.models.providers import BedrockModelProvider
55+
from app.services.models import Reranking
5556
from .testing_chat_history_manager import (
5657
patch_get_chat_history_manager,
5758
)
@@ -182,10 +183,8 @@ def mock_make_api_call(
182183
elif operation_name == "Rerank":
183184
return {
184185
"results": [
185-
# TODO: Is the document store checked prior to this? Do I need to mock that too?
186-
{"index": 0, "relevanceScore": random.random()},
187-
{"index": 1, "relevanceScore": random.random()},
188-
{"index": 2, "relevanceScore": random.random()},
186+
{"index": i, "relevanceScore": random.random()}
187+
for i in range(len(Reranking._TEST_NODES))
189188
]
190189
}
191190
else:
@@ -234,7 +233,7 @@ def test_bedrock_models(client: TestClient) -> None:
234233
] == available_embedding_models
235234
for model_id in available_embedding_models:
236235
response = client.get(f"/llm-service/models/embedding/{model_id}/test")
237-
assert response.status_code == 200 # TODO
236+
assert response.status_code == 200
238237

239238
available_text_models = [
240239
model_id
@@ -258,9 +257,9 @@ def test_bedrock_models(client: TestClient) -> None:
258257
assert [
259258
model["model_id"] for model in response.json()
260259
] == available_reranking_models
261-
# for model_id in available_reranking_models:
262-
# response = client.get(f"/llm-service/models/reranking/{model_id}/test")
263-
# assert response.status_code == 200 # TODO
260+
for model_id in available_reranking_models:
261+
response = client.get(f"/llm-service/models/reranking/{model_id}/test")
262+
assert response.status_code == 200
264263

265264

266265
def test_bedrock_sessions(client: TestClient) -> None:

0 commit comments

Comments
 (0)