Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 91 additions & 42 deletions fastapi_jsonapi/data_layers/sqla_orm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module is a CRUD interface between resource managers and the sqlalchemy ORM"""
import logging
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Tuple, Type, Union

from sqlalchemy import delete, func, select
from sqlalchemy.exc import DBAPIError, IntegrityError, MissingGreenlet, NoResultFound
Expand Down Expand Up @@ -44,6 +44,9 @@

log = logging.getLogger(__name__)

ModelTypeOneOrMany = Union[TypeModel, list[TypeModel]]
ActionTrigger = Literal["create", "update"]


class SqlalchemyDataLayer(BaseDataLayer):
"""Sqlalchemy data layer"""
Expand Down Expand Up @@ -134,12 +137,88 @@ def prepare_id_value(self, col: InstrumentedAttribute, value: Any) -> Any:

return value

async def apply_relationships(self, obj: TypeModel, data_create: BaseJSONAPIItemInSchema) -> None:
async def link_relationship_object(
self,
obj: TypeModel,
relation_name: str,
related_data: Optional[ModelTypeOneOrMany],
action_trigger: ActionTrigger,
):
"""
Links target object with relationship object or objects

:param obj:
:param relation_name:
:param related_data:
:param action_trigger: indicates which one operation triggered relationships applying
"""
# todo: relation name may be different?
setattr(obj, relation_name, related_data)

async def check_object_has_relationship_or_raise(self, obj: TypeModel, relation_name: str):
"""
TODO: move generic code to another method
Checks that there is relationship with relation_name in obj

:param obj:
:param relation_name:
"""
try:
hasattr(obj, relation_name)
except MissingGreenlet:
raise InternalServerError(
detail=(
f"Error of loading the {relation_name!r} relationship. "
f"Please add this relationship to include query parameter explicitly."
),
parameter="include",
)

async def get_related_data_to_link(
self,
related_model: TypeModel,
relationship_info: RelationshipInfo,
relationship_in: Union[
BaseJSONAPIRelationshipDataToOneSchema,
BaseJSONAPIRelationshipDataToManySchema,
],
) -> Optional[ModelTypeOneOrMany]:
"""
Retrieves object or objects to link from database

:param related_model:
:param relationship_info:
:param relationship_in:
"""
if not relationship_in.data:
return [] if relationship_info.many else None

if relationship_info.many:
assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToManySchema)
return await self.get_related_objects_list(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
ids=[r.id for r in relationship_in.data],
)

assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToOneSchema)
return await self.get_related_object(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
id_value=relationship_in.data.id,
)

async def apply_relationships(
self,
obj: TypeModel,
data_create: BaseJSONAPIItemInSchema,
action_trigger: ActionTrigger,
) -> None:
"""
Handles relationships passed in request

:param obj:
:param data_create:
:param action_trigger: indicates which one operation triggered relationships applying
:return:
"""
relationships: "PydanticBaseModel" = data_create.relationships
Expand Down Expand Up @@ -167,45 +246,15 @@ async def apply_relationships(self, obj: TypeModel, data_create: BaseJSONAPIItem
continue

relationship_info: RelationshipInfo = field.field_info.extra["relationship"]

# ...
related_model = get_related_model_cls(type(obj), relation_name)
related_data = await self.get_related_data_to_link(
related_model=related_model,
relationship_info=relationship_info,
relationship_in=relationship_in,
)

if relationship_info.many:
assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToManySchema)

related_data = []
if relationship_in.data:
related_data = await self.get_related_objects_list(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
ids=[r.id for r in relationship_in.data],
)
else:
assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToOneSchema)

if relationship_in.data:
related_data = await self.get_related_object(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
id_value=relationship_in.data.id,
)
else:
setattr(obj, relation_name, None)
continue
try:
hasattr(obj, relation_name)
except MissingGreenlet:
raise InternalServerError(
detail=(
f"Error of loading the {relation_name!r} relationship. "
f"Please add this relationship to include query parameter explicitly."
),
parameter="include",
)

# todo: relation name may be different?
setattr(obj, relation_name, related_data)
await self.check_object_has_relationship_or_raise(obj, relation_name)
await self.link_relationship_object(obj, relation_name, related_data, action_trigger)

async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs: dict) -> TypeModel:
"""
Expand All @@ -222,7 +271,7 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
await self.before_create_object(model_kwargs=model_kwargs, view_kwargs=view_kwargs)

obj = self.model(**model_kwargs)
await self.apply_relationships(obj, data_create)
await self.apply_relationships(obj, data_create, action_trigger="create")

self.session.add(obj)
try:
Expand Down Expand Up @@ -348,7 +397,7 @@ async def update_object(
"""
new_data = data_update.attributes.dict(exclude_unset=True)

await self.apply_relationships(obj, data_update)
await self.apply_relationships(obj, data_update, action_trigger="update")

await self.before_update_object(obj, model_kwargs=new_data, view_kwargs=view_kwargs)

Expand Down