-- | Listen loop for the Switch (docs/08-switch.md §Listening).
--
-- Accepts inbound connections, applies connection gating policy,
-- upgrades to secure multiplexed connections, and dispatches
-- inbound streams to registered protocol handlers.
module Network.LibP2P.Switch.Listen
  ( -- * Connection gating
    ConnectionGater (..)
  , defaultConnectionGater
    -- * Inbound connection handling
  , handleInbound
    -- * Stream dispatch
  , streamAcceptLoop
  , dispatchStream
    -- * Listen orchestration
  , switchListen
  , acceptLoop
  , switchListenAddrs
  ) where

import Control.Concurrent.Async (async)
import Control.Concurrent.STM (atomically, readTVar, writeTVar)
import Control.Exception (SomeException, catch)
import Data.List (find)
import qualified Data.Map.Strict as Map
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr)
import Network.LibP2P.MultistreamSelect.Negotiation
  ( NegotiationResult (..)
  , StreamIO
  , negotiateResponder
  )
import Network.LibP2P.Switch.ConnPool (addConn)
import Network.LibP2P.Switch.ResourceManager (Direction (..), reserveConnection)
import Network.LibP2P.Switch.Types
  ( ActiveListener (..)
  , Connection (..)
  , MuxerSession (..)
  , Switch (..)
  )
import Network.LibP2P.Switch.Upgrade (upgradeInbound)
import Network.LibP2P.Transport.Transport (Listener (..), RawConnection (..), Transport (..))

-- | Connection gater: policy-based admission control (docs/08-switch.md §Connection Gating).
--
-- Called at multiple points during connection establishment to allow
-- or deny based on policy (IP blocklist, Peer ID allowlist, etc.).
data ConnectionGater = ConnectionGater
  { ConnectionGater -> Multiaddr -> IO Bool
gateAccept  :: !(Multiaddr -> IO Bool)  -- ^ Check after accepting raw connection (before upgrade)
  , ConnectionGater -> PeerId -> IO Bool
gateSecured :: !(PeerId -> IO Bool)     -- ^ Check after security handshake (remote PeerId known)
  }

-- | Default gater that allows all connections.
defaultConnectionGater :: ConnectionGater
defaultConnectionGater :: ConnectionGater
defaultConnectionGater = ConnectionGater
  { gateAccept :: Multiaddr -> IO Bool
gateAccept  = \Multiaddr
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  , gateSecured :: PeerId -> IO Bool
gateSecured = \PeerId
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
  }

-- | Handle a single inbound connection: gate → upgrade → pool → stream accept loop.
--
-- This function blocks until the connection closes. Each accepted connection
-- should be spawned in its own async thread from the accept loop.
handleInbound :: Switch -> ConnectionGater -> RawConnection -> IO ()
handleInbound :: Switch -> ConnectionGater -> RawConnection -> IO ()
handleInbound Switch
sw ConnectionGater
gater RawConnection
rawConn = do
  -- Gate 1: check remote address before any upgrade work
  allowed <- ConnectionGater -> Multiaddr -> IO Bool
gateAccept ConnectionGater
gater (RawConnection -> Multiaddr
rcRemoteAddr RawConnection
rawConn)
  if not allowed
    then rcClose rawConn
    else do
      -- Upgrade: Noise XX handshake + Yamux session
      conn <- upgradeInbound (swIdentityKey sw) rawConn
      -- Gate 2: check remote PeerId after security handshake
      secured <- gateSecured gater (connPeerId conn)
      if not secured
        then muxClose (connSession conn)
        else do
          -- Gate 3: check resource limits (PeerId known after handshake)
          resCheck <- atomically $ reserveConnection (swResourceMgr sw) (connPeerId conn) Inbound
          case resCheck of
            Left ResourceError
_ -> MuxerSession -> IO ()
muxClose (Connection -> MuxerSession
connSession Connection
conn)
            Right () -> do
              -- Add to connection pool
              STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Map PeerId [Connection]) -> Connection -> STM ()
addConn (Switch -> TVar (Map PeerId [Connection])
swConnPool Switch
sw) Connection
conn
              -- Notify connection listeners (e.g. GossipSub auto-stream open)
              notifiers <- STM [Connection -> IO ()] -> IO [Connection -> IO ()]
forall a. STM a -> IO a
atomically (STM [Connection -> IO ()] -> IO [Connection -> IO ()])
-> STM [Connection -> IO ()] -> IO [Connection -> IO ()]
forall a b. (a -> b) -> a -> b
$ TVar [Connection -> IO ()] -> STM [Connection -> IO ()]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [Connection -> IO ()]
swNotifiers Switch
sw)
              mapM_ (\Connection -> IO ()
f -> IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
f Connection
conn) notifiers
              -- Block on stream accept loop until connection closes
              streamAcceptLoop sw conn

-- | Accept inbound streams and dispatch to registered protocol handlers.
--
-- Runs forever, accepting streams from the muxer and spawning a handler
-- thread for each. Uses multistream-select to negotiate the protocol,
-- then dispatches to the registered StreamHandler.
streamAcceptLoop :: Switch -> Connection -> IO ()
streamAcceptLoop :: Switch -> Connection -> IO ()
streamAcceptLoop Switch
sw Connection
conn = IO ()
loop
  where
    loop :: IO ()
loop = do
      -- Accept the next inbound stream from the muxer
      result <- IO (Maybe StreamIO)
safeAccept
      case result of
        Maybe StreamIO
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- Muxer closed or error, stop the loop
        Just StreamIO
stream -> do
          -- Spawn handler for this stream, continue accepting
          _ <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Switch -> Connection -> StreamIO -> IO ()
dispatchStream Switch
sw Connection
conn StreamIO
stream
          loop
    -- Catch exceptions from muxAcceptStream (e.g. session shutdown)
    safeAccept :: IO (Maybe StreamIO)
safeAccept =
      (StreamIO -> Maybe StreamIO
forall a. a -> Maybe a
Just (StreamIO -> Maybe StreamIO) -> IO StreamIO -> IO (Maybe StreamIO)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MuxerSession -> IO StreamIO
muxAcceptStream (Connection -> MuxerSession
connSession Connection
conn))
        IO (Maybe StreamIO)
-> (SomeException -> IO (Maybe StreamIO)) -> IO (Maybe StreamIO)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> Maybe StreamIO -> IO (Maybe StreamIO)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe StreamIO
forall a. Maybe a
Nothing)

-- | Dispatch a single inbound stream to the appropriate protocol handler.
--
-- Runs multistream-select as responder to determine which protocol
-- the remote peer wants, then looks up and invokes the registered handler.
dispatchStream :: Switch -> Connection -> StreamIO -> IO ()
dispatchStream :: Switch -> Connection -> StreamIO -> IO ()
dispatchStream Switch
sw Connection
conn StreamIO
stream = do
  -- Get the list of supported protocols from the Switch
  supportedProtos <- STM [ProtocolId] -> IO [ProtocolId]
forall a. STM a -> IO a
atomically (STM [ProtocolId] -> IO [ProtocolId])
-> STM [ProtocolId] -> IO [ProtocolId]
forall a b. (a -> b) -> a -> b
$
    Map ProtocolId StreamHandler -> [ProtocolId]
forall k a. Map k a -> [k]
Map.keys (Map ProtocolId StreamHandler -> [ProtocolId])
-> STM (Map ProtocolId StreamHandler) -> STM [ProtocolId]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM (Map ProtocolId StreamHandler)
readProtos
  -- Run multistream-select responder
  result <- negotiateResponder stream supportedProtos
  case result of
    Accepted ProtocolId
proto -> do
      -- Look up the handler for the negotiated protocol
      mHandler <- ProtocolId -> IO (Maybe StreamHandler)
lookupHandler ProtocolId
proto
      case mHandler of
        Just StreamHandler
handler -> StreamHandler
handler StreamIO
stream (Connection -> PeerId
connPeerId Connection
conn)
        Maybe StreamHandler
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- Should not happen: proto was in supported list
    NegotiationResult
NoProtocol -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- No common protocol, stream will be closed
  where
    readProtos :: STM (Map ProtocolId StreamHandler)
readProtos = TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)
    lookupHandler :: ProtocolId -> IO (Maybe StreamHandler)
lookupHandler ProtocolId
proto = STM (Maybe StreamHandler) -> IO (Maybe StreamHandler)
forall a. STM a -> IO a
atomically (STM (Maybe StreamHandler) -> IO (Maybe StreamHandler))
-> STM (Maybe StreamHandler) -> IO (Maybe StreamHandler)
forall a b. (a -> b) -> a -> b
$
      ProtocolId -> Map ProtocolId StreamHandler -> Maybe StreamHandler
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ProtocolId
proto (Map ProtocolId StreamHandler -> Maybe StreamHandler)
-> STM (Map ProtocolId StreamHandler) -> STM (Maybe StreamHandler)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)

-- | Start listening on the given addresses.
--
-- For each address, selects a matching transport, binds a listener,
-- and spawns an accept loop that handles inbound connections.
-- Returns the actual bound addresses (port 0 resolved to actual port).
-- Fails if the switch is already closed.
switchListen :: Switch -> ConnectionGater -> [Multiaddr] -> IO [Multiaddr]
switchListen :: Switch -> ConnectionGater -> [Multiaddr] -> IO [Multiaddr]
switchListen Switch
sw ConnectionGater
gater [Multiaddr]
addrs = do
  closed <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar (Switch -> TVar Bool
swClosed Switch
sw)
  if closed
    then fail "switchListen: switch is closed"
    else do
      transports <- atomically $ readTVar (swTransports sw)
      activeListeners <- mapM (bindAndListen transports gater sw) addrs
      let newListeners = [[ActiveListener]] -> [ActiveListener]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[ActiveListener]]
activeListeners
      atomically $ do
        existing <- readTVar (swListeners sw)
        writeTVar (swListeners sw) (existing ++ newListeners)
      pure (map alAddress newListeners)
  where
    -- Find a transport for the address, bind, and spawn accept loop
    bindAndListen :: t Transport
-> ConnectionGater -> Switch -> Multiaddr -> IO [ActiveListener]
bindAndListen t Transport
transports ConnectionGater
gater' Switch
sw' Multiaddr
addr = do
      case (Transport -> Bool) -> t Transport -> Maybe Transport
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\Transport
t -> Transport -> Multiaddr -> Bool
transportCanDial Transport
t Multiaddr
addr) t Transport
transports of
        Maybe Transport
Nothing -> String -> IO [ActiveListener]
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO [ActiveListener]) -> String -> IO [ActiveListener]
forall a b. (a -> b) -> a -> b
$ String
"switchListen: no transport for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Multiaddr -> String
forall a. Show a => a -> String
show Multiaddr
addr
        Just Transport
transport -> do
          listener <- Transport -> Multiaddr -> IO Listener
transportListen Transport
transport Multiaddr
addr
          loopThread <- async $ acceptLoop sw' gater' listener
          pure [ActiveListener
            { alListener   = listener
            , alAcceptLoop = loopThread
            , alAddress    = listenerAddr listener
            }]

-- | Accept loop: forever accepts connections and spawns handleInbound threads.
-- Catches exceptions from individual connections without stopping the loop.
-- Stops when the listener is closed (accept throws).
acceptLoop :: Switch -> ConnectionGater -> Listener -> IO ()
acceptLoop :: Switch -> ConnectionGater -> Listener -> IO ()
acceptLoop Switch
sw ConnectionGater
gater Listener
listener = IO ()
loop
  where
    loop :: IO ()
loop = do
      result <- (RawConnection -> Either () RawConnection
forall a b. b -> Either a b
Right (RawConnection -> Either () RawConnection)
-> IO RawConnection -> IO (Either () RawConnection)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Listener -> IO RawConnection
listenerAccept Listener
listener)
                  IO (Either () RawConnection)
-> (SomeException -> IO (Either () RawConnection))
-> IO (Either () RawConnection)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> Either () RawConnection -> IO (Either () RawConnection)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either () RawConnection
forall a b. a -> Either a b
Left ()))
      case result of
        Left () -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- Listener closed, stop
        Right RawConnection
rawConn -> do
          _ <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Switch -> ConnectionGater -> RawConnection -> IO ()
handleInbound Switch
sw ConnectionGater
gater RawConnection
rawConn
                         IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
          loop

-- | Get the current listen addresses from all active listeners.
switchListenAddrs :: Switch -> IO [Multiaddr]
switchListenAddrs :: Switch -> IO [Multiaddr]
switchListenAddrs Switch
sw = STM [Multiaddr] -> IO [Multiaddr]
forall a. STM a -> IO a
atomically (STM [Multiaddr] -> IO [Multiaddr])
-> STM [Multiaddr] -> IO [Multiaddr]
forall a b. (a -> b) -> a -> b
$ do
  listeners <- TVar [ActiveListener] -> STM [ActiveListener]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [ActiveListener]
swListeners Switch
sw)
  pure (map alAddress listeners)