-- | DHT node state, RPC handler, and record/provider stores.
--
-- The DHTNode is the top-level coordinator for Kademlia DHT operations.
-- It owns the routing table, record store, provider store, and handles
-- both inbound (as handler) and outbound (sendDHTRequest) RPC.
--
-- For testability, sendDHTRequest is a field of DHTNode, allowing mock
-- injection in tests without real network connections.
module Network.LibP2P.DHT.DHT
  ( -- * Types
    DHTNode (..)
  , DHTMode (..)
  , ProviderEntry (..)
  , Validator (..)
    -- * Construction
  , newDHTNode
    -- * Handler registration
  , registerDHTHandler
    -- * Inbound RPC handler
  , handleDHTRequest
    -- * Store operations
  , storeRecord
  , lookupRecord
  , addProvider
  , getProviders
    -- * Constants
  , dhtProtocolId
  ) where

import Control.Concurrent.STM
import Data.ByteString (ByteString)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Text (Text)
import Data.Time (UTCTime, getCurrentTime)
import Network.LibP2P.Crypto.PeerId (PeerId (..), peerIdBytes)
import Network.LibP2P.DHT.Distance (peerIdToKey)
import Network.LibP2P.DHT.Message
import Network.LibP2P.DHT.RoutingTable (RoutingTable, closestPeers, newRoutingTable)
import Network.LibP2P.DHT.Types
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr)
import Network.LibP2P.MultistreamSelect.Negotiation (StreamIO (..))
import Network.LibP2P.Switch.Switch (setStreamHandler)
import Network.LibP2P.Switch.Types (Switch (..))

-- | DHT protocol identifier for multistream-select.
dhtProtocolId :: Text
dhtProtocolId :: Text
dhtProtocolId = Text
"/ipfs/kad/1.0.0"

-- | Server or client mode.
data DHTMode = DHTServer | DHTClient
  deriving (Int -> DHTMode -> ShowS
[DHTMode] -> ShowS
DHTMode -> String
(Int -> DHTMode -> ShowS)
-> (DHTMode -> String) -> ([DHTMode] -> ShowS) -> Show DHTMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DHTMode -> ShowS
showsPrec :: Int -> DHTMode -> ShowS
$cshow :: DHTMode -> String
show :: DHTMode -> String
$cshowList :: [DHTMode] -> ShowS
showList :: [DHTMode] -> ShowS
Show, DHTMode -> DHTMode -> Bool
(DHTMode -> DHTMode -> Bool)
-> (DHTMode -> DHTMode -> Bool) -> Eq DHTMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DHTMode -> DHTMode -> Bool
== :: DHTMode -> DHTMode -> Bool
$c/= :: DHTMode -> DHTMode -> Bool
/= :: DHTMode -> DHTMode -> Bool
Eq)

-- | A provider record for content routing.
data ProviderEntry = ProviderEntry
  { ProviderEntry -> PeerId
peProvider  :: !PeerId
  , ProviderEntry -> [Multiaddr]
peAddrs     :: ![Multiaddr]
  , ProviderEntry -> UTCTime
peTimestamp :: !UTCTime
  } deriving (Int -> ProviderEntry -> ShowS
[ProviderEntry] -> ShowS
ProviderEntry -> String
(Int -> ProviderEntry -> ShowS)
-> (ProviderEntry -> String)
-> ([ProviderEntry] -> ShowS)
-> Show ProviderEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ProviderEntry -> ShowS
showsPrec :: Int -> ProviderEntry -> ShowS
$cshow :: ProviderEntry -> String
show :: ProviderEntry -> String
$cshowList :: [ProviderEntry] -> ShowS
showList :: [ProviderEntry] -> ShowS
Show, ProviderEntry -> ProviderEntry -> Bool
(ProviderEntry -> ProviderEntry -> Bool)
-> (ProviderEntry -> ProviderEntry -> Bool) -> Eq ProviderEntry
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ProviderEntry -> ProviderEntry -> Bool
== :: ProviderEntry -> ProviderEntry -> Bool
$c/= :: ProviderEntry -> ProviderEntry -> Bool
/= :: ProviderEntry -> ProviderEntry -> Bool
Eq)

-- | Validator interface for record validation.
data Validator = Validator
  { Validator -> ByteString -> ByteString -> Either String ()
valValidate :: ByteString -> ByteString -> Either String ()
  , Validator -> ByteString -> [ByteString] -> Either String Int
valSelect   :: ByteString -> [ByteString] -> Either String Int
  }

-- | Top-level DHT node state.
data DHTNode = DHTNode
  { DHTNode -> Switch
dhtSwitch        :: !Switch
  , DHTNode -> TVar RoutingTable
dhtRoutingTable  :: !(TVar RoutingTable)
  , DHTNode -> TVar (Map ByteString DHTRecord)
dhtRecordStore   :: !(TVar (Map ByteString DHTRecord))
  , DHTNode -> TVar (Map ByteString [ProviderEntry])
dhtProviderStore :: !(TVar (Map ByteString [ProviderEntry]))
  , DHTNode -> DHTKey
dhtLocalKey      :: !DHTKey
  , DHTNode -> PeerId
dhtLocalPeerId   :: !PeerId
  , DHTNode -> DHTMode
dhtMode          :: !DHTMode
  , DHTNode -> PeerId -> DHTMessage -> IO (Either String DHTMessage)
dhtSendRequest   :: !(PeerId -> DHTMessage -> IO (Either String DHTMessage))
    -- ^ Injectable RPC sender for testability
  }

-- | Create a new DHT node.
newDHTNode :: Switch -> DHTMode -> IO DHTNode
newDHTNode :: Switch -> DHTMode -> IO DHTNode
newDHTNode Switch
sw DHTMode
mode = do
  let localPid :: PeerId
localPid = Switch -> PeerId
swLocalPeerId Switch
sw
  rt <- RoutingTable -> IO (TVar RoutingTable)
forall a. a -> IO (TVar a)
newTVarIO (PeerId -> RoutingTable
newRoutingTable PeerId
localPid)
  records <- newTVarIO Map.empty
  providers <- newTVarIO Map.empty
  pure DHTNode
    { dhtSwitch        = sw
    , dhtRoutingTable  = rt
    , dhtRecordStore   = records
    , dhtProviderStore = providers
    , dhtLocalKey      = peerIdToKey localPid
    , dhtLocalPeerId   = localPid
    , dhtMode          = mode
    , dhtSendRequest   = \PeerId
_ DHTMessage
_ -> Either String DHTMessage -> IO (Either String DHTMessage)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String DHTMessage
forall a b. a -> Either a b
Left String
"sendDHTRequest not configured")
    }

-- | Register the DHT handler on the Switch (server mode only).
registerDHTHandler :: DHTNode -> IO ()
registerDHTHandler :: DHTNode -> IO ()
registerDHTHandler DHTNode
node =
  Switch -> Text -> StreamHandler -> IO ()
setStreamHandler (DHTNode -> Switch
dhtSwitch DHTNode
node) Text
dhtProtocolId (\StreamIO
stream PeerId
pid -> DHTNode -> StreamHandler
handleDHTRequest DHTNode
node StreamIO
stream PeerId
pid)

-- | Handle an inbound DHT RPC request.
handleDHTRequest :: DHTNode -> StreamIO -> PeerId -> IO ()
handleDHTRequest :: DHTNode -> StreamHandler
handleDHTRequest DHTNode
node StreamIO
stream PeerId
_remotePeerId = do
  result <- StreamIO -> Int -> IO (Either String DHTMessage)
readFramedMessage StreamIO
stream Int
maxDHTMessageSize
  case result of
    Left String
_err -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()  -- Stream error, just return
    Right DHTMessage
msg -> do
      response <- DHTNode -> DHTMessage -> PeerId -> IO DHTMessage
processRequest DHTNode
node DHTMessage
msg PeerId
_remotePeerId
      writeFramedMessage stream response

-- | Process a single DHT request and produce a response.
processRequest :: DHTNode -> DHTMessage -> PeerId -> IO DHTMessage
processRequest :: DHTNode -> DHTMessage -> PeerId -> IO DHTMessage
processRequest DHTNode
node DHTMessage
msg PeerId
remotePeerId =
  case DHTMessage -> MessageType
msgType DHTMessage
msg of
    MessageType
FindNode -> DHTNode -> DHTMessage -> IO DHTMessage
handleFindNode DHTNode
node DHTMessage
msg
    MessageType
GetValue -> DHTNode -> DHTMessage -> IO DHTMessage
handleGetValue DHTNode
node DHTMessage
msg
    MessageType
PutValue -> DHTNode -> DHTMessage -> IO DHTMessage
handlePutValue DHTNode
node DHTMessage
msg
    MessageType
AddProvider -> DHTNode -> DHTMessage -> PeerId -> IO DHTMessage
handleAddProvider DHTNode
node DHTMessage
msg PeerId
remotePeerId
    MessageType
GetProviders -> DHTNode -> DHTMessage -> IO DHTMessage
handleGetProviders DHTNode
node DHTMessage
msg

-- | FIND_NODE: return k closest peers to the requested key.
handleFindNode :: DHTNode -> DHTMessage -> IO DHTMessage
handleFindNode :: DHTNode -> DHTMessage -> IO DHTMessage
handleFindNode DHTNode
node DHTMessage
msg = do
  rt <- TVar RoutingTable -> IO RoutingTable
forall a. TVar a -> IO a
readTVarIO (DHTNode -> TVar RoutingTable
dhtRoutingTable DHTNode
node)
  let targetKey = ByteString -> DHTKey
DHTKey (DHTMessage -> ByteString
msgKey DHTMessage
msg)
      closest = DHTKey -> Int -> RoutingTable -> [BucketEntry]
closestPeers DHTKey
targetKey Int
kValue RoutingTable
rt
      peers = (BucketEntry -> DHTPeer) -> [BucketEntry] -> [DHTPeer]
forall a b. (a -> b) -> [a] -> [b]
map BucketEntry -> DHTPeer
entryToDHTPeer [BucketEntry]
closest
  pure emptyDHTMessage
    { msgType = FindNode
    , msgCloserPeers = peers
    }

-- | GET_VALUE: return stored record + k closest peers.
handleGetValue :: DHTNode -> DHTMessage -> IO DHTMessage
handleGetValue :: DHTNode -> DHTMessage -> IO DHTMessage
handleGetValue DHTNode
node DHTMessage
msg = do
  rt <- TVar RoutingTable -> IO RoutingTable
forall a. TVar a -> IO a
readTVarIO (DHTNode -> TVar RoutingTable
dhtRoutingTable DHTNode
node)
  records <- readTVarIO (dhtRecordStore node)
  let key = DHTMessage -> ByteString
msgKey DHTMessage
msg
      targetKey = ByteString -> DHTKey
DHTKey ByteString
key
      closest = DHTKey -> Int -> RoutingTable -> [BucketEntry]
closestPeers DHTKey
targetKey Int
kValue RoutingTable
rt
      peers = (BucketEntry -> DHTPeer) -> [BucketEntry] -> [DHTPeer]
forall a b. (a -> b) -> [a] -> [b]
map BucketEntry -> DHTPeer
entryToDHTPeer [BucketEntry]
closest
      rec = ByteString -> Map ByteString DHTRecord -> Maybe DHTRecord
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ByteString
key Map ByteString DHTRecord
records
  pure emptyDHTMessage
    { msgType = GetValue
    , msgRecord = rec
    , msgCloserPeers = peers
    }

-- | PUT_VALUE: store record and echo it back.
handlePutValue :: DHTNode -> DHTMessage -> IO DHTMessage
handlePutValue :: DHTNode -> DHTMessage -> IO DHTMessage
handlePutValue DHTNode
node DHTMessage
msg = do
  case DHTMessage -> Maybe DHTRecord
msgRecord DHTMessage
msg of
    Maybe DHTRecord
Nothing -> DHTMessage -> IO DHTMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DHTMessage
emptyDHTMessage { msgType = PutValue }
    Just DHTRecord
rec -> do
      DHTNode -> DHTRecord -> IO ()
storeRecord DHTNode
node DHTRecord
rec
      DHTMessage -> IO DHTMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure DHTMessage
emptyDHTMessage
        { msgType = PutValue
        , msgKey = msgKey msg
        , msgRecord = Just rec
        }

-- | ADD_PROVIDER: verify sender and store provider record.
handleAddProvider :: DHTNode -> DHTMessage -> PeerId -> IO DHTMessage
handleAddProvider :: DHTNode -> DHTMessage -> PeerId -> IO DHTMessage
handleAddProvider DHTNode
node DHTMessage
msg PeerId
remotePeerId = do
  now <- IO UTCTime
getCurrentTime
  -- Verify that provider peers match sender's Peer ID
  let validProviders = (DHTPeer -> Bool) -> [DHTPeer] -> [DHTPeer]
forall a. (a -> Bool) -> [a] -> [a]
filter (\DHTPeer
p -> DHTPeer -> ByteString
dhtPeerId DHTPeer
p ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== PeerId -> ByteString
peerIdBytes PeerId
remotePeerId) (DHTMessage -> [DHTPeer]
msgProviderPeers DHTMessage
msg)
  -- Store each valid provider keyed by msgKey
  mapM_ (\DHTPeer
p -> DHTNode -> ByteString -> ProviderEntry -> IO ()
addProvider DHTNode
node (DHTMessage -> ByteString
msgKey DHTMessage
msg) (DHTPeer -> UTCTime -> ProviderEntry
dhtPeerToProvider DHTPeer
p UTCTime
now)) validProviders
  pure emptyDHTMessage { msgType = AddProvider }

-- | GET_PROVIDERS: return stored providers + k closest peers.
handleGetProviders :: DHTNode -> DHTMessage -> IO DHTMessage
handleGetProviders :: DHTNode -> DHTMessage -> IO DHTMessage
handleGetProviders DHTNode
node DHTMessage
msg = do
  rt <- TVar RoutingTable -> IO RoutingTable
forall a. TVar a -> IO a
readTVarIO (DHTNode -> TVar RoutingTable
dhtRoutingTable DHTNode
node)
  providerMap <- readTVarIO (dhtProviderStore node)
  let key = DHTMessage -> ByteString
msgKey DHTMessage
msg
      targetKey = ByteString -> DHTKey
DHTKey ByteString
key
      closest = DHTKey -> Int -> RoutingTable -> [BucketEntry]
closestPeers DHTKey
targetKey Int
kValue RoutingTable
rt
      closerPeers = (BucketEntry -> DHTPeer) -> [BucketEntry] -> [DHTPeer]
forall a b. (a -> b) -> [a] -> [b]
map BucketEntry -> DHTPeer
entryToDHTPeer [BucketEntry]
closest
      providers = [ProviderEntry]
-> ByteString -> Map ByteString [ProviderEntry] -> [ProviderEntry]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] ByteString
key Map ByteString [ProviderEntry]
providerMap
      providerPeers = (ProviderEntry -> DHTPeer) -> [ProviderEntry] -> [DHTPeer]
forall a b. (a -> b) -> [a] -> [b]
map ProviderEntry -> DHTPeer
providerToDHTPeer [ProviderEntry]
providers
  pure emptyDHTMessage
    { msgType = GetProviders
    , msgCloserPeers = closerPeers
    , msgProviderPeers = providerPeers
    }

-- Store operations

-- | Store a record in the local datastore.
storeRecord :: DHTNode -> DHTRecord -> IO ()
storeRecord :: DHTNode -> DHTRecord -> IO ()
storeRecord DHTNode
node DHTRecord
rec = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
  TVar (Map ByteString DHTRecord)
-> (Map ByteString DHTRecord -> Map ByteString DHTRecord) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' (DHTNode -> TVar (Map ByteString DHTRecord)
dhtRecordStore DHTNode
node) (ByteString
-> DHTRecord
-> Map ByteString DHTRecord
-> Map ByteString DHTRecord
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert (DHTRecord -> ByteString
recKey DHTRecord
rec) DHTRecord
rec)

-- | Look up a record by key.
lookupRecord :: DHTNode -> ByteString -> IO (Maybe DHTRecord)
lookupRecord :: DHTNode -> ByteString -> IO (Maybe DHTRecord)
lookupRecord DHTNode
node ByteString
key = ByteString -> Map ByteString DHTRecord -> Maybe DHTRecord
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ByteString
key (Map ByteString DHTRecord -> Maybe DHTRecord)
-> IO (Map ByteString DHTRecord) -> IO (Maybe DHTRecord)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map ByteString DHTRecord) -> IO (Map ByteString DHTRecord)
forall a. TVar a -> IO a
readTVarIO (DHTNode -> TVar (Map ByteString DHTRecord)
dhtRecordStore DHTNode
node)

-- | Add a provider entry for a content key.
addProvider :: DHTNode -> ByteString -> ProviderEntry -> IO ()
addProvider :: DHTNode -> ByteString -> ProviderEntry -> IO ()
addProvider DHTNode
node ByteString
key ProviderEntry
entry = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
  TVar (Map ByteString [ProviderEntry])
-> (Map ByteString [ProviderEntry]
    -> Map ByteString [ProviderEntry])
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' (DHTNode -> TVar (Map ByteString [ProviderEntry])
dhtProviderStore DHTNode
node) ((Map ByteString [ProviderEntry] -> Map ByteString [ProviderEntry])
 -> STM ())
-> (Map ByteString [ProviderEntry]
    -> Map ByteString [ProviderEntry])
-> STM ()
forall a b. (a -> b) -> a -> b
$ \Map ByteString [ProviderEntry]
m ->
    ([ProviderEntry] -> [ProviderEntry] -> [ProviderEntry])
-> ByteString
-> [ProviderEntry]
-> Map ByteString [ProviderEntry]
-> Map ByteString [ProviderEntry]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith [ProviderEntry] -> [ProviderEntry] -> [ProviderEntry]
forall a. [a] -> [a] -> [a]
(++) ByteString
key [ProviderEntry
entry] Map ByteString [ProviderEntry]
m

-- | Get providers for a content key.
getProviders :: DHTNode -> ByteString -> IO [ProviderEntry]
getProviders :: DHTNode -> ByteString -> IO [ProviderEntry]
getProviders DHTNode
node ByteString
key =
  [ProviderEntry]
-> ByteString -> Map ByteString [ProviderEntry] -> [ProviderEntry]
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault [] ByteString
key (Map ByteString [ProviderEntry] -> [ProviderEntry])
-> IO (Map ByteString [ProviderEntry]) -> IO [ProviderEntry]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar (Map ByteString [ProviderEntry])
-> IO (Map ByteString [ProviderEntry])
forall a. TVar a -> IO a
readTVarIO (DHTNode -> TVar (Map ByteString [ProviderEntry])
dhtProviderStore DHTNode
node)

-- Helpers

-- | Convert a BucketEntry to a DHTPeer protobuf message.
entryToDHTPeer :: BucketEntry -> DHTPeer
entryToDHTPeer :: BucketEntry -> DHTPeer
entryToDHTPeer BucketEntry
entry = DHTPeer
  { dhtPeerId :: ByteString
dhtPeerId = PeerId -> ByteString
peerIdBytes (BucketEntry -> PeerId
entryPeerId BucketEntry
entry)
  , dhtPeerAddrs :: [ByteString]
dhtPeerAddrs = []  -- addresses would be encoded multiaddrs
  , dhtPeerConnType :: ConnectionType
dhtPeerConnType = BucketEntry -> ConnectionType
entryConnType BucketEntry
entry
  }

-- | Convert a DHTPeer from ADD_PROVIDER into a ProviderEntry.
dhtPeerToProvider :: DHTPeer -> UTCTime -> ProviderEntry
dhtPeerToProvider :: DHTPeer -> UTCTime -> ProviderEntry
dhtPeerToProvider DHTPeer
peer UTCTime
now = ProviderEntry
  { peProvider :: PeerId
peProvider  = ByteString -> PeerId
PeerId (DHTPeer -> ByteString
dhtPeerId DHTPeer
peer)
  , peAddrs :: [Multiaddr]
peAddrs     = []  -- raw bytes in dhtPeerAddrs; address decoding is not yet wired
  , peTimestamp :: UTCTime
peTimestamp = UTCTime
now
  }

-- | Convert a ProviderEntry to a DHTPeer protobuf message.
providerToDHTPeer :: ProviderEntry -> DHTPeer
providerToDHTPeer :: ProviderEntry -> DHTPeer
providerToDHTPeer ProviderEntry
pe = DHTPeer
  { dhtPeerId :: ByteString
dhtPeerId = PeerId -> ByteString
peerIdBytes (ProviderEntry -> PeerId
peProvider ProviderEntry
pe)
  , dhtPeerAddrs :: [ByteString]
dhtPeerAddrs = []  -- addresses would be encoded multiaddrs
  , dhtPeerConnType :: ConnectionType
dhtPeerConnType = ConnectionType
Connected
  }