-- |
-- Module: M.IO.Internal.Socket
-- Description: Socket connection handling for Minecraft protocol
-- Copyright: (c) axionbuster, 2025
-- License: BSD-3-Clause
--
-- Implements reliable duplex stream connections for the Java Minecraft protocol,
-- handling encryption and compression.
module M.IO.Internal.Socket (Connection (..), withcxfromsocket) where

import Data.ByteString qualified as B
import M.Crypto
import M.IO.Internal.Datagram
import M.IO.Obs
import Network.SocketA
import System.IO.Streams
import UnliftIO

-- | a connection to either a server or a client
data Connection = Connection
  { -- | encryption key, AES-128-CFB8; doubles as IV
    Connection -> TVar (Maybe ByteString)
cxkey :: TVar (Maybe ByteString),
    -- | compression threshold; negative = off, non-negative = on with threshold
    Connection -> TVar Int
cxcompth :: TVar Int,
    -- | input stream
    Connection -> InputStream Uninterpreted
cxinput :: InputStream Uninterpreted,
    -- | output stream
    Connection -> OutputStream Uninterpreted
cxoutput :: OutputStream Uninterpreted
  }

-- | create a connection from a socket
withcxfromsocket :: (MonadUnliftIO m) => Socket -> (Connection -> m a) -> m a
withcxfromsocket :: forall (m :: * -> *) a.
MonadUnliftIO m =>
Socket -> (Connection -> m a) -> m a
withcxfromsocket Socket
sk Connection -> m a
cont = do
  th <- Int -> m (TVar Int)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO (-Int
1) -- compression off by default
  (i0, o0) <- liftIO (socketToStreams sk)
  (ef, df) <- liftA2 (,) (newTVarIO pure) (newTVarIO pure)
  (i1, o1) <-
    liftA2
      (,)
      (liftIO $ makedecrypting df i0)
      (liftIO $ makeencrypting ef o0)
  (i2, o2) <-
    liftA2
      (,)
      (liftIO $ makepacketstreami th i1)
      (liftIO $ makepacketstreamo th o1)
  k <- newTVarIO Nothing
  let watchk = IO Any -> m Any
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Any -> m Any) -> IO Any -> m Any
forall a b. (a -> b) -> a -> b
$ TVar (Maybe ByteString)
-> (Maybe ByteString -> Maybe ByteString -> STM (Maybe ByteString))
-> (Maybe ByteString -> Maybe ByteString -> IO ())
-> IO Any
forall (m :: * -> *) a b.
(MonadIO m, Eq a) =>
TVar a -> (a -> a -> STM a) -> (a -> a -> m ()) -> m b
obs
        do k -- target
        do (Maybe ByteString -> STM (Maybe ByteString))
-> Maybe ByteString -> Maybe ByteString -> STM (Maybe ByteString)
forall a b. a -> b -> a
const Maybe ByteString -> STM (Maybe ByteString)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure -- transform (accept new value)
        do
          (Maybe ByteString -> IO ())
-> Maybe ByteString -> Maybe ByteString -> IO ()
forall a b. a -> b -> a
const \case
            -- what to do upon a change in 'k'
            Maybe ByteString
Nothing -> STM () -> IO ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically do
              TVar (ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ByteString -> IO ByteString)
ef ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
              TVar (ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (ByteString -> IO ByteString)
df ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            Just ByteString
key -> do
              aese <- forall (mode :: Mode). AESClass mode => ByteString -> IO (AES mode)
aesnew @'Encrypt ByteString
key
              aesd <- aesnew @'Decrypt key
              atomically do
                writeTVar ef (aesupdate aese)
                writeTVar df (aesupdate aesd)
  withAsync watchk \Async Any
s -> do
    Async Any -> m ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
link Async Any
s
    (SomeException -> m a) -> m a -> m a
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
handle
      do \(SomeException
e :: SomeException) -> SomeException -> m a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
e
      do
        Connection -> m a
cont
          Connection
            { cxkey :: TVar (Maybe ByteString)
cxkey = TVar (Maybe ByteString)
k,
              cxcompth :: TVar Int
cxcompth = TVar Int
th,
              cxinput :: InputStream Uninterpreted
cxinput = InputStream Uninterpreted
i2,
              cxoutput :: OutputStream Uninterpreted
cxoutput = OutputStream Uninterpreted
o2
            }

-- compatibility for socketToStreams from System.IO.Streams.Network
-- that uses "network" for networking instead of "winasyncsocket"
socketToStreams ::
  Socket ->
  IO (InputStream ByteString, OutputStream ByteString)
socketToStreams :: Socket -> IO (InputStream ByteString, OutputStream ByteString)
socketToStreams Socket
sk = do
  i <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
makeInputStream do
    c <- Socket -> Int -> IO ByteString
recv Socket
sk Int
2048
    pure
      if B.null c
        then Nothing
        else Just c
  o <- makeOutputStream \case
    Maybe ByteString
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- leave open; conventional in io-streams
    Just ByteString
x -> Socket -> ByteString -> IO ()
sendall Socket
sk ByteString
x
  pure (i, o)