Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ This project adheres to `Semantic Versioning <https://semver.org/>`__.
`Unreleased <https://github.com/jpadilla/pyjwt/compare/2.10.1...HEAD>`__
------------------------------------------------------------------------

Fixed
~~~~~
- Validate key against allowed types for Algorithm family in `#964 <https://github.com/jpadilla/pyjwt/pull/964>`__

Added
~~~~~

Expand Down
155 changes: 111 additions & 44 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,56 @@
load_ssh_public_key,
)

# pyjwt-964: we use these both for type checking below, as well as for validating the key passed in.
# in Py >= 3.10, we can replace this with the Union types below
ALLOWED_RSA_KEY_TYPES = (RSAPrivateKey, RSAPublicKey)
ALLOWED_EC_KEY_TYPES = (EllipticCurvePrivateKey, EllipticCurvePublicKey)
ALLOWED_OKP_KEY_TYPES = (
Ed25519PrivateKey,
Ed25519PublicKey,
Ed448PrivateKey,
Ed448PublicKey,
)
ALLOWED_KEY_TYPES = (
ALLOWED_RSA_KEY_TYPES + ALLOWED_EC_KEY_TYPES + ALLOWED_OKP_KEY_TYPES
)
ALLOWED_PRIVATE_KEY_TYPES = (
RSAPrivateKey,
EllipticCurvePrivateKey,
Ed25519PrivateKey,
Ed448PrivateKey,
)
ALLOWED_PUBLIC_KEY_TYPES = (
RSAPublicKey,
EllipticCurvePublicKey,
Ed25519PublicKey,
Ed448PublicKey,
)

has_crypto = True
except ModuleNotFoundError:
has_crypto = False


if TYPE_CHECKING:
from typing import TypeAlias

from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)

# Type aliases for convenience in algorithms method signatures
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
AllowedOKPKeys = (
AllowedRSAKeys: TypeAlias = RSAPrivateKey | RSAPublicKey
AllowedECKeys: TypeAlias = EllipticCurvePrivateKey | EllipticCurvePublicKey
AllowedOKPKeys: TypeAlias = (
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
)
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
AllowedPrivateKeys = (
AllowedKeys: TypeAlias = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
AllowedPrivateKeys: TypeAlias = (
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
)
AllowedPublicKeys = (
AllowedPublicKeys: TypeAlias = (
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
)

Expand Down Expand Up @@ -141,6 +174,9 @@ class Algorithm(ABC):
The interface for an algorithm used to sign and verify tokens.
"""

# pyjwt-964: Validate to ensure the key passed in was decoded to the correct cryptography key family
_crypto_key_types: tuple[type[AllowedKeys], ...] | None = None

def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
Expand All @@ -163,6 +199,30 @@ def compute_hash_digest(self, bytestr: bytes) -> bytes:
else:
return bytes(hash_alg(bytestr).digest())

def check_crypto_key_type(self, key: PublicKeyTypes | PrivateKeyTypes):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also considered putting this in a sub-ABC "CryptoAlgorithm", but that seemed like more changes for not as much value. Let me know if you'd rather I do that instead, or if you would prefer an alternate approach.

"""Check that the key belongs to the right cryptographic family.

Note that this method only works when `cryptography` is installed.

Args:
key (Any): Potentially a cryptography key
Raises:
ValueError: if `cryptography` is not installed, or this method is called by a non-cryptography algorithm
InvalidKeyError: if the key doesn't match the expected key classes
"""
if not has_crypto or self._crypto_key_types is None:
raise ValueError(
"This method requires the cryptography library, and should only be used by cryptography-based algorithms."
)

if not isinstance(key, self._crypto_key_types):
valid_classes = (cls.__name__ for cls in self._crypto_key_types)
actual_class = key.__class__.__name__
self_class = self.__class__.__name__
raise InvalidKeyError(
f"Expected one of {valid_classes}, got: {actual_class}. Invalid Key type for {self_class}"
)

@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Expand Down Expand Up @@ -323,11 +383,13 @@ class RSAAlgorithm(Algorithm):
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512

_crypto_key_types = ALLOWED_RSA_KEY_TYPES

def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
if isinstance(key, self._crypto_key_types):
return key

if not isinstance(key, (bytes, str)):
Expand All @@ -337,14 +399,20 @@ def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:

try:
if key_bytes.startswith(b"ssh-rsa"):
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
self.check_crypto_key_type(public_key)
return cast(RSAPublicKey, public_key)
else:
return cast(
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
private_key: PrivateKeyTypes = load_pem_private_key(
key_bytes, password=None
)
self.check_crypto_key_type(private_key)
return cast(RSAPrivateKey, private_key)
except ValueError:
try:
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
public_key = load_pem_public_key(key_bytes)
self.check_crypto_key_type(public_key)
return cast(RSAPublicKey, public_key)
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError(
"Could not parse the provided public key."
Expand Down Expand Up @@ -493,11 +561,13 @@ class ECAlgorithm(Algorithm):
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512

_crypto_key_types = ALLOWED_EC_KEY_TYPES

def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg

def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
if isinstance(key, self._crypto_key_types):
return key

if not isinstance(key, (bytes, str)):
Expand All @@ -510,21 +580,17 @@ def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
# the Verifying Key first.
try:
if key_bytes.startswith(b"ecdsa-sha2-"):
crypto_key = load_ssh_public_key(key_bytes)
public_key: PublicKeyTypes = load_ssh_public_key(key_bytes)
else:
crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
except ValueError:
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
public_key = load_pem_public_key(key_bytes)

# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
) from None

return crypto_key
# Explicit check the key to prevent confusing errors from cryptography
self.check_crypto_key_type(public_key)
return cast(EllipticCurvePublicKey, public_key)
except ValueError:
private_key = load_pem_private_key(key_bytes, password=None)
self.check_crypto_key_type(private_key)
return cast(EllipticCurvePrivateKey, private_key)

def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
Expand Down Expand Up @@ -715,31 +781,32 @@ class OKPAlgorithm(Algorithm):
This class requires ``cryptography>=2.6`` to be installed.
"""

_crypto_key_types = ALLOWED_OKP_KEY_TYPES

def __init__(self, **kwargs: Any) -> None:
pass

def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if isinstance(key, (bytes, str)):
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key_bytes = key.encode("utf-8") if isinstance(key, str) else key

if "-----BEGIN PUBLIC" in key_str:
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
elif "-----BEGIN PRIVATE" in key_str:
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
elif key_str[0:4] == "ssh-":
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
if not isinstance(key, (str, bytes)):
self.check_crypto_key_type(key)
return cast("AllowedOKPKeys", key)

key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key_bytes = key.encode("utf-8") if isinstance(key, str) else key

loaded_key: PublicKeyTypes | PrivateKeyTypes
if "-----BEGIN PUBLIC" in key_str:
loaded_key = load_pem_public_key(key_bytes)
elif "-----BEGIN PRIVATE" in key_str:
loaded_key = load_pem_private_key(key_bytes, password=None)
elif key_str[0:4] == "ssh-":
loaded_key = load_ssh_public_key(key_bytes)
else:
raise InvalidKeyError("Not a public or private key")

# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
key,
(Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for EdDSA algorithms"
)

return key
self.check_crypto_key_type(loaded_key)
return cast("AllowedOKPKeys", loaded_key)

def sign(
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
Expand Down
3 changes: 3 additions & 0 deletions tests/keys/testkey_ed25519.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEIJb2MBNIWqpJ2zwLlbw8JkHNPIBkFCv/g127aQI7dQ1Q
-----END PRIVATE KEY-----
3 changes: 3 additions & 0 deletions tests/keys/testkey_ed25519.pub.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEASmyuOjH4q3bPqsOwf61G4jBH5L2g9kWnCDOp/7IOHKg=
-----END PUBLIC KEY-----
68 changes: 56 additions & 12 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@


class TestAlgorithms:
def test_check_crypto_key_type_should_fail_when_not_using_crypto(self):
"""If has_crypto is False, or if _crypto_key_types is None, then this method should throw."""

algo = NoneAlgorithm()
with pytest.raises(ValueError):
algo.check_crypto_key_type("key") # type: ignore[arg-type]

def test_none_algorithm_should_throw_exception_if_key_is_not_none(self):
algo = NoneAlgorithm()

Expand Down Expand Up @@ -811,6 +818,7 @@ def test_ec_verify_should_return_true_for_test_vector(self):
@crypto_required
class TestOKPAlgorithms:
hello_world_sig = b"Qxa47mk/azzUgmY2StAOguAd4P7YBLpyCfU3JdbaiWnXM4o4WibXwmIHvNYgN3frtE2fcyd8OYEaOiD/KiwkCg=="
hello_world_sig_pem = b"9ueQE7PT8uudHIQb2zZZ7tB7k1X3jeTnIfOVvGCINZejrqQbru1EXPeuMlGcQEZrGkLVcfMmr99W/+byxfppAg=="
hello_world = b"Hello World!"

def test_okp_ed25519_should_reject_non_string_key(self):
Expand All @@ -825,58 +833,94 @@ def test_okp_ed25519_should_reject_non_string_key(self):
with open(key_path("testkey_ed25519.pub")) as keyfile:
algo.prepare_key(keyfile.read())

def test_okp_ed25519_sign_should_generate_correct_signature_value(self):
@pytest.mark.parametrize(
"private_key_file,public_key_file,sig_attr",
[
("testkey_ed25519", "testkey_ed25519.pub", "hello_world_sig"),
("testkey_ed25519.pem", "testkey_ed25519.pub.pem", "hello_world_sig_pem"),
],
)
def test_okp_ed25519_sign_should_generate_correct_signature_value(
self, private_key_file, public_key_file, sig_attr
):
algo = OKPAlgorithm()

jwt_message = self.hello_world

expected_sig = base64.b64decode(self.hello_world_sig)
expected_sig = base64.b64decode(getattr(self, sig_attr))

with open(key_path("testkey_ed25519")) as keyfile:
with open(key_path(private_key_file)) as keyfile:
jwt_key = cast(Ed25519PrivateKey, algo.prepare_key(keyfile.read()))

with open(key_path("testkey_ed25519.pub")) as keyfile:
with open(key_path(public_key_file)) as keyfile:
jwt_pub_key = cast(Ed25519PublicKey, algo.prepare_key(keyfile.read()))

algo.sign(jwt_message, jwt_key)
result = algo.verify(jwt_message, jwt_pub_key, expected_sig)
assert result

def test_okp_ed25519_verify_should_return_false_if_signature_invalid(self):
@pytest.mark.parametrize(
"public_key_file,sig_attr",
[
("testkey_ed25519.pub", "hello_world_sig"),
("testkey_ed25519.pub.pem", "hello_world_sig_pem"),
],
)
def test_okp_ed25519_verify_should_return_false_if_signature_invalid(
self, public_key_file, sig_attr
):
algo = OKPAlgorithm()

jwt_message = self.hello_world
jwt_sig = base64.b64decode(self.hello_world_sig)
jwt_sig = base64.b64decode(getattr(self, sig_attr))

jwt_sig += b"123" # Signature is now invalid

with open(key_path("testkey_ed25519.pub")) as keyfile:
with open(key_path(public_key_file)) as keyfile:
jwt_pub_key = algo.prepare_key(keyfile.read())

result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
assert not result

def test_okp_ed25519_verify_should_return_true_if_signature_valid(self):
@pytest.mark.parametrize(
"public_key_file,sig_attr",
[
("testkey_ed25519.pub", "hello_world_sig"),
("testkey_ed25519.pub.pem", "hello_world_sig_pem"),
],
)
def test_okp_ed25519_verify_should_return_true_if_signature_valid(
self, public_key_file, sig_attr
):
algo = OKPAlgorithm()

jwt_message = self.hello_world
jwt_sig = base64.b64decode(self.hello_world_sig)
jwt_sig = base64.b64decode(getattr(self, sig_attr))

with open(key_path("testkey_ed25519.pub")) as keyfile:
with open(key_path(public_key_file)) as keyfile:
jwt_pub_key = algo.prepare_key(keyfile.read())

result = algo.verify(jwt_message, jwt_pub_key, jwt_sig)
assert result

def test_okp_ed25519_prepare_key_should_be_idempotent(self):
@pytest.mark.parametrize(
"public_key_file", ("testkey_ed25519.pub", "testkey_ed25519.pub.pem")
)
def test_okp_ed25519_prepare_key_should_be_idempotent(self, public_key_file):
algo = OKPAlgorithm()

with open(key_path("testkey_ed25519.pub")) as keyfile:
with open(key_path(public_key_file)) as keyfile:
jwt_pub_key_first = algo.prepare_key(keyfile.read())
jwt_pub_key_second = algo.prepare_key(jwt_pub_key_first)

assert jwt_pub_key_first == jwt_pub_key_second

def test_okp_ed25519_prepare_key_should_reject_invalid_key(self):
algo = OKPAlgorithm()

with pytest.raises(InvalidKeyError):
algo.prepare_key("not a valid key")

def test_okp_ed25519_jwk_private_key_should_parse_and_verify(self):
algo = OKPAlgorithm()

Expand Down
Loading