diff --git a/postgresql-simple.cabal b/postgresql-simple.cabal index 8627ba6..3f7c52e 100644 --- a/postgresql-simple.cabal +++ b/postgresql-simple.cabal @@ -168,6 +168,7 @@ test-suite test build-depends: aeson + , async , base , base16-bytestring , bytestring diff --git a/src/Database/PostgreSQL/Simple/Internal.hs b/src/Database/PostgreSQL/Simple/Internal.hs index b1f937a..1ccc767 100644 --- a/src/Database/PostgreSQL/Simple/Internal.hs +++ b/src/Database/PostgreSQL/Simple/Internal.hs @@ -24,7 +24,7 @@ module Database.PostgreSQL.Simple.Internal where import Control.Applicative import Control.Exception import Control.Concurrent.MVar -import Control.Monad(MonadPlus(..)) +import Control.Monad(MonadPlus(..), when) import Data.ByteString(ByteString) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 @@ -77,6 +77,10 @@ data Connection = Connection { connectionHandle :: {-# UNPACK #-} !(MVar PQ.Connection) , connectionObjects :: {-# UNPACK #-} !(MVar TypeInfoCache) , connectionTempNameCounter :: {-# UNPACK #-} !(IORef Int64) + , connectionMayHaveOrphanedStatement :: {-# UNPACK #-} !(IORef Bool) + -- ^ True if there could be a statement running in postgres in this connection, but + -- postgresql-simple is not waiting for results from it. This can happen when + -- postgresql-simple is interrupted by asynchronous exceptions. } deriving (Typeable) instance Eq Connection where @@ -238,6 +242,7 @@ connectPostgreSQL connstr = do connectionHandle <- newMVar conn connectionObjects <- newMVar (IntMap.empty) connectionTempNameCounter <- newIORef 0 + connectionMayHaveOrphanedStatement <- newIORef False let wconn = Connection{..} version <- PQ.serverVersion conn let settings @@ -330,43 +335,90 @@ exec conn sql = Just res -> return res #else exec conn sql = - withConnection conn $ \h -> do - success <- PQ.sendQuery h sql - if success - then awaitResult h Nothing - else throwLibPQError h "PQsendQuery failed" + withConnection conn $ \h -> withSocket h $ \socket-> uninterruptibleMask $ \restore -> do + -- 1. If postgresql-simple was interrupted when waiting for query results + -- before, cancel that query (it may even have completed by now, but that's fine) + -- before issuing a new one. + restore $ do + needsToCancel <- readIORef (connectionMayHaveOrphanedStatement conn) + when needsToCancel $ do + cancelRunningQuery h socket + writeIORef (connectionMayHaveOrphanedStatement conn) False + + -- 2. Ideally, the code that issues the query and waits for results + -- should not throw exceptions. That way we know an exception means + -- postgresql-simple was interrupted and the query might still be running. + -- Still, even if the code throws exceptions for other reasons, it means + -- we'll try to cancel a running query later once, which is fairly inocuous + -- as long as such exceptions are rare (which they should be). + restore (sendQueryAndWaitForResults h socket) + `onException` writeIORef (connectionMayHaveOrphanedStatement conn) True + where - awaitResult h mres = do - mfd <- PQ.socket h - case mfd of - Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec" - Just fd -> do - threadWaitRead fd - _ <- PQ.consumeInput h -- FIXME? - getResult h mres + withSocket h f = do + mfd <- PQ.socket h + case mfd of + Nothing -> throwIO $! fdError "Database.PostgreSQL.Simple.Internal.exec" + Just socket -> f socket + + sendQueryAndWaitForResults h socket = do + success <- PQ.sendQuery h sql + if success then do + consumeUntilNotBusy h socket + getResult h Nothing + else throwLibPQError h "PQsendQuery failed" + + cancelRunningQuery h socket = do + mcncl <- PQ.getCancel h + case mcncl of + Nothing -> pure () + Just cncl -> do + cancelStatus <- PQ.cancel cncl + case cancelStatus of + Left _ -> PQ.errorMessage h >>= \mmsg -> throwLibPQError h ("Database.PostgreSQL.Simple.Internal.cancelRunningQuery: " <> fromMaybe "Unknown error" mmsg + <> "\nIt looks like postgresql-simple was previously interrupted by an exception while waiting for query results." + <> " Because of that, before issuing a new query, we tried to cancel that previous query that was interrupted, but failed to do so.") + Right () -> do + consumeUntilNotBusy h socket + waitForNullResult h + + waitForNullResult h = do + mres <- PQ.getResult h + case mres of + Nothing -> pure () + Just _ -> waitForNullResult h + + -- | Waits until results are ready to be fetched. + consumeUntilNotBusy h socket = do + -- According to https://www.postgresql.org/docs/current/libpq-async.html : + -- 1. The isBusy status only changes by calling PQConsumeInput + -- 2. In case of errors, "PQgetResult should be called until it returns a null pointer, to allow libpq to process the error information completely" + -- 3. Also, "A typical application using these functions will have a main loop that uses select() or poll() ... When the main loop detects input ready, it should call PQconsumeInput to read the input. It can then call PQisBusy, followed by PQgetResult if PQisBusy returns false (0)" + busy <- PQ.isBusy h + when busy $ do + threadWaitRead socket + someError <- not <$> PQ.consumeInput h + when someError $ PQ.errorMessage h >>= \mmsg -> throwLibPQError h ("Database.PostgreSQL.Simple.Internal.consumeUntilNotBusy: " <> fromMaybe "Unknown error" mmsg) + consumeUntilNotBusy h socket getResult h mres = do - isBusy <- PQ.isBusy h - if isBusy - then awaitResult h mres - else do - mres' <- PQ.getResult h - case mres' of - Nothing -> case mres of - Nothing -> throwLibPQError h "PQgetResult returned no results" - Just res -> return res - Just res -> do - status <- PQ.resultStatus res - case status of - -- FIXME: handle PQ.CopyBoth and PQ.SingleTuple - PQ.EmptyQuery -> getResult h mres' - PQ.CommandOk -> getResult h mres' - PQ.TuplesOk -> getResult h mres' - PQ.CopyOut -> return res - PQ.CopyIn -> return res - PQ.BadResponse -> getResult h mres' - PQ.NonfatalError -> getResult h mres' - PQ.FatalError -> getResult h mres' + mres' <- PQ.getResult h + case mres' of + Nothing -> case mres of + Nothing -> throwLibPQError h "PQgetResult returned no results" + Just res -> return res + Just res -> do + status <- PQ.resultStatus res + case status of + -- FIXME: handle PQ.CopyBoth and PQ.SingleTuple + PQ.EmptyQuery -> getResult h mres' + PQ.CommandOk -> getResult h mres' + PQ.TuplesOk -> getResult h mres' + PQ.CopyOut -> return res + PQ.CopyIn -> return res + PQ.BadResponse -> getResult h mres' + PQ.NonfatalError -> getResult h mres' + PQ.FatalError -> getResult h mres' #endif -- | A version of 'execute' that does not perform query substitution. @@ -450,6 +502,7 @@ newNullConnection = do connectionHandle <- newMVar =<< PQ.newNullConnection connectionObjects <- newMVar IntMap.empty connectionTempNameCounter <- newIORef 0 + connectionMayHaveOrphanedStatement <- newIORef False return Connection{..} data Row = Row { diff --git a/test/Main.hs b/test/Main.hs index 0cb7b35..dc88a7b 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -16,16 +16,18 @@ import Database.PostgreSQL.Simple.ToField (ToField) import Database.PostgreSQL.Simple.FromField (FromField) import Database.PostgreSQL.Simple.HStore import Database.PostgreSQL.Simple.Newtypes -import Database.PostgreSQL.Simple.Internal (breakOnSingleQuestionMark) +import Database.PostgreSQL.Simple.Internal (breakOnSingleQuestionMark, connectionMayHaveOrphanedStatement) import Database.PostgreSQL.Simple.Types(Query(..),Values(..), PGArray(..)) import qualified Database.PostgreSQL.Simple.Transaction as ST import Control.Applicative +import Control.Concurrent (threadDelay) +import Control.Concurrent.Async (withAsync, wait) import Control.Exception as E import Control.Monad import Data.Char import Data.Foldable (toList) -import Data.List (concat, sort) +import Data.List (concat, sort, isInfixOf) import Data.IORef import Data.Monoid ((<>)) import Data.String (fromString) @@ -48,6 +50,7 @@ import System.FilePath import System.Timeout(timeout) import Data.Time.Compat (getCurrentTime, diffUTCTime) import System.Environment (getEnvironment) +import qualified System.IO as IO import Test.Tasty import Test.Tasty.Golden @@ -84,6 +87,10 @@ tests env = testGroup "tests" , testCase "2-ary generic" . testGeneric2 , testCase "3-ary generic" . testGeneric3 , testCase "Timeout" . testTimeout + , testCase "Expected user exceptions" . testExpectedExceptions + , testCase "Orphaned running query state mgmt" . testOrphanedRunningQueryStateMgmt + , testCase "Async exceptions" . testAsyncExceptionFailure + , testCase "Query canceled" . testCanceledQueryExceptions ] testBytea :: TestEnv -> TestTree @@ -534,6 +541,128 @@ testDouble TestEnv{..} = do [Only (x :: Double)] <- query_ conn "SELECT '-Infinity'::float8" x @?= (-1 / 0) +-- | Specifies exceptions thrown by postgresql-simple for certain user errors. +testExpectedExceptions :: TestEnv -> Assertion +testExpectedExceptions TestEnv{..} = do + withConn $ \c -> do + execute_ c "SELECT 1,2" `shouldThrow` (\(e :: QueryError) -> "2-column result" `isInfixOf` show e) + execute_ c "SELECT 1/0" `shouldThrow` (\(e :: SqlError) -> sqlState e == "22012") + (query_ c "SELECT 1, 2, 3" :: IO [(String, Int)]) `shouldThrow` (\(e :: ResultError) -> errSQLType e == "int4" && errHaskellType e == "Text") + +shouldThrow :: forall e a. Exception e => IO a -> (e -> Bool) -> IO () +shouldThrow f pred = do + ea <- try f + assertBool "Exception is as expected" $ case ea of + Right _ -> False + Left (ex :: e) -> pred ex + +-- | Ensures that the state associated with there being an orphaned +-- running statement in a connection is updated accordingly. +testOrphanedRunningQueryStateMgmt :: TestEnv -> Assertion +testOrphanedRunningQueryStateMgmt TestEnv{..} = withConn $ \c -> do + -- 1. Connections are created with no orphaned running queries, naturally. + runState c `shouldReturn` False + + -- 2. Interrupting a query that is still running should set the state + -- to True. + -- We need to give it enough time to start executing the query + -- before timing out. One second should be more than enough + void $ timeout (1000 * 1000) (execute_ c "SELECT pg_sleep(100)") + runState c `shouldReturn` True + + -- 3. Running a new query should clear the state again + [ Only (num13 :: Int) ] <- query c "SELECT 13" () + num13 @?= 13 + runState c `shouldReturn` False + + -- 4. Interrupting a query but letting it run until completion shouldn't + -- matter (postgresql-simple has no way of knowing that), but no errors + -- should come out of it + void $ timeout (1000 * 1000) (execute_ c "SELECT pg_sleep(2)") + runState c `shouldReturn` True + + -- One second has passed, wait 2 more to ensure the query finished. + -- The state is still True. + threadDelay (1000 * 1000 * 2) + runState c `shouldReturn` True + + -- 5. Check that nothing wrong happens if we try to cancel a query + -- that is no longer running (this happens automatically by running another query) + [ Only (num17 :: Int) ] <- query c "SELECT 17" () + num17 @?= 17 + runState c `shouldReturn` False + + -- 6. Other errors that are not interruptions don't change the connection's state + execute_ c "SELECT 1/0" `shouldThrow` (\(_ :: SqlError) -> True) + runState c `shouldReturn` False + + where + runState = readIORef . connectionMayHaveOrphanedStatement + shouldReturn :: (Eq a, Show a, HasCallStack) => IO a -> a -> IO () + shouldReturn f expected = do + actual <- f + actual @?= expected + + +-- | Ensures that asynchronous exceptions thrown while queries are executing +-- are handled properly. +testAsyncExceptionFailure :: TestEnv -> Assertion +testAsyncExceptionFailure TestEnv{..} = withConn $ \c -> do + -- We need to give it enough time to start executing the query + -- before timing out. One second should be more than enough + execute_ c "SET my.setting TO '42'" + testAsyncException c (1000 * 1000) (execute_ c "SELECT pg_sleep(5)") + testAsyncException c (1000 * 1000) $ + bracket_ (execute_ c "CREATE TABLE IF NOT EXISTS copy_cancel (v INT)") (execute_ c "DROP TABLE IF EXISTS copy_cancel") $ + bracket_ (copy_ c "COPY copy_cancel FROM STDIN (FORMAT CSV)") (putCopyEnd c) $ do + putCopyData c "1\n" + threadDelay (1000 * 1000 * 60) + + where + testAsyncException c timeLimit f = do + tmt <- timeout timeLimit f + tmt @?= Nothing + -- Any other query should work now without errors. + number42 <- query_ c "SELECT current_setting('my.setting')" + number42 @?= [ Only ("42" :: String) ] + +-- | Ensures that canceled queries don't invalidate the Connection and specifies how +-- they can be detected. +testCanceledQueryExceptions :: TestEnv -> Assertion +testCanceledQueryExceptions TestEnv{..} = do + withConn $ \c1 -> withConn $ \c2 -> do + [ Only (c1Pid :: Int) ] <- query_ c1 "SELECT pg_backend_pid()" + execute_ c1 "SET my.setting TO '42'" + + testCancelation c1 c2 c1Pid execPgSleep $ \(ex :: SqlError) -> sqlState ex == "57014" + + -- What should we expect when COPY is canceled and putCopyEnd runs? The same SqlError as above, perhaps? Right now, + -- detecting if a query was canceled involves detecting two distinct types of exception. + testCancelation c1 c2 c1Pid execCopy $ \(ex :: IOException) -> "Database.PostgreSQL.Simple.Copy.putCopyEnd: failed to parse command status" `isInfixOf` show ex + && "ERROR: canceling statement due to user request" `isInfixOf` show ex + + -- Any other query should work now without errors. + number42 <- query_ c1 "SELECT current_setting('my.setting')" + number42 @?= [ Only ("42" :: String) ] + + where + execPgSleep c = execute_ c "SELECT pg_sleep(5)" + execCopy c = + bracket_ (execute_ c "CREATE TABLE IF NOT EXISTS copy_cancel (v INT)") (execute_ c "DROP TABLE IF EXISTS copy_cancel") $ + bracket_ (copy_ c "COPY copy_cancel FROM STDIN (FORMAT CSV)") (putCopyEnd c) $ do + putCopyData c "1\n" + threadDelay (1000 * 1000 * 2) + -- putCopyEnd will run after pg_cancel_backend due to threadDelays + testCancelation c1 c2 cPid f exPred = withAsync (f c1) $ \longRunningAction -> do + -- We need to give it enough time to start executing the query + -- before canceling it. One second should be more than enough + threadDelay (1000 * 1000) + cancelResult <- query c2 "SELECT pg_cancel_backend(?)" (Only cPid) + cancelResult @?= [ Only True ] + wait longRunningAction `shouldThrow` exPred + -- Connection is still usable after query canceled + [ Only (cPidAgain :: Int) ] <- query_ c1 "SELECT pg_backend_pid()" + cPid @?= cPidAgain testGeneric1 :: TestEnv -> Assertion testGeneric1 TestEnv{..} = do @@ -621,6 +750,8 @@ withTestEnv connstr cb = main :: IO () main = withConnstring $ \connstring -> do + IO.hSetBuffering IO.stdout IO.NoBuffering + IO.hSetBuffering IO.stderr IO.NoBuffering withTestEnv connstring (defaultMain . tests) withConnstring :: (BS8.ByteString -> IO ()) -> IO ()