-- | Circuit Relay v2 server: manage reservations and bridge streams.
--
-- Protocols:
--   /libp2p/circuit/relay/0.2.0/hop (client ↔ relay)
--
-- Provides:
--   - Reservation management (with expiration and limits)
--   - Stream bridging between source and target
--   - Resource limits (max reservations, max circuits, data/duration limits)
module Network.LibP2P.NAT.Relay.Relay
  ( -- * Types
    RelayConfig (..)
  , RelayState (..)
  , ActiveReservation (..)
    -- * Configuration
  , defaultRelayConfig
    -- * State management
  , newRelayState
    -- * Handlers
  , handleReserve
  , handleConnect
    -- * Stream bridging
  , bridgeStreams
    -- * Relay address helpers
  , buildRelayAddrBytes
  , isRelayedConnection
  ) where

import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Control.Concurrent.Async (race)
import Control.Concurrent.STM
import Data.IORef (IORef, newIORef, readIORef, modifyIORef')
import qualified Data.Map.Strict as Map
import Data.Word (Word32, Word64)
import Network.LibP2P.NAT.Relay.Message
import Network.LibP2P.MultistreamSelect.Negotiation (StreamIO (..))
import Network.LibP2P.Crypto.PeerId (PeerId (..))
import Network.LibP2P.Core.Varint (encodeUvarint)

-- | Relay server configuration.
data RelayConfig = RelayConfig
  { RelayConfig -> Int
rcMaxReservations      :: !Int      -- ^ Max concurrent reservations
  , RelayConfig -> Int
rcMaxCircuits          :: !Int      -- ^ Max concurrent relayed circuits
  , RelayConfig -> Word64
rcReservationDuration  :: !Word64   -- ^ Reservation duration (seconds)
  , RelayConfig -> Word64
rcDefaultDataLimit     :: !Word64   -- ^ Default data limit per circuit (bytes)
  , RelayConfig -> Word32
rcDefaultDurationLimit :: !Word32   -- ^ Default duration limit per circuit (seconds)
  } deriving (Int -> RelayConfig -> ShowS
[RelayConfig] -> ShowS
RelayConfig -> String
(Int -> RelayConfig -> ShowS)
-> (RelayConfig -> String)
-> ([RelayConfig] -> ShowS)
-> Show RelayConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RelayConfig -> ShowS
showsPrec :: Int -> RelayConfig -> ShowS
$cshow :: RelayConfig -> String
show :: RelayConfig -> String
$cshowList :: [RelayConfig] -> ShowS
showList :: [RelayConfig] -> ShowS
Show, RelayConfig -> RelayConfig -> Bool
(RelayConfig -> RelayConfig -> Bool)
-> (RelayConfig -> RelayConfig -> Bool) -> Eq RelayConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RelayConfig -> RelayConfig -> Bool
== :: RelayConfig -> RelayConfig -> Bool
$c/= :: RelayConfig -> RelayConfig -> Bool
/= :: RelayConfig -> RelayConfig -> Bool
Eq)

-- | Default relay configuration.
defaultRelayConfig :: RelayConfig
defaultRelayConfig :: RelayConfig
defaultRelayConfig = RelayConfig
  { rcMaxReservations :: Int
rcMaxReservations      = Int
128
  , rcMaxCircuits :: Int
rcMaxCircuits          = Int
16
  , rcReservationDuration :: Word64
rcReservationDuration  = Word64
3600  -- 1 hour
  , rcDefaultDataLimit :: Word64
rcDefaultDataLimit     = Word64
131072  -- 128 KiB
  , rcDefaultDurationLimit :: Word32
rcDefaultDurationLimit = Word32
120  -- 2 minutes
  }

-- | An active reservation for a peer.
data ActiveReservation = ActiveReservation
  { ActiveReservation -> PeerId
arPeerId     :: !PeerId
  , ActiveReservation -> Word64
arExpiration :: !Word64   -- ^ Unix timestamp
  } deriving (Int -> ActiveReservation -> ShowS
[ActiveReservation] -> ShowS
ActiveReservation -> String
(Int -> ActiveReservation -> ShowS)
-> (ActiveReservation -> String)
-> ([ActiveReservation] -> ShowS)
-> Show ActiveReservation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ActiveReservation -> ShowS
showsPrec :: Int -> ActiveReservation -> ShowS
$cshow :: ActiveReservation -> String
show :: ActiveReservation -> String
$cshowList :: [ActiveReservation] -> ShowS
showList :: [ActiveReservation] -> ShowS
Show, ActiveReservation -> ActiveReservation -> Bool
(ActiveReservation -> ActiveReservation -> Bool)
-> (ActiveReservation -> ActiveReservation -> Bool)
-> Eq ActiveReservation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ActiveReservation -> ActiveReservation -> Bool
== :: ActiveReservation -> ActiveReservation -> Bool
$c/= :: ActiveReservation -> ActiveReservation -> Bool
/= :: ActiveReservation -> ActiveReservation -> Bool
Eq)

-- | Mutable relay server state.
data RelayState = RelayState
  { RelayState -> RelayConfig
rsConfig       :: !RelayConfig
  , RelayState -> TVar (Map PeerId ActiveReservation)
rsReservations :: !(TVar (Map.Map PeerId ActiveReservation))
  , RelayState -> TVar Int
rsCircuitCount :: !(TVar Int)
  }

-- | Create new relay state from configuration.
newRelayState :: RelayConfig -> IO RelayState
newRelayState :: RelayConfig -> IO RelayState
newRelayState RelayConfig
config = RelayConfig
-> TVar (Map PeerId ActiveReservation) -> TVar Int -> RelayState
RelayState RelayConfig
config
  (TVar (Map PeerId ActiveReservation) -> TVar Int -> RelayState)
-> IO (TVar (Map PeerId ActiveReservation))
-> IO (TVar Int -> RelayState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map PeerId ActiveReservation
-> IO (TVar (Map PeerId ActiveReservation))
forall a. a -> IO (TVar a)
newTVarIO Map PeerId ActiveReservation
forall k a. Map k a
Map.empty
  IO (TVar Int -> RelayState) -> IO (TVar Int) -> IO RelayState
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0

-- | Handle a RESERVE request from a peer.
handleReserve :: RelayState -> StreamIO -> PeerId -> IO ()
handleReserve :: RelayState -> StreamIO -> PeerId -> IO ()
handleReserve RelayState
state StreamIO
stream PeerId
peerId = do
  -- Check resource limits
  reservations <- TVar (Map PeerId ActiveReservation)
-> IO (Map PeerId ActiveReservation)
forall a. TVar a -> IO a
readTVarIO (RelayState -> TVar (Map PeerId ActiveReservation)
rsReservations RelayState
state)
  let limit = RelayConfig -> Int
rcMaxReservations (RelayState -> RelayConfig
rsConfig RelayState
state)
  if Map.size reservations >= limit
    then sendHopStatus stream ResourceLimitExceeded
    else do
      -- Create reservation
      let expiration = RelayConfig -> Word64
rcReservationDuration (RelayState -> RelayConfig
rsConfig RelayState
state)
          reservation = ActiveReservation
            { arPeerId :: PeerId
arPeerId = PeerId
peerId
            , arExpiration :: Word64
arExpiration = Word64
expiration
            }
      atomically $ modifyTVar' (rsReservations state) (Map.insert peerId reservation)
      -- Send OK response with reservation info
      let resp = HopMessage
            { hopType :: Maybe HopMessageType
hopType = HopMessageType -> Maybe HopMessageType
forall a. a -> Maybe a
Just HopMessageType
HopStatus
            , hopPeer :: Maybe RelayPeer
hopPeer = Maybe RelayPeer
forall a. Maybe a
Nothing
            , hopReservation :: Maybe Reservation
hopReservation = Reservation -> Maybe Reservation
forall a. a -> Maybe a
Just Reservation
                { rsvExpire :: Maybe Word64
rsvExpire = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just Word64
expiration
                , rsvAddrs :: [ByteString]
rsvAddrs = []  -- relay would populate with own addresses
                , rsvVoucher :: Maybe ByteString
rsvVoucher = Maybe ByteString
forall a. Maybe a
Nothing  -- voucher signing handled separately
                }
            , hopLimit :: Maybe RelayLimit
hopLimit = RelayLimit -> Maybe RelayLimit
forall a. a -> Maybe a
Just RelayLimit
                { rlDuration :: Maybe Word32
rlDuration = Word32 -> Maybe Word32
forall a. a -> Maybe a
Just (RelayConfig -> Word32
rcDefaultDurationLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                , rlData :: Maybe Word64
rlData = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just (RelayConfig -> Word64
rcDefaultDataLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                }
            , hopStatus :: Maybe RelayStatus
hopStatus = RelayStatus -> Maybe RelayStatus
forall a. a -> Maybe a
Just RelayStatus
RelayOK
            }
      writeHopMessage stream resp

-- | Handle a CONNECT request from a peer.
-- The openStopStream callback is used to open a stop stream to the target.
handleConnect :: RelayState -> StreamIO -> PeerId -> HopMessage -> (PeerId -> IO (Maybe StreamIO)) -> IO ()
handleConnect :: RelayState
-> StreamIO
-> PeerId
-> HopMessage
-> (PeerId -> IO (Maybe StreamIO))
-> IO ()
handleConnect RelayState
state StreamIO
stream PeerId
_sourcePeerId HopMessage
msg PeerId -> IO (Maybe StreamIO)
openStopStream = do
  case HopMessage -> Maybe RelayPeer
hopPeer HopMessage
msg of
    Maybe RelayPeer
Nothing -> StreamIO -> RelayStatus -> IO ()
sendHopStatus StreamIO
stream RelayStatus
MalformedMessage
    Just RelayPeer
peer -> do
      let targetId :: PeerId
targetId = ByteString -> PeerId
PeerId (RelayPeer -> ByteString
rpId RelayPeer
peer)
      -- Check target has a reservation
      reservations <- TVar (Map PeerId ActiveReservation)
-> IO (Map PeerId ActiveReservation)
forall a. TVar a -> IO a
readTVarIO (RelayState -> TVar (Map PeerId ActiveReservation)
rsReservations RelayState
state)
      case Map.lookup targetId reservations of
        Maybe ActiveReservation
Nothing -> StreamIO -> RelayStatus -> IO ()
sendHopStatus StreamIO
stream RelayStatus
NoReservation
        Just ActiveReservation
_rsv -> do
          -- Try to open stop stream to target
          mStopStream <- PeerId -> IO (Maybe StreamIO)
openStopStream PeerId
targetId
          case mStopStream of
            Maybe StreamIO
Nothing -> StreamIO -> RelayStatus -> IO ()
sendHopStatus StreamIO
stream RelayStatus
ConnectionFailed
            Just StreamIO
stopStream -> do
              -- Send CONNECT to target via stop protocol
              let stopMsg :: StopMessage
stopMsg = StopMessage
                    { stopType :: Maybe StopMessageType
stopType = StopMessageType -> Maybe StopMessageType
forall a. a -> Maybe a
Just StopMessageType
StopConnect
                    , stopPeer :: Maybe RelayPeer
stopPeer = RelayPeer -> Maybe RelayPeer
forall a. a -> Maybe a
Just RelayPeer
                        { rpId :: ByteString
rpId = let PeerId ByteString
bs = PeerId
_sourcePeerId in ByteString
bs
                        , rpAddrs :: [ByteString]
rpAddrs = []
                        }
                    , stopLimit :: Maybe RelayLimit
stopLimit = RelayLimit -> Maybe RelayLimit
forall a. a -> Maybe a
Just RelayLimit
                        { rlDuration :: Maybe Word32
rlDuration = Word32 -> Maybe Word32
forall a. a -> Maybe a
Just (RelayConfig -> Word32
rcDefaultDurationLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                        , rlData :: Maybe Word64
rlData = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just (RelayConfig -> Word64
rcDefaultDataLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                        }
                    , stopStatus :: Maybe RelayStatus
stopStatus = Maybe RelayStatus
forall a. Maybe a
Nothing
                    }
              StreamIO -> StopMessage -> IO ()
writeStopMessage StreamIO
stopStream StopMessage
stopMsg
              -- Wait for target's STATUS response
              targetResp <- StreamIO -> Int -> IO (Either String StopMessage)
readStopMessage StreamIO
stopStream Int
maxRelayMessageSize
              case targetResp of
                Right StopMessage
resp | StopMessage -> Maybe RelayStatus
stopStatus StopMessage
resp Maybe RelayStatus -> Maybe RelayStatus -> Bool
forall a. Eq a => a -> a -> Bool
== RelayStatus -> Maybe RelayStatus
forall a. a -> Maybe a
Just RelayStatus
RelayOK -> do
                  -- Notify source of success
                  let okResp :: HopMessage
okResp = HopMessage
                        { hopType :: Maybe HopMessageType
hopType = HopMessageType -> Maybe HopMessageType
forall a. a -> Maybe a
Just HopMessageType
HopStatus
                        , hopPeer :: Maybe RelayPeer
hopPeer = Maybe RelayPeer
forall a. Maybe a
Nothing
                        , hopReservation :: Maybe Reservation
hopReservation = Maybe Reservation
forall a. Maybe a
Nothing
                        , hopLimit :: Maybe RelayLimit
hopLimit = RelayLimit -> Maybe RelayLimit
forall a. a -> Maybe a
Just RelayLimit
                            { rlDuration :: Maybe Word32
rlDuration = Word32 -> Maybe Word32
forall a. a -> Maybe a
Just (RelayConfig -> Word32
rcDefaultDurationLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                            , rlData :: Maybe Word64
rlData = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just (RelayConfig -> Word64
rcDefaultDataLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                            }
                        , hopStatus :: Maybe RelayStatus
hopStatus = RelayStatus -> Maybe RelayStatus
forall a. a -> Maybe a
Just RelayStatus
RelayOK
                        }
                  StreamIO -> HopMessage -> IO ()
writeHopMessage StreamIO
stream HopMessage
okResp
                  -- Bridge the two streams
                  let limit :: Maybe RelayLimit
limit = RelayLimit -> Maybe RelayLimit
forall a. a -> Maybe a
Just RelayLimit
                        { rlDuration :: Maybe Word32
rlDuration = Word32 -> Maybe Word32
forall a. a -> Maybe a
Just (RelayConfig -> Word32
rcDefaultDurationLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                        , rlData :: Maybe Word64
rlData = Word64 -> Maybe Word64
forall a. a -> Maybe a
Just (RelayConfig -> Word64
rcDefaultDataLimit (RelayState -> RelayConfig
rsConfig RelayState
state))
                        }
                  Maybe RelayLimit -> StreamIO -> StreamIO -> IO ()
bridgeStreams Maybe RelayLimit
limit StreamIO
stream StreamIO
stopStream
                Either String StopMessage
_ -> StreamIO -> RelayStatus -> IO ()
sendHopStatus StreamIO
stream RelayStatus
ConnectionFailed

-- | Send a simple HopMessage STATUS response.
sendHopStatus :: StreamIO -> RelayStatus -> IO ()
sendHopStatus :: StreamIO -> RelayStatus -> IO ()
sendHopStatus StreamIO
stream RelayStatus
status = StreamIO -> HopMessage -> IO ()
writeHopMessage StreamIO
stream HopMessage
  { hopType :: Maybe HopMessageType
hopType = HopMessageType -> Maybe HopMessageType
forall a. a -> Maybe a
Just HopMessageType
HopStatus
  , hopPeer :: Maybe RelayPeer
hopPeer = Maybe RelayPeer
forall a. Maybe a
Nothing
  , hopReservation :: Maybe Reservation
hopReservation = Maybe Reservation
forall a. Maybe a
Nothing
  , hopLimit :: Maybe RelayLimit
hopLimit = Maybe RelayLimit
forall a. Maybe a
Nothing
  , hopStatus :: Maybe RelayStatus
hopStatus = RelayStatus -> Maybe RelayStatus
forall a. a -> Maybe a
Just RelayStatus
status
  }

-- | Bridge two streams bidirectionally with optional data/duration limits.
-- Terminates when either direction closes or limits are exceeded.
bridgeStreams :: Maybe RelayLimit -> StreamIO -> StreamIO -> IO ()
bridgeStreams :: Maybe RelayLimit -> StreamIO -> StreamIO -> IO ()
bridgeStreams Maybe RelayLimit
mLimit StreamIO
streamA StreamIO
streamB = do
  let dataLimit :: Int
dataLimit = case Maybe RelayLimit
mLimit of
        Just RelayLimit
lim -> case RelayLimit -> Maybe Word64
rlData RelayLimit
lim of
          Just Word64
n  -> Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
n :: Int
          Maybe Word64
Nothing -> Int
forall a. Bounded a => a
maxBound
        Maybe RelayLimit
Nothing -> Int
forall a. Bounded a => a
maxBound
  -- Track bytes transferred in each direction
  countAtoB <- Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef (Int
0 :: Int)
  countBtoA <- newIORef (0 :: Int)
  -- Forward A→B and B→A concurrently; terminate when either finishes
  _ <- race
    (forwardWithLimit streamA streamB countAtoB dataLimit)
    (forwardWithLimit streamB streamA countBtoA dataLimit)
  pure ()

-- | Forward bytes from source to destination with a byte limit.
forwardWithLimit :: StreamIO -> StreamIO -> IORef Int -> Int -> IO ()
forwardWithLimit :: StreamIO -> StreamIO -> IORef Int -> Int -> IO ()
forwardWithLimit StreamIO
src StreamIO
dst IORef Int
countRef Int
limit = IO ()
go
  where
    go :: IO ()
go = do
      b <- StreamIO -> IO Word8
streamReadByte StreamIO
src
      count <- readIORef countRef
      if count >= limit
        then pure ()  -- limit reached, stop forwarding
        else do
          modifyIORef' countRef (+ 1)
          streamWrite dst (BS.singleton b)
          go

-- | Build a relay multiaddr in binary format.
-- Format: <relayAddr>/p2p/<relayId>/p2p-circuit/p2p/<targetId>
buildRelayAddrBytes :: ByteString -> ByteString -> ByteString -> ByteString
buildRelayAddrBytes :: ByteString -> ByteString -> ByteString -> ByteString
buildRelayAddrBytes ByteString
relayAddr ByteString
relayIdBytes ByteString
targetIdBytes =
  ByteString
relayAddr
  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
p2pProtocolBytes ByteString
relayIdBytes
  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
p2pCircuitBytes
  ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
p2pProtocolBytes ByteString
targetIdBytes
  where
    -- P2P protocol: code 421 (0xa503) + varint(len) + peer ID bytes
    p2pProtocolBytes :: ByteString -> ByteString
    p2pProtocolBytes :: ByteString -> ByteString
p2pProtocolBytes ByteString
pid = Word64 -> ByteString
encodeUvarint Word64
421 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
encodeUvarint (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
pid)) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
pid

    -- P2PCircuit protocol: code 290 (0xa202), no address
    p2pCircuitBytes :: ByteString
    p2pCircuitBytes :: ByteString
p2pCircuitBytes = Word64 -> ByteString
encodeUvarint Word64
290

-- | Check if raw multiaddr bytes contain P2PCircuit (code 290).
-- Simple heuristic: look for the varint encoding of 290.
isRelayedConnection :: ByteString -> Bool
isRelayedConnection :: ByteString -> Bool
isRelayedConnection ByteString
bs =
  let circuitMarker :: ByteString
circuitMarker = Word64 -> ByteString
encodeUvarint Word64
290
  in ByteString
circuitMarker ByteString -> ByteString -> Bool
`BS.isInfixOf` ByteString
bs