-- | Yamux session management: create, openStream, acceptStream, ping, goaway.
--
-- Implements the session-level Yamux protocol per HashiCorp yamux spec.md.
-- The session manages a collection of multiplexed streams over a single
-- underlying transport connection.
--
-- Two background loops run per session:
--   recvLoop: reads 12-byte headers from transport, dispatches to streams
--   sendLoop: dequeues from ysSendCh, writes to transport
module Network.LibP2P.Mux.Yamux.Session
  ( newSession
  , closeSession
  , openStream
  , acceptStream
  , ping
  , sendGoAway
  , recvLoop
  , sendLoop
  ) where

import Control.Concurrent.STM
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.Map.Strict as Map
import Data.Word (Word32)
import Network.LibP2P.Mux.Yamux.Frame
import Network.LibP2P.Mux.Yamux.Types

-- | Create a new Yamux session over a transport connection.
-- Client uses odd stream IDs starting at 1, server uses even starting at 2.
newSession :: SessionRole -> (ByteString -> IO ()) -> (Int -> IO ByteString) -> IO YamuxSession
newSession :: SessionRole
-> (ByteString -> IO ())
-> (Int -> IO ByteString)
-> IO YamuxSession
newSession SessionRole
role ByteString -> IO ()
writeFn Int -> IO ByteString
readFn = do
  let startId :: Word32
startId = case SessionRole
role of
        SessionRole
RoleClient -> Word32
1
        SessionRole
RoleServer -> Word32
2
  nextId <- Word32 -> IO (TVar Word32)
forall a. a -> IO (TVar a)
newTVarIO Word32
startId
  streams <- newTVarIO Map.empty
  acceptCh <- newTQueueIO
  sendCh <- newTQueueIO
  shutdown <- newTVarIO False
  remoteGoAway <- newTVarIO False
  pings <- newTVarIO Map.empty
  nextPingId <- newTVarIO 1
  pure
    YamuxSession
      { ysRole = role
      , ysNextStreamId = nextId
      , ysStreams = streams
      , ysAcceptCh = acceptCh
      , ysSendCh = sendCh
      , ysShutdown = shutdown
      , ysRemoteGoAway = remoteGoAway
      , ysPings = pings
      , ysNextPingId = nextPingId
      , ysWrite = writeFn
      , ysRead = readFn
      }

-- | Gracefully close the session by sending GoAway Normal.
closeSession :: YamuxSession -> IO ()
closeSession :: YamuxSession -> IO ()
closeSession YamuxSession
sess = YamuxSession -> GoAwayCode -> IO ()
sendGoAway YamuxSession
sess GoAwayCode
GoAwayNormal

-- | Open a new outbound stream. Allocates the next stream ID and sends SYN.
-- Returns YamuxSessionShutdown if the session has sent or received GoAway.
openStream :: YamuxSession -> IO (Either YamuxError YamuxStream)
openStream :: YamuxSession -> IO (Either YamuxError YamuxStream)
openStream YamuxSession
sess = do
  -- Check shutdown state
  canOpen <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
    shut <- TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar Bool
ysShutdown YamuxSession
sess)
    remote <- readTVar (ysRemoteGoAway sess)
    pure (not shut && not remote)
  if not canOpen
    then pure (Left YamuxSessionShutdown)
    else do
      -- Allocate stream ID (atomically increment by 2)
      sid <- atomically $ do
        nextId <- readTVar (ysNextStreamId sess)
        writeTVar (ysNextStreamId sess) (nextId + 2)
        pure nextId
      -- Create stream in SYNSent state
      stream <- newStream sess sid StreamSYNSent
      -- Register stream
      atomically $ modifyTVar' (ysStreams sess) (Map.insert sid stream)
      -- Send SYN frame (Data frame with SYN flag, no payload)
      let hdr =
            YamuxHeader
              { yhVersion :: Word8
yhVersion = Word8
0
              , yhType :: FrameType
yhType = FrameType
FrameData
              , yhFlags :: Flags
yhFlags = Flags
defaultFlags {flagSYN = True}
              , yhStreamId :: Word32
yhStreamId = Word32
sid
              , yhLength :: Word32
yhLength = Word32
0
              }
      atomically $ writeTQueue (ysSendCh sess) (hdr, BS.empty)
      pure (Right stream)

-- | Accept an inbound stream. Blocks until a remote SYN arrives.
-- Returns YamuxSessionShutdown if the session is shut down.
acceptStream :: YamuxSession -> IO (Either YamuxError YamuxStream)
acceptStream :: YamuxSession -> IO (Either YamuxError YamuxStream)
acceptStream YamuxSession
sess = do
  stream <- STM YamuxStream -> IO YamuxStream
forall a. STM a -> IO a
atomically (STM YamuxStream -> IO YamuxStream)
-> STM YamuxStream -> IO YamuxStream
forall a b. (a -> b) -> a -> b
$ TQueue YamuxStream -> STM YamuxStream
forall a. TQueue a -> STM a
readTQueue (YamuxSession -> TQueue YamuxStream
ysAcceptCh YamuxSession
sess)
  -- Send ACK (WindowUpdate frame with ACK flag)
  let hdr =
        YamuxHeader
          { yhVersion :: Word8
yhVersion = Word8
0
          , yhType :: FrameType
yhType = FrameType
FrameWindowUpdate
          , yhFlags :: Flags
yhFlags = Flags
defaultFlags {flagACK = True}
          , yhStreamId :: Word32
yhStreamId = YamuxStream -> Word32
ysStreamId YamuxStream
stream
          , yhLength :: Word32
yhLength = Word32
0
          }
  atomically $ writeTQueue (ysSendCh sess) (hdr, BS.empty)
  -- Transition to Established
  atomically $ writeTVar (ysState stream) StreamEstablished
  pure (Right stream)

-- | Send a Ping and wait for the ACK response.
-- Ping uses StreamID 0 and the Length field carries an opaque value.
ping :: YamuxSession -> IO (Either YamuxError ())
ping :: YamuxSession -> IO (Either YamuxError ())
ping YamuxSession
sess = do
  (pingId, waiter) <- STM (Word32, TMVar ()) -> IO (Word32, TMVar ())
forall a. STM a -> IO a
atomically (STM (Word32, TMVar ()) -> IO (Word32, TMVar ()))
-> STM (Word32, TMVar ()) -> IO (Word32, TMVar ())
forall a b. (a -> b) -> a -> b
$ do
    pid <- TVar Word32 -> STM Word32
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar Word32
ysNextPingId YamuxSession
sess)
    writeTVar (ysNextPingId sess) (pid + 1)
    w <- newEmptyTMVar
    modifyTVar' (ysPings sess) (Map.insert pid w)
    pure (pid, w)
  -- Send Ping SYN frame
  let hdr =
        YamuxHeader
          { yhVersion :: Word8
yhVersion = Word8
0
          , yhType :: FrameType
yhType = FrameType
FramePing
          , yhFlags :: Flags
yhFlags = Flags
defaultFlags {flagSYN = True}
          , yhStreamId :: Word32
yhStreamId = Word32
0
          , yhLength :: Word32
yhLength = Word32
pingId
          }
  atomically $ writeTQueue (ysSendCh sess) (hdr, BS.empty)
  -- Wait for ACK
  atomically $ takeTMVar waiter
  -- Cleanup
  atomically $ modifyTVar' (ysPings sess) (Map.delete pingId)
  pure (Right ())

-- | Send a GoAway frame with the specified error code.
-- Sets ysShutdown to True so no new streams can be opened.
sendGoAway :: YamuxSession -> GoAwayCode -> IO ()
sendGoAway :: YamuxSession -> GoAwayCode -> IO ()
sendGoAway YamuxSession
sess GoAwayCode
code = do
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxSession -> TVar Bool
ysShutdown YamuxSession
sess) Bool
True
  let errCode :: Word32
errCode = case GoAwayCode
code of
        GoAwayCode
GoAwayNormal -> Word32
0x00
        GoAwayCode
GoAwayProtocol -> Word32
0x01
        GoAwayCode
GoAwayInternal -> Word32
0x02
  let hdr :: YamuxHeader
hdr =
        YamuxHeader
          { yhVersion :: Word8
yhVersion = Word8
0
          , yhType :: FrameType
yhType = FrameType
FrameGoAway
          , yhFlags :: Flags
yhFlags = Flags
defaultFlags
          , yhStreamId :: Word32
yhStreamId = Word32
0
          , yhLength :: Word32
yhLength = Word32
errCode
          }
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (YamuxHeader, ByteString)
-> (YamuxHeader, ByteString) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess) (YamuxHeader
hdr, ByteString
BS.empty)

-- | Receive loop: reads 12-byte headers from transport and dispatches frames.
-- This loop runs until the transport connection is closed or an error occurs.
recvLoop :: YamuxSession -> IO ()
recvLoop :: YamuxSession -> IO ()
recvLoop YamuxSession
sess = IO ()
go
  where
    go :: IO ()
go = do
      -- Read 12-byte header
      headerBytes <- YamuxSession -> Int -> IO ByteString
ysRead YamuxSession
sess Int
headerSize
      case decodeHeader headerBytes of
        Left String
_err -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Protocol error, stop
        Right YamuxHeader
hdr -> do
          -- Verify version
          if YamuxHeader -> Word8
yhVersion YamuxHeader
hdr Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0
            then () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Protocol error
            else do
              YamuxSession -> YamuxHeader -> IO ()
dispatchFrame YamuxSession
sess YamuxHeader
hdr
              IO ()
go

-- | Dispatch a decoded frame to the appropriate handler.
dispatchFrame :: YamuxSession -> YamuxHeader -> IO ()
dispatchFrame :: YamuxSession -> YamuxHeader -> IO ()
dispatchFrame YamuxSession
sess YamuxHeader
hdr = case YamuxHeader -> FrameType
yhType YamuxHeader
hdr of
  FrameType
FrameData -> YamuxSession -> YamuxHeader -> IO ()
handleDataFrame YamuxSession
sess YamuxHeader
hdr
  FrameType
FrameWindowUpdate -> YamuxSession -> YamuxHeader -> IO ()
handleWindowUpdate YamuxSession
sess YamuxHeader
hdr
  FrameType
FramePing -> YamuxSession -> YamuxHeader -> IO ()
handlePing YamuxSession
sess YamuxHeader
hdr
  FrameType
FrameGoAway -> YamuxSession -> YamuxHeader -> IO ()
handleGoAway YamuxSession
sess YamuxHeader
hdr

-- | Handle a Data frame: read payload, manage stream state, deliver data.
handleDataFrame :: YamuxSession -> YamuxHeader -> IO ()
handleDataFrame :: YamuxSession -> YamuxHeader -> IO ()
handleDataFrame YamuxSession
sess YamuxHeader
hdr = do
  -- Read payload
  payload <-
    if YamuxHeader -> Word32
yhLength YamuxHeader
hdr Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
> Word32
0
      then YamuxSession -> Int -> IO ByteString
ysRead YamuxSession
sess (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (YamuxHeader -> Word32
yhLength YamuxHeader
hdr))
      else ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
BS.empty
  let sid = YamuxHeader -> Word32
yhStreamId YamuxHeader
hdr
      flags = YamuxHeader -> Flags
yhFlags YamuxHeader
hdr
  -- Handle SYN flag: create new inbound stream (with parity + duplicate validation)
  when (flagSYN flags) $ do
    valid <- atomically $ validateInboundSYN sess sid
    if not valid
      then sendGoAway sess GoAwayProtocol
      else do
        stream <- newStream sess sid StreamSYNReceived
        atomically $ do
          modifyTVar' (ysStreams sess) (Map.insert sid stream)
          writeTQueue (ysAcceptCh sess) stream
  -- Handle ACK flag: transition SYNSent -> Established
  when (flagACK flags) $ do
    mStream <- atomically $ Map.lookup sid <$> readTVar (ysStreams sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
        case st of
          StreamState
StreamSYNSent -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamEstablished
          StreamState
_ -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Deliver payload to stream buffer (with flow-control check)
  when (BS.length payload > 0) $ do
    mStream <- atomically $ Map.lookup sid <$> readTVar (ysStreams sess)
    case mStream of
      Just YamuxStream
stream -> do
        let payloadLen :: Word32
payloadLen = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
payload)
        overWindow <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ do
          w <- TVar Word32 -> STM Word32
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar Word32
ysRecvWindow YamuxStream
stream)
          if w < payloadLen
            then pure True
            else do
              writeTQueue (ysRecvBuf stream) payload
              writeTVar (ysRecvWindow stream) (w - payloadLen)
              pure False
        when overWindow $ sendGoAway sess GoAwayProtocol
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Handle FIN flag
  when (flagFIN flags) $ do
    mStream <- atomically $ Map.lookup sid <$> readTVar (ysStreams sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
        case st of
          StreamState
StreamEstablished -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamRemoteClose
          StreamState
StreamLocalClose -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamClosed
          StreamState
StreamSYNSent -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamRemoteClose
          StreamState
_ -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Handle RST flag
  when (flagRST flags) $ do
    mStream <- atomically $ Map.lookup sid <$> readTVar (ysStreams sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamReset
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Handle a WindowUpdate frame: update send window, manage stream lifecycle.
handleWindowUpdate :: YamuxSession -> YamuxHeader -> IO ()
handleWindowUpdate :: YamuxSession -> YamuxHeader -> IO ()
handleWindowUpdate YamuxSession
sess YamuxHeader
hdr = do
  let sid :: Word32
sid = YamuxHeader -> Word32
yhStreamId YamuxHeader
hdr
      flags :: Flags
flags = YamuxHeader -> Flags
yhFlags YamuxHeader
hdr
      delta :: Word32
delta = YamuxHeader -> Word32
yhLength YamuxHeader
hdr
  -- Handle SYN flag: create new inbound stream (with parity + duplicate validation)
  Bool -> IO () -> IO ()
when (Flags -> Bool
flagSYN Flags
flags) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    valid <- STM Bool -> IO Bool
forall a. STM a -> IO a
atomically (STM Bool -> IO Bool) -> STM Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ YamuxSession -> Word32 -> STM Bool
validateInboundSYN YamuxSession
sess Word32
sid
    if not valid
      then sendGoAway sess GoAwayProtocol
      else do
        stream <- newStream sess sid StreamSYNReceived
        atomically $ do
          modifyTVar' (ysStreams sess) (Map.insert sid stream)
          writeTQueue (ysAcceptCh sess) stream
  -- Handle ACK flag
  Bool -> IO () -> IO ()
when (Flags -> Bool
flagACK Flags
flags) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    mStream <- STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a. STM a -> IO a
atomically (STM (Maybe YamuxStream) -> IO (Maybe YamuxStream))
-> STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a b. (a -> b) -> a -> b
$ Word32 -> Map Word32 YamuxStream -> Maybe YamuxStream
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Word32
sid (Map Word32 YamuxStream -> Maybe YamuxStream)
-> STM (Map Word32 YamuxStream) -> STM (Maybe YamuxStream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map Word32 YamuxStream) -> STM (Map Word32 YamuxStream)
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar (Map Word32 YamuxStream)
ysStreams YamuxSession
sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
        case st of
          StreamState
StreamSYNSent -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamEstablished
          StreamState
_ -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Update send window
  Bool -> IO () -> IO ()
when (Word32
delta Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
> Word32
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    mStream <- STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a. STM a -> IO a
atomically (STM (Maybe YamuxStream) -> IO (Maybe YamuxStream))
-> STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a b. (a -> b) -> a -> b
$ Word32 -> Map Word32 YamuxStream -> Maybe YamuxStream
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Word32
sid (Map Word32 YamuxStream -> Maybe YamuxStream)
-> STM (Map Word32 YamuxStream) -> STM (Maybe YamuxStream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map Word32 YamuxStream) -> STM (Map Word32 YamuxStream)
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar (Map Word32 YamuxStream)
ysStreams YamuxSession
sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        w <- TVar Word32 -> STM Word32
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar Word32
ysSendWindow YamuxStream
stream)
        writeTVar (ysSendWindow stream) (w + delta)
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Handle FIN flag
  Bool -> IO () -> IO ()
when (Flags -> Bool
flagFIN Flags
flags) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    mStream <- STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a. STM a -> IO a
atomically (STM (Maybe YamuxStream) -> IO (Maybe YamuxStream))
-> STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a b. (a -> b) -> a -> b
$ Word32 -> Map Word32 YamuxStream -> Maybe YamuxStream
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Word32
sid (Map Word32 YamuxStream -> Maybe YamuxStream)
-> STM (Map Word32 YamuxStream) -> STM (Maybe YamuxStream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map Word32 YamuxStream) -> STM (Map Word32 YamuxStream)
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar (Map Word32 YamuxStream)
ysStreams YamuxSession
sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
        case st of
          StreamState
StreamEstablished -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamRemoteClose
          StreamState
StreamLocalClose -> TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamClosed
          StreamState
_ -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  -- Handle RST flag
  Bool -> IO () -> IO ()
when (Flags -> Bool
flagRST Flags
flags) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    mStream <- STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a. STM a -> IO a
atomically (STM (Maybe YamuxStream) -> IO (Maybe YamuxStream))
-> STM (Maybe YamuxStream) -> IO (Maybe YamuxStream)
forall a b. (a -> b) -> a -> b
$ Word32 -> Map Word32 YamuxStream -> Maybe YamuxStream
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Word32
sid (Map Word32 YamuxStream -> Maybe YamuxStream)
-> STM (Map Word32 YamuxStream) -> STM (Maybe YamuxStream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map Word32 YamuxStream) -> STM (Map Word32 YamuxStream)
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar (Map Word32 YamuxStream)
ysStreams YamuxSession
sess)
    case mStream of
      Just YamuxStream
stream -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamReset
      Maybe YamuxStream
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Handle a Ping frame (StreamID must be 0).
-- SYN: echo back with ACK flag and same opaque value.
-- ACK: resolve the matching pending ping.
handlePing :: YamuxSession -> YamuxHeader -> IO ()
handlePing :: YamuxSession -> YamuxHeader -> IO ()
handlePing YamuxSession
sess YamuxHeader
hdr
  | Flags -> Bool
flagSYN (YamuxHeader -> Flags
yhFlags YamuxHeader
hdr) = do
      -- Echo back Ping with ACK
      let respHdr :: YamuxHeader
respHdr =
            YamuxHeader
              { yhVersion :: Word8
yhVersion = Word8
0
              , yhType :: FrameType
yhType = FrameType
FramePing
              , yhFlags :: Flags
yhFlags = Flags
defaultFlags {flagACK = True}
              , yhStreamId :: Word32
yhStreamId = Word32
0
              , yhLength :: Word32
yhLength = YamuxHeader -> Word32
yhLength YamuxHeader
hdr -- echo opaque value
              }
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (YamuxHeader, ByteString)
-> (YamuxHeader, ByteString) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess) (YamuxHeader
respHdr, ByteString
BS.empty)
  | Flags -> Bool
flagACK (YamuxHeader -> Flags
yhFlags YamuxHeader
hdr) = do
      -- Resolve pending ping
      let pingId :: Word32
pingId = YamuxHeader -> Word32
yhLength YamuxHeader
hdr
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        pMap <- TVar (Map Word32 (TMVar ())) -> STM (Map Word32 (TMVar ()))
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar (Map Word32 (TMVar ()))
ysPings YamuxSession
sess)
        case Map.lookup pingId pMap of
          Just TMVar ()
waiter -> TMVar () -> () -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar ()
waiter ()
          Maybe (TMVar ())
Nothing -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Handle a GoAway frame (StreamID must be 0).
-- Parse error code and set ysRemoteGoAway.
handleGoAway :: YamuxSession -> YamuxHeader -> IO ()
handleGoAway :: YamuxSession -> YamuxHeader -> IO ()
handleGoAway YamuxSession
sess YamuxHeader
_hdr = do
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxSession -> TVar Bool
ysRemoteGoAway YamuxSession
sess) Bool
True

-- | Send loop: dequeues frames from ysSendCh and writes to transport.
sendLoop :: YamuxSession -> IO ()
sendLoop :: YamuxSession -> IO ()
sendLoop YamuxSession
sess = IO ()
forall {b}. IO b
go
  where
    go :: IO b
go = do
      (hdr, payload) <- STM (YamuxHeader, ByteString) -> IO (YamuxHeader, ByteString)
forall a. STM a -> IO a
atomically (STM (YamuxHeader, ByteString) -> IO (YamuxHeader, ByteString))
-> STM (YamuxHeader, ByteString) -> IO (YamuxHeader, ByteString)
forall a b. (a -> b) -> a -> b
$ TQueue (YamuxHeader, ByteString) -> STM (YamuxHeader, ByteString)
forall a. TQueue a -> STM a
readTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess)
      ysWrite sess (encodeHeader hdr)
      when (BS.length payload > 0) $ ysWrite sess payload
      go

-- | Create a new YamuxStream with the given initial state.
newStream :: YamuxSession -> Word32 -> StreamState -> IO YamuxStream
newStream :: YamuxSession -> Word32 -> StreamState -> IO YamuxStream
newStream YamuxSession
sess Word32
sid StreamState
initState = do
  stateVar <- StreamState -> IO (TVar StreamState)
forall a. a -> IO (TVar a)
newTVarIO StreamState
initState
  sendWin <- newTVarIO initialWindowSize
  recvWin <- newTVarIO initialWindowSize
  recvBuf <- newTQueueIO
  sendNotify <- newEmptyTMVarIO
  pure
    YamuxStream
      { ysStreamId = sid
      , ysState = stateVar
      , ysSendWindow = sendWin
      , ysRecvWindow = recvWin
      , ysRecvBuf = recvBuf
      , ysSendNotify = sendNotify
      , ysSession = sess
      }

-- | Validate an inbound SYN stream ID for parity and uniqueness.
-- Returns True if valid, False if protocol error (caller must send GoAway).
-- Remote peers must use the opposite parity: client expects even, server expects odd.
validateInboundSYN :: YamuxSession -> Word32 -> STM Bool
validateInboundSYN :: YamuxSession -> Word32 -> STM Bool
validateInboundSYN YamuxSession
sess Word32
sid = do
  let validParity :: Bool
validParity = case YamuxSession -> SessionRole
ysRole YamuxSession
sess of
        -- Server expects odd IDs (from client)
        SessionRole
RoleServer -> Word32 -> Bool
forall a. Integral a => a -> Bool
odd Word32
sid
        -- Client expects even IDs (from server)
        SessionRole
RoleClient -> Word32 -> Bool
forall a. Integral a => a -> Bool
even Word32
sid
  if Word32
sid Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
0 Bool -> Bool -> Bool
|| Bool -> Bool
not Bool
validParity
    then Bool -> STM Bool
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    else do
      streams <- TVar (Map Word32 YamuxStream) -> STM (Map Word32 YamuxStream)
forall a. TVar a -> STM a
readTVar (YamuxSession -> TVar (Map Word32 YamuxStream)
ysStreams YamuxSession
sess)
      pure (not (Map.member sid streams))

-- | Helper: execute action when condition is True.
when :: Bool -> IO () -> IO ()
when :: Bool -> IO () -> IO ()
when Bool
True IO ()
action = IO ()
action
when Bool
False IO ()
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()