-- | STM-based connection pool for the Switch.
--
-- Tracks active connections per peer. All operations are STM-safe
-- for concurrent access from dial, listen, and cleanup threads.
--
-- Connection identity uses TVar pointer equality (connState field)
-- since Connection contains function fields that prevent deriving Eq.
module Network.LibP2P.Switch.ConnPool
  ( newConnPool
  , lookupConn
  , lookupAllConns
  , addConn
  , removeConn
  , allConns
  ) where

import Control.Concurrent.STM (STM, TVar, newTVarIO, readTVar, writeTVar)
import qualified Data.Map.Strict as Map
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.Switch.Types (ConnState (..), Connection (..))

-- | Create a new empty connection pool.
newConnPool :: IO (TVar (Map.Map PeerId [Connection]))
newConnPool :: IO (TVar (Map PeerId [Connection]))
newConnPool = Map PeerId [Connection] -> IO (TVar (Map PeerId [Connection]))
forall a. a -> IO (TVar a)
newTVarIO Map PeerId [Connection]
forall k a. Map k a
Map.empty

-- | Look up the first Open connection for a peer.
-- Returns Nothing if no connection exists or none are in ConnOpen state.
lookupConn :: TVar (Map.Map PeerId [Connection]) -> PeerId -> STM (Maybe Connection)
lookupConn :: TVar (Map PeerId [Connection]) -> PeerId -> STM (Maybe Connection)
lookupConn TVar (Map PeerId [Connection])
poolVar PeerId
pid = do
  pool <- TVar (Map PeerId [Connection]) -> STM (Map PeerId [Connection])
forall a. TVar a -> STM a
readTVar TVar (Map PeerId [Connection])
poolVar
  case Map.lookup pid pool of
    Maybe [Connection]
Nothing -> Maybe Connection -> STM (Maybe Connection)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Connection
forall a. Maybe a
Nothing
    Just [Connection]
conns -> [Connection] -> STM (Maybe Connection)
findOpen [Connection]
conns
  where
    findOpen :: [Connection] -> STM (Maybe Connection)
findOpen [] = Maybe Connection -> STM (Maybe Connection)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Connection
forall a. Maybe a
Nothing
    findOpen (Connection
c : [Connection]
rest) = do
      st <- TVar ConnState -> STM ConnState
forall a. TVar a -> STM a
readTVar (Connection -> TVar ConnState
connState Connection
c)
      if st == ConnOpen
        then pure (Just c)
        else findOpen rest

-- | Look up all connections for a peer (any state).
lookupAllConns :: TVar (Map.Map PeerId [Connection]) -> PeerId -> STM [Connection]
lookupAllConns :: TVar (Map PeerId [Connection]) -> PeerId -> STM [Connection]
lookupAllConns TVar (Map PeerId [Connection])
poolVar PeerId
pid = do
  pool <- TVar (Map PeerId [Connection]) -> STM (Map PeerId [Connection])
forall a. TVar a -> STM a
readTVar TVar (Map PeerId [Connection])
poolVar
  pure $ Map.findWithDefault [] pid pool

-- | Add a connection to the pool, keyed by its peer ID.
addConn :: TVar (Map.Map PeerId [Connection]) -> Connection -> STM ()
addConn :: TVar (Map PeerId [Connection]) -> Connection -> STM ()
addConn TVar (Map PeerId [Connection])
poolVar Connection
conn = do
  pool <- TVar (Map PeerId [Connection]) -> STM (Map PeerId [Connection])
forall a. TVar a -> STM a
readTVar TVar (Map PeerId [Connection])
poolVar
  let pid = Connection -> PeerId
connPeerId Connection
conn
      conns = [Connection] -> PeerId -> Map PeerId [Connection] -> [Connection]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] PeerId
pid Map PeerId [Connection]
pool
  writeTVar poolVar (Map.insert pid (conns ++ [conn]) pool)

-- | Remove a specific connection from the pool.
-- Uses TVar reference equality (connState) to identify the connection.
-- Removes empty entries from the map to prevent memory leaks.
removeConn :: TVar (Map.Map PeerId [Connection]) -> Connection -> STM ()
removeConn :: TVar (Map PeerId [Connection]) -> Connection -> STM ()
removeConn TVar (Map PeerId [Connection])
poolVar Connection
conn = do
  pool <- TVar (Map PeerId [Connection]) -> STM (Map PeerId [Connection])
forall a. TVar a -> STM a
readTVar TVar (Map PeerId [Connection])
poolVar
  let pid = Connection -> PeerId
connPeerId Connection
conn
      targetState = Connection -> TVar ConnState
connState Connection
conn
  case Map.lookup pid pool of
    Maybe [Connection]
Nothing -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just [Connection]
conns -> do
      let remaining :: [Connection]
remaining = (Connection -> Bool) -> [Connection] -> [Connection]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Connection
c -> Connection -> TVar ConnState
connState Connection
c TVar ConnState -> TVar ConnState -> Bool
forall a. Eq a => a -> a -> Bool
/= TVar ConnState
targetState) [Connection]
conns
      if [Connection] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Connection]
remaining
        then TVar (Map PeerId [Connection]) -> Map PeerId [Connection] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Map PeerId [Connection])
poolVar (PeerId -> Map PeerId [Connection] -> Map PeerId [Connection]
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete PeerId
pid Map PeerId [Connection]
pool)
        else TVar (Map PeerId [Connection]) -> Map PeerId [Connection] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Map PeerId [Connection])
poolVar (PeerId
-> [Connection]
-> Map PeerId [Connection]
-> Map PeerId [Connection]
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert PeerId
pid [Connection]
remaining Map PeerId [Connection]
pool)

-- | Get all connections across all peers.
allConns :: TVar (Map.Map PeerId [Connection]) -> STM [Connection]
allConns :: TVar (Map PeerId [Connection]) -> STM [Connection]
allConns TVar (Map PeerId [Connection])
poolVar = do
  pool <- TVar (Map PeerId [Connection]) -> STM (Map PeerId [Connection])
forall a. TVar a -> STM a
readTVar TVar (Map PeerId [Connection])
poolVar
  pure $ concat (Map.elems pool)