-- |
-- Module: M.IO.Internal.Socket
-- Description: Socket connection handling for Minecraft protocol
-- Copyright: (c) axionbuster, 2025
-- License: BSD-3-Clause
--
-- Implements reliable duplex stream connections for the Java Minecraft protocol,
-- handling encryption and compression.
module M.IO.Internal.Socket (Connection (..), withcxfromsocket) where

import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Monad
import Data.ByteString
import M.Crypto
import M.IO.Internal.Datagram
import Network.Socket
import System.IO.Streams (InputStream, OutputStream)
import System.IO.Streams.Network (socketToStreams)

-- | a connection to either a server or a client
data Connection = Connection
  { -- | encryption key, AES-128-CFB8; doubles as IV
    Connection -> TVar (Maybe ByteString)
cxkey :: TVar (Maybe ByteString),
    -- | compression threshold; negative = off, non-negative = on with threshold
    Connection -> TVar Int
cxcompth :: TVar Int,
    -- | input stream
    Connection -> InputStream Uninterpreted
cxinput :: InputStream Uninterpreted,
    -- | output stream
    Connection -> OutputStream Uninterpreted
cxoutput :: OutputStream Uninterpreted
  }

-- | create a connection from a socket
withcxfromsocket :: Socket -> (Connection -> IO a) -> IO a
withcxfromsocket :: forall a. Socket -> (Connection -> IO a) -> IO a
withcxfromsocket Socket
sk Connection -> IO a
cont = do
  TVar Int
th <- Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO (-Int
1) -- compression off by default
  (InputStream ByteString
i0, OutputStream ByteString
o0) <- Socket -> IO (InputStream ByteString, OutputStream ByteString)
socketToStreams Socket
sk
  (TVar (ByteString -> IO ByteString)
ef, TVar (ByteString -> IO ByteString)
df) <- (TVar (ByteString -> IO ByteString)
 -> TVar (ByteString -> IO ByteString)
 -> (TVar (ByteString -> IO ByteString),
     TVar (ByteString -> IO ByteString)))
-> IO (TVar (ByteString -> IO ByteString))
-> IO (TVar (ByteString -> IO ByteString))
-> IO
     (TVar (ByteString -> IO ByteString),
      TVar (ByteString -> IO ByteString))
forall a b c. (a -> b -> c) -> IO a -> IO b -> IO c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) ((ByteString -> IO ByteString)
-> IO (TVar (ByteString -> IO ByteString))
forall a. a -> IO (TVar a)
newTVarIO ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) ((ByteString -> IO ByteString)
-> IO (TVar (ByteString -> IO ByteString))
forall a. a -> IO (TVar a)
newTVarIO ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure)
  (InputStream ByteString
i1, OutputStream ByteString
o1) <- (InputStream ByteString
 -> OutputStream ByteString
 -> (InputStream ByteString, OutputStream ByteString))
-> IO (InputStream ByteString)
-> IO (OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall a b c. (a -> b -> c) -> IO a -> IO b -> IO c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (TVar (ByteString -> IO ByteString)
-> InputStream ByteString -> IO (InputStream ByteString)
makedecrypting TVar (ByteString -> IO ByteString)
df InputStream ByteString
i0) (TVar (ByteString -> IO ByteString)
-> OutputStream ByteString -> IO (OutputStream ByteString)
makeencrypting TVar (ByteString -> IO ByteString)
ef OutputStream ByteString
o0)
  (InputStream Uninterpreted
i2, OutputStream Uninterpreted
o2) <- (InputStream Uninterpreted
 -> OutputStream Uninterpreted
 -> (InputStream Uninterpreted, OutputStream Uninterpreted))
-> IO (InputStream Uninterpreted)
-> IO (OutputStream Uninterpreted)
-> IO (InputStream Uninterpreted, OutputStream Uninterpreted)
forall a b c. (a -> b -> c) -> IO a -> IO b -> IO c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (TVar Int
-> InputStream ByteString -> IO (InputStream Uninterpreted)
makepacketstreami TVar Int
th InputStream ByteString
i1) (TVar Int
-> OutputStream ByteString -> IO (OutputStream Uninterpreted)
makepacketstreamo TVar Int
th OutputStream ByteString
o1)
  TVar (Maybe ByteString)
k <- Maybe ByteString -> IO (TVar (Maybe ByteString))
forall a. a -> IO (TVar a)
newTVarIO Maybe ByteString
forall a. Maybe a
Nothing
  -- need to go from the easy way to the hard way.
  -- why? because Datagram.hs expects functions to be passed in
  -- for crypto, so we need to convert encryption keys to
  -- encryption functions
  let watchk :: IO Any
watchk = do
        TVar (Maybe ByteString)
kold <- Maybe ByteString -> IO (TVar (Maybe ByteString))
forall a. a -> IO (TVar a)
newTVarIO Maybe ByteString
forall a. Maybe a
Nothing
        IO () -> IO Any
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
          Maybe ByteString
k' <- STM (Maybe ByteString) -> IO (Maybe ByteString)
forall a. STM a -> IO a
atomically do
            Maybe ByteString
kold' <- TVar (Maybe ByteString) -> STM (Maybe ByteString)
forall a. TVar a -> STM a
readTVar TVar (Maybe ByteString)
kold
            Maybe ByteString
knew <- TVar (Maybe ByteString) -> STM (Maybe ByteString)
forall a. TVar a -> STM a
readTVar TVar (Maybe ByteString)
k
            Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe ByteString
knew Maybe ByteString -> Maybe ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe ByteString
kold') STM ()
forall a. STM a
retry
            TVar (Maybe ByteString) -> Maybe ByteString -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe ByteString)
kold Maybe ByteString
knew
            Maybe ByteString -> STM (Maybe ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ByteString
knew
          case Maybe ByteString
k' of
            Maybe ByteString
Nothing -> STM () -> IO ()
forall a. STM a -> IO a
atomically do
              TVar (ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ByteString -> IO ByteString)
ef ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
              TVar (ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ByteString -> IO ByteString)
df ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            Just ByteString
key -> do
              AES 'Encrypt
aese <- forall (mode :: Mode). AESClass mode => ByteString -> IO (AES mode)
aesnew @'Encrypt ByteString
key
              AES 'Decrypt
aesd <- forall (mode :: Mode). AESClass mode => ByteString -> IO (AES mode)
aesnew @'Decrypt ByteString
key
              STM () -> IO ()
forall a. STM a -> IO a
atomically do
                TVar (ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ByteString -> IO ByteString)
ef (AES 'Encrypt -> ByteString -> IO ByteString
forall (mode :: Mode).
AESClass mode =>
AES mode -> ByteString -> IO ByteString
aesupdate AES 'Encrypt
aese)
                TVar (ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ByteString -> IO ByteString)
df (AES 'Decrypt -> ByteString -> IO ByteString
forall (mode :: Mode).
AESClass mode =>
AES mode -> ByteString -> IO ByteString
aesupdate AES 'Decrypt
aesd)
  IO Any -> (Async Any -> IO a) -> IO a
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync IO Any
watchk \Async Any
_ ->
    Connection -> IO a
cont
      Connection
        { cxkey :: TVar (Maybe ByteString)
cxkey = TVar (Maybe ByteString)
k,
          cxcompth :: TVar Int
cxcompth = TVar Int
th,
          cxinput :: InputStream Uninterpreted
cxinput = InputStream Uninterpreted
i2,
          cxoutput :: OutputStream Uninterpreted
cxoutput = OutputStream Uninterpreted
o2
        }