diff --git a/saw-core/src/Verifier/SAW/Simulator/TermModel.hs b/saw-core/src/Verifier/SAW/Simulator/TermModel.hs index 3b75df3f01..d6bbc03fb9 100644 --- a/saw-core/src/Verifier/SAW/Simulator/TermModel.hs +++ b/saw-core/src/Verifier/SAW/Simulator/TermModel.hs @@ -25,6 +25,9 @@ module Verifier.SAW.Simulator.TermModel import Control.Monad import Control.Monad.Fix import Control.Monad.IO.Class +import Control.Monad.Trans +import Control.Monad.Trans.Except + import Data.IORef import Data.Maybe (fromMaybe) import qualified Data.Vector as V @@ -34,8 +37,10 @@ import Data.Set (Set) import qualified Data.Set as Set import Numeric.Natural + import Verifier.SAW.Prim (BitVector(..)) import qualified Verifier.SAW.Prim as Prim +import Verifier.SAW.Prelude.Constants import qualified Verifier.SAW.Simulator as Sim import Verifier.SAW.Simulator.Value import qualified Verifier.SAW.Simulator.Prims as Prims @@ -188,11 +193,17 @@ type instance Extra TermModel = VExtra data VExtra = VExtraTerm - (TValue TermModel) -- type of the term - Term -- term value (closed term!) + !(TValue TermModel) -- type of the term + !Term -- term value (closed term!) + | VExtraStream + !(TValue TermModel) -- type of the stream elements + !(Thunk TermModel -> MValue TermModel) -- function to compute stream values + !(IORef (Map Natural (Value TermModel))) -- cache of concrete values + !(Lazy IO Term) -- stream value as a term (closed term!) instance Show VExtra where show (VExtraTerm ty tm) = " " ++ showTerm tm ++ " : " ++ show ty + show (VExtraStream ty _ _ _) = "" data TermModelArray = TMArray @@ -376,6 +387,7 @@ readBackValue sc cfg = loop loop _ (TValue tv) = readBackTValue sc cfg tv loop _ (VExtra (VExtraTerm _tp tm)) = return tm + loop _ (VExtra (VExtraStream _tp _fn _ref tm)) = liftIO (force tm) loop tv@VPiType{} v@VFun{} = do (ecs, tm) <- readBackFuns tv v @@ -460,9 +472,6 @@ intTerm :: SharedContext -> VInt TermModel -> IO Term intTerm _ (Left tm) = pure tm intTerm sc (Right i) = scIntegerConst sc i -extraTerm :: VExtra -> IO Term -extraTerm (VExtraTerm _ tm) = pure tm - unOp :: SharedContext -> (SharedContext -> t -> IO t') -> @@ -663,8 +672,8 @@ prims sc cfg = case c of Right b -> if b then pure x else pure y Left tm -> - do x' <- extraTerm x - y' <- extraTerm y + do x' <- readBackValue sc cfg tp (VExtra x) + y' <- readBackValue sc cfg tp (VExtra y) a <- readBackTValue sc cfg tp VExtraTerm tp <$> scIte sc a tm x' y' @@ -892,6 +901,10 @@ constMap sc cfg = Map.union (Map.fromList localPrims) (Prims.constMap pms) , ("Prelude.intModMul" , intModMulOp sc) , ("Prelude.intModNeg" , intModNegOp sc) + -- Streams + , ("Prelude.MkStream", mkStreamOp sc cfg) + , ("Prelude.streamGet", streamGetOp) + -- Miscellaneous , ("Prelude.expByNat", Prims.expByNatOp pms) ] @@ -1008,7 +1021,7 @@ bvShiftOp sc cfg szf tmOp bvOp = do let n = szf n0 n0' <- scNat sc n0 w' <- readBackValue sc cfg (VVecType n VBoolType) w - dt <- scRequireDataType sc "Prelude.Nat" + dt <- scRequireDataType sc preludeNatIdent pn <- traverse (evalType cfg) (dtPrimName dt) amt' <- readBackValue sc cfg (VDataType pn [] []) amt tm <- tmOp sc n0' w' amt' @@ -1104,3 +1117,41 @@ intModBinOp sc termOp valOp n = binOp sc toTerm termOp' valOp termOp' _ x y = do n' <- scNat sc n termOp sc n' x y + +-- MkStream :: (a :: sort 0) -> (Nat -> a) -> Stream a; +mkStreamOp :: (?recordEC :: BoundECRecorder) => + SharedContext -> Sim.SimulatorConfig TermModel -> TmPrim +mkStreamOp sc cfg = + Prims.tvalFun $ \ty -> + Prims.strictFun $ \f -> + Prims.PrimExcept $ + case f of + VFun nm fn -> + do ref <- liftIO (newIORef mempty) + stm <- liftIO $ delay $ do + natDT <- scRequireDataType sc preludeNatIdent + natPN <- traverse (evalType cfg) (dtPrimName natDT) + ty' <- readBackTValue sc cfg ty + ftm <- readBackValue sc cfg (VPiType nm (VDataType natPN [] []) (VNondependentPi ty)) f + scCtorApp sc (mkIdent preludeModuleName "MkStream") [ty',ftm] + return (VExtra (VExtraStream ty fn ref stm)) + + _ -> throwE "expected function value" + +-- streamGet :: (a :: sort 0) -> Stream a -> Nat -> a; +streamGetOp :: TmPrim +streamGetOp = + Prims.tvalFun $ \_ty -> + Prims.strictFun $ \xs -> + Prims.natFun $ \ix -> + Prims.PrimExcept $ + case xs of + VExtra (VExtraStream _ fn ref _tm) -> + liftIO (Map.lookup ix <$> readIORef ref) >>= \case + Just v -> return v + Nothing -> lift $ + do v <- fn (ready (VNat ix)) + liftIO (atomicModifyIORef' ref (\m' -> (Map.insert ix v m', ()))) + return v + + _ -> throwE "expected stream value"