Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11330.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type annotations in Synapse's test suite.
9 changes: 6 additions & 3 deletions tests/handlers/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,23 +193,26 @@ def test_mau_limits_when_disabled(self):

@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
self.store.count_monthly_users = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))

@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
self.store.get_monthly_active_count = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)

self.store.get_monthly_active_count = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
Expand Down
51 changes: 42 additions & 9 deletions tests/rest/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
MutableMapping,
Optional,
Tuple,
Union,
overload,
)
from unittest.mock import patch

import attr
from typing_extensions import Literal

from twisted.web.resource import Resource
from twisted.web.server import Site
Expand All @@ -55,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site)
auth_user_id = attr.ib()

@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> str:
...

@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: int = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> Optional[str]:
...

def create_room_as(
self,
room_creator: Optional[str] = None,
Expand All @@ -64,7 +91,7 @@ def create_room_as(
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> str:
) -> Optional[str]:
"""
Create a room.

Expand Down Expand Up @@ -107,6 +134,8 @@ def create_room_as(

if expect_code == 200:
return channel.json_body["room_id"]
else:
return None

def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
Expand Down Expand Up @@ -176,7 +205,7 @@ def change_membership(
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
expect_code: int = 200,
expect_errcode: str = None,
expect_errcode: Optional[str] = None,
) -> None:
"""
Send a membership state event into a room.
Expand Down Expand Up @@ -260,9 +289,7 @@ def send_event(
txn_id=None,
tok=None,
expect_code=200,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
Expand Down Expand Up @@ -509,7 +536,7 @@ def auth_via_oidc(
went.
"""

cookies = {}
cookies: Dict[str, str] = {}

# if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id:
Expand Down Expand Up @@ -631,7 +658,13 @@ def initiate_sso_login(

# hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider.
location = channel.headers.getRawHeaders("Location")[0]
def get_location(channel: FakeChannel) -> str:
location_values = channel.headers.getRawHeaders("Location")
# Keep mypy happy by asserting that location_values is nonempty
assert location_values
return location_values[0]

location = get_location(channel)
parts = urllib.parse.urlsplit(location)
channel = make_request(
self.hs.get_reactor(),
Expand All @@ -645,7 +678,7 @@ def initiate_sso_login(

assert channel.code == 302
channel.extract_cookies(cookies)
return channel.headers.getRawHeaders("Location")[0]
return get_location(channel)

def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
Expand Down
3 changes: 2 additions & 1 deletion tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MutableMapping,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -226,7 +227,7 @@ def make_request(
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
request: Request = SynapseRequest,
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
Expand Down
31 changes: 15 additions & 16 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request

from synapse import events
from synapse.api.constants import EventTypes, Membership
Expand Down Expand Up @@ -95,16 +96,13 @@ def new(*args, **kwargs):
return _around


T = TypeVar("T")


class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""

def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
def __init__(self, methodName: str):
super().__init__(methodName)

method = getattr(self, methodName)

Expand Down Expand Up @@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase):
Attributes:
servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
hijack_auth: Whether to hijack auth to return the user specified
in user_id.
"""

hijack_auth = True
needs_threadpool = False
hijack_auth: ClassVar[bool] = True
needs_threadpool: ClassVar[bool] = False
servlets: ClassVar[List[RegisterServletsFunc]] = []

def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
def __init__(self, methodName: str):
super().__init__(methodName)

# see if we have any additional config for this test
method = getattr(self, methodName)
Expand Down Expand Up @@ -301,9 +299,10 @@ async def get_user_by_req(request, allow_guest=False, rights="access"):
None,
)

self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)

Expand Down Expand Up @@ -417,7 +416,7 @@ def make_request(
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
Expand Down Expand Up @@ -596,7 +595,7 @@ def register_user(
nonce_str += b"\x00notadmin"

want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
want_mac_digest = want_mac.hexdigest()

body = json.dumps(
{
Expand All @@ -605,7 +604,7 @@ def register_user(
"displayname": displayname,
"password": password,
"admin": admin,
"mac": want_mac,
"mac": want_mac_digest,
"inhibit_login": True,
}
)
Expand Down