From 33694b201955650aea1d93d196866f0066b92b8e Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Tue, 25 Apr 2023 19:46:46 -0400 Subject: [PATCH 01/10] add refines SAW command to build refinesS terms --- .../src/Verifier/SAW/TypedTerm.hs | 2 +- examples/mr_solver/mr_solver_unit_tests.saw | 71 ++++++++++++----- src/SAWScript/Builtins.hs | 51 +++++++++++- src/SAWScript/Interpreter.hs | 14 ++++ src/SAWScript/Prover/MRSolver.hs | 2 +- src/SAWScript/Prover/MRSolver/Monad.hs | 77 +++++++++++++++---- src/SAWScript/Prover/MRSolver/Solver.hs | 67 ++++++++++++++++ src/SAWScript/Prover/MRSolver/Term.hs | 29 +++++++ 8 files changed, 274 insertions(+), 39 deletions(-) diff --git a/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs b/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs index 02e9c1f413..568d92e3f9 100644 --- a/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs +++ b/cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs @@ -49,7 +49,7 @@ data TypedTermType deriving Show --- | Convert the 'ttTerm' field of a 'TypedTerm' to a SAW core term +-- | Convert the 'ttType' field of a 'TypedTerm' to a SAW core term ttTypeAsTerm :: SharedContext -> Env -> TypedTerm -> IO Term ttTypeAsTerm sc env (TypedTerm (TypedTermSchema schema) _) = importSchema sc env schema diff --git a/examples/mr_solver/mr_solver_unit_tests.saw b/examples/mr_solver/mr_solver_unit_tests.saw index 71eaad2858..3947f008b0 100644 --- a/examples/mr_solver/mr_solver_unit_tests.saw +++ b/examples/mr_solver/mr_solver_unit_tests.saw @@ -28,42 +28,57 @@ const1 <- parse_core const1_core; // const0 <= const0 run_test "const0 |= const0" (mr_solver_query const0 const0) true; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] const0 const0); +// (testing that "refines [] const0 const0" is actually "const0 <= const0") let const0_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", const0_core, ") x) ", "((", const0_core, ") x)"]; -prove_extcore mrsolver (parse_core const0_refines); +run_test "refines [] const0 const0" (is_convertible (parse_core const0_refines) + (refines [] const0 const0)) true; -// The function test_fun0 = const0 +// The function test_fun0 <= const0 test_fun0 <- parse_core_mod "test_funs" "test_fun0"; run_test "const0 |= test_fun0" (mr_solver_query const0 test_fun0) true; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] const0 test_fun0); +// (testing that "refines [] const0 test_fun0" is actually "const0 <= test_fun0") let const0_test_fun0_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", const0_core, ") x) ", "(test_fun0 x)"]; -prove_extcore mrsolver (parse_core_mod "test_funs" const0_test_fun0_refines); +run_test "refines [] const0 test_fun0" (is_convertible (parse_core_mod "test_funs" const0_test_fun0_refines) + (refines [] const0 test_fun0)) true; // not const0 <= const1 run_test "const0 |= const1" (mr_solver_query const0 const1) false; // (using mrsolver tactic - fails as expected) -// let const0_const1_refines = -// str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", -// "((", const0_core, ") x) ", "((", const1_core, ") x)"]; -// prove_extcore mrsolver (parse_core const0_const1_refines); +// prove_extcore mrsolver (refines [] const0 const1); +// (testing that "refines [] const0 const1" is actually "const0 <= const1") +let const0_const1_refines = + str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", + "((", const0_core, ") x) ", "((", const1_core, ") x)"]; +run_test "refines [] const0 const1" (is_convertible (parse_core const0_const1_refines) + (refines [] const0 const1)) true; // The function test_fun1 = const1 test_fun1 <- parse_core_mod "test_funs" "test_fun1"; run_test "const1 |= test_fun1" (mr_solver_query const1 test_fun1) true; run_test "const0 |= test_fun1" (mr_solver_query const0 test_fun1) false; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] const1 test_fun1); +// (testing that "refines [] const1 test_fun1" is actually "const1 <= test_fun1") let const1_test_fun1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", const1_core, ") x) ", "(test_fun1 x)"]; -prove_extcore mrsolver (parse_core_mod "test_funs" const1_test_fun1_refines); +run_test "refines [] const1 test_fun1" (is_convertible (parse_core_mod "test_funs" const1_test_fun1_refines) + (refines [] const1 test_fun1)) true; // (using mrsolver tactic - fails as expected) -// let const0_test_fun1_refines = -// str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", -// "((", const0_core, ") x) ", "(test_fun1 x)"]; -// prove_extcore mrsolver (parse_core_mod "test_funs" const0_test_fun1_refines); +// prove_extcore mrsolver (refines [] const0 test_fun1); +// (testing that "refines [] const0 test_fun1" is actually "const0 <= test_fun1") +let const0_test_fun1_refines = + str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", + "((", const0_core, ") x) ", "(test_fun1 x)"]; +run_test "refines [] const0 test_fun1" (is_convertible (parse_core_mod "test_funs" const0_test_fun1_refines) + (refines [] const0 test_fun1)) true; // ifxEq0 x = If x == 0 then x else 0; should be equal to 0 let ifxEq0_core = "\\ (x:Vec 64 Bool) -> \ @@ -76,18 +91,25 @@ ifxEq0 <- parse_core ifxEq0_core; // ifxEq0 <= const0 run_test "ifxEq0 |= const0" (mr_solver_query ifxEq0 const0) true; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] ifxEq0 const0); +// (testing that "refines [] ifxEq0 const0" is actually "ifxEq0 <= const0") let ifxEq0_const0_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", ifxEq0_core, ") x) ", "((", const0_core, ") x)"]; -prove_extcore mrsolver (parse_core ifxEq0_const0_refines); +run_test "refines [] ifxEq0 const0" (is_convertible (parse_core ifxEq0_const0_refines) + (refines [] ifxEq0 const0)) true; + // not ifxEq0 <= const1 run_test "ifxEq0 |= const1" (mr_solver_query ifxEq0 const1) false; // (using mrsolver tactic - fails as expected) -// let ifxEq0_const1_refines = -// str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", -// "((", ifxEq0_core, ") x) ", "((", const1_core, ") x)"]; -// prove_extcore mrsolver (parse_core ifxEq0_const1_refines); +// prove_extcore mrsolver (refines [] ifxEq0 const1); +// (testing that "refines [] ifxEq0 const1" is actually "ifxEq0 <= const1") +let ifxEq0_const1_refines = + str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", + "((", ifxEq0_core, ") x) ", "((", const1_core, ") x)"]; +run_test "refines [] ifxEq0 const1" (is_convertible (parse_core ifxEq0_const1_refines) + (refines [] ifxEq0 const1)) true; // noErrors1 x = existsS x. retS x let noErrors1_core = @@ -97,18 +119,24 @@ noErrors1 <- parse_core noErrors1_core; // const0 <= noErrors run_test "noErrors1 |= noErrors1" (mr_solver_query noErrors1 noErrors1) true; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] noErrors1 noErrors1); +// (testing that "refines [] noErrors1 noErrors1" is actually "noErrors1 <= noErrors1") let noErrors1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", noErrors1_core, ") x) ", "((", noErrors1_core, ") x)"]; -prove_extcore mrsolver (parse_core noErrors1_refines); +run_test "refines [] noErrors1 noErrors1" (is_convertible (parse_core noErrors1_refines) + (refines [] noErrors1 noErrors1)) true; // const1 <= noErrors run_test "const1 |= noErrors1" (mr_solver_query const1 noErrors1) true; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] const1 noErrors1); +// (testing that "refines [] const1 noErrors1" is actually "const1 <= noErrors1") let const1_noErrors1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", const1_core, ") x) ", "((", noErrors1_core, ") x)"]; -prove_extcore mrsolver (parse_core const1_noErrors1_refines); +run_test "refines [] const1 noErrors1" (is_convertible (parse_core const1_noErrors1_refines) + (refines [] const1 noErrors1)) true; // noErrorsRec1 _ = orS (existsM x. returnM x) (noErrorsRec1 x) // Intuitively, this specifies functions that either return a value or loop @@ -137,7 +165,10 @@ loop1 <- parse_core loop1_core; // loop1 <= noErrorsRec1 run_test "loop1 |= noErrorsRec1" (mr_solver_query loop1 noErrorsRec1) true; // (using mrsolver tactic) +prove_extcore mrsolver (refines [] loop1 noErrorsRec1); +// (testing that "refines [] loop1 noErrorsRec1" is actually "loop1 <= noErrorsRec1") let loop1_noErrorsRec1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", loop1_core, ") x) ", "((", noErrorsRec1_core, ") x)"]; -prove_extcore mrsolver (parse_core loop1_noErrorsRec1_refines); +run_test "refines [] loop1 noErrorsRec1" (is_convertible (parse_core loop1_noErrorsRec1_refines) + (refines [] loop1 noErrorsRec1)) true; diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 6849359be5..1bf5f0131f 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -2209,12 +2209,26 @@ mrSolverTactic sc = execTactic $ Tactic $ \goal -> lift $ do case sequentState (goalSequent goal) of Unfocused -> fail "mrsolver: focus required" HypFocus _ _ -> fail "mrsolver: cannot apply mrsolver in a hypothesis" + ConclFocus (asPiList . unProp -> (args, asApplyAll -> + (asGlobalDef -> Just "Prelude.refinesS", + [ev1, ev2, stack1, stack2, + asApplyAll -> (asGlobalDef -> Just "Prelude.eqPreRel", _), + asApplyAll -> (asGlobalDef -> Just "Prelude.eqPostRel", _), + rtp1, rtp2, + asApplyAll -> (asGlobalDef -> Just "Prelude.eqRR", _), + t1, t2]))) _ -> + on_refinesS dlvl goal args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2 ConclFocus (asPiList . unProp -> (args, asApplyAll -> (asGlobalDef -> Just "Prelude.refinesS_eq", [ev, stack, rtp, t1, t2]))) _ -> - do tp <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev, stack, rtp] - let tt1 = TypedTerm (TypedTermOther tp) t1 - let tt2 = TypedTerm (TypedTermOther tp) t2 + on_refinesS dlvl goal args ev ev stack stack rtp rtp t1 t2 + _ -> error "[MRSolver] cannot apply mrsolver tactic to a refinesS goal with non-trivial RPre/RPost/RR" + where + on_refinesS dlvl goal args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2 = + do tp1 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev1, stack1, rtp1] + tp2 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev2, stack2, rtp2] + let tt1 = TypedTerm (TypedTermOther tp1) t1 + let tt2 = TypedTerm (TypedTermOther tp2) t2 (diff, res) <- mrSolver Prover.askMRSolver (Just "mrsolver") sc args tt1 tt2 case res of Left err | dlvl == 0 -> @@ -2231,7 +2245,6 @@ mrSolverTactic sc = execTactic $ Tactic $ \goal -> lift $ do printOutLnTop Info (printf "[MRSolver] Success in %s" (show diff)) >> let stats = solverStats "MRSOLVER ADMITTED" (sequentSharedSize (goalSequent goal)) in return ((), stats, [], leafEvidence MrSolverEvidence) - _ -> error "mrsolver tactic not applied to a refinesS_eq goal" -- | Run Mr Solver to prove that the first term refines the second, adding -- any relevant 'Prover.FunAssump's to the 'Prover.MREnv' if the first argument @@ -2318,6 +2331,36 @@ mrSolverSetDebug dlvl = modify (\rw -> rw { rwMRSolverEnv = Prover.mrEnvSetDebugLevel dlvl (rwMRSolverEnv rw) }) +-- | Given a list of names and types representing variables over which to +-- quantify as as well as two terms containing those variables, which may be +-- terms or functions in the SpecM monad, construct the SAWCore term which is +-- the refinement (@Prelude.refinesS@) of the given terms, with the given +-- variables generalized with a Pi type. +refinesTerm :: [(Text, C.Schema)] -> TypedTerm -> TypedTerm -> TopLevel TypedTerm +refinesTerm args tt1 tt2 = + do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get + sc <- getSharedContext + env <- rwMRSolverEnv <$> get + args' <- io $ mapM (mapM (argType sc)) args + m1 <- ttTerm <$> ensureMonadicTerm sc tt1 + m2 <- ttTerm <$> ensureMonadicTerm sc tt2 + res <- io $ Prover.refinementTerm sc env Nothing args' m1 m2 + case res of + Left err | dlvl == 0 -> + io (putStrLn $ Prover.showMRFailure err) >> + printOutLnTop Info (printf "[MRSolver] Failed to build refinement term") >> + io (Exit.exitWith $ Exit.ExitFailure 1) + Left err -> + -- we ignore the MRFailure context here since it will have already + -- been printed by the debug trace + io (putStrLn $ Prover.showMRFailureNoCtx err) >> + printOutLnTop Info (printf "[MRSolver] Failed to build refinement term") >> + io (Exit.exitWith $ Exit.ExitFailure 1) + Right t -> + io (mkTypedTerm sc t) + where argType sc (C.Forall [] [] a) = Cryptol.importType sc Cryptol.emptyEnv a + argType _ _ = fail "refinesTerm: given a non-monomorphic type" + setMonadification :: SharedContext -> String -> String -> Bool -> TopLevel () setMonadification sc cry_str saw_str poly_p = do rw <- get diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index e63271db55..48a3f98ced 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3861,6 +3861,15 @@ primitives = Map.fromList [ "Use MRSolver to prove a current goal of the form:" , "(a1:A1) -> ... -> (an:A1) -> refinesS_eq ..." ] + , prim "refines" "[(String, Type)] -> Term -> Term -> Term" + (funVal3 refinesTerm) + Experimental + [ "Given a list of names and types representing variables over which" + , " to quantify as as well as two terms containing those variables," + , " which may be terms or functions in the SpecM monad, construct the" + , " SAWCore term which is the refinement (`Prelude.refinesS`) of the" + , " given terms, with the given variables generalized with a Pi type." ] + --------------------------------------------------------------------- , prim "monadify_term" "Term -> TopLevel Term" @@ -4315,6 +4324,11 @@ primitives = Map.fromList funVal2 f _ _ = VLambda $ \a -> return $ VLambda $ \b -> fmap toValue (f (fromValue a) (fromValue b)) + funVal3 :: forall a b c t. (FromValue a, FromValue b, FromValue c, IsValue t) => (a -> b -> c -> TopLevel t) + -> Options -> BuiltinContext -> Value + funVal3 f _ _ = VLambda $ \a -> return $ VLambda $ \b -> return $ VLambda $ \c -> + fmap toValue (f (fromValue a) (fromValue b) (fromValue c)) + scVal :: forall t. IsValue t => (SharedContext -> t) -> Options -> BuiltinContext -> Value scVal f _ bic = toValue (f (biSharedContext bic)) diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index c31cb1deb5..bc680bdccc 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -9,7 +9,7 @@ Portability : non-portable (language extensions) -} module SAWScript.Prover.MRSolver - (askMRSolver, assumeMRSolver, MRSolverResult, + (askMRSolver, assumeMRSolver, MRSolverResult, refinementTerm, MRFailure(..), showMRFailure, showMRFailureNoCtx, FunAssump(..), FunAssumpRHS(..), MREnv(..), emptyMREnv, mrEnvAddFunAssump, mrEnvSetDebugLevel, diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index f025f2b724..05390b1147 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -23,14 +23,15 @@ monadic combinators for operating on terms. module SAWScript.Prover.MRSolver.Monad where +import Data.Maybe (fromJust) import Data.List (find, findIndex, foldl') +import Data.Foldable (foldrM) import qualified Data.Text as T import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State import Control.Monad.Except import Control.Monad.Trans.Maybe -import GHC.Generics import Data.Map (Map) import qualified Data.Map as Map @@ -72,6 +73,7 @@ data MRFailure | CannotLookupFunDef FunName | RecursiveUnfold FunName | MalformedLetRecTypes Term + | MalformedDataTypeAssump Term | MalformedDefs Term | MalformedComp Term | NotCompFunType Term @@ -151,6 +153,9 @@ instance PrettyInCtx MRFailure where ppWithPrefix "Recursive unfolding of function inside its own body:" nm prettyInCtx (MalformedLetRecTypes t) = ppWithPrefix "Not a ground LetRecTypes list:" t + prettyInCtx (MalformedDataTypeAssump t) = + ppWithPrefix ("assertS/assumeS expects a Bool, Either, or TCNum equality" + ++ " with a constructor on one side, got:") t prettyInCtx (MalformedDefs t) = ppWithPrefix "Cannot handle multiFixS recursive definitions term:" t prettyInCtx (MalformedComp t) = @@ -285,18 +290,6 @@ instance PrettyInCtx CoIndHyp where return "|=", prettyTermApp (funNameTerm f2) args2] --- | An assumption that something is equal to one of the constructors of a --- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term' -data DataTypeAssump - = IsLeft Term | IsRight Term | IsNum Term | IsInf - deriving (Generic, Show, TermLike) - -instance PrettyInCtx DataTypeAssump where - prettyInCtx (IsLeft x) = prettyInCtx x >>= ppWithPrefix "Left _ _" - prettyInCtx (IsRight x) = prettyInCtx x >>= ppWithPrefix "Right _ _" - prettyInCtx (IsNum x) = prettyInCtx x >>= ppWithPrefix "TCNum" - prettyInCtx IsInf = return "TCInf" - -- | A map from 'Term's to 'DataTypeAssump's over that term type DataTypeAssumps = HashMap Term DataTypeAssump @@ -825,6 +818,64 @@ mrCallsFun f = memoFixTermFun $ \recurse t -> case t of (unwrapTermF -> tf) -> foldM (\b t' -> if b then return b else recurse t') False tf +-- | Given a 'DataTypeAssump' and a 'Term' to which it applies, return the +-- equality representing the proposition that the 'DataTypeAssump' holds. +-- For example, @mrDataTypeAssumpTerm x (IsLeft y)@ for @x : Either a b@ +-- would return @Eq (Either a b) x (Left a b y)@. +mrDataTypeAssumpTerm :: Term -> DataTypeAssump -> MRM Term +mrDataTypeAssumpTerm x dt = + do tp <- mrTypeOf x + y <- case dt of + IsLeft y + | Just (primName -> "Prelude.Either", [a, b]) <- asDataType tp -> + liftSC2 scCtorApp "Prelude.Left" [a, b, y] + | otherwise -> error $ "IsLeft expected Either, got: " ++ show tp + IsRight y + | Just (primName -> "Prelude.Either", [a, b]) <- asDataType tp -> + liftSC2 scCtorApp "Prelude.Right" [a, b, y] + | otherwise -> error $ "IsRight expected Either, got: " ++ show tp + IsNum y -> liftSC2 scCtorApp "Prelude.TCNum" [y] + IsInf -> liftSC2 scCtorApp "Prelude.TCInf" [] + liftSC2 scGlobalApply "Prelude.Eq" [tp, x, y] + +-- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of the +-- given 'Term's, after quantifying over all current 'mrUVars' with Pi types +-- and adding calls to @assertS@ on the right hand side for any current +-- 'mrAssumps' and/or 'mrDataTypeAssump's +mrRefinementGoal :: Term -> Term -> MRM Term +mrRefinementGoal t1 t2 = + do (SpecMParams ev1 stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1 + (SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2 + assumps <- mrAssumptions + assumpsAssert <- liftSC2 scGlobalApply "Prelude.assertBoolS" + [ev2, stack2, assumps] + t2' <- case asBool assumps of + Just True -> return t2 + _ -> bindConst ev2 stack2 tp2 assumpsAssert t2 + dtAssumps <- HashMap.toList <$> mrDataTypeAssumps + dtAssumpAsserts <- forM dtAssumps $ \(nm, assump) -> + do assump_tm <- mrDataTypeAssumpTerm nm assump + liftSC2 scGlobalApply "Prelude.assertS" + [ev2, stack2, assump_tm] + t2'' <- foldrM (bindConst ev2 stack2 tp2) t2' dtAssumpAsserts + coIndHyps <- mrCoIndHyps + (rpre, rpost, rr) <- + if Map.null coIndHyps + then (,,) <$> liftSC2 scGlobalApply "Prelude.eqPreRel" [ev2, stack2] + <*> liftSC2 scGlobalApply "Prelude.eqPostRel" [ev2, stack2] + <*> liftSC2 scGlobalApply "Prelude.eqRR" [tp2] + else error "FIXME: Handle CoIndHyps in mrRefinementGoal" + ref_tm <- liftSC2 scGlobalApply "Prelude.refinesS" + [ev1, ev2, stack1, stack2, rpre, rpost, + tp1, tp2, rr, t1, t2''] + uvars <- mrUVarsOuterToInner + liftSC2 scPiList uvars ref_tm + where bindConst ev stack tp x y = + do unit <- liftSC0 scUnitType + const_y <- liftSC3 incVars 0 1 y >>= liftSC3 scLambda "_" unit + liftSC2 scGlobalApply "Prelude.bindS" + [ev, stack, unit, tp, x, const_y] + ---------------------------------------------------------------------- -- * Monadic Operations on Mr. Solver State diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 2aa274586d..59b4ed7881 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -304,6 +304,16 @@ normComp (CompTerm t) = do unit_tp <- mrUnitType return $ AssumeBoolBind cond (CompFunReturn (SpecMParams ev stack) unit_tp) + (isGlobalDef "Prelude.assertS" -> Just (), [ev, stack, prop]) -> + do unit_tp <- mrUnitType + assert <- either AssertBoolBind (uncurry AssertDataTypeBind) + <$> normCompAssertAssumeBody prop + return $ assert (CompFunReturn (SpecMParams ev stack) unit_tp) + (isGlobalDef "Prelude.assumeS" -> Just (), [ev, stack, prop]) -> + do unit_tp <- mrUnitType + assume <- either AssumeBoolBind (uncurry AssumeDataTypeBind) + <$> normCompAssertAssumeBody prop + return $ assume (CompFunReturn (SpecMParams ev stack) unit_tp) (isGlobalDef "Prelude.existsS" -> Just (), [ev, stack, tp]) -> do unit_tp <- mrUnitType return $ ExistsBind (Type tp) (CompFunReturn @@ -447,6 +457,26 @@ normComp (CompTerm t) = _ -> throwMRFailure (MalformedComp t) +-- | Given the body of an @assertS@ or @assumeS@, return either the boolean +-- term @x@ if the body is of the form @Eq Bool x True@ or @Eq Bool True x@, +-- or a 'Term' @x@ and a 'DataTypeAssump' @c@ if the body is of the form +-- @Eq _ x (c ...)@ or @Eq _ (c ...) x@ +normCompAssertAssumeBody :: Term -> MRM (Either Term (Term, DataTypeAssump)) +normCompAssertAssumeBody (asEq -> Just (_, x1, asBool -> Just True)) = + return $ Left x1 +normCompAssertAssumeBody (asEq -> Just (_, asBool -> Just True, x2)) = + return $ Left x2 +normCompAssertAssumeBody (asEq -> Just (_, x1, asEither -> Just e2)) = + return $ Right (x1, either IsLeft IsRight e2) +normCompAssertAssumeBody (asEq -> Just (_, asEither -> Just e1, x2)) = + return $ Right (x2, either IsLeft IsRight e1) +normCompAssertAssumeBody (asEq -> Just (_, x1, asNum -> Just e2)) = + return $ Right (x1, either IsNum (const IsInf) e2) +normCompAssertAssumeBody (asEq -> Just (_, asNum -> Just e1, x2)) = + return $ Right (x2, either IsNum (const IsInf) e1) +normCompAssertAssumeBody prop = + throwMRFailure (MalformedDataTypeAssump prop) + -- | Bind a computation in whnf with a function, and normalize normBind :: NormComp -> CompFun -> MRM NormComp @@ -464,6 +494,10 @@ normBind (AssertBoolBind cond f) k = return $ AssertBoolBind cond (compFunComp f k) normBind (AssumeBoolBind cond f) k = return $ AssumeBoolBind cond (compFunComp f k) +normBind (AssertDataTypeBind x assump f) k = + return $ AssertDataTypeBind x assump (compFunComp f k) +normBind (AssumeDataTypeBind x assump f) k = + return $ AssumeDataTypeBind x assump (compFunComp f k) normBind (ExistsBind tp f) k = return $ ExistsBind tp (compFunComp f k) normBind (ForallBind tp f) k = return $ ForallBind tp (compFunComp f k) normBind (FunBind f args k1) k2 @@ -925,6 +959,13 @@ mrRefines' (AssertBoolBind cond1 k1) m2 = do m1 <- liftSC0 scUnitValue >>= applyCompFun k1 withAssumption cond1 $ mrRefines m1 m2 +mrRefines' m1 (AssumeDataTypeBind x2 assump2 k2) = + do m2 <- liftSC0 scUnitValue >>= applyCompFun k2 + withDataTypeAssump x2 assump2 $ mrRefines m1 m2 +mrRefines' (AssertDataTypeBind x1 assump1 k1) m2 = + do m1 <- liftSC0 scUnitValue >>= applyCompFun k1 + withDataTypeAssump x1 assump1 $ mrRefines m1 m2 + mrRefines' m1 (ForallBind tp f2) = let nm = maybe "x" id (compFunVarName f2) in withUVarLift nm tp (m1,f2) $ \x (m1',f2') -> @@ -1117,6 +1158,14 @@ mrRefines'' (AssumeBoolBind cond1 k1) m2 = if cond1_pv then mrRefines m1 m2 else throwMRFailure (AssumptionNotProvable cond1) +-- FIXME: Do something smarter here? +mrRefines'' _ (AssertDataTypeBind t2 assump2 _) = + do cond2 <- mrDataTypeAssumpTerm t2 assump2 + throwMRFailure (AssertionNotProvable cond2) +mrRefines'' (AssumeDataTypeBind t1 assump1 _) _ = + do cond1 <- mrDataTypeAssumpTerm t1 assump1 + throwMRFailure (AssertionNotProvable cond1) + mrRefines'' m1 (ExistsBind tp f2) = do let nm = maybe "x" id (compFunVarName f2) evar <- mrFreshEVar nm tp @@ -1299,3 +1348,21 @@ assumeMRSolver sc env timeout args t1 t2 = do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc mrRefinesFunH (askMRSolverH (\_ _ -> return ())) [] tp1 t1 tp2 t2 + +-- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of fully +-- applied versions of the given 'Term's, after quantifying over all the given +-- arguments as well as any additional arguments needed to fully apply the given +-- terms, and adding any calls to @assertS@ on the right hand side needed for +-- unifying the arguments generated when fully applying the given terms +refinementTerm :: + SharedContext -> + MREnv {- ^ The Mr Solver environment -} -> + Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> + Term -> Term -> IO (Either MRFailure Term) +refinementTerm sc env timeout args t1 t2 = + runMRM sc timeout env $ + withUVars (mrVarCtxFromOuterToInner args) $ \_ -> + do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc + tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc + mrRefinesFunH mrRefinementGoal [] tp1 t1 tp2 t2 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index 72974a5b36..cf3038a236 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -174,6 +174,12 @@ mrVarCtxFromOuterToInner = mrVarCtxFromInnerToOuter . reverse specMParamsArgs :: SpecMParams Term -> [Term] specMParamsArgs (SpecMParams ev stack) = [ev, stack] +-- | An assumption that something is equal to one of the constructors of a +-- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term' +data DataTypeAssump + = IsLeft Term | IsRight Term | IsNum Term | IsInf + deriving (Generic, Show, TermLike) + -- | A Haskell representation of a @SpecM@ in "monadic normal form" data NormComp = RetS Term -- ^ A term @retS _ _ a x@ @@ -184,6 +190,8 @@ data NormComp | OrS Comp Comp -- ^ an @orS@ computation | AssertBoolBind Term CompFun -- ^ the bind of an @assertBoolS@ computation | AssumeBoolBind Term CompFun -- ^ the bind of an @assumeBoolS@ computation + | AssertDataTypeBind Term DataTypeAssump CompFun -- ^ the bind of a datatype @assertS@ computation + | AssumeDataTypeBind Term DataTypeAssump CompFun -- ^ the bind of a datatype @assumeS@ computation | ExistsBind Type CompFun -- ^ the bind of an @existsS@ computation | ForallBind Type CompFun -- ^ the bind of a @forallS@ computation | FunBind FunName [Term] CompFun @@ -571,6 +579,15 @@ instance PrettyInCtx FunName where foldM (\pp proj -> (pp <>) <$> prettyInCtx proj) (ppName $ globalDefName g) projs +instance PrettyInCtx DataTypeAssump where + prettyInCtx (IsLeft x) = + prettyAppList [return "Left _ _", parens <$> prettyInCtx x] + prettyInCtx (IsRight x) = + prettyAppList [return "Right _ _", parens <$> prettyInCtx x] + prettyInCtx (IsNum x) = + prettyAppList [return "TCNum", parens <$> prettyInCtx x] + prettyInCtx IsInf = return "TCInf" + instance PrettyInCtx Comp where prettyInCtx (CompTerm t) = prettyInCtx t prettyInCtx (CompBind c f) = @@ -614,6 +631,18 @@ instance PrettyInCtx NormComp where prettyAppList [return "assumeBoolS", return "_", return "_", parens <$> prettyInCtx cond, return ">>=", parens <$> prettyInCtx k] + prettyInCtx (AssertDataTypeBind x y k) = + prettyAppList [return "assertS", return "_", return "_", + parens <$> prettyAppList [return "Eq", return "_", + parens <$> prettyInCtx x, + parens <$> prettyInCtx y], + return ">>=", parens <$> prettyInCtx k] + prettyInCtx (AssumeDataTypeBind x y k) = + prettyAppList [return "assumeS", return "_", return "_", + parens <$> prettyAppList [return "Eq", return "_", + parens <$> prettyInCtx x, + parens <$> prettyInCtx y], + return ">>=", parens <$> prettyInCtx k] prettyInCtx (ExistsBind tp k) = prettyAppList [return "existsS", return "_", return "_", prettyInCtx tp, return ">>=", parens <$> prettyInCtx k] From 3107ded513fc5d63aab23e9210c29bba7de7dcb0 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Thu, 27 Apr 2023 12:53:25 -0400 Subject: [PATCH 02/10] have refines command use fresh_symbolic variables --- examples/mr_solver/mr_solver_unit_tests.saw | 6 ++++++ src/SAWScript/Builtins.hs | 15 +++++++-------- src/SAWScript/Interpreter.hs | 8 ++++---- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/examples/mr_solver/mr_solver_unit_tests.saw b/examples/mr_solver/mr_solver_unit_tests.saw index 3947f008b0..93c06a993b 100644 --- a/examples/mr_solver/mr_solver_unit_tests.saw +++ b/examples/mr_solver/mr_solver_unit_tests.saw @@ -35,6 +35,12 @@ let const0_refines = "((", const0_core, ") x) ", "((", const0_core, ") x)"]; run_test "refines [] const0 const0" (is_convertible (parse_core const0_refines) (refines [] const0 const0)) true; +// (testing that "refines [x] ..." gives the same expression as "refines [] ...") +x <- fresh_symbolic "x" {| [64] |}; +run_test "refines [x] (const0 x) (const0 x)" + (is_convertible (refines [] const0 const0) + (refines [x] (term_apply const0 [x]) + (term_apply const0 [x]))) true; // The function test_fun0 <= const0 test_fun0 <- parse_core_mod "test_funs" "test_fun0"; diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 1bf5f0131f..91d8616bc7 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -2336,15 +2336,16 @@ mrSolverSetDebug dlvl = -- terms or functions in the SpecM monad, construct the SAWCore term which is -- the refinement (@Prelude.refinesS@) of the given terms, with the given -- variables generalized with a Pi type. -refinesTerm :: [(Text, C.Schema)] -> TypedTerm -> TypedTerm -> TopLevel TypedTerm -refinesTerm args tt1 tt2 = +refinesTerm :: [TypedTerm] -> TypedTerm -> TypedTerm -> TopLevel TypedTerm +refinesTerm vars tt1 tt2 = do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get sc <- getSharedContext env <- rwMRSolverEnv <$> get - args' <- io $ mapM (mapM (argType sc)) args - m1 <- ttTerm <$> ensureMonadicTerm sc tt1 - m2 <- ttTerm <$> ensureMonadicTerm sc tt2 - res <- io $ Prover.refinementTerm sc env Nothing args' m1 m2 + tt1' <- lambdas vars tt1 + tt2' <- lambdas vars tt2 + m1 <- ttTerm <$> ensureMonadicTerm sc tt1' + m2 <- ttTerm <$> ensureMonadicTerm sc tt2' + res <- io $ Prover.refinementTerm sc env Nothing [] m1 m2 case res of Left err | dlvl == 0 -> io (putStrLn $ Prover.showMRFailure err) >> @@ -2358,8 +2359,6 @@ refinesTerm args tt1 tt2 = io (Exit.exitWith $ Exit.ExitFailure 1) Right t -> io (mkTypedTerm sc t) - where argType sc (C.Forall [] [] a) = Cryptol.importType sc Cryptol.emptyEnv a - argType _ _ = fail "refinesTerm: given a non-monomorphic type" setMonadification :: SharedContext -> String -> String -> Bool -> TopLevel () setMonadification sc cry_str saw_str poly_p = diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 48a3f98ced..235518f4d0 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3861,12 +3861,12 @@ primitives = Map.fromList [ "Use MRSolver to prove a current goal of the form:" , "(a1:A1) -> ... -> (an:A1) -> refinesS_eq ..." ] - , prim "refines" "[(String, Type)] -> Term -> Term -> Term" + , prim "refines" "[Term] -> Term -> Term -> Term" (funVal3 refinesTerm) Experimental - [ "Given a list of names and types representing variables over which" - , " to quantify as as well as two terms containing those variables," - , " which may be terms or functions in the SpecM monad, construct the" + [ "Given a list of 'fresh_symbolic' variables over which to quantify" + , " as as well as two terms containing those variables, which may be" + , " either terms or functions in the SpecM monad, construct the" , " SAWCore term which is the refinement (`Prelude.refinesS`) of the" , " given terms, with the given variables generalized with a Pi type." ] From d6fde0c28b010e1df689caa0e78219a6ae4f7d29 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Thu, 4 May 2023 13:46:57 -0400 Subject: [PATCH 03/10] add MREvidence, keep track of Theorems and SolverStats, overhaul MRSolver commands --- examples/mr_solver/mr_solver_unit_tests.saw | 26 +- heapster-saw/examples/arrays_mr_solver.saw | 9 +- .../examples/exp_explosion_mr_solver.saw | 3 +- .../examples/linked_list_mr_solver.saw | 11 +- heapster-saw/examples/sha512_mr_solver.saw | 30 +- saw-script.cabal | 1 + src/SAWScript/Builtins.hs | 201 ++++-------- src/SAWScript/Interpreter.hs | 66 ++-- src/SAWScript/Proof.hs | 22 +- src/SAWScript/Prover/MRSolver.hs | 8 +- src/SAWScript/Prover/MRSolver/Evidence.hs | 189 +++++++++++ src/SAWScript/Prover/MRSolver/Monad.hs | 302 ++++++++++-------- src/SAWScript/Prover/MRSolver/SMT.hs | 59 ++-- src/SAWScript/Prover/MRSolver/Solver.hs | 151 ++++----- src/SAWScript/Prover/MRSolver/Term.hs | 70 +--- src/SAWScript/Value.hs | 30 +- 16 files changed, 630 insertions(+), 548 deletions(-) create mode 100644 src/SAWScript/Prover/MRSolver/Evidence.hs diff --git a/examples/mr_solver/mr_solver_unit_tests.saw b/examples/mr_solver/mr_solver_unit_tests.saw index 93c06a993b..d04512704f 100644 --- a/examples/mr_solver/mr_solver_unit_tests.saw +++ b/examples/mr_solver/mr_solver_unit_tests.saw @@ -26,8 +26,6 @@ let const1_core = "\\ (_:Vec 64 Bool) -> retS VoidEv emptyFunStack (Vec 64 Bool) const1 <- parse_core const1_core; // const0 <= const0 -run_test "const0 |= const0" (mr_solver_query const0 const0) true; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] const0 const0); // (testing that "refines [] const0 const0" is actually "const0 <= const0") let const0_refines = @@ -44,8 +42,6 @@ run_test "refines [x] (const0 x) (const0 x)" // The function test_fun0 <= const0 test_fun0 <- parse_core_mod "test_funs" "test_fun0"; -run_test "const0 |= test_fun0" (mr_solver_query const0 test_fun0) true; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] const0 test_fun0); // (testing that "refines [] const0 test_fun0" is actually "const0 <= test_fun0") let const0_test_fun0_refines = @@ -55,9 +51,7 @@ run_test "refines [] const0 test_fun0" (is_convertible (parse_core_mod "test_fun (refines [] const0 test_fun0)) true; // not const0 <= const1 -run_test "const0 |= const1" (mr_solver_query const0 const1) false; -// (using mrsolver tactic - fails as expected) -// prove_extcore mrsolver (refines [] const0 const1); +fails (prove_extcore mrsolver (refines [] const0 const1)); // (testing that "refines [] const0 const1" is actually "const0 <= const1") let const0_const1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", @@ -67,18 +61,14 @@ run_test "refines [] const0 const1" (is_convertible (parse_core const0_const1_re // The function test_fun1 = const1 test_fun1 <- parse_core_mod "test_funs" "test_fun1"; -run_test "const1 |= test_fun1" (mr_solver_query const1 test_fun1) true; -run_test "const0 |= test_fun1" (mr_solver_query const0 test_fun1) false; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] const1 test_fun1); +fails (prove_extcore mrsolver (refines [] const0 test_fun1)); // (testing that "refines [] const1 test_fun1" is actually "const1 <= test_fun1") let const1_test_fun1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", "((", const1_core, ") x) ", "(test_fun1 x)"]; run_test "refines [] const1 test_fun1" (is_convertible (parse_core_mod "test_funs" const1_test_fun1_refines) (refines [] const1 test_fun1)) true; -// (using mrsolver tactic - fails as expected) -// prove_extcore mrsolver (refines [] const0 test_fun1); // (testing that "refines [] const0 test_fun1" is actually "const0 <= test_fun1") let const0_test_fun1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", @@ -95,8 +85,6 @@ let ifxEq0_core = "\\ (x:Vec 64 Bool) -> \ ifxEq0 <- parse_core ifxEq0_core; // ifxEq0 <= const0 -run_test "ifxEq0 |= const0" (mr_solver_query ifxEq0 const0) true; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] ifxEq0 const0); // (testing that "refines [] ifxEq0 const0" is actually "ifxEq0 <= const0") let ifxEq0_const0_refines = @@ -107,9 +95,7 @@ run_test "refines [] ifxEq0 const0" (is_convertible (parse_core ifxEq0_const0_re // not ifxEq0 <= const1 -run_test "ifxEq0 |= const1" (mr_solver_query ifxEq0 const1) false; -// (using mrsolver tactic - fails as expected) -// prove_extcore mrsolver (refines [] ifxEq0 const1); +fails (prove_extcore mrsolver (refines [] ifxEq0 const1)); // (testing that "refines [] ifxEq0 const1" is actually "ifxEq0 <= const1") let ifxEq0_const1_refines = str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ", @@ -123,8 +109,6 @@ let noErrors1_core = noErrors1 <- parse_core noErrors1_core; // const0 <= noErrors -run_test "noErrors1 |= noErrors1" (mr_solver_query noErrors1 noErrors1) true; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] noErrors1 noErrors1); // (testing that "refines [] noErrors1 noErrors1" is actually "noErrors1 <= noErrors1") let noErrors1_refines = @@ -134,8 +118,6 @@ run_test "refines [] noErrors1 noErrors1" (is_convertible (parse_core noErrors1_ (refines [] noErrors1 noErrors1)) true; // const1 <= noErrors -run_test "const1 |= noErrors1" (mr_solver_query const1 noErrors1) true; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] const1 noErrors1); // (testing that "refines [] const1 noErrors1" is actually "const1 <= noErrors1") let const1_noErrors1_refines = @@ -169,8 +151,6 @@ let loop1_core = loop1 <- parse_core loop1_core; // loop1 <= noErrorsRec1 -run_test "loop1 |= noErrorsRec1" (mr_solver_query loop1 noErrorsRec1) true; -// (using mrsolver tactic) prove_extcore mrsolver (refines [] loop1 noErrorsRec1); // (testing that "refines [] loop1 noErrorsRec1" is actually "loop1 <= noErrorsRec1") let loop1_noErrorsRec1_refines = diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw index 3100b82983..2f24a7d9b0 100644 --- a/heapster-saw/examples/arrays_mr_solver.saw +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -2,14 +2,15 @@ include "arrays.saw"; // Test that contains0 |= contains0 contains0 <- parse_core_mod "arrays" "contains0"; -mr_solver_test contains0 contains0; +prove_extcore mrsolver (refines [] contains0 contains0); noErrorsContains0 <- parse_core_mod "arrays" "noErrorsContains0"; -mr_solver_prove contains0 noErrorsContains0; +prove_extcore mrsolver (refines [] contains0 noErrorsContains0); + include "specPrims.saw"; import "arrays.cry"; zero_array <- parse_core_mod "arrays" "zero_array"; -mr_solver_test zero_array {{ zero_array_loop_spec }}; -mr_solver_prove zero_array {{ zero_array_spec }}; +prove_extcore mrsolver (refines [] zero_array {{ zero_array_loop_spec }}); +prove_extcore mrsolver (refines [] zero_array {{ zero_array_spec }}); diff --git a/heapster-saw/examples/exp_explosion_mr_solver.saw b/heapster-saw/examples/exp_explosion_mr_solver.saw index 2bd71bb927..0cf92af63e 100644 --- a/heapster-saw/examples/exp_explosion_mr_solver.saw +++ b/heapster-saw/examples/exp_explosion_mr_solver.saw @@ -1,7 +1,6 @@ include "exp_explosion.saw"; import "exp_explosion.cry"; -monadify_term {{ op }}; exp_explosion <- parse_core_mod "exp_explosion" "exp_explosion"; -mr_solver_prove exp_explosion {{ exp_explosion_spec }}; +prove_extcore mrsolver (refines [] exp_explosion {{ exp_explosion_spec }}); diff --git a/heapster-saw/examples/linked_list_mr_solver.saw b/heapster-saw/examples/linked_list_mr_solver.saw index a80aab1a42..a64acdef73 100644 --- a/heapster-saw/examples/linked_list_mr_solver.saw +++ b/heapster-saw/examples/linked_list_mr_solver.saw @@ -27,11 +27,10 @@ heapster_typecheck_fun env "is_head" \ arg0:true, arg1:true, ret:int64<>"; is_head <- parse_core_mod "linked_list" "is_head"; -mr_solver_test is_head is_head; +prove_extcore mrsolver (refines [] is_head is_head); is_elem <- parse_core_mod "linked_list" "is_elem"; - -mr_solver_test is_elem is_elem; +prove_extcore mrsolver (refines [] is_elem is_elem); is_elem_noErrorsSpec <- parse_core "\\ (x:Vec 64 Bool) (y:List (Vec 64 Bool)) -> \ @@ -52,9 +51,9 @@ is_elem_noErrorsSpec <- parse_core \ Vec 64 Bool)) \ \ (Vec 64 Bool)) \ \ (f x)) (x, y)"; -mr_solver_test is_elem is_elem_noErrorsSpec; +prove_extcore mrsolver (refines [] is_elem is_elem_noErrorsSpec); -mr_solver_prove is_elem {{ is_elem_spec }}; +prove_extcore mrsolver (refines [] is_elem {{ is_elem_spec }}); monadify_term {{ Right }}; @@ -63,4 +62,4 @@ monadify_term {{ nil }}; monadify_term {{ cons }}; sorted_insert_no_malloc <- parse_core_mod "linked_list" "sorted_insert_no_malloc"; -mr_solver_prove sorted_insert_no_malloc {{ sorted_insert_spec }}; +prove_extcore mrsolver (refines [] sorted_insert_no_malloc {{ sorted_insert_spec }}); diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 030af51294..ac68a154fc 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -6,14 +6,28 @@ processBlock <- parse_core_mod "SHA512" "processBlock"; processBlocks <- parse_core_mod "SHA512" "processBlocks"; // Test that every function refines itself -// mr_solver_test processBlocks processBlocks; -// mr_solver_test processBlock processBlock; -// mr_solver_test round_16_80 round_16_80; -// mr_solver_test round_00_15 round_00_15; +// prove_extcore mrsolver (refines [] processBlocks processBlocks); +// prove_extcore mrsolver (refines [] processBlock processBlock); +// prove_extcore mrsolver (refines [] round_16_80 round_16_80); +// prove_extcore mrsolver (refines [] round_00_15 round_00_15); import "sha512.cry"; -mr_solver_prove round_00_15 {{ round_00_15_spec }}; -mr_solver_prove round_16_80 {{ round_16_80_spec }}; -mr_solver_prove processBlock {{ processBlock_spec }}; -// mr_solver_prove processBlocks {{ processBlocks_spec }}; +thm_round_00_15 <- + prove_extcore mrsolver (refines [] round_00_15 {{ round_00_15_spec }}); + +thm_round_16_80 <- + prove_extcore + mrsolver_with (addrefns [thm_round_00_15] empty_rs)) + (refines [] round_16_80 {{ round_16_80_spec }}); + +thm_processBlock <- + prove_extcore + (mrsolver_with (addrefns [thm_round_00_15, thm_round_16_80] empty_rs)) + (refines [] processBlock {{ processBlock_spec }}); + +// thm_processBlocks <- +// prove_extcore +// (mrsolver_with (addrefns [thm_processBlock] empty_rs)) +// (refines [] processBlocks {{ processBlocks_spec }}); + diff --git a/saw-script.cabal b/saw-script.cabal index cff9ea9dc1..13c2c4ab38 100644 --- a/saw-script.cabal +++ b/saw-script.cabal @@ -164,6 +164,7 @@ library SAWScript.Prover.MRSolver.Monad SAWScript.Prover.MRSolver.SMT SAWScript.Prover.MRSolver.Solver + SAWScript.Prover.MRSolver.Evidence SAWScript.Prover.MRSolver.Term SAWScript.Prover.RME SAWScript.Prover.ABC diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 91d8616bc7..4385f5140d 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -2159,14 +2159,21 @@ ensureMonadicTerm sc t False -> monadifyTypedTerm sc t ensureMonadicTerm sc t = monadifyTypedTerm sc t --- | A wrapper for either 'Prover.askMRSolver' or 'Prover.assumeMRSolver' from --- @MRSolver.hs@: if the first argument is @Just str@, prints out @str@ +-- | A wrapper for either 'Prover.askMRSolver' or 'Prover.refinementTerm' from +-- @MRSolver.hs@: if the second argument is @Just str@, prints out @str@ -- followed by an abridged version of the refinement being asked, then calls --- the given function, returning how long it took to execute -mrSolver :: (SharedContext -> Prover.MREnv -> Maybe Integer -> [(LocalName, Term)] -> Term -> Term -> IO a) -> - Maybe SawDoc -> SharedContext -> [(LocalName, Term)] -> TypedTerm -> TypedTerm -> - TopLevel (NominalDiffTime, a) -mrSolver f printStr sc top_args tt1 tt2 = +-- the given function. On failure, a string of how long the function took to +-- run is passed to the third argument and the result is used as the message +-- for 'fail'. On success, if the fourth argument is @Just strf@, a string of +-- how long the function took to run is passed to @strf@ and the result is +-- printed, then regardless, the last argument is called on the result. +mrSolverH :: SharedContext -> + Maybe SawDoc -> (String -> String) -> Maybe (String -> String) -> + (SharedContext -> Prover.MREnv -> Maybe Integer -> SV.SAWRefnset -> + [(LocalName, Term)] -> Term -> Term -> IO (Either Prover.MRFailure a)) -> + SV.SAWRefnset -> [(LocalName, Term)] -> TypedTerm -> TypedTerm -> + (a -> TopLevel b) -> TopLevel b +mrSolverH sc printStr errStrf succStr f rs top_args tt1 tt2 cont = do env <- rwMRSolverEnv <$> get m1 <- ttTerm <$> ensureMonadicTerm sc tt1 m2 <- ttTerm <$> ensureMonadicTerm sc tt2 @@ -2178,9 +2185,20 @@ mrSolver f printStr sc top_args tt1 tt2 = "[MRSolver] " <> str <> ": " <> ppTmHead m1' <> " |= " <> ppTmHead m2' time1 <- liftIO getCurrentTime - res <- io $ f sc env Nothing top_args m1' m2' + res <- io $ f sc env Nothing rs top_args m1' m2' time2 <- liftIO getCurrentTime - return (diffUTCTime time2 time1, res) + let diff = show $ diffUTCTime time2 time1 + case res of + Left err | Prover.mreDebugLevel env == 0 -> + fail (Prover.showMRFailure err ++ "\n[MRSolver] " ++ errStrf diff) + Left err -> + -- we ignore the MRFailure context here since it will have already + -- been printed by the debug trace + fail (Prover.showMRFailureNoCtx err ++ "\n[MRSolver] " ++ errStrf diff) + Right a | Just sf <- succStr -> + printOutLnTop Info (sf diff) >> cont a + Right a -> + cont a where -- Turn a term of the form @\x1 ... xn -> f x1 ... xn@ into @f@ collapseEta :: Term -> Term collapseEta (asLambdaList -> (lamVars, @@ -2200,130 +2218,37 @@ mrSolver f printStr sc top_args tt1 tt2 = ppTmHead _ = "..." -- | Invokes MRSolver to attempt to solve a focused goal of the form --- @(a1:A1) -> ... -> (an:A1) -> refinesS_eq ...@, printing an error message --- and exiting if this cannot be done. This function will not modify the --- 'Prover.MREnv'. -mrSolverTactic :: SharedContext -> ProofScript () -mrSolverTactic sc = execTactic $ Tactic $ \goal -> lift $ do - dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get +-- @(a1:A1) -> ... -> (an:An) -> refinesS_eq ...@, assuming the refinements +-- in the given 'Refnset', and printing an error message and exiting if +-- this cannot be done +mrSolver :: SV.SAWRefnset -> ProofScript () +mrSolver rs = execTactic $ Tactic $ \goal -> lift $ + getSharedContext >>= \sc -> case sequentState (goalSequent goal) of Unfocused -> fail "mrsolver: focus required" HypFocus _ _ -> fail "mrsolver: cannot apply mrsolver in a hypothesis" - ConclFocus (asPiList . unProp -> (args, asApplyAll -> - (asGlobalDef -> Just "Prelude.refinesS", - [ev1, ev2, stack1, stack2, - asApplyAll -> (asGlobalDef -> Just "Prelude.eqPreRel", _), - asApplyAll -> (asGlobalDef -> Just "Prelude.eqPostRel", _), - rtp1, rtp2, - asApplyAll -> (asGlobalDef -> Just "Prelude.eqRR", _), - t1, t2]))) _ -> - on_refinesS dlvl goal args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2 - ConclFocus (asPiList . unProp -> (args, asApplyAll -> - (asGlobalDef -> Just "Prelude.refinesS_eq", - [ev, stack, rtp, t1, t2]))) _ -> - on_refinesS dlvl goal args ev ev stack stack rtp rtp t1 t2 - _ -> error "[MRSolver] cannot apply mrsolver tactic to a refinesS goal with non-trivial RPre/RPost/RR" - where - on_refinesS dlvl goal args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2 = + ConclFocus (Prover.asRefinesS . unProp -> Just (args, ev1, ev2, stack1, stack2, + rtp1, rtp2, t1, t2)) _ -> do tp1 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev1, stack1, rtp1] tp2 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev2, stack2, rtp2] let tt1 = TypedTerm (TypedTermOther tp1) t1 let tt2 = TypedTerm (TypedTermOther tp2) t2 - (diff, res) <- mrSolver Prover.askMRSolver (Just "mrsolver") sc args tt1 tt2 - case res of - Left err | dlvl == 0 -> - io (putStrLn $ Prover.showMRFailure err) >> - printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Left err -> - -- we ignore the MRFailure context here since it will have already - -- been printed by the debug trace - io (putStrLn $ Prover.showMRFailureNoCtx err) >> - printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Right _ -> - printOutLnTop Info (printf "[MRSolver] Success in %s" (show diff)) >> - let stats = solverStats "MRSOLVER ADMITTED" (sequentSharedSize (goalSequent goal)) in - return ((), stats, [], leafEvidence MrSolverEvidence) - --- | Run Mr Solver to prove that the first term refines the second, adding --- any relevant 'Prover.FunAssump's to the 'Prover.MREnv' if the first argument --- is true and this can be done, or printing an error message and exiting if it --- cannot. -mrSolverProve :: Bool -> SharedContext -> TypedTerm -> TypedTerm -> TopLevel () -mrSolverProve addToEnv sc t1 t2 = - do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get - let printStr = if addToEnv then "Proving" else "Testing" - (diff, res) <- mrSolver Prover.askMRSolver (Just printStr) sc [] t1 t2 - case res of - Left err | dlvl == 0 -> - io (putStrLn $ Prover.showMRFailure err) >> - printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Left err -> - -- we ignore the MRFailure context here since it will have already - -- been printed by the debug trace - io (putStrLn $ Prover.showMRFailureNoCtx err) >> - printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Right (Just (fnm, fassump)) | addToEnv -> - let assump_str = case Prover.fassumpRHS fassump of - Prover.OpaqueFunAssump _ _ -> "an opaque" - Prover.RewriteFunAssump _ -> "a rewrite" in - printOutLnTop Info ( - printf "[MRSolver] Success in %s, added as %s assumption" - (show diff) (assump_str :: String)) >> - modify (\rw -> rw { rwMRSolverEnv = - Prover.mrEnvAddFunAssump fnm fassump (rwMRSolverEnv rw) }) - _ -> - printOutLnTop Info $ printf "[MRSolver] Success in %s" (show diff) - --- | Run Mr Solver to prove that the first term refines the second, returning --- true iff this can be done. This function will not modify the 'Prover.MREnv'. -mrSolverQuery :: SharedContext -> TypedTerm -> TypedTerm -> TopLevel Bool -mrSolverQuery sc t1 t2 = - do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get - (diff, res) <- mrSolver Prover.askMRSolver (Just "Querying") sc [] t1 t2 - case res of - Left _ | dlvl == 0 -> - printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> - return False - Left err -> - -- we ignore the MRFailure context here since it will have already - -- been printed by the debug trace - io (putStrLn $ Prover.showMRFailureNoCtx err) >> - printOutLnTop Info (printf "[MRSolver] Failure in %s" (show diff)) >> - return False - Right _ -> - printOutLnTop Info (printf "[MRSolver] Success in %s" (show diff)) >> - return True - --- | Generate the 'Prover.FunAssump' which corresponds to the given refinement --- and add it to the 'Prover.MREnv' -mrSolverAssume :: SharedContext -> TypedTerm -> TypedTerm -> TopLevel () -mrSolverAssume sc t1 t2 = - do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get - (_, res) <- mrSolver Prover.assumeMRSolver (Just "Assuming") sc [] t1 t2 - case res of - Left err | dlvl == 0 -> - io (putStrLn $ Prover.showMRFailure err) >> - printOutLnTop Info (printf "[MRSolver] Failure") >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Left err -> - -- we ignore the MRFailure context here since it will have already - -- been printed by the debug trace - io (putStrLn $ Prover.showMRFailureNoCtx err) >> - printOutLnTop Info (printf "[MRSolver] Failure") >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Right (Just (fnm, fassump)) -> - printOutLnTop Info ( - printf "[MRSolver] Success, added as an opaque assumption") >> - modify (\rw -> rw { rwMRSolverEnv = - Prover.mrEnvAddFunAssump fnm fassump (rwMRSolverEnv rw) }) - _ -> - printOutLnTop Info $ printf $ - "[MRSolver] Failure, given refinement cannot be interpreted as" ++ - " an assumption" + mrSolverH sc + (Just $ "Tactic call") (printf "Failure in %s") (Just $ printf "Success in %s") + Prover.askMRSolver rs args tt1 tt2 + (\(stats, mre) -> return ((), stats, [], leafEvidence $ MrSolverEvidence mre)) + _ -> error "mrsolver: cannot apply mrsolver to a non-refinement goal" + +-- | Add a proved refinement theorem to a given refinement set +addrefn :: Theorem -> SV.SAWRefnset -> TopLevel SV.SAWRefnset +addrefn thm rs = + case Prover.asFunAssump (Just (thmNonce thm)) (unProp $ thmProp thm) of + Nothing -> fail "addrefn: theorem is not a refinement" + Just fassump -> pure (Prover.addFunAssump fassump rs) + +-- | Add proved refinement theorems to a given refinement set +addrefns :: [Theorem] -> SV.SAWRefnset -> TopLevel SV.SAWRefnset +addrefns thms ss = foldM (flip addrefn) ss thms -- | Set the debug level of the 'Prover.MREnv' mrSolverSetDebug :: Int -> TopLevel () @@ -2338,27 +2263,13 @@ mrSolverSetDebug dlvl = -- variables generalized with a Pi type. refinesTerm :: [TypedTerm] -> TypedTerm -> TypedTerm -> TopLevel TypedTerm refinesTerm vars tt1 tt2 = - do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get - sc <- getSharedContext - env <- rwMRSolverEnv <$> get + do sc <- getSharedContext tt1' <- lambdas vars tt1 tt2' <- lambdas vars tt2 - m1 <- ttTerm <$> ensureMonadicTerm sc tt1' - m2 <- ttTerm <$> ensureMonadicTerm sc tt2' - res <- io $ Prover.refinementTerm sc env Nothing [] m1 m2 - case res of - Left err | dlvl == 0 -> - io (putStrLn $ Prover.showMRFailure err) >> - printOutLnTop Info (printf "[MRSolver] Failed to build refinement term") >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Left err -> - -- we ignore the MRFailure context here since it will have already - -- been printed by the debug trace - io (putStrLn $ Prover.showMRFailureNoCtx err) >> - printOutLnTop Info (printf "[MRSolver] Failed to build refinement term") >> - io (Exit.exitWith $ Exit.ExitFailure 1) - Right t -> - io (mkTypedTerm sc t) + mrSolverH sc + Nothing (printf "[MRSolver] Failed to build refinement term (%s)") Nothing + Prover.refinementTerm Prover.emptyRefnset [] tt1' tt2' + (io . mkTypedTerm sc) setMonadification :: SharedContext -> String -> String -> Bool -> TopLevel () setMonadification sc cry_str saw_str poly_p = diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 235518f4d0..b64c89b0b4 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -70,7 +70,7 @@ import SAWScript.Value import SAWScript.Proof (newTheoremDB) import SAWScript.Prover.Rewrite(basic_ss) import SAWScript.Prover.Exporter -import SAWScript.Prover.MRSolver (emptyMREnv) +import SAWScript.Prover.MRSolver (emptyMREnv, emptyRefnset) import SAWScript.Yosys import Verifier.SAW.Conversion --import Verifier.SAW.PrettySExp @@ -3814,41 +3814,7 @@ primitives = Map.fromList --------------------------------------------------------------------- - , prim "mr_solver_prove" "Term -> Term -> TopLevel ()" - (scVal (mrSolverProve True)) - Experimental - [ "Call the monadic-recursive solver (that's MR. Solver to you)" - , " to prove that one monadic term refines another. If this can" - , " be done, this refinement will be used in future calls to" - , " Mr. Solver, and if it cannot, the script will exit. See also:" - , " mr_solver_test, mr_solver_query." ] - - , prim "mr_solver_test" "Term -> Term -> TopLevel ()" - (scVal (mrSolverProve False)) - Experimental - [ "Call the monadic-recursive solver (that's MR. Solver to you)" - , " to prove that one monadic term refines another. If this cannot" - , " be done, the script will exit. See also: mr_solver_prove," - , " mr_solver_query - unlike the former, this refinement will not" - , " be used in future calls to Mr. Solver." ] - - , prim "mr_solver_query" "Term -> Term -> TopLevel Bool" - (scVal mrSolverQuery) - Experimental - [ "Call the monadic-recursive solver (that's MR. Solver to you)" - , " to prove that one monadic term refines another, returning" - , " true iff this can be done. See also: mr_solver_prove," - , " mr_solver_test - unlike the former, this refinement will not" - , " be considered in future calls to Mr. Solver, and unlike both," - , " this command will never fail." ] - - , prim "mr_solver_assume" "Term -> Term -> TopLevel Bool" - (scVal mrSolverAssume) - Experimental - [ "Add the refinement of the two given expressions as an assumption" - , " which will be used in future calls to Mr. Solver." ] - - , prim "mr_solver_set_debug_level" "Int -> TopLevel ()" + , prim "mrsolver_set_debug_level" "Int -> TopLevel ()" (pureVal mrSolverSetDebug) Experimental [ "Set the debug level for Mr. Solver; 0 = no debug output," @@ -3856,10 +3822,32 @@ primitives = Map.fromList , " 3 = all debug output" ] , prim "mrsolver" "ProofScript ()" - (scVal mrSolverTactic) + (pureVal (mrSolver emptyRefnset)) + Experimental + [ "Use MRSolver to prove a current refinement goal, i.e. a goal of" + , " the form `(a1:A1) -> ... -> (an:An) -> refinesS_eq ...`" ] + + , prim "empty_rs" "Refnset" + (pureVal (emptyRefnset :: SAWRefnset)) + Current + [ "The empty refinement set, containing no refinements." ] + + , prim "addrefn" "Theorem -> Refnset -> Refnset" + (funVal2 addrefn) + Current + [ "Add a proved refinement theorem to a given refinement set." ] + + , prim "addrefns" "[Theorem] -> Refnset -> Refnset" + (funVal2 addrefns) + Current + [ "Add proved refinement theorems to a given refinement set." ] + + , prim "mrsolver_with" "Renfset -> ProofScript ()" + (pureVal mrSolver) Experimental - [ "Use MRSolver to prove a current goal of the form:" - , "(a1:A1) -> ... -> (an:A1) -> refinesS_eq ..." ] + [ "Use MRSolver to prove a current refinement goal, i.e. a goal of" + , " the form `(a1:A1) -> ... -> (an:An) -> refinesS_eq ...`, with" + , " the given set of refinements taken as assumptions" ] , prim "refines" "[Term] -> Term -> Term -> Term" (funVal3 refinesTerm) diff --git a/src/SAWScript/Proof.hs b/src/SAWScript/Proof.hs index edccec690d..5500c8f417 100644 --- a/src/SAWScript/Proof.hs +++ b/src/SAWScript/Proof.hs @@ -167,6 +167,7 @@ import What4.ProgramLoc (ProgramLoc) import SAWScript.Position import SAWScript.Prover.SolverStats +import qualified SAWScript.Prover.MRSolver.Evidence as MRSolver import SAWScript.Crucible.Common as Common import qualified Verifier.SAW.Simulator.TermModel as TM import qualified Verifier.SAW.Simulator.What4 as W4Sim @@ -1076,10 +1077,10 @@ data Evidence -- sequent calculus axiom, which connects a hypothesis to a conclusion. | AxiomEvidence - -- | FIXME: This is a placeholder for evidence that will be generated by - -- MRSolver - currently this trivial evidence is given whenever MRSolver - -- completes without error (see 'proveRefinement' in @Builtins.hs@) - | MrSolverEvidence + -- | Evidence generated by running the @mrsolver@ tactic. + -- FIXME: Add a @[Evidence]@ here when MRSolver is updated to support + -- returning unsolved goals. + | MrSolverEvidence !(MRSolver.MREvidence TheoremNonce) -- | The the proposition proved by a given theorem. thmProp :: Theorem -> Prop @@ -1696,9 +1697,16 @@ checkEvidence sc = \e p -> do nenv <- scGetNamingEnv sc ] return (mempty, ProvedTheorem mempty) - MrSolverEvidence -> - -- TODO Fill this in when we have evidence for MrSolver - return (mempty, ProvedTheorem mempty) + MrSolverEvidence mre -> + case sequentState sqt of + ConclFocus _p _mkSqt -> + do (d, stats) <- MRSolver.checkMREvidence mre + -- FIXME: Check that p actually does match the MRSolverEvidence + return (d, ProvedTheorem stats) + _ -> fail $ unlines $ + [ "MRSolver evidence requires a conclusion-focused sequent" + , prettySequent defaultPPOpts nenv sqt + ] CutEvidence p ehyp egl -> do d1 <- check nenv ehyp (addHypothesis p sqt) diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index bc680bdccc..3fb3c94c25 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -9,12 +9,14 @@ Portability : non-portable (language extensions) -} module SAWScript.Prover.MRSolver - (askMRSolver, assumeMRSolver, MRSolverResult, refinementTerm, + (askMRSolver, refinementTerm, MRFailure(..), showMRFailure, showMRFailureNoCtx, - FunAssump(..), FunAssumpRHS(..), - MREnv(..), emptyMREnv, mrEnvAddFunAssump, mrEnvSetDebugLevel, + FunAssump(..), FunAssumpRHS(..), asRefinesS, asFunAssump, + Refnset, emptyRefnset, addFunAssump, + MREnv(..), emptyMREnv, mrEnvSetDebugLevel, asProjAll, isSpecFunType) where import SAWScript.Prover.MRSolver.Term +import SAWScript.Prover.MRSolver.Evidence import SAWScript.Prover.MRSolver.Monad import SAWScript.Prover.MRSolver.Solver diff --git a/src/SAWScript/Prover/MRSolver/Evidence.hs b/src/SAWScript/Prover/MRSolver/Evidence.hs new file mode 100644 index 0000000000..f781b324d1 --- /dev/null +++ b/src/SAWScript/Prover/MRSolver/Evidence.hs @@ -0,0 +1,189 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} + +{- | +Module : SAWScript.Prover.MRSolver.Evidence +Copyright : Galois, Inc. 2023 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module defines multiple outward facing components of MRSolver, most +notably the 'MREvidence' type which provides evidence for the truth of a +refinement proposition proved by MRSolver, and used in @Proof.hs@. This module +also defines the 'MREnv' type, the global MRSolver state. +-} + +module SAWScript.Prover.MRSolver.Evidence where + +import Data.Foldable (foldMap') + +import Data.Map (Map) +import qualified Data.Map as Map + +import Data.HashMap.Lazy (HashMap) +import qualified Data.HashMap.Lazy as HashMap + +import Data.Set (Set) +import qualified Data.Set as Set + +import Verifier.SAW.Term.Functor +import Verifier.SAW.Recognizer +import Verifier.SAW.Cryptol.Monadify +import SAWScript.Prover.SolverStats + +import SAWScript.Prover.MRSolver.Term + + +---------------------------------------------------------------------- +-- * Function Refinement Assumptions +---------------------------------------------------------------------- + +-- | The right-hand-side of a 'FunAssump': either a 'FunName' and arguments, if +-- it is an opaque 'FunAsump', or a 'NormComp', if it is a rewrite 'FunAssump' +data FunAssumpRHS = OpaqueFunAssump FunName [Term] + | RewriteFunAssump Term + +-- | An assumption that a named function refines some specification. This has +-- the form +-- +-- > forall x1, ..., xn. F e1 ... ek |= m +-- +-- for some universal context @x1:T1, .., xn:Tn@, some list of argument +-- expressions @ei@ over the universal @xj@ variables, and some right-hand side +-- computation expression @m@. +data FunAssump t = FunAssump { + -- | The uvars that were in scope when this assumption was created + fassumpCtx :: MRVarCtx, + -- | The function on the left-hand-side + fassumpFun :: FunName, + -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars + fassumpArgs :: [Term], + -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars + fassumpRHS :: FunAssumpRHS, + -- | An optional annotation, which in the case of SAWScript, is always a + -- 'TheoremNonce' indicating from which 'Theorem' this assumption comes + fassumpAnnotation :: Maybe t +} + +-- | Recognizes a term of the form: +-- @(a1:A1) -> ... -> (an:An) -> refinesS_eq ev stack rtp t1 t2@, +-- and returns a tuple: +-- @([(a1,A1), ..., (an,An)], ev, ev, stack, stack, rtp, rtp, t1, t2)@ +asRefinesS :: Recognizer Term ([(LocalName, Term)], Term, Term, Term, Term, Term, Term, Term, Term) +asRefinesS (asPiList -> (args, asApplyAll -> + (asGlobalDef -> Just "Prelude.refinesS", + [ev1, ev2, stack1, stack2, + asApplyAll -> (asGlobalDef -> Just "Prelude.eqPreRel", _), + asApplyAll -> (asGlobalDef -> Just "Prelude.eqPostRel", _), + rtp1, rtp2, + asApplyAll -> (asGlobalDef -> Just "Prelude.eqRR", _), + t1, t2]))) = + Just (args, ev1, ev2, stack1, stack2, rtp1, rtp2, t1, t2) +asRefinesS (asPiList -> (args, asApplyAll -> + (asGlobalDef -> Just "Prelude.refinesS_eq", + [ev, stack, rtp, t1, t2]))) = + Just (args, ev, ev, stack, stack, rtp, rtp, t1, t2) +asRefinesS (asPiList -> (_, asApplyAll -> (asGlobalDef -> Just "Prelude.refinesS", _))) = + error "FIXME: MRSolver does not yet accept refinesS goals with non-trivial RPre/RPost/RR" +asRefinesS _ = Nothing + +-- | Recognizes a term of the form: +-- @(a1:A1) -> ... -> (an:An) -> refinesS_eq ev stack rtp (f b1 ... bm) t2@, +-- and returns: @FunAssump f [a1,...,an] [b1,...,bm] rhs ann@, +-- where @ann@ is the given argument and @rhs@ is either +-- @OpaqueFunAssump g [c1,...,cl]@ if @t2@ is @g c1 ... cl@, +-- or @RewriteFunAssump t2@ otherwise +asFunAssump :: Maybe t -> Recognizer Term (FunAssump t) +asFunAssump ann (asRefinesS -> Just (args, + asGlobalDef -> Just "Prelude.VoidEv", + asGlobalDef -> Just "Prelude.VoidEv", + asGlobalDef -> Just "Prelude.emptyFunStack", + asGlobalDef -> Just "Prelude.emptyFunStack", + _, _, + asApplyAll -> (asGlobalFunName -> Just f1, args1), + t2@(asApplyAll -> (asGlobalFunName -> mb_f2, args2)))) = + let rhs = maybe (RewriteFunAssump t2) (\f2 -> OpaqueFunAssump f2 args2) mb_f2 + in Just $ FunAssump { fassumpCtx = mrVarCtxFromOuterToInner args, + fassumpFun = f1, fassumpArgs = args1, + fassumpRHS = rhs, + fassumpAnnotation = ann } +asFunAssump _ _ = Nothing + + +---------------------------------------------------------------------- +-- * Refinement Sets +---------------------------------------------------------------------- + +-- | A set of refinements whose left-hand-sides are function applications, +-- represented as 'FunAssump's. Internally, a map from the 'VarIndex'es of the +-- LHS functions to 'FunAssump's describing the complete refinement. +type Refnset t = HashMap VarIndex (Map [TermProj] (FunAssump t)) + +-- | The 'Refnset' with no refinements +emptyRefnset :: Refnset t +emptyRefnset = HashMap.empty + +-- | Given a 'FunName' and a 'Refnset', return the 'FunAssump' which has +-- the given 'FunName' as its LHS function, if possible +lookupFunAssump :: FunName -> Refnset t -> Maybe (FunAssump t) +lookupFunAssump (GlobalName (GlobalDef _ ix _ _ _) projs) refSet = + HashMap.lookup ix refSet >>= Map.lookup projs +lookupFunAssump _ _ = Nothing + +-- | Add a 'FunAssump' to a 'Refnset' +addFunAssump :: FunAssump t -> Refnset t -> Refnset t +addFunAssump fa@(fassumpFun -> GlobalName (GlobalDef _ ix _ _ _) projs) = + HashMap.insertWith (\_ -> Map.insert projs fa) ix + (Map.singleton projs fa) +addFunAssump _ = error "Cannot insert a non-global name into a Refnset" + +-- | Return the list of 'FunAssump's in a given 'Refnset' +listFunAssumps :: Refnset t -> [FunAssump t] +listFunAssumps = concatMap Map.elems . HashMap.elems + + +---------------------------------------------------------------------- +-- * Mr Solver Environments +---------------------------------------------------------------------- + +-- | A global MR Solver environment +data MREnv = MREnv { + -- | The debug level, which controls debug printing + mreDebugLevel :: Int +} + +-- | The empty 'MREnv' +emptyMREnv :: MREnv +emptyMREnv = MREnv { mreDebugLevel = 0 } + +-- | Set the debug level of a Mr Solver environment +mrEnvSetDebugLevel :: Int -> MREnv -> MREnv +mrEnvSetDebugLevel dlvl env = env { mreDebugLevel = dlvl } + + +---------------------------------------------------------------------- +-- * Mr Solver Evidence +---------------------------------------------------------------------- + +-- | An entry in 'MREvidence' indicating a usage of an SMT solver or a +-- 'FunAssump' +data MREvidenceEntry t = MREUsedSolver !SolverStats !Term + | MREUsedFunAssump !t + +-- | Records evidence for the truth of a refinement proposition proved by +-- MRSolver. Currently, this is just a list of 'MREvidenceEntry's indicating +-- each instance where MRSolver needed to use an SMT solver or a 'FunAssump'. +-- FIXME: Have this data type actually provide evidence! i.e. have it keep +-- track of each refinement theorem used by MRSolver along the way. +type MREvidence t = [MREvidenceEntry t] + +-- | Verify that the given evidence in fact supports the given refinement +-- proposition. Returns the identifiers of all the theorems depended on while +-- checking evidence. +-- FIXME: Actually take in a refinement to check against! +checkMREvidence :: Ord t => MREvidence t -> IO (Set t, SolverStats) +checkMREvidence = return . foldMap' checkEntry + where checkEntry (MREUsedSolver stats _) = (mempty, stats) + checkEntry (MREUsedFunAssump t) = (Set.singleton t, mempty) diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 05390b1147..fe9cf09dce 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -2,6 +2,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE PatternSynonyms #-} @@ -48,8 +49,10 @@ import Verifier.SAW.SCTypeCheck import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer import Verifier.SAW.Cryptol.Monadify +import SAWScript.Prover.SolverStats import SAWScript.Prover.MRSolver.Term +import SAWScript.Prover.MRSolver.Evidence ---------------------------------------------------------------------- @@ -264,7 +267,7 @@ coIndHypSetArg hyp@(CoIndHyp {..}) (Right i) x = -- | Add a variable to the context of a coinductive hypothesis, returning the -- updated coinductive hypothesis and a 'Term' which is the new variable -coIndHypWithVar :: CoIndHyp -> LocalName -> Type -> MRM (CoIndHyp, Term) +coIndHypWithVar :: CoIndHyp -> LocalName -> Type -> MRM t (CoIndHyp, Term) coIndHypWithVar (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) nm tp = do var <- liftSC1 scLocalVar 0 let ctx' = mrVarCtxAppend (singletonMRVarCtx nm tp) ctx @@ -294,15 +297,17 @@ instance PrettyInCtx CoIndHyp where type DataTypeAssumps = HashMap Term DataTypeAssump -- | Parameters and locals for MR. Solver -data MRInfo = MRInfo { +data MRInfo t = MRInfo { -- | Global shared context for building terms, etc. mriSC :: SharedContext, -- | SMT timeout for SMT calls made by Mr. Solver mriSMTTimeout :: Maybe Integer, - -- | The current context of universal variables - mriUVars :: MRVarCtx, -- | The top-level Mr Solver environment mriEnv :: MREnv, + -- | The set of function refinements to assume + mriRefnset :: Refnset t, + -- | The current context of universal variables + mriUVars :: MRVarCtx, -- | The current set of co-inductive hypotheses mriCoIndHyps :: CoIndHyps, -- | The current assumptions, which are conjoined into a single Boolean term; @@ -313,7 +318,12 @@ data MRInfo = MRInfo { } -- | State maintained by MR. Solver -data MRState = MRState { +data MRState t = MRState { + -- | Cumulative stats on all solver runs made so far + mrsSolverStats :: SolverStats, + -- | The evidence object, which includes information about which + -- 'FunAssump's in 'mriRefnset' have been used so far + mrsEvidence :: MREvidence t, -- | The existential and letrec-bound variables mrsVars :: MRVarMap } @@ -327,89 +337,111 @@ data MRExn = MRExnFailure MRFailure -- | Mr. Monad, the monad used by MR. Solver, which has 'MRInfo' as as a -- shared environment, 'MRState' as state, and 'MRFailure' as an exception -- type, all over an 'IO' monad -newtype MRM a = MRM { unMRM :: ReaderT MRInfo (StateT MRState - (ExceptT MRExn IO)) a } - deriving newtype (Functor, Applicative, Monad, MonadIO, - MonadReader MRInfo, MonadState MRState, - MonadError MRExn) +newtype MRM t a = MRM { unMRM :: ReaderT (MRInfo t) (StateT (MRState t) + (ExceptT MRExn IO)) a } + deriving newtype (Functor, Applicative, Monad, MonadIO, + MonadReader (MRInfo t), MonadState (MRState t), + MonadError MRExn) -instance MonadTerm MRM where +instance MonadTerm (MRM t) where mkTermF = liftSC1 scTermF liftTerm = liftSC3 incVars whnfTerm = liftSC1 scWhnf substTerm = liftSC3 instantiateVarList -- | Get the current value of 'mriSC' -mrSC :: MRM SharedContext +mrSC :: MRM t SharedContext mrSC = mriSC <$> ask -- | Get the current value of 'mriSMTTimeout' -mrSMTTimeout :: MRM (Maybe Integer) +mrSMTTimeout :: MRM t (Maybe Integer) mrSMTTimeout = mriSMTTimeout <$> ask -- | Get the current value of 'mriUVars' -mrUVars :: MRM MRVarCtx +mrUVars :: MRM t MRVarCtx mrUVars = mriUVars <$> ask -- | Get the current function assumptions -mrFunAssumps :: MRM FunAssumps -mrFunAssumps = mreFunAssumps <$> mriEnv <$> ask +mrRefnset :: MRM t (Refnset t) +mrRefnset = mriRefnset <$> ask -- | Get the current value of 'mriCoIndHyps' -mrCoIndHyps :: MRM CoIndHyps +mrCoIndHyps :: MRM t CoIndHyps mrCoIndHyps = mriCoIndHyps <$> ask -- | Get the current value of 'mriAssumptions' -mrAssumptions :: MRM Term +mrAssumptions :: MRM t Term mrAssumptions = mriAssumptions <$> ask -- | Get the current value of 'mriDataTypeAssumps' -mrDataTypeAssumps :: MRM DataTypeAssumps +mrDataTypeAssumps :: MRM t DataTypeAssumps mrDataTypeAssumps = mriDataTypeAssumps <$> ask -- | Get the current debug level -mrDebugLevel :: MRM Int +mrDebugLevel :: MRM t Int mrDebugLevel = mreDebugLevel <$> mriEnv <$> ask -- | Get the current value of 'mriEnv' -mrEnv :: MRM MREnv +mrEnv :: MRM t MREnv mrEnv = mriEnv <$> ask +-- | Get the current value of 'mrsSolverStats' +mrSolverStats :: MRM t SolverStats +mrSolverStats = mrsSolverStats <$> get + +-- | Get the current value of 'mrsEvidence' +mrEvidence :: MRM t (MREvidence t) +mrEvidence = mrsEvidence <$> get + -- | Get the current value of 'mrsVars' -mrVars :: MRM MRVarMap +mrVars :: MRM t MRVarMap mrVars = mrsVars <$> get --- | Run an 'MRM' computation and return a result or an error -runMRM :: SharedContext -> Maybe Integer -> MREnv -> - MRM a -> IO (Either MRFailure a) -runMRM sc timeout env m = +-- | Run an 'MRM' computation and return a result or an error, including the +-- final state of 'mrsSolverStats' and 'mrsEvidence' +runMRM :: SharedContext -> Maybe Integer -> MREnv -> Refnset t -> + MRM t a -> IO (Either MRFailure (a, (SolverStats, MREvidence t))) +runMRM sc timeout env rs m = do true_tm <- scBool sc True let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, - mriEnv = env, + mriEnv = env, mriRefnset = rs, mriUVars = emptyMRVarCtx, mriCoIndHyps = Map.empty, mriAssumptions = true_tm, mriDataTypeAssumps = HashMap.empty } - let init_st = MRState { mrsVars = Map.empty } - res <- runExceptT $ flip evalStateT init_st $ + let init_st = MRState { mrsSolverStats = mempty, mrsEvidence = mempty, + mrsVars = Map.empty } + res <- runExceptT $ flip runStateT init_st $ flip runReaderT init_info $ unMRM m case res of - Right a -> return $ Right a + Right (a, st) -> return $ Right (a, (mrsSolverStats st, mrsEvidence st)) Left (MRExnFailure failure) -> return $ Left failure Left exn -> fail ("runMRM: unexpected internal exception: " ++ show exn) +-- | Run an 'MRM' computation and return a result or an error, discarding the +-- final state +evalMRM :: SharedContext -> Maybe Integer -> MREnv -> Refnset t -> + MRM t a -> IO (Either MRFailure a) +evalMRM sc timeout env rs = fmap (fmap fst) . runMRM sc timeout env rs + +-- | Run an 'MRM' computation and return a final state or an error, discarding +-- the result +execMRM :: SharedContext -> Maybe Integer -> MREnv -> Refnset t -> + MRM t a -> IO (Either MRFailure (SolverStats, MREvidence t)) +execMRM sc timeout env rs = fmap (fmap snd) . runMRM sc timeout env rs + -- | Throw an 'MRFailure' -throwMRFailure :: MRFailure -> MRM a +throwMRFailure :: MRFailure -> MRM t a throwMRFailure = throwError . MRExnFailure -- | Apply a function to any failure thrown by an 'MRM' computation -mapMRFailure :: (MRFailure -> MRFailure) -> MRM a -> MRM a +mapMRFailure :: (MRFailure -> MRFailure) -> MRM t a -> MRM t a mapMRFailure f m = catchError m $ \case MRExnFailure failure -> throwError $ MRExnFailure $ f failure e -> throwError e -- | Catch any 'MRFailure' raised by a computation -catchFailure :: MRM a -> (MRFailure -> MRM a) -> MRM a +catchFailure :: MRM t a -> (MRFailure -> MRM t a) -> MRM t a catchFailure m f = m `catchError` \case MRExnFailure failure -> f failure @@ -417,14 +449,14 @@ catchFailure m f = -- | Try two different 'MRM' computations, combining their failures if needed. -- Note that the 'MRState' will reset if the first computation fails. -mrOr :: MRM a -> MRM a -> MRM a +mrOr :: MRM t a -> MRM t a -> MRM t a mrOr m1 m2 = catchFailure m1 $ \err1 -> catchFailure m2 $ \err2 -> throwMRFailure $ MRFailureDisj err1 err2 -- | Run an 'MRM' computation in an extended failure context -withFailureCtx :: FailCtx -> MRM a -> MRM a +withFailureCtx :: FailCtx -> MRM t a -> MRM t a withFailureCtx ctx = mapMRFailure (MRFailureCtx ctx) {- @@ -437,29 +469,29 @@ catchErrorEither m = catchError (Right <$> m) (return . Left) -- typeclass like LiftTCM -- | Lift a nullary SharedTerm computation into 'MRM' -liftSC0 :: (SharedContext -> IO a) -> MRM a +liftSC0 :: (SharedContext -> IO a) -> MRM t a liftSC0 f = mrSC >>= \sc -> liftIO (f sc) -- | Lift a unary SharedTerm computation into 'MRM' -liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM b +liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM t b liftSC1 f a = mrSC >>= \sc -> liftIO (f sc a) -- | Lift a binary SharedTerm computation into 'MRM' -liftSC2 :: (SharedContext -> a -> b -> IO c) -> a -> b -> MRM c +liftSC2 :: (SharedContext -> a -> b -> IO c) -> a -> b -> MRM t c liftSC2 f a b = mrSC >>= \sc -> liftIO (f sc a b) -- | Lift a ternary SharedTerm computation into 'MRM' -liftSC3 :: (SharedContext -> a -> b -> c -> IO d) -> a -> b -> c -> MRM d +liftSC3 :: (SharedContext -> a -> b -> c -> IO d) -> a -> b -> c -> MRM t d liftSC3 f a b c = mrSC >>= \sc -> liftIO (f sc a b c) -- | Lift a quaternary SharedTerm computation into 'MRM' liftSC4 :: (SharedContext -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> - MRM e + MRM t e liftSC4 f a b c d = mrSC >>= \sc -> liftIO (f sc a b c d) -- | Lift a quinary SharedTerm computation into 'MRM' liftSC5 :: (SharedContext -> a -> b -> c -> d -> e -> IO f) -> - a -> b -> c -> d -> e -> MRM f + a -> b -> c -> d -> e -> MRM t f liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) @@ -468,25 +500,25 @@ liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) ---------------------------------------------------------------------- -- | Create a term representing the type @IsFinite n@ -mrIsFinite :: Term -> MRM Term +mrIsFinite :: Term -> MRM t Term mrIsFinite n = liftSC2 scGlobalApply "CryptolM.isFinite" [n] -- | Create a term representing an application of @Prelude.error@ -mrErrorTerm :: Term -> T.Text -> MRM Term +mrErrorTerm :: Term -> T.Text -> MRM t Term mrErrorTerm a str = do err_str <- liftSC1 scString str liftSC2 scGlobalApply "Prelude.error" [a, err_str] -- | Create a term representing an application of @Prelude.genBVVecFromVec@, -- where the default value argument is @Prelude.error@ of the given 'T.Text' -mrGenBVVecFromVec :: Term -> Term -> Term -> T.Text -> Term -> Term -> MRM Term +mrGenBVVecFromVec :: Term -> Term -> Term -> T.Text -> Term -> Term -> MRM t Term mrGenBVVecFromVec m a v def_err_str n len = do err_tm <- mrErrorTerm a def_err_str liftSC2 scGlobalApply "Prelude.genBVVecFromVec" [m, a, v, err_tm, n, len] -- | Create a term representing an application of @Prelude.genFromBVVec@, -- where the default value argument is @Prelude.error@ of the given 'T.Text' -mrGenFromBVVec :: Term -> Term -> Term -> Term -> T.Text -> Term -> MRM Term +mrGenFromBVVec :: Term -> Term -> Term -> Term -> T.Text -> Term -> MRM t Term mrGenFromBVVec n len a v def_err_str m = do err_tm <- mrErrorTerm a def_err_str liftSC2 scGlobalApply "Prelude.genFromBVVec" [n, len, a, v, err_tm, m] @@ -497,7 +529,7 @@ mrGenFromBVVec n len a v def_err_str m = ---------------------------------------------------------------------- -- | Apply a 'TermProj' to perform a projection on a 'Term' -doTermProj :: Term -> TermProj -> MRM Term +doTermProj :: Term -> TermProj -> MRM t Term doTermProj (asPairValue -> Just (t, _)) TermProjLeft = return t doTermProj (asPairValue -> Just (_, t)) TermProjRight = return t doTermProj (asRecordValue -> Just t_map) (TermProjRecord fld) @@ -508,7 +540,7 @@ doTermProj t (TermProjRecord fld) = liftSC2 scRecordSelect t fld -- | Apply a 'TermProj' to a type to get the output type of the projection, -- assuming that the type is already normalized -doTypeProj :: Term -> TermProj -> MRM Term +doTypeProj :: Term -> TermProj -> MRM t Term doTypeProj (asPairType -> Just (tp1, _)) TermProjLeft = return tp1 doTypeProj (asPairType -> Just (_, tp2)) TermProjRight = return tp2 doTypeProj (asRecordType -> Just tp_map) (TermProjRecord fld) @@ -520,7 +552,7 @@ doTypeProj _ _ = error "doTypeProj" -- | Get and normalize the type of a 'FunName' -funNameType :: FunName -> MRM Term +funNameType :: FunName -> MRM t Term funNameType (CallSName var) = liftSC1 scWhnf $ mrVarType var funNameType (EVarFunName var) = liftSC1 scWhnf $ mrVarType var funNameType (GlobalName gd projs) = @@ -528,28 +560,28 @@ funNameType (GlobalName gd projs) = foldM doTypeProj gd_tp projs -- | Apply a 'Term' to a list of arguments and beta-reduce in Mr. Monad -mrApplyAll :: Term -> [Term] -> MRM Term +mrApplyAll :: Term -> [Term] -> MRM t Term mrApplyAll f args = liftSC2 scApplyAllBeta f args -- | Apply a 'Term' to a single argument and beta-reduce in Mr. Monad -mrApply :: Term -> Term -> MRM Term +mrApply :: Term -> Term -> MRM t Term mrApply f arg = mrApplyAll f [arg] -- | Return the unit type as a 'Type' -mrUnitType :: MRM Type +mrUnitType :: MRM t Type mrUnitType = Type <$> liftSC0 scUnitType -- | Build a constructor application in Mr. Monad -mrCtorApp :: Ident -> [Term] -> MRM Term +mrCtorApp :: Ident -> [Term] -> MRM t Term mrCtorApp = liftSC2 scCtorApp -- | Build a 'Term' for a global in Mr. Monad -mrGlobalTerm :: Ident -> MRM Term +mrGlobalTerm :: Ident -> MRM t Term mrGlobalTerm = liftSC1 scGlobalDef -- | Like 'scBvConst', but if given a bitvector literal it is converted to a -- natural number literal -mrBvToNat :: Term -> Term -> MRM Term +mrBvToNat :: Term -> Term -> MRM t Term mrBvToNat _ (asArrayValue -> Just (asBoolType -> Just _, mapM asBool -> Just bits)) = liftSC1 scNat $ foldl' (\n bit -> if bit then 2*n+1 else 2*n) 0 bits @@ -558,30 +590,30 @@ mrBvToNat n len = liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in -- the order as seen "from the outside" -mrUVarsOuterToInner :: MRM [(LocalName,Term)] +mrUVarsOuterToInner :: MRM t [(LocalName,Term)] mrUVarsOuterToInner = mrVarCtxOuterToInner <$> mrUVars -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the most recently bound uvar first, i.e., in -- the order as seen "from the inside" -mrUVarsInnerToOuter :: MRM [(LocalName,Term)] +mrUVarsInnerToOuter :: MRM t [(LocalName,Term)] mrUVarsInnerToOuter = mrVarCtxInnerToOuter <$> mrUVars -- | Get the type of a 'Term' in the current uvar context -mrTypeOf :: Term -> MRM Term +mrTypeOf :: Term -> MRM t Term mrTypeOf t = -- NOTE: scTypeOf' wants the type context in the most recently bound var first -- mrDebugPPPrefix 3 "mrTypeOf:" t >> mrUVarsInnerToOuter >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t -- | Check if two 'Term's are convertible in the 'MRM' monad -mrConvertible :: Term -> Term -> MRM Bool +mrConvertible :: Term -> Term -> MRM t Bool mrConvertible = liftSC4 scConvertibleEval scTypeCheckWHNF True -- | Take a 'FunName' @f@ for a monadic function of type @vars -> SpecM a@ and -- compute the type @SpecM [args/vars]a@ of @f@ applied to @args@. Return the -- type @[args/vars]a@ that @SpecM@ is applied to, along with its parameters. -mrFunOutType :: FunName -> [Term] -> MRM (SpecMParams Term, Term) +mrFunOutType :: FunName -> [Term] -> MRM t (SpecMParams Term, Term) mrFunOutType fname args = mrApplyAll (funNameTerm fname) args >>= mrTypeOf >>= \case (asSpecM -> Just (params, tp)) -> (params,) <$> liftSC1 scWhnf tp @@ -611,7 +643,7 @@ uniquifyNames (nm:nms) nms_other = -- | Build a lambda term with the lifting (in the sense of 'incVars') of an -- MR Solver term mrLambdaLift :: TermLike tm => [(LocalName,Term)] -> tm -> - ([Term] -> tm -> MRM Term) -> MRM Term + ([Term] -> tm -> MRM t Term) -> MRM t Term mrLambdaLift [] t f = f [] t mrLambdaLift ctx t f = do -- uniquifyNames doesn't care about the order of the names in its second, @@ -629,7 +661,7 @@ mrLambdaLift ctx t f = -- | Call 'mrLambdaLift' with exactly one 'Term' argument. mrLambdaLift1 :: TermLike tm => (LocalName,Term) -> tm -> - (Term -> tm -> MRM Term) -> MRM Term + (Term -> tm -> MRM t Term) -> MRM t Term mrLambdaLift1 ctx t f = mrLambdaLift [ctx] t $ \vars t' -> case vars of @@ -638,7 +670,7 @@ mrLambdaLift1 ctx t f = -- | Call 'mrLambdaLift' with exactly two 'Term' arguments. mrLambdaLift2 :: TermLike tm => (LocalName,Term) -> (LocalName,Term) -> tm -> - (Term -> Term -> tm -> MRM Term) -> MRM Term + (Term -> Term -> tm -> MRM t Term) -> MRM t Term mrLambdaLift2 ctx1 ctx2 t f = mrLambdaLift [ctx1, ctx2] t $ \vars t' -> case vars of @@ -648,7 +680,7 @@ mrLambdaLift2 ctx1 ctx2 t f = -- | Run a MR Solver computation in a context extended with a universal -- variable, which is passed as a 'Term' to the sub-computation. Note that any -- assumptions made in the sub-computation will be lost when it completes. -withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a +withUVar :: LocalName -> Type -> (Term -> MRM t a) -> MRM t a withUVar nm tp m = withUVars (singletonMRVarCtx nm tp) $ \case [v] -> m v _ -> error "withUVar: impossible" @@ -656,13 +688,13 @@ withUVar nm tp m = withUVars (singletonMRVarCtx nm tp) $ \case -- | Run a MR Solver computation in a context extended with a universal variable -- and pass it the lifting (in the sense of 'incVars') of an MR Solver term withUVarLift :: TermLike tm => LocalName -> Type -> tm -> - (Term -> tm -> MRM a) -> MRM a + (Term -> tm -> MRM t a) -> MRM t a withUVarLift nm tp t m = withUVar nm tp (\x -> liftTermLike 0 1 t >>= m x) -- | Run a MR Solver computation in a context extended with a list of universal -- variables, passing 'Term's for those variables to the supplied computation. -withUVars :: MRVarCtx -> ([Term] -> MRM a) -> MRM a +withUVars :: MRVarCtx -> ([Term] -> MRM t a) -> MRM t a withUVars (mrVarCtxLength -> 0) f = f [] withUVars ctx f = do -- for uniquifyNames, we want to consider the oldest names first, thus we @@ -687,7 +719,7 @@ withUVars ctx f = foldr (\nm m -> mapMRFailure (MRFailureLocalVar nm) m) (f vars) nms -- | Run a MR Solver in a top-level context, i.e., with no uvars or assumptions -withNoUVars :: MRM a -> MRM a +withNoUVars :: MRM t a -> MRM t a withNoUVars m = do true_tm <- liftSC1 scBool True local (\info -> info { mriUVars = emptyMRVarCtx, mriAssumptions = true_tm, @@ -695,35 +727,35 @@ withNoUVars m = -- | Run a MR Solver in a context of only the specified UVars, no others - -- note that this also clears all assumptions -withOnlyUVars :: MRVarCtx -> MRM a -> MRM a +withOnlyUVars :: MRVarCtx -> MRM t a -> MRM t a withOnlyUVars vars m = withNoUVars $ withUVars vars $ const m -- | Build 'Term's for all the uvars currently in scope, ordered from least to -- most recently bound -getAllUVarTerms :: MRM [Term] +getAllUVarTerms :: MRM t [Term] getAllUVarTerms = (mrVarCtxLength <$> mrUVars) >>= \len -> mapM (liftSC1 scLocalVar) [len-1, len-2 .. 0] -- | Lambda-abstract all the current uvars out of a 'Term', with the least -- recently bound variable being abstracted first -lambdaUVarsM :: Term -> MRM Term +lambdaUVarsM :: Term -> MRM t Term lambdaUVarsM t = mrUVarsOuterToInner >>= \ctx -> liftSC2 scLambdaList ctx t -- | Pi-abstract all the current uvars out of a 'Term', with the least recently -- bound variable being abstracted first -piUVarsM :: Term -> MRM Term +piUVarsM :: Term -> MRM t Term piUVarsM t = mrUVarsOuterToInner >>= \ctx -> liftSC2 scPiList ctx t -- | Instantiate all uvars in a term using the supplied function -instantiateUVarsM :: TermLike a => (LocalName -> Term -> MRM Term) -> a -> MRM a +instantiateUVarsM :: forall a t. TermLike a => (LocalName -> Term -> MRM t Term) -> a -> MRM t a instantiateUVarsM f a = do ctx <- mrUVarsOuterToInner -- Remember: the uvar context is outermost to innermost, so we bind -- variables from left to right, substituting earlier ones into the types -- of later ones, but all substitutions are in reverse order, since -- substTerm and friends like innermost bindings first - let helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] + let helper :: [Term] -> [(LocalName,Term)] -> MRM t [Term] helper tms [] = return tms helper tms ((nm,tp):vars) = do tp' <- substTerm 0 tms tp @@ -733,7 +765,7 @@ instantiateUVarsM f a = substTermLike 0 ecs a -- | Convert an 'MRVar' to a 'Term', applying it to all the uvars in scope -mrVarTerm :: MRVar -> MRM Term +mrVarTerm :: MRVar -> MRM t Term mrVarTerm (MRVar ec) = do var_tm <- liftSC1 scExtCns ec vars <- getAllUVarTerms @@ -743,15 +775,15 @@ mrVarTerm (MRVar ec) = -- should be of @Prop@ sort, by creating an 'ExtCns' axiom. This is sound as -- long as we only use the resulting term in computation branches where we know -- the proposition holds. -mrDummyProof :: Term -> MRM Term +mrDummyProof :: Term -> MRM t Term mrDummyProof tp = mrFreshVar "pf" tp >>= mrVarTerm -- | Get the 'VarInfo' associated with a 'MRVar' -mrVarInfo :: MRVar -> MRM (Maybe MRVarInfo) +mrVarInfo :: MRVar -> MRM t (Maybe MRVarInfo) mrVarInfo var = Map.lookup var <$> mrVars -- | Convert an 'ExtCns' to a 'FunName' -extCnsToFunName :: ExtCns Term -> MRM FunName +extCnsToFunName :: ExtCns Term -> MRM t FunName extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case Just (EVarInfo _) -> return $ EVarFunName var Just (CallVarInfo _) -> return $ CallSName var @@ -761,19 +793,19 @@ extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case _ -> error "extCnsToFunName: unreachable" -- | Get the 'FunName' of a global definition -mrGlobalDef :: Ident -> MRM FunName +mrGlobalDef :: Ident -> MRM t FunName mrGlobalDef ident = asTypedGlobalDef <$> liftSC1 scGlobalDef ident >>= \case Just glob -> return $ GlobalName glob [] _ -> error $ "mrGlobalDef: could not get GlobalDef of: " ++ show ident -- | Get the body of a global definition, raising an 'error' if none is found -mrGlobalDefBody :: Ident -> MRM Term +mrGlobalDefBody :: Ident -> MRM t Term mrGlobalDefBody ident = asConstant <$> liftSC1 scGlobalDef ident >>= \case Just (_, Just body) -> return body _ -> error $ "mrGlobalDefBody: global has no definition: " ++ show ident -- | Get the body of a function @f@ if it has one -mrFunNameBody :: FunName -> MRM (Maybe Term) +mrFunNameBody :: FunName -> MRM t (Maybe Term) mrFunNameBody (CallSName var) = mrVarInfo var >>= \case Just (CallVarInfo body) -> return $ Just body @@ -785,7 +817,7 @@ mrFunNameBody (GlobalName _ _) = return Nothing mrFunNameBody (EVarFunName _) = return Nothing -- | Get the body of a function @f@ applied to some arguments, if possible -mrFunBody :: FunName -> [Term] -> MRM (Maybe Term) +mrFunBody :: FunName -> [Term] -> MRM t (Maybe Term) mrFunBody f args = mrFunNameBody f >>= \case Just body -> Just <$> mrApplyAll body args Nothing -> return Nothing @@ -793,7 +825,7 @@ mrFunBody f args = mrFunNameBody f >>= \case -- | Get the body of a function @f@ applied to some arguments, as per -- 'mrFunBody', and also return whether its body recursively calls itself, as -- per 'mrCallsFun' -mrFunBodyRecInfo :: FunName -> [Term] -> MRM (Maybe (Term, Bool)) +mrFunBodyRecInfo :: FunName -> [Term] -> MRM t (Maybe (Term, Bool)) mrFunBodyRecInfo f args = mrFunBody f args >>= \case Just f_body -> Just <$> (f_body,) <$> mrCallsFun f f_body @@ -801,7 +833,7 @@ mrFunBodyRecInfo f args = -- | Test if a 'Term' contains, after possibly unfolding some functions, a call -- to a given function @f@ again -mrCallsFun :: FunName -> Term -> MRM Bool +mrCallsFun :: FunName -> Term -> MRM t Bool mrCallsFun f = memoFixTermFun $ \recurse t -> case t of (asExtCns -> Just ec) -> do g <- extCnsToFunName ec @@ -822,7 +854,7 @@ mrCallsFun f = memoFixTermFun $ \recurse t -> case t of -- equality representing the proposition that the 'DataTypeAssump' holds. -- For example, @mrDataTypeAssumpTerm x (IsLeft y)@ for @x : Either a b@ -- would return @Eq (Either a b) x (Left a b y)@. -mrDataTypeAssumpTerm :: Term -> DataTypeAssump -> MRM Term +mrDataTypeAssumpTerm :: Term -> DataTypeAssump -> MRM t Term mrDataTypeAssumpTerm x dt = do tp <- mrTypeOf x y <- case dt of @@ -841,23 +873,25 @@ mrDataTypeAssumpTerm x dt = -- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of the -- given 'Term's, after quantifying over all current 'mrUVars' with Pi types -- and adding calls to @assertS@ on the right hand side for any current --- 'mrAssumps' and/or 'mrDataTypeAssump's -mrRefinementGoal :: Term -> Term -> MRM Term -mrRefinementGoal t1 t2 = +-- 'mrAssumps' and/or 'mrDataTypeAssump's if the given 'Bool' is 'True' +mrRefinementTerm :: Bool -> Term -> Term -> MRM t Term +mrRefinementTerm includeAssumps t1 t2 = do (SpecMParams ev1 stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1 (SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2 assumps <- mrAssumptions assumpsAssert <- liftSC2 scGlobalApply "Prelude.assertBoolS" [ev2, stack2, assumps] - t2' <- case asBool assumps of - Just True -> return t2 - _ -> bindConst ev2 stack2 tp2 assumpsAssert t2 + t2' <- if includeAssumps && asBool assumps /= Just True + then bindConst ev2 stack2 tp2 assumpsAssert t2 + else return t2 dtAssumps <- HashMap.toList <$> mrDataTypeAssumps dtAssumpAsserts <- forM dtAssumps $ \(nm, assump) -> do assump_tm <- mrDataTypeAssumpTerm nm assump liftSC2 scGlobalApply "Prelude.assertS" [ev2, stack2, assump_tm] - t2'' <- foldrM (bindConst ev2 stack2 tp2) t2' dtAssumpAsserts + t2'' <- if includeAssumps + then foldrM (bindConst ev2 stack2 tp2) t2' dtAssumpAsserts + else return t2' coIndHyps <- mrCoIndHyps (rpre, rpost, rr) <- if Map.null coIndHyps @@ -883,16 +917,16 @@ mrRefinementGoal t1 t2 = -- | Make a fresh 'MRVar' of a given type, which must be closed, i.e., have no -- free uvars -mrFreshVarCl :: LocalName -> Term -> MRM MRVar +mrFreshVarCl :: LocalName -> Term -> MRM t MRVar mrFreshVarCl nm tp = MRVar <$> liftSC2 scFreshEC nm tp -- | Make a fresh 'MRVar' of type @(u1:tp1) -> ... (un:tpn) -> tp@, where the -- @ui@ are all the current uvars -mrFreshVar :: LocalName -> Term -> MRM MRVar +mrFreshVar :: LocalName -> Term -> MRM t MRVar mrFreshVar nm tp = piUVarsM tp >>= mrFreshVarCl nm -- | Set the info associated with an 'MRVar', assuming it has not been set -mrSetVarInfo :: MRVar -> MRVarInfo -> MRM () +mrSetVarInfo :: MRVar -> MRVarInfo -> MRM t () mrSetVarInfo var info = debugPretty 3 ("mrSetVarInfo" <+> ppInEmptyCtx var <+> "=" <+> ppInEmptyCtx info) >> (modify $ \st -> @@ -904,7 +938,7 @@ mrSetVarInfo var info = -- | Make a fresh existential variable of the given type, abstracting out all -- the current uvars and returning the new evar applied to all current uvars -mrFreshEVar :: LocalName -> Type -> MRM Term +mrFreshEVar :: LocalName -> Type -> MRM t Term mrFreshEVar nm (Type tp) = do var <- mrFreshVar nm tp mrSetVarInfo var (EVarInfo Nothing) @@ -912,21 +946,21 @@ mrFreshEVar nm (Type tp) = -- | Return a fresh sequence of existential variables from a 'MRVarCtx'. -- Return the new evars all applied to the current uvars. -mrFreshEVars :: MRVarCtx -> MRM [Term] +mrFreshEVars :: MRVarCtx -> MRM t [Term] mrFreshEVars = helper [] . mrVarCtxOuterToInner where -- Return fresh evars for the suffix of a context of variable names and types, -- where the supplied Terms are evars that have already been generated for the -- earlier part of the context, and so must be substituted into the remaining -- types in the context. Since we want to make fresh evars for the oldest -- variables first, the second argument must be in outer-to-inner order. - helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] + helper :: [Term] -> [(LocalName,Term)] -> MRM t [Term] helper evars [] = return evars helper evars ((nm,tp):ctx) = do evar <- substTerm 0 evars tp >>= mrFreshEVar nm . Type helper (evar:evars) ctx -- | Set the value of an evar to a closed term -mrSetEVarClosed :: MRVar -> Term -> MRM () +mrSetEVarClosed :: MRVar -> Term -> MRM t () mrSetEVarClosed var val = do val_tp <- mrTypeOf val -- NOTE: need to instantiate any evars in the type of var, to ensure the @@ -957,7 +991,7 @@ mrSetEVarClosed var val = -- expression @e@ by trying to set @X@ to @\ x1 ... xn -> e@. This only works if -- each free uvar @xi@ in @e@ is one of the arguments @ej@ to @X@ (though it -- need not be the case that @i=j@). Return whether this succeeded. -mrTrySetAppliedEVar :: MRVar -> [Term] -> Term -> MRM Bool +mrTrySetAppliedEVar :: MRVar -> [Term] -> Term -> MRM t Bool mrTrySetAppliedEVar evar args t = -- Get the complete list of argument variables of the type of evar let (evar_vars, _) = asPiList (mrVarType evar) in @@ -998,7 +1032,7 @@ mrTrySetAppliedEVar evar args t = -- | Replace all evars in a 'Term' with their instantiations when they have one -mrSubstEVars :: Term -> MRM Term +mrSubstEVars :: Term -> MRM t Term mrSubstEVars = memoFixTermFun $ \recurse t -> do var_map <- mrVars case t of @@ -1010,7 +1044,7 @@ mrSubstEVars = memoFixTermFun $ \recurse t -> -- | Replace all evars in a 'Term' with their instantiations, returning -- 'Nothing' if we hit an uninstantiated evar -mrSubstEVarsStrict :: Term -> MRM (Maybe Term) +mrSubstEVarsStrict :: Term -> MRM t (Maybe Term) mrSubstEVarsStrict top_t = runMaybeT $ flip memoFixTermFun top_t $ \recurse t -> do var_map <- lift mrVars @@ -1025,15 +1059,15 @@ mrSubstEVarsStrict top_t = _ -> traverseSubterms recurse t -- | Makes 'mrSubstEVarsStrict' be marked as used -_mrSubstEVarsStrict :: Term -> MRM (Maybe Term) +_mrSubstEVarsStrict :: Term -> MRM t (Maybe Term) _mrSubstEVarsStrict = mrSubstEVarsStrict -- | Get the 'CoIndHyp' for a pair of 'FunName's, if there is one -mrGetCoIndHyp :: FunName -> FunName -> MRM (Maybe CoIndHyp) +mrGetCoIndHyp :: FunName -> FunName -> MRM t (Maybe CoIndHyp) mrGetCoIndHyp nm1 nm2 = Map.lookup (nm1, nm2) <$> mrCoIndHyps -- | Run a compuation under an additional co-inductive assumption -withCoIndHyp :: CoIndHyp -> MRM a -> MRM a +withCoIndHyp :: CoIndHyp -> MRM t a -> MRM t a withCoIndHyp hyp m = do debugPretty 2 ("withCoIndHyp" <+> ppInEmptyCtx hyp) hyps' <- Map.insert (coIndHypLHSFun hyp, @@ -1042,7 +1076,7 @@ withCoIndHyp hyp m = -- | Generate fresh evars for the context of a 'CoIndHyp' and -- substitute them into its arguments and right-hand side -instantiateCoIndHyp :: CoIndHyp -> MRM ([Term], [Term]) +instantiateCoIndHyp :: CoIndHyp -> MRM t ([Term], [Term]) instantiateCoIndHyp (CoIndHyp {..}) = do evars <- mrFreshEVars coIndHypCtx lhs <- substTermLike 0 evars coIndHypLHS @@ -1052,9 +1086,9 @@ instantiateCoIndHyp (CoIndHyp {..}) = -- | Apply the invariants of a 'CoIndHyp' to their respective arguments, -- yielding @Bool@ conditions, using the constant @True@ value when an -- invariant is absent -applyCoIndHypInvariants :: CoIndHyp -> MRM (Term, Term) +applyCoIndHypInvariants :: CoIndHyp -> MRM t (Term, Term) applyCoIndHypInvariants hyp = - let apply_invariant :: Maybe Term -> [Term] -> MRM Term + let apply_invariant :: Maybe Term -> [Term] -> MRM t Term apply_invariant (Just (asLambdaList -> (vars, phi))) args | length vars == length args -- NOTE: applying to a list of arguments == substituting the reverse @@ -1069,23 +1103,21 @@ applyCoIndHypInvariants hyp = return (invar1, invar2) -- | Look up the 'FunAssump' for a 'FunName', if there is one -mrGetFunAssump :: FunName -> MRM (Maybe FunAssump) -mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps +mrGetFunAssump :: FunName -> MRM t (Maybe (FunAssump t)) +mrGetFunAssump nm = lookupFunAssump nm <$> mrRefnset -- | Run a computation under the additional assumption that a named function -- applied to a list of arguments refines a given right-hand side, all of which -- are 'Term's that can have the current uvars free -withFunAssump :: FunName -> [Term] -> NormComp -> MRM a -> MRM a +withFunAssump :: FunName -> [Term] -> Term -> MRM t a -> MRM t a withFunAssump fname args rhs m = do k <- mkCompFunReturn <$> mrFunOutType fname args mrDebugPPPrefixSep 1 "withFunAssump" (FunBind fname args k) "|=" rhs ctx <- mrUVars - assumps <- mrFunAssumps - let assump = FunAssump ctx args (RewriteFunAssump rhs) - let assumps' = Map.insert fname assump assumps - local (\info -> - let env' = (mriEnv info) { mreFunAssumps = assumps' } in - info { mriEnv = env' }) m + rs <- mrRefnset + let assump = FunAssump ctx fname args (RewriteFunAssump rhs) Nothing + let rs' = addFunAssump assump rs + local (\info -> info { mriRefnset = rs' }) m -- | Get the invariant hint associated with a function name, by unfolding the -- name and checking if its body has the form @@ -1095,14 +1127,14 @@ withFunAssump fname args rhs m = -- If so, return @\ x1 ... xn -> phi@ as a term with the @xi@ variables free. -- Otherwise, return 'Nothing'. Note that this function will also look past -- any initial @bindM ... (assertFiniteM ...)@ applications. -mrGetInvariant :: FunName -> MRM (Maybe Term) +mrGetInvariant :: FunName -> MRM t (Maybe Term) mrGetInvariant nm = mrFunNameBody nm >>= \case Just body -> mrGetInvariantBody body _ -> return Nothing -- | The main loop of 'mrGetInvariant', which operates on a function body -mrGetInvariantBody :: Term -> MRM (Maybe Term) +mrGetInvariantBody :: Term -> MRM t (Maybe Term) mrGetInvariantBody tm = case asApplyAll tm of -- go inside any top-level lambdas (asLambda -> Just (nm, tp, body), []) -> @@ -1129,7 +1161,7 @@ mrGetInvariantBody tm = case asApplyAll tm of -- | Add an assumption of type @Bool@ to the current path condition while -- executing a sub-computation -withAssumption :: Term -> MRM a -> MRM a +withAssumption :: Term -> MRM t a -> MRM t a withAssumption phi m = do mrDebugPPPrefix 1 "withAssumption" phi assumps <- mrAssumptions @@ -1137,28 +1169,34 @@ withAssumption phi m = local (\info -> info { mriAssumptions = assumps' }) m -- | Remove any existing assumptions and replace them with a Boolean term -withOnlyAssumption :: Term -> MRM a -> MRM a +withOnlyAssumption :: Term -> MRM t a -> MRM t a withOnlyAssumption phi m = do mrDebugPPPrefix 1 "withOnlyAssumption" phi local (\info -> info { mriAssumptions = phi }) m -- | Add a 'DataTypeAssump' to the current context while executing a -- sub-computations -withDataTypeAssump :: Term -> DataTypeAssump -> MRM a -> MRM a +withDataTypeAssump :: Term -> DataTypeAssump -> MRM t a -> MRM t a withDataTypeAssump x assump m = do mrDebugPPPrefixSep 1 "withDataTypeAssump" x "==" assump dataTypeAssumps' <- HashMap.insert x assump <$> mrDataTypeAssumps local (\info -> info { mriDataTypeAssumps = dataTypeAssumps' }) m -- | Get the 'DataTypeAssump' associated to the given term, if one exists -mrGetDataTypeAssump :: Term -> MRM (Maybe DataTypeAssump) +mrGetDataTypeAssump :: Term -> MRM t (Maybe DataTypeAssump) mrGetDataTypeAssump x = HashMap.lookup x <$> mrDataTypeAssumps --- | Convert a 'FunAssumpRHS' to a 'NormComp' -mrFunAssumpRHSAsNormComp :: FunAssumpRHS -> MRM NormComp -mrFunAssumpRHSAsNormComp (OpaqueFunAssump f args) = - FunBind f args <$> mkCompFunReturn <$> mrFunOutType f args -mrFunAssumpRHSAsNormComp (RewriteFunAssump rhs) = return rhs +-- | Record a use of an SMT solver (for tracking 'SolverStats' and 'MRSolverEvidence') +recordUsedSolver :: SolverStats -> Term -> MRM t () +recordUsedSolver stats prop = + modify $ \st -> st { mrsSolverStats = stats <> mrsSolverStats st, + mrsEvidence = MREUsedSolver stats prop : mrsEvidence st } + +-- | Record a use of a 'FunAssump' (for 'MRSolverEvidence') +recordUsedFunAssump :: FunAssump t -> MRM t () +recordUsedFunAssump (fassumpAnnotation -> Just t) = + modify $ \st -> st { mrsEvidence = MREUsedFunAssump t : mrsEvidence st } +recordUsedFunAssump _ = return () ---------------------------------------------------------------------- @@ -1166,27 +1204,27 @@ mrFunAssumpRHSAsNormComp (RewriteFunAssump rhs) = return rhs ---------------------------------------------------------------------- -- | Print a 'String' if the debug level is at least the supplied 'Int' -debugPrint :: Int -> String -> MRM () +debugPrint :: Int -> String -> MRM t () debugPrint i str = mrDebugLevel >>= \lvl -> if lvl >= i then liftIO (hPutStrLn stderr str) else return () -- | Print a document if the debug level is at least the supplied 'Int' -debugPretty :: Int -> SawDoc -> MRM () +debugPretty :: Int -> SawDoc -> MRM t () debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp -- | Pretty-print an object in the current context if the current debug level is -- at least the supplied 'Int' -debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () +debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM t () debugPrettyInCtx i a = mrUVars >>= \ctx -> debugPrint i (showInCtx ctx a) -- | Pretty-print an object relative to the current context -mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc +mrPPInCtx :: PrettyInCtx a => a -> MRM t SawDoc mrPPInCtx a = runPPInCtxM (prettyInCtx a) <$> mrUVars -- | Pretty-print the result of 'ppWithPrefix' relative to the current uvar -- context to 'stderr' if the debug level is at least the 'Int' provided -mrDebugPPPrefix :: PrettyInCtx a => Int -> String -> a -> MRM () +mrDebugPPPrefix :: PrettyInCtx a => Int -> String -> a -> MRM t () mrDebugPPPrefix i pre a = mrUVars >>= \ctx -> debugPretty i $ @@ -1195,7 +1233,7 @@ mrDebugPPPrefix i pre a = -- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar -- context to 'stderr' if the debug level is at least the 'Int' provided mrDebugPPPrefixSep :: (PrettyInCtx a, PrettyInCtx b) => - Int -> String -> a -> String -> b -> MRM () + Int -> String -> a -> String -> b -> MRM t () mrDebugPPPrefixSep i pre a1 sp a2 = mrUVars >>= \ctx -> debugPretty i $ diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index f5f8be763f..d519d0e15b 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -350,7 +350,7 @@ smtNorm sc t = normalizeSharedTerm sc modmap (smtNormPrims sc) Map.empty Set.empty t -- | Normalize a 'Term' using some Mr Solver specific primitives -mrNormTerm :: Term -> MRM Term +mrNormTerm :: Term -> MRM t Term mrNormTerm t = debugPrint 2 "Normalizing term:" >> debugPrettyInCtx 2 t >> @@ -358,7 +358,7 @@ mrNormTerm t = -- | Normalize an open term by wrapping it in lambdas, normalizing, and then -- removing those lambdas -mrNormOpenTerm :: Term -> MRM Term +mrNormOpenTerm :: Term -> MRM t Term mrNormOpenTerm body = do length_ctx <- mrVarCtxLength <$> mrUVars fun_term <- lambdaUVarsM body @@ -380,7 +380,7 @@ mrNormOpenTerm body = -- uvars or 'MRVar's. -- -- FIXME: use the timeout! -mrProvableRaw :: Term -> MRM Bool +mrProvableRaw :: Term -> MRM t Bool mrProvableRaw prop_term = do sc <- mrSC prop <- liftSC1 termToProp prop_term @@ -401,18 +401,19 @@ mrProvableRaw prop_term = Left msg -> debugPrint 2 ("SMT solver encountered a saw-core error term: " ++ msg) >> return False - Right (Just cex, _) -> + Right (Just cex, stats) -> debugPrint 2 "SMT solver response: not provable" >> debugPrint 3 ("Counterexample:" ++ concatMap (\(x,v) -> "\n - " ++ renderSawDoc defaultPPOpts (ppTerm defaultPPOpts (Unshared (FTermF (ExtCns x)))) ++ " = " ++ renderSawDoc defaultPPOpts (ppFirstOrderValue defaultPPOpts v)) cex) >> - return False - Right (Nothing, _) -> - debugPrint 2 "SMT solver response: provable" >> return True + recordUsedSolver stats prop_term >> return False + Right (Nothing, stats) -> + debugPrint 2 "SMT solver response: provable" >> + recordUsedSolver stats prop_term >> return True -- | Test if a Boolean term over the current uvars is provable given the current -- assumptions -mrProvable :: Term -> MRM Bool +mrProvable :: Term -> MRM t Bool mrProvable (asBool -> Just b) = return b mrProvable bool_tm = do mrUVars >>= mrDebugPPPrefix 3 "mrProvable uvars:" @@ -422,7 +423,7 @@ mrProvable bool_tm = mrNormTerm prop_inst >>= mrProvableRaw where -- | Given a UVar name and type, generate a 'Term' to be passed to -- SMT, with special cases for BVVec and pair types - instUVar :: LocalName -> Term -> MRM Term + instUVar :: LocalName -> Term -> MRM t Term instUVar nm tp = mrDebugPPPrefix 3 "instUVar" (nm, tp) >> liftSC1 scWhnf tp >>= \case (asNonBVVecVectorType -> Just (m, a)) -> @@ -532,13 +533,13 @@ nonTrivialConv (ConvComp cs) = not (null cs) -- | Return 'True' iff the given 'InjConversion's are convertible, i.e. if -- the two injective conversions are the compositions of the same constructors, -- and the arguments to those constructors are convertible via 'mrConvertible' -mrConvsConvertible :: InjConversion -> InjConversion -> MRM Bool +mrConvsConvertible :: InjConversion -> InjConversion -> MRM t Bool mrConvsConvertible (ConvComp cs1) (ConvComp cs2) = if length cs1 /= length cs2 then return False else and <$> zipWithM mrSingleConvsConvertible cs1 cs2 -- | Used in the definition of 'mrConvsConvertible' -mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM Bool +mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM t Bool mrSingleConvsConvertible SingleNatToNum SingleNatToNum = return True mrSingleConvsConvertible (SingleBVToNat n1) (SingleBVToNat n2) = return $ n1 == n2 mrSingleConvsConvertible (SingleBVVecToVec n1 len1 a1 m1) @@ -559,9 +560,9 @@ mrSingleConvsConvertible _ _ = return False -- @c1 <> c2 <> ... <> cn@ are applied from right to left as in function -- composition (i.e. @mrApplyConv (c1 <> c2 <> ... <> cn) t@ is equivalent to -- @mrApplyConv c1 (mrApplyConv c2 (... mrApplyConv cn t ...))@) -mrApplyConv :: InjConversion -> Term -> MRM Term +mrApplyConv :: InjConversion -> Term -> MRM t Term mrApplyConv (ConvComp cs) = flip (foldrM go) cs - where go :: SingleInjConversion -> Term -> MRM Term + where go :: SingleInjConversion -> Term -> MRM t Term go SingleNatToNum t = liftSC2 scCtorApp "Cryptol.TCNum" [t] go (SingleBVToNat n) t = liftSC2 scBvToNat n t go (SingleBVVecToVec n len a m) t = mrGenFromBVVec n len a t "mrApplyConv" m @@ -572,9 +573,9 @@ mrApplyConv (ConvComp cs) = flip (foldrM go) cs -- | Try to apply the inverse of the given the conversion to the given term, -- raising an error if this is not possible - see also 'mrApplyConv' -mrApplyInvConv :: InjConversion -> Term -> MRM Term +mrApplyInvConv :: InjConversion -> Term -> MRM t Term mrApplyInvConv (ConvComp cs) = flip (foldlM go) cs - where go :: Term -> SingleInjConversion -> MRM Term + where go :: Term -> SingleInjConversion -> MRM t Term go t SingleNatToNum = case asNum t of Just (Left t') -> return t' _ -> error "mrApplyInvConv: Num term does not normalize to TCNum constructor" @@ -628,7 +629,7 @@ mrConvOfTerm _ = NoConv -- types @tp1@ and @tp2@ are convertible, but the latter indicates that no -- 'InjConversion' could be found. findInjConvs :: Term -> Maybe Term -> Term -> Maybe Term -> - MRM (Maybe (Term, InjConversion, InjConversion)) + MRM t (Maybe (Term, InjConversion, InjConversion)) -- always add 'NatToNum' conversions findInjConvs (asDataType -> Just (primName -> "Cryptol.Num", _)) t1 tp2 t2 = do tp1' <- liftSC0 scNatType @@ -722,13 +723,13 @@ findInjConvs tp1 _ tp2 _ = -- | Build a Boolean 'Term' stating that two 'Term's are equal. This is like -- 'scEq' except that it works on open terms. -mrEq :: Term -> Term -> MRM Term +mrEq :: Term -> Term -> MRM t Term mrEq t1 t2 = mrTypeOf t1 >>= \tp -> mrEq' tp t1 t2 -- | Build a Boolean 'Term' stating that the second and third 'Term' arguments -- are equal, where the first 'Term' gives their type (which we assume is the -- same for both). This is like 'scEq' except that it works on open terms. -mrEq' :: Term -> Term -> Term -> MRM Term +mrEq' :: Term -> Term -> Term -> MRM t Term -- FIXME: For this Nat case, the definition of 'equalNat' in @Prims.hs@ means -- that if both sides do not have immediately clear bit-widths (e.g. either -- side is is an application of @mulNat@) this will 'error'... @@ -750,7 +751,7 @@ data TermInCtx = TermInCtx [(LocalName,Term)] Term -- | Lift a binary operation on 'Term's to one on 'TermInCtx's liftTermInCtx2 :: (SharedContext -> Term -> Term -> IO Term) -> - TermInCtx -> TermInCtx -> MRM TermInCtx + TermInCtx -> TermInCtx -> MRM t TermInCtx liftTermInCtx2 op (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = do -- Insert the variables in ctx2 into the context of t1 starting at index 0, @@ -767,9 +768,9 @@ liftTermInCtx2 op (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = extTermInCtx :: [(LocalName,Term)] -> TermInCtx -> TermInCtx extTermInCtx ctx (TermInCtx ctx' t) = TermInCtx (ctx++ctx') t --- | Run an 'MRM' computation in the context of a 'TermInCtx', passing in the +-- | Run an 'MRM t' computation in the context of a 'TermInCtx', passing in the -- 'Term' -withTermInCtx :: TermInCtx -> (Term -> MRM a) -> MRM a +withTermInCtx :: TermInCtx -> (Term -> MRM t a) -> MRM t a withTermInCtx (TermInCtx [] tm) f = f tm withTermInCtx (TermInCtx ((nm,tp):ctx) tm) f = withUVar nm (Type tp) $ const $ withTermInCtx (TermInCtx ctx tm) f @@ -777,8 +778,8 @@ withTermInCtx (TermInCtx ((nm,tp):ctx) tm) f = -- | A "simple" strategy for proving equality between two terms, which we assume -- are of the same type, which builds an equality proposition by applying the -- supplied function to both sides and passes this proposition to an SMT solver. -mrProveEqSimple :: (Term -> Term -> MRM Term) -> Term -> Term -> - MRM TermInCtx +mrProveEqSimple :: (Term -> Term -> MRM t Term) -> Term -> Term -> + MRM t TermInCtx -- NOTE: The use of mrSubstEVars instead of mrSubstEVarsStrict means that we -- allow evars in the terms we send to the SMT solver, but we treat them as -- uvars. @@ -789,18 +790,18 @@ mrProveEqSimple eqf t1 t2 = -- | Prove that two terms are equal, instantiating evars if necessary, -- returning true on success - the same as @mrProveRel False@ -mrProveEq :: Term -> Term -> MRM Bool +mrProveEq :: Term -> Term -> MRM t Bool mrProveEq = mrProveRel False -- | Prove that two terms are equal, instantiating evars if necessary, or -- throwing an error if this is not possible - the same as -- @mrAssertProveRel False@ -mrAssertProveEq :: Term -> Term -> MRM () +mrAssertProveEq :: Term -> Term -> MRM t () mrAssertProveEq = mrAssertProveRel False -- | Prove that two terms are related, heterogeneously iff the first argument -- is true, instantiating evars if necessary, returning true on success -mrProveRel :: Bool -> Term -> Term -> MRM Bool +mrProveRel :: Bool -> Term -> Term -> MRM t Bool mrProveRel het t1 t2 = do let nm = if het then "mrProveRel" else "mrProveEq" mrDebugPPPrefixSep 2 nm t1 (if het then "~=" else "==") t2 @@ -819,7 +820,7 @@ mrProveRel het t1 t2 = -- | Prove that two terms are related, heterogeneously iff the first argument, -- is true, instantiating evars if necessary, or throwing an error if this is -- not possible -mrAssertProveRel :: Bool -> Term -> Term -> MRM () +mrAssertProveRel :: Bool -> Term -> Term -> MRM t () mrAssertProveRel het t1 t2 = do success <- mrProveRel het t1 t2 if success then return () else @@ -829,7 +830,7 @@ mrAssertProveRel het t1 t2 = -- expressing that the fourth and fifth arguments are related, heterogeneously -- iff the first argument is true, whose types are given by the second and -- third arguments, respectively -mrProveRelH :: Bool -> Term -> Term -> Term -> Term -> MRM TermInCtx +mrProveRelH :: Bool -> Term -> Term -> Term -> Term -> MRM t TermInCtx mrProveRelH het tp1 tp2 t1 t2 = do varmap <- mrVars tp1' <- liftSC1 scWhnf tp1 @@ -839,7 +840,7 @@ mrProveRelH het tp1 tp2 t1 t2 = -- | The body of 'mrProveRelH' -- NOTE: Don't call this function recursively, call 'mrProveRelH' mrProveRelH' :: Map MRVar MRVarInfo -> Bool -> - Term -> Term -> Term -> Term -> MRM TermInCtx + Term -> Term -> Term -> Term -> MRM t TermInCtx -- If t1 is an instantiated evar, substitute and recurse mrProveRelH' var_map het tp1 tp2 (asEVarApp var_map -> Just (_, args, Just f)) t2 = diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 59b4ed7881..e12999258d 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -140,8 +140,10 @@ import Verifier.SAW.Term.Functor import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer import Verifier.SAW.Cryptol.Monadify +import SAWScript.Prover.SolverStats import SAWScript.Prover.MRSolver.Term +import SAWScript.Prover.MRSolver.Evidence import SAWScript.Prover.MRSolver.Monad import SAWScript.Prover.MRSolver.SMT @@ -193,7 +195,7 @@ asCallS _ = Nothing -- of our variable monadic operations (including, e.g., if-then-else and the -- either and maybe eliminators). But the implementation here should give the -- correct result for any code we are actually going to see... -mrReplaceCallsWithTerms :: [Term] -> Term -> MRM Term +mrReplaceCallsWithTerms :: [Term] -> Term -> MRM t Term mrReplaceCallsWithTerms top_tms top_t = flip runReaderT top_tms $ flip memoFixTermFun top_t $ \recurse t -> case t of @@ -219,7 +221,7 @@ mrReplaceCallsWithTerms top_tms top_t = -- | Bind fresh function variables for a @multiFixS@ with the given list of -- @LetRecType@s and tuple of definitions for the function bodies -mrFreshCallVars :: Term -> Term -> Term -> Term -> MRM [MRVar] +mrFreshCallVars :: Term -> Term -> Term -> Term -> MRM t [MRVar] mrFreshCallVars ev stack frame defs_tm = do -- First, make fresh function constants for all the recursive functions, @@ -251,13 +253,13 @@ mrFreshCallVars ev stack frame defs_tm = -- | Normalize a 'Term' of monadic type to monadic normal form -normCompTerm :: Term -> MRM NormComp +normCompTerm :: Term -> MRM t NormComp normCompTerm = normComp . CompTerm -- | Normalize a computation to monadic normal form, assuming any 'Term's it -- contains have already been normalized with respect to beta and projections -- (but constants need not be unfolded) -normComp :: Comp -> MRM NormComp +normComp :: Comp -> MRM t NormComp normComp (CompReturn t) = return $ RetS t normComp (CompBind m f) = do norm <- normComp m @@ -461,7 +463,7 @@ normComp (CompTerm t) = -- term @x@ if the body is of the form @Eq Bool x True@ or @Eq Bool True x@, -- or a 'Term' @x@ and a 'DataTypeAssump' @c@ if the body is of the form -- @Eq _ x (c ...)@ or @Eq _ (c ...) x@ -normCompAssertAssumeBody :: Term -> MRM (Either Term (Term, DataTypeAssump)) +normCompAssertAssumeBody :: Term -> MRM t (Either Term (Term, DataTypeAssump)) normCompAssertAssumeBody (asEq -> Just (_, x1, asBool -> Just True)) = return $ Left x1 normCompAssertAssumeBody (asEq -> Just (_, asBool -> Just True, x2)) = @@ -479,7 +481,7 @@ normCompAssertAssumeBody prop = -- | Bind a computation in whnf with a function, and normalize -normBind :: NormComp -> CompFun -> MRM NormComp +normBind :: NormComp -> CompFun -> MRM t NormComp normBind (RetS t) k = applyNormCompFun k t normBind (ErrorS msg) _ = return (ErrorS msg) normBind (Ite cond comp1 comp2) k = @@ -520,19 +522,19 @@ normBind (FunBind f args k1) k2 | otherwise -} = return $ FunBind f args (compFunComp k1 k2) -- | Bind a 'Term' for a computation with a function and normalize -normBindTerm :: Term -> CompFun -> MRM NormComp +normBindTerm :: Term -> CompFun -> MRM t NormComp normBindTerm t f = normCompTerm t >>= \m -> normBind m f {- -- | Get the return type of a 'CompFun' -compFunReturnType :: CompFun -> MRM Term +compFunReturnType :: CompFun -> MRM t Term compFunReturnType (CompFunTerm _ t) = mrTypeOf t compFunReturnType (CompFunComp _ g) = compFunReturnType g compFunReturnType (CompFunReturn (Type t)) = return t -} -- | Apply a computation function to a term argument to get a computation -applyCompFun :: CompFun -> Term -> MRM Comp +applyCompFun :: CompFun -> Term -> MRM t Comp applyCompFun (CompFunComp f g) t = -- (f >=> g) t == f t >>= g do comp <- applyCompFun f t @@ -542,7 +544,7 @@ applyCompFun (CompFunReturn _ _) t = applyCompFun (CompFunTerm _ f) t = CompTerm <$> mrApplyAll f [t] -- | Convert a 'CompFun' into a 'Term' -compFunToTerm :: CompFun -> MRM Term +compFunToTerm :: CompFun -> MRM t Term compFunToTerm (CompFunTerm _ t) = return t compFunToTerm (CompFunComp f g) = do f' <- compFunToTerm f @@ -566,7 +568,7 @@ compFunToTerm (CompFunReturn params (Type a)) = {- -- | Convert a 'Comp' into a 'Term' -compToTerm :: Comp -> MRM Term +compToTerm :: Comp -> MRM t Term compToTerm (CompTerm t) = return t compToTerm (CompReturn t) = do tp <- mrTypeOf t @@ -582,9 +584,17 @@ compToTerm (CompBind m f) = -} -- | Apply a 'CompFun' to a term and normalize the resulting computation -applyNormCompFun :: CompFun -> Term -> MRM NormComp +applyNormCompFun :: CompFun -> Term -> MRM t NormComp applyNormCompFun f arg = applyCompFun f arg >>= normComp + +-- | Convert a 'FunAssumpRHS' to a 'NormComp' +mrFunAssumpRHSAsNormComp :: FunAssumpRHS -> MRM t NormComp +mrFunAssumpRHSAsNormComp (OpaqueFunAssump f args) = + FunBind f args <$> mkCompFunReturn <$> mrFunOutType f args +mrFunAssumpRHSAsNormComp (RewriteFunAssump rhs) = normCompTerm rhs + + -- | Match a term as a static list of eliminators for an Eithers type matchEitherElims :: Term -> Maybe [EitherElim] matchEitherElims (asCtor -> @@ -596,7 +606,7 @@ matchEitherElims (asCtor -> Just (primName -> "Prelude.FunsTo_Cons", matchEitherElims _ = Nothing -- | Construct the type @Eithers tps@ eliminated by a list of 'EitherElim's -elimsEithersType :: [EitherElim] -> MRM Type +elimsEithersType :: [EitherElim] -> MRM t Type elimsEithersType elims = Type <$> (do f <- mrGlobalTerm "Prelude.Eithers" @@ -613,7 +623,7 @@ elimsEithersType elims = -- | Lookup the definition of a function or throw a 'CannotLookupFunDef' if this is -- not allowed, either because it is a global function we are treating as opaque -- or because it is a locally-bound function variable -mrLookupFunDef :: FunName -> MRM Term +mrLookupFunDef :: FunName -> MRM t Term mrLookupFunDef f@(GlobalName _) = throwMRFailure (CannotLookupFunDef f) mrLookupFunDef f@(LocalName var) = mrVarInfo var >>= \case @@ -622,7 +632,7 @@ mrLookupFunDef f@(LocalName var) = Nothing -> error "mrLookupFunDef: unknown variable!" -- | Unfold a call to function @f@ in term @f args >>= g@ -mrUnfoldFunBind :: FunName -> [Term] -> Mark -> CompFun -> MRM Comp +mrUnfoldFunBind :: FunName -> [Term] -> Mark -> CompFun -> MRM t Comp mrUnfoldFunBind f _ mark _ | inMark f mark = throwMRFailure (RecursiveUnfold f) mrUnfoldFunBind f args mark g = do f_def <- mrLookupFunDef f @@ -644,7 +654,7 @@ handling the recursive ones ---------------------------------------------------------------------- -- | Prove the invariant of a coinductive hypothesis -proveCoIndHypInvariant :: CoIndHyp -> MRM () +proveCoIndHypInvariant :: CoIndHyp -> MRM t () proveCoIndHypInvariant hyp = do (invar1, invar2) <- applyCoIndHypInvariants hyp invar <- liftSC2 scAnd invar1 invar2 @@ -668,7 +678,7 @@ proveCoIndHypInvariant hyp = -- assumptions are thrown away. If while running the refinement computation a -- 'CoIndHypMismatchWidened' error is reached with the given names, the state is -- restored and the computation is re-run with the widened hypothesis. -mrRefinesCoInd :: FunName -> [Term] -> FunName -> [Term] -> MRM () +mrRefinesCoInd :: FunName -> [Term] -> FunName -> [Term] -> MRM t () mrRefinesCoInd f1 args1 f2 args2 = do ctx <- mrUVars preF1 <- mrGetInvariant f1 @@ -679,7 +689,7 @@ mrRefinesCoInd f1 args1 f2 args2 = -- | Prove the refinement represented by a 'CoIndHyp' coinductively. This is the -- main loop implementing 'mrRefinesCoInd'. See that function for documentation. -proveCoIndHyp :: CoIndHyp -> MRM () +proveCoIndHyp :: CoIndHyp -> MRM t () proveCoIndHyp hyp = withFailureCtx (FailCtxCoIndHyp hyp) $ do let f1 = coIndHypLHSFun hyp f2 = coIndHypRHSFun hyp @@ -695,7 +705,7 @@ proveCoIndHyp hyp = withFailureCtx (FailCtxCoIndHyp hyp) $ MRExnWiden nm1' nm2' new_vars | f1 == nm1' && f2 == nm2' -> -- NOTE: the state automatically gets reset here because we defined - -- MRM with ExceptT at a lower level than StateT + -- MRM t with ExceptT at a lower level than StateT do mrDebugPPPrefixSep 1 "Widening recursive assumption for" nm1' "|=" nm2' hyp' <- generalizeCoIndHyp hyp new_vars proveCoIndHyp hyp' @@ -703,7 +713,7 @@ proveCoIndHyp hyp = withFailureCtx (FailCtxCoIndHyp hyp) $ -- | Test that a coinductive hypothesis for the given function names matches the -- given arguments, otherwise throw an exception saying that widening is needed -matchCoIndHyp :: CoIndHyp -> [Term] -> [Term] -> MRM () +matchCoIndHyp :: CoIndHyp -> [Term] -> [Term] -> MRM t () matchCoIndHyp hyp args1 args2 = do mrDebugPPPrefix 1 "matchCoIndHyp" hyp (args1', args2') <- instantiateCoIndHyp hyp @@ -717,7 +727,7 @@ matchCoIndHyp hyp args1 args2 = proveCoIndHypInvariant hyp -- | Generalize some of the arguments of a coinductive hypothesis -generalizeCoIndHyp :: CoIndHyp -> [Either Int Int] -> MRM CoIndHyp +generalizeCoIndHyp :: CoIndHyp -> [Either Int Int] -> MRM t CoIndHyp generalizeCoIndHyp hyp [] = return hyp generalizeCoIndHyp hyp all_specs@(arg_spec_0:arg_specs) = withOnlyUVars (coIndHypCtx hyp) $ do @@ -777,7 +787,7 @@ generalizeCoIndHyp hyp all_specs@(arg_spec_0:arg_specs) = -- and @c_0 <> c1@. let cbnConvs :: (Term, InjConversion, [(a, InjConversion)]) -> (a, (Term, InjConversion, InjConversion)) -> - MRM (Term, InjConversion, [(a, InjConversion)]) + MRM t (Term, InjConversion, [(a, InjConversion)]) cbnConvs (tp, c_0, cs) (arg_spec_i, (tp_i, _, c2_i)) = findInjConvs tp Nothing tp_i Nothing >>= \case Just (tp', c1, c2) -> @@ -802,7 +812,7 @@ generalizeCoIndHyp hyp all_specs@(arg_spec_0:arg_specs) = -- | An object that can be converted to a normalized computation class ToNormComp a where - toNormComp :: a -> MRM NormComp + toNormComp :: a -> MRM t NormComp instance ToNormComp NormComp where toNormComp = return @@ -813,7 +823,7 @@ instance ToNormComp Term where -- | Prove that the left-hand computation refines the right-hand one. See the -- rules described at the beginning of this module. -mrRefines :: (ToNormComp a, ToNormComp b) => a -> b -> MRM () +mrRefines :: (ToNormComp a, ToNormComp b) => a -> b -> MRM t () mrRefines t1 t2 = do m1 <- toNormComp t1 m2 <- toNormComp t2 @@ -823,7 +833,7 @@ mrRefines t1 t2 = withFailureCtx (FailCtxRefines m1 m2) $ mrRefines' m1 m2 -- | The main implementation of 'mrRefines' -mrRefines' :: NormComp -> NormComp -> MRM () +mrRefines' :: NormComp -> NormComp -> MRM t () mrRefines' (RetS e1) (RetS e2) = mrAssertProveRel True e1 e2 mrRefines' (ErrorS _) (ErrorS _) = return () @@ -1025,7 +1035,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- If we have an opaque FunAssump that f1 args1' refines f2 args2', then -- prove that args1 = args1', args2 = args2', and then that k1 refines k2 - (_, Just (FunAssump ctx args1' (OpaqueFunAssump f2' args2'))) | f2 == f2' -> + (_, Just fa@(FunAssump ctx _ args1' (OpaqueFunAssump f2' args2') _)) | f2 == f2' -> do debugPretty 2 $ flip runPPInCtxM ctx $ prettyAppList [return "mrRefines using opaque FunAssump:", prettyInCtx ctx, return ".", @@ -1036,19 +1046,20 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = (args1'', args2'') <- substTermLike 0 evars (args1', args2') zipWithM_ mrAssertProveEq args1'' args1 zipWithM_ mrAssertProveEq args2'' args2 - mrRefinesFun tp1 k1 tp2 k2 + recordUsedFunAssump fa >> mrRefinesFun tp1 k1 tp2 k2 -- If we have an opaque FunAssump that f1 refines some f /= f2, and f2 -- unfolds and is not recursive in itself, unfold f2 and recurse - (_, Just (FunAssump _ _ (OpaqueFunAssump _ _))) + (_, Just fa@(FunAssump _ _ _ (OpaqueFunAssump _ _) _)) | Just (f2_body, False) <- maybe_f2_body -> - normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + normBindTerm f2_body k2 >>= \m2' -> + recordUsedFunAssump fa >> mrRefines m1 m2' -- If we have a rewrite FunAssump, or we have an opaque FunAssump that -- f1 args1' refines some f args where f /= f2 and f2 does not match the -- case above, treat either case like we have a rewrite FunAssump and prove -- that args1 = args1' and then that f args refines m2 - (_, Just (FunAssump ctx args1' rhs)) -> + (_, Just fa@(FunAssump ctx _ args1' rhs _)) -> do debugPretty 2 $ flip runPPInCtxM ctx $ prettyAppList [return "mrRefines rewriting by FunAssump:", prettyInCtx ctx, return ".", @@ -1064,7 +1075,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 m1' <- normBind rhs'' k1 - mrRefines m1' m2 + recordUsedFunAssump fa >> mrRefines m1' m2 -- If f1 unfolds and is not recursive in itself, unfold it and recurse _ | Just (f1_body, False) <- maybe_f1_body -> @@ -1098,13 +1109,13 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = -- If we have an assumption that f1 args' refines some rhs, then prove that -- args1 = args' and then that rhs refines m2 - Just (FunAssump ctx args1' rhs) -> + Just fa@(FunAssump ctx _ args1' rhs _) -> do rhs' <- mrFunAssumpRHSAsNormComp rhs evars <- mrFreshEVars ctx (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 m1' <- normBind rhs'' k1 - mrRefines m1' m2 + recordUsedFunAssump fa >> mrRefines m1' m2 -- Otherwise, see if we can unfold f1 Nothing -> @@ -1145,7 +1156,7 @@ mrRefines' m1 m2 = mrRefines'' m1 m2 -- | The cases of 'mrRefines' which must occur after the ones in 'mrRefines''. -- For example, the rules that introduce existential variables need to go last, -- so that they can quantify over as many universals as possible -mrRefines'' :: NormComp -> NormComp -> MRM () +mrRefines'' :: NormComp -> NormComp -> MRM t () mrRefines'' m1 (AssertBoolBind cond2 k2) = do m2 <- liftSC0 scUnitValue >>= applyCompFun k2 @@ -1181,7 +1192,7 @@ mrRefines'' (ForallBind tp f1) m2 = mrRefines'' m1 m2 = throwMRFailure (CompsDoNotRefine m1 m2) -- | Prove that one function refines another for all inputs -mrRefinesFun :: Term -> CompFun -> Term -> CompFun -> MRM () +mrRefinesFun :: Term -> CompFun -> Term -> CompFun -> MRM t () mrRefinesFun tp1 f1 tp2 f2 = do mrDebugPPPrefixSep 1 "mrRefinesFun on types:" tp1 "," tp2 f1' <- compFunToTerm f1 >>= liftSC1 scWhnf @@ -1203,8 +1214,8 @@ mrRefinesFun tp1 f1 tp2 f2 = -- wrapper functions determined by how the types are heterogeneously related), -- and call the continuation on the resulting terms. The second argument is -- an accumulator of variables to introduce, innermost first. -mrRefinesFunH :: (Term -> Term -> MRM a) -> [Term] -> - Term -> Term -> Term -> Term -> MRM a +mrRefinesFunH :: (Term -> Term -> MRM t a) -> [Term] -> + Term -> Term -> Term -> Term -> MRM t a -- Introduce equalities on either side as assumptions mrRefinesFunH k vars (asPi -> Just (nm1, tp1@(asEq -> Just (asBoolType -> Just (), b1, b2)), _)) t1 piTp2 t2 = @@ -1284,71 +1295,32 @@ mrRefinesFunH k _ _ t1 _ t2 = k t1 t2 -- * External Entrypoints ---------------------------------------------------------------------- --- | The result of a successful call to Mr. Solver: either a 'FunAssump' to --- (optionally) add to the 'MREnv', or 'Nothing' if the left-hand-side was not --- a function name -type MRSolverResult = Maybe (FunName, FunAssump) - --- | The continuation passed to 'mrRefinesFunH' in 'askMRSolver' and --- 'assumeMRSolver': normalizes both resulting terms using 'normCompTerm', --- calls the given monadic function, then returns a 'FunAssump', if possible -askMRSolverH :: (NormComp -> NormComp -> MRM ()) -> - Term -> Term -> MRM MRSolverResult +-- | The continuation passed to 'mrRefinesFunH' in 'askMRSolver' - normalizes +-- both resulting terms using 'normCompTerm' then calls the given monadic +-- function +askMRSolverH :: (NormComp -> NormComp -> MRM t a) -> Term -> Term -> MRM t a askMRSolverH f t1 t2 = do mrUVars >>= mrDebugPPPrefix 1 "askMRSolverH uvars:" m1 <- normCompTerm t1 m2 <- normCompTerm t2 f m1 m2 - case (m1, m2) of - -- If t1 and t2 are both named functions, our result is the opaque - -- FunAssump that forall xs. f1 xs |= f2 xs' - (FunBind f1 args1 (CompFunReturn _ _), - FunBind f2 args2 (CompFunReturn _ _)) -> - mrUVars >>= \uvar_ctx -> - return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, - fassumpArgs = args1, - fassumpRHS = OpaqueFunAssump f2 args2 }) - -- If just t1 is a named function, our result is the rewrite FunAssump - -- that forall xs. f1 xs |= m2 - (FunBind f1 args1 (CompFunReturn _ _), _) -> - mrUVars >>= \uvar_ctx -> - return $ Just (f1, FunAssump { fassumpCtx = uvar_ctx, - fassumpArgs = args1, - fassumpRHS = RewriteFunAssump m2 }) - _ -> return Nothing - --- | Test two monadic, recursive terms for refinement. On success, if the --- left-hand term is a named function, returning a 'FunAssump' to add to the --- 'MREnv'. + +-- | Test two monadic, recursive terms for refinement askMRSolver :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> - Term -> Term -> IO (Either MRFailure MRSolverResult) -askMRSolver sc env timeout args t1 t2 = - runMRM sc timeout env $ + Term -> Term -> IO (Either MRFailure (SolverStats, MREvidence t)) +askMRSolver sc env timeout rs args t1 t2 = + execMRM sc timeout env rs $ withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 mrRefinesFunH (askMRSolverH mrRefines) [] tp1 t1 tp2 t2 --- | Return the 'FunAssump' to add to the 'MREnv' that would be generated if --- 'askMRSolver' succeeded on the given terms. -assumeMRSolver :: - SharedContext -> - MREnv {- ^ The Mr Solver environment -} -> - Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> - [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> - Term -> Term -> IO (Either MRFailure MRSolverResult) -assumeMRSolver sc env timeout args t1 t2 = - runMRM sc timeout env $ - withUVars (mrVarCtxFromOuterToInner args) $ \_ -> - do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc - tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc - mrRefinesFunH (askMRSolverH (\_ _ -> return ())) [] tp1 t1 tp2 t2 - -- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of fully -- applied versions of the given 'Term's, after quantifying over all the given -- arguments as well as any additional arguments needed to fully apply the given @@ -1358,11 +1330,12 @@ refinementTerm :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> Term -> Term -> IO (Either MRFailure Term) -refinementTerm sc env timeout args t1 t2 = - runMRM sc timeout env $ +refinementTerm sc env timeout rs args t1 t2 = + evalMRM sc timeout env rs $ withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc - mrRefinesFunH mrRefinementGoal [] tp1 t1 tp2 t2 + mrRefinesFunH (mrRefinementTerm True) [] tp1 t1 tp2 t2 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index cf3038a236..0c7178bc2a 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -41,9 +41,6 @@ import GHC.Generics import Prettyprinter import Data.Text (Text, unpack) -import Data.Map (Map) -import qualified Data.Map as Map - import Verifier.SAW.Term.Functor import Verifier.SAW.Term.CtxTerm (MonadTerm(..)) import Verifier.SAW.Term.Pretty @@ -316,61 +313,6 @@ asLambdaName (asLambda -> Just (nm, _, _)) = Just nm asLambdaName _ = Nothing ----------------------------------------------------------------------- --- * Mr Solver Environments ----------------------------------------------------------------------- - --- | The right-hand-side of a 'FunAssump': either a 'FunName' and arguments, if --- it is an opaque 'FunAsump', or a 'NormComp', if it is a rewrite 'FunAssump' -data FunAssumpRHS = OpaqueFunAssump FunName [Term] - | RewriteFunAssump NormComp - --- | An assumption that a named function refines some specification. This has --- the form --- --- > forall x1, ..., xn. F e1 ... ek |= m --- --- for some universal context @x1:T1, .., xn:Tn@, some list of argument --- expressions @ei@ over the universal @xj@ variables, and some right-hand side --- computation expression @m@. -data FunAssump = FunAssump { - -- | The uvars that were in scope when this assumption was created - fassumpCtx :: MRVarCtx, - -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars - fassumpArgs :: [Term], - -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars - fassumpRHS :: FunAssumpRHS -} - --- | A map from function names to function refinement assumptions over that --- name --- --- FIXME: this should probably be an 'IntMap' on the 'VarIndex' of globals -type FunAssumps = Map FunName FunAssump - --- | A global MR Solver environment -data MREnv = MREnv { - -- | The set of function refinements to be assumed by to Mr. Solver (which - -- have hopefully been proved previously...) - mreFunAssumps :: FunAssumps, - -- | The debug level, which controls debug printing - mreDebugLevel :: Int -} - --- | The empty 'MREnv' -emptyMREnv :: MREnv -emptyMREnv = MREnv { mreFunAssumps = Map.empty, mreDebugLevel = 0 } - --- | Add a 'FunAssump' to a Mr Solver environment -mrEnvAddFunAssump :: FunName -> FunAssump -> MREnv -> MREnv -mrEnvAddFunAssump f fassump env = - env { mreFunAssumps = Map.insert f fassump (mreFunAssumps env) } - --- | Set the debug level of a Mr Solver environment -mrEnvSetDebugLevel :: Int -> MREnv -> MREnv -mrEnvSetDebugLevel dlvl env = env { mreDebugLevel = dlvl } - - ---------------------------------------------------------------------- -- * Utility Functions for Transforming 'Term's ---------------------------------------------------------------------- @@ -498,13 +440,17 @@ newtype PPInCtxM a = PPInCtxM (Reader [LocalName] a) runPPInCtxM :: PPInCtxM a -> MRVarCtx -> a runPPInCtxM (PPInCtxM m) = runReader m . map fst . mrVarCtxInnerToOuter +-- | Pretty-print an object in a SAW core context +ppInCtx :: PrettyInCtx a => MRVarCtx -> a -> SawDoc +ppInCtx ctx a = runPPInCtxM (prettyInCtx a) ctx + -- | Pretty-print an object in a SAW core context and render to a 'String' showInCtx :: PrettyInCtx a => MRVarCtx -> a -> String -showInCtx ctx a = renderSawDoc defaultPPOpts $ runPPInCtxM (prettyInCtx a) ctx +showInCtx ctx a = renderSawDoc defaultPPOpts $ ppInCtx ctx a -- | Pretty-print an object in the empty SAW core context ppInEmptyCtx :: PrettyInCtx a => a -> SawDoc -ppInEmptyCtx a = runPPInCtxM (prettyInCtx a) emptyMRVarCtx +ppInEmptyCtx = ppInCtx emptyMRVarCtx -- | A generic function for pretty-printing an object in a SAW core context of -- locally-bound names @@ -523,6 +469,10 @@ prettyTermApp :: Term -> [Term] -> PPInCtxM SawDoc prettyTermApp f_top args = prettyInCtx $ foldl (\f arg -> Unshared $ App f arg) f_top args +-- | Pretty-print the application of a 'Term' in a SAW core context +ppTermAppInCtx :: MRVarCtx -> Term -> [Term] -> SawDoc +ppTermAppInCtx ctx f_top args = runPPInCtxM (prettyTermApp f_top args) ctx + instance PrettyInCtx MRVarCtx where prettyInCtx = return . align . sep . helper [] . mrVarCtxOuterToInner where helper :: [LocalName] -> [(LocalName,Term)] -> [SawDoc] diff --git a/src/SAWScript/Value.hs b/src/SAWScript/Value.hs index d7c018d980..c00b307f5b 100644 --- a/src/SAWScript/Value.hs +++ b/src/SAWScript/Value.hs @@ -77,7 +77,8 @@ import SAWScript.JavaPretty (prettyClass) import SAWScript.Options (Options(printOutFn),printOutLn,Verbosity(..)) import SAWScript.Proof import SAWScript.Prover.SolverStats -import SAWScript.Prover.MRSolver.Term as MRSolver +import SAWScript.Prover.MRSolver.Term (funNameTerm, mrVarCtxInnerToOuter, ppTermAppInCtx) +import SAWScript.Prover.MRSolver.Evidence as MRSolver import SAWScript.Crucible.LLVM.Skeleton import SAWScript.X86 (X86Unsupported(..), X86Error(..)) import SAWScript.Yosys.IR @@ -138,6 +139,7 @@ data Value | VTopLevel (TopLevel Value) | VProofScript (ProofScript Value) | VSimpset SAWSimpset + | VRefnset SAWRefnset | VTheorem Theorem ----- | VLLVMCrucibleSetup !(LLVMCrucibleSetupM Value) @@ -171,6 +173,7 @@ data Value | VYosysTheorem YosysTheorem type SAWSimpset = Simpset TheoremNonce +type SAWRefnset = MRSolver.Refnset TheoremNonce data AIGNetwork where AIGNetwork :: (Typeable l, Typeable g, AIG.IsAIG l g) => AIG.Network l g -> AIGNetwork @@ -300,6 +303,23 @@ showSimpset opts ss = ppTerm t = SAWCorePP.ppTerm opts' t opts' = sawPPOpts opts +-- | Pretty-print a 'Refnset' to a 'String' +showRefnset :: PPOpts -> MRSolver.Refnset a -> String +showRefnset opts ss = + unlines ("Refinements" : "=============" : map (show . ppFunAssump) + (MRSolver.listFunAssumps ss)) + where + ppFunAssump (MRSolver.FunAssump ctx f args rhs _) = + PP.pretty '*' PP.<+> + (PP.nest 2 $ PP.fillSep + [ ppTermAppInCtx ctx (funNameTerm f) args + , PP.pretty ("|=" :: String) PP.<+> ppFunAssumpRHS ctx rhs ]) + ppFunAssumpRHS ctx (OpaqueFunAssump f args) = + ppTermAppInCtx ctx (funNameTerm f) args + ppFunAssumpRHS ctx (RewriteFunAssump rhs) = + SAWCorePP.ppTermInCtx opts' (map fst $ mrVarCtxInnerToOuter ctx) rhs + opts' = sawPPOpts opts + showsPrecValue :: PPOpts -> SAWNamingEnv -> Int -> Value -> ShowS showsPrecValue opts nenv p v = case v of @@ -323,6 +343,7 @@ showsPrecValue opts nenv p v = VBind {} -> showString "<>" VTopLevel {} -> showString "<>" VSimpset ss -> showString (showSimpset opts ss) + VRefnset ss -> showString (showRefnset opts ss) VProofScript {} -> showString "<>" VTheorem thm -> showString "Theorem " . @@ -1145,6 +1166,13 @@ instance FromValue SAWSimpset where fromValue (VSimpset ss) = ss fromValue _ = error "fromValue Simpset" +instance IsValue SAWRefnset where + toValue rs = VRefnset rs + +instance FromValue SAWRefnset where + fromValue (VRefnset rs) = rs + fromValue _ = error "fromValue Refnset" + instance IsValue Theorem where toValue t = VTheorem t From 3c46deda3d9843e5984a3b2b335ec830abdb0210 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Thu, 4 May 2023 14:12:52 -0400 Subject: [PATCH 04/10] revert unnecessary DataTypeAssump bits of 33694b2 --- src/SAWScript/Prover/MRSolver/Monad.hs | 79 ++++--------------------- src/SAWScript/Prover/MRSolver/Solver.hs | 67 ++++++--------------- src/SAWScript/Prover/MRSolver/Term.hs | 29 --------- 3 files changed, 30 insertions(+), 145 deletions(-) diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index fe9cf09dce..e5bac61326 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -24,15 +24,14 @@ monadic combinators for operating on terms. module SAWScript.Prover.MRSolver.Monad where -import Data.Maybe (fromJust) import Data.List (find, findIndex, foldl') -import Data.Foldable (foldrM) import qualified Data.Text as T import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State import Control.Monad.Except import Control.Monad.Trans.Maybe +import GHC.Generics import Data.Map (Map) import qualified Data.Map as Map @@ -76,7 +75,6 @@ data MRFailure | CannotLookupFunDef FunName | RecursiveUnfold FunName | MalformedLetRecTypes Term - | MalformedDataTypeAssump Term | MalformedDefs Term | MalformedComp Term | NotCompFunType Term @@ -156,9 +154,6 @@ instance PrettyInCtx MRFailure where ppWithPrefix "Recursive unfolding of function inside its own body:" nm prettyInCtx (MalformedLetRecTypes t) = ppWithPrefix "Not a ground LetRecTypes list:" t - prettyInCtx (MalformedDataTypeAssump t) = - ppWithPrefix ("assertS/assumeS expects a Bool, Either, or TCNum equality" - ++ " with a constructor on one side, got:") t prettyInCtx (MalformedDefs t) = ppWithPrefix "Cannot handle multiFixS recursive definitions term:" t prettyInCtx (MalformedComp t) = @@ -293,6 +288,18 @@ instance PrettyInCtx CoIndHyp where return "|=", prettyTermApp (funNameTerm f2) args2] +-- | An assumption that something is equal to one of the constructors of a +-- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term' +data DataTypeAssump + = IsLeft Term | IsRight Term | IsNum Term | IsInf + deriving (Generic, Show, TermLike) + +instance PrettyInCtx DataTypeAssump where + prettyInCtx (IsLeft x) = prettyInCtx x >>= ppWithPrefix "Left _ _" + prettyInCtx (IsRight x) = prettyInCtx x >>= ppWithPrefix "Right _ _" + prettyInCtx (IsNum x) = prettyInCtx x >>= ppWithPrefix "TCNum" + prettyInCtx IsInf = return "TCInf" + -- | A map from 'Term's to 'DataTypeAssump's over that term type DataTypeAssumps = HashMap Term DataTypeAssump @@ -850,66 +857,6 @@ mrCallsFun f = memoFixTermFun $ \recurse t -> case t of (unwrapTermF -> tf) -> foldM (\b t' -> if b then return b else recurse t') False tf --- | Given a 'DataTypeAssump' and a 'Term' to which it applies, return the --- equality representing the proposition that the 'DataTypeAssump' holds. --- For example, @mrDataTypeAssumpTerm x (IsLeft y)@ for @x : Either a b@ --- would return @Eq (Either a b) x (Left a b y)@. -mrDataTypeAssumpTerm :: Term -> DataTypeAssump -> MRM t Term -mrDataTypeAssumpTerm x dt = - do tp <- mrTypeOf x - y <- case dt of - IsLeft y - | Just (primName -> "Prelude.Either", [a, b]) <- asDataType tp -> - liftSC2 scCtorApp "Prelude.Left" [a, b, y] - | otherwise -> error $ "IsLeft expected Either, got: " ++ show tp - IsRight y - | Just (primName -> "Prelude.Either", [a, b]) <- asDataType tp -> - liftSC2 scCtorApp "Prelude.Right" [a, b, y] - | otherwise -> error $ "IsRight expected Either, got: " ++ show tp - IsNum y -> liftSC2 scCtorApp "Prelude.TCNum" [y] - IsInf -> liftSC2 scCtorApp "Prelude.TCInf" [] - liftSC2 scGlobalApply "Prelude.Eq" [tp, x, y] - --- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of the --- given 'Term's, after quantifying over all current 'mrUVars' with Pi types --- and adding calls to @assertS@ on the right hand side for any current --- 'mrAssumps' and/or 'mrDataTypeAssump's if the given 'Bool' is 'True' -mrRefinementTerm :: Bool -> Term -> Term -> MRM t Term -mrRefinementTerm includeAssumps t1 t2 = - do (SpecMParams ev1 stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1 - (SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2 - assumps <- mrAssumptions - assumpsAssert <- liftSC2 scGlobalApply "Prelude.assertBoolS" - [ev2, stack2, assumps] - t2' <- if includeAssumps && asBool assumps /= Just True - then bindConst ev2 stack2 tp2 assumpsAssert t2 - else return t2 - dtAssumps <- HashMap.toList <$> mrDataTypeAssumps - dtAssumpAsserts <- forM dtAssumps $ \(nm, assump) -> - do assump_tm <- mrDataTypeAssumpTerm nm assump - liftSC2 scGlobalApply "Prelude.assertS" - [ev2, stack2, assump_tm] - t2'' <- if includeAssumps - then foldrM (bindConst ev2 stack2 tp2) t2' dtAssumpAsserts - else return t2' - coIndHyps <- mrCoIndHyps - (rpre, rpost, rr) <- - if Map.null coIndHyps - then (,,) <$> liftSC2 scGlobalApply "Prelude.eqPreRel" [ev2, stack2] - <*> liftSC2 scGlobalApply "Prelude.eqPostRel" [ev2, stack2] - <*> liftSC2 scGlobalApply "Prelude.eqRR" [tp2] - else error "FIXME: Handle CoIndHyps in mrRefinementGoal" - ref_tm <- liftSC2 scGlobalApply "Prelude.refinesS" - [ev1, ev2, stack1, stack2, rpre, rpost, - tp1, tp2, rr, t1, t2''] - uvars <- mrUVarsOuterToInner - liftSC2 scPiList uvars ref_tm - where bindConst ev stack tp x y = - do unit <- liftSC0 scUnitType - const_y <- liftSC3 incVars 0 1 y >>= liftSC3 scLambda "_" unit - liftSC2 scGlobalApply "Prelude.bindS" - [ev, stack, unit, tp, x, const_y] - ---------------------------------------------------------------------- -- * Monadic Operations on Mr. Solver State diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index e12999258d..902b2d61a1 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -306,16 +306,6 @@ normComp (CompTerm t) = do unit_tp <- mrUnitType return $ AssumeBoolBind cond (CompFunReturn (SpecMParams ev stack) unit_tp) - (isGlobalDef "Prelude.assertS" -> Just (), [ev, stack, prop]) -> - do unit_tp <- mrUnitType - assert <- either AssertBoolBind (uncurry AssertDataTypeBind) - <$> normCompAssertAssumeBody prop - return $ assert (CompFunReturn (SpecMParams ev stack) unit_tp) - (isGlobalDef "Prelude.assumeS" -> Just (), [ev, stack, prop]) -> - do unit_tp <- mrUnitType - assume <- either AssumeBoolBind (uncurry AssumeDataTypeBind) - <$> normCompAssertAssumeBody prop - return $ assume (CompFunReturn (SpecMParams ev stack) unit_tp) (isGlobalDef "Prelude.existsS" -> Just (), [ev, stack, tp]) -> do unit_tp <- mrUnitType return $ ExistsBind (Type tp) (CompFunReturn @@ -459,26 +449,6 @@ normComp (CompTerm t) = _ -> throwMRFailure (MalformedComp t) --- | Given the body of an @assertS@ or @assumeS@, return either the boolean --- term @x@ if the body is of the form @Eq Bool x True@ or @Eq Bool True x@, --- or a 'Term' @x@ and a 'DataTypeAssump' @c@ if the body is of the form --- @Eq _ x (c ...)@ or @Eq _ (c ...) x@ -normCompAssertAssumeBody :: Term -> MRM t (Either Term (Term, DataTypeAssump)) -normCompAssertAssumeBody (asEq -> Just (_, x1, asBool -> Just True)) = - return $ Left x1 -normCompAssertAssumeBody (asEq -> Just (_, asBool -> Just True, x2)) = - return $ Left x2 -normCompAssertAssumeBody (asEq -> Just (_, x1, asEither -> Just e2)) = - return $ Right (x1, either IsLeft IsRight e2) -normCompAssertAssumeBody (asEq -> Just (_, asEither -> Just e1, x2)) = - return $ Right (x2, either IsLeft IsRight e1) -normCompAssertAssumeBody (asEq -> Just (_, x1, asNum -> Just e2)) = - return $ Right (x1, either IsNum (const IsInf) e2) -normCompAssertAssumeBody (asEq -> Just (_, asNum -> Just e1, x2)) = - return $ Right (x2, either IsNum (const IsInf) e1) -normCompAssertAssumeBody prop = - throwMRFailure (MalformedDataTypeAssump prop) - -- | Bind a computation in whnf with a function, and normalize normBind :: NormComp -> CompFun -> MRM t NormComp @@ -496,10 +466,6 @@ normBind (AssertBoolBind cond f) k = return $ AssertBoolBind cond (compFunComp f k) normBind (AssumeBoolBind cond f) k = return $ AssumeBoolBind cond (compFunComp f k) -normBind (AssertDataTypeBind x assump f) k = - return $ AssertDataTypeBind x assump (compFunComp f k) -normBind (AssumeDataTypeBind x assump f) k = - return $ AssumeDataTypeBind x assump (compFunComp f k) normBind (ExistsBind tp f) k = return $ ExistsBind tp (compFunComp f k) normBind (ForallBind tp f) k = return $ ForallBind tp (compFunComp f k) normBind (FunBind f args k1) k2 @@ -969,13 +935,6 @@ mrRefines' (AssertBoolBind cond1 k1) m2 = do m1 <- liftSC0 scUnitValue >>= applyCompFun k1 withAssumption cond1 $ mrRefines m1 m2 -mrRefines' m1 (AssumeDataTypeBind x2 assump2 k2) = - do m2 <- liftSC0 scUnitValue >>= applyCompFun k2 - withDataTypeAssump x2 assump2 $ mrRefines m1 m2 -mrRefines' (AssertDataTypeBind x1 assump1 k1) m2 = - do m1 <- liftSC0 scUnitValue >>= applyCompFun k1 - withDataTypeAssump x1 assump1 $ mrRefines m1 m2 - mrRefines' m1 (ForallBind tp f2) = let nm = maybe "x" id (compFunVarName f2) in withUVarLift nm tp (m1,f2) $ \x (m1',f2') -> @@ -1169,14 +1128,6 @@ mrRefines'' (AssumeBoolBind cond1 k1) m2 = if cond1_pv then mrRefines m1 m2 else throwMRFailure (AssumptionNotProvable cond1) --- FIXME: Do something smarter here? -mrRefines'' _ (AssertDataTypeBind t2 assump2 _) = - do cond2 <- mrDataTypeAssumpTerm t2 assump2 - throwMRFailure (AssertionNotProvable cond2) -mrRefines'' (AssumeDataTypeBind t1 assump1 _) _ = - do cond1 <- mrDataTypeAssumpTerm t1 assump1 - throwMRFailure (AssertionNotProvable cond1) - mrRefines'' m1 (ExistsBind tp f2) = do let nm = maybe "x" id (compFunVarName f2) evar <- mrFreshEVar nm tp @@ -1321,6 +1272,22 @@ askMRSolver sc env timeout rs args t1 t2 = mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 mrRefinesFunH (askMRSolverH mrRefines) [] tp1 t1 tp2 t2 +-- | The continuation passed to 'mrRefinesFunH' in 'refinementTerm' - returns +-- the 'Term' which is the refinement (@Prelude.refinesS@) of the given +-- 'Term's, after quantifying over all current 'mrUVars' with Pi types +refinementTermH :: Term -> Term -> MRM t Term +refinementTermH t1 t2 = + do (SpecMParams ev1 stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1 + (SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2 + rpre <- liftSC2 scGlobalApply "Prelude.eqPreRel" [ev2, stack2] + rpost <- liftSC2 scGlobalApply "Prelude.eqPostRel" [ev2, stack2] + rr <- liftSC2 scGlobalApply "Prelude.eqRR" [tp2] + ref_tm <- liftSC2 scGlobalApply "Prelude.refinesS" + [ev1, ev2, stack1, stack2, rpre, rpost, + tp1, tp2, rr, t1, t2] + uvars <- mrUVarsOuterToInner + liftSC2 scPiList uvars ref_tm + -- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of fully -- applied versions of the given 'Term's, after quantifying over all the given -- arguments as well as any additional arguments needed to fully apply the given @@ -1338,4 +1305,4 @@ refinementTerm sc env timeout rs args t1 t2 = withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc - mrRefinesFunH (mrRefinementTerm True) [] tp1 t1 tp2 t2 + mrRefinesFunH refinementTermH [] tp1 t1 tp2 t2 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index 0c7178bc2a..de910afccc 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -171,12 +171,6 @@ mrVarCtxFromOuterToInner = mrVarCtxFromInnerToOuter . reverse specMParamsArgs :: SpecMParams Term -> [Term] specMParamsArgs (SpecMParams ev stack) = [ev, stack] --- | An assumption that something is equal to one of the constructors of a --- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term' -data DataTypeAssump - = IsLeft Term | IsRight Term | IsNum Term | IsInf - deriving (Generic, Show, TermLike) - -- | A Haskell representation of a @SpecM@ in "monadic normal form" data NormComp = RetS Term -- ^ A term @retS _ _ a x@ @@ -187,8 +181,6 @@ data NormComp | OrS Comp Comp -- ^ an @orS@ computation | AssertBoolBind Term CompFun -- ^ the bind of an @assertBoolS@ computation | AssumeBoolBind Term CompFun -- ^ the bind of an @assumeBoolS@ computation - | AssertDataTypeBind Term DataTypeAssump CompFun -- ^ the bind of a datatype @assertS@ computation - | AssumeDataTypeBind Term DataTypeAssump CompFun -- ^ the bind of a datatype @assumeS@ computation | ExistsBind Type CompFun -- ^ the bind of an @existsS@ computation | ForallBind Type CompFun -- ^ the bind of a @forallS@ computation | FunBind FunName [Term] CompFun @@ -529,15 +521,6 @@ instance PrettyInCtx FunName where foldM (\pp proj -> (pp <>) <$> prettyInCtx proj) (ppName $ globalDefName g) projs -instance PrettyInCtx DataTypeAssump where - prettyInCtx (IsLeft x) = - prettyAppList [return "Left _ _", parens <$> prettyInCtx x] - prettyInCtx (IsRight x) = - prettyAppList [return "Right _ _", parens <$> prettyInCtx x] - prettyInCtx (IsNum x) = - prettyAppList [return "TCNum", parens <$> prettyInCtx x] - prettyInCtx IsInf = return "TCInf" - instance PrettyInCtx Comp where prettyInCtx (CompTerm t) = prettyInCtx t prettyInCtx (CompBind c f) = @@ -581,18 +564,6 @@ instance PrettyInCtx NormComp where prettyAppList [return "assumeBoolS", return "_", return "_", parens <$> prettyInCtx cond, return ">>=", parens <$> prettyInCtx k] - prettyInCtx (AssertDataTypeBind x y k) = - prettyAppList [return "assertS", return "_", return "_", - parens <$> prettyAppList [return "Eq", return "_", - parens <$> prettyInCtx x, - parens <$> prettyInCtx y], - return ">>=", parens <$> prettyInCtx k] - prettyInCtx (AssumeDataTypeBind x y k) = - prettyAppList [return "assumeS", return "_", return "_", - parens <$> prettyAppList [return "Eq", return "_", - parens <$> prettyInCtx x, - parens <$> prettyInCtx y], - return ">>=", parens <$> prettyInCtx k] prettyInCtx (ExistsBind tp k) = prettyAppList [return "existsS", return "_", return "_", prettyInCtx tp, return ">>=", parens <$> prettyInCtx k] From 9635b656e4eeb57dffe91896caa0495e3ef8b1a0 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 8 May 2023 14:10:26 -0400 Subject: [PATCH 05/10] fix infinite loop in `mrCallsFun` on recursive sub-functions --- heapster-saw/examples/sha512_mr_solver.saw | 2 +- src/SAWScript/Prover/MRSolver/Monad.hs | 25 +++++++++------------ src/SAWScript/Prover/MRSolver/Term.hs | 26 +++++++++++++--------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index ac68a154fc..27d38a002d 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -18,7 +18,7 @@ thm_round_00_15 <- thm_round_16_80 <- prove_extcore - mrsolver_with (addrefns [thm_round_00_15] empty_rs)) + (mrsolver_with (addrefns [thm_round_00_15] empty_rs)) (refines [] round_16_80 {{ round_16_80_spec }}); thm_processBlock <- diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index e5bac61326..7122484fdc 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -39,6 +39,8 @@ import qualified Data.Map as Map import Data.HashMap.Lazy (HashMap) import qualified Data.HashMap.Lazy as HashMap +import qualified Data.Set as Set + import Prettyprinter import Verifier.SAW.Term.Functor @@ -841,21 +843,16 @@ mrFunBodyRecInfo f args = -- | Test if a 'Term' contains, after possibly unfolding some functions, a call -- to a given function @f@ again mrCallsFun :: FunName -> Term -> MRM t Bool -mrCallsFun f = memoFixTermFun $ \recurse t -> case t of - (asExtCns -> Just ec) -> - do g <- extCnsToFunName ec - maybe_body <- mrFunNameBody g - case maybe_body of - _ | f == g -> return True - Just body -> recurse body - Nothing -> return False - (asTypedGlobalProj -> Just (gdef, projs)) -> - case globalDefBody gdef of - _ | f == GlobalName gdef projs -> return True - Just body -> recurse body - Nothing -> return False +mrCallsFun f = flip memoFixTermFunAccum Set.empty $ \recurse seen t -> + let onFunName g = mrFunNameBody g >>= \case + _ | f == g -> return True + Just body | Set.notMember g seen -> recurse (Set.insert g seen) body + _ -> return False + in case t of + (asExtCns -> Just ec) -> extCnsToFunName ec >>= onFunName + (asGlobalFunName -> Just g) -> onFunName g (unwrapTermF -> tf) -> - foldM (\b t' -> if b then return b else recurse t') False tf + foldM (\b t' -> if b then return b else recurse seen t') False tf ---------------------------------------------------------------------- diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index de910afccc..f6d533d4f1 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -313,23 +313,29 @@ asLambdaName _ = Nothing traverseSubterms :: MonadTerm m => (Term -> m Term) -> Term -> m Term traverseSubterms f (unwrapTermF -> tf) = traverse f tf >>= mkTermF --- | Build a recursive memoized function for tranforming 'Term's. Take in a --- function @f@ that intuitively performs one step of the transformation and --- allow it to recursively call the memoized function being defined by passing --- it as the first argument to @f@. -memoFixTermFun :: MonadIO m => ((Term -> m a) -> Term -> m a) -> Term -> m a -memoFixTermFun f term_top = +-- | Like 'memoFixTermFun', but threads through an accumulating argument +memoFixTermFunAccum :: MonadIO m => + ((b -> Term -> m a) -> b -> Term -> m a) -> + b -> Term -> m a +memoFixTermFunAccum f acc_top term_top = do table_ref <- liftIO $ newIORef IntMap.empty - let go t@(STApp { stAppIndex = ix }) = + let go acc t@(STApp { stAppIndex = ix }) = liftIO (readIORef table_ref) >>= \table -> case IntMap.lookup ix table of Just ret -> return ret Nothing -> - do ret <- f go t + do ret <- f go acc t liftIO $ modifyIORef' table_ref (IntMap.insert ix ret) return ret - go t = f go t - go term_top + go acc t = f go acc t + go acc_top term_top + +-- | Build a recursive memoized function for tranforming 'Term's. Take in a +-- function @f@ that intuitively performs one step of the transformation and +-- allow it to recursively call the memoized function being defined by passing +-- it as the first argument to @f@. +memoFixTermFun :: MonadIO m => ((Term -> m a) -> Term -> m a) -> Term -> m a +memoFixTermFun f = memoFixTermFunAccum (f .) () ---------------------------------------------------------------------- From d4678121cf51a7a8dabc2a4236eab31c19505039 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 14 Aug 2023 16:48:19 -0400 Subject: [PATCH 06/10] use solver caching (i.e. use applyProverToGoal) in MRSolver --- heapster-saw/examples/Makefile | 2 +- src/SAWScript/Builtins.hs | 30 ++++++++------- src/SAWScript/Prover/MRSolver/Monad.hs | 49 ++++++++++++++++++------- src/SAWScript/Prover/MRSolver/SMT.hs | 24 ++++++------ src/SAWScript/Prover/MRSolver/Solver.hs | 17 ++++++--- 5 files changed, 75 insertions(+), 47 deletions(-) diff --git a/heapster-saw/examples/Makefile b/heapster-saw/examples/Makefile index 25980f49fe..0a8354f7fc 100644 --- a/heapster-saw/examples/Makefile +++ b/heapster-saw/examples/Makefile @@ -41,7 +41,7 @@ endif $(SAW) $< # Lists all the Mr Solver tests, without their ".saw" suffix -MR_SOLVER_TESTS = # arrays_mr_solver linked_list_mr_solver sha512_mr_solver +MR_SOLVER_TESTS = exp_explosion_mr_solver.saw # arrays_mr_solver linked_list_mr_solver sha512_mr_solver .PHONY: mr-solver-tests $(MR_SOLVER_TESTS) mr-solver-tests: $(MR_SOLVER_TESTS) diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 8988721a6a..69df87a9f5 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -973,19 +973,18 @@ proveUnintSBV conf unints = (Prover.proveUnintSBV conf timeout) unintSet -- | Given a continuation which calls a prover, call the continuation on the --- 'goalSequent' of the given 'ProofGoal' and return a 'SolveResult'. If there --- is a 'SolverCache', do not call the continuation if the goal has an already --- cached result, and otherwise save the result of the call to the cache. +-- given 'Sequent' and return a 'SolveResult'. If there is a 'SolverCache', +-- do not call the continuation if the goal has an already cached result, +-- and otherwise save the result of the call to the cache. applyProverToGoal :: [SolverBackend] -> [SolverBackendOption] -> (SATQuery -> TopLevel (Maybe CEX, String)) - -> Set VarIndex - -> ProofGoal + -> Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult) -applyProverToGoal backends opts f unintSet g = do +applyProverToGoal backends opts f unintSet sqt = do sc <- getSharedContext let opt_backends = concatMap optionBackends opts vs <- io $ getSolverBackendVersions (backends ++ opt_backends) - satq <- io $ sequentToSATQuery sc unintSet (goalSequent g) + satq <- io $ sequentToSATQuery sc unintSet sqt k <- io $ mkSolverCacheKey sc vs opts satq (mb, solver_name) <- SV.onSolverCache (lookupInSolverCache k) >>= \case -- Use a cached result if one exists (and it's valid w.r.t our query) @@ -995,9 +994,9 @@ applyProverToGoal backends opts f unintSet g = do Just v -> SV.onSolverCache (insertInSolverCache k v) >> return res Nothing -> return res - let stats = solverStats solver_name (sequentSharedSize (goalSequent g)) + let stats = solverStats solver_name (sequentSharedSize sqt) case mb of - Nothing -> return (stats, SolveSuccess (SolverEvidence stats (goalSequent g))) + Nothing -> return (stats, SolveSuccess (SolverEvidence stats sqt)) Just a -> return (stats, SolveCounterexample a) wrapProver :: @@ -1005,8 +1004,8 @@ wrapProver :: (SATQuery -> TopLevel (Maybe CEX, String)) -> Set VarIndex -> ProofScript () -wrapProver backends opts f = - execTactic . tacticSolve . applyProverToGoal backends opts f +wrapProver backends opts f unints = + execTactic $ tacticSolve $ applyProverToGoal backends opts f unints . goalSequent wrapW4Prover :: SolverBackend -> [SolverBackendOption] -> @@ -2215,12 +2214,15 @@ ensureMonadicTerm sc t = monadifyTypedTerm sc t -- printed, then regardless, the last argument is called on the result. mrSolverH :: SharedContext -> Maybe SawDoc -> (String -> String) -> Maybe (String -> String) -> - (SharedContext -> Prover.MREnv -> Maybe Integer -> SV.SAWRefnset -> - [(LocalName, Term)] -> Term -> Term -> IO (Either Prover.MRFailure a)) -> + (SharedContext -> Prover.MREnv -> Maybe Integer -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> + SV.SAWRefnset -> [(LocalName, Term)] -> Term -> Term -> + TopLevel (Either Prover.MRFailure a)) -> SV.SAWRefnset -> [(LocalName, Term)] -> TypedTerm -> TypedTerm -> (a -> TopLevel b) -> TopLevel b mrSolverH sc printStr errStrf succStr f rs top_args tt1 tt2 cont = do env <- rwMRSolverEnv <$> get + let askSMT = applyProverToGoal [What4, Z3] [] (Prover.proveWhat4_z3 True) m1 <- ttTerm <$> ensureMonadicTerm sc tt1 m2 <- ttTerm <$> ensureMonadicTerm sc tt2 m1' <- io $ collapseEta <$> betaNormalize sc m1 @@ -2231,7 +2233,7 @@ mrSolverH sc printStr errStrf succStr f rs top_args tt1 tt2 cont = "[MRSolver] " <> str <> ": " <> ppTmHead m1' <> " |= " <> ppTmHead m2' time1 <- liftIO getCurrentTime - res <- io $ f sc env Nothing rs top_args m1' m2' + res <- f sc env Nothing askSMT rs top_args m1' m2' time2 <- liftIO getCurrentTime let diff = show $ diffUTCTime time2 time1 case res of diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 7122484fdc..fdebc87035 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -30,6 +30,7 @@ import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State import Control.Monad.Except +import Control.Monad.Catch (MonadThrow, MonadCatch) import Control.Monad.Trans.Maybe import GHC.Generics @@ -39,6 +40,7 @@ import qualified Data.Map as Map import Data.HashMap.Lazy (HashMap) import qualified Data.HashMap.Lazy as HashMap +import Data.Set (Set) import qualified Data.Set as Set import Prettyprinter @@ -51,6 +53,8 @@ import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer import Verifier.SAW.Cryptol.Monadify import SAWScript.Prover.SolverStats +import SAWScript.Proof (Sequent, SolveResult) +import SAWScript.Value (TopLevel) import SAWScript.Prover.MRSolver.Term import SAWScript.Prover.MRSolver.Evidence @@ -313,6 +317,9 @@ data MRInfo t = MRInfo { mriSMTTimeout :: Maybe Integer, -- | The top-level Mr Solver environment mriEnv :: MREnv, + -- | The function to be used as the SMT backend for Mr. Solver, taking a set + -- of uninterpreted variables and a proposition to prove + mriAskSMT :: Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult), -- | The set of function refinements to assume mriRefnset :: Refnset t, -- | The current context of universal variables @@ -347,10 +354,10 @@ data MRExn = MRExnFailure MRFailure -- shared environment, 'MRState' as state, and 'MRFailure' as an exception -- type, all over an 'IO' monad newtype MRM t a = MRM { unMRM :: ReaderT (MRInfo t) (StateT (MRState t) - (ExceptT MRExn IO)) a } + (ExceptT MRExn TopLevel)) a } deriving newtype (Functor, Applicative, Monad, MonadIO, MonadReader (MRInfo t), MonadState (MRState t), - MonadError MRExn) + MonadError MRExn, MonadThrow, MonadCatch) instance MonadTerm (MRM t) where mkTermF = liftSC1 scTermF @@ -386,6 +393,13 @@ mrAssumptions = mriAssumptions <$> ask mrDataTypeAssumps :: MRM t DataTypeAssumps mrDataTypeAssumps = mriDataTypeAssumps <$> ask +-- | Call the SMT backend given by 'mriAskSMT' on a set of uninterpreted +-- variables and a proposition to prove +mrAskSMT :: Set VarIndex -> Sequent -> MRM t (SolverStats, SolveResult) +mrAskSMT unints goal = do + askSMT <- mriAskSMT <$> ask + MRM $ lift $ lift $ lift $ askSMT unints goal + -- | Get the current debug level mrDebugLevel :: MRM t Int mrDebugLevel = mreDebugLevel <$> mriEnv <$> ask @@ -408,12 +422,15 @@ mrVars = mrsVars <$> get -- | Run an 'MRM' computation and return a result or an error, including the -- final state of 'mrsSolverStats' and 'mrsEvidence' -runMRM :: SharedContext -> Maybe Integer -> MREnv -> Refnset t -> - MRM t a -> IO (Either MRFailure (a, (SolverStats, MREvidence t))) -runMRM sc timeout env rs m = - do true_tm <- scBool sc True +runMRM :: SharedContext -> Maybe Integer -> MREnv -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> + Refnset t -> MRM t a -> + TopLevel (Either MRFailure (a, (SolverStats, MREvidence t))) +runMRM sc timeout env askSMT rs m = + do true_tm <- liftIO $ scBool sc True let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, - mriEnv = env, mriRefnset = rs, + mriEnv = env, mriAskSMT = askSMT, + mriRefnset = rs, mriUVars = emptyMRVarCtx, mriCoIndHyps = Map.empty, mriAssumptions = true_tm, @@ -429,15 +446,21 @@ runMRM sc timeout env rs m = -- | Run an 'MRM' computation and return a result or an error, discarding the -- final state -evalMRM :: SharedContext -> Maybe Integer -> MREnv -> Refnset t -> - MRM t a -> IO (Either MRFailure a) -evalMRM sc timeout env rs = fmap (fmap fst) . runMRM sc timeout env rs +evalMRM :: SharedContext -> Maybe Integer -> MREnv -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> + Refnset t -> MRM t a -> + TopLevel (Either MRFailure a) +evalMRM sc timeout env askSMT rs = + fmap (fmap fst) . runMRM sc timeout env askSMT rs -- | Run an 'MRM' computation and return a final state or an error, discarding -- the result -execMRM :: SharedContext -> Maybe Integer -> MREnv -> Refnset t -> - MRM t a -> IO (Either MRFailure (SolverStats, MREvidence t)) -execMRM sc timeout env rs = fmap (fmap snd) . runMRM sc timeout env rs +execMRM :: SharedContext -> Maybe Integer -> MREnv -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> + Refnset t -> MRM t a -> + TopLevel (Either MRFailure (SolverStats, MREvidence t)) +execMRM sc timeout env askSMT rs = + fmap (fmap snd) . runMRM sc timeout env askSMT rs -- | Throw an 'MRFailure' throwMRFailure :: MRFailure -> MRM t a diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 057fff2c45..5856d252d3 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -25,7 +25,7 @@ module SAWScript.Prover.MRSolver.SMT where import qualified Data.Vector as V import Numeric.Natural (Natural) import Control.Monad.Except -import qualified Control.Exception as X +import Control.Monad.Catch (throwM, catch) import Control.Monad.Trans.Maybe import Data.Foldable (foldrM, foldlM) import GHC.Generics @@ -48,10 +48,7 @@ import Verifier.SAW.Simulator.Prims import Verifier.SAW.Module import Verifier.SAW.Prelude.Constants import Verifier.SAW.FiniteValue - -import SAWScript.Proof (termToProp, propToTerm, prettyProp, propToSequent, sequentToSATQuery) -import What4.Solver -import SAWScript.Prover.What4 +import SAWScript.Proof (termToProp, propToTerm, prettyProp, propToSequent, SolveResult(..)) import SAWScript.Prover.MRSolver.Term import SAWScript.Prover.MRSolver.Monad @@ -388,27 +385,28 @@ mrProvableRaw prop_term = nenv <- liftIO (scGetNamingEnv sc) debugPrint 2 ("Calling SMT solver with proposition: " ++ prettyProp defaultPPOpts nenv prop) - satq <- liftIO $ sequentToSATQuery sc unints (propToSequent prop) - sym <- liftIO $ setupWhat4_sym True -- If there are any saw-core `error`s in the term, this will throw a -- Haskell error - in this case we want to just return False, not stop -- execution - smt_res <- liftIO $ - (Right <$> proveWhat4_solver z3Adapter sym sc satq (return ())) - `X.catch` \case + smt_res <- + (Right <$> mrAskSMT unints (propToSequent prop)) + `catch` \case UserError msg -> return $ Left msg - e -> X.throw e + e -> throwM e case smt_res of Left msg -> debugPrint 2 ("SMT solver encountered a saw-core error term: " ++ msg) >> return False - Right (Just cex, stats) -> + Right (stats, SolveUnknown) -> + debugPrint 2 "SMT solver response: unknown" >> + recordUsedSolver stats prop_term >> return False + Right (stats, SolveCounterexample cex) -> debugPrint 2 "SMT solver response: not provable" >> debugPrint 3 ("Counterexample:" ++ concatMap (\(x,v) -> "\n - " ++ renderSawDoc defaultPPOpts (ppTerm defaultPPOpts (Unshared (FTermF (ExtCns x)))) ++ " = " ++ renderSawDoc defaultPPOpts (ppFirstOrderValue defaultPPOpts v)) cex) >> recordUsedSolver stats prop_term >> return False - Right (Nothing, stats) -> + Right (stats, SolveSuccess _) -> debugPrint 2 "SMT solver response: provable" >> recordUsedSolver stats prop_term >> return True diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 902b2d61a1..b8acda7d3b 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -133,6 +133,7 @@ import Control.Monad.Reader import Control.Monad.Except import qualified Data.Map as Map import qualified Data.Text as Text +import Data.Set (Set) import Prettyprinter @@ -141,6 +142,8 @@ import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer import Verifier.SAW.Cryptol.Monadify import SAWScript.Prover.SolverStats +import SAWScript.Proof (Sequent, SolveResult) +import SAWScript.Value (TopLevel) import SAWScript.Prover.MRSolver.Term import SAWScript.Prover.MRSolver.Evidence @@ -1261,11 +1264,12 @@ askMRSolver :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) {- ^ ... -} -> Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> - Term -> Term -> IO (Either MRFailure (SolverStats, MREvidence t)) -askMRSolver sc env timeout rs args t1 t2 = - execMRM sc timeout env rs $ + Term -> Term -> TopLevel (Either MRFailure (SolverStats, MREvidence t)) +askMRSolver sc env timeout askSMT rs args t1 t2 = + execMRM sc timeout env askSMT rs $ withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc @@ -1297,11 +1301,12 @@ refinementTerm :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) {- ^ ... -} -> Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> - Term -> Term -> IO (Either MRFailure Term) -refinementTerm sc env timeout rs args t1 t2 = - evalMRM sc timeout env rs $ + Term -> Term -> TopLevel (Either MRFailure Term) +refinementTerm sc env timeout askSMT rs args t1 t2 = + evalMRM sc timeout env askSMT rs $ withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc From 930ee39b39b2c7ce4d11b553638b39ffd5bf9ea4 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 14 Aug 2023 17:10:26 -0400 Subject: [PATCH 07/10] fix typo in docs of mrsolver_with --- src/SAWScript/Interpreter.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 9f8be70011..b1fa38cd1d 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3936,7 +3936,7 @@ primitives = Map.fromList Current [ "Add proved refinement theorems to a given refinement set." ] - , prim "mrsolver_with" "Renfset -> ProofScript ()" + , prim "mrsolver_with" "Refnset -> ProofScript ()" (pureVal mrSolver) Experimental [ "Use MRSolver to prove a current refinement goal, i.e. a goal of" From 3ec6cfaa49a8434c95d9a8c10d71664fcd444a36 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Wed, 16 Aug 2023 12:43:53 -0400 Subject: [PATCH 08/10] remove extraneous .saw from heapster-saw Makefile --- heapster-saw/examples/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/heapster-saw/examples/Makefile b/heapster-saw/examples/Makefile index 0a8354f7fc..ae44946fef 100644 --- a/heapster-saw/examples/Makefile +++ b/heapster-saw/examples/Makefile @@ -41,7 +41,7 @@ endif $(SAW) $< # Lists all the Mr Solver tests, without their ".saw" suffix -MR_SOLVER_TESTS = exp_explosion_mr_solver.saw # arrays_mr_solver linked_list_mr_solver sha512_mr_solver +MR_SOLVER_TESTS = exp_explosion_mr_solver # arrays_mr_solver linked_list_mr_solver sha512_mr_solver .PHONY: mr-solver-tests $(MR_SOLVER_TESTS) mr-solver-tests: $(MR_SOLVER_TESTS) From e41e00a526d9f62db78fd73a2e31b7724560d891 Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Thu, 17 Aug 2023 23:13:17 +0200 Subject: [PATCH 09/10] incorporate comments from @bboston7 and @eddywestbrook --- src/SAWScript/Builtins.hs | 102 ++++++++++++---------- src/SAWScript/Interpreter.hs | 6 +- src/SAWScript/Prover/MRSolver/Evidence.hs | 7 ++ src/SAWScript/Prover/MRSolver/Monad.hs | 49 +++++++---- src/SAWScript/Prover/MRSolver/Solver.hs | 21 +++-- 5 files changed, 113 insertions(+), 72 deletions(-) diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 69df87a9f5..d08e7f4469 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -2204,26 +2204,16 @@ ensureMonadicTerm sc t False -> monadifyTypedTerm sc t ensureMonadicTerm sc t = monadifyTypedTerm sc t --- | A wrapper for either 'Prover.askMRSolver' or 'Prover.refinementTerm' from --- @MRSolver.hs@: if the second argument is @Just str@, prints out @str@ --- followed by an abridged version of the refinement being asked, then calls --- the given function. On failure, a string of how long the function took to --- run is passed to the third argument and the result is used as the message --- for 'fail'. On success, if the fourth argument is @Just strf@, a string of --- how long the function took to run is passed to @strf@ and the result is --- printed, then regardless, the last argument is called on the result. -mrSolverH :: SharedContext -> - Maybe SawDoc -> (String -> String) -> Maybe (String -> String) -> - (SharedContext -> Prover.MREnv -> Maybe Integer -> - (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> - SV.SAWRefnset -> [(LocalName, Term)] -> Term -> Term -> - TopLevel (Either Prover.MRFailure a)) -> - SV.SAWRefnset -> [(LocalName, Term)] -> TypedTerm -> TypedTerm -> - (a -> TopLevel b) -> TopLevel b -mrSolverH sc printStr errStrf succStr f rs top_args tt1 tt2 cont = - do env <- rwMRSolverEnv <$> get - let askSMT = applyProverToGoal [What4, Z3] [] (Prover.proveWhat4_z3 True) - m1 <- ttTerm <$> ensureMonadicTerm sc tt1 +-- | Normalizes the given 'TypedTerm's for calling 'Prover.askMRSolver' or +-- 'Prover.refinementTerm' and ensures they are of the expected form. +-- Additionally, if the second argument is @Just str@, prints out @str@ +-- followed by an abridged version of the refinement represented by the two +-- terms. +mrSolverNormalizeAndPrintArgs :: + SharedContext -> Maybe SawDoc -> + TypedTerm -> TypedTerm -> TopLevel (Term, Term) +mrSolverNormalizeAndPrintArgs sc printStr tt1 tt2 = + do m1 <- ttTerm <$> ensureMonadicTerm sc tt1 m2 <- ttTerm <$> ensureMonadicTerm sc tt2 m1' <- io $ collapseEta <$> betaNormalize sc m1 m2' <- io $ collapseEta <$> betaNormalize sc m2 @@ -2232,21 +2222,7 @@ mrSolverH sc printStr errStrf succStr f rs top_args tt1 tt2 cont = Just str -> printOutLnTop Info $ renderSawDoc defaultPPOpts $ "[MRSolver] " <> str <> ": " <> ppTmHead m1' <> " |= " <> ppTmHead m2' - time1 <- liftIO getCurrentTime - res <- f sc env Nothing askSMT rs top_args m1' m2' - time2 <- liftIO getCurrentTime - let diff = show $ diffUTCTime time2 time1 - case res of - Left err | Prover.mreDebugLevel env == 0 -> - fail (Prover.showMRFailure err ++ "\n[MRSolver] " ++ errStrf diff) - Left err -> - -- we ignore the MRFailure context here since it will have already - -- been printed by the debug trace - fail (Prover.showMRFailureNoCtx err ++ "\n[MRSolver] " ++ errStrf diff) - Right a | Just sf <- succStr -> - printOutLnTop Info (sf diff) >> cont a - Right a -> - cont a + return (m1', m2') where -- Turn a term of the form @\x1 ... xn -> f x1 ... xn@ into @f@ collapseEta :: Term -> Term collapseEta (asLambdaList -> (lamVars, @@ -2265,6 +2241,32 @@ mrSolverH sc printStr errStrf succStr f rs top_args tt1 tt2 cont = ppTerm defaultPPOpts t <> if length args > 0 then " ..." else "" ppTmHead _ = "..." +-- | The calback to be used by MRSolver for making SMT queries +mrSolverAskSMT :: Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult) +mrSolverAskSMT = applyProverToGoal [What4, Z3] [] (Prover.proveWhat4_z3 True) + +-- | Given the result of calling 'Prover.askMRSolver' or +-- 'Prover.refinementTerm', fails and prints out `err` followed by the second +-- argument if the given result is `Left err` for some `err`, or otherwise +-- returns `a` if the result is `Right a` for some `a`. Additionally, if the +-- third argument is @Just str@, prints out @str@ on success (i.e. 'Right'). +mrSolverGetResultOrFail :: + Prover.MREnv -> + String {- The string to print out on failure -} -> + Maybe String {- The string to print out on success, if any -} -> + Either Prover.MRFailure a {- The result, printed out on error -} -> + TopLevel a +mrSolverGetResultOrFail env errStr succStr res = case res of + Left err | Prover.mreDebugLevel env == 0 -> + fail (Prover.showMRFailure err ++ "\n[MRSolver] " ++ errStr) + Left err -> + -- we ignore the MRFailure context here since it will have already + -- been printed by the debug trace + fail (Prover.showMRFailureNoCtx err ++ "\n[MRSolver] " ++ errStr) + Right a | Just s <- succStr -> + printOutLnTop Info s >> return a + Right a -> return a + -- | Invokes MRSolver to attempt to solve a focused goal of the form -- @(a1:A1) -> ... -> (an:An) -> refinesS_eq ...@, assuming the refinements -- in the given 'Refnset', and printing an error message and exiting if @@ -2280,11 +2282,17 @@ mrSolver rs = execTactic $ Tactic $ \goal -> lift $ do tp1 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev1, stack1, rtp1] tp2 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev2, stack2, rtp2] let tt1 = TypedTerm (TypedTermOther tp1) t1 - let tt2 = TypedTerm (TypedTermOther tp2) t2 - mrSolverH sc - (Just $ "Tactic call") (printf "Failure in %s") (Just $ printf "Success in %s") - Prover.askMRSolver rs args tt1 tt2 - (\(stats, mre) -> return ((), stats, [], leafEvidence $ MrSolverEvidence mre)) + tt2 = TypedTerm (TypedTermOther tp2) t2 + (m1, m2) <- mrSolverNormalizeAndPrintArgs sc (Just $ "Tactic call") tt1 tt2 + env <- rwMRSolverEnv <$> get + time1 <- liftIO getCurrentTime + res <- Prover.askMRSolver sc env Nothing mrSolverAskSMT rs args m1 m2 + time2 <- liftIO getCurrentTime + let diff = show $ diffUTCTime time2 time1 + errStr = printf "Failure in %s" diff + succStr = printf "Success in %s" diff + (stats, mre) <- mrSolverGetResultOrFail env errStr (Just succStr) res + return ((), stats, [], leafEvidence $ MrSolverEvidence mre) _ -> error "mrsolver: cannot apply mrsolver to a non-refinement goal" -- | Add a proved refinement theorem to a given refinement set @@ -2314,10 +2322,16 @@ refinesTerm vars tt1 tt2 = do sc <- getSharedContext tt1' <- lambdas vars tt1 tt2' <- lambdas vars tt2 - mrSolverH sc - Nothing (printf "[MRSolver] Failed to build refinement term (%s)") Nothing - Prover.refinementTerm Prover.emptyRefnset [] tt1' tt2' - (io . mkTypedTerm sc) + (m1, m2) <- mrSolverNormalizeAndPrintArgs sc Nothing tt1' tt2' + env <- rwMRSolverEnv <$> get + time1 <- liftIO getCurrentTime + res <- Prover.refinementTerm sc env Nothing mrSolverAskSMT + Prover.emptyRefnset [] m1 m2 + time2 <- liftIO getCurrentTime + let diff = show $ diffUTCTime time2 time1 + errStr = printf "[MRSolver] Failed to build refinement term (%s)" diff + ttRes <- mrSolverGetResultOrFail env errStr Nothing res + io $ mkTypedTerm sc ttRes setMonadification :: SharedContext -> String -> String -> Bool -> TopLevel () setMonadification sc cry_str saw_str poly_p = diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index b1fa38cd1d..ee27ab7470 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -3923,17 +3923,17 @@ primitives = Map.fromList , prim "empty_rs" "Refnset" (pureVal (emptyRefnset :: SAWRefnset)) - Current + Experimental [ "The empty refinement set, containing no refinements." ] , prim "addrefn" "Theorem -> Refnset -> Refnset" (funVal2 addrefn) - Current + Experimental [ "Add a proved refinement theorem to a given refinement set." ] , prim "addrefns" "[Theorem] -> Refnset -> Refnset" (funVal2 addrefns) - Current + Experimental [ "Add proved refinement theorems to a given refinement set." ] , prim "mrsolver_with" "Refnset -> ProofScript ()" diff --git a/src/SAWScript/Prover/MRSolver/Evidence.hs b/src/SAWScript/Prover/MRSolver/Evidence.hs index f781b324d1..005e68bdbb 100644 --- a/src/SAWScript/Prover/MRSolver/Evidence.hs +++ b/src/SAWScript/Prover/MRSolver/Evidence.hs @@ -13,6 +13,13 @@ This module defines multiple outward facing components of MRSolver, most notably the 'MREvidence' type which provides evidence for the truth of a refinement proposition proved by MRSolver, and used in @Proof.hs@. This module also defines the 'MREnv' type, the global MRSolver state. + +Note: In order to avoid circular dependencies, the 'FunAssump' type and its +dependents in this file ('Refnset' and 'MREvidence') are given a type +parameter `t` which in practice always be 'TheoremNonce' from `Value.hs`. +The reason we cannot just import `Value.hs` here directly is because the +'Refnset' type is used in `Value.hs` - specifically, in the 'VRefnset' +constructor of the 'Value' datatype. -} module SAWScript.Prover.MRSolver.Evidence where diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index fdebc87035..aff5b94ed9 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -422,11 +422,16 @@ mrVars = mrsVars <$> get -- | Run an 'MRM' computation and return a result or an error, including the -- final state of 'mrsSolverStats' and 'mrsEvidence' -runMRM :: SharedContext -> Maybe Integer -> MREnv -> - (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> - Refnset t -> MRM t a -> - TopLevel (Either MRFailure (a, (SolverStats, MREvidence t))) -runMRM sc timeout env askSMT rs m = +runMRM :: + SharedContext -> + MREnv {- ^ The Mr Solver environment -} -> + Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) + {- ^ The callback to use for making SMT queries -} -> + Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> + MRM t a {- ^ The monadic computation to run -} -> + TopLevel (Either MRFailure (a, (SolverStats, MREvidence t))) +runMRM sc env timeout askSMT rs m = do true_tm <- liftIO $ scBool sc True let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, mriEnv = env, mriAskSMT = askSMT, @@ -446,21 +451,31 @@ runMRM sc timeout env askSMT rs m = -- | Run an 'MRM' computation and return a result or an error, discarding the -- final state -evalMRM :: SharedContext -> Maybe Integer -> MREnv -> - (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> - Refnset t -> MRM t a -> - TopLevel (Either MRFailure a) -evalMRM sc timeout env askSMT rs = - fmap (fmap fst) . runMRM sc timeout env askSMT rs +evalMRM :: + SharedContext -> + MREnv {- ^ The Mr Solver environment -} -> + Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) + {- ^ The callback to use for making SMT queries -} -> + Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> + MRM t a {- ^ The monadic computation to eval -} -> + TopLevel (Either MRFailure a) +evalMRM sc env timeout askSMT rs = + fmap (fmap fst) . runMRM sc env timeout askSMT rs -- | Run an 'MRM' computation and return a final state or an error, discarding -- the result -execMRM :: SharedContext -> Maybe Integer -> MREnv -> - (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) -> - Refnset t -> MRM t a -> - TopLevel (Either MRFailure (SolverStats, MREvidence t)) -execMRM sc timeout env askSMT rs = - fmap (fmap snd) . runMRM sc timeout env askSMT rs +execMRM :: + SharedContext -> + MREnv {- ^ The Mr Solver environment -} -> + Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) + {- ^ The callback to use for making SMT queries -} -> + Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> + MRM t a {- ^ The monadic computation to exec -} -> + TopLevel (Either MRFailure (SolverStats, MREvidence t)) +execMRM sc env timeout askSMT rs = + fmap (fmap snd) . runMRM sc env timeout askSMT rs -- | Throw an 'MRFailure' throwMRFailure :: MRFailure -> MRM t a diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index b8acda7d3b..65a233abee 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -1264,12 +1264,13 @@ askMRSolver :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> - (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) {- ^ ... -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) + {- ^ The callback to use for making SMT queries -} -> Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> Term -> Term -> TopLevel (Either MRFailure (SolverStats, MREvidence t)) askMRSolver sc env timeout askSMT rs args t1 t2 = - execMRM sc timeout env askSMT rs $ + execMRM sc env timeout askSMT rs $ withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc @@ -1278,16 +1279,19 @@ askMRSolver sc env timeout askSMT rs args t1 t2 = -- | The continuation passed to 'mrRefinesFunH' in 'refinementTerm' - returns -- the 'Term' which is the refinement (@Prelude.refinesS@) of the given --- 'Term's, after quantifying over all current 'mrUVars' with Pi types +-- 'Term's, after quantifying over all current 'mrUVars' with Pi types. Note +-- that this assumes both terms have the same event and stack types - if they +-- do not a saw-core typechecking error will be raised. refinementTermH :: Term -> Term -> MRM t Term refinementTermH t1 t2 = - do (SpecMParams ev1 stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1 - (SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2 + do (SpecMParams _ev1 _stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1 + (SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2 rpre <- liftSC2 scGlobalApply "Prelude.eqPreRel" [ev2, stack2] rpost <- liftSC2 scGlobalApply "Prelude.eqPostRel" [ev2, stack2] rr <- liftSC2 scGlobalApply "Prelude.eqRR" [tp2] + -- NB: This will throw a type error if _ev1 /= ev2 or _stack1 /= stack2 ref_tm <- liftSC2 scGlobalApply "Prelude.refinesS" - [ev1, ev2, stack1, stack2, rpre, rpost, + [ev2, ev2, stack2, stack2, rpre, rpost, tp1, tp2, rr, t1, t2] uvars <- mrUVarsOuterToInner liftSC2 scPiList uvars ref_tm @@ -1301,12 +1305,13 @@ refinementTerm :: SharedContext -> MREnv {- ^ The Mr Solver environment -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> - (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) {- ^ ... -} -> + (Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult)) + {- ^ The callback to use for making SMT queries -} -> Refnset t {- ^ Any additional refinements to be assumed by Mr Solver -} -> [(LocalName, Term)] {- ^ Any universally quantified variables in scope -} -> Term -> Term -> TopLevel (Either MRFailure Term) refinementTerm sc env timeout askSMT rs args t1 t2 = - evalMRM sc timeout env askSMT rs $ + evalMRM sc env timeout askSMT rs $ withUVars (mrVarCtxFromOuterToInner args) $ \_ -> do tp1 <- liftIO $ scTypeOf sc t1 >>= scWhnf sc tp2 <- liftIO $ scTypeOf sc t2 >>= scWhnf sc From 82e58618e67a9e67c044f4555853cb6ef359cbaf Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Thu, 17 Aug 2023 23:39:07 +0200 Subject: [PATCH 10/10] add RefinesS type for asRefinesS, fix haddock --- src/SAWScript/Builtins.hs | 11 +-- src/SAWScript/Prover/MRSolver.hs | 3 +- src/SAWScript/Prover/MRSolver/Evidence.hs | 87 ++++++++++++++--------- 3 files changed, 63 insertions(+), 38 deletions(-) diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index d08e7f4469..3ca9813a26 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -2246,9 +2246,9 @@ mrSolverAskSMT :: Set VarIndex -> Sequent -> TopLevel (SolverStats, SolveResult) mrSolverAskSMT = applyProverToGoal [What4, Z3] [] (Prover.proveWhat4_z3 True) -- | Given the result of calling 'Prover.askMRSolver' or --- 'Prover.refinementTerm', fails and prints out `err` followed by the second --- argument if the given result is `Left err` for some `err`, or otherwise --- returns `a` if the result is `Right a` for some `a`. Additionally, if the +-- 'Prover.refinementTerm', fails and prints out@`err@ followed by the second +-- argument if the given result is @Left err@ for some @err@, or otherwise +-- returns @a@ if the result is@`Right a@ for some @a@. Additionally, if the -- third argument is @Just str@, prints out @str@ on success (i.e. 'Right'). mrSolverGetResultOrFail :: Prover.MREnv -> @@ -2277,8 +2277,9 @@ mrSolver rs = execTactic $ Tactic $ \goal -> lift $ case sequentState (goalSequent goal) of Unfocused -> fail "mrsolver: focus required" HypFocus _ _ -> fail "mrsolver: cannot apply mrsolver in a hypothesis" - ConclFocus (Prover.asRefinesS . unProp -> Just (args, ev1, ev2, stack1, stack2, - rtp1, rtp2, t1, t2)) _ -> + ConclFocus (Prover.asRefinesS . unProp -> Just (Prover.RefinesS args ev1 ev2 + stack1 stack2 rtp1 rtp2 + t1 t2)) _ -> do tp1 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev1, stack1, rtp1] tp2 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev2, stack2, rtp2] let tt1 = TypedTerm (TypedTermOther tp1) t1 diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index 3fb3c94c25..4aa02c6eb0 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -11,7 +11,8 @@ Portability : non-portable (language extensions) module SAWScript.Prover.MRSolver (askMRSolver, refinementTerm, MRFailure(..), showMRFailure, showMRFailureNoCtx, - FunAssump(..), FunAssumpRHS(..), asRefinesS, asFunAssump, + RefinesS(..), asRefinesS, + FunAssump(..), FunAssumpRHS(..), asFunAssump, Refnset, emptyRefnset, addFunAssump, MREnv(..), emptyMREnv, mrEnvSetDebugLevel, asProjAll, isSpecFunType) where diff --git a/src/SAWScript/Prover/MRSolver/Evidence.hs b/src/SAWScript/Prover/MRSolver/Evidence.hs index 005e68bdbb..bc627954f5 100644 --- a/src/SAWScript/Prover/MRSolver/Evidence.hs +++ b/src/SAWScript/Prover/MRSolver/Evidence.hs @@ -16,9 +16,9 @@ also defines the 'MREnv' type, the global MRSolver state. Note: In order to avoid circular dependencies, the 'FunAssump' type and its dependents in this file ('Refnset' and 'MREvidence') are given a type -parameter `t` which in practice always be 'TheoremNonce' from `Value.hs`. -The reason we cannot just import `Value.hs` here directly is because the -'Refnset' type is used in `Value.hs` - specifically, in the 'VRefnset' +parameter @t@ which in practice always be 'TheoremNonce' from @Value.hs@. +The reason we cannot just import @Value.hs@ here directly is because the +'Refnset' type is used in @Value.hs@ - specifically, in the 'VRefnset' constructor of the 'Value' datatype. -} @@ -47,6 +47,52 @@ import SAWScript.Prover.MRSolver.Term -- * Function Refinement Assumptions ---------------------------------------------------------------------- +-- | A representation of a term of the form: +-- @(a1:A1) -> ... -> (an:An) -> refinesS ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2@ +data RefinesS = RefinesS { + -- | The context of the refinement, i.e. @[(a1,A1), ..., (an,An)]@ + -- from the term above + refnCtx :: [(LocalName, Term)], + -- | The LHS event type of the refinement, i.e. @ev1@ above + refnEv1 :: Term, + -- | The RHS event type of the refinement, i.e. @ev2@ above + refnEv2 :: Term, + -- | The LHS stack type of the refinement, i.e. @stack1@ above + refnStack1 :: Term, + -- | The RHS stack type of the refinement, i.e. @stack2@ above + refnStack2 :: Term, + -- | The LHS return type of the refinement, i.e. @rtp1@ above + refnRType1 :: Term, + -- | The RHS return type of the refinement, i.e. @rtp2@ above + refnRType2 :: Term, + -- | The LHS term of the refinement, i.e. @t1@ above + refnLHS :: Term, + -- | The RHS term of the refinement, i.e. @t2@ above + refnRHS :: Term +} + +-- | Recognizes a term of the form: +-- @(a1:A1) -> ... -> (an:An) -> refinesS ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2@ +-- and returns: +-- @RefinesS [(a1,A1), ..., (an,An)] ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2@ +asRefinesS :: Recognizer Term RefinesS +asRefinesS (asPiList -> (args, asApplyAll -> + (asGlobalDef -> Just "Prelude.refinesS", + [ev1, ev2, stack1, stack2, + asApplyAll -> (asGlobalDef -> Just "Prelude.eqPreRel", _), + asApplyAll -> (asGlobalDef -> Just "Prelude.eqPostRel", _), + rtp1, rtp2, + asApplyAll -> (asGlobalDef -> Just "Prelude.eqRR", _), + t1, t2]))) = + Just $ RefinesS args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2 +asRefinesS (asPiList -> (args, asApplyAll -> + (asGlobalDef -> Just "Prelude.refinesS_eq", + [ev, stack, rtp, t1, t2]))) = + Just $ RefinesS args ev ev stack stack rtp rtp t1 t2 +asRefinesS (asPiList -> (_, asApplyAll -> (asGlobalDef -> Just "Prelude.refinesS", _))) = + error "FIXME: MRSolver does not yet accept refinesS goals with non-trivial RPre/RPost/RR" +asRefinesS _ = Nothing + -- | The right-hand-side of a 'FunAssump': either a 'FunName' and arguments, if -- it is an opaque 'FunAsump', or a 'NormComp', if it is a rewrite 'FunAssump' data FunAssumpRHS = OpaqueFunAssump FunName [Term] @@ -74,28 +120,6 @@ data FunAssump t = FunAssump { fassumpAnnotation :: Maybe t } --- | Recognizes a term of the form: --- @(a1:A1) -> ... -> (an:An) -> refinesS_eq ev stack rtp t1 t2@, --- and returns a tuple: --- @([(a1,A1), ..., (an,An)], ev, ev, stack, stack, rtp, rtp, t1, t2)@ -asRefinesS :: Recognizer Term ([(LocalName, Term)], Term, Term, Term, Term, Term, Term, Term, Term) -asRefinesS (asPiList -> (args, asApplyAll -> - (asGlobalDef -> Just "Prelude.refinesS", - [ev1, ev2, stack1, stack2, - asApplyAll -> (asGlobalDef -> Just "Prelude.eqPreRel", _), - asApplyAll -> (asGlobalDef -> Just "Prelude.eqPostRel", _), - rtp1, rtp2, - asApplyAll -> (asGlobalDef -> Just "Prelude.eqRR", _), - t1, t2]))) = - Just (args, ev1, ev2, stack1, stack2, rtp1, rtp2, t1, t2) -asRefinesS (asPiList -> (args, asApplyAll -> - (asGlobalDef -> Just "Prelude.refinesS_eq", - [ev, stack, rtp, t1, t2]))) = - Just (args, ev, ev, stack, stack, rtp, rtp, t1, t2) -asRefinesS (asPiList -> (_, asApplyAll -> (asGlobalDef -> Just "Prelude.refinesS", _))) = - error "FIXME: MRSolver does not yet accept refinesS goals with non-trivial RPre/RPost/RR" -asRefinesS _ = Nothing - -- | Recognizes a term of the form: -- @(a1:A1) -> ... -> (an:An) -> refinesS_eq ev stack rtp (f b1 ... bm) t2@, -- and returns: @FunAssump f [a1,...,an] [b1,...,bm] rhs ann@, @@ -103,13 +127,12 @@ asRefinesS _ = Nothing -- @OpaqueFunAssump g [c1,...,cl]@ if @t2@ is @g c1 ... cl@, -- or @RewriteFunAssump t2@ otherwise asFunAssump :: Maybe t -> Recognizer Term (FunAssump t) -asFunAssump ann (asRefinesS -> Just (args, - asGlobalDef -> Just "Prelude.VoidEv", - asGlobalDef -> Just "Prelude.VoidEv", - asGlobalDef -> Just "Prelude.emptyFunStack", - asGlobalDef -> Just "Prelude.emptyFunStack", - _, _, - asApplyAll -> (asGlobalFunName -> Just f1, args1), +asFunAssump ann (asRefinesS -> Just (RefinesS args + (asGlobalDef -> Just "Prelude.VoidEv") + (asGlobalDef -> Just "Prelude.VoidEv") + (asGlobalDef -> Just "Prelude.emptyFunStack") + (asGlobalDef -> Just "Prelude.emptyFunStack") + _ _ (asApplyAll -> (asGlobalFunName -> Just f1, args1)) t2@(asApplyAll -> (asGlobalFunName -> mb_f2, args2)))) = let rhs = maybe (RewriteFunAssump t2) (\f2 -> OpaqueFunAssump f2 args2) mb_f2 in Just $ FunAssump { fassumpCtx = mrVarCtxFromOuterToInner args,