{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.Client.TLS13 (
    recvServerSecondFlight13,
    sendClientSecondFlight13,
    asyncServerHello13,
    postHandshakeAuthClientWith,
) where

import Control.Exception (bracket)
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Client.Common
import Network.TLS.Handshake.Client.ServerHello
import Network.TLS.Handshake.Common hiding (expectFinished)
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types
import Network.TLS.X509

----------------------------------------------------------------
----------------------------------------------------------------

recvServerSecondFlight13 :: ClientParams -> Context -> Maybe Group -> IO ()
recvServerSecondFlight13 :: ClientParams -> Context -> Maybe Group -> IO ()
recvServerSecondFlight13 ClientParams
cparams Context
ctx Maybe Group
groupSent = do
    resuming <- Context -> Maybe Group -> IO Bool
prepareSecondFlight13 Context
ctx Maybe Group
groupSent
    runRecvHandshake13 $ do
        recvHandshake13 ctx $ expectEncryptedExtensions ctx
        unless resuming $ recvHandshake13 ctx $ expectCertRequest cparams ctx
        recvHandshake13hash ctx $ expectFinished cparams ctx

----------------------------------------------------------------

prepareSecondFlight13
    :: Context -> Maybe Group -> IO Bool
prepareSecondFlight13 :: Context -> Maybe Group -> IO Bool
prepareSecondFlight13 Context
ctx Maybe Group
groupSent = do
    choice <- Version -> Cipher -> CipherChoice
makeCipherChoice Version
TLS13 (Cipher -> CipherChoice) -> IO Cipher -> IO CipherChoice
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> HandshakeM Cipher -> IO Cipher
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM Cipher
getPendingCipher
    prepareSecondFlight13' ctx groupSent choice

prepareSecondFlight13'
    :: Context
    -> Maybe Group
    -> CipherChoice
    -> IO Bool
prepareSecondFlight13' :: Context -> Maybe Group -> CipherChoice -> IO Bool
prepareSecondFlight13' Context
ctx Maybe Group
groupSent CipherChoice
choice = do
    (_, hkey, resuming) <- IO (Cipher, SecretTriple HandshakeSecret, Bool)
switchToHandshakeSecret
    let clientHandshakeSecret = SecretTriple HandshakeSecret -> ClientTrafficSecret HandshakeSecret
forall a. SecretTriple a -> ClientTrafficSecret a
triClient SecretTriple HandshakeSecret
hkey
        serverHandshakeSecret = SecretTriple HandshakeSecret -> ServerTrafficSecret HandshakeSecret
forall a. SecretTriple a -> ServerTrafficSecret a
triServer SecretTriple HandshakeSecret
hkey
        handSecInfo = Cipher -> TrafficSecrets HandshakeSecret -> HandshakeSecretInfo
HandshakeSecretInfo Cipher
usedCipher (ClientTrafficSecret HandshakeSecret
clientHandshakeSecret, ServerTrafficSecret HandshakeSecret
serverHandshakeSecret)
    contextSync ctx $ RecvServerHello handSecInfo
    modifyTLS13State ctx $ \TLS13State
st ->
        TLS13State
st
            { tls13stChoice = choice
            , tls13stHsKey = Just hkey
            }
    return resuming
  where
    usedCipher :: Cipher
usedCipher = CipherChoice -> Cipher
cCipher CipherChoice
choice
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice

    hashSize :: Int
hashSize = Hash -> Int
hashDigestSize Hash
usedHash

    switchToHandshakeSecret :: IO (Cipher, SecretTriple HandshakeSecret, Bool)
switchToHandshakeSecret = do
        Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
ensureRecvComplete Context
ctx
        ecdhe <- IO ByteString
calcSharedKey
        (earlySecret, resuming) <- makeEarlySecret
        handKey <- calculateHandshakeSecret ctx choice earlySecret ecdhe
        let serverHandshakeSecret = SecretTriple HandshakeSecret -> ServerTrafficSecret HandshakeSecret
forall a. SecretTriple a -> ServerTrafficSecret a
triServer SecretTriple HandshakeSecret
handKey
        setRxRecordState ctx usedHash usedCipher serverHandshakeSecret
        return (usedCipher, handKey, resuming)

    calcSharedKey :: IO ByteString
calcSharedKey = do
        serverKeyShare <- do
            mks <- Context -> TLSSt (Maybe KeyShare) -> IO (Maybe KeyShare)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe KeyShare)
getTLS13KeyShare
            case mks of
                Just (KeyShareServerHello KeyShareEntry
ks) -> KeyShareEntry -> IO KeyShareEntry
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return KeyShareEntry
ks
                Just KeyShare
_ ->
                    TLSError -> IO KeyShareEntry
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO KeyShareEntry) -> TLSError -> IO KeyShareEntry
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"invalid key_share value" AlertDescription
IllegalParameter
                Maybe KeyShare
Nothing ->
                    TLSError -> IO KeyShareEntry
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO KeyShareEntry) -> TLSError -> IO KeyShareEntry
forall a b. (a -> b) -> a -> b
$
                        String -> AlertDescription -> TLSError
Error_Protocol
                            String
"key exchange not implemented, expected key_share extension"
                            AlertDescription
HandshakeFailure
        let grp = KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
serverKeyShare
        unless (checkKeyShareKeyLength serverKeyShare) $
            throwCore $
                Error_Protocol "broken key_share" IllegalParameter
        unless (groupSent == Just grp) $
            throwCore $
                Error_Protocol "received incompatible group for (EC)DHE" IllegalParameter
        usingHState ctx $ setSupportedGroup grp
        usingHState ctx getGroupPrivate >>= fromServerKeyShare serverKeyShare

    makeEarlySecret :: IO (BaseSecret EarlySecret, Bool)
makeEarlySecret = do
        mEarlySecretPSK <- Context
-> HandshakeM (Maybe (BaseSecret EarlySecret))
-> IO (Maybe (BaseSecret EarlySecret))
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret
        case mEarlySecretPSK of
            Maybe (BaseSecret EarlySecret)
Nothing -> (BaseSecret EarlySecret, Bool) -> IO (BaseSecret EarlySecret, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CipherChoice -> Maybe ByteString -> BaseSecret EarlySecret
initEarlySecret CipherChoice
choice Maybe ByteString
forall a. Maybe a
Nothing, Bool
False)
            Just earlySecretPSK :: BaseSecret EarlySecret
earlySecretPSK@(BaseSecret ByteString
sec) -> do
                mSelectedIdentity <- Context -> TLSSt (Maybe PreSharedKey) -> IO (Maybe PreSharedKey)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe PreSharedKey)
getTLS13PreSharedKey
                case mSelectedIdentity of
                    Maybe PreSharedKey
Nothing ->
                        (BaseSecret EarlySecret, Bool) -> IO (BaseSecret EarlySecret, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CipherChoice -> Maybe ByteString -> BaseSecret EarlySecret
initEarlySecret CipherChoice
choice Maybe ByteString
forall a. Maybe a
Nothing, Bool
False)
                    Just (PreSharedKeyServerHello Int
0) -> do
                        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Int
B.length ByteString
sec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
hashSize) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                                String -> AlertDescription -> TLSError
Error_Protocol
                                    String
"selected cipher is incompatible with selected PSK"
                                    AlertDescription
IllegalParameter
                        Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
PreSharedKey
                        (BaseSecret EarlySecret, Bool) -> IO (BaseSecret EarlySecret, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (BaseSecret EarlySecret
earlySecretPSK, Bool
True)
                    Just PreSharedKey
_ ->
                        TLSError -> IO (BaseSecret EarlySecret, Bool)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO (BaseSecret EarlySecret, Bool))
-> TLSError -> IO (BaseSecret EarlySecret, Bool)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"selected identity out of range" AlertDescription
IllegalParameter

----------------------------------------------------------------

expectEncryptedExtensions
    :: MonadIO m => Context -> Handshake13 -> m ()
expectEncryptedExtensions :: forall (m :: * -> *). MonadIO m => Context -> Handshake13 -> m ()
expectEncryptedExtensions Context
ctx (EncryptedExtensions13 [ExtensionRaw]
eexts) = do
    IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
        Context -> MessageType -> [ExtensionRaw] -> IO ()
setALPN Context
ctx MessageType
MsgTEncryptedExtensions [ExtensionRaw]
eexts
        Context -> (TLS13State -> TLS13State) -> IO ()
modifyTLS13State Context
ctx ((TLS13State -> TLS13State) -> IO ())
-> (TLS13State -> TLS13State) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TLS13State
st -> TLS13State
st{tls13stClientExtensions = eexts}
    st13 <- Context -> HandshakeM RTT0Status -> m RTT0Status
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM RTT0Status
getTLS13RTT0Status
    when (st13 == RTT0Sent) $
        case extensionLookup EID_EarlyData eexts of
            Just ByteString
_ -> do
                Context -> HandshakeM () -> m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> m ()) -> HandshakeM () -> m ()
forall a b. (a -> b) -> a -> b
$ HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
RTT0
                Context -> HandshakeM () -> m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> m ()) -> HandshakeM () -> m ()
forall a b. (a -> b) -> a -> b
$ RTT0Status -> HandshakeM ()
setTLS13RTT0Status RTT0Status
RTT0Accepted
                IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Context -> (TLS13State -> TLS13State) -> IO ()
modifyTLS13State Context
ctx ((TLS13State -> TLS13State) -> IO ())
-> (TLS13State -> TLS13State) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TLS13State
st -> TLS13State
st{tls13st0RTTAccepted = True}
            Maybe ByteString
Nothing -> do
                Context -> HandshakeM () -> m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> m ()) -> HandshakeM () -> m ()
forall a b. (a -> b) -> a -> b
$ HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
PreSharedKey
                Context -> HandshakeM () -> m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> m ()) -> HandshakeM () -> m ()
forall a b. (a -> b) -> a -> b
$ RTT0Status -> HandshakeM ()
setTLS13RTT0Status RTT0Status
RTT0Rejected
expectEncryptedExtensions Context
_ Handshake13
p = String -> Maybe String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"encrypted extensions")

----------------------------------------------------------------
-- not used in 0-RTT
expectCertRequest
    :: MonadIO m => ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
expectCertRequest :: forall (m :: * -> *).
MonadIO m =>
ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
expectCertRequest ClientParams
cparams Context
ctx (CertRequest13 ByteString
token [ExtensionRaw]
exts) = do
    Context -> ByteString -> [ExtensionRaw] -> RecvHandshake13M m ()
forall (m :: * -> *).
MonadIO m =>
Context -> ByteString -> [ExtensionRaw] -> m ()
processCertRequest13 Context
ctx ByteString
token [ExtensionRaw]
exts
    Context
-> (Handshake13 -> RecvHandshake13M m ()) -> RecvHandshake13M m ()
forall (m :: * -> *) a.
MonadIO m =>
Context
-> (Handshake13 -> RecvHandshake13M m a) -> RecvHandshake13M m a
recvHandshake13 Context
ctx ((Handshake13 -> RecvHandshake13M m ()) -> RecvHandshake13M m ())
-> (Handshake13 -> RecvHandshake13M m ()) -> RecvHandshake13M m ()
forall a b. (a -> b) -> a -> b
$ ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
forall (m :: * -> *).
MonadIO m =>
ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
expectCertAndVerify ClientParams
cparams Context
ctx
expectCertRequest ClientParams
cparams Context
ctx Handshake13
other = do
    Context -> HandshakeM () -> RecvHandshake13M m ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> RecvHandshake13M m ())
-> HandshakeM () -> RecvHandshake13M m ()
forall a b. (a -> b) -> a -> b
$ do
        Maybe ByteString -> HandshakeM ()
setCertReqToken Maybe ByteString
forall a. Maybe a
Nothing
        Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata Maybe CertReqCBdata
forall a. Maybe a
Nothing
    -- setCertReqSigAlgsCert Nothing
    ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
forall (m :: * -> *).
MonadIO m =>
ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
expectCertAndVerify ClientParams
cparams Context
ctx Handshake13
other

processCertRequest13
    :: MonadIO m => Context -> CertReqContext -> [ExtensionRaw] -> m ()
processCertRequest13 :: forall (m :: * -> *).
MonadIO m =>
Context -> ByteString -> [ExtensionRaw] -> m ()
processCertRequest13 Context
ctx ByteString
token [ExtensionRaw]
exts = do
    let hsextID :: ExtensionID
hsextID = ExtensionID
EID_SignatureAlgorithms
    -- caextID = EID_SignatureAlgorithmsCert
    dNames <- m [DistinguishedName]
canames
    -- The @signature_algorithms@ extension is mandatory.
    hsAlgs <- extalgs hsextID unsighash
    cTypes <- case hsAlgs of
        Just [HashAndSignatureAlgorithm]
as ->
            let validAs :: [HashAndSignatureAlgorithm]
validAs = (HashAndSignatureAlgorithm -> Bool)
-> [HashAndSignatureAlgorithm] -> [HashAndSignatureAlgorithm]
forall a. (a -> Bool) -> [a] -> [a]
filter HashAndSignatureAlgorithm -> Bool
isHashSignatureValid13 [HashAndSignatureAlgorithm]
as
             in [CertificateType] -> m [CertificateType]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([CertificateType] -> m [CertificateType])
-> [CertificateType] -> m [CertificateType]
forall a b. (a -> b) -> a -> b
$ Context -> [HashAndSignatureAlgorithm] -> [CertificateType]
sigAlgsToCertTypes Context
ctx [HashAndSignatureAlgorithm]
validAs
        Maybe [HashAndSignatureAlgorithm]
Nothing -> TLSError -> m [CertificateType]
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m [CertificateType])
-> TLSError -> m [CertificateType]
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"invalid certificate request" AlertDescription
HandshakeFailure
    -- Unused:
    -- caAlgs <- extalgs caextID uncertsig
    usingHState ctx $ do
        setCertReqToken $ Just token
        setCertReqCBdata $ Just (cTypes, hsAlgs, dNames)
  where
    -- setCertReqSigAlgsCert caAlgs

    canames :: m [DistinguishedName]
canames = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup
        ExtensionID
EID_CertificateAuthorities
        [ExtensionRaw]
exts of
        Maybe ByteString
Nothing -> [DistinguishedName] -> m [DistinguishedName]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return []
        Just ByteString
ext -> case MessageType -> ByteString -> Maybe CertificateAuthorities
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTCertificateRequest ByteString
ext of
            Just (CertificateAuthorities [DistinguishedName]
names) -> [DistinguishedName] -> m [DistinguishedName]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return [DistinguishedName]
names
            Maybe CertificateAuthorities
_ -> TLSError -> m [DistinguishedName]
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m [DistinguishedName])
-> TLSError -> m [DistinguishedName]
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"invalid certificate request" AlertDescription
HandshakeFailure
    extalgs :: ExtensionID -> (t -> Maybe a) -> m (Maybe a)
extalgs ExtensionID
extID t -> Maybe a
decons = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
extID [ExtensionRaw]
exts of
        Maybe ByteString
Nothing -> Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
        Just ByteString
ext -> case MessageType -> ByteString -> Maybe t
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTCertificateRequest ByteString
ext of
            Just t
e ->
                Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ t -> Maybe a
decons t
e
            Maybe t
_ -> TLSError -> m (Maybe a)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m (Maybe a)) -> TLSError -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"invalid certificate request" AlertDescription
HandshakeFailure
    unsighash
        :: SignatureAlgorithms
        -> Maybe [HashAndSignatureAlgorithm]
    unsighash :: SignatureAlgorithms -> Maybe [HashAndSignatureAlgorithm]
unsighash (SignatureAlgorithms [HashAndSignatureAlgorithm]
a) = [HashAndSignatureAlgorithm] -> Maybe [HashAndSignatureAlgorithm]
forall a. a -> Maybe a
Just [HashAndSignatureAlgorithm]
a

----------------------------------------------------------------
-- not used in 0-RTT
expectCertAndVerify
    :: MonadIO m => ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
expectCertAndVerify :: forall (m :: * -> *).
MonadIO m =>
ClientParams -> Context -> Handshake13 -> RecvHandshake13M m ()
expectCertAndVerify ClientParams
cparams Context
ctx (Certificate13 ByteString
_ CertificateChain
cc [[ExtensionRaw]]
_) = do
    IO () -> RecvHandshake13M m ()
forall a. IO a -> RecvHandshake13M m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> RecvHandshake13M m ()) -> IO () -> RecvHandshake13M m ()
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> TLSSt ()
setServerCertificateChain CertificateChain
cc
    IO () -> RecvHandshake13M m ()
forall a. IO a -> RecvHandshake13M m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> RecvHandshake13M m ()) -> IO () -> RecvHandshake13M m ()
forall a b. (a -> b) -> a -> b
$ ClientParams -> Context -> CertificateChain -> IO ()
doCertificate ClientParams
cparams Context
ctx CertificateChain
cc
    let pubkey :: PubKey
pubkey = Certificate -> PubKey
certPubKey (Certificate -> PubKey) -> Certificate -> PubKey
forall a b. (a -> b) -> a -> b
$ SignedCertificate -> Certificate
getCertificate (SignedCertificate -> Certificate)
-> SignedCertificate -> Certificate
forall a b. (a -> b) -> a -> b
$ CertificateChain -> SignedCertificate
getCertificateChainLeaf CertificateChain
cc
    ver <- IO Version -> RecvHandshake13M m Version
forall a. IO a -> RecvHandshake13M m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Version -> RecvHandshake13M m Version)
-> IO Version -> RecvHandshake13M m Version
forall a b. (a -> b) -> a -> b
$ Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    checkDigitalSignatureKey ver pubkey
    usingHState ctx $ setPublicKey pubkey
    recvHandshake13hash ctx $ expectCertVerify ctx pubkey
expectCertAndVerify ClientParams
_ Context
_ Handshake13
p = String -> Maybe String -> RecvHandshake13M m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"server certificate")

----------------------------------------------------------------

expectCertVerify
    :: MonadIO m => Context -> PubKey -> ByteString -> Handshake13 -> m ()
expectCertVerify :: forall (m :: * -> *).
MonadIO m =>
Context -> PubKey -> ByteString -> Handshake13 -> m ()
expectCertVerify Context
ctx PubKey
pubkey ByteString
hChSc (CertVerify13 HashAndSignatureAlgorithm
sigAlg ByteString
sig) = do
    ok <- Context
-> PubKey
-> HashAndSignatureAlgorithm
-> ByteString
-> ByteString
-> m Bool
forall (m :: * -> *).
MonadIO m =>
Context
-> PubKey
-> HashAndSignatureAlgorithm
-> ByteString
-> ByteString
-> m Bool
checkCertVerify Context
ctx PubKey
pubkey HashAndSignatureAlgorithm
sigAlg ByteString
sig ByteString
hChSc
    unless ok $ decryptError "cannot verify CertificateVerify"
expectCertVerify Context
_ PubKey
_ ByteString
_ Handshake13
p = String -> Maybe String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"certificate verify")

----------------------------------------------------------------

expectFinished
    :: MonadIO m
    => ClientParams
    -> Context
    -> ByteString
    -> Handshake13
    -> m ()
expectFinished :: forall (m :: * -> *).
MonadIO m =>
ClientParams -> Context -> ByteString -> Handshake13 -> m ()
expectFinished ClientParams
cparams Context
ctx ByteString
hashValue (Finished13 ByteString
verifyData) = do
    st <- IO TLS13State -> m TLS13State
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO TLS13State -> m TLS13State) -> IO TLS13State -> m TLS13State
forall a b. (a -> b) -> a -> b
$ Context -> IO TLS13State
getTLS13State Context
ctx
    let usedHash = CipherChoice -> Hash
cHash (CipherChoice -> Hash) -> CipherChoice -> Hash
forall a b. (a -> b) -> a -> b
$ TLS13State -> CipherChoice
tls13stChoice TLS13State
st
        ServerTrafficSecret baseKey = triServer $ fromJust $ tls13stHsKey st
    checkFinished ctx usedHash baseKey hashValue verifyData
    liftIO $ do
        minfo <- contextGetInformation ctx
        case minfo of
            Maybe Information
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just Information
info -> ClientHooks -> Information -> IO ()
onServerFinished (ClientParams -> ClientHooks
clientHooks ClientParams
cparams) Information
info
    liftIO $ modifyTLS13State ctx $ \TLS13State
s -> TLS13State
s{tls13stRecvSF = True}
expectFinished ClientParams
_ Context
_ ByteString
_ Handshake13
p = String -> Maybe String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
p) (String -> Maybe String
forall a. a -> Maybe a
Just String
"server finished")

----------------------------------------------------------------
----------------------------------------------------------------

sendClientSecondFlight13 :: ClientParams -> Context -> IO ()
sendClientSecondFlight13 :: ClientParams -> Context -> IO ()
sendClientSecondFlight13 ClientParams
cparams Context
ctx = do
    st <- Context -> IO TLS13State
getTLS13State Context
ctx
    let choice = TLS13State -> CipherChoice
tls13stChoice TLS13State
st
        hkey = Maybe (SecretTriple HandshakeSecret)
-> SecretTriple HandshakeSecret
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (SecretTriple HandshakeSecret)
 -> SecretTriple HandshakeSecret)
-> Maybe (SecretTriple HandshakeSecret)
-> SecretTriple HandshakeSecret
forall a b. (a -> b) -> a -> b
$ TLS13State -> Maybe (SecretTriple HandshakeSecret)
tls13stHsKey TLS13State
st
        rtt0accepted = TLS13State -> Bool
tls13st0RTTAccepted TLS13State
st
        eexts = TLS13State -> [ExtensionRaw]
tls13stClientExtensions TLS13State
st
    sendClientSecondFlight13' cparams ctx choice hkey rtt0accepted eexts
    modifyTLS13State ctx $ \TLS13State
s -> TLS13State
s{tls13stSentCF = True}

sendClientSecondFlight13'
    :: ClientParams
    -> Context
    -> CipherChoice
    -> SecretTriple HandshakeSecret
    -> Bool
    -> [ExtensionRaw]
    -> IO ()
sendClientSecondFlight13' :: ClientParams
-> Context
-> CipherChoice
-> SecretTriple HandshakeSecret
-> Bool
-> [ExtensionRaw]
-> IO ()
sendClientSecondFlight13' ClientParams
cparams Context
ctx CipherChoice
choice SecretTriple HandshakeSecret
hkey Bool
rtt0accepted [ExtensionRaw]
eexts = do
    hChSf <- Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
transcriptHash Context
ctx
    unless (ctxQUICMode ctx) $
        runPacketFlight ctx $
            sendChangeCipherSpec13 ctx
    when (rtt0accepted && not (ctxQUICMode ctx)) $
        sendPacket13 ctx (Handshake13 [EndOfEarlyData13])
    let clientHandshakeSecret = SecretTriple HandshakeSecret -> ClientTrafficSecret HandshakeSecret
forall a. SecretTriple a -> ClientTrafficSecret a
triClient SecretTriple HandshakeSecret
hkey
    setTxRecordState ctx usedHash usedCipher clientHandshakeSecret
    sendClientFlight13 cparams ctx usedHash clientHandshakeSecret
    appKey <- switchToApplicationSecret hChSf
    let applicationSecret = SecretTriple ApplicationSecret -> BaseSecret ApplicationSecret
forall a. SecretTriple a -> BaseSecret a
triBase SecretTriple ApplicationSecret
appKey
    setResumptionSecret applicationSecret
    let appSecInfo = TrafficSecrets ApplicationSecret -> ApplicationSecretInfo
ApplicationSecretInfo (SecretTriple ApplicationSecret
-> ClientTrafficSecret ApplicationSecret
forall a. SecretTriple a -> ClientTrafficSecret a
triClient SecretTriple ApplicationSecret
appKey, SecretTriple ApplicationSecret
-> ServerTrafficSecret ApplicationSecret
forall a. SecretTriple a -> ServerTrafficSecret a
triServer SecretTriple ApplicationSecret
appKey)
    contextSync ctx $ SendClientFinished eexts appSecInfo
    modifyTLS13State ctx $ \TLS13State
st -> TLS13State
st{tls13stHsKey = Nothing}
    handshakeDone13 ctx
    rtt0 <- tls13st0RTT <$> getTLS13State ctx
    when rtt0 $ do
        builder <- tls13stPendingSentData <$> getTLS13State ctx
        modifyTLS13State ctx $ \TLS13State
st -> TLS13State
st{tls13stPendingSentData = id}
        unless rtt0accepted $
            mapM_ (sendPacket13 ctx . AppData13) $
                builder []
  where
    usedCipher :: Cipher
usedCipher = CipherChoice -> Cipher
cCipher CipherChoice
choice
    usedHash :: Hash
usedHash = CipherChoice -> Hash
cHash CipherChoice
choice

    switchToApplicationSecret :: ByteString -> IO (SecretTriple ApplicationSecret)
switchToApplicationSecret ByteString
hChSf = do
        Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
ensureRecvComplete Context
ctx
        let handshakeSecret :: BaseSecret HandshakeSecret
handshakeSecret = SecretTriple HandshakeSecret -> BaseSecret HandshakeSecret
forall a. SecretTriple a -> BaseSecret a
triBase SecretTriple HandshakeSecret
hkey
        appKey <- Context
-> CipherChoice
-> BaseSecret HandshakeSecret
-> ByteString
-> IO (SecretTriple ApplicationSecret)
calculateApplicationSecret Context
ctx CipherChoice
choice BaseSecret HandshakeSecret
handshakeSecret ByteString
hChSf
        let serverApplicationSecret0 = SecretTriple ApplicationSecret
-> ServerTrafficSecret ApplicationSecret
forall a. SecretTriple a -> ServerTrafficSecret a
triServer SecretTriple ApplicationSecret
appKey
        let clientApplicationSecret0 = SecretTriple ApplicationSecret
-> ClientTrafficSecret ApplicationSecret
forall a. SecretTriple a -> ClientTrafficSecret a
triClient SecretTriple ApplicationSecret
appKey
        setTxRecordState ctx usedHash usedCipher clientApplicationSecret0
        setRxRecordState ctx usedHash usedCipher serverApplicationSecret0
        return appKey

    setResumptionSecret :: BaseSecret ApplicationSecret -> IO ()
setResumptionSecret BaseSecret ApplicationSecret
applicationSecret = do
        resumptionSecret <- Context
-> CipherChoice
-> BaseSecret ApplicationSecret
-> IO (BaseSecret ResumptionSecret)
calculateResumptionSecret Context
ctx CipherChoice
choice BaseSecret ApplicationSecret
applicationSecret
        usingHState ctx $ setTLS13ResumptionSecret resumptionSecret

{- Unused for now
uncertsig :: SignatureAlgorithmsCert
          -> Maybe [HashAndSignatureAlgorithm]
uncertsig (SignatureAlgorithmsCert a) = Just a
-}

sendClientFlight13
    :: ClientParams -> Context -> Hash -> ClientTrafficSecret a -> IO ()
sendClientFlight13 :: forall a.
ClientParams -> Context -> Hash -> ClientTrafficSecret a -> IO ()
sendClientFlight13 ClientParams
cparams Context
ctx Hash
usedHash (ClientTrafficSecret ByteString
baseKey) = do
    mcc <- ClientParams -> Context -> IO (Maybe CertificateChain)
clientChain ClientParams
cparams Context
ctx
    runPacketFlight ctx $ do
        case mcc of
            Maybe CertificateChain
Nothing -> () -> PacketFlightM b ()
forall a. a -> PacketFlightM b a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just CertificateChain
cc -> Context
-> HandshakeM (Maybe ByteString)
-> PacketFlightM b (Maybe ByteString)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe ByteString)
getCertReqToken PacketFlightM b (Maybe ByteString)
-> (Maybe ByteString -> PacketFlightM b ()) -> PacketFlightM b ()
forall a b.
PacketFlightM b a -> (a -> PacketFlightM b b) -> PacketFlightM b b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CertificateChain -> Maybe ByteString -> PacketFlightM b ()
forall {b}.
Monoid b =>
CertificateChain -> Maybe ByteString -> PacketFlightM b ()
loadClientData13 CertificateChain
cc
        rawFinished <- makeFinished ctx usedHash baseKey
        loadPacket13 ctx $ Handshake13 [rawFinished]
    when (isJust mcc) $
        modifyTLS13State ctx $
            \TLS13State
st -> TLS13State
st{tls13stSentClientCert = True}
  where
    loadClientData13 :: CertificateChain -> Maybe ByteString -> PacketFlightM b ()
loadClientData13 CertificateChain
chain (Just ByteString
token) = do
        let (CertificateChain [SignedCertificate]
certs) = CertificateChain
chain
            certExts :: [[a]]
certExts = Int -> [a] -> [[a]]
forall a. Int -> a -> [a]
replicate ([SignedCertificate] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SignedCertificate]
certs) []
            cHashSigs :: [HashAndSignatureAlgorithm]
cHashSigs = (HashAndSignatureAlgorithm -> Bool)
-> [HashAndSignatureAlgorithm] -> [HashAndSignatureAlgorithm]
forall a. (a -> Bool) -> [a] -> [a]
filter HashAndSignatureAlgorithm -> Bool
isHashSignatureValid13 ([HashAndSignatureAlgorithm] -> [HashAndSignatureAlgorithm])
-> [HashAndSignatureAlgorithm] -> [HashAndSignatureAlgorithm]
forall a b. (a -> b) -> a -> b
$ Supported -> [HashAndSignatureAlgorithm]
supportedHashSignatures (Supported -> [HashAndSignatureAlgorithm])
-> Supported -> [HashAndSignatureAlgorithm]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
        Context -> Packet13 -> PacketFlightM b ()
forall b. Monoid b => Context -> Packet13 -> PacketFlightM b ()
loadPacket13 Context
ctx (Packet13 -> PacketFlightM b ()) -> Packet13 -> PacketFlightM b ()
forall a b. (a -> b) -> a -> b
$ [Handshake13] -> Packet13
Handshake13 [ByteString -> CertificateChain -> [[ExtensionRaw]] -> Handshake13
Certificate13 ByteString
token CertificateChain
chain [[ExtensionRaw]]
forall {a}. [[a]]
certExts]
        case [SignedCertificate]
certs of
            [] -> () -> PacketFlightM b ()
forall a. a -> PacketFlightM b a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            [SignedCertificate]
_ -> do
                hChSc <- Context -> PacketFlightM b ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
transcriptHash Context
ctx
                pubKey <- getLocalPublicKey ctx
                sigAlg <-
                    liftIO $ getLocalHashSigAlg ctx signatureCompatible13 cHashSigs pubKey
                vfy <- makeCertVerify ctx pubKey sigAlg hChSc
                loadPacket13 ctx $ Handshake13 [vfy]
    --
    loadClientData13 CertificateChain
_ Maybe ByteString
_ =
        TLSError -> PacketFlightM b ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> PacketFlightM b ()) -> TLSError -> PacketFlightM b ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"missing TLS 1.3 certificate request context token" AlertDescription
InternalError

----------------------------------------------------------------
----------------------------------------------------------------

postHandshakeAuthClientWith :: ClientParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthClientWith :: ClientParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthClientWith ClientParams
cparams Context
ctx h :: Handshake13
h@(CertRequest13 ByteString
certReqCtx [ExtensionRaw]
exts) =
    IO (Saved (Maybe HandshakeState))
-> (Saved (Maybe HandshakeState)
    -> IO (Saved (Maybe HandshakeState)))
-> (Saved (Maybe HandshakeState) -> IO ())
-> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Context -> IO (Saved (Maybe HandshakeState))
saveHState Context
ctx) (Context
-> Saved (Maybe HandshakeState)
-> IO (Saved (Maybe HandshakeState))
restoreHState Context
ctx) ((Saved (Maybe HandshakeState) -> IO ()) -> IO ())
-> (Saved (Maybe HandshakeState) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Saved (Maybe HandshakeState)
_ -> do
        Context -> Handshake13 -> IO ()
processHandshake13 Context
ctx Handshake13
h
        Context -> ByteString -> [ExtensionRaw] -> IO ()
forall (m :: * -> *).
MonadIO m =>
Context -> ByteString -> [ExtensionRaw] -> m ()
processCertRequest13 Context
ctx ByteString
certReqCtx [ExtensionRaw]
exts
        (usedHash, _, level, applicationSecretN) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxRecordState Context
ctx
        unless (level == CryptApplicationSecret) $
            throwCore $
                Error_Protocol
                    "unexpected post-handshake authentication request"
                    UnexpectedMessage
        sendClientFlight13 cparams ctx usedHash (ClientTrafficSecret applicationSecretN)
postHandshakeAuthClientWith ClientParams
_ Context
_ Handshake13
_ =
    TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> AlertDescription -> TLSError
Error_Protocol
            String
"unexpected handshake message received in postHandshakeAuthClientWith"
            AlertDescription
UnexpectedMessage

----------------------------------------------------------------
----------------------------------------------------------------

asyncServerHello13
    :: ClientParams -> Context -> Maybe Group -> Millisecond -> IO ()
asyncServerHello13 :: ClientParams -> Context -> Maybe Group -> Millisecond -> IO ()
asyncServerHello13 ClientParams
cparams Context
ctx Maybe Group
groupSent Millisecond
chSentTime = do
    Context -> [PendingRecvAction] -> IO ()
setPendingRecvActions
        Context
ctx
        [ Bool -> (Handshake13 -> IO ()) -> PendingRecvAction
PendingRecvAction Bool
True Handshake13 -> IO ()
expectServerHello
        , Bool -> (Handshake13 -> IO ()) -> PendingRecvAction
PendingRecvAction Bool
True (Context -> Handshake13 -> IO ()
forall (m :: * -> *). MonadIO m => Context -> Handshake13 -> m ()
expectEncryptedExtensions Context
ctx)
        , Bool -> (ByteString -> Handshake13 -> IO ()) -> PendingRecvAction
PendingRecvActionHash Bool
True ByteString -> Handshake13 -> IO ()
forall {m :: * -> *}.
MonadIO m =>
ByteString -> Handshake13 -> m ()
expectFinishedAndSet
        ]
  where
    expectServerHello :: Handshake13 -> IO ()
expectServerHello Handshake13
sh = do
        Context -> Millisecond -> IO ()
setRTT Context
ctx Millisecond
chSentTime
        ClientParams -> Context -> Handshake13 -> IO ()
processServerHello13 ClientParams
cparams Context
ctx Handshake13
sh
        IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> Maybe Group -> IO Bool
prepareSecondFlight13 Context
ctx Maybe Group
groupSent
    expectFinishedAndSet :: ByteString -> Handshake13 -> m ()
expectFinishedAndSet ByteString
h Handshake13
sf = do
        ClientParams -> Context -> ByteString -> Handshake13 -> m ()
forall (m :: * -> *).
MonadIO m =>
ClientParams -> Context -> ByteString -> Handshake13 -> m ()
expectFinished ClientParams
cparams Context
ctx ByteString
h Handshake13
sf
        IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$
            IORef (Maybe (Context -> IO ()))
-> Maybe (Context -> IO ()) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef (Maybe (Context -> IO ()))
ctxPendingSendAction Context
ctx) (Maybe (Context -> IO ()) -> IO ())
-> Maybe (Context -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
                (Context -> IO ()) -> Maybe (Context -> IO ())
forall a. a -> Maybe a
Just ((Context -> IO ()) -> Maybe (Context -> IO ()))
-> (Context -> IO ()) -> Maybe (Context -> IO ())
forall a b. (a -> b) -> a -> b
$
                    ClientParams -> Context -> IO ()
sendClientSecondFlight13 ClientParams
cparams