Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions fastapi_jsonapi/data_layers/filtering/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Helper to create sqlalchemy filters according to filter querystring parameter"""
import inspect
import logging
from typing import (
Any,
Expand Down Expand Up @@ -133,10 +134,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)
pydantic_types, userspace_types = self._separate_types(types)

if pydantic_types:
func = self._cast_value_with_pydantic
if isinstance(value, list):
clear_value, errors = self._cast_iterable_with_pydantic(pydantic_types, value)
else:
clear_value, errors = self._cast_value_with_pydantic(pydantic_types, value)
func = self._cast_iterable_with_pydantic
clear_value, errors = func(pydantic_types, value, schema_field)

if clear_value is None and userspace_types:
log.warning("Filtering by user type values is not properly tested yet. Use this on your own risk.")
Expand All @@ -151,7 +152,10 @@ def create_filter(self, schema_field: ModelField, model_column, operator, value)

# Если None, при этом поле обязательное (среди типов в аннотации нет None, то кидаем ошибку)
if clear_value is None and not can_be_none:
raise InvalidType(detail=", ".join(errors))
raise InvalidType(
detail=", ".join(errors),
pointer=schema_field.name,
)

return getattr(model_column, self.operator)(clear_value)

Expand Down Expand Up @@ -179,32 +183,65 @@ def _separate_types(self, types: List[Type]) -> Tuple[List[Type], List[Type]]:
]
return pydantic_types, userspace_types

def _validator_requires_model_field(self, validator: Callable) -> bool:
"""
Check if validator accepts the `field` param

:param validator:
:return:
"""
signature = inspect.signature(validator)
parameters = signature.parameters

if "field" not in parameters:
return False

field_param = parameters["field"]
field_type = field_param.annotation

return field_type == "ModelField" or field_type is ModelField

def _cast_value_with_pydantic(
self,
types: List[Type],
value: Any,
schema_field: ModelField,
) -> Tuple[Optional[Any], List[str]]:
result_value, errors = None, []

for type_to_cast in types:
for validator in find_validators(type_to_cast, BaseConfig):
args = [value]
# TODO: some other way to get all the validator's dependencies?
if self._validator_requires_model_field(validator):
args.append(schema_field)
try:
result_value = validator(value)
return result_value, errors
result_value = validator(*args)
except Exception as ex:
errors.append(str(ex))
else:
return result_value, errors

return None, errors

def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple[List, List[str]]:
def _cast_iterable_with_pydantic(
self,
types: List[Type],
values: List,
schema_field: ModelField,
) -> Tuple[List, List[str]]:
type_cast_failed = False
failed_values = []

result_values: List[Any] = []
errors: List[str] = []

for value in values:
casted_value, cast_errors = self._cast_value_with_pydantic(types, value)
casted_value, cast_errors = self._cast_value_with_pydantic(
types,
value,
schema_field,
)
errors.extend(cast_errors)

if casted_value is None:
Expand All @@ -217,7 +254,7 @@ def _cast_iterable_with_pydantic(self, types: List[Type], values: List) -> Tuple

if type_cast_failed:
msg = f"Can't parse items {failed_values} of value {values}"
raise InvalidFilters(msg)
raise InvalidFilters(msg, pointer=schema_field.name)

return result_values, errors

Expand Down
6 changes: 5 additions & 1 deletion fastapi_jsonapi/exceptions/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def __init__(
parameter = parameter or self.parameter
if not errors:
if pointer:
pointer = pointer if pointer.startswith("/") else "/data/" + pointer
pointer = (
pointer
if pointer.startswith("/")
else "/data/" + (pointer if pointer == "id" else "attributes/" + pointer)
)
self.source = {"pointer": pointer}
elif parameter:
self.source = {"parameter": parameter}
Expand Down
13 changes: 13 additions & 0 deletions tests/fixtures/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tests.models import (
Child,
Computer,
CustomUUIDItem,
Parent,
ParentToChildAssociation,
Post,
Expand All @@ -30,6 +31,7 @@
ComputerInSchema,
ComputerPatchSchema,
ComputerSchema,
CustomUUIDItemSchema,
ParentPatchSchema,
ParentSchema,
ParentToChildAssociationSchema,
Expand Down Expand Up @@ -178,6 +180,17 @@ def add_routers(app_plain: FastAPI):
schema_in_post=TaskInSchema,
)

RoutersJSONAPI(
router=router,
path="/custom-uuid-item",
tags=["Custom UUID Item"],
class_detail=DetailViewBaseGeneric,
class_list=ListViewBaseGeneric,
model=CustomUUIDItem,
schema=CustomUUIDItemSchema,
resource_type="custom_uuid_item",
)

atomic = AtomicOperations()

app_plain.include_router(router, prefix="")
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/db_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ async def async_session_plain(async_engine):

@async_fixture(scope="class")
async def async_session(async_session_plain):
async with async_session_plain() as session:
async with async_session_plain() as session: # type: AsyncSession
yield session
await session.rollback()
22 changes: 16 additions & 6 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.orm import declared_attr, relationship
from sqlalchemy.types import CHAR, TypeDecorator

from tests.common import sqla_uri
from tests.common import is_postgres_tests, sqla_uri


class Base:
Expand Down Expand Up @@ -253,33 +253,43 @@ def load_dialect_impl(self, dialect):
return CHAR(32)

def process_bind_param(self, value, dialect):
if value is None:
return value

if not isinstance(value, UUID):
msg = f"Incorrect type got {type(value).__name__}, expected {UUID.__name__}"
raise Exception(msg)

return str(value)

def process_result_value(self, value, dialect):
return UUID(value)
return value and UUID(value)

@property
def python_type(self):
return UUID if self.as_uuid else str


db_uri = sqla_uri()
if "postgres" in db_uri:
if is_postgres_tests():
# noinspection PyPep8Naming
from sqlalchemy.dialects.postgresql import UUID as UUIDType
from sqlalchemy.dialects.postgresql.asyncpg import AsyncpgUUID as UUIDType
elif "sqlite" in db_uri:
UUIDType = CustomUUIDType
else:
msg = "unsupported dialect (custom uuid?)"
raise ValueError(msg)


class IdCast(Base):
id = Column(UUIDType, primary_key=True)
class CustomUUIDItem(Base):
__tablename__ = "custom_uuid_item"
id = Column(UUIDType(as_uuid=True), primary_key=True)

extra_id = Column(
UUIDType(as_uuid=True),
nullable=True,
unique=True,
)


class SelfRelationship(Base):
Expand Down
9 changes: 8 additions & 1 deletion tests/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,14 @@ class TaskSchema(TaskBaseSchema):
# uuid below


class IdCastSchema(BaseModel):
class CustomUUIDItemAttributesSchema(BaseModel):
extra_id: Optional[UUID] = None

class Config:
orm_mode = True


class CustomUUIDItemSchema(CustomUUIDItemAttributesSchema):
id: UUID = Field(client_can_set_id=True)


Expand Down
Loading