-- | XOR distance metric for the Kademlia DHT.
--
-- All distance computations operate on 256-bit SHA-256 keys.
-- Distance is the XOR of two keys interpreted as a 256-bit unsigned integer.
module Network.LibP2P.DHT.Distance
  ( peerIdToKey
  , xorDistance
  , commonPrefixLength
  , compareDistance
  , sortByDistance
  ) where

import Crypto.Hash (Digest, SHA256, hash)
import Data.ByteArray (convert)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Bits (xor, testBit)
import Data.List (sortBy)
import Data.Word (Word8)
import Network.LibP2P.Crypto.PeerId (PeerId, peerIdBytes)
import Network.LibP2P.DHT.Types (BucketEntry (..), DHTKey (..))

-- | Convert a Peer ID to its DHT key by hashing with SHA-256.
peerIdToKey :: PeerId -> DHTKey
peerIdToKey :: PeerId -> DHTKey
peerIdToKey PeerId
pid =
  let digest :: Digest SHA256
digest = ByteString -> Digest SHA256
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
hash (PeerId -> ByteString
peerIdBytes PeerId
pid) :: Digest SHA256
  in ByteString -> DHTKey
DHTKey (Digest SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Digest SHA256
digest)

-- | Compute XOR distance between two DHT keys.
xorDistance :: DHTKey -> DHTKey -> DHTKey
xorDistance :: DHTKey -> DHTKey -> DHTKey
xorDistance (DHTKey ByteString
a) (DHTKey ByteString
b) =
  ByteString -> DHTKey
DHTKey ([Word8] -> ByteString
BS.pack ((Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> [Word8]
forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
BS.zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor ByteString
a ByteString
b))

-- | Count the number of leading zero bits (common prefix length).
-- Same key → 256. First bit differs → 0.
commonPrefixLength :: DHTKey -> DHTKey -> Int
commonPrefixLength :: DHTKey -> DHTKey -> Int
commonPrefixLength DHTKey
a DHTKey
b =
  let (DHTKey ByteString
d) = DHTKey -> DHTKey -> DHTKey
xorDistance DHTKey
a DHTKey
b
  in ByteString -> Int
countLeadingZeros ByteString
d

-- | Count leading zero bits in a ByteString (big-endian).
countLeadingZeros :: ByteString -> Int
countLeadingZeros :: ByteString -> Int
countLeadingZeros ByteString
bs = Int -> Int
go Int
0
  where
    len :: Int
len = ByteString -> Int
BS.length ByteString
bs
    go :: Int -> Int
go Int
i
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len  = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8
      | Word8
byte Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 = Int -> Int
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      | Bool
otherwise = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Word8 -> Int
clzByte Word8
byte
      where byte :: Word8
byte = HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
BS.index ByteString
bs Int
i

-- | Count leading zeros of a single byte (0-8).
clzByte :: Word8 -> Int
clzByte :: Word8 -> Int
clzByte Word8
0 = Int
8
clzByte Word8
w = Int -> Int
go Int
7
  where
    go :: Int -> Int
go (-1) = Int
8
    go Int
bit
      | Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word8
w Int
bit = Int
7 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bit
      | Bool
otherwise     = Int -> Int
go (Int
bit Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

-- | Compare two keys by distance to a target.
-- Returns LT if a is closer to target than b, GT if farther, EQ if equal.
compareDistance :: DHTKey -> DHTKey -> DHTKey -> Ordering
compareDistance :: DHTKey -> DHTKey -> DHTKey -> Ordering
compareDistance DHTKey
target DHTKey
a DHTKey
b =
  DHTKey -> DHTKey -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (DHTKey -> DHTKey -> DHTKey
xorDistance DHTKey
target DHTKey
a) (DHTKey -> DHTKey -> DHTKey
xorDistance DHTKey
target DHTKey
b)

-- | Sort bucket entries by ascending XOR distance to a target key.
sortByDistance :: DHTKey -> [BucketEntry] -> [BucketEntry]
sortByDistance :: DHTKey -> [BucketEntry] -> [BucketEntry]
sortByDistance DHTKey
target =
  (BucketEntry -> BucketEntry -> Ordering)
-> [BucketEntry] -> [BucketEntry]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (\BucketEntry
a BucketEntry
b -> DHTKey -> DHTKey -> DHTKey -> Ordering
compareDistance DHTKey
target (BucketEntry -> DHTKey
entryKey BucketEntry
a) (BucketEntry -> DHTKey
entryKey BucketEntry
b))