module PureSAT.Vec (
    Vec,
    newVec,
    sizeofVec,
    insertVec,
    readVec,
    writeVec,
    shrinkVec,
) where

import Data.Primitive.PrimVar
import Unsafe.Coerce          (unsafeCoerce)

import PureSAT.Base
import PureSAT.Utils
import PureSAT.Prim

data Vec s a = Vec {-# UNPACK #-} !(PrimVar s Int) {-# UNPACK #-} !(MutableArray s a)

newVec
    :: Int             -- ^ capacity
    -> ST s (Vec s a)
newVec :: forall s a. Int -> ST s (Vec s a)
newVec Int
capacity = do
    MutableArray s a
arr <- Int -> a -> ST s (MutableArray (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
newArray (Int -> Int
nextPowerOf2 (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
64 Int
capacity)) a
forall a. a
unused
    PrimVar s Int
size <- Int -> ST s (PrimVar (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
a -> m (PrimVar (PrimState m) a)
newPrimVar Int
0
    Vec s a -> ST s (Vec s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimVar s Int -> MutableArray s a -> Vec s a
forall s a. PrimVar s Int -> MutableArray s a -> Vec s a
Vec PrimVar s Int
size MutableArray s a
arr)

unused :: a
unused :: forall a. a
unused = a
forall a. HasCallStack => a
undefined

sizeofVec :: Vec s a -> ST s Int
sizeofVec :: forall s a. Vec s a -> ST s Int
sizeofVec (Vec PrimVar s Int
size MutableArray s a
_) = PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
size
{-# INLINE sizeofVec #-}

-- | Insert at the end: @push_back@
--
-- The new vector may be returned.
-- The vec is done such way, as we use it in mutable context already,
-- so we don't need an extra STRef.
insertVec :: Vec s a -> a -> ST s (Vec s a)
insertVec :: forall s a. Vec s a -> a -> ST s (Vec s a)
insertVec vec :: Vec s a
vec@(Vec PrimVar s Int
sizeRef MutableArray s a
arr) a
x = do
    Int
size <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
sizeRef
    let !capacity :: Int
capacity = MutableArray s a -> Int
forall s a. MutableArray s a -> Int
sizeofMutableArray MutableArray s a
arr
    if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
capacity
    then do
        MutableArray s a -> Int -> a -> ST s ()
forall s a. HasCallStack => MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
arr Int
size a
x
        PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
sizeRef (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        Vec s a -> ST s (Vec s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Vec s a
vec

    else do
        MutableArray s a
new <- Int -> a -> ST s (MutableArray (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MutableArray (PrimState m) a)
newArray (Int
capacity Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2) a
forall a. a
unused
        MutableArray s a
-> Int -> MutableArray s a -> Int -> Int -> ST s ()
forall s a.
HasCallStack =>
MutableArray s a
-> Int -> MutableArray s a -> Int -> Int -> ST s ()
copyMutableArray MutableArray s a
new Int
0 MutableArray s a
arr Int
0 Int
size
        MutableArray s a -> Int -> a -> ST s ()
forall s a. HasCallStack => MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
new Int
size a
x
        PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
sizeRef (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        Vec s a -> ST s (Vec s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimVar s Int -> MutableArray s a -> Vec s a
forall s a. PrimVar s Int -> MutableArray s a -> Vec s a
Vec PrimVar s Int
sizeRef MutableArray s a
new)

readVec :: Vec s a -> Int -> ST s a
readVec :: forall s a. Vec s a -> Int -> ST s a
readVec (Vec PrimVar s Int
_ MutableArray s a
arr) Int
i = MutableArray s a -> Int -> ST s a
forall s a. HasCallStack => MutableArray s a -> Int -> ST s a
readArray MutableArray s a
arr Int
i

writeVec :: Vec s a -> Int -> a -> ST s ()
writeVec :: forall s a. Vec s a -> Int -> a -> ST s ()
writeVec (Vec PrimVar s Int
_ MutableArray s a
arr) Int
i a
x = MutableArray s a -> Int -> a -> ST s ()
forall s a. HasCallStack => MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
arr Int
i a
x

-- | Shrink vector. New size should be smaller than the current.
shrinkVec :: Vec s a -> Int -> ST s ()
shrinkVec :: forall s a. Vec s a -> Int -> ST s ()
shrinkVec (Vec PrimVar s Int
sizeRef MutableArray s a
arr) Int
newSize = do
    Int
size <- PrimVar (PrimState (ST s)) Int -> ST s Int
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> m a
readPrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
sizeRef
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
newSize .. Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> MutableArray s a -> Int -> a -> ST s ()
forall s a. HasCallStack => MutableArray s a -> Int -> a -> ST s ()
writeArray MutableArray s a
arr Int
i a
forall a. a
unused
    PrimVar (PrimState (ST s)) Int -> Int -> ST s ()
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
PrimVar (PrimState m) a -> a -> m ()
writePrimVar PrimVar s Int
PrimVar (PrimState (ST s)) Int
sizeRef Int
newSize