-- | AutoNAT v1 service: detect NAT status by asking remote peers to dial back.
--
-- Protocol: /libp2p/autonat/1.0.0
-- Flow: Client sends DIAL with addresses, server dials back, responds with result.
--
-- Security rules (from docs/10-nat-traversal.md):
--   - Server MUST NOT dial addresses unless they match the requester's observed IP
--   - Server MUST NOT accept dial requests over relayed connections
module Network.LibP2P.NAT.AutoNAT.AutoNAT
  ( -- * Types
    NATStatus (..)
  , AutoNATConfig (..)
    -- * Server
  , handleAutoNAT
    -- * Client
  , requestAutoNAT
    -- * NAT status aggregation
  , probeNATStatusPure
  ) where

import Data.ByteString (ByteString)
import qualified Data.Text as T
import Data.Word (Word32)
import Network.LibP2P.NAT.AutoNAT.Message
import Network.LibP2P.MultistreamSelect.Negotiation (StreamIO (..))
import Network.LibP2P.Multiaddr.Multiaddr (Multiaddr (..), toBytes, fromBytes)
import Network.LibP2P.Multiaddr.Protocol (Protocol (..))
import Network.LibP2P.Crypto.PeerId (PeerId (..))

-- | Detected NAT status.
data NATStatus = NATPublic | NATPrivate | NATUnknown
  deriving (Int -> NATStatus -> ShowS
[NATStatus] -> ShowS
NATStatus -> String
(Int -> NATStatus -> ShowS)
-> (NATStatus -> String)
-> ([NATStatus] -> ShowS)
-> Show NATStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NATStatus -> ShowS
showsPrec :: Int -> NATStatus -> ShowS
$cshow :: NATStatus -> String
show :: NATStatus -> String
$cshowList :: [NATStatus] -> ShowS
showList :: [NATStatus] -> ShowS
Show, NATStatus -> NATStatus -> Bool
(NATStatus -> NATStatus -> Bool)
-> (NATStatus -> NATStatus -> Bool) -> Eq NATStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NATStatus -> NATStatus -> Bool
== :: NATStatus -> NATStatus -> Bool
$c/= :: NATStatus -> NATStatus -> Bool
/= :: NATStatus -> NATStatus -> Bool
Eq)

-- | AutoNAT configuration.
data AutoNATConfig = AutoNATConfig
  { AutoNATConfig -> Int
natThreshold :: !Int
    -- ^ Number of peers that must agree for a definitive result
  , AutoNATConfig -> PeerId -> [Multiaddr] -> IO (Either String ())
natDialBack  :: !(PeerId -> [Multiaddr] -> IO (Either String ()))
    -- ^ Injectable dial-back function (for testing)
  }

-- | Server handler: receive DIAL, validate, dial back, respond.
--
-- Security:
--   - Rejects requests from relayed connections (P2PCircuit in observed addr)
--   - Filters dial-back addresses to match observed IP
handleAutoNAT :: AutoNATConfig -> StreamIO -> PeerId -> Multiaddr -> IO ()
handleAutoNAT :: AutoNATConfig -> StreamIO -> PeerId -> Multiaddr -> IO ()
handleAutoNAT AutoNATConfig
config StreamIO
stream PeerId
remotePeerId Multiaddr
remoteObservedAddr = do
  result <- StreamIO -> Int -> IO (Either String AutoNATMessage)
readAutoNATMessage StreamIO
stream Int
maxAutoNATMessageSize
  case result of
    Left String
_err -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Right AutoNATMessage
msg -> do
      resp <- AutoNATConfig
-> AutoNATMessage -> PeerId -> Multiaddr -> IO AutoNATMessage
processDialRequest AutoNATConfig
config AutoNATMessage
msg PeerId
remotePeerId Multiaddr
remoteObservedAddr
      writeAutoNATMessage stream resp

-- | Process a DIAL request and produce a response.
processDialRequest :: AutoNATConfig -> AutoNATMessage -> PeerId -> Multiaddr -> IO AutoNATMessage
processDialRequest :: AutoNATConfig
-> AutoNATMessage -> PeerId -> Multiaddr -> IO AutoNATMessage
processDialRequest AutoNATConfig
config AutoNATMessage
msg PeerId
_remotePeerId Multiaddr
remoteObservedAddr
  -- Reject requests from relayed connections
  | Multiaddr -> Bool
isRelayedAddr Multiaddr
remoteObservedAddr = AutoNATMessage -> IO AutoNATMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATMessage -> IO AutoNATMessage)
-> AutoNATMessage -> IO AutoNATMessage
forall a b. (a -> b) -> a -> b
$ ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
EDialRefused (String -> Maybe String
forall a. a -> Maybe a
Just String
"relayed connection") Maybe ByteString
forall a. Maybe a
Nothing
  | Bool
otherwise = case AutoNATMessage -> Maybe AutoNATDial
anMsgDial AutoNATMessage
msg of
      Maybe AutoNATDial
Nothing -> AutoNATMessage -> IO AutoNATMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATMessage -> IO AutoNATMessage)
-> AutoNATMessage -> IO AutoNATMessage
forall a b. (a -> b) -> a -> b
$ ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
EBadRequest (String -> Maybe String
forall a. a -> Maybe a
Just String
"missing dial field") Maybe ByteString
forall a. Maybe a
Nothing
      Just AutoNATDial
dial -> case AutoNATDial -> Maybe AutoNATPeerInfo
anDialPeer AutoNATDial
dial of
        Maybe AutoNATPeerInfo
Nothing -> AutoNATMessage -> IO AutoNATMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATMessage -> IO AutoNATMessage)
-> AutoNATMessage -> IO AutoNATMessage
forall a b. (a -> b) -> a -> b
$ ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
EBadRequest (String -> Maybe String
forall a. a -> Maybe a
Just String
"missing peer info") Maybe ByteString
forall a. Maybe a
Nothing
        Just AutoNATPeerInfo
peerInfo -> do
          let requestedAddrs :: [Multiaddr]
requestedAddrs = (ByteString -> Either String Multiaddr)
-> [ByteString] -> [Multiaddr]
forall a e b. (a -> Either e b) -> [a] -> [b]
mapMaybe' ByteString -> Either String Multiaddr
fromBytes (AutoNATPeerInfo -> [ByteString]
anAddrs AutoNATPeerInfo
peerInfo)
              filteredAddrs :: [Multiaddr]
filteredAddrs = Multiaddr -> [Multiaddr] -> [Multiaddr]
filterByObservedIP Multiaddr
remoteObservedAddr [Multiaddr]
requestedAddrs
          if [Multiaddr] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Multiaddr]
filteredAddrs
            then AutoNATMessage -> IO AutoNATMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATMessage -> IO AutoNATMessage)
-> AutoNATMessage -> IO AutoNATMessage
forall a b. (a -> b) -> a -> b
$ ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
EBadRequest (String -> Maybe String
forall a. a -> Maybe a
Just String
"no valid addresses") Maybe ByteString
forall a. Maybe a
Nothing
            else do
              let peerId :: PeerId
peerId = ByteString -> PeerId
PeerId (AutoNATPeerInfo -> ByteString
anPeerId AutoNATPeerInfo
peerInfo)
              dialResult <- AutoNATConfig -> PeerId -> [Multiaddr] -> IO (Either String ())
natDialBack AutoNATConfig
config PeerId
peerId [Multiaddr]
filteredAddrs
              case dialResult of
                Right () ->
                  let addrBytes :: Maybe ByteString
addrBytes = case [Multiaddr]
filteredAddrs of
                        (Multiaddr
a:[Multiaddr]
_) -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (Multiaddr -> ByteString
toBytes Multiaddr
a)
                        []    -> Maybe ByteString
forall a. Maybe a
Nothing
                  in AutoNATMessage -> IO AutoNATMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATMessage -> IO AutoNATMessage)
-> AutoNATMessage -> IO AutoNATMessage
forall a b. (a -> b) -> a -> b
$ ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
StatusOK Maybe String
forall a. Maybe a
Nothing Maybe ByteString
addrBytes
                Left String
_err ->
                  AutoNATMessage -> IO AutoNATMessage
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATMessage -> IO AutoNATMessage)
-> AutoNATMessage -> IO AutoNATMessage
forall a b. (a -> b) -> a -> b
$ ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
EDialError (String -> Maybe String
forall a. a -> Maybe a
Just String
"dial failed") Maybe ByteString
forall a. Maybe a
Nothing

-- | Build a DIAL_RESPONSE message.
mkDialResponse :: ResponseStatus -> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse :: ResponseStatus
-> Maybe String -> Maybe ByteString -> AutoNATMessage
mkDialResponse ResponseStatus
status Maybe String
mText Maybe ByteString
mAddr = AutoNATMessage
  { anMsgType :: Maybe AutoNATMessageType
anMsgType = AutoNATMessageType -> Maybe AutoNATMessageType
forall a. a -> Maybe a
Just AutoNATMessageType
DIAL_RESPONSE
  , anMsgDial :: Maybe AutoNATDial
anMsgDial = Maybe AutoNATDial
forall a. Maybe a
Nothing
  , anMsgDialResponse :: Maybe AutoNATDialResponse
anMsgDialResponse = AutoNATDialResponse -> Maybe AutoNATDialResponse
forall a. a -> Maybe a
Just AutoNATDialResponse
      { anRespStatus :: Maybe ResponseStatus
anRespStatus = ResponseStatus -> Maybe ResponseStatus
forall a. a -> Maybe a
Just ResponseStatus
status
      , anRespStatusText :: Maybe Text
anRespStatusText = (String -> Text) -> Maybe String -> Maybe Text
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap String -> Text
T.pack Maybe String
mText
      , anRespAddr :: Maybe ByteString
anRespAddr = Maybe ByteString
mAddr
      }
  }

-- | Client: send DIAL with local addresses, receive response.
requestAutoNAT :: StreamIO -> PeerId -> [Multiaddr] -> IO (Either String AutoNATDialResponse)
requestAutoNAT :: StreamIO
-> PeerId -> [Multiaddr] -> IO (Either String AutoNATDialResponse)
requestAutoNAT StreamIO
stream PeerId
localPeerId [Multiaddr]
localAddrs = do
  let PeerId ByteString
pidBytes = PeerId
localPeerId
      dialMsg :: AutoNATMessage
dialMsg = AutoNATMessage
        { anMsgType :: Maybe AutoNATMessageType
anMsgType = AutoNATMessageType -> Maybe AutoNATMessageType
forall a. a -> Maybe a
Just AutoNATMessageType
DIAL
        , anMsgDial :: Maybe AutoNATDial
anMsgDial = AutoNATDial -> Maybe AutoNATDial
forall a. a -> Maybe a
Just AutoNATDial
            { anDialPeer :: Maybe AutoNATPeerInfo
anDialPeer = AutoNATPeerInfo -> Maybe AutoNATPeerInfo
forall a. a -> Maybe a
Just AutoNATPeerInfo
                { anPeerId :: ByteString
anPeerId = ByteString
pidBytes
                , anAddrs :: [ByteString]
anAddrs = (Multiaddr -> ByteString) -> [Multiaddr] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Multiaddr -> ByteString
toBytes [Multiaddr]
localAddrs
                }
            }
        , anMsgDialResponse :: Maybe AutoNATDialResponse
anMsgDialResponse = Maybe AutoNATDialResponse
forall a. Maybe a
Nothing
        }
  StreamIO -> AutoNATMessage -> IO ()
writeAutoNATMessage StreamIO
stream AutoNATMessage
dialMsg
  result <- StreamIO -> Int -> IO (Either String AutoNATMessage)
readAutoNATMessage StreamIO
stream Int
maxAutoNATMessageSize
  case result of
    Left String
err -> Either String AutoNATDialResponse
-> IO (Either String AutoNATDialResponse)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String AutoNATDialResponse
forall a b. a -> Either a b
Left String
err)
    Right AutoNATMessage
resp -> case AutoNATMessage -> Maybe AutoNATDialResponse
anMsgDialResponse AutoNATMessage
resp of
      Maybe AutoNATDialResponse
Nothing -> Either String AutoNATDialResponse
-> IO (Either String AutoNATDialResponse)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Either String AutoNATDialResponse
forall a b. a -> Either a b
Left String
"response missing dialResponse field")
      Just AutoNATDialResponse
dr -> Either String AutoNATDialResponse
-> IO (Either String AutoNATDialResponse)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AutoNATDialResponse -> Either String AutoNATDialResponse
forall a b. b -> Either a b
Right AutoNATDialResponse
dr)

-- | Pure aggregation of AutoNAT results into a NAT status.
-- Counts OK responses as "public" votes, all other results as "private" votes.
probeNATStatusPure :: Int -> [Either String AutoNATDialResponse] -> NATStatus
probeNATStatusPure :: Int -> [Either String AutoNATDialResponse] -> NATStatus
probeNATStatusPure Int
_threshold [] = NATStatus
NATUnknown
probeNATStatusPure Int
threshold [Either String AutoNATDialResponse]
results =
  let (Int
okCount, Int
failCount) = ((Int, Int) -> Either String AutoNATDialResponse -> (Int, Int))
-> (Int, Int) -> [Either String AutoNATDialResponse] -> (Int, Int)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Int, Int) -> Either String AutoNATDialResponse -> (Int, Int)
countResult (Int
0 :: Int, Int
0 :: Int) [Either String AutoNATDialResponse]
results
  in if Int
okCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
threshold then NATStatus
NATPublic
     else if Int
failCount Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
threshold then NATStatus
NATPrivate
     else NATStatus
NATUnknown
  where
    countResult :: (Int, Int) -> Either String AutoNATDialResponse -> (Int, Int)
    countResult :: (Int, Int) -> Either String AutoNATDialResponse -> (Int, Int)
countResult (Int
ok, Int
fail') (Left String
_) = (Int
ok, Int
fail' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    countResult (Int
ok, Int
fail') (Right AutoNATDialResponse
dr) =
      case AutoNATDialResponse -> Maybe ResponseStatus
anRespStatus AutoNATDialResponse
dr of
        Just ResponseStatus
StatusOK -> (Int
ok Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
fail')
        Maybe ResponseStatus
_             -> (Int
ok, Int
fail' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- Helpers

-- | Check if a multiaddr contains P2PCircuit (indicating a relayed connection).
isRelayedAddr :: Multiaddr -> Bool
isRelayedAddr :: Multiaddr -> Bool
isRelayedAddr (Multiaddr [Protocol]
ps) = (Protocol -> Bool) -> [Protocol] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Protocol -> Bool
isCircuit [Protocol]
ps
  where
    isCircuit :: Protocol -> Bool
isCircuit Protocol
P2PCircuit = Bool
True
    isCircuit Protocol
_          = Bool
False

-- | Extract IP address (as Word32 for IPv4) from a multiaddr.
extractIP4 :: Multiaddr -> Maybe Word32
extractIP4 :: Multiaddr -> Maybe Word32
extractIP4 (Multiaddr [Protocol]
ps) = [Protocol] -> Maybe Word32
go [Protocol]
ps
  where
    go :: [Protocol] -> Maybe Word32
go [] = Maybe Word32
forall a. Maybe a
Nothing
    go (IP4 Word32
addr : [Protocol]
_) = Word32 -> Maybe Word32
forall a. a -> Maybe a
Just Word32
addr
    go (Protocol
_ : [Protocol]
rest) = [Protocol] -> Maybe Word32
go [Protocol]
rest

-- | Filter addresses to only those matching the observed IP.
filterByObservedIP :: Multiaddr -> [Multiaddr] -> [Multiaddr]
filterByObservedIP :: Multiaddr -> [Multiaddr] -> [Multiaddr]
filterByObservedIP Multiaddr
observed [Multiaddr]
addrs =
  case Multiaddr -> Maybe Word32
extractIP4 Multiaddr
observed of
    Maybe Word32
Nothing -> [Multiaddr]
addrs  -- can't determine IP, pass all through
    Just Word32
obsIP -> (Multiaddr -> Bool) -> [Multiaddr] -> [Multiaddr]
forall a. (a -> Bool) -> [a] -> [a]
filter (Word32 -> Multiaddr -> Bool
matchesIP Word32
obsIP) [Multiaddr]
addrs
  where
    matchesIP :: Word32 -> Multiaddr -> Bool
    matchesIP :: Word32 -> Multiaddr -> Bool
matchesIP Word32
obsIP Multiaddr
addr =
      case Multiaddr -> Maybe Word32
extractIP4 Multiaddr
addr of
        Just Word32
ip -> Word32
ip Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
== Word32
obsIP
        Maybe Word32
Nothing -> Bool
False  -- non-IP4 addresses are filtered out

-- | mapMaybe for Either (keeping only Right values).
mapMaybe' :: (a -> Either e b) -> [a] -> [b]
mapMaybe' :: forall a e b. (a -> Either e b) -> [a] -> [b]
mapMaybe' a -> Either e b
_ [] = []
mapMaybe' a -> Either e b
f (a
x:[a]
xs) = case a -> Either e b
f a
x of
  Right b
v -> b
v b -> [b] -> [b]
forall a. a -> [a] -> [a]
: (a -> Either e b) -> [a] -> [b]
forall a e b. (a -> Either e b) -> [a] -> [b]
mapMaybe' a -> Either e b
f [a]
xs
  Left e
_  -> (a -> Either e b) -> [a] -> [b]
forall a e b. (a -> Either e b) -> [a] -> [b]
mapMaybe' a -> Either e b
f [a]
xs