-- | Post-handshake Noise transport session.
--
-- After the XX handshake completes, cacophony's NoiseState transitions
-- to transport mode with two CipherStates (one for each direction).
-- This module wraps that state for encrypted message exchange.
module Network.LibP2P.Security.Noise.Session
  ( NoiseSession
  , mkNoiseSession
  , encryptMessage
  , decryptMessage
  ) where

import Crypto.Noise
  ( NoiseResult (..)
  , NoiseState
  , convert
  , readMessage
  , writeMessage
  )
import Crypto.Noise.Cipher.ChaChaPoly1305 (ChaChaPoly1305)
import Crypto.Noise.DH.Curve25519 (Curve25519)
import Crypto.Noise.Hash.SHA256 (SHA256)
import Data.ByteArray (ScrubbedBytes)
import Data.ByteString (ByteString)

-- | Type alias for the Noise state with our fixed cipher suite.
type CacophonyState = NoiseState ChaChaPoly1305 Curve25519 SHA256

-- | A post-handshake transport session for encrypted communication.
newtype NoiseSession = NoiseSession CacophonyState

-- | Create a NoiseSession from a completed handshake state.
mkNoiseSession :: CacophonyState -> NoiseSession
mkNoiseSession :: CacophonyState -> NoiseSession
mkNoiseSession = CacophonyState -> NoiseSession
NoiseSession

-- | Encrypt a plaintext message for sending.
-- Returns (ciphertext, updatedSession).
encryptMessage :: NoiseSession -> ByteString -> Either String (ByteString, NoiseSession)
encryptMessage :: NoiseSession
-> ByteString -> Either String (ByteString, NoiseSession)
encryptMessage (NoiseSession CacophonyState
ns) ByteString
plaintext =
  let sb :: ScrubbedBytes
sb = ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
plaintext :: ScrubbedBytes
   in case ScrubbedBytes
-> CacophonyState -> NoiseResult ChaChaPoly1305 Curve25519 SHA256
forall c d h.
(Cipher c, DH d, Hash h) =>
ScrubbedBytes -> NoiseState c d h -> NoiseResult c d h
writeMessage ScrubbedBytes
sb CacophonyState
ns of
        NoiseResultMessage ScrubbedBytes
ct CacophonyState
ns' ->
          (ByteString, NoiseSession)
-> Either String (ByteString, NoiseSession)
forall a b. b -> Either a b
Right (ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ScrubbedBytes
ct, CacophonyState -> NoiseSession
NoiseSession CacophonyState
ns')
        NoiseResultException SomeException
ex ->
          String -> Either String (ByteString, NoiseSession)
forall a b. a -> Either a b
Left (String -> Either String (ByteString, NoiseSession))
-> String -> Either String (ByteString, NoiseSession)
forall a b. (a -> b) -> a -> b
$ String
"encryptMessage: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
ex
        NoiseResultNeedPSK CacophonyState
_ ->
          String -> Either String (ByteString, NoiseSession)
forall a b. a -> Either a b
Left String
"encryptMessage: unexpected PSK request"

-- | Decrypt a received ciphertext message.
-- Returns (plaintext, updatedSession).
decryptMessage :: NoiseSession -> ByteString -> Either String (ByteString, NoiseSession)
decryptMessage :: NoiseSession
-> ByteString -> Either String (ByteString, NoiseSession)
decryptMessage (NoiseSession CacophonyState
ns) ByteString
ciphertext =
  let sb :: ScrubbedBytes
sb = ByteString -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ByteString
ciphertext :: ScrubbedBytes
   in case ScrubbedBytes
-> CacophonyState -> NoiseResult ChaChaPoly1305 Curve25519 SHA256
forall c d h.
(Cipher c, DH d, Hash h) =>
ScrubbedBytes -> NoiseState c d h -> NoiseResult c d h
readMessage ScrubbedBytes
sb CacophonyState
ns of
        NoiseResultMessage ScrubbedBytes
pt CacophonyState
ns' ->
          (ByteString, NoiseSession)
-> Either String (ByteString, NoiseSession)
forall a b. b -> Either a b
Right (ScrubbedBytes -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert ScrubbedBytes
pt, CacophonyState -> NoiseSession
NoiseSession CacophonyState
ns')
        NoiseResultException SomeException
ex ->
          String -> Either String (ByteString, NoiseSession)
forall a b. a -> Either a b
Left (String -> Either String (ByteString, NoiseSession))
-> String -> Either String (ByteString, NoiseSession)
forall a b. (a -> b) -> a -> b
$ String
"decryptMessage: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
ex
        NoiseResultNeedPSK CacophonyState
_ ->
          String -> Either String (ByteString, NoiseSession)
forall a b. a -> Either a b
Left String
"decryptMessage: unexpected PSK request"