-- | k-Bucket routing table for the Kademlia DHT.
--
-- Organizes peers into 256 buckets by XOR distance prefix length.
-- Each bucket holds up to k=20 peers, ordered by last-seen time
-- (head = least-recently-seen, tail = most-recently-seen).
module Network.LibP2P.DHT.RoutingTable
  ( KBucket (..)
  , RoutingTable (..)
  , newRoutingTable
  , emptyBucket
  , insertPeer
  , removePeer
  , closestPeers
  , bucketForPeer
  , bucketSize
  , allPeers
  ) where

import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IntMap
import Data.Sequence (Seq (..))
import qualified Data.Sequence as Seq
import Network.LibP2P.Crypto.PeerId (PeerId)
import Network.LibP2P.DHT.Distance (commonPrefixLength, peerIdToKey, sortByDistance)
import Network.LibP2P.DHT.Types
  ( BucketEntry (..)
  , DHTKey (..)
  , InsertResult (..)
  , kValue
  , numBuckets
  )

-- | A k-bucket holding up to k peers.
-- Ordered by last-seen: head = LRS (least recently seen), tail = MRS.
data KBucket = KBucket
  { KBucket -> Seq BucketEntry
bucketEntries :: !(Seq BucketEntry)
  } deriving (Int -> KBucket -> ShowS
[KBucket] -> ShowS
KBucket -> String
(Int -> KBucket -> ShowS)
-> (KBucket -> String) -> ([KBucket] -> ShowS) -> Show KBucket
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KBucket -> ShowS
showsPrec :: Int -> KBucket -> ShowS
$cshow :: KBucket -> String
show :: KBucket -> String
$cshowList :: [KBucket] -> ShowS
showList :: [KBucket] -> ShowS
Show)

-- | Full routing table: 256 k-buckets indexed by prefix length.
-- Uses sparse IntMap — only non-empty buckets are stored.
data RoutingTable = RoutingTable
  { RoutingTable -> DHTKey
rtSelfKey :: !DHTKey
  , RoutingTable -> IntMap KBucket
rtBuckets :: !(IntMap KBucket)
  , RoutingTable -> Int
rtK       :: !Int
  } deriving (Int -> RoutingTable -> ShowS
[RoutingTable] -> ShowS
RoutingTable -> String
(Int -> RoutingTable -> ShowS)
-> (RoutingTable -> String)
-> ([RoutingTable] -> ShowS)
-> Show RoutingTable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RoutingTable -> ShowS
showsPrec :: Int -> RoutingTable -> ShowS
$cshow :: RoutingTable -> String
show :: RoutingTable -> String
$cshowList :: [RoutingTable] -> ShowS
showList :: [RoutingTable] -> ShowS
Show)

-- | Create a new empty routing table for the given local peer.
newRoutingTable :: PeerId -> RoutingTable
newRoutingTable :: PeerId -> RoutingTable
newRoutingTable PeerId
localPeer = RoutingTable
  { rtSelfKey :: DHTKey
rtSelfKey = PeerId -> DHTKey
peerIdToKey PeerId
localPeer
  , rtBuckets :: IntMap KBucket
rtBuckets = IntMap KBucket
forall a. IntMap a
IntMap.empty
  , rtK :: Int
rtK       = Int
kValue
  }

-- | An empty k-bucket.
emptyBucket :: KBucket
emptyBucket :: KBucket
emptyBucket = Seq BucketEntry -> KBucket
KBucket Seq BucketEntry
forall a. Seq a
Seq.empty

-- | Insert a peer into the routing table.
--
-- Rules:
-- 1. Self is never inserted (returns Updated as no-op).
-- 2. If peer already exists in the bucket, move it to tail → Updated.
-- 3. If bucket has space, append to tail → Inserted.
-- 4. If bucket is full, return BucketFull with the LRS peer ID.
insertPeer :: BucketEntry -> RoutingTable -> (RoutingTable, InsertResult)
insertPeer :: BucketEntry -> RoutingTable -> (RoutingTable, InsertResult)
insertPeer BucketEntry
entry RoutingTable
rt
  -- Reject self-insertion
  | BucketEntry -> DHTKey
entryKey BucketEntry
entry DHTKey -> DHTKey -> Bool
forall a. Eq a => a -> a -> Bool
== RoutingTable -> DHTKey
rtSelfKey RoutingTable
rt = (RoutingTable
rt, InsertResult
Updated)
  | Bool
otherwise =
    let idx :: Int
idx = DHTKey -> DHTKey -> Int
bucketIndex (RoutingTable -> DHTKey
rtSelfKey RoutingTable
rt) (BucketEntry -> DHTKey
entryKey BucketEntry
entry)
        bucket :: KBucket
bucket = KBucket -> Int -> IntMap KBucket -> KBucket
forall a. a -> Int -> IntMap a -> a
IntMap.findWithDefault KBucket
emptyBucket Int
idx (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt)
        entries :: Seq BucketEntry
entries = KBucket -> Seq BucketEntry
bucketEntries KBucket
bucket
    in case PeerId -> Seq BucketEntry -> Maybe Int
findEntryIndex (BucketEntry -> PeerId
entryPeerId BucketEntry
entry) Seq BucketEntry
entries of
         -- Peer already in bucket: remove from current position, append to tail
         Just Int
i ->
           let entries' :: Seq BucketEntry
entries' = Int -> Seq BucketEntry -> Seq BucketEntry
forall a. Int -> Seq a -> Seq a
Seq.deleteAt Int
i Seq BucketEntry
entries Seq BucketEntry -> BucketEntry -> Seq BucketEntry
forall a. Seq a -> a -> Seq a
Seq.|> BucketEntry
entry
               rt' :: RoutingTable
rt' = RoutingTable
rt { rtBuckets = IntMap.insert idx (KBucket entries') (rtBuckets rt) }
           in (RoutingTable
rt', InsertResult
Updated)
         -- Peer not in bucket
         Maybe Int
Nothing
           | Seq BucketEntry -> Int
forall a. Seq a -> Int
Seq.length Seq BucketEntry
entries Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< RoutingTable -> Int
rtK RoutingTable
rt ->
             -- Space available: append to tail
             let entries' :: Seq BucketEntry
entries' = Seq BucketEntry
entries Seq BucketEntry -> BucketEntry -> Seq BucketEntry
forall a. Seq a -> a -> Seq a
Seq.|> BucketEntry
entry
                 rt' :: RoutingTable
rt' = RoutingTable
rt { rtBuckets = IntMap.insert idx (KBucket entries') (rtBuckets rt) }
             in (RoutingTable
rt', InsertResult
Inserted)
           | Bool
otherwise ->
             -- Bucket full: return LRS peer for potential eviction
             case Seq BucketEntry
entries of
               BucketEntry
lrs :<| Seq BucketEntry
_ -> (RoutingTable
rt, PeerId -> InsertResult
BucketFull (BucketEntry -> PeerId
entryPeerId BucketEntry
lrs))
               Seq BucketEntry
_         -> (RoutingTable
rt, PeerId -> InsertResult
BucketFull (BucketEntry -> PeerId
entryPeerId BucketEntry
entry)) -- should not happen

-- | Remove a peer from the routing table.
removePeer :: PeerId -> RoutingTable -> RoutingTable
removePeer :: PeerId -> RoutingTable -> RoutingTable
removePeer PeerId
pid RoutingTable
rt =
  let key :: DHTKey
key = PeerId -> DHTKey
peerIdToKey PeerId
pid
      idx :: Int
idx = DHTKey -> DHTKey -> Int
bucketIndex (RoutingTable -> DHTKey
rtSelfKey RoutingTable
rt) DHTKey
key
      bucket :: KBucket
bucket = KBucket -> Int -> IntMap KBucket -> KBucket
forall a. a -> Int -> IntMap a -> a
IntMap.findWithDefault KBucket
emptyBucket Int
idx (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt)
      entries :: Seq BucketEntry
entries = KBucket -> Seq BucketEntry
bucketEntries KBucket
bucket
  in case PeerId -> Seq BucketEntry -> Maybe Int
findEntryIndex PeerId
pid Seq BucketEntry
entries of
       Maybe Int
Nothing -> RoutingTable
rt  -- not found, no-op
       Just Int
i ->
         let entries' :: Seq BucketEntry
entries' = Int -> Seq BucketEntry -> Seq BucketEntry
forall a. Int -> Seq a -> Seq a
Seq.deleteAt Int
i Seq BucketEntry
entries
             buckets' :: IntMap KBucket
buckets' = if Seq BucketEntry -> Bool
forall a. Seq a -> Bool
Seq.null Seq BucketEntry
entries'
                        then Int -> IntMap KBucket -> IntMap KBucket
forall a. Int -> IntMap a -> IntMap a
IntMap.delete Int
idx (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt)
                        else Int -> KBucket -> IntMap KBucket -> IntMap KBucket
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
idx (Seq BucketEntry -> KBucket
KBucket Seq BucketEntry
entries') (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt)
         in RoutingTable
rt { rtBuckets = buckets' }

-- | Find the n closest peers to a target key, sorted by XOR distance.
-- Searches across all buckets.
closestPeers :: DHTKey -> Int -> RoutingTable -> [BucketEntry]
closestPeers :: DHTKey -> Int -> RoutingTable -> [BucketEntry]
closestPeers DHTKey
target Int
n RoutingTable
rt =
  let allEntries :: [BucketEntry]
allEntries = (KBucket -> [BucketEntry]) -> [KBucket] -> [BucketEntry]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Seq BucketEntry -> [BucketEntry]
forall {a}. Seq a -> [a]
toList (Seq BucketEntry -> [BucketEntry])
-> (KBucket -> Seq BucketEntry) -> KBucket -> [BucketEntry]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KBucket -> Seq BucketEntry
bucketEntries) (IntMap KBucket -> [KBucket]
forall a. IntMap a -> [a]
IntMap.elems (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt))
      sorted :: [BucketEntry]
sorted = DHTKey -> [BucketEntry] -> [BucketEntry]
sortByDistance DHTKey
target [BucketEntry]
allEntries
  in Int -> [BucketEntry] -> [BucketEntry]
forall a. Int -> [a] -> [a]
take Int
n [BucketEntry]
sorted
  where
    toList :: Seq a -> [a]
toList = (a -> [a] -> [a]) -> [a] -> Seq a -> [a]
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (:) []

-- | Compute the bucket index for a peer key relative to the local key.
-- This is the common prefix length (0-255).
bucketForPeer :: DHTKey -> RoutingTable -> Int
bucketForPeer :: DHTKey -> RoutingTable -> Int
bucketForPeer DHTKey
key RoutingTable
rt =
  DHTKey -> DHTKey -> Int
bucketIndex (RoutingTable -> DHTKey
rtSelfKey RoutingTable
rt) DHTKey
key

-- | Get the number of entries in a specific bucket.
bucketSize :: Int -> RoutingTable -> Int
bucketSize :: Int -> RoutingTable -> Int
bucketSize Int
idx RoutingTable
rt =
  case Int -> IntMap KBucket -> Maybe KBucket
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
idx (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt) of
    Maybe KBucket
Nothing -> Int
0
    Just KBucket
bucket -> Seq BucketEntry -> Int
forall a. Seq a -> Int
Seq.length (KBucket -> Seq BucketEntry
bucketEntries KBucket
bucket)

-- | Get all peers across all buckets.
allPeers :: RoutingTable -> [BucketEntry]
allPeers :: RoutingTable -> [BucketEntry]
allPeers RoutingTable
rt =
  (KBucket -> [BucketEntry]) -> [KBucket] -> [BucketEntry]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Seq BucketEntry -> [BucketEntry]
forall {a}. Seq a -> [a]
toList (Seq BucketEntry -> [BucketEntry])
-> (KBucket -> Seq BucketEntry) -> KBucket -> [BucketEntry]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KBucket -> Seq BucketEntry
bucketEntries) (IntMap KBucket -> [KBucket]
forall a. IntMap a -> [a]
IntMap.elems (RoutingTable -> IntMap KBucket
rtBuckets RoutingTable
rt))
  where
    toList :: Seq a -> [a]
toList = (a -> [a] -> [a]) -> [a] -> Seq a -> [a]
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (:) []

-- Internal helpers

-- | Compute bucket index: common prefix length, clamped to [0, numBuckets-1].
bucketIndex :: DHTKey -> DHTKey -> Int
bucketIndex :: DHTKey -> DHTKey -> Int
bucketIndex DHTKey
selfKey DHTKey
peerKey =
  let cpl :: Int
cpl = DHTKey -> DHTKey -> Int
commonPrefixLength DHTKey
selfKey DHTKey
peerKey
  in Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
cpl (Int
numBuckets Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- | Find the index of a peer in a sequence by PeerId.
findEntryIndex :: PeerId -> Seq BucketEntry -> Maybe Int
findEntryIndex :: PeerId -> Seq BucketEntry -> Maybe Int
findEntryIndex PeerId
pid Seq BucketEntry
entries =
  (BucketEntry -> Bool) -> Seq BucketEntry -> Maybe Int
forall a. (a -> Bool) -> Seq a -> Maybe Int
Seq.findIndexL (\BucketEntry
e -> BucketEntry -> PeerId
entryPeerId BucketEntry
e PeerId -> PeerId -> Bool
forall a. Eq a => a -> a -> Bool
== PeerId
pid) Seq BucketEntry
entries