-- | Noise XX handshake for libp2p secure channels.
--
-- Implements the Noise_XX_25519_ChaChaPoly_SHA256 handshake pattern
-- with libp2p-specific payload injection (identity key + signature).
--
-- Uses cacophony for the core Noise protocol state machine.
module Network.LibP2P.Security.Noise.Handshake
  ( -- * Handshake types
    HandshakeResult (..)
  , NoisePayload (..)
  , HandshakeState (..)
    -- * Payload encoding
  , encodeNoisePayload
  , decodeNoisePayload
  , buildHandshakePayload
  , validateHandshakePayload
    -- * Static key signing
  , signStaticKey
  , verifyStaticKey
    -- * Handshake lifecycle
  , initHandshakeInitiator
  , initHandshakeResponder
  , writeHandshakeMsg
  , readHandshakeMsg
  , sessionComplete
    -- * Remote static key extraction
  , getRemoteNoiseStaticKey
    -- * Convenience
  , performFullHandshake
  , performFullHandshakeWithSessions
    -- * Re-exports for payload decoding
  , decodePublicKey
  ) where

import Crypto.Noise
  ( HandshakeRole (..)
  , NoiseResult (..)
  , NoiseState
  , convert
  , defaultHandshakeOpts
  , handshakeComplete
  , noiseState
  , readMessage
  , remoteStaticKey
  , setLocalEphemeral
  , setLocalStatic
  , writeMessage
  )
import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305)
import Crypto.Noise.DH (dhGenKey, dhPubToBytes)
import qualified Crypto.Noise.DH as DH
import Crypto.Noise.DH.Curve25519 (Curve25519)
import Crypto.Noise.HandshakePatterns (noiseXX)
import Crypto.Noise.Hash.SHA256 (SHA256)
import Data.ByteArray (ScrubbedBytes)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Word (Word8)
import Network.LibP2P.Core.Varint (decodeUvarint, encodeUvarint)
import Network.LibP2P.Crypto.Key
  ( KeyPair (..)
  , PrivateKey (..)
  , PublicKey (..)
  , verify
  )
import qualified Network.LibP2P.Crypto.Key as Key
import Network.LibP2P.Crypto.PeerId (PeerId, fromPublicKey)
import Network.LibP2P.Crypto.Protobuf (decodePublicKey, encodePublicKey)
import Network.LibP2P.Security.Noise.Session (NoiseSession, mkNoiseSession)

-- | Type alias for the Noise state with our fixed cipher suite.
type CacophonyState = NoiseState ChaChaPoly1305 Curve25519 SHA256

-- | Opaque handshake state wrapping cacophony's NoiseState.
newtype HandshakeState = HandshakeState
  { HandshakeState -> CacophonyState
hsNoiseState :: CacophonyState
  }

-- | Result of a successful Noise handshake.
data HandshakeResult = HandshakeResult
  { HandshakeResult -> PeerId
hrRemotePeerId :: !PeerId
  , HandshakeResult -> PublicKey
hrRemotePublicKey :: !PublicKey
  }
  deriving (Int -> HandshakeResult -> ShowS
[HandshakeResult] -> ShowS
HandshakeResult -> String
(Int -> HandshakeResult -> ShowS)
-> (HandshakeResult -> String)
-> ([HandshakeResult] -> ShowS)
-> Show HandshakeResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HandshakeResult -> ShowS
showsPrec :: Int -> HandshakeResult -> ShowS
$cshow :: HandshakeResult -> String
show :: HandshakeResult -> String
$cshowList :: [HandshakeResult] -> ShowS
showList :: [HandshakeResult] -> ShowS
Show, HandshakeResult -> HandshakeResult -> Bool
(HandshakeResult -> HandshakeResult -> Bool)
-> (HandshakeResult -> HandshakeResult -> Bool)
-> Eq HandshakeResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HandshakeResult -> HandshakeResult -> Bool
== :: HandshakeResult -> HandshakeResult -> Bool
$c/= :: HandshakeResult -> HandshakeResult -> Bool
/= :: HandshakeResult -> HandshakeResult -> Bool
Eq)

-- | Noise handshake payload (protobuf-encoded in messages 2 and 3).
data NoisePayload = NoisePayload
  { NoisePayload -> ByteString
npIdentityKey :: !ByteString -- ^ Serialized PublicKey protobuf
  , NoisePayload -> ByteString
npIdentitySig :: !ByteString -- ^ Signature over "noise-libp2p-static-key:" || static_pubkey
  }
  deriving (Int -> NoisePayload -> ShowS
[NoisePayload] -> ShowS
NoisePayload -> String
(Int -> NoisePayload -> ShowS)
-> (NoisePayload -> String)
-> ([NoisePayload] -> ShowS)
-> Show NoisePayload
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NoisePayload -> ShowS
showsPrec :: Int -> NoisePayload -> ShowS
$cshow :: NoisePayload -> String
show :: NoisePayload -> String
$cshowList :: [NoisePayload] -> ShowS
showList :: [NoisePayload] -> ShowS
Show, NoisePayload -> NoisePayload -> Bool
(NoisePayload -> NoisePayload -> Bool)
-> (NoisePayload -> NoisePayload -> Bool) -> Eq NoisePayload
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NoisePayload -> NoisePayload -> Bool
== :: NoisePayload -> NoisePayload -> Bool
$c/= :: NoisePayload -> NoisePayload -> Bool
/= :: NoisePayload -> NoisePayload -> Bool
Eq)

-- | The prefix for the signed data in Noise handshake.
noiseStaticKeyPrefix :: ByteString
noiseStaticKeyPrefix :: ByteString
noiseStaticKeyPrefix = ByteString
"noise-libp2p-static-key:"

-- | Sign the Noise static public key with the identity private key.
signStaticKey :: PrivateKey -> ByteString -> Either String ByteString
signStaticKey :: PrivateKey -> ByteString -> Either String ByteString
signStaticKey PrivateKey
sk ByteString
noiseStaticPubKey =
  let payload :: ByteString
payload = ByteString
noiseStaticKeyPrefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
noiseStaticPubKey
   in PrivateKey -> ByteString -> Either String ByteString
Key.sign PrivateKey
sk ByteString
payload

-- | Verify a signature over the Noise static public key.
verifyStaticKey :: PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey :: PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
pk ByteString
noiseStaticPubKey ByteString
sig =
  let payload :: ByteString
payload = ByteString
noiseStaticKeyPrefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
noiseStaticPubKey
   in PublicKey -> ByteString -> ByteString -> Bool
verify PublicKey
pk ByteString
payload ByteString
sig

-- | Build a handshake payload from an identity key pair and Noise static pubkey.
buildHandshakePayload :: Key.KeyPair -> ByteString -> NoisePayload
buildHandshakePayload :: KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
identityKP ByteString
noiseStaticPub =
  let identKey :: ByteString
identKey = PublicKey -> ByteString
encodePublicKey (KeyPair -> PublicKey
kpPublic KeyPair
identityKP)
      identSig :: ByteString
identSig = case PrivateKey -> ByteString -> Either String ByteString
signStaticKey (KeyPair -> PrivateKey
kpPrivate KeyPair
identityKP) ByteString
noiseStaticPub of
        Right ByteString
s -> ByteString
s
        Left String
err -> String -> ByteString
forall a. HasCallStack => String -> a
error (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ String
"buildHandshakePayload: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
err
   in ByteString -> ByteString -> NoisePayload
NoisePayload ByteString
identKey ByteString
identSig

-- | Validate a handshake payload (decode identity key and check structure).
-- Does NOT verify the signature (caller must provide the remote Noise static key).
validateHandshakePayload :: NoisePayload -> Either String PublicKey
validateHandshakePayload :: NoisePayload -> Either String PublicKey
validateHandshakePayload NoisePayload
np = ByteString -> Either String PublicKey
decodePublicKey (NoisePayload -> ByteString
npIdentityKey NoisePayload
np)

-- | Encode a NoisePayload as a minimal protobuf message.
encodeNoisePayload :: NoisePayload -> ByteString
encodeNoisePayload :: NoisePayload -> ByteString
encodeNoisePayload (NoisePayload ByteString
identKey ByteString
identSig) =
  -- Field 1: tag 0x0a (field 1, wire type 2 = length-delimited)
  Word8 -> ByteString
BS.singleton Word8
0x0a
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
encodeUvarint (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
identKey))
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
identKey
    -- Field 2: tag 0x12 (field 2, wire type 2 = length-delimited)
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton Word8
0x12
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
encodeUvarint (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
identSig))
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
identSig

-- | Decode a NoisePayload from protobuf bytes.
decodeNoisePayload :: ByteString -> Either String NoisePayload
decodeNoisePayload :: ByteString -> Either String NoisePayload
decodeNoisePayload ByteString
bs = do
  (identKey, rest1) <- Word8 -> ByteString -> Either String (ByteString, ByteString)
decodeField Word8
0x0a ByteString
bs
  (identSig, _rest2) <- decodeField 0x12 rest1
  Right (NoisePayload identKey identSig)
  where
    decodeField :: Word8 -> ByteString -> Either String (ByteString, ByteString)
    decodeField :: Word8 -> ByteString -> Either String (ByteString, ByteString)
decodeField Word8
expectedTag ByteString
input
      | ByteString -> Bool
BS.null ByteString
input = String -> Either String (ByteString, ByteString)
forall a b. a -> Either a b
Left String
"decodeNoisePayload: unexpected end of input"
      | HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
input Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
expectedTag =
          String -> Either String (ByteString, ByteString)
forall a b. a -> Either a b
Left (String -> Either String (ByteString, ByteString))
-> String -> Either String (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ String
"decodeNoisePayload: expected tag " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word8 -> String
forall a. Show a => a -> String
show Word8
expectedTag String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" got " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word8 -> String
forall a. Show a => a -> String
show (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
input)
      | Bool
otherwise = do
          let rest :: ByteString
rest = HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
input
          (len, rest2) <- ByteString -> Either String (Word64, ByteString)
decodeUvarint ByteString
rest
          let fieldLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len :: Int
          if BS.length rest2 < fieldLen
            then Left "decodeNoisePayload: not enough bytes for field"
            else Right (BS.take fieldLen rest2, BS.drop fieldLen rest2)

-- | Initialize a handshake state for the initiator role.
-- Returns (HandshakeState, noiseStaticPublicKey).
initHandshakeInitiator :: Key.KeyPair -> IO (HandshakeState, ByteString)
initHandshakeInitiator :: KeyPair -> IO (HandshakeState, ByteString)
initHandshakeInitiator KeyPair
_identityKP = do
  noiseStaticKP <- IO (KeyPair Curve25519)
forall d. DH d => IO (KeyPair d)
dhGenKey :: IO (DH.KeyPair Curve25519)
  noiseEphemeralKP <- dhGenKey :: IO (DH.KeyPair Curve25519)
  let noiseStaticPub = ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (PublicKey Curve25519 -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes (KeyPair Curve25519 -> PublicKey Curve25519
forall a b. (a, b) -> b
snd KeyPair Curve25519
noiseStaticKP)) :: ByteString
  let opts = Maybe (KeyPair Curve25519)
-> HandshakeOpts Curve25519 -> HandshakeOpts Curve25519
forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalStatic (KeyPair Curve25519 -> Maybe (KeyPair Curve25519)
forall a. a -> Maybe a
Just KeyPair Curve25519
noiseStaticKP)
           (HandshakeOpts Curve25519 -> HandshakeOpts Curve25519)
-> (HandshakeOpts Curve25519 -> HandshakeOpts Curve25519)
-> HandshakeOpts Curve25519
-> HandshakeOpts Curve25519
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (KeyPair Curve25519)
-> HandshakeOpts Curve25519 -> HandshakeOpts Curve25519
forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalEphemeral (KeyPair Curve25519 -> Maybe (KeyPair Curve25519)
forall a. a -> Maybe a
Just KeyPair Curve25519
noiseEphemeralKP)
           (HandshakeOpts Curve25519 -> HandshakeOpts Curve25519)
-> HandshakeOpts Curve25519 -> HandshakeOpts Curve25519
forall a b. (a -> b) -> a -> b
$ HandshakeRole -> ScrubbedBytes -> HandshakeOpts Curve25519
forall d. HandshakeRole -> ScrubbedBytes -> HandshakeOpts d
defaultHandshakeOpts HandshakeRole
InitiatorRole ScrubbedBytes
""
  let ns = HandshakeOpts Curve25519 -> HandshakePattern -> CacophonyState
forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakeOpts d -> HandshakePattern -> NoiseState c d h
noiseState HandshakeOpts Curve25519
opts HandshakePattern
noiseXX :: CacophonyState
  pure (HandshakeState ns, noiseStaticPub)

-- | Initialize a handshake state for the responder role.
-- Returns (HandshakeState, noiseStaticPublicKey).
initHandshakeResponder :: Key.KeyPair -> IO (HandshakeState, ByteString)
initHandshakeResponder :: KeyPair -> IO (HandshakeState, ByteString)
initHandshakeResponder KeyPair
_identityKP = do
  noiseStaticKP <- IO (KeyPair Curve25519)
forall d. DH d => IO (KeyPair d)
dhGenKey :: IO (DH.KeyPair Curve25519)
  noiseEphemeralKP <- dhGenKey :: IO (DH.KeyPair Curve25519)
  let noiseStaticPub = ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (PublicKey Curve25519 -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes (KeyPair Curve25519 -> PublicKey Curve25519
forall a b. (a, b) -> b
snd KeyPair Curve25519
noiseStaticKP)) :: ByteString
  let opts = Maybe (KeyPair Curve25519)
-> HandshakeOpts Curve25519 -> HandshakeOpts Curve25519
forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalStatic (KeyPair Curve25519 -> Maybe (KeyPair Curve25519)
forall a. a -> Maybe a
Just KeyPair Curve25519
noiseStaticKP)
           (HandshakeOpts Curve25519 -> HandshakeOpts Curve25519)
-> (HandshakeOpts Curve25519 -> HandshakeOpts Curve25519)
-> HandshakeOpts Curve25519
-> HandshakeOpts Curve25519
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (KeyPair Curve25519)
-> HandshakeOpts Curve25519 -> HandshakeOpts Curve25519
forall d. Maybe (KeyPair d) -> HandshakeOpts d -> HandshakeOpts d
setLocalEphemeral (KeyPair Curve25519 -> Maybe (KeyPair Curve25519)
forall a. a -> Maybe a
Just KeyPair Curve25519
noiseEphemeralKP)
           (HandshakeOpts Curve25519 -> HandshakeOpts Curve25519)
-> HandshakeOpts Curve25519 -> HandshakeOpts Curve25519
forall a b. (a -> b) -> a -> b
$ HandshakeRole -> ScrubbedBytes -> HandshakeOpts Curve25519
forall d. HandshakeRole -> ScrubbedBytes -> HandshakeOpts d
defaultHandshakeOpts HandshakeRole
ResponderRole ScrubbedBytes
""
  let ns = HandshakeOpts Curve25519 -> HandshakePattern -> CacophonyState
forall c d h.
(Cipher c, DH d, Hash h) =>
HandshakeOpts d -> HandshakePattern -> NoiseState c d h
noiseState HandshakeOpts Curve25519
opts HandshakePattern
noiseXX :: CacophonyState
  pure (HandshakeState ns, noiseStaticPub)

-- | Write a handshake message with the given payload.
-- Returns (ciphertext, updatedState).
writeHandshakeMsg :: HandshakeState -> ByteString -> Either String (ByteString, HandshakeState)
writeHandshakeMsg :: HandshakeState
-> ByteString -> Either String (ByteString, HandshakeState)
writeHandshakeMsg HandshakeState
hs ByteString
payload =
  let sb :: ScrubbedBytes
sb = ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
payload :: ScrubbedBytes
   in case ScrubbedBytes
-> CacophonyState -> NoiseResult ChaChaPoly1305 Curve25519 SHA256
forall c d h.
(Cipher c, DH d, Hash h) =>
ScrubbedBytes -> NoiseState c d h -> NoiseResult c d h
writeMessage ScrubbedBytes
sb (HandshakeState -> CacophonyState
hsNoiseState HandshakeState
hs) of
        NoiseResultMessage ScrubbedBytes
ct CacophonyState
ns' ->
          (ByteString, HandshakeState)
-> Either String (ByteString, HandshakeState)
forall a b. b -> Either a b
Right (ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ScrubbedBytes
ct, CacophonyState -> HandshakeState
HandshakeState CacophonyState
ns')
        NoiseResultException SomeException
ex ->
          String -> Either String (ByteString, HandshakeState)
forall a b. a -> Either a b
Left (String -> Either String (ByteString, HandshakeState))
-> String -> Either String (ByteString, HandshakeState)
forall a b. (a -> b) -> a -> b
$ String
"writeHandshakeMsg: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
ex
        NoiseResultNeedPSK CacophonyState
_ ->
          String -> Either String (ByteString, HandshakeState)
forall a b. a -> Either a b
Left String
"writeHandshakeMsg: unexpected PSK request"

-- | Read a handshake message and extract the decrypted payload.
-- Returns (plaintext, updatedState).
readHandshakeMsg :: HandshakeState -> ByteString -> Either String (ByteString, HandshakeState)
readHandshakeMsg :: HandshakeState
-> ByteString -> Either String (ByteString, HandshakeState)
readHandshakeMsg HandshakeState
hs ByteString
ciphertext =
  let sb :: ScrubbedBytes
sb = ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
ciphertext :: ScrubbedBytes
   in case ScrubbedBytes
-> CacophonyState -> NoiseResult ChaChaPoly1305 Curve25519 SHA256
forall c d h.
(Cipher c, DH d, Hash h) =>
ScrubbedBytes -> NoiseState c d h -> NoiseResult c d h
readMessage ScrubbedBytes
sb (HandshakeState -> CacophonyState
hsNoiseState HandshakeState
hs) of
        NoiseResultMessage ScrubbedBytes
pt CacophonyState
ns' ->
          (ByteString, HandshakeState)
-> Either String (ByteString, HandshakeState)
forall a b. b -> Either a b
Right (ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ScrubbedBytes
pt, CacophonyState -> HandshakeState
HandshakeState CacophonyState
ns')
        NoiseResultException SomeException
ex ->
          String -> Either String (ByteString, HandshakeState)
forall a b. a -> Either a b
Left (String -> Either String (ByteString, HandshakeState))
-> String -> Either String (ByteString, HandshakeState)
forall a b. (a -> b) -> a -> b
$ String
"readHandshakeMsg: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
ex
        NoiseResultNeedPSK CacophonyState
_ ->
          String -> Either String (ByteString, HandshakeState)
forall a b. a -> Either a b
Left String
"readHandshakeMsg: unexpected PSK request"

-- | Extract the remote party's Noise static public key from the handshake state.
-- Returns Just after the remote static key has been transmitted (msg2 for initiator,
-- msg3 for responder in XX pattern).
getRemoteNoiseStaticKey :: HandshakeState -> Maybe ByteString
getRemoteNoiseStaticKey :: HandshakeState -> Maybe ByteString
getRemoteNoiseStaticKey HandshakeState
hs =
  ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ScrubbedBytes -> ByteString)
-> (PublicKey Curve25519 -> ScrubbedBytes)
-> PublicKey Curve25519
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicKey Curve25519 -> ScrubbedBytes
forall d. DH d => PublicKey d -> ScrubbedBytes
dhPubToBytes (PublicKey Curve25519 -> ByteString)
-> Maybe (PublicKey Curve25519) -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CacophonyState -> Maybe (PublicKey Curve25519)
forall c d h. NoiseState c d h -> Maybe (PublicKey d)
remoteStaticKey (HandshakeState -> CacophonyState
hsNoiseState HandshakeState
hs)

-- | Check whether the handshake is complete.
sessionComplete :: HandshakeState -> Bool
sessionComplete :: HandshakeState -> Bool
sessionComplete = CacophonyState -> Bool
forall c d h. NoiseState c d h -> Bool
handshakeComplete (CacophonyState -> Bool)
-> (HandshakeState -> CacophonyState) -> HandshakeState -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandshakeState -> CacophonyState
hsNoiseState

-- | Perform a full 3-message XX handshake between two peers.
-- Returns the remote PeerId as seen by each side.
performFullHandshake :: Key.KeyPair -> Key.KeyPair -> IO (Either String (PeerId, PeerId))
performFullHandshake :: KeyPair -> KeyPair -> IO (Either String (PeerId, PeerId))
performFullHandshake KeyPair
aliceIdentity KeyPair
bobIdentity = do
  (aliceInit, aliceNoiseStaticPub) <- KeyPair -> IO (HandshakeState, ByteString)
initHandshakeInitiator KeyPair
aliceIdentity
  (bobInit, bobNoiseStaticPub) <- initHandshakeResponder bobIdentity
  pure $ do
    -- Message 1: Alice → Bob (empty payload)
    (msg1, aliceState1) <- writeHandshakeMsg aliceInit BS.empty
    (_payload1, bobState1) <- readHandshakeMsg bobInit msg1

    -- Message 2: Bob → Alice (Bob's identity payload)
    let bobPayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
bobIdentity ByteString
bobNoiseStaticPub
    (msg2, bobState2) <- writeHandshakeMsg bobState1 bobPayload
    (payload2, aliceState2) <- readHandshakeMsg aliceState1 msg2

    -- Decode Bob's identity
    bobNP <- decodeNoisePayload payload2
    bobPubKey <- decodePublicKey (npIdentityKey bobNP)
    let bobRemotePeerId = PublicKey -> PeerId
fromPublicKey PublicKey
bobPubKey

    -- Verify Bob's identity_sig: binds identity key to Noise static key
    case getRemoteNoiseStaticKey aliceState2 of
      Maybe ByteString
Nothing -> String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshake: remote Noise static key unavailable after msg2"
      Just ByteString
remoteNoisePub ->
        if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
bobPubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
npIdentitySig NoisePayload
bobNP))
          then String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshake: Bob's identity signature verification failed"
          else () -> Either String ()
forall a b. b -> Either a b
Right ()

    -- Message 3: Alice → Bob (Alice's identity payload)
    let alicePayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
aliceIdentity ByteString
aliceNoiseStaticPub
    (msg3, _aliceFinal) <- writeHandshakeMsg aliceState2 alicePayload
    (payload3, bobFinal) <- readHandshakeMsg bobState2 msg3

    -- Decode Alice's identity
    aliceNP <- decodeNoisePayload payload3
    alicePubKey <- decodePublicKey (npIdentityKey aliceNP)
    let aliceRemotePeerId = PublicKey -> PeerId
fromPublicKey PublicKey
alicePubKey

    -- Verify Alice's identity_sig: binds identity key to Noise static key
    case getRemoteNoiseStaticKey bobFinal of
      Maybe ByteString
Nothing -> String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshake: remote Noise static key unavailable after msg3"
      Just ByteString
remoteNoisePub ->
        if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
alicePubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
npIdentitySig NoisePayload
aliceNP))
          then String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshake: Alice's identity signature verification failed"
          else () -> Either String ()
forall a b. b -> Either a b
Right ()

    Right (bobRemotePeerId, aliceRemotePeerId)

-- | Perform a full handshake and return transport sessions for both sides.
performFullHandshakeWithSessions :: Key.KeyPair -> Key.KeyPair -> IO (Either String (NoiseSession, NoiseSession))
performFullHandshakeWithSessions :: KeyPair
-> KeyPair -> IO (Either String (NoiseSession, NoiseSession))
performFullHandshakeWithSessions KeyPair
aliceIdentity KeyPair
bobIdentity = do
  (aliceInit, aliceNoiseStaticPub) <- KeyPair -> IO (HandshakeState, ByteString)
initHandshakeInitiator KeyPair
aliceIdentity
  (bobInit, bobNoiseStaticPub) <- initHandshakeResponder bobIdentity
  pure $ do
    -- Message 1: Alice → Bob (empty payload)
    (msg1, aliceState1) <- writeHandshakeMsg aliceInit BS.empty
    (_payload1, bobState1) <- readHandshakeMsg bobInit msg1

    -- Message 2: Bob → Alice (Bob's identity payload)
    let bobPayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
bobIdentity ByteString
bobNoiseStaticPub
    (msg2, bobState2) <- writeHandshakeMsg bobState1 bobPayload
    (payload2, aliceState2) <- readHandshakeMsg aliceState1 msg2

    -- Verify Bob's identity_sig
    bobNP <- decodeNoisePayload payload2
    bobPubKey <- decodePublicKey (npIdentityKey bobNP)
    case getRemoteNoiseStaticKey aliceState2 of
      Maybe ByteString
Nothing -> String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshakeWithSessions: remote Noise static key unavailable after msg2"
      Just ByteString
remoteNoisePub ->
        if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
bobPubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
npIdentitySig NoisePayload
bobNP))
          then String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshakeWithSessions: Bob's identity signature verification failed"
          else () -> Either String ()
forall a b. b -> Either a b
Right ()

    -- Message 3: Alice → Bob (Alice's identity payload)
    let alicePayload = NoisePayload -> ByteString
encodeNoisePayload (NoisePayload -> ByteString) -> NoisePayload -> ByteString
forall a b. (a -> b) -> a -> b
$ KeyPair -> ByteString -> NoisePayload
buildHandshakePayload KeyPair
aliceIdentity ByteString
aliceNoiseStaticPub
    (msg3, aliceFinal) <- writeHandshakeMsg aliceState2 alicePayload
    (payload3, bobFinal) <- readHandshakeMsg bobState2 msg3

    -- Verify Alice's identity_sig
    aliceNP <- decodeNoisePayload payload3
    alicePubKey <- decodePublicKey (npIdentityKey aliceNP)
    case getRemoteNoiseStaticKey bobFinal of
      Maybe ByteString
Nothing -> String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshakeWithSessions: remote Noise static key unavailable after msg3"
      Just ByteString
remoteNoisePub ->
        if Bool -> Bool
not (PublicKey -> ByteString -> ByteString -> Bool
verifyStaticKey PublicKey
alicePubKey ByteString
remoteNoisePub (NoisePayload -> ByteString
npIdentitySig NoisePayload
aliceNP))
          then String -> Either String ()
forall a b. a -> Either a b
Left String
"performFullHandshakeWithSessions: Alice's identity signature verification failed"
          else () -> Either String ()
forall a b. b -> Either a b
Right ()

    -- Convert to transport sessions
    Right (mkNoiseSession (hsNoiseState aliceFinal), mkNoiseSession (hsNoiseState bobFinal))