Skip to content

Commit 1bfd490

Browse files
committed
Add GenericApiView
1 parent 82a1338 commit 1bfd490

File tree

4 files changed

+258
-3
lines changed

4 files changed

+258
-3
lines changed

channels_api/generics.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from typing import Dict, Type, Optional
2+
3+
from asgiref.sync import async_to_sync
4+
from channels.db import database_sync_to_async
5+
from django.db.models import QuerySet, Model
6+
from rest_framework.generics import get_object_or_404, GenericAPIView
7+
from rest_framework.serializers import Serializer
8+
9+
from channels_api.views import AsyncWebsocketAPIView
10+
11+
GenericAPIView
12+
13+
class GenericAsyncWebsocketAPIView(AsyncWebsocketAPIView):
14+
"""
15+
Base class for all other generic views.
16+
"""
17+
18+
# You'll need to either set these attributes,
19+
# or override `get_queryset()`/`get_serializer_class()`.
20+
# If you are overriding a view method, it is important that you call
21+
# `get_queryset()` instead of accessing the `queryset` property directly,
22+
# as `queryset` will get evaluated only once, and those results are cached
23+
# for all subsequent requests.
24+
25+
queryset = None # type: QuerySet
26+
serializer_class = None # type: Type[Serializer]
27+
28+
# If you want to use object lookups other than pk, set 'lookup_field'.
29+
# For more complex lookup requirements override `get_object()`.
30+
lookup_field = 'pk' # type: str
31+
lookup_url_kwarg = None # type: Optional[str]
32+
33+
# TODO filter_backends
34+
35+
# TODO pagination_class
36+
37+
async def get_queryset(self, action: str, **kwargs) -> QuerySet:
38+
"""
39+
Get the list of items for this view.
40+
This must be an iterable, and may be a queryset.
41+
Defaults to using `self.queryset`.
42+
43+
This method should always be used rather than accessing `self.queryset`
44+
directly, as `self.queryset` gets evaluated only once, and those results
45+
are cached for all subsequent requests.
46+
47+
You may want to override this if you need to provide different
48+
querysets depending on the incoming request.
49+
50+
(Eg. return a list of items that is specific to the user)
51+
"""
52+
assert self.queryset is not None, (
53+
"'%s' should either include a `queryset` attribute, "
54+
"or override the `get_queryset()` method."
55+
% self.__class__.__name__
56+
)
57+
58+
queryset = self.queryset
59+
if isinstance(queryset, QuerySet):
60+
# Ensure queryset is re-evaluated on each request.
61+
queryset = queryset.all()
62+
return queryset
63+
64+
async def get_object(self, action: str, **kwargs) ->Model:
65+
"""
66+
Returns the object the view is displaying.
67+
68+
You may want to override this if you need to provide non-standard
69+
queryset lookups. Eg if objects are referenced using multiple
70+
keyword arguments in the url conf.
71+
"""
72+
queryset = await self.filter_queryset(
73+
queryset=await self.get_queryset(action=action, **kwargs),
74+
action=action,
75+
**kwargs
76+
)
77+
78+
# Perform the lookup filtering.
79+
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
80+
81+
assert lookup_url_kwarg in kwargs, (
82+
'Expected view %s to be called with a URL keyword argument '
83+
'named "%s". Fix your URL conf, or set the `.lookup_field` '
84+
'attribute on the view correctly.' %
85+
(self.__class__.__name__, lookup_url_kwarg)
86+
)
87+
88+
filter_kwargs = {self.lookup_field: kwargs[lookup_url_kwarg]}
89+
90+
obj = await database_sync_to_async(get_object_or_404)(queryset, **filter_kwargs)
91+
# TODO check_object_permissions
92+
93+
return obj
94+
95+
async def get_serializer(
96+
self, action: str,
97+
action_kwargs: Dict=None,
98+
*args, **kwargs):
99+
"""
100+
Return the serializer instance that should be used for validating and
101+
deserializing input, and for serializing output.
102+
"""
103+
serializer_class = await self.get_serializer_class(
104+
action=action, **action_kwargs
105+
)
106+
107+
kwargs['context'] = await self.get_serializer_context(
108+
action=action, **action_kwargs
109+
)
110+
111+
return serializer_class(*args, **kwargs)
112+
113+
async def get_serializer_class(self, action: str, **kwargs) -> Type[Serializer]:
114+
"""
115+
Return the class to use for the serializer.
116+
Defaults to using `self.serializer_class`.
117+
118+
You may want to override this if you need to provide different
119+
serializations depending on the incoming request.
120+
121+
(Eg. admins get full serialization, others get basic serialization)
122+
"""
123+
assert self.serializer_class is not None, (
124+
"'%s' should either include a `serializer_class` attribute, "
125+
"or override the `get_serializer_class()` method."
126+
% self.__class__.__name__
127+
)
128+
129+
return self.serializer_class
130+
131+
async def get_serializer_context(self, action: str, **kwargs):
132+
"""
133+
Extra context provided to the serializer class.
134+
"""
135+
return {
136+
'scope': self.scope,
137+
'consumer': self
138+
}
139+
140+
async def filter_queryset(self, queryset: QuerySet, action: str, **kwargs):
141+
"""
142+
Given a queryset, filter it with whichever filter backend is in use.
143+
144+
You are unlikely to want to override this method, although you may need
145+
to call it either from a list view, or from a custom `get_object`
146+
method if you want to apply the configured filtering backend to the
147+
default queryset.
148+
"""
149+
# TODO filter_backends
150+
151+
return queryset

channels_api/views.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
WebsocketConsumer
1313
from django.conf.urls import url
1414
from django.http import HttpRequest, HttpResponse
15-
from django.http.response import HttpResponseBase
15+
from django.http.response import HttpResponseBase, Http404
1616
from django.template.response import SimpleTemplateResponse
1717
from django.urls import Resolver404, reverse, resolve
1818

@@ -92,6 +92,13 @@ async def handle_exception(self, exc: Exception, action: str, request_id):
9292
status=exc.status_code,
9393
request_id=request_id
9494
)
95+
elif exc == Http404 or isinstance(exc, Http404):
96+
await self.reply(
97+
action=action,
98+
errors=self._format_errors('Not found'),
99+
status=404,
100+
request_id=request_id
101+
)
95102
else:
96103
raise exc
97104

@@ -266,7 +273,7 @@ def view_as_consumer(
266273
wrapped_view: typing.Callable[[HttpRequest], HttpResponse],
267274
mapped_actions: typing.Optional[
268275
typing.Dict[str, str]
269-
] =None) -> Type[AsyncConsumer]:
276+
]=None) -> Type[AsyncConsumer]:
270277
"""
271278
Wrap a django View so that it will be triggered by actions over this json
272279
websocket consumer.

tests/test_generic_view.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pytest
2+
from channels.testing import WebsocketCommunicator
3+
from django.contrib.auth import get_user_model
4+
from rest_framework import serializers
5+
6+
from channels_api.decorators import action
7+
from channels_api.generics import GenericAsyncWebsocketAPIView
8+
9+
10+
@pytest.mark.django_db(transaction=True)
11+
@pytest.mark.asyncio
12+
async def test_generic_view():
13+
14+
results = {}
15+
16+
class UserSerializer(serializers.ModelSerializer):
17+
class Meta:
18+
model = get_user_model()
19+
fields = ('id', 'username', 'email',)
20+
21+
class AView(GenericAsyncWebsocketAPIView):
22+
queryset = get_user_model().objects.all()
23+
serializer_class = UserSerializer
24+
25+
@action()
26+
async def test_async_action(self, reply, pk=None):
27+
user = await self.get_object(action='test_async_action', pk=pk)
28+
29+
s = await self.get_serializer(
30+
action='test_async_action',
31+
action_kwargs={'pk': pk},
32+
instance=user
33+
)
34+
await reply(data=s.data, status=200)
35+
36+
@action()
37+
def test_sync_action(self, pk=None):
38+
results['test_sync_action'] = pk
39+
return {'pk': pk, 'sync': True}, 200
40+
41+
# Test a normal connection
42+
communicator = WebsocketCommunicator(AView, "/testws/")
43+
connected, _ = await communicator.connect()
44+
assert connected
45+
46+
await communicator.send_json_to(
47+
{
48+
"action": "test_async_action",
49+
"pk": 2,
50+
"request_id": 1
51+
}
52+
)
53+
54+
response = await communicator.receive_json_from()
55+
56+
assert response == {
57+
"action": "test_async_action",
58+
"errors": ["Not found"],
59+
"response_status": 404,
60+
"request_id": 1,
61+
"data": None,
62+
}
63+
64+
user = get_user_model().objects.create(
65+
username='test1', email='[email protected]'
66+
)
67+
68+
pk = user.id
69+
70+
assert get_user_model().objects.filter(pk=pk).exists()
71+
72+
await communicator.disconnect()
73+
74+
communicator = WebsocketCommunicator(AView, "/testws/")
75+
connected, _ = await communicator.connect()
76+
77+
assert connected
78+
79+
await communicator.send_json_to(
80+
{
81+
"action": "test_async_action",
82+
"pk": pk,
83+
"request_id": 2
84+
}
85+
)
86+
87+
response = await communicator.receive_json_from()
88+
89+
assert response == {
90+
"action": "test_async_action",
91+
"errors": [],
92+
"response_status": 200,
93+
"request_id": 2,
94+
"data": {'email': '[email protected]', 'id': 1, 'username': 'test1'}
95+
}
96+
97+
await communicator.disconnect()

tests/test_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from channels_api.decorators import action
55
from channels_api.views import AsyncWebsocketAPIView
66

7-
7+
@pytest.mark.django_db(transaction=True)
88
@pytest.mark.asyncio
99
async def test_decorator():
1010

0 commit comments

Comments
 (0)