-- | Connection upgrade pipeline for the Switch.
--
-- Transforms a raw transport connection into a fully upgraded
-- (secure + multiplexed) Connection by executing a 4-step pipeline:
--   1. multistream-select: negotiate security protocol ("/noise")
--   2. Noise XX handshake: encrypted channel + remote PeerId
--   3. multistream-select: negotiate muxer ("/yamux/1.0.0")
--   4. Yamux session init: multiplexed streams
--
-- See docs/08-switch.md §Connection Upgrading Pipeline.
module Network.LibP2P.Switch.Upgrade
  ( -- * Streaming handshake
    performStreamHandshake
    -- * Encrypted StreamIO
  , noiseSessionToStreamIO
    -- * Yamux → MuxerSession adapter
  , yamuxToMuxerSession
    -- * Full upgrade pipeline
  , upgradeOutbound
  , upgradeInbound
    -- * Helpers (exported for testing)
  , readExact
  , readFramedMessage
  , writeFramedMessage
  ) where

import Control.Concurrent.Async (async)
import Control.Concurrent.STM (newTVarIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import Data.Word (Word8)
import Network.LibP2P.Core.Binary (readWord16BE)
import Network.LibP2P.Crypto.Key (KeyPair (..))
import Network.LibP2P.Crypto.PeerId (fromPublicKey)
import Network.LibP2P.Mux.Yamux.Session (closeSession, newSession, recvLoop, sendLoop)
import qualified Network.LibP2P.Mux.Yamux.Session as Yamux
import Network.LibP2P.Mux.Yamux.Stream (streamRead)
import qualified Network.LibP2P.Mux.Yamux.Stream as YS
import Network.LibP2P.Mux.Yamux.Types (SessionRole (..), YamuxSession, YamuxStream)
import Network.LibP2P.MultistreamSelect.Negotiation
  ( NegotiationResult (..)
  , StreamIO (..)
  , negotiateInitiator
  , negotiateResponder
  )
import Network.LibP2P.Security.Noise.Framing (encodeFrame)
import Network.LibP2P.Security.Noise.Handshake
  ( HandshakeResult (..)
  , buildHandshakePayload
  , decodeNoisePayload
  , encodeNoisePayload
  , getRemoteNoiseStaticKey
  , initHandshakeInitiator
  , initHandshakeResponder
  , readHandshakeMsg
  , verifyStaticKey
  , writeHandshakeMsg
  )
import Network.LibP2P.Security.Noise.Session
  ( NoiseSession
  , decryptMessage
  , encryptMessage
  , mkNoiseSession
  )
import Network.LibP2P.Switch.Types
  ( ConnState (..)
  , Connection (..)
  , Direction (..)
  , MuxerSession (..)
  )
import Network.LibP2P.Transport.Transport (RawConnection (..))
import qualified Network.LibP2P.Crypto.Protobuf as Proto
import qualified Network.LibP2P.Security.Noise.Handshake as HS

-- | Read exactly n bytes from a StreamIO.
readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]

-- | Read a 2-byte-BE-length-prefixed Noise frame from a StreamIO.
readFramedMessage :: StreamIO -> IO ByteString
readFramedMessage :: StreamIO -> IO ByteString
readFramedMessage StreamIO
stream = do
  lenBytes <- StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
2
  let len = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word16
readWord16BE ByteString
lenBytes) :: Int
  if len == 0
    then pure BS.empty
    else readExact stream len

-- | Write a 2-byte-BE-length-prefixed Noise frame to a StreamIO.
writeFramedMessage :: StreamIO -> ByteString -> IO ()
writeFramedMessage :: StreamIO -> ByteString -> IO ()
writeFramedMessage StreamIO
stream ByteString
msg = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream (ByteString -> ByteString
encodeFrame ByteString
msg)

-- | Perform a Noise XX handshake over a StreamIO using framed messages.
-- Returns (NoiseSession, HandshakeResult) with the remote PeerId.
performStreamHandshake
  :: KeyPair -> Direction -> StreamIO -> IO (NoiseSession, HandshakeResult)
performStreamHandshake :: KeyPair
-> Direction -> StreamIO -> IO (NoiseSession, HandshakeResult)
performStreamHandshake KeyPair
identityKP Direction
dir StreamIO
stream = case Direction
dir of
  Direction
Outbound -> KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performInitiatorHandshake KeyPair
identityKP StreamIO
stream
  Direction
Inbound  -> KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performResponderHandshake KeyPair
identityKP StreamIO
stream

-- | Initiator (dialer) side of the Noise XX handshake.
--
-- Message flow:
--   1. Initiator → Responder: e (empty payload)
--   2. Responder → Initiator: e, ee, s, es (responder identity payload)
--   3. Initiator → Responder: s, se (initiator identity payload)
performInitiatorHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performInitiatorHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performInitiatorHandshake KeyPair
identityKP StreamIO
stream = do
  (hsState0, noiseStaticPub) <- KeyPair -> IO (HandshakeState, ByteString)
initHandshakeInitiator KeyPair
identityKP

  -- Message 1: → (empty payload)
  (msg1, hsState1) <- either (fail . ("initiator msg1 write: " <>)) pure $
    writeHandshakeMsg hsState0 BS.empty
  writeFramedMessage stream msg1

  -- Message 2: ← (responder's identity payload)
  msg2 <- readFramedMessage stream
  (payload2, hsState2) <- either (fail . ("initiator msg2 read: " <>)) pure $
    readHandshakeMsg hsState1 msg2

  -- Decode responder's identity
  remoteNP <- either (fail . ("initiator decode payload: " <>)) pure $
    decodeNoisePayload payload2
  remotePubKey <- either (fail . ("initiator decode pubkey: " <>)) pure $
    Proto.decodePublicKey (HS.npIdentityKey remoteNP)
  let remotePeerId = PublicKey -> PeerId
fromPublicKey PublicKey
remotePubKey

  -- Verify identity_sig: binds identity key to Noise static key
  case getRemoteNoiseStaticKey hsState2 of
    Maybe ByteString
Nothing -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performInitiatorHandshake: remote Noise static key unavailable after msg2"
    Just ByteString
remoteNoisePub ->
      if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
remotePubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
HS.npIdentitySig NoisePayload
remoteNP))
        then String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performInitiatorHandshake: identity signature verification failed"
        else () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  -- Message 3: → (initiator's identity payload)
  let identPayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
identityKP ByteString
noiseStaticPub
  (msg3, hsStateFinal) <- either (fail . ("initiator msg3 write: " <>)) pure $
    writeHandshakeMsg hsState2 identPayload
  writeFramedMessage stream msg3

  let noiseSession = CacophonyState -> NoiseSession
mkNoiseSession (HandshakeState -> CacophonyState
HS.hsNoiseState HandshakeState
hsStateFinal)
  pure (noiseSession, HandshakeResult remotePeerId remotePubKey)

-- | Responder (listener) side of the Noise XX handshake.
performResponderHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performResponderHandshake :: KeyPair -> StreamIO -> IO (NoiseSession, HandshakeResult)
performResponderHandshake KeyPair
identityKP StreamIO
stream = do
  (hsState0, noiseStaticPub) <- KeyPair -> IO (HandshakeState, ByteString)
initHandshakeResponder KeyPair
identityKP

  -- Message 1: ← (empty payload)
  msg1 <- readFramedMessage stream
  (_payload1, hsState1) <- either (fail . ("responder msg1 read: " <>)) pure $
    readHandshakeMsg hsState0 msg1

  -- Message 2: → (responder's identity payload)
  let identPayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
identityKP ByteString
noiseStaticPub
  (msg2, hsState2) <- either (fail . ("responder msg2 write: " <>)) pure $
    writeHandshakeMsg hsState1 identPayload
  writeFramedMessage stream msg2

  -- Message 3: ← (initiator's identity payload)
  msg3 <- readFramedMessage stream
  (payload3, hsStateFinal) <- either (fail . ("responder msg3 read: " <>)) pure $
    readHandshakeMsg hsState2 msg3

  -- Decode initiator's identity
  remoteNP <- either (fail . ("responder decode payload: " <>)) pure $
    decodeNoisePayload payload3
  remotePubKey <- either (fail . ("responder decode pubkey: " <>)) pure $
    Proto.decodePublicKey (HS.npIdentityKey remoteNP)
  let remotePeerId = PublicKey -> PeerId
fromPublicKey PublicKey
remotePubKey

  -- Verify identity_sig: binds identity key to Noise static key
  case getRemoteNoiseStaticKey hsStateFinal of
    Maybe ByteString
Nothing -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performResponderHandshake: remote Noise static key unavailable after msg3"
    Just ByteString
remoteNoisePub ->
      if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
remotePubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
HS.npIdentitySig NoisePayload
remoteNP))
        then String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"performResponderHandshake: identity signature verification failed"
        else () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  let noiseSession = CacophonyState -> NoiseSession
mkNoiseSession (HandshakeState -> CacophonyState
HS.hsNoiseState HandshakeState
hsStateFinal)
  pure (noiseSession, HandshakeResult remotePeerId remotePubKey)

-- | Create an encrypted StreamIO from a NoiseSession and raw StreamIO.
--
-- Uses separate IORefs for send/recv session state (each direction's
-- CipherState is independent in Noise). A read buffer (IORef ByteString)
-- bridges Noise's message-boundary decryption with StreamIO's byte-level reads.
noiseSessionToStreamIO
  :: IORef NoiseSession    -- ^ Send session state
  -> IORef NoiseSession    -- ^ Recv session state
  -> IORef ByteString      -- ^ Read buffer (decrypted but unconsumed bytes)
  -> StreamIO              -- ^ Raw (unencrypted) StreamIO
  -> StreamIO
noiseSessionToStreamIO :: IORef NoiseSession
-> IORef NoiseSession -> IORef ByteString -> StreamIO -> StreamIO
noiseSessionToStreamIO IORef NoiseSession
sendRef IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO = StreamIO
  { streamWrite :: ByteString -> IO ()
streamWrite = IORef NoiseSession -> StreamIO -> ByteString -> IO ()
encryptAndWrite IORef NoiseSession
sendRef StreamIO
rawIO
  , streamReadByte :: IO Word8
streamReadByte = IORef NoiseSession -> IORef ByteString -> StreamIO -> IO Word8
decryptAndReadByte IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO
  , streamClose :: IO ()
streamClose = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- Encryption layer does not own the connection
  }

-- | Encrypt plaintext and write as a framed Noise message.
encryptAndWrite :: IORef NoiseSession -> StreamIO -> ByteString -> IO ()
encryptAndWrite :: IORef NoiseSession -> StreamIO -> ByteString -> IO ()
encryptAndWrite IORef NoiseSession
sendRef StreamIO
rawIO ByteString
plaintext = do
  sess <- IORef NoiseSession -> IO NoiseSession
forall a. IORef a -> IO a
readIORef IORef NoiseSession
sendRef
  case encryptMessage sess plaintext of
    Left String
err -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"encryptAndWrite: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
err
    Right (ByteString
ct, NoiseSession
sess') -> do
      IORef NoiseSession -> NoiseSession -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef NoiseSession
sendRef NoiseSession
sess'
      StreamIO -> ByteString -> IO ()
writeFramedMessage StreamIO
rawIO ByteString
ct

-- | Read and decrypt a byte from the Noise channel.
-- If the buffer has bytes, return the first. Otherwise, read a full Noise
-- frame from the raw stream, decrypt it, and buffer the result.
decryptAndReadByte :: IORef NoiseSession -> IORef ByteString -> StreamIO -> IO Word8
decryptAndReadByte :: IORef NoiseSession -> IORef ByteString -> StreamIO -> IO Word8
decryptAndReadByte IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO = do
  buf <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
bufRef
  if BS.null buf
    then do
      -- Read a full framed Noise message
      ct <- readFramedMessage rawIO
      sess <- readIORef recvRef
      case decryptMessage sess ct of
        Left String
err -> String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO Word8) -> String -> IO Word8
forall a b. (a -> b) -> a -> b
$ String
"decryptAndReadByte: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
err
        Right (ByteString
pt, NoiseSession
sess') -> do
          IORef NoiseSession -> NoiseSession -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef NoiseSession
recvRef NoiseSession
sess'
          if ByteString -> Bool
BS.null ByteString
pt
            then String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"decryptAndReadByte: empty plaintext"
            else do
              IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
bufRef (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
pt)
              Word8 -> IO Word8
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
pt)
    else do
      writeIORef bufRef (BS.tail buf)
      pure (BS.head buf)

-- | Wrap a YamuxSession as a MuxerSession.
-- Starts sendLoop and recvLoop as background threads.
-- The MuxerSession provides open/accept stream operations that
-- produce StreamIO-compatible streams.
yamuxToMuxerSession :: YamuxSession -> IO MuxerSession
yamuxToMuxerSession :: YamuxSession -> IO MuxerSession
yamuxToMuxerSession YamuxSession
yamuxSess = do
  -- Start background loops
  _ <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (YamuxSession -> IO ()
sendLoop YamuxSession
yamuxSess)
  _ <- async (recvLoop yamuxSess)
  pure MuxerSession
    { muxOpenStream = do
        result <- Yamux.openStream yamuxSess
        case result of
          Right YamuxStream
stream -> YamuxStream -> IO StreamIO
yamuxStreamToStreamIO YamuxStream
stream
          Left YamuxError
err -> String -> IO StreamIO
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO StreamIO) -> String -> IO StreamIO
forall a b. (a -> b) -> a -> b
$ String
"muxOpenStream: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
    , muxAcceptStream = do
        result <- Yamux.acceptStream yamuxSess
        case result of
          Right YamuxStream
stream -> YamuxStream -> IO StreamIO
yamuxStreamToStreamIO YamuxStream
stream
          Left YamuxError
err -> String -> IO StreamIO
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO StreamIO) -> String -> IO StreamIO
forall a b. (a -> b) -> a -> b
$ String
"muxAcceptStream: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
    , muxClose = closeSession yamuxSess
    }

-- | Convert a YamuxStream to StreamIO with a read buffer.
-- Yamux delivers data in chunks via streamRead, but StreamIO requires
-- byte-by-byte reads. An IORef buffer bridges this gap.
yamuxStreamToStreamIO :: YamuxStream -> IO StreamIO
yamuxStreamToStreamIO :: YamuxStream -> IO StreamIO
yamuxStreamToStreamIO YamuxStream
yamuxStream = do
  readBuf <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
BS.empty
  pure StreamIO
    { streamWrite = \ByteString
bs -> do
        result <- YamuxStream -> ByteString -> IO (Either YamuxError ())
YS.streamWrite YamuxStream
yamuxStream ByteString
bs
        case result of
          Right () -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          Left YamuxError
err -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"yamuxStreamWrite: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
    , streamReadByte = do
        buf <- readIORef readBuf
        if BS.null buf
          then do
            result <- streamRead yamuxStream
            case result of
              Left YamuxError
err -> String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO Word8) -> String -> IO Word8
forall a b. (a -> b) -> a -> b
$ String
"yamuxStreamRead: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> YamuxError -> String
forall a. Show a => a -> String
show YamuxError
err
              Right ByteString
chunk
                | ByteString -> Bool
BS.null ByteString
chunk -> String -> IO Word8
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"yamuxStreamRead: empty chunk"
                | ByteString -> Int
BS.length ByteString
chunk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 -> Word8 -> IO Word8
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
chunk)
                | Bool
otherwise -> do
                    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
readBuf (HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
chunk)
                    Word8 -> IO Word8
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
chunk)
          else do
            writeIORef readBuf (BS.tail buf)
            pure (BS.head buf)
    , streamClose = do
        _ <- YS.streamClose yamuxStream  -- Sends FIN flag
        pure ()
    }

-- | Upgrade an outbound (dialer) raw connection.
-- Pipeline: mss(/noise) → Noise XX → mss(/yamux/1.0.0) → Yamux client
upgradeOutbound :: KeyPair -> RawConnection -> IO Connection
upgradeOutbound :: KeyPair -> RawConnection -> IO Connection
upgradeOutbound KeyPair
identityKP RawConnection
rawConn = do
  let rawIO :: StreamIO
rawIO = RawConnection -> StreamIO
rcStreamIO RawConnection
rawConn

  -- Step 1: multistream-select → "/noise"
  secResult <- StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateInitiator StreamIO
rawIO [ProtocolId
"/noise"]
  case secResult of
    Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeOutbound: /noise negotiation failed"

  -- Step 2: Noise XX handshake (initiator)
  (noiseSess, HandshakeResult remotePeerId _remotePK) <-
    performStreamHandshake identityKP Outbound rawIO

  -- Step 3: Create encrypted StreamIO
  sendRef <- newIORef noiseSess
  recvRef <- newIORef noiseSess
  bufRef  <- newIORef BS.empty
  let encryptedIO = IORef NoiseSession
-> IORef NoiseSession -> IORef ByteString -> StreamIO -> StreamIO
noiseSessionToStreamIO IORef NoiseSession
sendRef IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO

  -- Step 4: multistream-select → "/yamux/1.0.0" (over encrypted channel)
  muxResult <- negotiateInitiator encryptedIO ["/yamux/1.0.0"]
  case muxResult of
    Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeOutbound: /yamux/1.0.0 negotiation failed"

  -- Step 5: Initialize Yamux session (client = odd IDs)
  let yamuxWrite = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
encryptedIO
      yamuxRead  = \Int
n -> StreamIO -> Int -> IO ByteString
readExact StreamIO
encryptedIO Int
n
  yamuxSess <- newSession RoleClient yamuxWrite yamuxRead
  muxer <- yamuxToMuxerSession yamuxSess

  -- Build Connection
  stateVar <- newTVarIO ConnOpen
  pure Connection
    { connPeerId     = remotePeerId
    , connDirection  = Outbound
    , connLocalAddr  = rcLocalAddr rawConn
    , connRemoteAddr = rcRemoteAddr rawConn
    , connSecurity   = "/noise"
    , connMuxer      = "/yamux/1.0.0"
    , connSession    = muxer
    , connState      = stateVar
    }

-- | Upgrade an inbound (listener) raw connection.
-- Pipeline: mss(/noise) → Noise XX → mss(/yamux/1.0.0) → Yamux server
upgradeInbound :: KeyPair -> RawConnection -> IO Connection
upgradeInbound :: KeyPair -> RawConnection -> IO Connection
upgradeInbound KeyPair
identityKP RawConnection
rawConn = do
  let rawIO :: StreamIO
rawIO = RawConnection -> StreamIO
rcStreamIO RawConnection
rawConn

  -- Step 1: multistream-select → "/noise"
  secResult <- StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateResponder StreamIO
rawIO [ProtocolId
"/noise"]
  case secResult of
    Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeInbound: /noise negotiation failed"

  -- Step 2: Noise XX handshake (responder)
  (noiseSess, HandshakeResult remotePeerId _remotePK) <-
    performStreamHandshake identityKP Inbound rawIO

  -- Step 3: Create encrypted StreamIO
  sendRef <- newIORef noiseSess
  recvRef <- newIORef noiseSess
  bufRef  <- newIORef BS.empty
  let encryptedIO = IORef NoiseSession
-> IORef NoiseSession -> IORef ByteString -> StreamIO -> StreamIO
noiseSessionToStreamIO IORef NoiseSession
sendRef IORef NoiseSession
recvRef IORef ByteString
bufRef StreamIO
rawIO

  -- Step 4: multistream-select → "/yamux/1.0.0" (over encrypted channel)
  muxResult <- negotiateResponder encryptedIO ["/yamux/1.0.0"]
  case muxResult of
    Accepted ProtocolId
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    NegotiationResult
NoProtocol -> String -> IO ()
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"upgradeInbound: /yamux/1.0.0 negotiation failed"

  -- Step 5: Initialize Yamux session (server = even IDs)
  let yamuxWrite = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
encryptedIO
      yamuxRead  = \Int
n -> StreamIO -> Int -> IO ByteString
readExact StreamIO
encryptedIO Int
n
  yamuxSess <- newSession RoleServer yamuxWrite yamuxRead
  muxer <- yamuxToMuxerSession yamuxSess

  -- Build Connection
  stateVar <- newTVarIO ConnOpen
  pure Connection
    { connPeerId     = remotePeerId
    , connDirection  = Inbound
    , connLocalAddr  = rcLocalAddr rawConn
    , connRemoteAddr = rcRemoteAddr rawConn
    , connSecurity   = "/noise"
    , connMuxer      = "/yamux/1.0.0"
    , connSession    = muxer
    , connState      = stateVar
    }