Skip to content

Commit

Permalink
Reimplement TopLevel monad using ReaderT (IORef TopLevelRW) IO.
Browse files Browse the repository at this point in the history
Previously it used `StateT TopLevelRW IO`. This change will make it
possible to preserve state changes that have been made before an
exception is thrown.
  • Loading branch information
Brian Huffman committed Jun 22, 2021
1 parent 4254dd0 commit 9c2ee9d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
7 changes: 3 additions & 4 deletions src/SAWScript/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import Data.Monoid
#endif
import Control.Monad.Except (MonadError(..))
import Control.Monad.State
import Control.Monad.Reader (ask)
import qualified Control.Exception as Ex
import qualified Data.ByteString as StrictBS
import qualified Data.ByteString.Lazy as BS
Expand Down Expand Up @@ -1145,9 +1144,9 @@ timePrim a = do
return r

failsPrim :: TopLevel SV.Value -> TopLevel ()
failsPrim m = TopLevel $ do
topRO <- ask
topRW <- Control.Monad.State.get
failsPrim m = do
topRO <- getTopLevelRO
topRW <- getTopLevelRW
x <- liftIO $ Ex.try (runTopLevel m topRO topRW)
case x of
Left (ex :: Ex.SomeException) ->
Expand Down
30 changes: 21 additions & 9 deletions src/SAWScript/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ import qualified Control.Exception as X
import qualified System.IO.Error as IOError
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ReaderT(..), ask, asks, local)
import Control.Monad.State (MonadState, StateT(..), get, gets, put)
import Control.Monad.State (MonadState(..), StateT(..), get, gets, put)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.IORef
import Data.List ( intersperse )
import qualified Data.Map as M
import Data.Map ( Map )
Expand Down Expand Up @@ -419,17 +420,28 @@ data TopLevelRW =
}

newtype TopLevel a =
TopLevel (ReaderT TopLevelRO (StateT TopLevelRW IO) a)
TopLevel (ReaderT TopLevelRO (ReaderT (IORef TopLevelRW) IO) a)
deriving (Applicative, Functor, Generic, Generic1, Monad, MonadIO, MonadThrow, MonadCatch, MonadMask)

deriving instance MonadReader TopLevelRO TopLevel
deriving instance MonadState TopLevelRW TopLevel

instance MonadState TopLevelRW TopLevel where
get = TopLevel (lift (ReaderT readIORef))
put s = TopLevel (lift (ReaderT (flip writeIORef s)))
state f = TopLevel (lift (ReaderT (flip atomicModifyIORef (swap . f))))
where swap (x, y) = (y, x)

instance Wrapped (TopLevel a) where

instance MonadFail TopLevel where
fail = throwTopLevel

runTopLevel :: TopLevel a -> TopLevelRO -> TopLevelRW -> IO (a, TopLevelRW)
runTopLevel (TopLevel m) ro rw = runStateT (runReaderT m ro) rw
runTopLevel (TopLevel m) ro rw =
do ref <- newIORef rw
x <- runReaderT (runReaderT m ro) ref
rw' <- readIORef ref
pure (x, rw')

io :: IO a -> TopLevel a
io f = liftIO f
Expand Down Expand Up @@ -489,10 +501,10 @@ getTopLevelRO :: TopLevel TopLevelRO
getTopLevelRO = TopLevel ask

getTopLevelRW :: TopLevel TopLevelRW
getTopLevelRW = TopLevel get
getTopLevelRW = get

putTopLevelRW :: TopLevelRW -> TopLevel ()
putTopLevelRW rw = TopLevel (put rw)
putTopLevelRW rw = put rw

returnProof :: IsValue v => v -> TopLevel v
returnProof v = recordProof v >> return v
Expand All @@ -503,8 +515,8 @@ recordProof v =
putTopLevelRW rw { rwProofs = toValue v : rwProofs rw }

-- | Access the current state of Java Class translation
getJVMTrans :: TopLevel CJ.JVMContext
getJVMTrans = TopLevel (gets rwJVMTrans)
getJVMTrans :: TopLevel CJ.JVMContext
getJVMTrans = gets rwJVMTrans

-- | Access the current state of Java Class translation
putJVMTrans :: CJ.JVMContext -> TopLevel ()
Expand Down Expand Up @@ -1015,7 +1027,7 @@ addTraceReaderT str = underReaderT (addTraceTopLevel str)
-- | Similar to 'addTraceIO', but for the 'TopLevel' monad.
addTraceTopLevel :: String -> TopLevel a -> TopLevel a
addTraceTopLevel str action = action & _Wrapped' %~
underReaderT (underStateT (liftIO . addTraceIO str))
underReaderT (underReaderT (liftIO . addTraceIO str))

data SkeletonState = SkeletonState
{ _skelArgs :: [(Maybe TypedTerm, Maybe (CMSLLVM.AllLLVM CMS.SetupValue), Maybe Text)]
Expand Down

0 comments on commit 9c2ee9d

Please sign in to comment.