-- | Ping protocol implementation (docs/07-protocols.md).
--
-- Protocol ID: /ipfs/ping/1.0.0
--
-- Wire format: 32 bytes random → 32 bytes echo. No framing, no protobuf.
-- The handler runs an echo loop: reads 32 bytes, writes them back,
-- until the stream closes. The initiator sends 32 random bytes,
-- measures round-trip time, and verifies the echo matches.
module Network.LibP2P.Protocol.Ping.Ping
  ( -- * Protocol ID
    pingProtocolId
    -- * Types
  , PingError (..)
  , PingResult (..)
    -- * Protocol logic
  , handlePing
  , sendPing
    -- * Registration
  , registerPingHandler
    -- * Constants
  , pingSize
  ) where

import Control.Concurrent.STM (atomically, readTVar, writeTVar)
import Control.Exception (SomeException, catch)
import qualified Data.ByteString as BS
import Data.ByteString (ByteString)
import qualified Data.Map.Strict as Map
import Data.Text (Text)
import Data.Time.Clock (NominalDiffTime, diffUTCTime, getCurrentTime)
import Crypto.Random (getRandomBytes)
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.MultistreamSelect.Negotiation
  ( StreamIO (..)
  , negotiateInitiator
  , NegotiationResult (..)
  )
import Network.LibP2P.Switch.Types
  ( Connection (..)
  , MuxerSession (..)
  , Switch (..)
  )

-- | Ping protocol ID.
pingProtocolId :: Text
pingProtocolId :: Text
pingProtocolId = Text
"/ipfs/ping/1.0.0"

-- | Ping payload size: 32 bytes.
pingSize :: Int
pingSize :: Int
pingSize = Int
32

-- | Ping error types.
data PingError
  = PingTimeout          -- ^ No response within timeout
  | PingMismatch         -- ^ Response doesn't match sent bytes
  | PingStreamError !String  -- ^ Stream I/O error
  deriving (Int -> PingError -> ShowS
[PingError] -> ShowS
PingError -> String
(Int -> PingError -> ShowS)
-> (PingError -> String)
-> ([PingError] -> ShowS)
-> Show PingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PingError -> ShowS
showsPrec :: Int -> PingError -> ShowS
$cshow :: PingError -> String
show :: PingError -> String
$cshowList :: [PingError] -> ShowS
showList :: [PingError] -> ShowS
Show, PingError -> PingError -> Bool
(PingError -> PingError -> Bool)
-> (PingError -> PingError -> Bool) -> Eq PingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PingError -> PingError -> Bool
== :: PingError -> PingError -> Bool
$c/= :: PingError -> PingError -> Bool
/= :: PingError -> PingError -> Bool
Eq)

-- | Successful ping result.
data PingResult = PingResult
  { PingResult -> NominalDiffTime
pingRTT :: !NominalDiffTime  -- ^ Round-trip time
  } deriving (Int -> PingResult -> ShowS
[PingResult] -> ShowS
PingResult -> String
(Int -> PingResult -> ShowS)
-> (PingResult -> String)
-> ([PingResult] -> ShowS)
-> Show PingResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PingResult -> ShowS
showsPrec :: Int -> PingResult -> ShowS
$cshow :: PingResult -> String
show :: PingResult -> String
$cshowList :: [PingResult] -> ShowS
showList :: [PingResult] -> ShowS
Show, PingResult -> PingResult -> Bool
(PingResult -> PingResult -> Bool)
-> (PingResult -> PingResult -> Bool) -> Eq PingResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: PingResult -> PingResult -> Bool
== :: PingResult -> PingResult -> Bool
$c/= :: PingResult -> PingResult -> Bool
/= :: PingResult -> PingResult -> Bool
Eq)

-- | Handle an inbound Ping request (responder / echo loop).
--
-- Reads 32 bytes, writes them back. Repeats until stream closes.
handlePing :: StreamIO -> PeerId -> IO ()
handlePing :: StreamIO -> PeerId -> IO ()
handlePing StreamIO
stream PeerId
_remotePeerId = IO ()
echoLoop
  where
    echoLoop :: IO ()
echoLoop = do
      result <- (ByteString -> Either () ByteString
forall a b. b -> Either a b
Right (ByteString -> Either () ByteString)
-> IO ByteString -> IO (Either () ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
pingSize) IO (Either () ByteString)
-> (SomeException -> IO (Either () ByteString))
-> IO (Either () ByteString)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
                (\(SomeException
_ :: SomeException) -> Either () ByteString -> IO (Either () ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either () ByteString
forall a b. a -> Either a b
Left ()))
      case result of
        Left () -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- Stream closed, exit loop
        Right ByteString
payload -> do
          StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream ByteString
payload
          IO ()
echoLoop

-- | Send a Ping to a remote peer (initiator side).
--
-- Opens a new stream, negotiates /ipfs/ping/1.0.0, sends 32 random
-- bytes, reads 32 bytes back, verifies match, measures RTT.
sendPing :: Connection -> IO (Either PingError PingResult)
sendPing :: Connection -> IO (Either PingError PingResult)
sendPing Connection
conn = do
  stream <- MuxerSession -> IO StreamIO
muxOpenStream (Connection -> MuxerSession
connSession Connection
conn)
  result <- negotiateInitiator stream [pingProtocolId]
  case result of
    NegotiationResult
NoProtocol -> Either PingError PingResult -> IO (Either PingError PingResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError PingResult
forall a b. a -> Either a b
Left (String -> PingError
PingStreamError String
"remote does not support ping"))
    Accepted Text
_ -> do
      payload <- Int -> IO ByteString
forall byteArray. ByteArray byteArray => Int -> IO byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
pingSize :: IO ByteString
      t0 <- getCurrentTime
      streamWrite stream payload
      response <- (Right <$> readExact stream pingSize) `catch`
                  (\(SomeException
_ :: SomeException) -> Either PingError ByteString -> IO (Either PingError ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError ByteString
forall a b. a -> Either a b
Left (String -> PingError
PingStreamError String
"read failed")))
      case response of
        Left PingError
err -> Either PingError PingResult -> IO (Either PingError PingResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError PingResult
forall a b. a -> Either a b
Left PingError
err)
        Right ByteString
echo
          | ByteString
echo ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
payload -> Either PingError PingResult -> IO (Either PingError PingResult)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PingError -> Either PingError PingResult
forall a b. a -> Either a b
Left PingError
PingMismatch)
          | Bool
otherwise -> do
              t1 <- IO UTCTime
getCurrentTime
              pure (Right (PingResult (diffUTCTime t1 t0)))

-- | Register the Ping handler on the Switch.
registerPingHandler :: Switch -> IO ()
registerPingHandler :: Switch -> IO ()
registerPingHandler Switch
sw = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  protos <- TVar (Map Text (StreamIO -> PeerId -> IO ()))
-> STM (Map Text (StreamIO -> PeerId -> IO ()))
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map Text (StreamIO -> PeerId -> IO ()))
swProtocols Switch
sw)
  writeTVar (swProtocols sw) (Map.insert pingProtocolId handlePing protos)

-- | Read exactly n bytes from a StreamIO.
readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]