Skip to content

Commit 76020d3

Browse files
committed
Parse and validate extensions
1 parent 1a0ffa0 commit 76020d3

File tree

7 files changed

+212
-22
lines changed

7 files changed

+212
-22
lines changed

libs/wire-api/src/Wire/API/MLS/CipherSuite.hs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
module Wire.API.MLS.CipherSuite where
2121

22-
import Data.Binary
22+
import Data.Word
23+
import Wire.API.MLS.Serialisation
2324

2425
newtype CipherSuite = CipherSuite {cipherSuiteNumber :: Word16}
25-
deriving newtype (Binary)
26+
deriving newtype (ParseMLS)

libs/wire-api/src/Wire/API/MLS/Credential.hs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,13 @@ data Credential = BasicCredential
3434
bcSignatureScheme :: SignatureScheme,
3535
bcSignatureKey :: ByteString
3636
}
37-
deriving (Generic)
3837

39-
instance Binary Credential
38+
instance ParseMLS Credential where
39+
parseMLS =
40+
BasicCredential
41+
<$> parseMLSBytes @Word16
42+
<*> parseMLS
43+
<*> parseMLSBytes @Word16
4044

4145
data CredentialType = BasicCredentialType
4246

@@ -47,7 +51,7 @@ credentialType (BasicCredential _ _ _) = BasicCredentialType
4751
--
4852
-- See <https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-signaturescheme>.
4953
newtype SignatureScheme = SignatureScheme {signatureSchemeNumber :: Word16}
50-
deriving newtype (Binary)
54+
deriving newtype (ParseMLS)
5155

5256
data ClientIdentity = ClientIdentity
5357
{ ciDomain :: Domain,

libs/wire-api/src/Wire/API/MLS/KeyPackage.hs

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
2+
{-# LANGUAGE StandaloneKindSignatures #-}
3+
14
-- This file is part of the Wire Server implementation.
25
--
36
-- Copyright (C) 2022 Wire Swiss GmbH <[email protected]>
@@ -17,14 +20,19 @@
1720

1821
module Wire.API.MLS.KeyPackage where
1922

23+
import Control.Applicative
24+
import Control.Error.Util
2025
import Data.Aeson (FromJSON)
2126
import Data.Binary
2227
import Data.Json.Util
2328
import Data.Schema
29+
import Data.Singletons
30+
import Data.Singletons.TH
2431
import qualified Data.Swagger as S
2532
import Imports
2633
import Wire.API.MLS.CipherSuite
2734
import Wire.API.MLS.Credential
35+
import Wire.API.MLS.Proposal
2836
import Wire.API.MLS.Serialisation
2937

3038
data KeyPackageUpload = KeyPackageUpload
@@ -47,16 +55,77 @@ instance ToSchema KeyPackageData where
4755
--------------------------------------------------------------------------------
4856

4957
data ProtocolVersion = ProtocolReserved | ProtocolMLS
50-
deriving stock (Enum)
51-
deriving (Binary) via EnumBinary Word8 ProtocolVersion
58+
deriving stock (Bounded, Enum)
59+
deriving (ParseMLS) via EnumMLS Word8 ProtocolVersion
5260

5361
data Extension = Extension
5462
{ extType :: Word16,
5563
extData :: ByteString
5664
}
57-
deriving (Generic)
5865

59-
instance Binary Extension
66+
instance ParseMLS Extension where
67+
parseMLS = Extension <$> parseMLS <*> parseMLSBytes @Word32
68+
69+
data ExtensionTag
70+
= ReservedExtensionTag
71+
| CapabilitiesExtensionTag
72+
| LifetimeExtensionTag
73+
deriving (Bounded, Enum)
74+
75+
$(genSingletons [''ExtensionTag])
76+
77+
type family ExtensionType (t :: ExtensionTag) :: * where
78+
ExtensionType 'ReservedExtensionTag = ()
79+
ExtensionType 'CapabilitiesExtensionTag = Capabilities
80+
ExtensionType 'LifetimeExtensionTag = Lifetime
81+
82+
parseExtension :: Sing t -> Get (ExtensionType t)
83+
parseExtension SReservedExtensionTag = pure ()
84+
parseExtension SCapabilitiesExtensionTag = parseMLS
85+
parseExtension SLifetimeExtensionTag = parseMLS
86+
87+
data SomeExtension where
88+
SomeExtension :: Sing t -> ExtensionType t -> SomeExtension
89+
90+
decodeExtension :: Extension -> Maybe SomeExtension
91+
decodeExtension e = do
92+
t <- safeToEnum (fromIntegral (extType e))
93+
hush $
94+
withSomeSing t $ \st ->
95+
decodeMLSWith' (SomeExtension st <$> parseExtension st) (extData e)
96+
97+
-- t <- parse
98+
-- parseMLS = do
99+
-- t <- parseMLS
100+
-- case toSing t of
101+
-- SomeSing st -> SomeExtension st <$> parseExtension st
102+
103+
data Capabilities = Capabilities
104+
{ capVersions :: [ProtocolVersion],
105+
capCiphersuites :: [CipherSuite],
106+
capExtensions :: [Word16],
107+
capProposals :: [ProposalType]
108+
}
109+
110+
instance ParseMLS Capabilities where
111+
parseMLS =
112+
Capabilities
113+
<$> parseMLSVector @Word8 parseMLS
114+
<*> parseMLSVector @Word8 parseMLS
115+
<*> parseMLSVector @Word8 parseMLS
116+
<*> parseMLSVector @Word8 parseMLS
117+
118+
-- | Seconds since the UNIX epoch.
119+
newtype Timestamp = Timestamp {timestampSeconds :: Word64}
120+
deriving newtype (ParseMLS)
121+
122+
data Lifetime = Lifetime
123+
{ ltNotBefore :: Timestamp,
124+
ltNotAfter :: Timestamp
125+
}
126+
127+
instance ParseMLS Lifetime where
128+
parseMLS = Lifetime <$> parseMLS <*> parseMLS
60129

61130
data KeyPackageTBS = KeyPackageTBS
62131
{ kpProtocolVersion :: ProtocolVersion,
@@ -65,15 +134,20 @@ data KeyPackageTBS = KeyPackageTBS
65134
kpCredential :: Credential,
66135
kpExtensions :: [Extension]
67136
}
68-
deriving (Generic)
69137

70-
instance Binary KeyPackageTBS
138+
instance ParseMLS KeyPackageTBS where
139+
parseMLS =
140+
KeyPackageTBS
141+
<$> parseMLS
142+
<*> parseMLS
143+
<*> parseMLSBytes @Word16
144+
<*> parseMLS
145+
<*> parseMLSVector @Word32 parseMLS
71146

72147
data KeyPackage = KeyPackage
73148
{ kpTBS :: KeyPackageTBS,
74149
kpSignature :: ByteString
75150
}
76-
deriving (Generic)
77-
deriving (ParseMLS) via BinaryMLS KeyPackage
78151

79-
instance Binary KeyPackage
152+
instance ParseMLS KeyPackage where
153+
parseMLS = KeyPackage <$> parseMLS <*> parseMLSBytes @Word16
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
-- This file is part of the Wire Server implementation.
2+
--
3+
-- Copyright (C) 2022 Wire Swiss GmbH <[email protected]>
4+
--
5+
-- This program is free software: you can redistribute it and/or modify it under
6+
-- the terms of the GNU Affero General Public License as published by the Free
7+
-- Software Foundation, either version 3 of the License, or (at your option) any
8+
-- later version.
9+
--
10+
-- This program is distributed in the hope that it will be useful, but WITHOUT
11+
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12+
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
13+
-- details.
14+
--
15+
-- You should have received a copy of the GNU Affero General Public License along
16+
-- with this program. If not, see <https://www.gnu.org/licenses/>.
17+
18+
module Wire.API.MLS.Proposal where
19+
20+
import Data.Binary
21+
import Imports
22+
import Wire.API.MLS.Serialisation
23+
24+
data ProposalType
25+
= AddProposal
26+
| UpdateProposal
27+
| RemoveProposal
28+
| PreSharedKeyProposal
29+
| ReInitProposal
30+
| ExternalInitProposal
31+
| AppAckProposal
32+
| GroupContextExtensionsProposal
33+
| ExternalProposal
34+
deriving stock (Bounded, Enum)
35+
deriving (ParseMLS) via (EnumMLS Word16 ProposalType)

libs/wire-api/src/Wire/API/MLS/Serialisation.hs

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717

1818
module Wire.API.MLS.Serialisation
1919
( ParseMLS (..),
20+
parseMLSVector,
21+
parseMLSBytes,
2022
BinaryMLS (..),
21-
EnumBinary (..),
23+
EnumMLS (..),
24+
safeToEnum,
2225
decodeMLS,
2326
decodeMLS',
27+
decodeMLSWith,
28+
decodeMLSWith',
2429
)
2530
where
2631

32+
import Control.Applicative
2733
import Data.Binary
2834
import Data.Binary.Get
2935
import qualified Data.ByteString.Lazy as LBS
@@ -34,8 +40,30 @@ import Imports
3440
class ParseMLS a where
3541
parseMLS :: Get a
3642

37-
instance ParseMLS ByteString where
38-
parseMLS = get
43+
parseMLSVector :: forall w a. (Binary w, Integral w) => Get a -> Get [a]
44+
parseMLSVector getItem = do
45+
len <- get @w
46+
pos <- bytesRead
47+
isolate (fromIntegral len) $ go (pos + fromIntegral len)
48+
where
49+
go :: Int64 -> Get [a]
50+
go endPos = do
51+
x <- getItem
52+
pos <- bytesRead
53+
(:) <$> pure x <*> if pos < endPos then go endPos else pure []
54+
55+
parseMLSBytes :: forall w. (Binary w, Integral w) => Get ByteString
56+
parseMLSBytes = do
57+
len <- fromIntegral <$> get @w
58+
getByteString len
59+
60+
instance ParseMLS Word8 where parseMLS = get
61+
62+
instance ParseMLS Word16 where parseMLS = get
63+
64+
instance ParseMLS Word32 where parseMLS = get
65+
66+
instance ParseMLS Word64 where parseMLS = get
3967

4068
-- | A wrapper to generate a 'ParseMLS' instance given a 'Binary' instance.
4169
newtype BinaryMLS a = BinaryMLS a
@@ -44,11 +72,15 @@ instance Binary a => ParseMLS (BinaryMLS a) where
4472
parseMLS = BinaryMLS <$> get
4573

4674
-- | A wrapper to generate a 'Binary' instance for an enumerated type.
47-
newtype EnumBinary w a = EnumBinary {unEnumBinary :: a}
75+
newtype EnumMLS w a = EnumMLS {unEnumMLS :: a}
76+
77+
safeToEnum :: forall a f. (Bounded a, Enum a, Alternative f) => Int -> f a
78+
safeToEnum n = guard (n >= fromEnum @a minBound && n <= fromEnum @a maxBound) $> toEnum n
4879

49-
instance (Binary w, Integral w, Enum a) => Binary (EnumBinary w a) where
50-
get = EnumBinary . toEnum . fromIntegral <$> get @w
51-
put = put @w . fromIntegral . fromEnum . unEnumBinary
80+
instance (Binary w, Integral w, Bounded a, Enum a) => ParseMLS (EnumMLS w a) where
81+
parseMLS = do
82+
n <- fromIntegral <$> get @w
83+
EnumMLS <$> safeToEnum n
5284

5385
-- | Decode an MLS value from a lazy bytestring. Return an error message in case of failure.
5486
decodeMLS :: ParseMLS a => LByteString -> Either Text a
@@ -65,3 +97,6 @@ decodeMLSWith p b = case runGetOrFail p b of
6597
Right (remainder, pos, x)
6698
| LBS.null remainder -> Right x
6799
| otherwise -> Left $ "Trailing data at position " <> T.pack (show pos)
100+
101+
decodeMLSWith' :: Get a -> ByteString -> Either Text a
102+
decodeMLSWith' p = decodeMLSWith p . LBS.fromStrict

libs/wire-api/wire-api.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ library
3939
Wire.API.MLS.CipherSuite
4040
Wire.API.MLS.Credential
4141
Wire.API.MLS.KeyPackage
42+
Wire.API.MLS.Proposal
4243
Wire.API.MLS.Serialisation
4344
Wire.API.Notification
4445
Wire.API.Properties

services/brig/src/Brig/API/MLS/KeyPackages.hs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@
1515
-- You should have received a copy of the GNU Affero General Public License along
1616
-- with this program. If not, see <https://www.gnu.org/licenses/>.
1717

18-
module Brig.API.MLS.KeyPackages where
18+
module Brig.API.MLS.KeyPackages
19+
( uploadKeyPackages,
20+
)
21+
where
1922

2023
import Brig.API.Error
2124
import Brig.API.Handler
2225
import qualified Brig.Data.MLS.KeyPackage as Data
26+
import Control.Applicative
2327
import Control.Monad.Trans.Except
2428
import Data.Id
2529
import Data.Json.Util
@@ -45,6 +49,7 @@ parseKeyPackage (fromBase64ByteString . kpData -> kpd) =
4549
validateKeyPackage :: ClientIdentity -> KeyPackage -> Handler ()
4650
validateKeyPackage identity (kpTBS -> kp) = do
4751
validateCredential identity (kpCredential kp)
52+
validateExtensions (kpExtensions kp)
4853

4954
validateCredential :: ClientIdentity -> Credential -> Handler ()
5055
validateCredential identity cred = do
@@ -54,3 +59,38 @@ validateCredential identity cred = do
5459
decodeMLS' (bcIdentity cred)
5560
when (identity /= identity') $
5661
throwE mlsIdentityMismatch
62+
63+
data RequiredExtensions f = RequiredExtensions
64+
{ reLifetime :: f Lifetime,
65+
reCapabilities :: f ()
66+
}
67+
68+
instance Alternative f => Semigroup (RequiredExtensions f) where
69+
RequiredExtensions lt1 cap1 <> RequiredExtensions lt2 cap2 =
70+
RequiredExtensions (lt1 <|> lt2) (cap1 <|> cap2)
71+
72+
instance Alternative f => Monoid (RequiredExtensions f) where
73+
mempty = RequiredExtensions empty empty
74+
75+
checkRequiredExtensions :: Applicative f => RequiredExtensions f -> f (RequiredExtensions Identity)
76+
checkRequiredExtensions re =
77+
RequiredExtensions
78+
<$> (Identity <$> reLifetime re)
79+
<*> (Identity <$> reCapabilities re)
80+
81+
findExtensions :: [Extension] -> Maybe (RequiredExtensions Identity)
82+
findExtensions = checkRequiredExtensions . foldMap findExtension
83+
84+
findExtension :: Extension -> RequiredExtensions Maybe
85+
findExtension e = case decodeExtension e of
86+
Just (SomeExtension SLifetimeExtensionTag lt) -> RequiredExtensions (pure lt) Nothing
87+
Just (SomeExtension SCapabilitiesExtensionTag _) -> RequiredExtensions Nothing (pure ())
88+
_ -> mempty
89+
90+
validateExtensions :: [Extension] -> Handler ()
91+
validateExtensions exts = do
92+
_re <-
93+
maybe (throwE (mlsProtocolError "Missing required extensions")) pure $
94+
findExtensions exts
95+
-- TODO: validate lifetime
96+
pure ()

0 commit comments

Comments
 (0)