Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions jwt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .api_jws import PyJWS
from .api_jwt import (
PyJWT,
decode,
encode,
from .api_jws import (
PyJWS,
get_unverified_header,
register_algorithm,
unregister_algorithm,
)
from .api_jwt import PyJWT, decode, encode
from .exceptions import (
DecodeError,
ExpiredSignatureError,
Expand Down
32 changes: 20 additions & 12 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import binascii
import json
from collections.abc import Mapping
from typing import Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type

from .algorithms import (
Algorithm,
Expand Down Expand Up @@ -77,7 +77,7 @@ def get_algorithms(self):

def encode(
self,
payload: Union[Dict, bytes],
payload: bytes,
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
Expand Down Expand Up @@ -127,15 +127,14 @@ def encode(

return encoded_string.decode("utf-8")

def decode(
def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
complete: bool = False,
**kwargs,
):
) -> Dict[str, Any]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a perfectly fine solution, but if you wanted to maintain backwards-compatibility, this could remain the same, but use @overload.
It would look something like:

@overload
def decode(self, jwt: str, key: str = ..., algorithms: Optional[List[str]] = ..., options: Optional[Dict] = ..., complete: Literal[True], **kwargs: Any) -> Dict[str, Any]: ...
@overload
def decode(self, jwt: str, key: str = ..., algorithms: Optional[List[str]] = ..., options: Optional[Dict] = ..., complete: Literal[False] = ..., **kwargs: Any) -> str: ...
def decode(self, jwt: str, key: str = "", algorithms: Optional[List[str]] = None, options: Optional[Dict] = None, complete: bool = False, **kwargs) -> Union[str, Dict[str, Any]]:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're prepping for version 2.0, so this is the chance to break backward compatibility if we must. IMO, we should aim for the desired interface rather than force support for something we can drop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with the overload idea a bit. Unfortunately, Literal was only introduced in Python 3.8.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right. Not sure you can do the overload without it.
If you wanted to go that route, you could import from typing_extensions for backwards compatibility (adding that to the dependencies). Alternatively, only do the overload in Python 3.8+ (if sys.version_info >= (3, 8), and users of older versions will have to put up with the Union (though I suspect most Mypy users are already using 3.8+).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your current proposal is probably the easier option though.

if options is None:
options = {}
merged_options = {**self.options, **options}
Expand All @@ -153,14 +152,22 @@ def decode(
payload, signing_input, header, signature, key, algorithms
)

if complete:
return {
"payload": payload,
"header": header,
"signature": signature,
}
return {
"payload": payload,
"header": header,
"signature": signature,
}

return payload
def decode(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
**kwargs,
) -> str:
decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
return decoded["payload"]

def get_unverified_header(self, jwt):
"""Returns back the JWT header parameters as a dict()
Expand Down Expand Up @@ -249,6 +256,7 @@ def _validate_kid(self, kid):

_jws_global_obj = PyJWS()
encode = _jws_global_obj.encode
decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm
Expand Down
45 changes: 24 additions & 21 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Type, Union

from .api_jws import PyJWS
from . import api_jws
from .exceptions import (
DecodeError,
ExpiredSignatureError,
Expand All @@ -16,8 +16,11 @@
)


class PyJWT(PyJWS):
header_type = "JWT"
class PyJWT:
def __init__(self, options=None):
if options is None:
options = {}
self.options = {**self._get_default_options(), **options}

@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
Expand All @@ -33,7 +36,7 @@ def _get_default_options() -> Dict[str, Union[bool, List[str]]]:

def encode(
self,
payload: Union[Dict, bytes],
payload: Dict[str, Any],
key: str,
algorithm: str = "HS256",
headers: Optional[Dict] = None,
Expand All @@ -59,20 +62,18 @@ def encode(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")

return super().encode(
return api_jws.encode(
json_payload, key, algorithm, headers, json_encoder
)

def decode(
def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
complete: bool = False,
**kwargs,
) -> Dict[str, Any]:

if options is None:
options = {"verify_signature": True}
else:
Expand All @@ -83,20 +84,16 @@ def decode(
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
)

decoded = super().decode(
decoded = api_jws.decode_complete(
jwt,
key=key,
algorithms=algorithms,
options=options,
complete=complete,
**kwargs,
)

try:
if complete:
payload = json.loads(decoded["payload"])
else:
payload = json.loads(decoded)
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError("Invalid payload string: %s" % e)
if not isinstance(payload, dict):
Expand All @@ -106,11 +103,19 @@ def decode(
merged_options = {**self.options, **options}
self._validate_claims(payload, merged_options, **kwargs)

if complete:
decoded["payload"] = payload
return decoded
decoded["payload"] = payload
return decoded

return payload
def decode(
self,
jwt: str,
key: str = "",
algorithms: List[str] = None,
options: Dict = None,
**kwargs,
) -> Dict[str, Any]:
decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs)
return decoded["payload"]

def _validate_claims(
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
Expand Down Expand Up @@ -215,7 +220,5 @@ def _validate_iss(self, payload, issuer):

_jwt_global_obj = PyJWT()
encode = _jwt_global_obj.encode
decode_complete = _jwt_global_obj.decode_complete
decode = _jwt_global_obj.decode
register_algorithm = _jwt_global_obj.register_algorithm
unregister_algorithm = _jwt_global_obj.unregister_algorithm
get_unverified_header = _jwt_global_obj.get_unverified_header
6 changes: 2 additions & 4 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import urllib.request

from .api_jwk import PyJWKSet
from .api_jwt import decode as decode_token
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError


Expand Down Expand Up @@ -50,8 +50,6 @@ def get_signing_key(self, kid):
return signing_key

def get_signing_key_from_jwt(self, token):
unverified = decode_token(
token, complete=True, options={"verify_signature": False}
)
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))
2 changes: 1 addition & 1 deletion jwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def base64url_decode(input):
return base64.urlsafe_b64decode(input)


def base64url_encode(input):
def base64url_encode(input: bytes) -> bytes:
return base64.urlsafe_b64encode(input).replace(b"=", b"")


Expand Down
21 changes: 21 additions & 0 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,27 @@ def test_decodes_valid_jws(self, jws, payload):

assert decoded_payload == payload

def test_decodes_complete_valid_jws(self, jws, payload):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you start enforcing Mypy on CI, I'd suggest adding some annotations to the tests (and adding tests/ to the Mypy run), as they can help catch API mistakes, as this is the only place in the code that actually uses much of the library. We've done this recently on aiohttp-jinja2, where some annotations were incorrect and only getting noticed by users previously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly would be useful, but it seems a much bigger issue to tackle and so I think is outside the scope for this particular PR. Help there would be welcome.

P.S. I'm not the maintainer of this project, just a contributor.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course, just giving some ideas for the future.

If you want to tackle it, I'd start by adding a .mypy.ini file, for example:
https://github.com/aio-libs/aiohttp-jinja2/blob/master/.mypy.ini

That may be a little too strict for this project, so play around with the options (but, you'll want to change disallow_untyped_defs to True under tests, we skip it because we don't have any important things in the parameters).

Then, adding CI support can be done with something as simple as:
https://github.com/mlowijs/tesla_api/blob/typing/.github/workflows/ci.yaml#L16-L23

Let me know if you need any other help.

example_secret = "secret"
example_jws = (
b"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9."
b"aGVsbG8gd29ybGQ."
b"gEW0pdU4kxPthjtehYdhxB9mMOGajt1xCKlGGXDJ8PM"
)

decoded = jws.decode_complete(
example_jws, example_secret, algorithms=["HS256"]
)

assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"\x80E\xb4\xa5\xd58\x93\x13\xed\x86;^\x85\x87a\xc4"
b"\x1ff0\xe1\x9a\x8e\xddq\x08\xa9F\x19p\xc9\xf0\xf3"
),
}

# 'Control' Elliptic Curve jws created by another library.
# Used to test for regressions that could affect both
# encoding / decoding operations equally (causing tests
Expand Down
26 changes: 23 additions & 3 deletions tests/test_api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@ def test_decodes_valid_jwt(self, jwt):

assert decoded_payload == example_payload

def test_decodes_complete_valid_jwt(self, jwt):
example_payload = {"hello": "world"}
example_secret = "secret"
example_jwt = (
b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
b".eyJoZWxsbyI6ICJ3b3JsZCJ9"
b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8"
)
decoded = jwt.decode_complete(
example_jwt, example_secret, algorithms=["HS256"]
)

assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": example_payload,
"signature": (
b'\xb6\xf6\xa0,2\xe8j"J\xc4\xe2\xaa\xa4\x15\xd2'
b"\x10l\xbbI\x84\xa2}\x98c\x9e\xd8&\xf5\xcbi\xca?"
),
}

def test_load_verify_valid_jwt(self, jwt):
example_payload = {"hello": "world"}
example_secret = "secret"
Expand Down Expand Up @@ -313,13 +334,12 @@ def test_decode_with_expiration_with_leeway(self, jwt, payload):
secret = "secret"
jwt_message = jwt.encode(payload, secret)

decoded_payload, signing, header, signature = jwt._load(jwt_message)

# With 3 seconds leeway, should be ok
for leeway in (3, timedelta(seconds=3)):
jwt.decode(
decoded = jwt.decode(
jwt_message, secret, leeway=leeway, algorithms=["HS256"]
)
assert decoded == payload

# With 1 seconds, should fail
for leeway in (1, timedelta(seconds=1)):
Expand Down