module Network.LibP2P.MultistreamSelect.Negotiation
( NegotiationResult (..)
, ProtocolId
, StreamIO (..)
, negotiateInitiator
, negotiateResponder
, mkMemoryStreamPair
) where
import Control.Concurrent.STM
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Text (Text)
import Data.Word (Word8)
import Network.LibP2P.Core.Varint (decodeUvarint)
import Network.LibP2P.MultistreamSelect.Wire
type ProtocolId = Text
data NegotiationResult
= Accepted !ProtocolId
| NoProtocol
deriving (Int -> NegotiationResult -> ShowS
[NegotiationResult] -> ShowS
NegotiationResult -> String
(Int -> NegotiationResult -> ShowS)
-> (NegotiationResult -> String)
-> ([NegotiationResult] -> ShowS)
-> Show NegotiationResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NegotiationResult -> ShowS
showsPrec :: Int -> NegotiationResult -> ShowS
$cshow :: NegotiationResult -> String
show :: NegotiationResult -> String
$cshowList :: [NegotiationResult] -> ShowS
showList :: [NegotiationResult] -> ShowS
Show, NegotiationResult -> NegotiationResult -> Bool
(NegotiationResult -> NegotiationResult -> Bool)
-> (NegotiationResult -> NegotiationResult -> Bool)
-> Eq NegotiationResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NegotiationResult -> NegotiationResult -> Bool
== :: NegotiationResult -> NegotiationResult -> Bool
$c/= :: NegotiationResult -> NegotiationResult -> Bool
/= :: NegotiationResult -> NegotiationResult -> Bool
Eq)
data StreamIO = StreamIO
{ StreamIO -> ByteString -> IO ()
streamWrite :: ByteString -> IO ()
, StreamIO -> IO Word8
streamReadByte :: IO Word8
, StreamIO -> IO ()
streamClose :: IO ()
}
mkMemoryStreamPair :: IO (StreamIO, StreamIO)
mkMemoryStreamPair :: IO (StreamIO, StreamIO)
mkMemoryStreamPair = do
queueAtoB <- IO (TQueue Word8)
forall a. IO (TQueue a)
newTQueueIO :: IO (TQueue Word8)
queueBtoA <- newTQueueIO :: IO (TQueue Word8)
let writeToQueue TQueue Word8
q ByteString
bs = (Word8 -> IO ()) -> [Word8] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (Word8 -> STM ()) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TQueue Word8 -> Word8 -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue Word8
q) (ByteString -> [Word8]
BS.unpack ByteString
bs)
readFromQueue TQueue a
q = STM a -> IO a
forall a. STM a -> IO a
atomically (TQueue a -> STM a
forall a. TQueue a -> STM a
readTQueue TQueue a
q)
pure
( StreamIO (writeToQueue queueAtoB) (readFromQueue queueBtoA) (pure ())
, StreamIO (writeToQueue queueBtoA) (readFromQueue queueAtoB) (pure ())
)
readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]
readMessage :: StreamIO -> IO (Either String Text)
readMessage :: StreamIO -> IO (Either String Text)
readMessage StreamIO
stream = do
varintBytes <- StreamIO -> IO ByteString
readVarint StreamIO
stream
case decodeUvarint varintBytes of
Left String
err -> Either String Text -> IO (Either String Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String Text
forall a b. a -> Either a b
Left String
err)
Right (Word64
len, ByteString
_) -> do
let payloadLen :: Int
payloadLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len :: Int
payload <- StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
payloadLen
case decodeMessage (varintBytes <> payload) of
Left String
err -> Either String Text -> IO (Either String Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String Text
forall a b. a -> Either a b
Left String
err)
Right (Text
msg, ByteString
_) -> Either String Text -> IO (Either String Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> Either String Text
forall a b. b -> Either a b
Right Text
msg)
readVarint :: StreamIO -> IO ByteString
readVarint :: StreamIO -> IO ByteString
readVarint StreamIO
stream = ByteString -> IO ByteString
go ByteString
BS.empty
where
go :: ByteString -> IO ByteString
go ByteString
acc = do
b <- StreamIO -> IO Word8
streamReadByte StreamIO
stream
let acc' = ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton Word8
b
if b < 0x80
then pure acc'
else go acc'
writeMessage :: StreamIO -> Text -> IO ()
writeMessage :: StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
msg = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream (Text -> ByteString
encodeMessage Text
msg)
negotiateInitiator :: StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateInitiator :: StreamIO -> [Text] -> IO NegotiationResult
negotiateInitiator StreamIO
stream [Text]
protocols = do
StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
multistreamHeader
result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
case result of
Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
Right Text
header
| Text
header Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Text
multistreamHeader -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
| Bool
otherwise -> [Text] -> IO NegotiationResult
tryProtocols [Text]
protocols
where
tryProtocols :: [Text] -> IO NegotiationResult
tryProtocols [] = NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
tryProtocols (Text
proto : [Text]
rest) = do
StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
proto
result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
case result of
Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
Right Text
response
| Text
response Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
proto -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> NegotiationResult
Accepted Text
proto)
| Text
response Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text
naMessage -> [Text] -> IO NegotiationResult
tryProtocols [Text]
rest
| Bool
otherwise -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
negotiateResponder :: StreamIO -> [ProtocolId] -> IO NegotiationResult
negotiateResponder :: StreamIO -> [Text] -> IO NegotiationResult
negotiateResponder StreamIO
stream [Text]
supported = do
result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
case result of
Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
Right Text
header
| Text
header Text -> Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Text
multistreamHeader -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
| Bool
otherwise -> do
StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
multistreamHeader
IO NegotiationResult
handleProposals
where
handleProposals :: IO NegotiationResult
handleProposals = do
result <- StreamIO -> IO (Either String Text)
readMessage StreamIO
stream
case result of
Left String
_ -> NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NegotiationResult
NoProtocol
Right Text
proposal
| Text
proposal Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
supported -> do
StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
proposal
NegotiationResult -> IO NegotiationResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> NegotiationResult
Accepted Text
proposal)
| Bool
otherwise -> do
StreamIO -> Text -> IO ()
writeMessage StreamIO
stream Text
naMessage
IO NegotiationResult
handleProposals