-- | DCUtR (Direct Connection Upgrade through Relay) protocol.
--
-- Protocol: /libp2p/dcutr
-- Coordinates hole punching over a relayed connection using a 3-message exchange
-- with RTT-based timing synchronization.
--
-- Message flow:
--   B (initiator) sends CONNECT with B's observed addresses
--   A (handler) sends CONNECT with A's observed addresses
--   B sends SYNC
--   B waits RTT/2, then dials A's addresses
--   A receives SYNC, then dials B's addresses immediately
--   Both peers attempt direct connections at approximately the same time
module Network.LibP2P.NAT.DCUtR.DCUtR
  ( -- * Types
    DCUtRConfig (..)
  , DCUtRResult (..)
    -- * Protocol operations
  , initiateDCUtR
  , handleDCUtR
    -- * Variants for testing
  , initiateDCUtRWithRTT
  , initiateDCUtRCapture
  , handleDCUtRCapture
  ) where

import qualified Data.ByteString as BS
import Data.IORef (IORef, newIORef, writeIORef)
import Control.Concurrent (threadDelay)
import Data.Time.Clock (getCurrentTime, diffUTCTime, NominalDiffTime)
import Network.LibP2P.NAT.DCUtR.Message
import Network.LibP2P.MultistreamSelect.Negotiation (StreamIO (..))
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr, toBytes, fromBytes)

-- | DCUtR configuration.
data DCUtRConfig = DCUtRConfig
  { DCUtRConfig -> Int
dcMaxRetries :: !Int
    -- ^ Maximum number of retry attempts (spec says 3 total = 1 initial + 2 retries)
  , DCUtRConfig -> Multiaddr -> IO (Either String ())
dcDialer     :: !(Multiaddr -> IO (Either String ()))
    -- ^ Injectable dial function for testing
  }

-- | DCUtR result.
data DCUtRResult = DCUtRSuccess | DCUtRFailed String
  deriving (Int -> DCUtRResult -> ShowS
[DCUtRResult] -> ShowS
DCUtRResult -> String
(Int -> DCUtRResult -> ShowS)
-> (DCUtRResult -> String)
-> ([DCUtRResult] -> ShowS)
-> Show DCUtRResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DCUtRResult -> ShowS
showsPrec :: Int -> DCUtRResult -> ShowS
$cshow :: DCUtRResult -> String
show :: DCUtRResult -> String
$cshowList :: [DCUtRResult] -> ShowS
showList :: [DCUtRResult] -> ShowS
Show, DCUtRResult -> DCUtRResult -> Bool
(DCUtRResult -> DCUtRResult -> Bool)
-> (DCUtRResult -> DCUtRResult -> Bool) -> Eq DCUtRResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DCUtRResult -> DCUtRResult -> Bool
== :: DCUtRResult -> DCUtRResult -> Bool
$c/= :: DCUtRResult -> DCUtRResult -> Bool
/= :: DCUtRResult -> DCUtRResult -> Bool
Eq)

-- | Peer B (initiator): run the DCUtR exchange over a relayed stream.
--
-- Flow:
--   1. Send CONNECT with own observed addresses
--   2. Read A's CONNECT (measure RTT)
--   3. Send SYNC
--   4. Wait RTT/2, then dial A's addresses
initiateDCUtR :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IO DCUtRResult
initiateDCUtR :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IO DCUtRResult
initiateDCUtR DCUtRConfig
config StreamIO
stream [Multiaddr]
addrs = do
  rttRef <- Maybe NominalDiffTime -> IO (IORef (Maybe NominalDiffTime))
forall a. a -> IO (IORef a)
newIORef Maybe NominalDiffTime
forall a. Maybe a
Nothing
  initiateDCUtRWithRTT config stream addrs rttRef

-- | Initiator variant that captures RTT for testing.
initiateDCUtRWithRTT :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IORef (Maybe NominalDiffTime) -> IO DCUtRResult
initiateDCUtRWithRTT :: DCUtRConfig
-> StreamIO
-> [Multiaddr]
-> IORef (Maybe NominalDiffTime)
-> IO DCUtRResult
initiateDCUtRWithRTT DCUtRConfig
config StreamIO
stream [Multiaddr]
addrs IORef (Maybe NominalDiffTime)
rttRef = do
  let addrBytes :: [ByteString]
addrBytes = (Multiaddr -> ByteString) -> [Multiaddr] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Multiaddr -> ByteString
toBytes [Multiaddr]
addrs
  -- Step 1: Send CONNECT with our observed addresses
  let connectOut :: HolePunchMessage
connectOut = HolePunchMessage { hpType :: HolePunchType
hpType = HolePunchType
HPConnect, hpObsAddrs :: [ByteString]
hpObsAddrs = [ByteString]
addrBytes }
  StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage StreamIO
stream HolePunchMessage
connectOut
  t0 <- IO UTCTime
getCurrentTime
  -- Step 2: Read A's CONNECT response (this measures RTT)
  result <- readHolePunchMessage stream maxDCUtRMessageSize
  case result of
    Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"failed to read CONNECT: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
    Right HolePunchMessage
msg
      | HolePunchMessage -> HolePunchType
hpType HolePunchMessage
msg HolePunchType -> HolePunchType -> Bool
forall a. Eq a => a -> a -> Bool
/= HolePunchType
HPConnect -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed String
"expected CONNECT message")
      | Bool
otherwise -> do
          t1 <- IO UTCTime
getCurrentTime
          let rtt = UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
t1 UTCTime
t0
          writeIORef rttRef (Just rtt)
          -- Step 3: Send SYNC
          let syncOut = HolePunchMessage { hpType :: HolePunchType
hpType = HolePunchType
HPSync, hpObsAddrs :: [ByteString]
hpObsAddrs = [] }
          writeHolePunchMessage stream syncOut
          -- Step 4: Wait RTT/2, then dial A's addresses
          let halfRTTMicros = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (NominalDiffTime -> Int
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (NominalDiffTime
rtt NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* NominalDiffTime
1000000 NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Fractional a => a -> a -> a
/ NominalDiffTime
2) :: Int)
          threadDelay halfRTTMicros
          let remoteAddrs = [ByteString] -> [Multiaddr]
parseAddrs (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)
          dialResult <- tryDialAddrs config remoteAddrs
          case dialResult of
            Right () -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DCUtRResult
DCUtRSuccess
            Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"dial failed: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)

-- | Initiator variant that captures received addresses for testing.
initiateDCUtRCapture :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IORef [BS.ByteString] -> IO DCUtRResult
initiateDCUtRCapture :: DCUtRConfig
-> StreamIO -> [Multiaddr] -> IORef [ByteString] -> IO DCUtRResult
initiateDCUtRCapture DCUtRConfig
config StreamIO
stream [Multiaddr]
addrs IORef [ByteString]
receivedRef = do
  let addrBytes :: [ByteString]
addrBytes = (Multiaddr -> ByteString) -> [Multiaddr] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Multiaddr -> ByteString
toBytes [Multiaddr]
addrs
  let connectOut :: HolePunchMessage
connectOut = HolePunchMessage { hpType :: HolePunchType
hpType = HolePunchType
HPConnect, hpObsAddrs :: [ByteString]
hpObsAddrs = [ByteString]
addrBytes }
  StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage StreamIO
stream HolePunchMessage
connectOut
  result <- StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage StreamIO
stream Int
maxDCUtRMessageSize
  case result of
    Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"failed to read CONNECT: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
    Right HolePunchMessage
msg
      | HolePunchMessage -> HolePunchType
hpType HolePunchMessage
msg HolePunchType -> HolePunchType -> Bool
forall a. Eq a => a -> a -> Bool
/= HolePunchType
HPConnect -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed String
"expected CONNECT message")
      | Bool
otherwise -> do
          IORef [ByteString] -> [ByteString] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [ByteString]
receivedRef (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)
          let syncOut :: HolePunchMessage
syncOut = HolePunchMessage { hpType :: HolePunchType
hpType = HolePunchType
HPSync, hpObsAddrs :: [ByteString]
hpObsAddrs = [] }
          StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage StreamIO
stream HolePunchMessage
syncOut
          let remoteAddrs :: [Multiaddr]
remoteAddrs = [ByteString] -> [Multiaddr]
parseAddrs (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)
          dialResult <- DCUtRConfig -> [Multiaddr] -> IO (Either String ())
tryDialAddrs DCUtRConfig
config [Multiaddr]
remoteAddrs
          case dialResult of
            Right () -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DCUtRResult
DCUtRSuccess
            Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"dial failed: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)

-- | Peer A (handler): handle the DCUtR exchange over a relayed stream.
--
-- Flow:
--   1. Read B's CONNECT
--   2. Send CONNECT with own observed addresses
--   3. Read SYNC
--   4. Dial B's addresses immediately
handleDCUtR :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IO DCUtRResult
handleDCUtR :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IO DCUtRResult
handleDCUtR DCUtRConfig
config StreamIO
stream [Multiaddr]
addrs = do
  let addrBytes :: [ByteString]
addrBytes = (Multiaddr -> ByteString) -> [Multiaddr] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Multiaddr -> ByteString
toBytes [Multiaddr]
addrs
  -- Step 1: Read B's CONNECT
  result <- StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage StreamIO
stream Int
maxDCUtRMessageSize
  case result of
    Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"failed to read CONNECT: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
    Right HolePunchMessage
msg
      | HolePunchMessage -> HolePunchType
hpType HolePunchMessage
msg HolePunchType -> HolePunchType -> Bool
forall a. Eq a => a -> a -> Bool
/= HolePunchType
HPConnect -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed String
"expected CONNECT message")
      | Bool
otherwise -> do
          -- Step 2: Send our CONNECT response
          let connectResp :: HolePunchMessage
connectResp = HolePunchMessage { hpType :: HolePunchType
hpType = HolePunchType
HPConnect, hpObsAddrs :: [ByteString]
hpObsAddrs = [ByteString]
addrBytes }
          StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage StreamIO
stream HolePunchMessage
connectResp
          -- Step 3: Read SYNC
          syncResult <- StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage StreamIO
stream Int
maxDCUtRMessageSize
          case syncResult of
            Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"failed to read SYNC: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
            Right HolePunchMessage
syncMsg
              | HolePunchMessage -> HolePunchType
hpType HolePunchMessage
syncMsg HolePunchType -> HolePunchType -> Bool
forall a. Eq a => a -> a -> Bool
/= HolePunchType
HPSync -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed String
"expected SYNC message")
              | Bool
otherwise -> do
                  -- Step 4: Dial B's addresses immediately
                  let remoteAddrs :: [Multiaddr]
remoteAddrs = [ByteString] -> [Multiaddr]
parseAddrs (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)
                  dialResult <- DCUtRConfig -> [Multiaddr] -> IO (Either String ())
tryDialAddrs DCUtRConfig
config [Multiaddr]
remoteAddrs
                  case dialResult of
                    Right () -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DCUtRResult
DCUtRSuccess
                    Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"dial failed: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)

-- | Handler variant that captures received addresses for testing.
handleDCUtRCapture :: DCUtRConfig -> StreamIO -> [Multiaddr] -> IORef [BS.ByteString] -> IO DCUtRResult
handleDCUtRCapture :: DCUtRConfig
-> StreamIO -> [Multiaddr] -> IORef [ByteString] -> IO DCUtRResult
handleDCUtRCapture DCUtRConfig
config StreamIO
stream [Multiaddr]
addrs IORef [ByteString]
receivedRef = do
  let addrBytes :: [ByteString]
addrBytes = (Multiaddr -> ByteString) -> [Multiaddr] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Multiaddr -> ByteString
toBytes [Multiaddr]
addrs
  result <- StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage StreamIO
stream Int
maxDCUtRMessageSize
  case result of
    Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"failed to read CONNECT: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
    Right HolePunchMessage
msg
      | HolePunchMessage -> HolePunchType
hpType HolePunchMessage
msg HolePunchType -> HolePunchType -> Bool
forall a. Eq a => a -> a -> Bool
/= HolePunchType
HPConnect -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed String
"expected CONNECT message")
      | Bool
otherwise -> do
          IORef [ByteString] -> [ByteString] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [ByteString]
receivedRef (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)
          let connectResp :: HolePunchMessage
connectResp = HolePunchMessage { hpType :: HolePunchType
hpType = HolePunchType
HPConnect, hpObsAddrs :: [ByteString]
hpObsAddrs = [ByteString]
addrBytes }
          StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage StreamIO
stream HolePunchMessage
connectResp
          syncResult <- StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage StreamIO
stream Int
maxDCUtRMessageSize
          case syncResult of
            Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"failed to read SYNC: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
            Right HolePunchMessage
syncMsg
              | HolePunchMessage -> HolePunchType
hpType HolePunchMessage
syncMsg HolePunchType -> HolePunchType -> Bool
forall a. Eq a => a -> a -> Bool
/= HolePunchType
HPSync -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed String
"expected SYNC message")
              | Bool
otherwise -> do
                  let remoteAddrs :: [Multiaddr]
remoteAddrs = [ByteString] -> [Multiaddr]
parseAddrs (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)
                  dialResult <- DCUtRConfig -> [Multiaddr] -> IO (Either String ())
tryDialAddrs DCUtRConfig
config [Multiaddr]
remoteAddrs
                  case dialResult of
                    Right () -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DCUtRResult
DCUtRSuccess
                    Left String
err -> DCUtRResult -> IO DCUtRResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> DCUtRResult
DCUtRFailed (String -> DCUtRResult) -> String -> DCUtRResult
forall a b. (a -> b) -> a -> b
$ String
"dial failed: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)

-- Helpers

-- | Parse binary multiaddr bytes into Multiaddrs, skipping invalid ones.
parseAddrs :: [BS.ByteString] -> [Multiaddr]
parseAddrs :: [ByteString] -> [Multiaddr]
parseAddrs = (ByteString -> [Multiaddr] -> [Multiaddr])
-> [Multiaddr] -> [ByteString] -> [Multiaddr]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\ByteString
bs [Multiaddr]
acc -> case ByteString -> Either String Multiaddr
fromBytes ByteString
bs of Right Multiaddr
a -> Multiaddr
a Multiaddr -> [Multiaddr] -> [Multiaddr]
forall a. a -> [a] -> [a]
: [Multiaddr]
acc; Left String
_ -> [Multiaddr]
acc) []

-- | Try to dial any of the given addresses. Returns Right () on first success.
tryDialAddrs :: DCUtRConfig -> [Multiaddr] -> IO (Either String ())
tryDialAddrs :: DCUtRConfig -> [Multiaddr] -> IO (Either String ())
tryDialAddrs DCUtRConfig
_config [] = Either String () -> IO (Either String ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String ()
forall a b. a -> Either a b
Left String
"no addresses to dial")
tryDialAddrs DCUtRConfig
config (Multiaddr
addr:[Multiaddr]
rest) = do
  result <- DCUtRConfig -> Multiaddr -> IO (Either String ())
dcDialer DCUtRConfig
config Multiaddr
addr
  case result of
    Right () -> Either String () -> IO (Either String ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either String ()
forall a b. b -> Either a b
Right ())
    Left String
_err
      | [Multiaddr] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Multiaddr]
rest -> Either String () -> IO (Either String ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String ()
forall a b. a -> Either a b
Left String
"all dial attempts failed")
      | Bool
otherwise -> DCUtRConfig -> [Multiaddr] -> IO (Either String ())
tryDialAddrs DCUtRConfig
config [Multiaddr]
rest