module Network.LibP2P.Switch.Upgrade
(
performStreamHandshake
, noiseSessionToStreamIO
, yamuxToMuxerSession
, upgradeOutbound
, upgradeInbound
, readExact
, readFramedMessage
, writeFramedMessage
) where
import Control.Concurrent.Async (async)
import Control.Concurrent.STM (newTVarIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Word (Word8)
import Network.LibP2P.Core.Binary (readWord16BE)
import Network.LibP2P.Crypto.Key (KeyPair (..))
import Network.LibP2P.Crypto.PeerId (fromPublicKey)
import Network.LibP2P.Mux.Yamux.Session (closeSession, newSession, recvLoop, sendLoop)
import qualified Network.LibP2P.Mux.Yamux.Session as Yamux
import Network.LibP2P.Mux.Yamux.Stream (streamRead)
import qualified Network.LibP2P.Mux.Yamux.Stream as YS
import Network.LibP2P.Mux.Yamux.Types (SessionRole (..), YamuxSession, YamuxStream)
import Network.LibP2P.MultistreamSelect.Negotiation
( NegotiationResult (..)
, StreamIO (..)
, negotiateInitiator
, negotiateResponder
)
import Network.LibP2P.Security.Noise.Framing (encodeFrame)
import Network.LibP2P.Security.Noise.Handshake
( HandshakeResult (..)
, buildHandshakePayload
, decodeNoisePayload
, encodeNoisePayload
, getRemoteNoiseStaticKey
, initHandshakeInitiator
, initHandshakeResponder
, readHandshakeMsg
, verifyStaticKey
, writeHandshakeMsg
)
import Network.LibP2P.Security.Noise.Session
( NoiseSession
, decryptMessage
, encryptMessage
, mkNoiseSession
)
import Network.LibP2P.Switch.Types
( ConnState (..)
, Connection (..)
, Direction (..)
, MuxerSession (..)
)
import Network.LibP2P.Transport.Transport (RawConnection (..))
import qualified Network.LibP2P.Crypto.Protobuf as Proto
import qualified Network.LibP2P.Security.Noise.Handshake as HS
readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]
readFramedMessage :: StreamIO -> IO ByteString
readFramedMessage :: StreamIO -> IO ByteString
readFramedMessage StreamIO
stream = do
lenBytes <- StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
2
let len = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word16
readWord16BE ByteString
lenBytes) :: Int
if len == 0
then pure BS.empty
else readExact stream len
writeFramedMessage :: StreamIO -> ByteString -> IO ()
writeFramedMessage :: StreamIO -> ByteString -> IO ()
writeFramedMessage StreamIO
stream ByteString
msg = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream (ByteString -> ByteString
encodeFrame ByteString
msg)
performStreamHandshake
:: KeyPair -> Direction -> StreamIO -> IO (NoiseSession, HandshakeResult)
performStreamHandshake :: KeyPair
-> Direction -> StreamIO -> IO (NoiseSession, HandshakeResult)
performStreamHandshake KeyPair
identityKP Direction
dir StreamIO
stream = case Direction
dir of
Direction
Outbound -> KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performInitiatorHandshake KeyPair
identityKP StreamIO
stream
Direction
Inbound -> KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performResponderHandshake KeyPair
identityKP StreamIO
stream
performInitiatorHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performInitiatorHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performInitiatorHandshake KeyPair
identityKP StreamIO
stream = do
(hsState0, noiseStaticPub) <- KeyPair -> IO (HandshakeState, ByteString)
initHandshakeInitiator KeyPair
identityKP
(msg1, hsState1) <- either (fail . ("initiator msg1 write: " <>)) pure $
writeHandshakeMsg hsState0 BS.empty
writeFramedMessage stream msg1
msg2 <- readFramedMessage stream
(payload2, hsState2) <- either (fail . ("initiator msg2 read: " <>)) pure $
readHandshakeMsg hsState1 msg2
remoteNP <- either (fail . ("initiator decode payload: " <>)) pure $
decodeNoisePayload payload2
remotePubKey <- either (fail . ("initiator decode pubkey: " <>)) pure $
Proto.decodePublicKey (HS.npIdentityKey remoteNP)
let remotePeerId = PublicKey -> PeerId
fromPublicKey PublicKey
remotePubKey
case getRemoteNoiseStaticKey hsState2 of
Maybe ByteString
Nothing -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performInitiatorHandshake: remote Noise static key unavailable after msg2"
Just ByteString
remoteNoisePub ->
if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
remotePubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
HS.npIdentitySig NoisePayload
remoteNP))
then String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performInitiatorHandshake: identity signature verification failed"
else () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let identPayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
identityKP ByteString
noiseStaticPub
(msg3, hsStateFinal) <- either (fail . ("initiator msg3 write: " <>)) pure $
writeHandshakeMsg hsState2 identPayload
writeFramedMessage stream msg3
let noiseSession = CacophonyState -> NoiseSession
mkNoiseSession (HandshakeState -> CacophonyState
HS.hsNoiseState HandshakeState
hsStateFinal)
pure (noiseSession, HandshakeResult remotePeerId remotePubKey)
performResponderHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performResponderHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performResponderHandshake KeyPair
identityKP StreamIO
stream = do
(hsState0, noiseStaticPub) <- KeyPair -> IO (HandshakeState, ByteString)
initHandshakeResponder KeyPair
identityKP
msg1 <- readFramedMessage stream
(_payload1, hsState1) <- either (fail . ("responder msg1 read: " <>)) pure $
readHandshakeMsg hsState0 msg1
let identPayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
identityKP ByteString
noiseStaticPub
(msg2, hsState2) <- either (fail . ("responder msg2 write: " <>)) pure $
writeHandshakeMsg hsState1 identPayload
writeFramedMessage stream msg2
msg3 <- readFramedMessage stream
(payload3, hsStateFinal) <- either (fail . ("responder msg3 read: " <>)) pure $
readHandshakeMsg hsState2 msg3
remoteNP <- either (fail . ("responder decode payload: " <>)) pure $
decodeNoisePayload payload3
remotePubKey <- either (fail . ("responder decode pubkey: " <>)) pure $
Proto.decodePublicKey (HS.npIdentityKey remoteNP)
let remotePeerId = PublicKey -> PeerId
fromPublicKey PublicKey
remotePubKey
case getRemoteNoiseStaticKey hsStateFinal of
Maybe ByteString
Nothing -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performResponderHandshake: remote Noise static key unavailable after msg3"
Just ByteString
remoteNoisePub ->
if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
remotePubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
HS.npIdentitySig NoisePayload
remoteNP))
then String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performResponderHandshake: identity signature verification failed"
else () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let noiseSession = CacophonyState -> NoiseSession
mkNoiseSession (HandshakeState -> CacophonyState
HS.hsNoiseState HandshakeState
hsStateFinal)
pure (noiseSession, HandshakeResult remotePeerId remotePubKey)
noiseSessionToStreamIO
:: IORef NoiseSession
-> IORef NoiseSession
-> IORef ByteString
-> StreamIO
-> StreamIO
noiseSessionToStreamIO :: IORef NoiseSession
-> IORef NoiseSession -> IORef ByteString -> StreamIO -> StreamIO
noiseSessionToStreamIO IORef NoiseSession
sendRef IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO = StreamIO
{ streamWrite :: ByteString -> IO ()
streamWrite = IORef NoiseSession -> StreamIO -> ByteString -> IO ()
encryptAndWrite IORef NoiseSession
sendRef StreamIO
rawIO
, streamReadByte :: IO Word8
streamReadByte = IORef NoiseSession -> IORef ByteString -> StreamIO -> IO Word8
decryptAndReadByte IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO
, streamClose :: IO ()
streamClose = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
}
encryptAndWrite :: IORef NoiseSession -> StreamIO -> ByteString -> IO ()
encryptAndWrite :: IORef NoiseSession -> StreamIO -> ByteString -> IO ()
encryptAndWrite IORef NoiseSession
sendRef StreamIO
rawIO ByteString
plaintext = do
sess <- IORef NoiseSession -> IO NoiseSession
forall a. IORef a -> IO a
readIORef IORef NoiseSession
sendRef
case encryptMessage sess plaintext of
Left String
err -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"encryptAndWrite: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
err
Right (ByteString
ct, NoiseSession
sess') -> do
IORef NoiseSession -> NoiseSession -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef NoiseSession
sendRef NoiseSession
sess'
StreamIO -> ByteString -> IO ()
writeFramedMessage StreamIO
rawIO ByteString
ct
decryptAndReadByte :: IORef NoiseSession -> IORef ByteString -> StreamIO -> IO Word8
decryptAndReadByte :: IORef NoiseSession -> IORef ByteString -> StreamIO -> IO Word8
decryptAndReadByte IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO = do
buf <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
bufRef
if BS.null buf
then do
ct <- readFramedMessage rawIO
sess <- readIORef recvRef
case decryptMessage sess ct of
Left String
err -> String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO Word8) -> String -> IO Word8
forall a b. (a -> b) -> a -> b
$ String
"decryptAndReadByte: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
err
Right (ByteString
pt, NoiseSession
sess') -> do
IORef NoiseSession -> NoiseSession -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef NoiseSession
recvRef NoiseSession
sess'
if ByteString -> Bool
BS.null ByteString
pt
then String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"decryptAndReadByte: empty plaintext"
else do
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
bufRef (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
pt)
Word8 -> IO Word8
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
pt)
else do
writeIORef bufRef (BS.tail buf)
pure (BS.head buf)
yamuxToMuxerSession :: YamuxSession -> IO MuxerSession
yamuxToMuxerSession :: YamuxSession -> IO MuxerSession
yamuxToMuxerSession YamuxSession
yamuxSess = do
_ <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (YamuxSession -> IO ()
sendLoop YamuxSession
yamuxSess)
_ <- async (recvLoop yamuxSess)
pure MuxerSession
{ muxOpenStream = do
result <- Yamux.openStream yamuxSess
case result of
Right YamuxStream
stream -> YamuxStream -> IO StreamIO
yamuxStreamToStreamIO YamuxStream
stream
Left YamuxError
err -> String -> IO StreamIO
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO StreamIO) -> String -> IO StreamIO
forall a b. (a -> b) -> a -> b
$ String
"muxOpenStream: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
, muxAcceptStream = do
result <- Yamux.acceptStream yamuxSess
case result of
Right YamuxStream
stream -> YamuxStream -> IO StreamIO
yamuxStreamToStreamIO YamuxStream
stream
Left YamuxError
err -> String -> IO StreamIO
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO StreamIO) -> String -> IO StreamIO
forall a b. (a -> b) -> a -> b
$ String
"muxAcceptStream: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
, muxClose = closeSession yamuxSess
}
yamuxStreamToStreamIO :: YamuxStream -> IO StreamIO
yamuxStreamToStreamIO :: YamuxStream -> IO StreamIO
yamuxStreamToStreamIO YamuxStream
yamuxStream = do
readBuf <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
BS.empty
pure StreamIO
{ streamWrite = \ByteString
bs -> do
result <- YamuxStream -> ByteString -> IO (Either YamuxError ())
YS.streamWrite YamuxStream
yamuxStream ByteString
bs
case result of
Right () -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Left YamuxError
err -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"yamuxStreamWrite: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
, streamReadByte = do
buf <- readIORef readBuf
if BS.null buf
then do
result <- streamRead yamuxStream
case result of
Left YamuxError
err -> String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO Word8) -> String -> IO Word8
forall a b. (a -> b) -> a -> b
$ String
"yamuxStreamRead: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
Right ByteString
chunk
| ByteString -> Bool
BS.null ByteString
chunk -> String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"yamuxStreamRead: empty chunk"
| ByteString -> Int
BS.length ByteString
chunk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 -> Word8 -> IO Word8
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
chunk)
| Bool
otherwise -> do
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
readBuf (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
chunk)
Word8 -> IO Word8
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
chunk)
else do
writeIORef readBuf (BS.tail buf)
pure (BS.head buf)
, streamClose = do
_ <- YS.streamClose yamuxStream
pure ()
}
upgradeOutbound :: KeyPair -> RawConnection -> IO Connection
upgradeOutbound :: KeyPair -> RawConnection -> IO Connection
upgradeOutbound KeyPair
identityKP RawConnection
rawConn = do
let rawIO :: StreamIO
rawIO = RawConnection -> StreamIO
rcStreamIO RawConnection
rawConn
secResult <- StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateInitiator StreamIO
rawIO [ProtocolId
"/noise"]
case secResult of
Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeOutbound: /noise negotiation failed"
(noiseSess, HandshakeResult remotePeerId _remotePK) <-
performStreamHandshake identityKP Outbound rawIO
sendRef <- newIORef noiseSess
recvRef <- newIORef noiseSess
bufRef <- newIORef BS.empty
let encryptedIO = IORef NoiseSession
-> IORef NoiseSession -> IORef ByteString -> StreamIO -> StreamIO
noiseSessionToStreamIO IORef NoiseSession
sendRef IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO
muxResult <- negotiateInitiator encryptedIO ["/yamux/1.0.0"]
case muxResult of
Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeOutbound: /yamux/1.0.0 negotiation failed"
let yamuxWrite = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
encryptedIO
yamuxRead = \Int
n -> StreamIO -> Int -> IO ByteString
readExact StreamIO
encryptedIO Int
n
yamuxSess <- newSession RoleClient yamuxWrite yamuxRead
muxer <- yamuxToMuxerSession yamuxSess
stateVar <- newTVarIO ConnOpen
pure Connection
{ connPeerId = remotePeerId
, connDirection = Outbound
, connLocalAddr = rcLocalAddr rawConn
, connRemoteAddr = rcRemoteAddr rawConn
, connSecurity = "/noise"
, connMuxer = "/yamux/1.0.0"
, connSession = muxer
, connState = stateVar
}
upgradeInbound :: KeyPair -> RawConnection -> IO Connection
upgradeInbound :: KeyPair -> RawConnection -> IO Connection
upgradeInbound KeyPair
identityKP RawConnection
rawConn = do
let rawIO :: StreamIO
rawIO = RawConnection -> StreamIO
rcStreamIO RawConnection
rawConn
secResult <- StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateResponder StreamIO
rawIO [ProtocolId
"/noise"]
case secResult of
Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeInbound: /noise negotiation failed"
(noiseSess, HandshakeResult remotePeerId _remotePK) <-
performStreamHandshake identityKP Inbound rawIO
sendRef <- newIORef noiseSess
recvRef <- newIORef noiseSess
bufRef <- newIORef BS.empty
let encryptedIO = IORef NoiseSession
-> IORef NoiseSession -> IORef ByteString -> StreamIO -> StreamIO
noiseSessionToStreamIO IORef NoiseSession
sendRef IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO
muxResult <- negotiateResponder encryptedIO ["/yamux/1.0.0"]
case muxResult of
Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeInbound: /yamux/1.0.0 negotiation failed"
let yamuxWrite = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
encryptedIO
yamuxRead = \Int
n -> StreamIO -> Int -> IO ByteString
readExact StreamIO
encryptedIO Int
n
yamuxSess <- newSession RoleServer yamuxWrite yamuxRead
muxer <- yamuxToMuxerSession yamuxSess
stateVar <- newTVarIO ConnOpen
pure Connection
{ connPeerId = remotePeerId
, connDirection = Inbound
, connLocalAddr = rcLocalAddr rawConn
, connRemoteAddr = rcRemoteAddr rawConn
, connSecurity = "/noise"
, connMuxer = "/yamux/1.0.0"
, connSession = muxer
, connState = stateVar
}