Skip to content

Commit 31beac9

Browse files
authored
Merge pull request #5234 from mozilla/test-verify-from-sns-mpp-3852
MPP-3852: Test `verify_from_sns`
2 parents b7395db + bed9d00 commit 31beac9

File tree

2 files changed

+225
-10
lines changed

2 files changed

+225
-10
lines changed

emails/sns.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,16 @@
6464

6565

6666
def verify_from_sns(json_body):
67+
"""
68+
Check that the SNS message was signed by the cetificate.
69+
70+
https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
71+
72+
Only supports SignatureVersion 1. SignatureVersion 2 (SHA256) was added in
73+
September 2022, and requires opt-in.
74+
75+
TODO MPP-3852: Stop using OpenSSL.crypto
76+
"""
6777
pemfile = _grab_keyfile(json_body["SigningCertURL"])
6878
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pemfile)
6979
signature = base64.decodebytes(json_body["Signature"].encode("utf-8"))

emails/tests/sns_tests.py

Lines changed: 215 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,220 @@
1-
from unittest.mock import patch
1+
# ruff: noqa: S303 # Use of insecure SHA1 hash function
22

3+
from base64 import b64encode
4+
from collections.abc import Iterator
5+
from datetime import UTC, datetime, timedelta
6+
from io import BytesIO
7+
from unittest.mock import Mock, patch
8+
9+
from django.core.cache import BaseCache, caches
310
from django.core.exceptions import SuspiciousOperation
4-
from django.test import TestCase
511

6-
from ..sns import _grab_keyfile
12+
import pytest
13+
from cryptography import x509
14+
from cryptography.hazmat.primitives import hashes, serialization
15+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
16+
from cryptography.x509.oid import NameOID
17+
from OpenSSL.crypto import Error
18+
from pytest_django.fixtures import SettingsWrapper
19+
20+
from ..sns import (
21+
NOTIFICATION_HASH_FORMAT,
22+
NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT,
23+
SUBSCRIPTION_HASH_FORMAT,
24+
_grab_keyfile,
25+
verify_from_sns,
26+
)
27+
28+
29+
@pytest.fixture(autouse=True)
30+
def key_cache(settings: SettingsWrapper) -> Iterator[BaseCache]:
31+
"""
32+
Return the cache used for signing certificates.
33+
34+
Clear the cache before and after tests.
35+
"""
36+
key_cache = caches[getattr(settings, "AWS_SNS_KEY_CACHE", "default")]
37+
key_cache.clear()
38+
yield key_cache
39+
key_cache.clear()
40+
41+
42+
@pytest.fixture
43+
def key_and_cert() -> tuple[rsa.RSAPrivateKey, x509.Certificate]:
44+
"""Generate an RSA key and a signing certificate"""
45+
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
46+
name_attributes = [
47+
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
48+
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Oklahoma"),
49+
x509.NameAttribute(NameOID.LOCALITY_NAME, "Tulsa"),
50+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Firefox Private Relay Test"),
51+
x509.NameAttribute(NameOID.COMMON_NAME, "github.com/mozilla/fx-private-relay/"),
52+
]
53+
subject = issuer = x509.Name(name_attributes)
54+
cert = (
55+
x509.CertificateBuilder()
56+
.subject_name(subject)
57+
.issuer_name(issuer)
58+
.public_key(key.public_key())
59+
.serial_number(x509.random_serial_number())
60+
.not_valid_before(datetime.now(UTC) - timedelta(seconds=1))
61+
.not_valid_after(datetime.now(UTC) + timedelta(seconds=60))
62+
.sign(key, hashes.SHA256())
63+
)
64+
return key, cert
65+
66+
67+
@pytest.fixture
68+
def signing_cert_url_and_private_key(
69+
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
70+
key_cache: BaseCache,
71+
settings: SettingsWrapper,
72+
) -> tuple[str, rsa.RSAPrivateKey]:
73+
"""Return the URL and private key for a cached signing certificate."""
74+
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
75+
key, cert = key_and_cert
76+
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
77+
key_cache.set(cert_url, cert_pem)
78+
return cert_url, key
79+
80+
81+
@pytest.fixture
82+
def mock_urlopen() -> Iterator[Mock]:
83+
with patch("emails.sns.urlopen") as mock_urlopen:
84+
yield mock_urlopen
85+
86+
87+
def test_grab_keyfile_checks_cert_url_origin(mock_urlopen: Mock) -> None:
88+
cert_url = "https://attacker.com/cert.pem"
89+
with pytest.raises(SuspiciousOperation):
90+
_grab_keyfile(cert_url)
91+
mock_urlopen.assert_not_called()
92+
93+
94+
def test_grab_keyfile_downloads_valid_certificate(
95+
mock_urlopen: Mock,
96+
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
97+
key_cache: BaseCache,
98+
settings: SettingsWrapper,
99+
) -> None:
100+
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
101+
key, cert = key_and_cert
102+
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
103+
mock_urlopen.return_value = BytesIO(cert_pem)
104+
ret_value = _grab_keyfile(cert_url)
105+
mock_urlopen.assert_called_once_with(cert_url)
106+
assert ret_value == cert_pem
107+
assert key_cache.get(cert_url) == cert_pem
108+
109+
110+
def test_grab_keyfile_reads_from_cache(
111+
mock_urlopen: Mock,
112+
key_cache: BaseCache,
113+
settings: SettingsWrapper,
114+
) -> None:
115+
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
116+
fake_pem = b"I am fake"
117+
key_cache.set(cert_url, fake_pem)
118+
ret_value = _grab_keyfile(cert_url)
119+
assert ret_value == fake_pem
120+
mock_urlopen.assert_not_called()
121+
122+
123+
def test_grab_keyfile_cert_chain_fails(
124+
mock_urlopen: Mock,
125+
key_and_cert: tuple[rsa.RSAPrivateKey, x509.Certificate],
126+
key_cache: BaseCache,
127+
settings: SettingsWrapper,
128+
) -> None:
129+
cert_url = f"https://sns.{settings.AWS_REGION}.amazonaws.com/cert.pem"
130+
key, cert = key_and_cert
131+
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
132+
two_cert_pem = b"\n".join((cert_pem, cert_pem))
133+
mock_urlopen.return_value = BytesIO(two_cert_pem)
134+
with pytest.raises(ValueError, match="Invalid Certificate File"):
135+
_grab_keyfile(cert_url)
136+
137+
138+
def test_verify_from_sns_notification_with_subject_ver1(
139+
signing_cert_url_and_private_key: tuple[str, rsa.RSAPrivateKey],
140+
) -> None:
141+
cert_url, key = signing_cert_url_and_private_key
142+
json_body = {
143+
"Type": "Notification",
144+
"Message": "message",
145+
"MessageId": "message_id",
146+
"Subject": "subject",
147+
"Timestamp": "timestamp",
148+
"TopicArn": "topic_arn",
149+
"SigningCertURL": cert_url,
150+
"SignatureVersion": 1,
151+
}
152+
text_to_sign = NOTIFICATION_HASH_FORMAT.format(**json_body)
153+
signature = key.sign(text_to_sign.encode(), padding.PKCS1v15(), hashes.SHA1())
154+
json_body["Signature"] = b64encode(signature).decode()
155+
ret = verify_from_sns(json_body)
156+
assert ret == json_body
157+
158+
159+
def test_verify_from_sns_notification_with_subject_ver1_fails(
160+
signing_cert_url_and_private_key: tuple[str, rsa.RSAPrivateKey],
161+
) -> None:
162+
cert_url, key = signing_cert_url_and_private_key
163+
json_body = {
164+
"Type": "Notification",
165+
"Message": "message",
166+
"MessageId": "message_id",
167+
"Subject": "subject",
168+
"Timestamp": "timestamp",
169+
"TopicArn": "topic_arn",
170+
"SigningCertURL": cert_url,
171+
"SignatureVersion": 1,
172+
}
173+
text_to_sign = NOTIFICATION_HASH_FORMAT.format(**json_body)
174+
signature = key.sign(text_to_sign.encode(), padding.PKCS1v15(), hashes.SHA1())
175+
json_body["Signature"] = b64encode(signature).decode()
176+
json_body["Message"] = "different message"
177+
with pytest.raises(Error):
178+
verify_from_sns(json_body)
179+
180+
181+
def test_verify_from_sns_notification_without_subject_ver1(
182+
signing_cert_url_and_private_key: tuple[str, rsa.RSAPrivateKey],
183+
) -> None:
184+
cert_url, key = signing_cert_url_and_private_key
185+
json_body = {
186+
"Type": "Notification",
187+
"Message": "message",
188+
"MessageId": "message_id",
189+
"Timestamp": "timestamp",
190+
"TopicArn": "topic_arn",
191+
"SigningCertURL": cert_url,
192+
"SignatureVersion": 1,
193+
}
194+
text_to_sign = NOTIFICATION_WITHOUT_SUBJECT_HASH_FORMAT.format(**json_body)
195+
signature = key.sign(text_to_sign.encode(), padding.PKCS1v15(), hashes.SHA1())
196+
json_body["Signature"] = b64encode(signature).decode()
197+
ret = verify_from_sns(json_body)
198+
assert ret == json_body
7199

8200

9-
class GrabKeyfileTest(TestCase):
10-
@patch("emails.sns.urlopen")
11-
def test_grab_keyfile_checks_cert_url_origin(self, mock_urlopen):
12-
cert_url = "https://attacker.com/cert.pem"
13-
with self.assertRaises(SuspiciousOperation):
14-
_grab_keyfile(cert_url)
15-
mock_urlopen.assert_not_called()
201+
def test_verify_from_sns_subscription_ver1(
202+
signing_cert_url_and_private_key: tuple[str, rsa.RSAPrivateKey],
203+
) -> None:
204+
cert_url, key = signing_cert_url_and_private_key
205+
json_body = {
206+
"Type": "Subscription",
207+
"Message": "message",
208+
"MessageId": "message_id",
209+
"SubscribeURL": "subscribe_url",
210+
"Timestamp": "timestamp",
211+
"Token": "token",
212+
"TopicArn": "topic_arn",
213+
"SigningCertURL": cert_url,
214+
"SignatureVersion": 1,
215+
}
216+
text_to_sign = SUBSCRIPTION_HASH_FORMAT.format(**json_body)
217+
signature = key.sign(text_to_sign.encode(), padding.PKCS1v15(), hashes.SHA1())
218+
json_body["Signature"] = b64encode(signature).decode()
219+
ret = verify_from_sns(json_body)
220+
assert ret == json_body

0 commit comments

Comments
 (0)