Skip to content

Commit 48fdb78

Browse files
authored
Raise in validate_server_version instead of returning a bool (#19)
1 parent 32ac321 commit 48fdb78

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

go2rtc_client/exceptions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ class Go2RtcClientError(Exception):
2323
"""Base exception for go2rtc client."""
2424

2525

26+
class Go2RtcVersionError(Exception):
27+
"""Base exception for go2rtc client."""
28+
29+
def __init__(
30+
self,
31+
server_version: str | None,
32+
min_version_supported: str,
33+
max_version_supported: str,
34+
) -> None:
35+
"""Initialize."""
36+
self._server_version = server_version
37+
self._min_version_supported = min_version_supported
38+
self._max_version_supported = max_version_supported
39+
40+
def __str__(self) -> str:
41+
"""Return exception message."""
42+
return (
43+
f"server version '{self._server_version}' not "
44+
f">= {self._min_version_supported} and < {self._max_version_supported}"
45+
)
46+
47+
2648
def handle_error[**_P, _R](
2749
func: Callable[_P, Coroutine[Any, Any, _R]],
2850
) -> Callable[_P, Coroutine[Any, Any, _R]]:

go2rtc_client/rest.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mashumaro.mixins.dict import DataClassDictMixin
1313
from yarl import URL
1414

15-
from .exceptions import handle_error
15+
from .exceptions import Go2RtcVersionError, handle_error
1616
from .models import ApplicationInfo, Stream, WebRTCSdpAnswer, WebRTCSdpOffer
1717

1818
if TYPE_CHECKING:
@@ -145,17 +145,24 @@ def __init__(self, websession: ClientSession, server_url: str) -> None:
145145
self.webrtc: Final = _WebRTCClient(self._client)
146146

147147
@handle_error
148-
async def validate_server_version(self) -> bool:
148+
async def validate_server_version(self) -> None:
149149
"""Validate the server version is compatible."""
150150
application_info = await self.application.get_info()
151151
try:
152-
return (
152+
version_supported = (
153153
_MIN_VERSION_SUPPORTED
154154
<= application_info.version
155155
< _MIN_VERSION_UNSUPPORTED
156156
)
157-
except AwesomeVersionException:
158-
_LOGGER.exception(
159-
"Invalid version received from server: %s", application_info.version
157+
except AwesomeVersionException as err:
158+
raise Go2RtcVersionError(
159+
application_info.version if application_info else "unknown",
160+
_MIN_VERSION_SUPPORTED,
161+
_MIN_VERSION_UNSUPPORTED,
162+
) from err
163+
if not version_supported:
164+
raise Go2RtcVersionError(
165+
application_info.version,
166+
_MIN_VERSION_SUPPORTED,
167+
_MIN_VERSION_UNSUPPORTED,
160168
)
161-
return False

tests/test_rest.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from __future__ import annotations
44

5+
from contextlib import AbstractContextManager, nullcontext as does_not_raise
56
import json
6-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Any
78

89
from aiohttp.hdrs import METH_PUT
910
from awesomeversion import AwesomeVersion
1011
import pytest
1112

13+
from go2rtc_client.exceptions import Go2RtcVersionError
1214
from go2rtc_client.models import WebRTCSdpOffer
1315
from go2rtc_client.rest import _ApplicationClient, _StreamClient, _WebRTCClient
1416
from tests import load_fixture
@@ -94,22 +96,25 @@ async def test_streams_add(
9496
responses.assert_called_once_with(url, method=METH_PUT, params=params)
9597

9698

99+
VERSION_ERR = "server version '{}' not >= 1.9.5 and < 2.0.0"
100+
101+
97102
@pytest.mark.parametrize(
98103
("server_version", "expected_result"),
99104
[
100-
("0.0.0", False),
101-
("1.9.4", False),
102-
("1.9.5", True),
103-
("1.9.6", True),
104-
("2.0.0", False),
105-
("BLAH", False),
105+
("0.0.0", pytest.raises(Go2RtcVersionError, match=VERSION_ERR.format("0.0.0"))),
106+
("1.9.4", pytest.raises(Go2RtcVersionError, match=VERSION_ERR.format("1.9.4"))),
107+
("1.9.5", does_not_raise()),
108+
("1.9.6", does_not_raise()),
109+
("2.0.0", pytest.raises(Go2RtcVersionError, match=VERSION_ERR.format("2.0.0"))),
110+
("BLAH", pytest.raises(Go2RtcVersionError, match=VERSION_ERR.format("BLAH"))),
106111
],
107112
)
108113
async def test_version_supported(
109114
responses: aioresponses,
110115
rest_client: Go2RtcRestClient,
111116
server_version: str,
112-
expected_result: bool,
117+
expected_result: AbstractContextManager[Any],
113118
) -> None:
114119
"""Test webrtc offer."""
115120
payload = json.loads(load_fixture("application_info_answer.json"))
@@ -119,7 +124,8 @@ async def test_version_supported(
119124
status=200,
120125
payload=payload,
121126
)
122-
assert await rest_client.validate_server_version() == expected_result
127+
with expected_result:
128+
await rest_client.validate_server_version()
123129

124130

125131
async def test_webrtc_offer(

0 commit comments

Comments
 (0)