|
1 | 1 | import asyncio
|
| 2 | +import json |
| 3 | +import re |
2 | 4 | import typing
|
3 | 5 | from functools import partial
|
4 | 6 |
|
5 | 7 | from typing import List, Type
|
6 | 8 |
|
7 |
| -from asgiref.sync import async_to_sync |
8 | 9 | from channels.db import database_sync_to_async
|
9 |
| -from channels.generic.websocket import AsyncJsonWebsocketConsumer |
10 |
| - |
| 10 | +from channels.generic.websocket import AsyncJsonWebsocketConsumer, \ |
| 11 | + WebsocketConsumer |
| 12 | +from django.conf.urls import url |
| 13 | +from django.http import HttpRequest, HttpResponse |
| 14 | +from django.http.response import HttpResponseBase |
| 15 | +from django.template.response import SimpleTemplateResponse |
| 16 | +from django.urls import Resolver404, reverse, resolve |
11 | 17 |
|
12 | 18 | from rest_framework.exceptions import PermissionDenied, MethodNotAllowed, \
|
13 |
| - APIException |
| 19 | + APIException, NotFound |
| 20 | +from rest_framework.response import Response |
| 21 | +from rest_framework.viewsets import ViewSet |
14 | 22 |
|
15 | 23 | from channels_api.permissions import BasePermission
|
16 | 24 | from channels_api.settings import api_settings
|
@@ -150,3 +158,121 @@ async def reply(self,
|
150 | 158 | await self.send_json(
|
151 | 159 | payload
|
152 | 160 | )
|
| 161 | + |
| 162 | + |
| 163 | +class DjangoViewConsumer(AsyncWebsocketAPIView): |
| 164 | + |
| 165 | + view = None |
| 166 | + |
| 167 | + @property |
| 168 | + def dumpy_url_config(self): |
| 169 | + return |
| 170 | + |
| 171 | + # maps actions to HTTP methods |
| 172 | + actions = {} # type: Dict[str, str] |
| 173 | + |
| 174 | + async def receive_json(self, content: typing.Dict, **kwargs): |
| 175 | + """ |
| 176 | + Called with decoded JSON content. |
| 177 | + """ |
| 178 | + # TODO assert format, if does not match return message. |
| 179 | + request_id = content.pop('request_id') |
| 180 | + action = content.pop('action') |
| 181 | + await self.handle_action(action, request_id=request_id, **content) |
| 182 | + |
| 183 | + async def handle_action(self, action: str, request_id: str, **kwargs): |
| 184 | + """ |
| 185 | + run the action. |
| 186 | + """ |
| 187 | + try: |
| 188 | + await self.check_permissions(action, **kwargs) |
| 189 | + |
| 190 | + if action not in self.actions: |
| 191 | + raise MethodNotAllowed(method=action) |
| 192 | + |
| 193 | + content, status = await self.call_view( |
| 194 | + action=action, |
| 195 | + **kwargs |
| 196 | + ) |
| 197 | + |
| 198 | + await self.reply( |
| 199 | + action=action, |
| 200 | + request_id=request_id, |
| 201 | + data=content, |
| 202 | + status=status |
| 203 | + ) |
| 204 | + |
| 205 | + except Exception as exc: |
| 206 | + await self.handle_exception( |
| 207 | + exc, |
| 208 | + action=action, |
| 209 | + request_id=request_id |
| 210 | + ) |
| 211 | + |
| 212 | + @database_sync_to_async |
| 213 | + def call_view(self, |
| 214 | + action: str, |
| 215 | + **kwargs): |
| 216 | + |
| 217 | + request = HttpRequest() |
| 218 | + request.path = self.scope.get('path') |
| 219 | + |
| 220 | + request.META = { |
| 221 | + 'HTTP_CONTENT_TYPE': 'application/json', |
| 222 | + 'HTTP_ACCEPT': 'application/json' |
| 223 | + } |
| 224 | + |
| 225 | + for (header_name, value) in self.scope.get('headers', []): |
| 226 | + request.META[header_name] = value |
| 227 | + |
| 228 | + args, view_kwargs = self.get_view_args(action=action, **kwargs) |
| 229 | + |
| 230 | + request.method = self.actions[action] |
| 231 | + request.POST = json.dumps(kwargs.get('data', {})) |
| 232 | + |
| 233 | + view = getattr(self.__class__, 'view') |
| 234 | + |
| 235 | + response = view(request, *args, **view_kwargs) |
| 236 | + |
| 237 | + status = response.status_code |
| 238 | + |
| 239 | + if isinstance(response, Response): |
| 240 | + data = response.data |
| 241 | + try: |
| 242 | + # check if we can json encode it! |
| 243 | + # there must be a better way fo doing this? |
| 244 | + json.dumps(data) |
| 245 | + return data, status |
| 246 | + except Exception as e: |
| 247 | + pass |
| 248 | + if isinstance(response, SimpleTemplateResponse): |
| 249 | + response.render() |
| 250 | + |
| 251 | + response_content = response.content |
| 252 | + if isinstance(response_content, bytes): |
| 253 | + try: |
| 254 | + response_content = response_content.decode('utf-8') |
| 255 | + except Exception as e: |
| 256 | + response_content = response_content.hex() |
| 257 | + return response_content, status |
| 258 | + |
| 259 | + def get_view_args(self, action: str, **kwargs): |
| 260 | + return [], {} |
| 261 | + |
| 262 | + |
| 263 | +def view_as_consumer(wrapped_view, |
| 264 | + mapped_actions=None): |
| 265 | + |
| 266 | + if mapped_actions is None: |
| 267 | + mapped_actions = { |
| 268 | + 'create': 'PUT', |
| 269 | + 'update': 'PATCH', |
| 270 | + 'list': 'GET', |
| 271 | + 'retrieve': 'GET' |
| 272 | + } |
| 273 | + |
| 274 | + class DV(DjangoViewConsumer): |
| 275 | + view = wrapped_view |
| 276 | + actions = mapped_actions |
| 277 | + |
| 278 | + return DV |
0 commit comments