module Network.LibP2P.Protocol.Ping.Ping
(
pingProtocolId
, PingError (..)
, PingResult (..)
, handlePing
, sendPing
, registerPingHandler
, pingSize
) where
import Control.Concurrent.STM (atomically, readTVar, writeTVar)
import Control.Exception (SomeException, catch)
import qualified Data.ByteString as BS
import Data.ByteString (ByteString)
import qualified Data.Map.Strict as Map
import Data.Text (Text)
import Data.Time.Clock (NominalDiffTime, diffUTCTime, getCurrentTime)
import Crypto.Random (getRandomBytes)
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.MultistreamSelect.Negotiation
( StreamIO (..)
, negotiateInitiator
, NegotiationResult (..)
)
import Network.LibP2P.Switch.Types
( Connection (..)
, MuxerSession (..)
, Switch (..)
)
pingProtocolId :: Text
pingProtocolId :: Text
pingProtocolId = Text
"/ipfs/ping/1.0.0"
pingSize :: Int
pingSize :: Int
pingSize = Int
32
data PingError
= PingTimeout
| PingMismatch
| PingStreamError !String
deriving (Int -> PingError -> ShowS
[PingError] -> ShowS
PingError -> String
(Int -> PingError -> ShowS)
-> (PingError -> String)
-> ([PingError] -> ShowS)
-> Show PingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PingError -> ShowS
showsPrec :: Int -> PingError -> ShowS
$cshow :: PingError -> String
show :: PingError -> String
$cshowList :: [PingError] -> ShowS
showList :: [PingError] -> ShowS
Show, PingError -> PingError -> Bool
(PingError -> PingError -> Bool)
-> (PingError -> PingError -> Bool) -> Eq PingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PingError -> PingError -> Bool
== :: PingError -> PingError -> Bool
$c/= :: PingError -> PingError -> Bool
/= :: PingError -> PingError -> Bool
Eq)
data PingResult = PingResult
{ PingResult -> NominalDiffTime
pingRTT :: !NominalDiffTime
} deriving (Int -> PingResult -> ShowS
[PingResult] -> ShowS
PingResult -> String
(Int -> PingResult -> ShowS)
-> (PingResult -> String)
-> ([PingResult] -> ShowS)
-> Show PingResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PingResult -> ShowS
showsPrec :: Int -> PingResult -> ShowS
$cshow :: PingResult -> String
show :: PingResult -> String
$cshowList :: [PingResult] -> ShowS
showList :: [PingResult] -> ShowS
Show, PingResult -> PingResult -> Bool
(PingResult -> PingResult -> Bool)
-> (PingResult -> PingResult -> Bool) -> Eq PingResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PingResult -> PingResult -> Bool
== :: PingResult -> PingResult -> Bool
$c/= :: PingResult -> PingResult -> Bool
/= :: PingResult -> PingResult -> Bool
Eq)
handlePing :: StreamIO -> PeerId -> IO ()
handlePing :: StreamIO -> PeerId -> IO ()
handlePing StreamIO
stream PeerId
_remotePeerId = IO ()
echoLoop
where
echoLoop :: IO ()
echoLoop = do
result <- (ByteString -> Either () ByteString
forall a b. b -> Either a b
Right (ByteString -> Either () ByteString)
-> IO ByteString -> IO (Either () ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
pingSize) IO (Either () ByteString)
-> (SomeException -> IO (Either () ByteString))
-> IO (Either () ByteString)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
(\(SomeException
_ :: SomeException) -> Either () ByteString -> IO (Either () ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either () ByteString
forall a b. a -> Either a b
Left ()))
case result of
Left () -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Right ByteString
payload -> do
StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream ByteString
payload
IO ()
echoLoop
sendPing :: Connection -> IO (Either PingError PingResult)
sendPing :: Connection -> IO (Either PingError PingResult)
sendPing Connection
conn = do
stream <- MuxerSession -> IO StreamIO
muxOpenStream (Connection -> MuxerSession
connSession Connection
conn)
result <- negotiateInitiator stream [pingProtocolId]
case result of
NegotiationResult
NoProtocol -> Either PingError PingResult -> IO (Either PingError PingResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError PingResult
forall a b. a -> Either a b
Left (String -> PingError
PingStreamError String
"remote does not support ping"))
Accepted Text
_ -> do
payload <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
pingSize :: IO ByteString
t0 <- getCurrentTime
streamWrite stream payload
response <- (Right <$> readExact stream pingSize) `catch`
(\(SomeException
_ :: SomeException) -> Either PingError ByteString -> IO (Either PingError ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError ByteString
forall a b. a -> Either a b
Left (String -> PingError
PingStreamError String
"read failed")))
case response of
Left PingError
err -> Either PingError PingResult -> IO (Either PingError PingResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError PingResult
forall a b. a -> Either a b
Left PingError
err)
Right ByteString
echo
| ByteString
echo ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
payload -> Either PingError PingResult -> IO (Either PingError PingResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError PingResult
forall a b. a -> Either a b
Left PingError
PingMismatch)
| Bool
otherwise -> do
t1 <- IO UTCTime
getCurrentTime
pure (Right (PingResult (diffUTCTime t1 t0)))
registerPingHandler :: Switch -> IO ()
registerPingHandler :: Switch -> IO ()
registerPingHandler Switch
sw = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
protos <- TVar (Map Text (StreamIO -> PeerId -> IO ()))
-> STM (Map Text (StreamIO -> PeerId -> IO ()))
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map Text (StreamIO -> PeerId -> IO ()))
swProtocols Switch
sw)
writeTVar (swProtocols sw) (Map.insert pingProtocolId handlePing protos)
readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]