Skip to content

Commit 9bb7d9d

Browse files
authored
Pydantic V2 upgrade PR2 (#1196)
1 parent 8d4d5fc commit 9bb7d9d

File tree

10 files changed

+532
-227
lines changed

10 files changed

+532
-227
lines changed

src/marqo/base_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pydantic
2+
from pydantic import ConfigDict
13
from pydantic.v1 import BaseModel
24

35

@@ -19,4 +21,8 @@ class Config(MarqoBaseModel.Config):
1921

2022
class ImmutableStrictBaseModel(StrictBaseModel, ImmutableBaseModel):
2123
class Config(StrictBaseModel.Config, ImmutableBaseModel.Config):
22-
pass
24+
pass
25+
26+
27+
class MarqoBaseModelV2(pydantic.BaseModel):
28+
model_config = ConfigDict(validate_by_name=True, validate_assignment=True)

src/marqo/core/models/marqo_index.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ class MarqoIndex(ImmutableBaseModel, ABC):
279279
marqo_version: str
280280
created_at: int = pydantic.Field(gt=0)
281281
updated_at: int = pydantic.Field(gt=0)
282+
# TODO After upgraded to pydantic v2, _cache can be removed. We can use @cached_property instead
282283
_cache: Dict[str, Any] = PrivateAttr()
283284
version: Optional[int] = pydantic.Field(default=None)
284285

@@ -624,7 +625,9 @@ def index_supports_partial_updates(self) -> bool:
624625
"""
625626
Check if the index supports partial updates.
626627
"""
627-
return self.parsed_marqo_version() >= self._PARTIAL_UPDATE_SUPPORTED_VERSION
628+
return self._cache_or_get(
629+
'index_supports_partial_updates',
630+
lambda: self.parsed_marqo_version() >= self._PARTIAL_UPDATE_SUPPORTED_VERSION)
628631

629632

630633
_PROTECTED_FIELD_NAMES = ['_id', '_tensor_facets', '_highlights', '_score', '_found']

src/marqo/core/semi_structured_vespa_index/semi_structured_document.py

Lines changed: 57 additions & 111 deletions
Large diffs are not rendered by default.

src/marqo/core/semi_structured_vespa_index/semi_structured_vespa_index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def to_marqo_document(self, vespa_document: Dict[str, Any], return_highlights: b
4646
vespa_doc = SemiStructuredVespaDocument.from_vespa_document(vespa_document, marqo_index=self.get_marqo_index())
4747
marqo_doc = vespa_doc.to_marqo_document(marqo_index=self.get_marqo_index())
4848

49-
if return_highlights and vespa_doc.match_features:
49+
if return_highlights and vespa_doc.fixed_fields.match_features:
5050
# Since tensor fields are stored in each individual field, we need to use same logic in structured
5151
# index to extract highlights
5252
marqo_doc[MARQO_DOC_HIGHLIGHTS] = StructuredVespaIndex._extract_highlights(

src/marqo/tensor_search/api.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,28 @@ async def api_validation_exception_handler(request: Request, exc: RequestValidat
207207
)
208208

209209

210+
# For validation error raised from PydanticV1 model classes
210211
@app.exception_handler(pydantic.v1.ValidationError)
211212
async def validation_exception_handler(request, exc: pydantic.v1.ValidationError) -> JSONResponse:
213+
"""Catch pydantic v1 validation errors and rewrite as an InvalidArgError whilst keeping error messages from the ValidationError."""
214+
error_messages = [{
215+
'loc': error.get('loc', ''),
216+
'msg': error.get('msg', ''),
217+
'type': error.get('type', '')
218+
} for error in exc.errors()]
219+
220+
body = {
221+
"message": json.dumps(error_messages),
222+
"code": InvalidArgError.code,
223+
"type": InvalidArgError.error_type,
224+
"link": InvalidArgError.link
225+
}
226+
return JSONResponse(content=body, status_code=InvalidArgError.status_code)
227+
228+
229+
# For validation error raised from PydanticV2 model classes
230+
@app.exception_handler(pydantic.ValidationError)
231+
async def validation_exception_handler(request, exc: pydantic.ValidationError) -> JSONResponse:
212232
"""Catch pydantic validation errors and rewrite as an InvalidArgError whilst keeping error messages from the ValidationError."""
213233
error_messages = [{
214234
'loc': error.get('loc', ''),

src/marqo/vespa/models/query_result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List, Dict, Any, Optional
22

3-
from pydantic.v1 import BaseModel, Field
3+
from pydantic import BaseModel, Field
44

55

66
# See https://docs.vespa.ai/en/reference/default-result-format.html

tests/integ_tests/tensor_search/test_api.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,19 @@
99
import pydantic
1010
from fastapi.exceptions import RequestValidationError
1111
from fastapi.testclient import TestClient
12+
from pydantic.v1.error_wrappers import ErrorWrapper
13+
from pydantic_core import InitErrorDetails, PydanticCustomError
1214

1315
import marqo.tensor_search.api as api
1416
from integ_tests.marqo_test import MarqoTestCase
1517
from marqo import exceptions as base_exceptions
18+
from marqo.api.exceptions import InvalidArgError
1619
from marqo.core import exceptions as core_exceptions
1720
from marqo.core.models.marqo_add_documents_response import MarqoAddDocumentsResponse, MarqoAddDocumentsItem
1821
from marqo.core.models.marqo_index import FieldType
1922
from marqo.core.models.marqo_index_request import FieldRequest
2023
from marqo.tensor_search.enums import EnvVars
24+
from marqo.tensor_search.models.api_models import SearchQuery
2125
from marqo.vespa import exceptions as vespa_exceptions
2226

2327

@@ -609,4 +613,37 @@ class PydanticV1Model(pydantic.v1.BaseModel):
609613

610614
self.assertIn('field required', str(context.exception.errors()))
611615

616+
def test_handle_pydantic_v1_validation_errors(self):
617+
"""Test pydantic v1 ValidationError is correctly handled and converted to error response"""
618+
error = pydantic.v1.ValidationError(errors=[ErrorWrapper(ValueError("some message"), loc="doc")],
619+
model=SearchQuery)
620+
with patch("marqo.tensor_search.tensor_search.search", side_effect=error):
621+
response = self.client.post("/indexes/" + self.structured_index.name + "/search?device=cpu", json={
622+
"q": "test",
623+
"filter": ""
624+
})
625+
626+
self.assertEqual(response.status_code, 400)
627+
self.assertEqual(response.json()["code"], InvalidArgError.code)
628+
self.assertEqual(response.json()["type"], InvalidArgError.error_type)
629+
assert "some message" in response.json()["message"]
630+
631+
def test_handle_pydantic_v2_validation_errors(self):
632+
"""Test pydantic v2 ValidationError is correctly handled and converted to error response"""
633+
error = pydantic.ValidationError.from_exception_data(
634+
title='SearchQuery',
635+
line_errors=[InitErrorDetails(
636+
type=PydanticCustomError('type1', 'some message'), loc=('doc',), input=...)]
637+
)
638+
with patch("marqo.tensor_search.tensor_search.search", side_effect=error):
639+
response = self.client.post("/indexes/" + self.structured_index.name + "/search?device=cpu", json={
640+
"q": "test",
641+
"filter": ""
642+
})
643+
644+
self.assertEqual(response.status_code, 400)
645+
self.assertEqual(response.json()["code"], InvalidArgError.code)
646+
self.assertEqual(response.json()["type"], InvalidArgError.error_type)
647+
assert "some message" in response.json()["message"]
648+
612649
# TODO: Test how marqo handles generic exceptions, including Exception, RunTimeError, ValueError, etc.

0 commit comments

Comments
 (0)