-- | Yamux stream operations: read, write, close, reset.
--
-- Implements per-stream data transfer with flow control
-- per HashiCorp yamux spec.md §Data/WindowUpdate/Stream Close.
module Network.LibP2P.Mux.Yamux.Stream
  ( streamWrite
  , streamRead
  , streamClose
  , streamReset
  ) where

import Control.Concurrent.STM
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Network.LibP2P.Mux.Yamux.Frame
import Network.LibP2P.Mux.Yamux.Types

-- | Write data to a stream. Blocks when send window is 0.
-- Writable states: SYNSent (optimistic), Established, RemoteClose.
-- Returns YamuxStreamClosed on LocalClose/Closed.
-- Returns YamuxStreamReset on Reset.
streamWrite :: YamuxStream -> ByteString -> IO (Either YamuxError ())
streamWrite :: YamuxStream -> ByteString -> IO (Either YamuxError ())
streamWrite YamuxStream
stream ByteString
payload
  | ByteString -> Bool
BS.null ByteString
payload = Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
  | Bool
otherwise = do
      st <- TVar StreamState -> IO StreamState
forall a. TVar a -> IO a
readTVarIO (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
      case st of
        StreamState
StreamClosed -> Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
        StreamState
StreamLocalClose -> Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
        StreamState
StreamReset -> Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
YamuxStreamReset)
        StreamState
_ -> YamuxStream -> ByteString -> IO (Either YamuxError ())
writeChunked YamuxStream
stream ByteString
payload

-- | Write payload in chunks respecting the send window.
-- Uses STM retry to block when window is 0 — STM automatically watches
-- ysSendWindow and re-evaluates when handleWindowUpdate writes to it.
writeChunked :: YamuxStream -> ByteString -> IO (Either YamuxError ())
writeChunked :: YamuxStream -> ByteString -> IO (Either YamuxError ())
writeChunked YamuxStream
stream ByteString
payload
  | ByteString -> Bool
BS.null ByteString
payload = Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
  | Bool
otherwise = do
      -- Wait for available send window (blocks via STM retry if 0)
      result <- STM (Either YamuxError ByteString)
-> IO (Either YamuxError ByteString)
forall a. STM a -> IO a
atomically (STM (Either YamuxError ByteString)
 -> IO (Either YamuxError ByteString))
-> STM (Either YamuxError ByteString)
-> IO (Either YamuxError ByteString)
forall a b. (a -> b) -> a -> b
$ do
        st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
        case st of
          StreamState
StreamReset -> Either YamuxError ByteString -> STM (Either YamuxError ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ByteString
forall a b. a -> Either a b
Left YamuxError
YamuxStreamReset)
          StreamState
StreamClosed -> Either YamuxError ByteString -> STM (Either YamuxError ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ByteString
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
          StreamState
_ -> do
            window <- TVar Word32 -> STM Word32
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar Word32
ysSendWindow YamuxStream
stream)
            if window == 0
              then retry -- blocks until ysSendWindow changes
              else do
                let chunkSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
window) (ByteString -> Int
BS.length ByteString
payload)
                let (c, _) = BS.splitAt chunkSize payload
                writeTVar (ysSendWindow stream) (window - fromIntegral chunkSize)
                pure (Right c)
      case result of
        Left YamuxError
err -> Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
err)
        Right ByteString
chunk -> do
          -- Enqueue Data frame
          let hdr :: YamuxHeader
hdr =
                YamuxHeader
                  { yhVersion :: Word8
yhVersion = Word8
0
                  , yhType :: FrameType
yhType = FrameType
FrameData
                  , yhFlags :: Flags
yhFlags = Flags
defaultFlags
                  , yhStreamId :: Word32
yhStreamId = YamuxStream -> Word32
ysStreamId YamuxStream
stream
                  , yhLength :: Word32
yhLength = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
chunk)
                  }
          let sess :: YamuxSession
sess = YamuxStream -> YamuxSession
ysSession YamuxStream
stream
          STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (YamuxHeader, ByteString)
-> (YamuxHeader, ByteString) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess) (YamuxHeader
hdr, ByteString
chunk)
          -- Continue with remaining data
          let remaining :: ByteString
remaining = Int -> ByteString -> ByteString
BS.drop (ByteString -> Int
BS.length ByteString
chunk) ByteString
payload
          YamuxStream -> ByteString -> IO (Either YamuxError ())
writeChunked YamuxStream
stream ByteString
remaining

-- | Read data from a stream. Blocks on empty buffer.
-- Readable states: SYNSent, SYNReceived, Established, LocalClose, RemoteClose (draining).
-- Returns YamuxStreamClosed on Closed/RemoteClose with empty buffer.
-- Returns YamuxStreamReset on Reset.
streamRead :: YamuxStream -> IO (Either YamuxError ByteString)
streamRead :: YamuxStream -> IO (Either YamuxError ByteString)
streamRead YamuxStream
stream = do
  result <- STM (Either YamuxError ByteString)
-> IO (Either YamuxError ByteString)
forall a. STM a -> IO a
atomically (STM (Either YamuxError ByteString)
 -> IO (Either YamuxError ByteString))
-> STM (Either YamuxError ByteString)
-> IO (Either YamuxError ByteString)
forall a b. (a -> b) -> a -> b
$ do
    -- Try to read from buffer first (non-blocking check)
    mData <- TQueue ByteString -> STM (Maybe ByteString)
forall a. TQueue a -> STM (Maybe a)
tryReadTQueue (YamuxStream -> TQueue ByteString
ysRecvBuf YamuxStream
stream)
    case mData of
      Just ByteString
payload -> Either YamuxError ByteString -> STM (Either YamuxError ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either YamuxError ByteString
forall a b. b -> Either a b
Right ByteString
payload)
      Maybe ByteString
Nothing -> do
        -- Buffer empty, check state
        st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
        case st of
          StreamState
StreamClosed -> Either YamuxError ByteString -> STM (Either YamuxError ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ByteString
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
          StreamState
StreamReset -> Either YamuxError ByteString -> STM (Either YamuxError ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ByteString
forall a b. a -> Either a b
Left YamuxError
YamuxStreamReset)
          StreamState
StreamRemoteClose -> Either YamuxError ByteString -> STM (Either YamuxError ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ByteString
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
          StreamState
_ -> STM (Either YamuxError ByteString)
forall a. STM a
retry -- block until data arrives
  case result of
    Right ByteString
payload -> do
      -- Send WindowUpdate to replenish the recv window
      let consumed :: Word32
consumed = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
payload)
      let hdr :: YamuxHeader
hdr =
            YamuxHeader
              { yhVersion :: Word8
yhVersion = Word8
0
              , yhType :: FrameType
yhType = FrameType
FrameWindowUpdate
              , yhFlags :: Flags
yhFlags = Flags
defaultFlags
              , yhStreamId :: Word32
yhStreamId = YamuxStream -> Word32
ysStreamId YamuxStream
stream
              , yhLength :: Word32
yhLength = Word32
consumed
              }
      let sess :: YamuxSession
sess = YamuxStream -> YamuxSession
ysSession YamuxStream
stream
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        TQueue (YamuxHeader, ByteString)
-> (YamuxHeader, ByteString) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess) (YamuxHeader
hdr, ByteString
BS.empty)
        -- Increment recv window
        w <- TVar Word32 -> STM Word32
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar Word32
ysRecvWindow YamuxStream
stream)
        writeTVar (ysRecvWindow stream) (w + consumed)
      Either YamuxError ByteString -> IO (Either YamuxError ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either YamuxError ByteString
forall a b. b -> Either a b
Right ByteString
payload)
    Left YamuxError
err -> Either YamuxError ByteString -> IO (Either YamuxError ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ByteString
forall a b. a -> Either a b
Left YamuxError
err)

-- | Half-close the stream by sending FIN flag.
-- Only valid in Established or RemoteClose states.
streamClose :: YamuxStream -> IO (Either YamuxError ())
streamClose :: YamuxStream -> IO (Either YamuxError ())
streamClose YamuxStream
stream = do
  result <- STM (Either YamuxError ()) -> IO (Either YamuxError ())
forall a. STM a -> IO a
atomically (STM (Either YamuxError ()) -> IO (Either YamuxError ()))
-> STM (Either YamuxError ()) -> IO (Either YamuxError ())
forall a b. (a -> b) -> a -> b
$ do
    st <- TVar StreamState -> STM StreamState
forall a. TVar a -> STM a
readTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream)
    case st of
      StreamState
StreamEstablished -> do
        TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamLocalClose
        Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
      StreamState
StreamRemoteClose -> do
        TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamClosed
        Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
      StreamState
StreamSYNSent -> do
        TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamLocalClose
        Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
      StreamState
StreamSYNReceived -> do
        TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamLocalClose
        Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
      StreamState
StreamClosed -> Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
      StreamState
StreamLocalClose -> Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
YamuxStreamClosed)
      StreamState
StreamReset -> Either YamuxError () -> STM (Either YamuxError ())
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
YamuxStreamReset)
  case result of
    Right () -> do
      -- Send FIN frame
      let hdr :: YamuxHeader
hdr =
            YamuxHeader
              { yhVersion :: Word8
yhVersion = Word8
0
              , yhType :: FrameType
yhType = FrameType
FrameData
              , yhFlags :: Flags
yhFlags = Flags
defaultFlags {flagFIN = True}
              , yhStreamId :: Word32
yhStreamId = YamuxStream -> Word32
ysStreamId YamuxStream
stream
              , yhLength :: Word32
yhLength = Word32
0
              }
      let sess :: YamuxSession
sess = YamuxStream -> YamuxSession
ysSession YamuxStream
stream
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (YamuxHeader, ByteString)
-> (YamuxHeader, ByteString) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess) (YamuxHeader
hdr, ByteString
BS.empty)
      Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either YamuxError ()
forall a b. b -> Either a b
Right ())
    Left YamuxError
err -> Either YamuxError () -> IO (Either YamuxError ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (YamuxError -> Either YamuxError ()
forall a b. a -> Either a b
Left YamuxError
err)

-- | Reset the stream by sending RST flag.
streamReset :: YamuxStream -> IO ()
streamReset :: YamuxStream -> IO ()
streamReset YamuxStream
stream = do
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar StreamState -> StreamState -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (YamuxStream -> TVar StreamState
ysState YamuxStream
stream) StreamState
StreamReset
  -- Send RST frame
  let hdr :: YamuxHeader
hdr =
        YamuxHeader
          { yhVersion :: Word8
yhVersion = Word8
0
          , yhType :: FrameType
yhType = FrameType
FrameData
          , yhFlags :: Flags
yhFlags = Flags
defaultFlags {flagRST = True}
          , yhStreamId :: Word32
yhStreamId = YamuxStream -> Word32
ysStreamId YamuxStream
stream
          , yhLength :: Word32
yhLength = Word32
0
          }
  let sess :: YamuxSession
sess = YamuxStream -> YamuxSession
ysSession YamuxStream
stream
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (YamuxHeader, ByteString)
-> (YamuxHeader, ByteString) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue (YamuxSession -> TQueue (YamuxHeader, ByteString)
ysSendCh YamuxSession
sess) (YamuxHeader
hdr, ByteString
BS.empty)