-- | Yamux frame header encoding/decoding.
--
-- Every Yamux frame has a fixed 12-byte header:
-- Version (1) | Type (1) | Flags (2 BE) | StreamID (4 BE) | Length (4 BE)
module Network.LibP2P.Mux.Yamux.Frame
  ( FrameType (..)
  , Flags (..)
  , YamuxHeader (..)
  , GoAwayCode (..)
  , encodeHeader
  , decodeHeader
  , defaultFlags
  , headerSize
  , initialWindowSize
  ) where

import Data.Bits (shiftL, shiftR, (.&.), (.|.))
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Word (Word16, Word32, Word8)

-- | Yamux header size is always 12 bytes.
headerSize :: Int
headerSize :: Int
headerSize = Int
12

-- | Default initial window size: 256 KiB (262144 bytes).
initialWindowSize :: Word32
initialWindowSize :: Word32
initialWindowSize = Word32
262144

-- | Yamux frame types.
data FrameType
  = FrameData         -- ^ 0x00: Data frame with payload
  | FrameWindowUpdate -- ^ 0x01: Window size increment
  | FramePing         -- ^ 0x02: Keepalive/latency measurement
  | FrameGoAway       -- ^ 0x03: Session termination
  deriving (Int -> FrameType -> ShowS
[FrameType] -> ShowS
FrameType -> String
(Int -> FrameType -> ShowS)
-> (FrameType -> String)
-> ([FrameType] -> ShowS)
-> Show FrameType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> FrameType -> ShowS
showsPrec :: Int -> FrameType -> ShowS
$cshow :: FrameType -> String
show :: FrameType -> String
$cshowList :: [FrameType] -> ShowS
showList :: [FrameType] -> ShowS
Show, FrameType -> FrameType -> Bool
(FrameType -> FrameType -> Bool)
-> (FrameType -> FrameType -> Bool) -> Eq FrameType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FrameType -> FrameType -> Bool
== :: FrameType -> FrameType -> Bool
$c/= :: FrameType -> FrameType -> Bool
/= :: FrameType -> FrameType -> Bool
Eq)

-- | Go Away error codes.
data GoAwayCode
  = GoAwayNormal    -- ^ 0x00: Normal termination
  | GoAwayProtocol  -- ^ 0x01: Protocol error
  | GoAwayInternal  -- ^ 0x02: Internal error
  deriving (Int -> GoAwayCode -> ShowS
[GoAwayCode] -> ShowS
GoAwayCode -> String
(Int -> GoAwayCode -> ShowS)
-> (GoAwayCode -> String)
-> ([GoAwayCode] -> ShowS)
-> Show GoAwayCode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GoAwayCode -> ShowS
showsPrec :: Int -> GoAwayCode -> ShowS
$cshow :: GoAwayCode -> String
show :: GoAwayCode -> String
$cshowList :: [GoAwayCode] -> ShowS
showList :: [GoAwayCode] -> ShowS
Show, GoAwayCode -> GoAwayCode -> Bool
(GoAwayCode -> GoAwayCode -> Bool)
-> (GoAwayCode -> GoAwayCode -> Bool) -> Eq GoAwayCode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: GoAwayCode -> GoAwayCode -> Bool
== :: GoAwayCode -> GoAwayCode -> Bool
$c/= :: GoAwayCode -> GoAwayCode -> Bool
/= :: GoAwayCode -> GoAwayCode -> Bool
Eq)

-- | Frame flags (bitmask).
data Flags = Flags
  { Flags -> Bool
flagSYN :: !Bool -- ^ 0x0001: Open a new stream
  , Flags -> Bool
flagACK :: !Bool -- ^ 0x0002: Acknowledge a new stream
  , Flags -> Bool
flagFIN :: !Bool -- ^ 0x0004: Half-close the stream
  , Flags -> Bool
flagRST :: !Bool -- ^ 0x0008: Reset the stream
  }
  deriving (Int -> Flags -> ShowS
[Flags] -> ShowS
Flags -> String
(Int -> Flags -> ShowS)
-> (Flags -> String) -> ([Flags] -> ShowS) -> Show Flags
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Flags -> ShowS
showsPrec :: Int -> Flags -> ShowS
$cshow :: Flags -> String
show :: Flags -> String
$cshowList :: [Flags] -> ShowS
showList :: [Flags] -> ShowS
Show, Flags -> Flags -> Bool
(Flags -> Flags -> Bool) -> (Flags -> Flags -> Bool) -> Eq Flags
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Flags -> Flags -> Bool
== :: Flags -> Flags -> Bool
$c/= :: Flags -> Flags -> Bool
/= :: Flags -> Flags -> Bool
Eq)

-- | No flags set.
defaultFlags :: Flags
defaultFlags :: Flags
defaultFlags = Bool -> Bool -> Bool -> Bool -> Flags
Flags Bool
False Bool
False Bool
False Bool
False

-- | Complete Yamux frame header (12 bytes).
data YamuxHeader = YamuxHeader
  { YamuxHeader -> Word8
yhVersion :: !Word8
  , YamuxHeader -> FrameType
yhType :: !FrameType
  , YamuxHeader -> Flags
yhFlags :: !Flags
  , YamuxHeader -> Word32
yhStreamId :: !Word32
  , YamuxHeader -> Word32
yhLength :: !Word32
  }
  deriving (Int -> YamuxHeader -> ShowS
[YamuxHeader] -> ShowS
YamuxHeader -> String
(Int -> YamuxHeader -> ShowS)
-> (YamuxHeader -> String)
-> ([YamuxHeader] -> ShowS)
-> Show YamuxHeader
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> YamuxHeader -> ShowS
showsPrec :: Int -> YamuxHeader -> ShowS
$cshow :: YamuxHeader -> String
show :: YamuxHeader -> String
$cshowList :: [YamuxHeader] -> ShowS
showList :: [YamuxHeader] -> ShowS
Show, YamuxHeader -> YamuxHeader -> Bool
(YamuxHeader -> YamuxHeader -> Bool)
-> (YamuxHeader -> YamuxHeader -> Bool) -> Eq YamuxHeader
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: YamuxHeader -> YamuxHeader -> Bool
== :: YamuxHeader -> YamuxHeader -> Bool
$c/= :: YamuxHeader -> YamuxHeader -> Bool
/= :: YamuxHeader -> YamuxHeader -> Bool
Eq)

-- | Encode a frame type to its wire value.
frameTypeToWord8 :: FrameType -> Word8
frameTypeToWord8 :: FrameType -> Word8
frameTypeToWord8 FrameType
FrameData = Word8
0x00
frameTypeToWord8 FrameType
FrameWindowUpdate = Word8
0x01
frameTypeToWord8 FrameType
FramePing = Word8
0x02
frameTypeToWord8 FrameType
FrameGoAway = Word8
0x03

-- | Decode a frame type from wire value.
word8ToFrameType :: Word8 -> Either String FrameType
word8ToFrameType :: Word8 -> Either String FrameType
word8ToFrameType Word8
0x00 = FrameType -> Either String FrameType
forall a b. b -> Either a b
Right FrameType
FrameData
word8ToFrameType Word8
0x01 = FrameType -> Either String FrameType
forall a b. b -> Either a b
Right FrameType
FrameWindowUpdate
word8ToFrameType Word8
0x02 = FrameType -> Either String FrameType
forall a b. b -> Either a b
Right FrameType
FramePing
word8ToFrameType Word8
0x03 = FrameType -> Either String FrameType
forall a b. b -> Either a b
Right FrameType
FrameGoAway
word8ToFrameType Word8
n = String -> Either String FrameType
forall a b. a -> Either a b
Left (String -> Either String FrameType)
-> String -> Either String FrameType
forall a b. (a -> b) -> a -> b
$ String
"unknown frame type: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word8 -> String
forall a. Show a => a -> String
show Word8
n

-- | Encode flags to a big-endian uint16.
flagsToWord16 :: Flags -> Word16
flagsToWord16 :: Flags -> Word16
flagsToWord16 (Flags Bool
syn Bool
ack Bool
fin Bool
rst) =
  (if Bool
syn then Word16
0x0001 else Word16
0)
    Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (if Bool
ack then Word16
0x0002 else Word16
0)
    Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (if Bool
fin then Word16
0x0004 else Word16
0)
    Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (if Bool
rst then Word16
0x0008 else Word16
0)

-- | Decode flags from a big-endian uint16.
word16ToFlags :: Word16 -> Flags
word16ToFlags :: Word16 -> Flags
word16ToFlags Word16
w =
  Flags
    { flagSYN :: Bool
flagSYN = Word16
w Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0x0001 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word16
0
    , flagACK :: Bool
flagACK = Word16
w Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0x0002 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word16
0
    , flagFIN :: Bool
flagFIN = Word16
w Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0x0004 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word16
0
    , flagRST :: Bool
flagRST = Word16
w Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0x0008 Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word16
0
    }

-- | Encode a Yamux header to 12 bytes.
encodeHeader :: YamuxHeader -> ByteString
encodeHeader :: YamuxHeader -> ByteString
encodeHeader (YamuxHeader Word8
ver FrameType
typ Flags
flags Word32
sid Word32
len) =
  let f :: Word16
f = Flags -> Word16
flagsToWord16 Flags
flags
   in [Word8] -> ByteString
BS.pack
        [ Word8
ver
        , FrameType -> Word8
frameTypeToWord8 FrameType
typ
        , Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
f Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
        , Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
f
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
sid Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
sid Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
sid Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
sid
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
len Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24)
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
len Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
len Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
        , Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
len
        ]

-- | Decode a Yamux header from 12 bytes.
decodeHeader :: ByteString -> Either String YamuxHeader
decodeHeader :: ByteString -> Either String YamuxHeader
decodeHeader ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
12 = String -> Either String YamuxHeader
forall a b. a -> Either a b
Left String
"decodeHeader: need 12 bytes"
  | Bool
otherwise = do
      let ver :: Word8
ver = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
0
      typ <- Word8 -> Either String FrameType
word8ToFrameType (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
1)
      let flags =
            Word16 -> Flags
word16ToFlags
              ( (Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
2) Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftL` Int
8)
                  Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
3)
              )
          sid =
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
4) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
24)
              Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
5) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16)
              Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
6) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
8)
              Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
7)
          len =
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
8) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
24)
              Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
9) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16)
              Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
10) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
8)
              Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|. Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
11)
      Right (YamuxHeader ver typ flags sid len)