Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ updates:
patterns:
- "django-stubs"
- "djangorestframework-stubs"
- "types-pyOpenSSL"
- "types-requests"
- package-ecosystem: "github-actions"
directory: "/"
Expand Down
5 changes: 2 additions & 3 deletions emails/management/commands/process_emails_from_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from django.http import HttpResponse

import boto3
import OpenSSL
from botocore.exceptions import ClientError
from codetiming import Timer
from markus.utils import generate_tag
Expand All @@ -37,7 +36,7 @@
CommandFromDjangoSettings,
SettingToLocal,
)
from emails.sns import verify_from_sns
from emails.sns import VerificationFailed, verify_from_sns
from emails.utils import gauge_if_enabled, incr_if_enabled
from emails.views import _sns_inbound_logic, validate_sns_arn_and_type

Expand Down Expand Up @@ -395,7 +394,7 @@ def process_message(self, message: SQSMessage) -> dict[str, Any]:
return results
try:
verified_json_body = verify_from_sns(json_body)
except (KeyError, OpenSSL.crypto.Error) as e:
except (KeyError, VerificationFailed) as e:
logger.error("Failed SNS verification", extra={"error": str(e)})
results["success"] = False
results["error"] = f"Failed SNS verification: {e}"
Expand Down
61 changes: 42 additions & 19 deletions emails/sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from django.conf import settings
from django.core.cache import caches
from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import smart_bytes

import pem
from OpenSSL import crypto
from cryptography import x509
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding, rsa

logger = logging.getLogger("events")

NOTIFICATION_HASH_FORMAT = """Message
NOTIFICATION_HASH_FORMAT = """\
Message
Comment on lines +20 to +21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: nice little clean-up.

{Message}
MessageId
{MessageId}
Expand All @@ -29,7 +31,8 @@
{Type}
"""

NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT = """Message
NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT = """\
Message
{Message}
MessageId
{MessageId}
Expand All @@ -41,7 +44,8 @@
{Type}
"""

SUBSCRIPTION_HASH_FORMAT = """Message
SUBSCRIPTION_HASH_FORMAT = """\
Message
{Message}
MessageId
{MessageId}
Expand All @@ -63,6 +67,10 @@
]


class VerificationFailed(ValueError):
pass


def verify_from_sns(json_body):
"""
Check that the SNS message was signed by the cetificate.
Expand All @@ -71,18 +79,29 @@ def verify_from_sns(json_body):

Only supports SignatureVersion 1. SignatureVersion 2 (SHA256) was added in
September 2022, and requires opt-in.

TODO MPP-3852: Stop using OpenSSL.crypto
"""
pemfile = _grab_keyfile(json_body["SigningCertURL"])
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pemfile)
signature = base64.decodebytes(json_body["Signature"].encode("utf-8"))
signing_cert_url = json_body["SigningCertURL"]
pemfile = _grab_keyfile(signing_cert_url)
cert = x509.load_pem_x509_certificate(pemfile)
signature = base64.decodebytes(json_body["Signature"].encode())

hash_format = _get_hash_format(json_body)
cert_pubkey = cert.public_key()
if not isinstance(cert_pubkey, rsa.RSAPublicKey):
raise VerificationFailed(f"SigningCertURL {signing_cert_url} is not an RSA key")

try:
cert_pubkey.verify(
signature,
hash_format.format(**json_body).encode(),
padding.PKCS1v15(),
hashes.SHA1(), # noqa: S303 # Use of insecure hash SHA1
)
except InvalidSignature as e:
raise VerificationFailed(
f"Invalid signature with SigningCertURL {signing_cert_url}"
) from e

crypto.verify(
cert, signature, hash_format.format(**json_body).encode("utf-8"), "sha1"
)
return json_body


Expand All @@ -109,14 +128,18 @@ def _grab_keyfile(cert_url):
if not pemfile:
response = urlopen(cert_url) # noqa: S310 (check for custom scheme)
pemfile = response.read()

# Extract the first certificate in the file and confirm it's a valid
# PEM certificate
certificates = pem.parse(smart_bytes(pemfile))

certs = x509.load_pem_x509_certificates(pemfile)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quibble (non-blocking): this is 2x calls to x509.load_pem_x509_certificate[s] when we could maybe change the code to only call load_pem_x509_certificate 1x and return the single valid cert that it loaded?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took it one step futher and changed _grab_keyfile, returning PEM bytes, to _get_signing_public_key, returning an RSAPublicKey. This way the certificate is loaded and validated the first time it is read, not every time it is fetched from the cache.

# A proper certificate file will contain 1 certificate
if len(certificates) != 1:
logger.error("Invalid Certificate File: URL %s", cert_url)
raise ValueError("Invalid Certificate File")
if len(certs) != 1:
raise VerificationFailed(
f"SigningCertURL {cert_url} has {len(certs)} certificates."
)
cert_pubkey = certs[0].public_key()
if not isinstance(cert_pubkey, rsa.RSAPublicKey):
raise VerificationFailed(f"SigningCertURL {cert_url} is not an RSA key")

key_cache.set(cert_url, pemfile)
return pemfile
5 changes: 2 additions & 3 deletions emails/tests/mgmt_process_emails_from_sqs_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from django.core.management.base import CommandError
from django.http import HttpResponse

import OpenSSL
import pytest
from botocore.exceptions import ClientError
from markus.testing import MetricsMock
from pytest import LogCaptureFixture
from pytest_django.fixtures import SettingsWrapper

from emails.sns import VerificationFailed
from emails.tests.views_tests import EMAIL_SNS_BODIES
from privaterelay.tests.utils import log_extra, omit_markus_logs

Expand Down Expand Up @@ -168,7 +168,6 @@ def mock_apply_async(
callback: Callable[[Any], None] | None = None,
error_callback: Callable[[BaseException], None] | None = None,
) -> Mock:

def call_wait(timeout: float) -> None:
mock_future._timeouts.append(timeout)
if not mock_future._is_stalled():
Expand Down Expand Up @@ -520,7 +519,7 @@ def test_verify_from_sns_raises_openssl_error(
mock_verify_from_sns: Mock, mock_sqs_client: Mock, caplog: LogCaptureFixture
) -> None:
"""If verify_from_sns raises an exception, the message is deleted."""
mock_verify_from_sns.side_effect = OpenSSL.crypto.Error("failed")
mock_verify_from_sns.side_effect = VerificationFailed("failed")
msg = fake_sqs_message(json.dumps(TEST_SNS_MESSAGE))
mock_sqs_client.return_value = fake_queue([msg], [])
call_command(COMMAND_NAME)
Expand Down
7 changes: 4 additions & 3 deletions emails/tests/sns_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.x509.oid import NameOID
from OpenSSL.crypto import Error
from pytest_django.fixtures import SettingsWrapper

from ..sns import (
NOTIFICATION_HASH_FORMAT,
NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT,
SUBSCRIPTION_HASH_FORMAT,
VerificationFailed,
_grab_keyfile,
verify_from_sns,
)
Expand Down Expand Up @@ -131,7 +131,8 @@ def test_grab_keyfile_cert_chain_fails(
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
two_cert_pem = b"\n".join((cert_pem, cert_pem))
mock_urlopen.return_value = BytesIO(two_cert_pem)
with pytest.raises(ValueError, match="Invalid Certificate File"):
expected = f"SigningCertURL {cert_url} has 2 certificates."
with pytest.raises(VerificationFailed, match=expected):
_grab_keyfile(cert_url)


Expand Down Expand Up @@ -174,7 +175,7 @@ def test_verify_from_sns_notification_with_subject_ver1_fails(
signature = key.sign(text_to_sign.encode(), padding.PKCS1v15(), hashes.SHA1())
json_body["Signature"] = b64encode(signature).decode()
json_body["Message"] = "different message"
with pytest.raises(Error):
with pytest.raises(VerificationFailed):
verify_from_sns(json_body)


Expand Down
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ google-cloud-profiler==4.1.0; python_version < '3.13'
gunicorn==23.0.0
jwcrypto==1.5.6
markus[datadog]==5.1.0
pem==23.1.0
psycopg[c]==3.2.3
PyJWT==2.10.1
python-decouple==3.8
pyOpenSSL==24.2.1
requests==2.32.3
sentry-sdk==2.19.0
whitenoise==6.8.2
Expand Down Expand Up @@ -58,5 +56,4 @@ mypy-boto3-ses==1.35.68
mypy-boto3-sns==1.35.68
mypy-boto3-sqs==1.35.0
mypy==1.13.0
types-pyOpenSSL==24.1.0.20240722
types-requests==2.32.0.20241016
Loading