-- | Switch core operations.
--
-- The Switch is the central coordinator of the libp2p stack.
-- This module provides construction, transport management,
-- protocol handler registration, and shutdown.
module Network.LibP2P.Switch.Switch
  ( newSwitch
  , addTransport
  , selectTransport
  , setStreamHandler
  , removeStreamHandler
  , lookupStreamHandler
  , switchClose
  ) where

import Control.Concurrent.Async (cancel)
import Control.Concurrent.STM (atomically, newBroadcastTChanIO, newTVarIO, readTVar, writeTVar)
import Control.Exception (SomeException, catch)
import Data.List (find)
import qualified Data.Map.Strict as Map
import Network.LibP2P.Crypto.Key (KeyPair)
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr)
import Network.LibP2P.MultistreamSelect.Negotiation (ProtocolId)
import Network.LibP2P.Switch.ResourceManager (DefaultLimits (..), defaultPeerLimits, defaultSystemLimits, newResourceManager)
import Network.LibP2P.Switch.Types (ActiveListener (..), StreamHandler, Switch (..))
import Network.LibP2P.Transport.Transport (Listener (..), Transport (..))

-- | Create a new Switch with the given local identity.
-- All internal state is initialized empty.
newSwitch :: PeerId -> KeyPair -> IO Switch
newSwitch :: PeerId -> KeyPair -> IO Switch
newSwitch PeerId
pid KeyPair
kp = do
  transportsVar   <- [Transport] -> IO (TVar [Transport])
forall a. a -> IO (TVar a)
newTVarIO []
  poolVar         <- newTVarIO Map.empty
  protosVar       <- newTVarIO Map.empty
  eventsChan      <- newBroadcastTChanIO
  closedVar       <- newTVarIO False
  backoffsVar     <- newTVarIO Map.empty
  pendingDialsVar <- newTVarIO Map.empty
  resMgr <- newResourceManager DefaultLimits
    { dlSystemLimits = defaultSystemLimits
    , dlPeerLimits   = defaultPeerLimits
    }
  peerStoreVar <- newTVarIO Map.empty
  notifiersVar <- newTVarIO []
  listenersVar <- newTVarIO []
  pure Switch
    { swLocalPeerId  = pid
    , swIdentityKey  = kp
    , swTransports   = transportsVar
    , swConnPool     = poolVar
    , swProtocols    = protosVar
    , swEvents       = eventsChan
    , swClosed       = closedVar
    , swDialBackoffs = backoffsVar
    , swPendingDials = pendingDialsVar
    , swResourceMgr  = resMgr
    , swPeerStore    = peerStoreVar
    , swNotifiers    = notifiersVar
    , swListeners    = listenersVar
    }

-- | Register a transport with the switch.
-- Appends to the list of transports; order matters for selectTransport.
addTransport :: Switch -> Transport -> IO ()
addTransport :: Switch -> Transport -> IO ()
addTransport Switch
sw Transport
t = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  ts <- TVar [Transport] -> STM [Transport]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [Transport]
swTransports Switch
sw)
  writeTVar (swTransports sw) (ts ++ [t])

-- | Find the first registered transport that can dial the given multiaddr.
selectTransport :: Switch -> Multiaddr -> IO (Maybe Transport)
selectTransport :: Switch -> Multiaddr -> IO (Maybe Transport)
selectTransport Switch
sw Multiaddr
addr = STM (Maybe Transport) -> IO (Maybe Transport)
forall a. STM a -> IO a
atomically (STM (Maybe Transport) -> IO (Maybe Transport))
-> STM (Maybe Transport) -> IO (Maybe Transport)
forall a b. (a -> b) -> a -> b
$ do
  ts <- TVar [Transport] -> STM [Transport]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [Transport]
swTransports Switch
sw)
  pure $ find (\Transport
t -> Transport -> Multiaddr -> Bool
transportCanDial Transport
t Multiaddr
addr) ts

-- | Register a protocol stream handler.
-- Overwrites any existing handler for the same protocol ID.
setStreamHandler :: Switch -> ProtocolId -> StreamHandler -> IO ()
setStreamHandler :: Switch -> ProtocolId -> StreamHandler -> IO ()
setStreamHandler Switch
sw ProtocolId
proto StreamHandler
handler = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
  do protos <- TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)
     writeTVar (swProtocols sw) (Map.insert proto handler protos)

-- | Remove a protocol stream handler.
removeStreamHandler :: Switch -> ProtocolId -> IO ()
removeStreamHandler :: Switch -> ProtocolId -> IO ()
removeStreamHandler Switch
sw ProtocolId
proto = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
  do protos <- TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)
     writeTVar (swProtocols sw) (Map.delete proto protos)

-- | Look up a registered stream handler by protocol ID.
lookupStreamHandler :: Switch -> ProtocolId -> IO (Maybe StreamHandler)
lookupStreamHandler :: Switch -> ProtocolId -> IO (Maybe StreamHandler)
lookupStreamHandler Switch
sw ProtocolId
proto = STM (Maybe StreamHandler) -> IO (Maybe StreamHandler)
forall a. STM a -> IO a
atomically (STM (Maybe StreamHandler) -> IO (Maybe StreamHandler))
-> STM (Maybe StreamHandler) -> IO (Maybe StreamHandler)
forall a b. (a -> b) -> a -> b
$ do
  protos <- TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)
  pure $ Map.lookup proto protos

-- | Shut down the switch.
-- Cancels all accept loop threads, closes all listeners, then sets the closed flag.
switchClose :: Switch -> IO ()
switchClose :: Switch -> IO ()
switchClose Switch
sw = do
  -- Read and clear listeners atomically
  listeners <- STM [ActiveListener] -> IO [ActiveListener]
forall a. STM a -> IO a
atomically (STM [ActiveListener] -> IO [ActiveListener])
-> STM [ActiveListener] -> IO [ActiveListener]
forall a b. (a -> b) -> a -> b
$ do
    ls <- TVar [ActiveListener] -> STM [ActiveListener]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [ActiveListener]
swListeners Switch
sw)
    writeTVar (swListeners sw) []
    writeTVar (swClosed sw) True
    pure ls
  -- Cancel accept loops and close listeners outside STM
  mapM_ closeListener listeners
  where
    closeListener :: ActiveListener -> IO ()
closeListener ActiveListener
al = do
      Async () -> IO ()
forall a. Async a -> IO ()
cancel (ActiveListener -> Async ()
alAcceptLoop ActiveListener
al) IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
      Listener -> IO ()
listenerClose (ActiveListener -> Listener
alListener ActiveListener
al) IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())