Skip to content

Commit f2df5b6

Browse files
authored
Encoding EC keys with a fixed bit length (#990)
1 parent 4ceee5a commit f2df5b6

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Changed
3232
Fixed
3333
~~~~~
3434

35+
- Encode EC keys with a fixed bit length by @etianen in `#990 <https://github.com/jpadilla/pyjwt/pull/990>`__
36+
3537
Added
3638
~~~~~
3739

jwt/algorithms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,13 +583,20 @@ def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
583583
obj: dict[str, Any] = {
584584
"kty": "EC",
585585
"crv": crv,
586-
"x": to_base64url_uint(public_numbers.x).decode(),
587-
"y": to_base64url_uint(public_numbers.y).decode(),
586+
"x": to_base64url_uint(
587+
public_numbers.x,
588+
bit_length=key_obj.curve.key_size,
589+
).decode(),
590+
"y": to_base64url_uint(
591+
public_numbers.y,
592+
bit_length=key_obj.curve.key_size,
593+
).decode(),
588594
}
589595

590596
if isinstance(key_obj, EllipticCurvePrivateKey):
591597
obj["d"] = to_base64url_uint(
592-
key_obj.private_numbers().private_value
598+
key_obj.private_numbers().private_value,
599+
bit_length=key_obj.curve.key_size,
593600
).decode()
594601

595602
if as_dict:

jwt/utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
import binascii
33
import re
4-
from typing import Union
4+
from typing import Optional, Union
55

66
try:
77
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
@@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes:
3737
return base64.urlsafe_b64encode(input).replace(b"=", b"")
3838

3939

40-
def to_base64url_uint(val: int) -> bytes:
40+
def to_base64url_uint(val: int, *, bit_length: Optional[int] = None) -> bytes:
4141
if val < 0:
4242
raise ValueError("Must be a positive integer")
4343

44-
int_bytes = bytes_from_int(val)
44+
int_bytes = bytes_from_int(val, bit_length=bit_length)
4545

4646
if len(int_bytes) == 0:
4747
int_bytes = b"\x00"
@@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int:
6363
return int(binascii.b2a_hex(string), 16)
6464

6565

66-
def bytes_from_int(val: int) -> bytes:
67-
remaining = val
68-
byte_length = 0
69-
70-
while remaining != 0:
71-
remaining >>= 8
72-
byte_length += 1
66+
def bytes_from_int(val: int, *, bit_length: Optional[int] = None) -> bytes:
67+
if bit_length is None:
68+
bit_length = val.bit_length()
69+
byte_length = (bit_length + 7) // 8
7370

7471
return val.to_bytes(byte_length, "big", signed=False)
7572

0 commit comments

Comments
 (0)