-- | Identify protocol implementation (docs/07-protocols.md).
--
-- Protocol ID: /ipfs/id/1.0.0
--
-- After a connection is established, both sides exchange IdentifyInfo
-- messages to learn about each other's capabilities, listen addresses,
-- and agent version. The message has no length prefix — the boundary
-- is determined by stream closure.
--
-- Also implements Identify Push (/ipfs/id/push/1.0.0) for proactive
-- updates when local state changes.
module Network.LibP2P.Protocol.Identify.Identify
  ( -- * Protocol IDs
    identifyProtocolId
  , identifyPushProtocolId
    -- * Protocol logic
  , handleIdentify
  , requestIdentify
  , handleIdentifyPush
    -- * Building local info
  , buildLocalIdentify
    -- * Registration
  , registerIdentifyHandlers
    -- * Helpers
  , readUntilEOF
  ) where

import Control.Concurrent.STM (atomically, readTVar, writeTVar)
import Control.Exception (SomeException, catch)
import qualified Data.ByteString as BS
import qualified Data.Map.Strict as Map
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.Crypto.Protobuf (encodePublicKey)
import Network.LibP2P.Crypto.Key (kpPublic)
import Network.LibP2P.Multiaddr.Codec (encodeProtocols)
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr (..))
import Network.LibP2P.MultistreamSelect.Negotiation
  ( ProtocolId
  , StreamIO (..)
  , negotiateInitiator
  , NegotiationResult (..)
  )
import Network.LibP2P.Protocol.Identify.Message
  ( IdentifyInfo (..)
  , decodeIdentify
  , encodeIdentify
  , maxIdentifySize
  )
import Network.LibP2P.Switch.Types
  ( ActiveListener (..)
  , Connection (..)
  , MuxerSession (..)
  , Switch (..)
  )

-- | Identify protocol ID.
identifyProtocolId :: ProtocolId
identifyProtocolId :: Text
identifyProtocolId = Text
"/ipfs/id/1.0.0"

-- | Identify Push protocol ID.
identifyPushProtocolId :: ProtocolId
identifyPushProtocolId :: Text
identifyPushProtocolId = Text
"/ipfs/id/push/1.0.0"

-- | Handle an inbound Identify request (responder side).
--
-- Sends our local IdentifyInfo as protobuf to the stream, then closes.
-- The remote side reads until EOF.
handleIdentify :: Switch -> StreamIO -> PeerId -> IO ()
handleIdentify :: Switch -> StreamIO -> PeerId -> IO ()
handleIdentify Switch
sw StreamIO
stream PeerId
_remotePeerId = do
  info <- Switch -> Maybe Connection -> IO IdentifyInfo
buildLocalIdentify Switch
sw Maybe Connection
forall a. Maybe a
Nothing
  let encoded = IdentifyInfo -> ByteString
encodeIdentify IdentifyInfo
info
  streamWrite stream encoded
  streamClose stream  -- Signal EOF so the remote side's readUntilEOF terminates

-- | Request Identify from a remote peer (initiator side).
--
-- Opens a new stream, negotiates /ipfs/id/1.0.0, reads until EOF,
-- then decodes the protobuf message.
requestIdentify :: Connection -> IO (Either String IdentifyInfo)
requestIdentify :: Connection -> IO (Either String IdentifyInfo)
requestIdentify Connection
conn = do
  stream <- MuxerSession -> IO StreamIO
muxOpenStream (Connection -> MuxerSession
connSession Connection
conn)
  result <- negotiateInitiator stream [identifyProtocolId]
  case result of
    Accepted Text
_ -> do
      bytesOrErr <- StreamIO -> Int -> IO (Either String ByteString)
readUntilEOF StreamIO
stream Int
maxIdentifySize
      case bytesOrErr of
        Left String
err -> Either String IdentifyInfo -> IO (Either String IdentifyInfo)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String IdentifyInfo
forall a b. a -> Either a b
Left String
err)
        Right ByteString
bs -> case ByteString -> Either ParseError IdentifyInfo
decodeIdentify ByteString
bs of
          Left ParseError
parseErr -> Either String IdentifyInfo -> IO (Either String IdentifyInfo)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String IdentifyInfo
forall a b. a -> Either a b
Left (ParseError -> String
forall a. Show a => a -> String
show ParseError
parseErr))
          Right IdentifyInfo
info -> Either String IdentifyInfo -> IO (Either String IdentifyInfo)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IdentifyInfo -> Either String IdentifyInfo
forall a b. b -> Either a b
Right IdentifyInfo
info)
    NegotiationResult
NoProtocol -> Either String IdentifyInfo -> IO (Either String IdentifyInfo)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String IdentifyInfo
forall a b. a -> Either a b
Left String
"remote does not support identify")

-- | Handle an inbound Identify Push (responder side).
--
-- Reads the pushed IdentifyInfo from the remote peer.
handleIdentifyPush :: Switch -> StreamIO -> PeerId -> IO ()
handleIdentifyPush :: Switch -> StreamIO -> PeerId -> IO ()
handleIdentifyPush Switch
sw StreamIO
stream PeerId
remotePeerId = do
  bytesOrErr <- StreamIO -> Int -> IO (Either String ByteString)
readUntilEOF StreamIO
stream Int
maxIdentifySize
  case bytesOrErr of
    Left String
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Right ByteString
bs -> case ByteString -> Either ParseError IdentifyInfo
decodeIdentify ByteString
bs of
      Left ParseError
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Right IdentifyInfo
info -> STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        store <- TVar (Map PeerId IdentifyInfo) -> STM (Map PeerId IdentifyInfo)
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map PeerId IdentifyInfo)
swPeerStore Switch
sw)
        writeTVar (swPeerStore sw) (Map.insert remotePeerId info store)

-- | Build our local IdentifyInfo from Switch state.
buildLocalIdentify :: Switch -> Maybe Connection -> IO IdentifyInfo
buildLocalIdentify :: Switch -> Maybe Connection -> IO IdentifyInfo
buildLocalIdentify Switch
sw Maybe Connection
mConn = do
  (protocols, listenAddrs) <- STM ([Text], [Multiaddr]) -> IO ([Text], [Multiaddr])
forall a. STM a -> IO a
atomically (STM ([Text], [Multiaddr]) -> IO ([Text], [Multiaddr]))
-> STM ([Text], [Multiaddr]) -> IO ([Text], [Multiaddr])
forall a b. (a -> b) -> a -> b
$ do
    protos <- Map Text (StreamIO -> PeerId -> IO ()) -> [Text]
forall k a. Map k a -> [k]
Map.keys (Map Text (StreamIO -> PeerId -> IO ()) -> [Text])
-> STM (Map Text (StreamIO -> PeerId -> IO ())) -> STM [Text]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map Text (StreamIO -> PeerId -> IO ()))
-> STM (Map Text (StreamIO -> PeerId -> IO ()))
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map Text (StreamIO -> PeerId -> IO ()))
swProtocols Switch
sw)
    listeners <- readTVar (swListeners sw)
    pure (protos, map alAddress listeners)
  pure IdentifyInfo
    { idProtocolVersion = Just "ipfs/0.1.0"
    , idAgentVersion    = Just "libp2p-hs/0.1.0"
    , idPublicKey       = Just (encodePublicKey (kpPublic (swIdentityKey sw)))
    , idListenAddrs     = map (\(Multiaddr [Protocol]
ps) -> [Protocol] -> ByteString
encodeProtocols [Protocol]
ps) listenAddrs
    , idObservedAddr    = (\(Multiaddr [Protocol]
ps) -> [Protocol] -> ByteString
encodeProtocols [Protocol]
ps) . connRemoteAddr <$> mConn
    , idProtocols       = protocols
    }

-- | Register Identify protocol handlers on the Switch.
--
-- Registers:
--   /ipfs/id/1.0.0      — respond to Identify requests
--   /ipfs/id/push/1.0.0 — handle Identify Push from remote
registerIdentifyHandlers :: Switch -> IO ()
registerIdentifyHandlers :: Switch -> IO ()
registerIdentifyHandlers Switch
sw = do
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    protos <- TVar (Map Text (StreamIO -> PeerId -> IO ()))
-> STM (Map Text (StreamIO -> PeerId -> IO ()))
forall a. TVar a -> STM a
readTVar (Switch -> TVar (Map Text (StreamIO -> PeerId -> IO ()))
swProtocols Switch
sw)
    let protos' = Text
-> (StreamIO -> PeerId -> IO ())
-> Map Text (StreamIO -> PeerId -> IO ())
-> Map Text (StreamIO -> PeerId -> IO ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Text
identifyProtocolId (Switch -> StreamIO -> PeerId -> IO ()
handleIdentify Switch
sw) Map Text (StreamIO -> PeerId -> IO ())
protos
        protos'' = Text
-> (StreamIO -> PeerId -> IO ())
-> Map Text (StreamIO -> PeerId -> IO ())
-> Map Text (StreamIO -> PeerId -> IO ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Text
identifyPushProtocolId (Switch -> StreamIO -> PeerId -> IO ()
handleIdentifyPush Switch
sw) Map Text (StreamIO -> PeerId -> IO ())
protos'
    writeTVar (swProtocols sw) protos''

-- | Read bytes from a StreamIO until EOF, up to a maximum size.
--
-- Identify uses stream closure as message boundary (no length prefix).
-- Accumulates bytes until streamReadByte throws (EOF/stream closed).
readUntilEOF :: StreamIO -> Int -> IO (Either String BS.ByteString)
readUntilEOF :: StreamIO -> Int -> IO (Either String ByteString)
readUntilEOF StreamIO
stream Int
maxSize = [Word8] -> Int -> IO (Either String ByteString)
forall {a}.
IsString a =>
[Word8] -> Int -> IO (Either a ByteString)
go []  Int
0
  where
    go :: [Word8] -> Int -> IO (Either a ByteString)
go [Word8]
acc Int
size
      | Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxSize = Either a ByteString -> IO (Either a ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Either a ByteString
forall a b. a -> Either a b
Left a
"message exceeds maximum size")
      | Bool
otherwise = do
          result <- (Word8 -> Either () Word8
forall a b. b -> Either a b
Right (Word8 -> Either () Word8) -> IO Word8 -> IO (Either () Word8)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StreamIO -> IO Word8
streamReadByte StreamIO
stream) IO (Either () Word8)
-> (SomeException -> IO (Either () Word8)) -> IO (Either () Word8)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
                    (\(SomeException
_ :: SomeException) -> Either () Word8 -> IO (Either () Word8)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either () Word8
forall a b. a -> Either a b
Left ()))
          case result of
            Left () -> Either a ByteString -> IO (Either a ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either a ByteString
forall a b. b -> Either a b
Right ([Word8] -> ByteString
BS.pack ([Word8] -> [Word8]
forall a. [a] -> [a]
reverse [Word8]
acc)))
            Right Word8
b -> [Word8] -> Int -> IO (Either a ByteString)
go (Word8
b Word8 -> [Word8] -> [Word8]
forall a. a -> [a] -> [a]
: [Word8]
acc) (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)