module Network.LibP2P.Switch.Listen
(
ConnectionGater (..)
, defaultConnectionGater
, handleInbound
, streamAcceptLoop
, dispatchStream
, switchListen
, acceptLoop
, switchListenAddrs
) where
import Control.Concurrent.Async (async)
import Control.Concurrent.STM (atomically, readTVar, writeTVar)
import Control.Exception (SomeException, catch)
import Data.List (find)
import qualified Data.Map.Strict as Map
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr)
import Network.LibP2P.MultistreamSelect.Negotiation
( NegotiationResult (..)
, StreamIO
, negotiateResponder
)
import Network.LibP2P.Switch.ConnPool (addConn)
import Network.LibP2P.Switch.ResourceManager (Direction (..), reserveConnection)
import Network.LibP2P.Switch.Types
( ActiveListener (..)
, Connection (..)
, MuxerSession (..)
, Switch (..)
)
import Network.LibP2P.Switch.Upgrade (upgradeInbound)
import Network.LibP2P.Transport.Transport (Listener (..), RawConnection (..), Transport (..))
data ConnectionGater = ConnectionGater
{ ConnectionGater -> Multiaddr -> IO Bool
gateAccept :: !(Multiaddr -> IO Bool)
, ConnectionGater -> PeerId -> IO Bool
gateSecured :: !(PeerId -> IO Bool)
}
defaultConnectionGater :: ConnectionGater
defaultConnectionGater :: ConnectionGater
defaultConnectionGater = ConnectionGater
{ gateAccept :: Multiaddr -> IO Bool
gateAccept = \Multiaddr
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
, gateSecured :: PeerId -> IO Bool
gateSecured = \PeerId
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
}
handleInbound :: Switch -> ConnectionGater -> RawConnection -> IO ()
handleInbound :: Switch -> ConnectionGater -> RawConnection -> IO ()
handleInbound Switch
sw ConnectionGater
gater RawConnection
rawConn = do
allowed <- ConnectionGater -> Multiaddr -> IO Bool
gateAccept ConnectionGater
gater (RawConnection -> Multiaddr
rcRemoteAddr RawConnection
rawConn)
if not allowed
then rcClose rawConn
else do
conn <- upgradeInbound (swIdentityKey sw) rawConn
secured <- gateSecured gater (connPeerId conn)
if not secured
then muxClose (connSession conn)
else do
resCheck <- atomically $ reserveConnection (swResourceMgr sw) (connPeerId conn) Inbound
case resCheck of
Left ResourceError
_ -> MuxerSession -> IO ()
muxClose (Connection -> MuxerSession
connSession Connection
conn)
Right () -> do
STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Map PeerId [Connection]) -> Connection -> STM ()
addConn (Switch -> TVar (Map PeerId [Connection])
swConnPool Switch
sw) Connection
conn
notifiers <- STM [Connection -> IO ()] -> IO [Connection -> IO ()]
forall a. STM a -> IO a
atomically (STM [Connection -> IO ()] -> IO [Connection -> IO ()])
-> STM [Connection -> IO ()] -> IO [Connection -> IO ()]
forall a b. (a -> b) -> a -> b
$ TVar [Connection -> IO ()] -> STM [Connection -> IO ()]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [Connection -> IO ()]
swNotifiers Switch
sw)
mapM_ (\Connection -> IO ()
f -> IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
f Connection
conn) notifiers
streamAcceptLoop sw conn
streamAcceptLoop :: Switch -> Connection -> IO ()
streamAcceptLoop :: Switch -> Connection -> IO ()
streamAcceptLoop Switch
sw Connection
conn = IO ()
loop
where
loop :: IO ()
loop = do
result <- IO (Maybe StreamIO)
safeAccept
case result of
Maybe StreamIO
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just StreamIO
stream -> do
_ <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Switch -> Connection -> StreamIO -> IO ()
dispatchStream Switch
sw Connection
conn StreamIO
stream
loop
safeAccept :: IO (Maybe StreamIO)
safeAccept =
(StreamIO -> Maybe StreamIO
forall a. a -> Maybe a
Just (StreamIO -> Maybe StreamIO) -> IO StreamIO -> IO (Maybe StreamIO)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MuxerSession -> IO StreamIO
muxAcceptStream (Connection -> MuxerSession
connSession Connection
conn))
IO (Maybe StreamIO)
-> (SomeException -> IO (Maybe StreamIO)) -> IO (Maybe StreamIO)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> Maybe StreamIO -> IO (Maybe StreamIO)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe StreamIO
forall a. Maybe a
Nothing)
dispatchStream :: Switch -> Connection -> StreamIO -> IO ()
dispatchStream :: Switch -> Connection -> StreamIO -> IO ()
dispatchStream Switch
sw Connection
conn StreamIO
stream = do
supportedProtos <- STM [ProtocolId] -> IO [ProtocolId]
forall a. STM a -> IO a
atomically (STM [ProtocolId] -> IO [ProtocolId])
-> STM [ProtocolId] -> IO [ProtocolId]
forall a b. (a -> b) -> a -> b
$
Map ProtocolId StreamHandler -> [ProtocolId]
forall k a. Map k a -> [k]
Map.keys (Map ProtocolId StreamHandler -> [ProtocolId])
-> STM (Map ProtocolId StreamHandler) -> STM [ProtocolId]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM (Map ProtocolId StreamHandler)
readProtos
result <- negotiateResponder stream supportedProtos
case result of
Accepted ProtocolId
proto -> do
mHandler <- ProtocolId -> IO (Maybe StreamHandler)
lookupHandler ProtocolId
proto
case mHandler of
Just StreamHandler
handler -> StreamHandler
handler StreamIO
stream (Connection -> PeerId
connPeerId Connection
conn)
Maybe StreamHandler
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
NegotiationResult
NoProtocol -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
readProtos :: STM (Map ProtocolId StreamHandler)
readProtos = TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)
lookupHandler :: ProtocolId -> IO (Maybe StreamHandler)
lookupHandler ProtocolId
proto = STM (Maybe StreamHandler) -> IO (Maybe StreamHandler)
forall a. STM a -> IO a
atomically (STM (Maybe StreamHandler) -> IO (Maybe StreamHandler))
-> STM (Maybe StreamHandler) -> IO (Maybe StreamHandler)
forall a b. (a -> b) -> a -> b
$
ProtocolId -> Map ProtocolId StreamHandler -> Maybe StreamHandler
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ProtocolId
proto (Map ProtocolId StreamHandler -> Maybe StreamHandler)
-> STM (Map ProtocolId StreamHandler) -> STM (Maybe StreamHandler)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map ProtocolId StreamHandler)
-> STM (Map ProtocolId StreamHandler)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map ProtocolId StreamHandler)
swProtocols Switch
sw)
switchListen :: Switch -> ConnectionGater -> [Multiaddr] -> IO [Multiaddr]
switchListen :: Switch -> ConnectionGater -> [Multiaddr] -> IO [Multiaddr]
switchListen Switch
sw ConnectionGater
gater [Multiaddr]
addrs = do
closed <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar (Switch -> TVar Bool
swClosed Switch
sw)
if closed
then fail "switchListen: switch is closed"
else do
transports <- atomically $ readTVar (swTransports sw)
activeListeners <- mapM (bindAndListen transports gater sw) addrs
let newListeners = [[ActiveListener]] -> [ActiveListener]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[ActiveListener]]
activeListeners
atomically $ do
existing <- readTVar (swListeners sw)
writeTVar (swListeners sw) (existing ++ newListeners)
pure (map alAddress newListeners)
where
bindAndListen :: t Transport
-> ConnectionGater -> Switch -> Multiaddr -> IO [ActiveListener]
bindAndListen t Transport
transports ConnectionGater
gater' Switch
sw' Multiaddr
addr = do
case (Transport -> Bool) -> t Transport -> Maybe Transport
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\Transport
t -> Transport -> Multiaddr -> Bool
transportCanDial Transport
t Multiaddr
addr) t Transport
transports of
Maybe Transport
Nothing -> String -> IO [ActiveListener]
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> IO [ActiveListener]) -> String -> IO [ActiveListener]
forall a b. (a -> b) -> a -> b
$ String
"switchListen: no transport for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Multiaddr -> String
forall a. Show a => a -> String
show Multiaddr
addr
Just Transport
transport -> do
listener <- Transport -> Multiaddr -> IO Listener
transportListen Transport
transport Multiaddr
addr
loopThread <- async $ acceptLoop sw' gater' listener
pure [ActiveListener
{ alListener = listener
, alAcceptLoop = loopThread
, alAddress = listenerAddr listener
}]
acceptLoop :: Switch -> ConnectionGater -> Listener -> IO ()
acceptLoop :: Switch -> ConnectionGater -> Listener -> IO ()
acceptLoop Switch
sw ConnectionGater
gater Listener
listener = IO ()
loop
where
loop :: IO ()
loop = do
result <- (RawConnection -> Either () RawConnection
forall a b. b -> Either a b
Right (RawConnection -> Either () RawConnection)
-> IO RawConnection -> IO (Either () RawConnection)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Listener -> IO RawConnection
listenerAccept Listener
listener)
IO (Either () RawConnection)
-> (SomeException -> IO (Either () RawConnection))
-> IO (Either () RawConnection)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> Either () RawConnection -> IO (Either () RawConnection)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either () RawConnection
forall a b. a -> Either a b
Left ()))
case result of
Left () -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Right RawConnection
rawConn -> do
_ <- IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$ Switch -> ConnectionGater -> RawConnection -> IO ()
handleInbound Switch
sw ConnectionGater
gater RawConnection
rawConn
IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
_ :: SomeException) -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
loop
switchListenAddrs :: Switch -> IO [Multiaddr]
switchListenAddrs :: Switch -> IO [Multiaddr]
switchListenAddrs Switch
sw = STM [Multiaddr] -> IO [Multiaddr]
forall a. STM a -> IO a
atomically (STM [Multiaddr] -> IO [Multiaddr])
-> STM [Multiaddr] -> IO [Multiaddr]
forall a b. (a -> b) -> a -> b
$ do
listeners <- TVar [ActiveListener] -> STM [ActiveListener]
forall a. TVar a -> STM a
readTVar (Switch -> TVar [ActiveListener]
swListeners Switch
sw)
pure (map alAddress listeners)