-- | Deterministic protobuf encoding for libp2p PublicKey/PrivateKey messages.
--
-- Hand-rolled encoding to guarantee deterministic output:
-- - Minimal varint encoding
-- - Fields in field number order (Type=1, Data=2)
-- - All fields present, no extras
module Network.LibP2P.Crypto.Protobuf
  ( encodePublicKey
  , decodePublicKey
  ) where

import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Word (Word64, Word8)
import Numeric (showHex)
import Network.LibP2P.Core.Varint (decodeUvarint, encodeUvarint)
import Network.LibP2P.Crypto.Key (KeyType (..), PublicKey (..))

-- | Protobuf KeyType enum values.
keyTypeToProto :: KeyType -> Word64
keyTypeToProto :: KeyType -> Word64
keyTypeToProto KeyType
Ed25519 = Word64
1

keyTypeFromProto :: Word64 -> Either String KeyType
keyTypeFromProto :: Word64 -> Either String KeyType
keyTypeFromProto Word64
1 = KeyType -> Either String KeyType
forall a b. b -> Either a b
Right KeyType
Ed25519
keyTypeFromProto Word64
n = String -> Either String KeyType
forall a b. a -> Either a b
Left (String -> Either String KeyType)
-> String -> Either String KeyType
forall a b. (a -> b) -> a -> b
$ String
"unknown KeyType: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Word64 -> String
forall a. Show a => a -> String
show Word64
n

-- | Deterministic protobuf encoding of a PublicKey message.
--
-- Layout:
--   Field 1 (Type): tag=0x08 (field 1, wire type 0=varint), value=keytype
--   Field 2 (Data): tag=0x12 (field 2, wire type 2=length-delimited), length, bytes
encodePublicKey :: PublicKey -> ByteString
encodePublicKey :: PublicKey -> ByteString
encodePublicKey (PublicKey KeyType
kt ByteString
rawKey) =
  -- Field 1: tag 0x08 + varint value
  Word8 -> ByteString
BS.singleton Word8
0x08 ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
encodeUvarint (KeyType -> Word64
keyTypeToProto KeyType
kt)
    -- Field 2: tag 0x12 + varint length + raw bytes
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word8 -> ByteString
BS.singleton Word8
0x12
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Word64 -> ByteString
encodeUvarint (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
rawKey))
    ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
rawKey

-- | Decode a protobuf-encoded PublicKey message.
decodePublicKey :: ByteString -> Either String PublicKey
decodePublicKey :: ByteString -> Either String PublicKey
decodePublicKey ByteString
bs = do
  -- Field 1: expect tag 0x08
  (tag1, rest1) <- Word8 -> ByteString -> String -> Either String (Word8, ByteString)
takeExpectedByte Word8
0x08 ByteString
bs String
"expected tag 0x08 for field 1"
  _ <- pure tag1
  (typeVal, rest2) <- decodeUvarint rest1
  kt <- keyTypeFromProto typeVal
  -- Field 2: expect tag 0x12
  (_, rest3) <- takeExpectedByte 0x12 rest2 "expected tag 0x12 for field 2"
  (dataLen, rest4) <- decodeUvarint rest3
  let len = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
dataLen :: Int
  if BS.length rest4 < len
    then Left "decodePublicKey: not enough bytes for key data"
    else
      let keyData = Int -> ByteString -> ByteString
BS.take Int
len ByteString
rest4
       in Right (PublicKey kt keyData)
  where
    takeExpectedByte :: Word8 -> ByteString -> String -> Either String (Word8, ByteString)
    takeExpectedByte :: Word8 -> ByteString -> String -> Either String (Word8, ByteString)
takeExpectedByte Word8
expected ByteString
input String
msg
      | ByteString -> Bool
BS.null ByteString
input = String -> Either String (Word8, ByteString)
forall a b. a -> Either a b
Left (String -> Either String (Word8, ByteString))
-> String -> Either String (Word8, ByteString)
forall a b. (a -> b) -> a -> b
$ String
"decodePublicKey: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
msg String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" (empty input)"
      | HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
input Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
expected =
          String -> Either String (Word8, ByteString)
forall a b. a -> Either a b
Left (String -> Either String (Word8, ByteString))
-> String -> Either String (Word8, ByteString)
forall a b. (a -> b) -> a -> b
$ String
"decodePublicKey: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
msg String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" (got 0x" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Word8 -> String -> String
forall a. Integral a => a -> String -> String
showHex (HasCallStack => ByteString -> Word8
ByteString -> Word8
BS.head ByteString
input) String
")"
      | Bool
otherwise = (Word8, ByteString) -> Either String (Word8, ByteString)
forall a b. b -> Either a b
Right (Word8
expected, HasCallStack => ByteString -> ByteString
ByteString -> ByteString
BS.tail ByteString
input)