-- | DCUtR (Direct Connection Upgrade through Relay) message encoding/decoding.
--
-- Protocol: /libp2p/dcutr
-- Wire format: varint-length-prefixed protobuf, max 4 KiB
--
-- HolePunch message:
--   field 1: type (required) - CONNECT(100) or SYNC(300)
--   field 2: ObsAddrs (repeated bytes) - observed multiaddr binary
module Network.LibP2P.NAT.DCUtR.Message
  ( -- * Types
    HolePunchType (..)
  , HolePunchMessage (..)
    -- * Type conversion
  , holePunchTypeToWord
  , wordToHolePunchType
    -- * Protobuf encode/decode (no framing)
  , encodeHolePunchMessage
  , decodeHolePunchMessage
    -- * Wire framing (uvarint length prefix)
  , encodeHolePunchFramed
  , decodeHolePunchFramed
    -- * Stream I/O helpers
  , writeHolePunchMessage
  , readHolePunchMessage
    -- * Constants
  , maxDCUtRMessageSize
  , dcutrProtocolId
  ) where

import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import Data.Text (Text)
import Data.Word (Word32)
import Proto3.Wire.Decode (Parser, RawMessage, ParseError, at, one, repeated, parse)
import qualified Proto3.Wire.Decode as Decode
import qualified Proto3.Wire.Encode as Encode
import Proto3.Wire.Types (FieldNumber (..))
import Network.LibP2P.Core.Varint (encodeUvarint, decodeUvarint)
import Network.LibP2P.MultistreamSelect.Negotiation (StreamIO (..))

-- | DCUtR protocol identifier.
dcutrProtocolId :: Text
dcutrProtocolId :: Text
dcutrProtocolId = Text
"/libp2p/dcutr"

-- | Maximum DCUtR message size: 4 KiB (per spec).
maxDCUtRMessageSize :: Int
maxDCUtRMessageSize :: Int
maxDCUtRMessageSize = Int
4096

-- | HolePunch message type.
data HolePunchType = HPConnect | HPSync
  deriving (Int -> HolePunchType -> ShowS
[HolePunchType] -> ShowS
HolePunchType -> String
(Int -> HolePunchType -> ShowS)
-> (HolePunchType -> String)
-> ([HolePunchType] -> ShowS)
-> Show HolePunchType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HolePunchType -> ShowS
showsPrec :: Int -> HolePunchType -> ShowS
$cshow :: HolePunchType -> String
show :: HolePunchType -> String
$cshowList :: [HolePunchType] -> ShowS
showList :: [HolePunchType] -> ShowS
Show, HolePunchType -> HolePunchType -> Bool
(HolePunchType -> HolePunchType -> Bool)
-> (HolePunchType -> HolePunchType -> Bool) -> Eq HolePunchType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HolePunchType -> HolePunchType -> Bool
== :: HolePunchType -> HolePunchType -> Bool
$c/= :: HolePunchType -> HolePunchType -> Bool
/= :: HolePunchType -> HolePunchType -> Bool
Eq)

-- | HolePunch message.
data HolePunchMessage = HolePunchMessage
  { HolePunchMessage -> HolePunchType
hpType     :: !HolePunchType    -- ^ field 1 (required)
  , HolePunchMessage -> [ByteString]
hpObsAddrs :: ![ByteString]     -- ^ field 2 (repeated, binary multiaddrs)
  } deriving (Int -> HolePunchMessage -> ShowS
[HolePunchMessage] -> ShowS
HolePunchMessage -> String
(Int -> HolePunchMessage -> ShowS)
-> (HolePunchMessage -> String)
-> ([HolePunchMessage] -> ShowS)
-> Show HolePunchMessage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HolePunchMessage -> ShowS
showsPrec :: Int -> HolePunchMessage -> ShowS
$cshow :: HolePunchMessage -> String
show :: HolePunchMessage -> String
$cshowList :: [HolePunchMessage] -> ShowS
showList :: [HolePunchMessage] -> ShowS
Show, HolePunchMessage -> HolePunchMessage -> Bool
(HolePunchMessage -> HolePunchMessage -> Bool)
-> (HolePunchMessage -> HolePunchMessage -> Bool)
-> Eq HolePunchMessage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: HolePunchMessage -> HolePunchMessage -> Bool
== :: HolePunchMessage -> HolePunchMessage -> Bool
$c/= :: HolePunchMessage -> HolePunchMessage -> Bool
/= :: HolePunchMessage -> HolePunchMessage -> Bool
Eq)

-- | Convert HolePunchType to wire value.
holePunchTypeToWord :: HolePunchType -> Word32
holePunchTypeToWord :: HolePunchType -> Word32
holePunchTypeToWord HolePunchType
HPConnect = Word32
100
holePunchTypeToWord HolePunchType
HPSync    = Word32
300

-- | Convert wire value to HolePunchType.
wordToHolePunchType :: Word32 -> Maybe HolePunchType
wordToHolePunchType :: Word32 -> Maybe HolePunchType
wordToHolePunchType Word32
100 = HolePunchType -> Maybe HolePunchType
forall a. a -> Maybe a
Just HolePunchType
HPConnect
wordToHolePunchType Word32
300 = HolePunchType -> Maybe HolePunchType
forall a. a -> Maybe a
Just HolePunchType
HPSync
wordToHolePunchType Word32
_   = Maybe HolePunchType
forall a. Maybe a
Nothing

-- Encoding

-- | Encode HolePunchMessage to protobuf (no framing).
encodeHolePunchMessage :: HolePunchMessage -> ByteString
encodeHolePunchMessage :: HolePunchMessage -> ByteString
encodeHolePunchMessage HolePunchMessage
msg = LazyByteString -> ByteString
BL.toStrict (LazyByteString -> ByteString) -> LazyByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ MessageBuilder -> LazyByteString
Encode.toLazyByteString (MessageBuilder -> LazyByteString)
-> MessageBuilder -> LazyByteString
forall a b. (a -> b) -> a -> b
$
     FieldNumber -> Word32 -> MessageBuilder
Encode.uint32 (Word64 -> FieldNumber
FieldNumber Word64
1) (HolePunchType -> Word32
holePunchTypeToWord (HolePunchMessage -> HolePunchType
hpType HolePunchMessage
msg))
  MessageBuilder -> MessageBuilder -> MessageBuilder
forall a. Semigroup a => a -> a -> a
<> (ByteString -> MessageBuilder) -> [ByteString] -> MessageBuilder
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\ByteString
a -> FieldNumber -> ByteString -> MessageBuilder
Encode.byteString (Word64 -> FieldNumber
FieldNumber Word64
2) ByteString
a) (HolePunchMessage -> [ByteString]
hpObsAddrs HolePunchMessage
msg)

-- Decoding

-- | Decode HolePunchMessage from protobuf.
decodeHolePunchMessage :: ByteString -> Either ParseError HolePunchMessage
decodeHolePunchMessage :: ByteString -> Either ParseError HolePunchMessage
decodeHolePunchMessage = Parser RawMessage HolePunchMessage
-> ByteString -> Either ParseError HolePunchMessage
forall a. Parser RawMessage a -> ByteString -> Either ParseError a
parse Parser RawMessage HolePunchMessage
holePunchParser

holePunchParser :: Parser RawMessage HolePunchMessage
holePunchParser :: Parser RawMessage HolePunchMessage
holePunchParser = HolePunchType -> [ByteString] -> HolePunchMessage
HolePunchMessage
  (HolePunchType -> [ByteString] -> HolePunchMessage)
-> Parser RawMessage HolePunchType
-> Parser RawMessage ([ByteString] -> HolePunchMessage)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Word32 -> HolePunchType
toHPType (Word32 -> HolePunchType)
-> Parser RawMessage Word32 -> Parser RawMessage HolePunchType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser RawField Word32 -> FieldNumber -> Parser RawMessage Word32
forall a. Parser RawField a -> FieldNumber -> Parser RawMessage a
at (Parser RawPrimitive Word32 -> Word32 -> Parser RawField Word32
forall a. Parser RawPrimitive a -> a -> Parser RawField a
one Parser RawPrimitive Word32
Decode.uint32 Word32
0) (Word64 -> FieldNumber
FieldNumber Word64
1))
  Parser RawMessage ([ByteString] -> HolePunchMessage)
-> Parser RawMessage [ByteString]
-> Parser RawMessage HolePunchMessage
forall a b.
Parser RawMessage (a -> b)
-> Parser RawMessage a -> Parser RawMessage b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser RawField [ByteString]
-> FieldNumber -> Parser RawMessage [ByteString]
forall a. Parser RawField a -> FieldNumber -> Parser RawMessage a
at (Parser RawPrimitive ByteString -> Parser RawField [ByteString]
forall a. Parser RawPrimitive a -> Parser RawField [a]
repeated Parser RawPrimitive ByteString
Decode.byteString) (Word64 -> FieldNumber
FieldNumber Word64
2)
  where
    toHPType :: Word32 -> HolePunchType
    toHPType :: Word32 -> HolePunchType
toHPType Word32
w = case Word32 -> Maybe HolePunchType
wordToHolePunchType Word32
w of
      Just HolePunchType
t  -> HolePunchType
t
      Maybe HolePunchType
Nothing -> HolePunchType
HPConnect  -- default for unknown (should not happen per spec)

-- Wire framing

-- | Encode with uvarint length prefix.
encodeHolePunchFramed :: HolePunchMessage -> ByteString
encodeHolePunchFramed :: HolePunchMessage -> ByteString
encodeHolePunchFramed HolePunchMessage
msg =
  let payload :: ByteString
payload = HolePunchMessage -> ByteString
encodeHolePunchMessage HolePunchMessage
msg
      lenPrefix :: ByteString
lenPrefix = Word64 -> ByteString
encodeUvarint (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
payload))
  in ByteString
lenPrefix ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
payload

-- | Decode from uvarint-length-prefixed bytes.
decodeHolePunchFramed :: Int -> ByteString -> Either String HolePunchMessage
decodeHolePunchFramed :: Int -> ByteString -> Either String HolePunchMessage
decodeHolePunchFramed Int
maxSize ByteString
bs = do
  (len, rest) <- ByteString -> Either String (Word64, ByteString)
decodeUvarint ByteString
bs
  let msgLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len :: Int
  if msgLen > maxSize
    then Left $ "DCUtR message too large: " ++ show msgLen ++ " > " ++ show maxSize
    else if BS.length rest < msgLen
      then Left "DCUtR message truncated"
      else case decodeHolePunchMessage (BS.take msgLen rest) of
        Left ParseError
err -> String -> Either String HolePunchMessage
forall a b. a -> Either a b
Left (String -> Either String HolePunchMessage)
-> String -> Either String HolePunchMessage
forall a b. (a -> b) -> a -> b
$ String
"DCUtR protobuf decode error: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ParseError -> String
forall a. Show a => a -> String
show ParseError
err
        Right HolePunchMessage
msg -> HolePunchMessage -> Either String HolePunchMessage
forall a b. b -> Either a b
Right HolePunchMessage
msg

-- Stream I/O

-- | Write a framed HolePunch message to a stream.
writeHolePunchMessage :: StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage :: StreamIO -> HolePunchMessage -> IO ()
writeHolePunchMessage StreamIO
stream HolePunchMessage
msg = StreamIO -> ByteString -> IO ()
streamWrite StreamIO
stream (HolePunchMessage -> ByteString
encodeHolePunchFramed HolePunchMessage
msg)

-- | Read a framed HolePunch message from a stream.
readHolePunchMessage :: StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage :: StreamIO -> Int -> IO (Either String HolePunchMessage)
readHolePunchMessage StreamIO
stream Int
maxSize = do
  varintBytes <- StreamIO -> IO ByteString
readVarintBytes StreamIO
stream
  case decodeUvarint varintBytes of
    Left String
err -> Either String HolePunchMessage
-> IO (Either String HolePunchMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String HolePunchMessage
forall a b. a -> Either a b
Left (String -> Either String HolePunchMessage)
-> String -> Either String HolePunchMessage
forall a b. (a -> b) -> a -> b
$ String
"DCUtR varint decode error: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
err)
    Right (Word64
len, ByteString
_) -> do
      let msgLen :: Int
msgLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len :: Int
      if Int
msgLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxSize
        then Either String HolePunchMessage
-> IO (Either String HolePunchMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String HolePunchMessage
forall a b. a -> Either a b
Left (String -> Either String HolePunchMessage)
-> String -> Either String HolePunchMessage
forall a b. (a -> b) -> a -> b
$ String
"DCUtR message too large: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
msgLen)
        else do
          payload <- StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
msgLen
          case decodeHolePunchMessage payload of
            Left ParseError
err -> Either String HolePunchMessage
-> IO (Either String HolePunchMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String HolePunchMessage
forall a b. a -> Either a b
Left (String -> Either String HolePunchMessage)
-> String -> Either String HolePunchMessage
forall a b. (a -> b) -> a -> b
$ String
"DCUtR protobuf decode error: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ParseError -> String
forall a. Show a => a -> String
show ParseError
err)
            Right HolePunchMessage
msg -> Either String HolePunchMessage
-> IO (Either String HolePunchMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HolePunchMessage -> Either String HolePunchMessage
forall a b. b -> Either a b
Right HolePunchMessage
msg)

readExact :: StreamIO -> Int -> IO ByteString
readExact :: StreamIO -> Int -> IO ByteString
readExact StreamIO
stream Int
n = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> IO [Word8] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> IO Word8) -> [Int] -> IO [Word8]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (IO Word8 -> Int -> IO Word8
forall a b. a -> b -> a
const (StreamIO -> IO Word8
streamReadByte StreamIO
stream)) [Int
1 .. Int
n]

readVarintBytes :: StreamIO -> IO ByteString
readVarintBytes :: StreamIO -> IO ByteString
readVarintBytes StreamIO
stream = [Word8] -> Int -> IO ByteString
forall {t}. (Ord t, Num t) => [Word8] -> t -> IO ByteString
go [] (Int
0 :: Int)
  where
    go :: [Word8] -> t -> IO ByteString
go [Word8]
acc t
n
      | t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
10 = ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Word8] -> ByteString
BS.pack ([Word8] -> [Word8]
forall a. [a] -> [a]
reverse [Word8]
acc))
      | Bool
otherwise = do
          b <- StreamIO -> IO Word8
streamReadByte StreamIO
stream
          if b < 0x80
            then pure (BS.pack (reverse (b : acc)))
            else go (b : acc) (n + 1)