Skip to content

Commit f1d1ac0

Browse files
committed
Add observer for signals
1 parent 4b989f4 commit f1d1ac0

File tree

5 files changed

+136
-1
lines changed

5 files changed

+136
-1
lines changed

channels_api/observer.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from functools import partial
2+
from typing import Dict, Any
3+
4+
from asgiref.sync import async_to_sync
5+
from channels.consumer import AsyncConsumer
6+
from channels.layers import get_channel_layer
7+
from django.dispatch import Signal
8+
9+
10+
class Observer:
11+
def __init__(self, func, signal: Signal=None, kwargs=None):
12+
if kwargs is None:
13+
kwargs = {}
14+
self.func = func
15+
self.signal = signal
16+
self.signal_kwargs = kwargs
17+
self._serializer = None
18+
self.signal.connect(
19+
self.handle, **self.signal_kwargs
20+
)
21+
22+
async def __call__(self, *args, **kwargs):
23+
return await self.func(*args, **kwargs)
24+
25+
def serialize(self, signal, *args, **kwargs) -> Dict[str, Any]:
26+
message = {}
27+
if self._serializer:
28+
message = self._serializer(self, signal, *args, **kwargs)
29+
message['type'] = self.func.__name__.replace('_', '.')
30+
31+
return message
32+
33+
def serializer(self, func):
34+
self._serializer = func
35+
36+
def handle(self, signal, *args, **kwargs):
37+
message = self.serialize(signal, *args, **kwargs)
38+
channel_layer = get_channel_layer()
39+
group_name = self.channel_name(signal, *args, **kwargs)
40+
async_to_sync(channel_layer.group_send)(group_name, message)
41+
42+
def channel_name(self, *args, **kwargs):
43+
return '{}-signal-{}'.format(
44+
self.func.__name__.replace('_', '.'),
45+
'.'.join(
46+
arg.lower().replace('_', '.') for arg in
47+
self.signal.providing_args
48+
)
49+
)
50+
51+
def __get__(self, parent, objtype):
52+
53+
if parent is None:
54+
return self
55+
56+
return partial(self.__call__, parent)
57+
58+
async def subscribe(self, consumer: AsyncConsumer, *args, **kwargs):
59+
await consumer.channel_layer.group_add(
60+
self.channel_name(*args, **kwargs),
61+
consumer.channel_name
62+
)
63+
print('subscribed', consumer, self.channel_name(*args, **kwargs), consumer.channel_name)
64+
65+
66+
def observer(signal, **kwargs):
67+
return partial(Observer, signal=signal, kwargs=kwargs)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
django>=1.8
22
djangorestframework>=3
3-
channels>=2.0.2
3+
git+https://github.com/django/channels.git
44
pytest-django
55
pytest-asyncio

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
'coverage~=4.4',
2626
],
2727
},
28+
dependency_links=[
29+
'git+https://github.com/django/channels.git'
30+
],
2831
classifiers=[
2932
'Programming Language :: Python :: 3',
3033
'Programming Language :: Python :: 3.4',

tests/test_consumer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def test_sync_action(self, pk=None):
2525

2626
# Test a normal connection
2727
communicator = WebsocketCommunicator(AConsumer, "/testws/")
28+
2829
connected, _ = await communicator.connect()
30+
2931
assert connected
3032

3133
await communicator.send_json_to(
@@ -63,3 +65,5 @@ def test_sync_action(self, pk=None):
6365
'response_status': 200,
6466
'request_id': 10
6567
}
68+
69+
await communicator.disconnect()

tests/test_observer.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pytest
2+
from channels import DEFAULT_CHANNEL_LAYER
3+
from channels.db import database_sync_to_async
4+
from channels.generic.websocket import AsyncJsonWebsocketConsumer
5+
from channels.layers import channel_layers
6+
from channels.testing import WebsocketCommunicator
7+
from django.contrib.auth import user_logged_in, get_user_model
8+
9+
from channels_api.observer import observer
10+
11+
12+
@pytest.mark.django_db(transaction=True)
13+
@pytest.mark.asyncio
14+
async def test_observer_wrapper(settings):
15+
settings.CHANNEL_LAYERS={
16+
"default": {
17+
"BACKEND": "channels.layers.InMemoryChannelLayer",
18+
"TEST_CONFIG": {
19+
"expiry": 100500,
20+
},
21+
},
22+
}
23+
24+
layer = channel_layers.make_test_backend(DEFAULT_CHANNEL_LAYER)
25+
26+
class TestConsumer(AsyncJsonWebsocketConsumer):
27+
28+
async def dispatch(self, message):
29+
"""
30+
Works out what to do with a message.
31+
"""
32+
await super().dispatch(message)
33+
34+
async def accept(self):
35+
await TestConsumer.handle_user_logged_in.subscribe(self)
36+
await super().accept()
37+
38+
@observer(user_logged_in)
39+
async def handle_user_logged_in(self, *args, **kwargs):
40+
await self.send_json({'message': kwargs,})
41+
42+
communicator = WebsocketCommunicator(TestConsumer, "/testws/")
43+
44+
connected, _ = await communicator.connect()
45+
46+
assert connected
47+
48+
user = await database_sync_to_async(get_user_model().objects.create)(
49+
username='test',
50+
51+
)
52+
53+
await database_sync_to_async(user_logged_in.send)(
54+
sender=user.__class__,
55+
request=None,
56+
user=user
57+
)
58+
59+
response = await communicator.receive_json_from()
60+
61+
assert {'message': {}} == response

0 commit comments

Comments
 (0)