Skip to content

Commit e781e08

Browse files
committed
Add a simpler view that lets one map django views to consumers
1 parent ace569f commit e781e08

File tree

2 files changed

+177
-4
lines changed

2 files changed

+177
-4
lines changed

channels_api/views.py

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
import asyncio
2+
import json
3+
import re
24
import typing
35
from functools import partial
46

57
from typing import List, Type
68

7-
from asgiref.sync import async_to_sync
89
from channels.db import database_sync_to_async
9-
from channels.generic.websocket import AsyncJsonWebsocketConsumer
10-
10+
from channels.generic.websocket import AsyncJsonWebsocketConsumer, \
11+
WebsocketConsumer
12+
from django.conf.urls import url
13+
from django.http import HttpRequest, HttpResponse
14+
from django.http.response import HttpResponseBase
15+
from django.template.response import SimpleTemplateResponse
16+
from django.urls import Resolver404, reverse, resolve
1117

1218
from rest_framework.exceptions import PermissionDenied, MethodNotAllowed, \
13-
APIException
19+
APIException, NotFound
20+
from rest_framework.response import Response
21+
from rest_framework.viewsets import ViewSet
1422

1523
from channels_api.permissions import BasePermission
1624
from channels_api.settings import api_settings
@@ -150,3 +158,121 @@ async def reply(self,
150158
await self.send_json(
151159
payload
152160
)
161+
162+
163+
class DjangoViewConsumer(AsyncWebsocketAPIView):
164+
165+
view = None
166+
167+
@property
168+
def dumpy_url_config(self):
169+
return
170+
171+
# maps actions to HTTP methods
172+
actions = {} # type: Dict[str, str]
173+
174+
async def receive_json(self, content: typing.Dict, **kwargs):
175+
"""
176+
Called with decoded JSON content.
177+
"""
178+
# TODO assert format, if does not match return message.
179+
request_id = content.pop('request_id')
180+
action = content.pop('action')
181+
await self.handle_action(action, request_id=request_id, **content)
182+
183+
async def handle_action(self, action: str, request_id: str, **kwargs):
184+
"""
185+
run the action.
186+
"""
187+
try:
188+
await self.check_permissions(action, **kwargs)
189+
190+
if action not in self.actions:
191+
raise MethodNotAllowed(method=action)
192+
193+
content, status = await self.call_view(
194+
action=action,
195+
**kwargs
196+
)
197+
198+
await self.reply(
199+
action=action,
200+
request_id=request_id,
201+
data=content,
202+
status=status
203+
)
204+
205+
except Exception as exc:
206+
await self.handle_exception(
207+
exc,
208+
action=action,
209+
request_id=request_id
210+
)
211+
212+
@database_sync_to_async
213+
def call_view(self,
214+
action: str,
215+
**kwargs):
216+
217+
request = HttpRequest()
218+
request.path = self.scope.get('path')
219+
220+
request.META = {
221+
'HTTP_CONTENT_TYPE': 'application/json',
222+
'HTTP_ACCEPT': 'application/json'
223+
}
224+
225+
for (header_name, value) in self.scope.get('headers', []):
226+
request.META[header_name] = value
227+
228+
args, view_kwargs = self.get_view_args(action=action, **kwargs)
229+
230+
request.method = self.actions[action]
231+
request.POST = json.dumps(kwargs.get('data', {}))
232+
233+
view = getattr(self.__class__, 'view')
234+
235+
response = view(request, *args, **view_kwargs)
236+
237+
status = response.status_code
238+
239+
if isinstance(response, Response):
240+
data = response.data
241+
try:
242+
# check if we can json encode it!
243+
# there must be a better way fo doing this?
244+
json.dumps(data)
245+
return data, status
246+
except Exception as e:
247+
pass
248+
if isinstance(response, SimpleTemplateResponse):
249+
response.render()
250+
251+
response_content = response.content
252+
if isinstance(response_content, bytes):
253+
try:
254+
response_content = response_content.decode('utf-8')
255+
except Exception as e:
256+
response_content = response_content.hex()
257+
return response_content, status
258+
259+
def get_view_args(self, action: str, **kwargs):
260+
return [], {}
261+
262+
263+
def view_as_consumer(wrapped_view,
264+
mapped_actions=None):
265+
266+
if mapped_actions is None:
267+
mapped_actions = {
268+
'create': 'PUT',
269+
'update': 'PATCH',
270+
'list': 'GET',
271+
'retrieve': 'GET'
272+
}
273+
274+
class DV(DjangoViewConsumer):
275+
view = wrapped_view
276+
actions = mapped_actions
277+
278+
return DV

tests/test_django_view_consumer.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
from channels.testing import WebsocketCommunicator
3+
from rest_framework.response import Response
4+
from rest_framework.views import APIView
5+
6+
from channels_api.views import view_as_consumer
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_decorator():
11+
12+
results = {}
13+
14+
class TestView(APIView):
15+
16+
def get(self, request, format=None):
17+
results['TestView-get'] = True
18+
return Response(['test1', 'test2'])
19+
20+
# Test a normal connection
21+
communicator = WebsocketCommunicator(view_as_consumer(
22+
TestView.as_view()),
23+
"/testws/"
24+
)
25+
26+
connected, _ = await communicator.connect()
27+
assert connected
28+
29+
await communicator.send_json_to(
30+
{
31+
"action": "retrieve",
32+
"request_id": 1
33+
}
34+
)
35+
36+
response = await communicator.receive_json_from()
37+
38+
assert 'TestView-get' in results
39+
40+
assert response == {
41+
'errors': [],
42+
'data': ['test1', 'test2'],
43+
'action': 'retrieve',
44+
'response_status': 200,
45+
'request_id': 1
46+
}
47+

0 commit comments

Comments
 (0)