Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Changed

- Use ``Sequence`` for parameter types rather than ``List`` where applicable by @imnotjames in `#970 <https://github.com/jpadilla/pyjwt/pull/970>`__
- Remove algorithm requirement from JWT API, instead relying on JWS API for enforcement, by @luhn in `#975 <https://github.com/jpadilla/pyjwt/pull/975>`__
- Add JWK support to JWT encode by @luhn in `#979 <https://github.com/jpadilla/pyjwt/pull/979>`__

Fixed
~~~~~
Expand Down
14 changes: 12 additions & 2 deletions jwt/api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys


DEFAULT_ALGORITHM = "DEFAULT_ALGORITHM"


class PyJWS:
header_typ = "JWT"

Expand Down Expand Up @@ -105,8 +108,8 @@ def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
def encode(
self,
payload: bytes,
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = DEFAULT_ALGORITHM,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
Expand All @@ -116,6 +119,11 @@ def encode(

# declare a new var to narrow the type for type checkers
algorithm_: str = algorithm if algorithm is not None else "none"
if algorithm_ == DEFAULT_ALGORITHM:
if isinstance(key, PyJWK):
algorithm_ = key.algorithm_name
else:
algorithm_ = "HS256"

# Prefer headers values if present to function parameters.
if headers:
Expand Down Expand Up @@ -159,6 +167,8 @@ def encode(
signing_input = b".".join(segments)

alg_obj = self.get_algorithm_by_name(algorithm_)
if isinstance(key, PyJWK):
key = key.key
key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key)

Expand Down
4 changes: 2 additions & 2 deletions jwt/api_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def _get_default_options() -> dict[str, bool | list[str]]:
def encode(
self,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
key: AllowedPrivateKeys | PyJWK | str | bytes,
algorithm: str | None = api_jws.DEFAULT_ALGORITHM,
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
Expand Down
31 changes: 31 additions & 0 deletions tests/test_api_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def test_decode_with_non_mapping_header_throws_exception(self, jws):
exception = context.value
assert str(exception) == "Invalid header string: must be a json object"

def test_encode_default_algorithm(self, jws, payload):
msg = jws.encode(payload, "secret")
decoded = jws.decode_complete(msg, "secret", algorithms=["HS256"])
assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
),
}

def test_encode_algorithm_param_should_be_case_sensitive(self, jws, payload):
jws.encode(payload, "secret", algorithm="HS256")

Expand Down Expand Up @@ -193,6 +205,25 @@ def test_encode_with_alg_hs256_and_headers_alg_es256(self, jws, payload):
msg = jws.encode(payload, priv_key, algorithm="HS256", headers={"alg": "ES256"})
assert b"hello world" == jws.decode(msg, pub_key, algorithms=["ES256"])

def test_encode_with_jwk(self, jws, payload):
jwk = PyJWK(
{
"kty": "oct",
"alg": "HS256",
"k": "c2VjcmV0", # "secret"
}
)
msg = jws.encode(payload, key=jwk)
decoded = jws.decode_complete(msg, key=jwk, algorithms=["HS256"])
assert decoded == {
"header": {"alg": "HS256", "typ": "JWT"},
"payload": payload,
"signature": (
b"H\x8a\xf4\xdf3:\xe1\xac\x16E\xd3\xeb\x00\xcf\xfa\xd5\x05\xac"
b"e\xc8@\xb6\x00\xd5\xde\x9aa|s\xcfZB"
),
}

def test_decode_algorithm_param_should_be_case_sensitive(self, jws):
example_jws = (
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256
Expand Down
Loading