module Database.Persist.Sql.Raw where

import Control.Exception (throwIO)
import Control.Monad (liftM, when)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Logger (logDebugNS, runLoggingT)
import Control.Monad.Reader (MonadReader, ReaderT, ask)
import Control.Monad.Trans.Resource (MonadResource, release)
import Data.Acquire (Acquire, allocateAcquire, mkAcquire, with)
import Data.Conduit
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.Int (Int64)
import Data.Text (Text, pack)
import qualified Data.Text as T

import Database.Persist
import Database.Persist.Sql.Class
import Database.Persist.Sql.Types
import Database.Persist.Sql.Types.Internal
import Database.Persist.SqlBackend.Internal.StatementCache

rawQuery :: (MonadResource m, MonadReader env m, BackendCompatible SqlBackend env)
         => Text
         -> [PersistValue]
         -> ConduitM () [PersistValue] m ()
rawQuery :: forall (m :: * -> *) env.
(MonadResource m, MonadReader env m,
 BackendCompatible SqlBackend env) =>
Text -> [PersistValue] -> ConduitM () [PersistValue] m ()
rawQuery Text
sql [PersistValue]
vals = do
    srcRes <- ReaderT env IO (Acquire (ConduitM () [PersistValue] m ()))
-> ConduitT
     () [PersistValue] m (Acquire (ConduitM () [PersistValue] m ()))
forall (m :: * -> *) backend b.
(MonadIO m, MonadReader backend m) =>
ReaderT backend IO b -> m b
liftPersist (ReaderT env IO (Acquire (ConduitM () [PersistValue] m ()))
 -> ConduitT
      () [PersistValue] m (Acquire (ConduitM () [PersistValue] m ())))
-> ReaderT env IO (Acquire (ConduitM () [PersistValue] m ()))
-> ConduitT
     () [PersistValue] m (Acquire (ConduitM () [PersistValue] m ()))
forall a b. (a -> b) -> a -> b
$ Text
-> [PersistValue]
-> ReaderT env IO (Acquire (ConduitM () [PersistValue] m ()))
forall (m1 :: * -> *) (m2 :: * -> *) env.
(MonadIO m1, MonadIO m2, BackendCompatible SqlBackend env) =>
Text
-> [PersistValue]
-> ReaderT env m1 (Acquire (ConduitM () [PersistValue] m2 ()))
rawQueryRes Text
sql [PersistValue]
vals
    (releaseKey, src) <- allocateAcquire srcRes
    src
    release releaseKey

rawQueryRes
    :: (MonadIO m1, MonadIO m2, BackendCompatible SqlBackend env)
    => Text
    -> [PersistValue]
    -> ReaderT env m1 (Acquire (ConduitM () [PersistValue] m2 ()))
rawQueryRes :: forall (m1 :: * -> *) (m2 :: * -> *) env.
(MonadIO m1, MonadIO m2, BackendCompatible SqlBackend env) =>
Text
-> [PersistValue]
-> ReaderT env m1 (Acquire (ConduitM () [PersistValue] m2 ()))
rawQueryRes Text
sql [PersistValue]
vals = do
    conn <- env -> SqlBackend
forall sup sub. BackendCompatible sup sub => sub -> sup
projectBackend (env -> SqlBackend)
-> ReaderT env m1 env -> ReaderT env m1 SqlBackend
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` ReaderT env m1 env
forall r (m :: * -> *). MonadReader r m => m r
ask
    let make = do
            LoggingT IO ()
-> (Loc -> Text -> LogLevel -> LogStr -> IO ()) -> IO ()
forall (m :: * -> *) a.
LoggingT m a -> (Loc -> Text -> LogLevel -> LogStr -> IO ()) -> m a
runLoggingT (Text -> Text -> LoggingT IO ()
forall (m :: * -> *). MonadLogger m => Text -> Text -> m ()
logDebugNS ([Char] -> Text
pack [Char]
"SQL") (Text -> LoggingT IO ()) -> Text -> LoggingT IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append Text
sql (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ [Char] -> Text
pack ([Char] -> Text) -> [Char] -> Text
forall a b. (a -> b) -> a -> b
$ [Char]
"; " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [PersistValue] -> [Char]
forall a. Show a => a -> [Char]
show [PersistValue]
vals)
                (SqlBackend -> Loc -> Text -> LogLevel -> LogStr -> IO ()
connLogFunc SqlBackend
conn)
            SqlBackend -> Text -> IO Statement
getStmtConn SqlBackend
conn Text
sql
    return $ do
        stmt <- mkAcquire make stmtReset
        stmtQuery stmt vals

-- | Execute a raw SQL statement
rawExecute :: (MonadIO m, BackendCompatible SqlBackend backend)
           => Text            -- ^ SQL statement, possibly with placeholders.
           -> [PersistValue]  -- ^ Values to fill the placeholders.
           -> ReaderT backend m ()
rawExecute :: forall (m :: * -> *) backend.
(MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m ()
rawExecute Text
x [PersistValue]
y = (Int64 -> ()) -> ReaderT backend m Int64 -> ReaderT backend m ()
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (() -> Int64 -> ()
forall a b. a -> b -> a
const ()) (ReaderT backend m Int64 -> ReaderT backend m ())
-> ReaderT backend m Int64 -> ReaderT backend m ()
forall a b. (a -> b) -> a -> b
$ Text -> [PersistValue] -> ReaderT backend m Int64
forall (m :: * -> *) backend.
(MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m Int64
rawExecuteCount Text
x [PersistValue]
y

-- | Execute a raw SQL statement and return the number of
-- rows it has modified.
rawExecuteCount :: (MonadIO m, BackendCompatible SqlBackend backend)
                => Text            -- ^ SQL statement, possibly with placeholders.
                -> [PersistValue]  -- ^ Values to fill the placeholders.
                -> ReaderT backend m Int64
rawExecuteCount :: forall (m :: * -> *) backend.
(MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m Int64
rawExecuteCount Text
sql [PersistValue]
vals = do
    conn <- backend -> SqlBackend
forall sup sub. BackendCompatible sup sub => sub -> sup
projectBackend (backend -> SqlBackend)
-> ReaderT backend m backend -> ReaderT backend m SqlBackend
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` ReaderT backend m backend
forall r (m :: * -> *). MonadReader r m => m r
ask
    runLoggingT (logDebugNS (pack "SQL") $ T.append sql $ pack $ "; " ++ show vals)
        (connLogFunc conn)
    stmt <- getStmt sql
    res <- liftIO $ stmtExecute stmt vals
    liftIO $ stmtReset stmt
    return res

getStmt
  :: (MonadIO m, MonadReader backend m, BackendCompatible SqlBackend backend)
  => Text -> m Statement
getStmt :: forall (m :: * -> *) backend.
(MonadIO m, MonadReader backend m,
 BackendCompatible SqlBackend backend) =>
Text -> m Statement
getStmt Text
sql = do
    conn <- backend -> SqlBackend
forall sup sub. BackendCompatible sup sub => sub -> sup
projectBackend (backend -> SqlBackend) -> m backend -> m SqlBackend
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` m backend
forall r (m :: * -> *). MonadReader r m => m r
ask
    liftIO $ getStmtConn conn sql

getStmtConn :: SqlBackend -> Text -> IO Statement
getStmtConn :: SqlBackend -> Text -> IO Statement
getStmtConn SqlBackend
conn Text
sql = do
    let cacheK :: StatementCacheKey
cacheK = Text -> StatementCacheKey
mkCacheKeyFromQuery Text
sql
    mstmt <- StatementCache -> StatementCacheKey -> IO (Maybe Statement)
statementCacheLookup (SqlBackend -> StatementCache
connStmtMap SqlBackend
conn) StatementCacheKey
cacheK
    stmt <- case mstmt of
        Just Statement
stmt -> Statement -> IO Statement
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Statement
stmt
        Maybe Statement
Nothing -> do
            stmt' <- IO Statement -> IO Statement
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Statement -> IO Statement) -> IO Statement -> IO Statement
forall a b. (a -> b) -> a -> b
$ SqlBackend -> Text -> IO Statement
connPrepare SqlBackend
conn Text
sql
            iactive <- liftIO $ newIORef True
            let stmt = Statement
                    { stmtFinalize :: IO ()
stmtFinalize = do
                        active <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
iactive
                        when active $ do stmtFinalize stmt'
                                         writeIORef iactive False
                    , stmtReset :: IO ()
stmtReset = do
                        active <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
iactive
                        when active $ stmtReset stmt'
                    , stmtExecute :: [PersistValue] -> IO Int64
stmtExecute = \[PersistValue]
x -> do
                        active <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
iactive
                        if active
                            then stmtExecute stmt' x
                            else throwIO $ StatementAlreadyFinalized sql
                    , stmtQuery :: forall (m :: * -> *).
MonadIO m =>
[PersistValue] -> Acquire (ConduitM () [PersistValue] m ())
stmtQuery = \[PersistValue]
x -> do
                        active <- IO Bool -> Acquire Bool
forall a. IO a -> Acquire a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> Acquire Bool) -> IO Bool -> Acquire Bool
forall a b. (a -> b) -> a -> b
$ IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
iactive
                        if active
                            then stmtQuery stmt' x
                            else liftIO $ throwIO $ StatementAlreadyFinalized sql
                    }

            liftIO $ statementCacheInsert (connStmtMap conn) cacheK stmt
            pure stmt
    (hookGetStatement $ connHooks conn) conn sql stmt

-- | Execute a raw SQL statement and return its results as a
-- list. If you do not expect a return value, use of
-- `rawExecute` is recommended.
--
-- If you're using 'Entity'@s@ (which is quite likely), then you
-- /must/ use entity selection placeholders (double question
-- mark, @??@).  These @??@ placeholders are then replaced for
-- the names of the columns that we need for your entities.
-- You'll receive an error if you don't use the placeholders.
-- Please see the 'Entity'@s@ documentation for more details.
--
-- You may put value placeholders (question marks, @?@) in your
-- SQL query.  These placeholders are then replaced by the values
-- you pass on the second parameter, already correctly escaped.
-- You may want to use 'toPersistValue' to help you constructing
-- the placeholder values.
--
-- Since you're giving a raw SQL statement, you don't get any
-- guarantees regarding safety.  If 'rawSql' is not able to parse
-- the results of your query back, then an exception is raised.
-- However, most common problems are mitigated by using the
-- entity selection placeholder @??@, and you shouldn't see any
-- error at all if you're not using 'Single'.
--
-- Some example of 'rawSql' based on this schema:
--
-- @
-- share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
-- Person
--     name String
--     age Int Maybe
--     deriving Show
-- BlogPost
--     title String
--     authorId PersonId
--     deriving Show
-- |]
-- @
--
-- Examples based on the above schema:
--
-- @
-- getPerson :: MonadIO m => ReaderT SqlBackend m [Entity Person]
-- getPerson = rawSql "select ?? from person where name=?" [PersistText "john"]
--
-- getAge :: MonadIO m => ReaderT SqlBackend m [Single Int]
-- getAge = rawSql "select person.age from person where name=?" [PersistText "john"]
--
-- getAgeName :: MonadIO m => ReaderT SqlBackend m [(Single Int, Single Text)]
-- getAgeName = rawSql "select person.age, person.name from person where name=?" [PersistText "john"]
--
-- getPersonBlog :: MonadIO m => ReaderT SqlBackend m [(Entity Person, Entity BlogPost)]
-- getPersonBlog = rawSql "select ??,?? from person,blog_post where person.id = blog_post.author_id" []
-- @
--
-- Minimal working program for PostgreSQL backend based on the above concepts:
--
-- > {-# LANGUAGE EmptyDataDecls             #-}
-- > {-# LANGUAGE FlexibleContexts           #-}
-- > {-# LANGUAGE GADTs                      #-}
-- > {-# LANGUAGE GeneralizedNewtypeDeriving #-}
-- > {-# LANGUAGE MultiParamTypeClasses      #-}
-- > {-# LANGUAGE OverloadedStrings          #-}
-- > {-# LANGUAGE QuasiQuotes                #-}
-- > {-# LANGUAGE TemplateHaskell            #-}
-- > {-# LANGUAGE TypeFamilies               #-}
-- >
-- > import           Control.Monad.IO.Class  (liftIO)
-- > import           Control.Monad.Logger    (runStderrLoggingT)
-- > import           Database.Persist
-- > import           Control.Monad.Reader
-- > import           Data.Text
-- > import           Database.Persist.Sql
-- > import           Database.Persist.Postgresql
-- > import           Database.Persist.TH
-- >
-- > share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase|
-- > Person
-- >     name String
-- >     age Int Maybe
-- >     deriving Show
-- > |]
-- >
-- > conn = "host=localhost dbname=new_db user=postgres password=postgres port=5432"
-- >
-- > getPerson :: MonadIO m => ReaderT SqlBackend m [Entity Person]
-- > getPerson = rawSql "select ?? from person where name=?" [PersistText "sibi"]
-- >
-- > liftSqlPersistMPool y x = liftIO (runSqlPersistMPool y x)
-- >
-- > main :: IO ()
-- > main = runStderrLoggingT $ withPostgresqlPool conn 10 $ liftSqlPersistMPool $ do
-- >          runMigration migrateAll
-- >          xs <- getPerson
-- >          liftIO (print xs)
-- >

rawSql :: (RawSql a, MonadIO m, BackendCompatible SqlBackend backend)
       => Text             -- ^ SQL statement, possibly with placeholders.
       -> [PersistValue]   -- ^ Values to fill the placeholders.
       -> ReaderT backend m [a]
rawSql :: forall a (m :: * -> *) backend.
(RawSql a, MonadIO m, BackendCompatible SqlBackend backend) =>
Text -> [PersistValue] -> ReaderT backend m [a]
rawSql Text
stmt = [PersistValue] -> ReaderT backend m [a]
run
    where
      getType :: (x -> m [a]) -> a
      getType :: forall x (m :: * -> *) a. (x -> m [a]) -> a
getType = [Char] -> (x -> m [a]) -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"rawSql.getType"

      x :: a
x = ([PersistValue] -> ReaderT backend m [a]) -> a
forall x (m :: * -> *) a. (x -> m [a]) -> a
getType [PersistValue] -> ReaderT backend m [a]
run
      process :: [PersistValue] -> Either Text a
process = [PersistValue] -> Either Text a
forall a. RawSql a => [PersistValue] -> Either Text a
rawSqlProcessRow

      withStmt' :: [Text]
-> [PersistValue]
-> ConduitT [PersistValue] Void IO [a]
-> ReaderT backend m [a]
withStmt' [Text]
colSubsts [PersistValue]
params ConduitT [PersistValue] Void IO [a]
sink = do
            srcRes <- Text
-> [PersistValue]
-> ReaderT backend m (Acquire (ConduitM () [PersistValue] IO ()))
forall (m1 :: * -> *) (m2 :: * -> *) env.
(MonadIO m1, MonadIO m2, BackendCompatible SqlBackend env) =>
Text
-> [PersistValue]
-> ReaderT env m1 (Acquire (ConduitM () [PersistValue] m2 ()))
rawQueryRes Text
sql [PersistValue]
params
            liftIO $ with srcRes (\ConduitM () [PersistValue] IO ()
src -> ConduitT () Void IO [a] -> IO [a]
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO [a] -> IO [a])
-> ConduitT () Void IO [a] -> IO [a]
forall a b. (a -> b) -> a -> b
$ ConduitM () [PersistValue] IO ()
src ConduitM () [PersistValue] IO ()
-> ConduitT [PersistValue] Void IO [a] -> ConduitT () Void IO [a]
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| ConduitT [PersistValue] Void IO [a]
sink)
          where
            sql :: Text
sql = [Text] -> Text
T.concat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ [Text] -> [Text] -> [Text]
makeSubsts [Text]
colSubsts ([Text] -> [Text]) -> [Text] -> [Text]
forall a b. (a -> b) -> a -> b
$ HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
placeholder Text
stmt
            placeholder :: Text
placeholder = Text
"??"
            makeSubsts :: [Text] -> [Text] -> [Text]
makeSubsts (Text
s:[Text]
ss) (Text
t:[Text]
ts) = Text
t Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: Text
s Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text] -> [Text] -> [Text]
makeSubsts [Text]
ss [Text]
ts
            makeSubsts []     []     = []
            makeSubsts []     [Text]
ts     = [Text -> [Text] -> Text
T.intercalate Text
placeholder [Text]
ts]
            makeSubsts [Text]
ss     []     = [Char] -> [Text]
forall a. HasCallStack => [Char] -> a
error ([[Char]] -> [Char]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Char]]
err)
                where
                  err :: [[Char]]
err = [ [Char]
"rawsql: there are still ", Int -> [Char]
forall a. Show a => a -> [Char]
show ([Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
ss)
                        , [Char]
"'??' placeholder substitutions to be made "
                        , [Char]
"but all '??' placeholders have already been "
                        , [Char]
"consumed.  Please read 'rawSql's documentation "
                        , [Char]
"on how '??' placeholders work."
                        ]

      run :: [PersistValue] -> ReaderT backend m [a]
run [PersistValue]
params = do
        conn <- backend -> SqlBackend
forall sup sub. BackendCompatible sup sub => sub -> sup
projectBackend (backend -> SqlBackend)
-> ReaderT backend m backend -> ReaderT backend m SqlBackend
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` ReaderT backend m backend
forall r (m :: * -> *). MonadReader r m => m r
ask
        let (colCount, colSubsts) = rawSqlCols (connEscapeRawName conn) x
        withStmt' colSubsts params $ firstRow colCount

      firstRow :: Int -> ConduitT [PersistValue] Void IO [a]
firstRow Int
colCount = do
        mrow <- ConduitT [PersistValue] Void IO (Maybe [PersistValue])
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
        case mrow of
          Maybe [PersistValue]
Nothing -> [a] -> ConduitT [PersistValue] Void IO [a]
forall a. a -> ConduitT [PersistValue] Void IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
          Just [PersistValue]
row
              | Int
colCount Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PersistValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PersistValue]
row -> Maybe [PersistValue] -> ConduitT [PersistValue] Void IO [a]
getter Maybe [PersistValue]
mrow
              | Bool
otherwise              -> [Char] -> ConduitT [PersistValue] Void IO [a]
forall a. [Char] -> ConduitT [PersistValue] Void IO a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> ConduitT [PersistValue] Void IO [a])
-> [Char] -> ConduitT [PersistValue] Void IO [a]
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
                  [ [Char]
"rawSql: wrong number of columns, got "
                  , Int -> [Char]
forall a. Show a => a -> [Char]
show ([PersistValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PersistValue]
row), [Char]
" but expected ", Int -> [Char]
forall a. Show a => a -> [Char]
show Int
colCount
                  , [Char]
" (", a -> [Char]
forall a. RawSql a => a -> [Char]
rawSqlColCountReason a
x, [Char]
")." ]

      getter :: Maybe [PersistValue] -> ConduitT [PersistValue] Void IO [a]
getter = ([a] -> [a])
-> Maybe [PersistValue] -> ConduitT [PersistValue] Void IO [a]
go [a] -> [a]
forall a. a -> a
id
          where
            go :: ([a] -> [a])
-> Maybe [PersistValue] -> ConduitT [PersistValue] Void IO [a]
go [a] -> [a]
acc Maybe [PersistValue]
Nothing = [a] -> ConduitT [PersistValue] Void IO [a]
forall a. a -> ConduitT [PersistValue] Void IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([a] -> [a]
acc [])
            go [a] -> [a]
acc (Just [PersistValue]
row) =
              case [PersistValue] -> Either Text a
process [PersistValue]
row of
                Left Text
err -> [Char] -> ConduitT [PersistValue] Void IO [a]
forall a. [Char] -> ConduitT [PersistValue] Void IO a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail (Text -> [Char]
T.unpack Text
err)
                Right a
r  -> ConduitT [PersistValue] Void IO (Maybe [PersistValue])
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await ConduitT [PersistValue] Void IO (Maybe [PersistValue])
-> (Maybe [PersistValue] -> ConduitT [PersistValue] Void IO [a])
-> ConduitT [PersistValue] Void IO [a]
forall a b.
ConduitT [PersistValue] Void IO a
-> (a -> ConduitT [PersistValue] Void IO b)
-> ConduitT [PersistValue] Void IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ([a] -> [a])
-> Maybe [PersistValue] -> ConduitT [PersistValue] Void IO [a]
go ([a] -> [a]
acc ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
ra -> [a] -> [a]
forall a. a -> [a] -> [a]
:))