{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase, ViewPatterns #-}
module PureSAT.PartialAssignment where
#define ASSERTING(x)
import PureSAT.Base
import PureSAT.LBool
import PureSAT.LitVar
import PureSAT.Prim
newtype PartialAssignment s = PA (MutableByteArray s)
newPartialAssignment :: Int -> ST s (PartialAssignment s)
newPartialAssignment :: forall s. Int -> ST s (PartialAssignment s)
newPartialAssignment (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
4096 -> Int
size) = do
MutableByteArray s
arr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
size
MutableByteArray s -> Int -> ST s ()
forall s. HasCallStack => MutableByteArray s -> Int -> ST s ()
shrinkMutableByteArray MutableByteArray s
arr Int
size
MutableByteArray s -> Int -> Int -> Word8 -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s -> Int -> Int -> Word8 -> ST s ()
fillByteArray MutableByteArray s
arr Int
0 Int
size Word8
0xff
PartialAssignment s -> ST s (PartialAssignment s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray s -> PartialAssignment s
forall s. MutableByteArray s -> PartialAssignment s
PA MutableByteArray s
arr)
clonePartialAssignment :: PartialAssignment s -> ST s (PartialAssignment s)
clonePartialAssignment :: forall s. PartialAssignment s -> ST s (PartialAssignment s)
clonePartialAssignment (PA MutableByteArray s
old) = do
Int
n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
old
MutableByteArray s
new <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
n
MutableByteArray s
-> Int -> MutableByteArray s -> Int -> Int -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s
-> Int -> MutableByteArray s -> Int -> Int -> ST s ()
copyMutableByteArray MutableByteArray s
new Int
0 MutableByteArray s
old Int
0 Int
n
PartialAssignment s -> ST s (PartialAssignment s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray s -> PartialAssignment s
forall s. MutableByteArray s -> PartialAssignment s
PA MutableByteArray s
new)
copyPartialAssignment :: PartialAssignment s -> PartialAssignment s -> ST s ()
copyPartialAssignment :: forall s. PartialAssignment s -> PartialAssignment s -> ST s ()
copyPartialAssignment (PA MutableByteArray s
src) (PA MutableByteArray s
tgt) = do
Int
n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
src
Int
m <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
tgt
let size :: Int
size = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
m
MutableByteArray s
-> Int -> MutableByteArray s -> Int -> Int -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s
-> Int -> MutableByteArray s -> Int -> Int -> ST s ()
copyMutableByteArray MutableByteArray s
tgt Int
0 MutableByteArray s
src Int
0 Int
size
extendPartialAssignment :: PartialAssignment s -> ST s (PartialAssignment s)
extendPartialAssignment :: forall s. PartialAssignment s -> ST s (PartialAssignment s)
extendPartialAssignment (PA MutableByteArray s
arr) = do
Int
size <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
MutableByteArray s
arr' <- MutableByteArray s -> Int -> ST s (MutableByteArray s)
forall s.
HasCallStack =>
MutableByteArray s -> Int -> ST s (MutableByteArray s)
resizeMutableByteArray MutableByteArray s
arr (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
MutableByteArray s -> Int -> Word8 -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray MutableByteArray s
arr' Int
size (Word8
0xff :: Word8)
PartialAssignment s -> ST s (PartialAssignment s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableByteArray s -> PartialAssignment s
forall s. MutableByteArray s -> PartialAssignment s
PA MutableByteArray s
arr')
lookupPartialAssignment :: Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment (MkLit Int
l) (PA MutableByteArray s
arr) = do
MutableByteArray s -> Int -> ST s Word8
forall s. HasCallStack => MutableByteArray s -> Int -> ST s Word8
readByteArray MutableByteArray s
arr (Int -> Int
lit_to_var Int
l) ST s Word8 -> (Word8 -> ST s LBool) -> ST s LBool
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Word8
0x0 -> LBool -> ST s LBool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
y then LBool
LFalse else LBool
LTrue)
Word8
0x1 -> LBool -> ST s LBool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
y then LBool
LTrue else LBool
LFalse)
Word8
_ -> LBool -> ST s LBool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return LBool
LUndef
where
y :: Bool
y = Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
l Int
0
{-# INLINE y #-}
insertPartialAssignment :: Lit -> PartialAssignment s -> ST s ()
insertPartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s ()
insertPartialAssignment (MkLit Int
l) (PA MutableByteArray s
arr) = do
ASSERTING(readByteArray arr (lit_to_var l) >>= \x -> assertST "insert" (x == (0xff :: Word8)))
MutableByteArray s -> Int -> Word8 -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray MutableByteArray s
arr (Int -> Int
lit_to_var Int
l) (if Int -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int
l Int
0 then Word8
0x1 else Word8
0x0 :: Word8)
deletePartialAssignment :: Lit -> PartialAssignment s -> ST s ()
deletePartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s ()
deletePartialAssignment (MkLit Int
l) (PA MutableByteArray s
arr) = do
MutableByteArray s -> Int -> Word8 -> ST s ()
forall s.
HasCallStack =>
MutableByteArray s -> Int -> Word8 -> ST s ()
writeByteArray MutableByteArray s
arr (Int -> Int
lit_to_var Int
l) (Word8
0xff :: Word8)
tracePartialAssignment :: PartialAssignment s -> ST s ()
tracePartialAssignment :: forall s. PartialAssignment s -> ST s ()
tracePartialAssignment (PA MutableByteArray s
arr) = do
Int
n <- MutableByteArray (PrimState (ST s)) -> ST s Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr
[Lit]
lits <- Int -> [Lit] -> Int -> ST s [Lit]
go Int
n [] Int
0
String -> ST s ()
forall (f :: * -> *). Applicative f => String -> f ()
traceM (String -> ST s ()) -> String -> ST s ()
forall a b. (a -> b) -> a -> b
$ String
"PartialAssignment " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Lit] -> String
forall a. Show a => a -> String
show [Lit]
lits
where
go :: Int -> [Lit] -> Int -> ST s [Lit]
go Int
n [Lit]
acc Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n
, let l :: Lit
l = Int -> Lit
MkLit (Int -> Int
var_to_lit Int
i)
= MutableByteArray s -> Int -> ST s Word8
forall s. HasCallStack => MutableByteArray s -> Int -> ST s Word8
readByteArray MutableByteArray s
arr Int
i ST s Word8 -> (Word8 -> ST s [Lit]) -> ST s [Lit]
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Word8
0x0 -> Int -> [Lit] -> Int -> ST s [Lit]
go Int
n ( Lit
l Lit -> [Lit] -> [Lit]
forall a. a -> [a] -> [a]
: [Lit]
acc) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Word8
0x1 -> Int -> [Lit] -> Int -> ST s [Lit]
go Int
n (Lit -> Lit
neg Lit
l Lit -> [Lit] -> [Lit]
forall a. a -> [a] -> [a]
: [Lit]
acc) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
Word8
_ -> Int -> [Lit] -> Int -> ST s [Lit]
go Int
n [Lit]
acc (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise
= [Lit] -> ST s [Lit]
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Lit] -> [Lit]
forall a. [a] -> [a]
reverse [Lit]
acc)
assertLiteralInPartialAssignment :: Lit -> PartialAssignment s -> ST s ()
assertLiteralInPartialAssignment :: forall s. Lit -> PartialAssignment s -> ST s ()
assertLiteralInPartialAssignment Lit
l PartialAssignment s
pa =
Lit -> PartialAssignment s -> ST s LBool
forall s. Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment Lit
l PartialAssignment s
pa ST s LBool -> (LBool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
LBool
LTrue -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
LBool
x -> String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST (String
"lit in partial: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ LBool -> String
forall a. Show a => a -> String
show LBool
x) Bool
False
assertLiteralUndef :: Lit -> PartialAssignment s -> ST s ()
assertLiteralUndef :: forall s. Lit -> PartialAssignment s -> ST s ()
assertLiteralUndef Lit
l PartialAssignment s
pa =
Lit -> PartialAssignment s -> ST s LBool
forall s. Lit -> PartialAssignment s -> ST s LBool
lookupPartialAssignment Lit
l PartialAssignment s
pa ST s LBool -> (LBool -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \LBool
x ->
String -> Bool -> ST s ()
forall s. HasCallStack => String -> Bool -> ST s ()
assertST (String
"assertLiteralUndef: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ LBool -> String
forall a. Show a => a -> String
show LBool
x) (LBool
x LBool -> LBool -> Bool
forall a. Eq a => a -> a -> Bool
== LBool
LUndef)