Skip to content

Commit

Permalink
Merge pull request #1349 from GaloisInc/toplevel-ioref
Browse files Browse the repository at this point in the history
Toplevel ioref
  • Loading branch information
brianhuffman authored Jun 22, 2021
2 parents e870373 + c9530c6 commit 6f52028
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 68 deletions.
9 changes: 6 additions & 3 deletions saw-remote-api/src/SAWServer/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module SAWServer.TopLevel (tl) where
import Control.Exception ( try, SomeException(..) )
import Control.Lens ( view, set )
import Control.Monad.State ( MonadIO(liftIO) )
import Data.IORef
import Data.Typeable (cast)

import SAWScript.Value ( TopLevel, runTopLevel )
Expand All @@ -20,13 +21,15 @@ tl act =
do st <- Argo.getState
let ro = view sawTopLevelRO st
rw = view sawTopLevelRW st
liftIO (try (runTopLevel act ro rw)) >>=
ref <- liftIO (newIORef rw)
liftIO (try (runTopLevel act ro ref)) >>=
\case
Left e@(SomeException e')
| Just (CryptolModuleException err warnings) <- cast e'
-> Argo.raise (cryptolError err warnings)
| otherwise
-> Argo.raise (verificationException e)
Right (res, rw') ->
do Argo.modifyState $ set sawTopLevelRW rw'
Right res ->
do rw' <- liftIO (readIORef ref)
Argo.modifyState $ set sawTopLevelRW rw'
return res
5 changes: 2 additions & 3 deletions saw/SAWScript/REPL/Command.hs
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,8 @@ sawScriptCmd str = do
Left err -> io $ print err
Right stmt ->
do ro <- getTopLevelRO
ie <- getEnvironment
((), ie') <- io $ runTopLevel (interpretStmt True stmt) ro ie
putEnvironment ie'
rwRef <- getEnvironmentRef
io $ runTopLevel (interpretStmt True stmt) ro rwRef

replFileName :: String
replFileName = "<stdin>"
Expand Down
99 changes: 54 additions & 45 deletions saw/SAWScript/REPL/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ module SAWScript.REPL.Monad (
, getTermEnv, modifyTermEnv, setTermEnv
, getExtraTypes, modifyExtraTypes, setExtraTypes
, getExtraNames, modifyExtraNames, setExtraNames
, getRW

-- ** SAWScript stuff
, getSharedContext
, getTopLevelRO
, getEnvironment, modifyEnvironment, putEnvironment
, getEnvironmentRef
, getSAWScriptNames
) where

Expand Down Expand Up @@ -96,50 +96,53 @@ deriving instance Typeable AIG.Proxy

-- REPL Environment ------------------------------------------------------------

-- REPL RW Environment.
data RW = RW
{ eContinue :: Bool
, eIsBatch :: Bool
, eTopLevelRO :: TopLevelRO
, environment :: TopLevelRW
-- REPL Environment.
data Refs = Refs
{ eContinue :: IORef Bool
, eIsBatch :: IORef Bool
, eTopLevelRO :: IORef TopLevelRO
, environment :: IORef TopLevelRW
}

-- | Initial, empty environment.
defaultRW :: Bool -> Options -> IO RW
defaultRW isBatch opts = do
defaultRefs :: Bool -> Options -> IO Refs
defaultRefs isBatch opts =
#ifdef USE_BUILTIN_ABC
(_biContext, ro, rw) <- buildTopLevelEnv (AIGProxy GIA.proxy) opts
do (_biContext, ro, rw) <- buildTopLevelEnv (AIGProxy GIA.proxy) opts
#else
(_biContext, ro, rw) <- buildTopLevelEnv (AIGProxy AIG.basicProxy) opts
do (_biContext, ro, rw) <- buildTopLevelEnv (AIGProxy AIG.basicProxy) opts
#endif

return RW
{ eContinue = True
, eIsBatch = isBatch
, eTopLevelRO = ro
, environment = rw
}
contRef <- newIORef True
batchRef <- newIORef isBatch
roRef <- newIORef ro
rwRef <- newIORef rw
return Refs
{ eContinue = contRef
, eIsBatch = batchRef
, eTopLevelRO = roRef
, environment = rwRef
}

-- | Build up the prompt for the REPL.
mkPrompt :: RW -> String
mkPrompt rw
| eIsBatch rw = ""
| otherwise = "sawscript> "
mkPrompt :: Bool {- ^ is batch -} -> String
mkPrompt batch
| batch = ""
| otherwise = "sawscript> "

mkTitle :: RW -> String
mkTitle _rw = "sawscript"
mkTitle :: Refs -> String
mkTitle _refs = "sawscript"


-- REPL Monad ------------------------------------------------------------------

-- | REPL_ context with InputT handling.
newtype REPL a = REPL { unREPL :: IORef RW -> IO a }
newtype REPL a = REPL { unREPL :: Refs -> IO a }

-- | Run a REPL action with a fresh environment.
runREPL :: Bool -> Options -> REPL a -> IO a
runREPL isBatch opts m = do
ref <- newIORef =<< defaultRW isBatch opts
unREPL m ref
runREPL isBatch opts m =
do refs <- defaultRefs isBatch opts
unREPL m refs

instance Functor REPL where
{-# INLINE fmap #-}
Expand Down Expand Up @@ -247,31 +250,35 @@ rethrowEvalError m = run `X.catch` rethrow
io :: IO a -> REPL a
io m = REPL (\ _ -> m)

getRW :: REPL RW
getRW = REPL readIORef
getRefs :: REPL Refs
getRefs = REPL pure

readRef :: (Refs -> IORef a) -> REPL a
readRef r = REPL (\refs -> readIORef (r refs))

modifyRW_ :: (RW -> RW) -> REPL ()
modifyRW_ f = REPL (\ ref -> modifyIORef ref f)
modifyRef :: (Refs -> IORef a) -> (a -> a) -> REPL ()
modifyRef r f = REPL (\refs -> modifyIORef (r refs) f)

-- | Construct the prompt for the current environment.
getPrompt :: REPL String
getPrompt = mkPrompt `fmap` getRW
getPrompt = mkPrompt <$> readRef eIsBatch

shouldContinue :: REPL Bool
shouldContinue = eContinue `fmap` getRW
shouldContinue = readRef eContinue

stop :: REPL ()
stop = modifyRW_ (\ rw -> rw { eContinue = False })
stop = modifyRef eContinue (const False)

unlessBatch :: REPL () -> REPL ()
unlessBatch body = do
rw <- getRW
unless (eIsBatch rw) body
unlessBatch body =
do batch <- readRef eIsBatch
unless batch body

setREPLTitle :: REPL ()
setREPLTitle = unlessBatch $ do
rw <- getRW
io (setTitle (mkTitle rw))
setREPLTitle =
unlessBatch $
do refs <- getRefs
io (setTitle (mkTitle refs))

getVars :: REPL (Map.Map T.Name M.IfaceDecl)
getVars = do
Expand Down Expand Up @@ -361,17 +368,19 @@ getSharedContext :: REPL SharedContext
getSharedContext = fmap roSharedContext getTopLevelRO

getTopLevelRO :: REPL TopLevelRO
getTopLevelRO = fmap eTopLevelRO getRW
getTopLevelRO = readRef eTopLevelRO

getEnvironmentRef :: REPL (IORef TopLevelRW)
getEnvironmentRef = environment <$> getRefs

getEnvironment :: REPL TopLevelRW
getEnvironment = fmap environment getRW
getEnvironment = readRef environment

putEnvironment :: TopLevelRW -> REPL ()
putEnvironment = modifyEnvironment . const

modifyEnvironment :: (TopLevelRW -> TopLevelRW) -> REPL ()
modifyEnvironment f = modifyRW_ $ \current ->
current { environment = f (environment current) }
modifyEnvironment = modifyRef environment

-- | Get visible variable names for Haskeline completion.
getSAWScriptNames :: REPL [String]
Expand Down
11 changes: 6 additions & 5 deletions src/SAWScript/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ import Data.Monoid
#endif
import Control.Monad.Except (MonadError(..))
import Control.Monad.State
import Control.Monad.Reader (ask)
import qualified Control.Exception as Ex
import qualified Data.ByteString as StrictBS
import qualified Data.ByteString.Lazy as BS
import qualified Data.ByteString.Lazy.UTF8 as B
import qualified Data.IntMap as IntMap
import Data.IORef
import Data.List (isPrefixOf, isInfixOf)
import qualified Data.Map as Map
import Data.Set (Set)
Expand Down Expand Up @@ -1145,10 +1145,11 @@ timePrim a = do
return r

failsPrim :: TopLevel SV.Value -> TopLevel ()
failsPrim m = TopLevel $ do
topRO <- ask
topRW <- Control.Monad.State.get
x <- liftIO $ Ex.try (runTopLevel m topRO topRW)
failsPrim m = do
topRO <- getTopLevelRO
topRW <- getTopLevelRW
ref <- liftIO $ newIORef topRW
x <- liftIO $ Ex.try (runTopLevel m topRO ref)
case x of
Left (ex :: Ex.SomeException) ->
do liftIO $ putStrLn "== Anticipated failure message =="
Expand Down
8 changes: 6 additions & 2 deletions src/SAWScript/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Control.Monad (unless, (>=>), when)
import Control.Monad.IO.Class (liftIO)
import qualified Data.ByteString as BS
import Data.Foldable (foldrM)
import Data.IORef
import qualified Data.Map as Map
import Data.Map ( Map )
import qualified Data.Set as Set
Expand Down Expand Up @@ -277,7 +278,9 @@ interpretStmts env stmts =
interpretStmts env' ss

stmtInterpreter :: StmtInterpreter
stmtInterpreter ro rw stmts = fmap fst $ runTopLevel (interpretStmts emptyLocal stmts) ro rw
stmtInterpreter ro rw stmts =
do ref <- newIORef rw
runTopLevel (interpretStmts emptyLocal stmts) ro ref

processStmtBind :: Bool -> SS.Pattern -> Maybe SS.Type -> SS.Expr -> TopLevel ()
processStmtBind printBinds pat _mc expr = do -- mx mt
Expand Down Expand Up @@ -491,7 +494,8 @@ processFile proxy opts file = do
oldpath <- getCurrentDirectory
file' <- canonicalizePath file
setCurrentDirectory (takeDirectory file')
_ <- runTopLevel (interpretFile file' True) ro rw
ref <- newIORef rw
_ <- runTopLevel (interpretFile file' True) ro ref
`X.catch` (handleException opts)
setCurrentDirectory oldpath
return ()
Expand Down
29 changes: 19 additions & 10 deletions src/SAWScript/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ import qualified Control.Exception as X
import qualified System.IO.Error as IOError
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ReaderT(..), ask, asks, local)
import Control.Monad.State (MonadState, StateT(..), get, gets, put)
import Control.Monad.State (MonadState(..), StateT(..), get, gets, put)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Data.IORef
import Data.List ( intersperse )
import qualified Data.Map as M
import Data.Map ( Map )
Expand Down Expand Up @@ -420,17 +421,25 @@ data TopLevelRW =
}

newtype TopLevel a =
TopLevel (ReaderT TopLevelRO (StateT TopLevelRW IO) a)
TopLevel (ReaderT TopLevelRO (ReaderT (IORef TopLevelRW) IO) a)
deriving (Applicative, Functor, Generic, Generic1, Monad, MonadIO, MonadThrow, MonadCatch, MonadMask)

deriving instance MonadReader TopLevelRO TopLevel
deriving instance MonadState TopLevelRW TopLevel

instance MonadState TopLevelRW TopLevel where
get = TopLevel (lift (ReaderT readIORef))
put s = TopLevel (lift (ReaderT (flip writeIORef s)))
state f = TopLevel (lift (ReaderT (flip atomicModifyIORef (swap . f))))
where swap (x, y) = (y, x)

instance Wrapped (TopLevel a) where

instance MonadFail TopLevel where
fail = throwTopLevel

runTopLevel :: TopLevel a -> TopLevelRO -> TopLevelRW -> IO (a, TopLevelRW)
runTopLevel (TopLevel m) ro rw = runStateT (runReaderT m ro) rw
runTopLevel :: TopLevel a -> TopLevelRO -> IORef TopLevelRW -> IO a
runTopLevel (TopLevel m) ro ref =
runReaderT (runReaderT m ro) ref

io :: IO a -> TopLevel a
io f = liftIO f
Expand Down Expand Up @@ -490,10 +499,10 @@ getTopLevelRO :: TopLevel TopLevelRO
getTopLevelRO = TopLevel ask

getTopLevelRW :: TopLevel TopLevelRW
getTopLevelRW = TopLevel get
getTopLevelRW = get

putTopLevelRW :: TopLevelRW -> TopLevel ()
putTopLevelRW rw = TopLevel (put rw)
putTopLevelRW rw = put rw

returnProof :: IsValue v => v -> TopLevel v
returnProof v = recordProof v >> return v
Expand All @@ -504,8 +513,8 @@ recordProof v =
putTopLevelRW rw { rwProofs = toValue v : rwProofs rw }

-- | Access the current state of Java Class translation
getJVMTrans :: TopLevel CJ.JVMContext
getJVMTrans = TopLevel (gets rwJVMTrans)
getJVMTrans :: TopLevel CJ.JVMContext
getJVMTrans = gets rwJVMTrans

-- | Access the current state of Java Class translation
putJVMTrans :: CJ.JVMContext -> TopLevel ()
Expand Down Expand Up @@ -1016,7 +1025,7 @@ addTraceReaderT str = underReaderT (addTraceTopLevel str)
-- | Similar to 'addTraceIO', but for the 'TopLevel' monad.
addTraceTopLevel :: String -> TopLevel a -> TopLevel a
addTraceTopLevel str action = action & _Wrapped' %~
underReaderT (underStateT (liftIO . addTraceIO str))
underReaderT (underReaderT (liftIO . addTraceIO str))

data SkeletonState = SkeletonState
{ _skelArgs :: [(Maybe TypedTerm, Maybe (CMSLLVM.AllLLVM CMS.SetupValue), Maybe Text)]
Expand Down

0 comments on commit 6f52028

Please sign in to comment.