Skip to content

MPP-3852: Use cryptography for SNS signature validation, remove pyopenssl and pem #5235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ updates:
patterns:
- "django-stubs"
- "djangorestframework-stubs"
- "types-pyOpenSSL"
- "types-requests"
- package-ecosystem: "github-actions"
directory: "/"
Expand Down
5 changes: 2 additions & 3 deletions emails/management/commands/process_emails_from_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from django.http import HttpResponse

import boto3
import OpenSSL
from botocore.exceptions import ClientError
from codetiming import Timer
from markus.utils import generate_tag
Expand All @@ -37,7 +36,7 @@
CommandFromDjangoSettings,
SettingToLocal,
)
from emails.sns import verify_from_sns
from emails.sns import VerificationFailed, verify_from_sns
from emails.utils import gauge_if_enabled, incr_if_enabled
from emails.views import _sns_inbound_logic, validate_sns_arn_and_type

Expand Down Expand Up @@ -395,7 +394,7 @@ def process_message(self, message: SQSMessage) -> dict[str, Any]:
return results
try:
verified_json_body = verify_from_sns(json_body)
except (KeyError, OpenSSL.crypto.Error) as e:
except (KeyError, VerificationFailed) as e:
logger.error("Failed SNS verification", extra={"error": str(e)})
results["success"] = False
results["error"] = f"Failed SNS verification: {e}"
Expand Down
94 changes: 65 additions & 29 deletions emails/sns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@

import base64
import logging
from typing import Any
from urllib.request import urlopen

from django.conf import settings
from django.core.cache import caches
from django.core.exceptions import SuspiciousOperation
from django.utils.encoding import smart_bytes

import pem
from OpenSSL import crypto
from cryptography import x509
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa

logger = logging.getLogger("events")

NOTIFICATION_HASH_FORMAT = """Message
NOTIFICATION_HASH_FORMAT = """\
Message
Comment on lines +20 to +21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: nice little clean-up.

{Message}
MessageId
{MessageId}
Expand All @@ -29,7 +32,8 @@
{Type}
"""

NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT = """Message
NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT = """\
Message
{Message}
MessageId
{MessageId}
Expand All @@ -41,7 +45,8 @@
{Type}
"""

SUBSCRIPTION_HASH_FORMAT = """Message
SUBSCRIPTION_HASH_FORMAT = """\
Message
{Message}
MessageId
{MessageId}
Expand All @@ -63,30 +68,40 @@
]


def verify_from_sns(json_body):
class VerificationFailed(ValueError):
pass


def verify_from_sns(json_body: dict[str, Any]) -> dict[str, Any]:
"""
Check that the SNS message was signed by the cetificate.
Raise an exception if SNS signature verification fails.

https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html

Only supports SignatureVersion 1. SignatureVersion 2 (SHA256) was added in
September 2022, and requires opt-in.

TODO MPP-3852: Stop using OpenSSL.crypto
"""
pemfile = _grab_keyfile(json_body["SigningCertURL"])
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pemfile)
signature = base64.decodebytes(json_body["Signature"].encode("utf-8"))

signing_cert_url = json_body["SigningCertURL"]
cert_pubkey = _get_signing_public_key(signing_cert_url)
signature = base64.decodebytes(json_body["Signature"].encode())
hash_format = _get_hash_format(json_body)

crypto.verify(
cert, signature, hash_format.format(**json_body).encode("utf-8"), "sha1"
)
try:
cert_pubkey.verify(
signature,
hash_format.format(**json_body).encode(),
padding.PKCS1v15(),
hashes.SHA1(), # noqa: S303 # Use of insecure hash SHA1
)
except InvalidSignature as e:
raise VerificationFailed(
f"Invalid signature with SigningCertURL {signing_cert_url}"
) from e

return json_body


def _get_hash_format(json_body):
def _get_hash_format(json_body: dict[str, Any]) -> str:
message_type = json_body["Type"]
if message_type == "Notification":
if "Subject" in json_body.keys():
Expand All @@ -96,27 +111,48 @@ def _get_hash_format(json_body):
return SUBSCRIPTION_HASH_FORMAT


def _grab_keyfile(cert_url):
def _get_signing_public_key(cert_url: str) -> rsa.RSAPublicKey:
"""
Download the signing certificate and return the public key.

Or, return the cached public key from a previous call.
"""
cert_url_origin = f"https://sns.{settings.AWS_REGION}.amazonaws.com/"
if not (cert_url.startswith(cert_url_origin)):
raise SuspiciousOperation(
f'SNS SigningCertURL "{cert_url}" did not start with "{cert_url_origin}"'
)

key_cache = caches[getattr(settings, "AWS_SNS_KEY_CACHE", "default")]

pemfile = key_cache.get(cert_url)
if not pemfile:
cache_key = f"{cert_url}:public_key"
public_pem = key_cache.get(cache_key)

set_cache = False
if public_pem:
cert_pubkey = serialization.load_pem_public_key(public_pem)
else:
set_cache = True
response = urlopen(cert_url) # noqa: S310 (check for custom scheme)
pemfile = response.read()
cert_pem = response.read()

# Extract the first certificate in the file and confirm it's a valid
# PEM certificate
certificates = pem.parse(smart_bytes(pemfile))
certs = x509.load_pem_x509_certificates(cert_pem)

# A proper certificate file will contain 1 certificate
if len(certificates) != 1:
logger.error("Invalid Certificate File: URL %s", cert_url)
raise ValueError("Invalid Certificate File")
if len(certs) != 1:
raise VerificationFailed(
f"SigningCertURL {cert_url} has {len(certs)} certificates."
)
cert_pubkey = certs[0].public_key()
public_pem = cert_pubkey.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)

if not isinstance(cert_pubkey, rsa.RSAPublicKey):
raise VerificationFailed(f"SigningCertURL {cert_url} is not an RSA key")

key_cache.set(cert_url, pemfile)
return pemfile
if set_cache:
key_cache.set(cache_key, public_pem)
return cert_pubkey
5 changes: 2 additions & 3 deletions emails/tests/mgmt_process_emails_from_sqs_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from django.core.management.base import CommandError
from django.http import HttpResponse

import OpenSSL
import pytest
from botocore.exceptions import ClientError
from markus.testing import MetricsMock
from pytest import LogCaptureFixture
from pytest_django.fixtures import SettingsWrapper

from emails.sns import VerificationFailed
from emails.tests.views_tests import EMAIL_SNS_BODIES
from privaterelay.tests.utils import log_extra, omit_markus_logs

Expand Down Expand Up @@ -168,7 +168,6 @@ def mock_apply_async(
callback: Callable[[Any], None] | None = None,
error_callback: Callable[[BaseException], None] | None = None,
) -> Mock:

def call_wait(timeout: float) -> None:
mock_future._timeouts.append(timeout)
if not mock_future._is_stalled():
Expand Down Expand Up @@ -520,7 +519,7 @@ def test_verify_from_sns_raises_openssl_error(
mock_verify_from_sns: Mock, mock_sqs_client: Mock, caplog: LogCaptureFixture
) -> None:
"""If verify_from_sns raises an exception, the message is deleted."""
mock_verify_from_sns.side_effect = OpenSSL.crypto.Error("failed")
mock_verify_from_sns.side_effect = VerificationFailed("failed")
msg = fake_sqs_message(json.dumps(TEST_SNS_MESSAGE))
mock_sqs_client.return_value = fake_queue([msg], [])
call_command(COMMAND_NAME)
Expand Down
52 changes: 32 additions & 20 deletions emails/tests/sns_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from cryptography.x509.oid import NameOID
from OpenSSL.crypto import Error
from pytest_django.fixtures import SettingsWrapper

from ..sns import (
NOTIFICATION_HASH_FORMAT,
NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT,
SUBSCRIPTION_HASH_FORMAT,
_grab_keyfile,
VerificationFailed,
_get_signing_public_key,
verify_from_sns,
)

Expand Down Expand Up @@ -64,6 +64,17 @@ def key_and_cert() -> tuple[rsa.RSAPrivateKey, x509.Certificate]:
return key, cert


def _cache_key(cert_url: str) -> str:
return f"{cert_url}:public_key"


def _public_pem(cert_or_private_key: rsa.RSAPrivateKey | x509.Certificate) -> bytes:
return cert_or_private_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)


@pytest.fixture
def signing_cert_url_and_private_key(
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
Expand All @@ -73,8 +84,7 @@ def signing_cert_url_and_private_key(
"""Return the URL and private key for a cached signing certificate."""
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
key, cert = key_and_cert
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
key_cache.set(cert_url, cert_pem)
key_cache.set(_cache_key(cert_url), _public_pem(cert))
return cert_url, key


Expand All @@ -84,43 +94,44 @@ def mock_urlopen() -> Iterator[Mock]:
yield mock_urlopen


def test_grab_keyfile_checks_cert_url_origin(mock_urlopen: Mock) -> None:
def test_get_signing_public_key_suspicious_url(mock_urlopen: Mock) -> None:
cert_url = "https://attacker.com/cert.pem"
with pytest.raises(SuspiciousOperation):
_grab_keyfile(cert_url)
_get_signing_public_key(cert_url)
mock_urlopen.assert_not_called()


def test_grab_keyfile_downloads_valid_certificate(
def test_get_signing_public_key_downloads_valid_certificate(
mock_urlopen: Mock,
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
key_cache: BaseCache,
settings: SettingsWrapper,
) -> None:
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
key, cert = key_and_cert
_, cert = key_and_cert
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
mock_urlopen.return_value = BytesIO(cert_pem)
ret_value = _grab_keyfile(cert_url)
ret_value = _get_signing_public_key(cert_url)
mock_urlopen.assert_called_once_with(cert_url)
assert ret_value == cert_pem
assert key_cache.get(cert_url) == cert_pem
assert ret_value == cert.public_key()
assert key_cache.get(_cache_key(cert_url)) == _public_pem(cert)


def test_grab_keyfile_reads_from_cache(
def test_get_signing_public_key_reads_from_cache(
mock_urlopen: Mock,
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
key_cache: BaseCache,
settings: SettingsWrapper,
) -> None:
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
fake_pem = b"I am fake"
key_cache.set(cert_url, fake_pem)
ret_value = _grab_keyfile(cert_url)
assert ret_value == fake_pem
_, cert = key_and_cert
key_cache.set(_cache_key(cert_url), _public_pem(cert))
ret_value = _get_signing_public_key(cert_url)
assert ret_value == cert.public_key()
mock_urlopen.assert_not_called()


def test_grab_keyfile_cert_chain_fails(
def test_get_signing_public_key_cert_chain_fails(
mock_urlopen: Mock,
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
key_cache: BaseCache,
Expand All @@ -131,8 +142,9 @@ def test_grab_keyfile_cert_chain_fails(
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
two_cert_pem = b"\n".join((cert_pem, cert_pem))
mock_urlopen.return_value = BytesIO(two_cert_pem)
with pytest.raises(ValueError, match="Invalid Certificate File"):
_grab_keyfile(cert_url)
expected = f"SigningCertURL {cert_url} has 2 certificates."
with pytest.raises(VerificationFailed, match=expected):
_get_signing_public_key(cert_url)


def test_verify_from_sns_notification_with_subject_ver1(
Expand Down Expand Up @@ -174,7 +186,7 @@ def test_verify_from_sns_notification_with_subject_ver1_fails(
signature = key.sign(text_to_sign.encode(), padding.PKCS1v15(), hashes.SHA1())
json_body["Signature"] = b64encode(signature).decode()
json_body["Message"] = "different message"
with pytest.raises(Error):
with pytest.raises(VerificationFailed):
verify_from_sns(json_body)


Expand Down
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ google-cloud-profiler==4.1.0; python_version < '3.13'
gunicorn==23.0.0
jwcrypto==1.5.6
markus[datadog]==5.1.0
pem==23.1.0
psycopg[c]==3.2.3
PyJWT==2.10.1
python-decouple==3.8
pyOpenSSL==24.2.1
requests==2.32.3
sentry-sdk==2.19.0
whitenoise==6.8.2
Expand Down Expand Up @@ -58,5 +56,4 @@ mypy-boto3-ses==1.35.68
mypy-boto3-sns==1.35.68
mypy-boto3-sqs==1.35.0
mypy==1.13.0
types-pyOpenSSL==24.1.0.20240722
types-requests==2.32.0.20241016
Loading