-- | multistream-select protocol negotiation.
--
-- Implements Initiator and Responder roles for negotiating
-- which protocol to use over a connection or stream.
module Network.LibP2P.MultistreamSelect.Negotiation
  ( NegotiationResult (..)
  , ProtocolId
  , StreamIO (..)
  , negotiateInitiator
  , negotiateResponder
  , mkMemoryStreamPair
  ) where

import Control.Concurrent.STM
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Text (Text)
import Data.Word (Word8)
import Network.LibP2P.Core.Varint (decodeUvarint)
import Network.LibP2P.MultistreamSelect.Wire

-- | A protocol identifier (e.g. "/noise", "/yamux/1.0.0").
type ProtocolId = Text

-- | Result of a negotiation attempt.
data NegotiationResult
  = Accepted !ProtocolId
  | NoProtocol
  deriving (Int -> NegotiationResult -> ShowS
[NegotiationResult] -> ShowS
NegotiationResult -> String
(Int -> NegotiationResult -> ShowS)
-> (NegotiationResult -> String)
-> ([NegotiationResult] -> ShowS)
-> Show NegotiationResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NegotiationResult -> ShowS
showsPrec :: Int -> NegotiationResult -> ShowS
$cshow :: NegotiationResult -> String
show :: NegotiationResult -> String
$cshowList :: [NegotiationResult] -> ShowS
showList :: [NegotiationResult] -> ShowS
Show, NegotiationResult -> NegotiationResult -> Bool
(NegotiationResult -> NegotiationResult -> Bool)
-> (NegotiationResult -> NegotiationResult -> Bool)
-> Eq NegotiationResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NegotiationResult -> NegotiationResult -> Bool
== :: NegotiationResult -> NegotiationResult -> Bool
$c/= :: NegotiationResult -> NegotiationResult -> Bool
/= :: NegotiationResult -> NegotiationResult -> Bool
Eq)

-- | Abstraction for stream I/O to enable testing with in-memory buffers.
data StreamIO = StreamIO
  { StreamIO -> ByteString -> IO ()
streamWrite    :: ByteString -> IO ()
  , StreamIO -> IO Word8
streamReadByte :: IO Word8   -- ^ Read exactly one byte (blocks until available)
  , StreamIO -> IO ()
streamClose    :: IO ()      -- ^ Close/half-close the stream (signals EOF to remote)
  }

-- | Create an in-memory stream pair for testing using STM TQueue.
-- Writes to stream A appear as reads on stream B and vice versa.
mkMemoryStreamPair :: IO (StreamIO, StreamIO)
mkMemoryStreamPair :: IO (StreamIO, StreamIO)
mkMemoryStreamPair = do
  queueAtoB <- IO (TQueue Word8)
forall a. IO (TQueue a)
newTQueueIO :: IO (TQueue Word8)
  queueBtoA <- newTQueueIO :: IO (TQueue Word8)
  let writeToQueue TQueue Word8
q ByteString
bs = (Word8 -> IO ()) -> [Word8] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (Word8 -> STM ()) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TQueue Word8 -> Word8 -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue Word8
q) (ByteString -> [Word8]
BS.unpack ByteString
bs)
      readFromQueue TQueue a
q = STM a -> IO a
forall a. STM a -> IO a
atomically (TQueue a -> STM a
forall a. TQueue a -> STM a
readTQueue TQueue a
q)
  pure
    ( StreamIO (writeToQueue queueAtoB) (readFromQueue queueBtoA) (pure ())
    , StreamIO (writeToQueue queueBtoA) (readFromQueue queueAtoB) (pure ())
    )

-- | Read exactly n bytes from a stream.
readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]

-- | Read a complete multistream-select message from a stream.
-- Reads varint length byte-by-byte, then reads the full payload.
readMessage :: StreamIO -> IO (Either String Text)
readMessage :: StreamIO -> IO (Either String Text)
readMessage StreamIO
stream = do
  varintBytes <- StreamIO -> IO ByteString
readVarint StreamIO
stream
  case decodeUvarint varintBytes of
    Left String
err -> Either String Text -> IO (Either String Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String Text
forall a b. a -> Either a b
Left String
err)
    Right (Word64
len, ByteString
_) -> do
      let payloadLen :: Int
payloadLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len :: Int
      payload <- StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
payloadLen
      case decodeMessage (varintBytes <> payload) of
        Left String
err -> Either String Text -> IO (Either String Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String Text
forall a b. a -> Either a b
Left String
err)
        Right (Text
msg, ByteString
_) -> Either String Text -> IO (Either String Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> Either String Text
forall a b. b -> Either a b
Right Text
msg)

-- | Read a varint one byte at a time from the stream.
readVarint :: StreamIO -> IO ByteString
readVarint :: StreamIO -> IO ByteString
readVarint StreamIO
stream = ByteString -> IO ByteString
go ByteString
BS.empty
  where
    go :: ByteString -> IO ByteString
go ByteString
acc = do
      b <- StreamIO -> IO Word8
streamReadByte StreamIO
stream
      let acc' = ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton Word8
b
      if b < 0x80
        then pure acc'
        else go acc'

-- | Write a multistream-select message to a stream.
writeMessage :: StreamIO -> Text -> IO ()
writeMessage :: StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
msg = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream (Text -> ByteString
encodeMessage Text
msg)

-- | Negotiate as the Initiator.
-- Sends header, then tries each protocol in order until one is accepted.
negotiateInitiator :: StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateInitiator :: StreamIO -> [Text] -> IO NegotiationResult
negotiateInitiator StreamIO
stream [Text]
protocols = do
  StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
multistreamHeader
  result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
  case result of
    Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
    Right Text
header
      | Text
header Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Text
multistreamHeader -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
      | Bool
otherwise -> [Text] -> IO NegotiationResult
tryProtocols [Text]
protocols
  where
    tryProtocols :: [Text] -> IO NegotiationResult
tryProtocols [] = NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
    tryProtocols (Text
proto : [Text]
rest) = do
      StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
proto
      result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
      case result of
        Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
        Right Text
response
          | Text
response Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
proto -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> NegotiationResult
Accepted Text
proto)
          | Text
response Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
naMessage -> [Text] -> IO NegotiationResult
tryProtocols [Text]
rest
          | Bool
otherwise -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol

-- | Negotiate as the Responder.
-- Receives header, then responds to the initiator's proposal.
negotiateResponder :: StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateResponder :: StreamIO -> [Text] -> IO NegotiationResult
negotiateResponder StreamIO
stream [Text]
supported = do
  result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
  case result of
    Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
    Right Text
header
      | Text
header Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Text
multistreamHeader -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
      | Bool
otherwise -> do
          StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
multistreamHeader
          IO NegotiationResult
handleProposals
  where
    handleProposals :: IO NegotiationResult
handleProposals = do
      result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
      case result of
        Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
        Right Text
proposal
          | Text
proposal Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
supported -> do
              StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
proposal
              NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> NegotiationResult
Accepted Text
proposal)
          | Bool
otherwise -> do
              StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
naMessage
              IO NegotiationResult
handleProposals