Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MrSolver] Add support for 'either's #1602

Merged
merged 9 commits into from
Mar 3, 2022
2 changes: 1 addition & 1 deletion heapster-saw/examples/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ rust_lifetimes.bc: rust_lifetimes.rs
rustc --crate-type=lib --emit=llvm-bc rust_lifetimes.rs

# Lists all the Mr Solver tests, without their ".saw" suffix
MR_SOLVER_TESTS = arrays_mr_solver
MR_SOLVER_TESTS = arrays_mr_solver linked_list_mr_solver

.PHONY: mr-solver-tests $(MR_SOLVER_TESTS)
mr-solver-tests: $(MR_SOLVER_TESTS)
Expand Down
Binary file modified heapster-saw/examples/linked_list.bc
Binary file not shown.
11 changes: 11 additions & 0 deletions heapster-saw/examples/linked_list.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ typedef struct list64_t {
struct list64_t *next;
} list64_t;

/* Test if a value is the head of a list, returning 1 if so and 0 otherwiese */
int64_t is_head (int64_t x, list64_t *l) {
if (l == NULL) {
return 0;
} else if (l->data == x) {
return 1;
} else {
return 0;
}
}

/* Test if a specific value is in a list, returning 1 if so and 0 otherwise */
int64_t is_elem (int64_t x, list64_t *l) {
if (l == NULL) {
Expand Down
26 changes: 26 additions & 0 deletions heapster-saw/examples/linked_list_mr_solver.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
include "linked_list.saw";

let eq_bool b1 b2 =
if b1 then
if b2 then true else false
else
if b2 then false else true;

let fail = do { print "Test failed"; exit 1; };
let run_test name test expected =
do { if expected then print (str_concat "Test: " name) else
print (str_concat (str_concat "Test: " name) " (expecting failure)");
actual <- test;
if eq_bool actual expected then print "Success\n" else
do { print "Test failed\n"; exit 1; }; };


heapster_typecheck_fun env "is_head"
"(). arg0:int64<>, arg1:List<int64<>,always,R> -o \
\ arg0:true, arg1:true, ret:int64<>";

is_head <- parse_core_mod "linked_list" "is_head";
run_test "is_head |= is_head" (mr_solver is_head is_head) true;

is_elem <- parse_core_mod "linked_list" "is_elem";
run_test "is_elem |= is_elem" (mr_solver is_elem is_elem) true;
87 changes: 72 additions & 15 deletions src/SAWScript/Prover/MRSolver/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}

{- |
Module : SAWScript.Prover.MRSolver.Monad
Expand All @@ -26,10 +29,14 @@ import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Except
import Control.Monad.Trans.Maybe
import GHC.Generics

import Data.Map (Map)
import qualified Data.Map as Map

import Data.HashMap.Lazy (HashMap)
import qualified Data.HashMap.Lazy as HashMap

import Prettyprinter

import Verifier.SAW.Term.Functor
Expand Down Expand Up @@ -226,6 +233,25 @@ data FunAssump = FunAssump {
-- name
type FunAssumps = Map FunName FunAssump

-- | An assumption that something is equal to one of the constructors of a
-- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term'
data DataTypeAssump = IsLeft Term | IsRight Term
deriving (Generic, Show, TermLike)

instance PrettyInCtx DataTypeAssump where
prettyInCtx (IsLeft x) = prettyInCtx x >>= ppWithPrefix "Left _ _"
prettyInCtx (IsRight x) = prettyInCtx x >>= ppWithPrefix "Right _ _"

-- | Recognize a term as a @Left@ or @Right@
asEither :: Recognizer Term (Either Term Term)
asEither (asCtor -> Just (c, [_, _, x]))
| primName c == "Prelude.Left" = return $ Left x
| primName c == "Prelude.Right" = return $ Right x
asEither _ = Nothing

-- | A map from 'Term's to 'DataTypeAssump's over that term
type DataTypeAssumps = HashMap Term DataTypeAssump

-- | Parameters and locals for MR. Solver
data MRInfo = MRInfo {
-- | Global shared context for building terms, etc.
Expand All @@ -243,6 +269,8 @@ data MRInfo = MRInfo {
-- | The current assumptions, which are conjoined into a single Boolean term;
-- note that these have the current UVars free
mriAssumptions :: Term,
-- | The current set of 'DataTypeAssump's
mriDataTypeAssumps :: DataTypeAssumps,
-- | The debug level, which controls debug printing
mriDebugLevel :: Int
}
Expand All @@ -264,9 +292,9 @@ data MRExn = MRExnFailure MRFailure
-- type, all over an 'IO' monad
newtype MRM a = MRM { unMRM :: ReaderT MRInfo (StateT MRState
(ExceptT MRExn IO)) a }
deriving (Functor, Applicative, Monad, MonadIO,
MonadReader MRInfo, MonadState MRState,
MonadError MRExn)
deriving newtype (Functor, Applicative, Monad, MonadIO,
MonadReader MRInfo, MonadState MRState,
MonadError MRExn)

instance MonadTerm MRM where
mkTermF = liftSC1 scTermF
Expand All @@ -278,31 +306,35 @@ instance MonadTerm MRM where
mrSC :: MRM SharedContext
mrSC = mriSC <$> ask

-- | Get the current value of 'mrSMTTimeout'
-- | Get the current value of 'mriSMTTimeout'
mrSMTTimeout :: MRM (Maybe Integer)
mrSMTTimeout = mriSMTTimeout <$> ask

-- | Get the current value of 'mrUVars'
-- | Get the current value of 'mriUVars'
mrUVars :: MRM [(LocalName,Type)]
mrUVars = mriUVars <$> ask

-- | Get the current value of 'mrFunAssumps'
-- | Get the current value of 'mriFunAssumps'
mrFunAssumps :: MRM FunAssumps
mrFunAssumps = mriFunAssumps <$> ask

-- | Get the current value of 'mrCoIndHyps'
-- | Get the current value of 'mriCoIndHyps'
mrCoIndHyps :: MRM CoIndHyps
mrCoIndHyps = mriCoIndHyps <$> ask

-- | Get the current value of 'mrAssumptions'
-- | Get the current value of 'mriAssumptions'
mrAssumptions :: MRM Term
mrAssumptions = mriAssumptions <$> ask

-- | Get the current value of 'mrDebugLevel'
-- | Get the current value of 'mriDataTypeAssumps'
mrDataTypeAssumps :: MRM DataTypeAssumps
mrDataTypeAssumps = mriDataTypeAssumps <$> ask

-- | Get the current value of 'mriDebugLevel'
mrDebugLevel :: MRM Int
mrDebugLevel = mriDebugLevel <$> ask

-- | Get the current value of 'mrVars'
-- | Get the current value of 'mrsVars'
mrVars :: MRM MRVarMap
mrVars = mrsVars <$> get

Expand All @@ -314,7 +346,8 @@ runMRM sc timeout debug assumps m =
let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout,
mriDebugLevel = debug, mriFunAssumps = assumps,
mriUVars = [], mriCoIndHyps = Map.empty,
mriAssumptions = true_tm }
mriAssumptions = true_tm,
mriDataTypeAssumps = HashMap.empty }
let init_st = MRState { mrsVars = Map.empty }
res <- runExceptT $ flip evalStateT init_st $
flip runReaderT init_info $ unMRM m
Expand Down Expand Up @@ -453,7 +486,7 @@ mrFunOutType :: FunName -> [Term] -> MRM Term
mrFunOutType fname args =
funNameType fname >>= \case
(asPiList -> (vars, asCompM -> Just tp))
| length vars == length args -> substTermLike 0 args tp
| length vars == length args -> substTermLike 0 (reverse args) tp
ftype@(asPiList -> (vars, _)) ->
do pp_ftype <- mrPPInCtx ftype
pp_fname <- mrPPInCtx fname
Expand Down Expand Up @@ -503,16 +536,19 @@ withUVars ctx f =
do nms <- uniquifyNames (map fst ctx) <$> map fst <$> mrUVars
let ctx_u = zip nms $ map (Type . snd) ctx
assumps' <- mrAssumptions >>= liftTerm 0 (length ctx)
dataTypeAssumps' <- mrDataTypeAssumps >>= mapM (liftTermLike 0 (length ctx))
vars <- reverse <$> mapM (liftSC1 scLocalVar) [0 .. length ctx - 1]
local (\info -> info { mriUVars = reverse ctx_u ++ mriUVars info,
mriAssumptions = assumps' }) $
mriAssumptions = assumps',
mriDataTypeAssumps = dataTypeAssumps' }) $
foldr (\nm m -> mapMRFailure (MRFailureLocalVar nm) m) (f vars) nms

-- | Run a MR Solver in a top-level context, i.e., with no uvars or assumptions
withNoUVars :: MRM a -> MRM a
withNoUVars m =
do true_tm <- liftSC1 scBool True
local (\info -> info { mriUVars = [], mriAssumptions = true_tm }) m
local (\info -> info { mriUVars = [], mriAssumptions = true_tm,
mriDataTypeAssumps = HashMap.empty }) m

-- | Run a MR Solver in a context of only the specified UVars, no others
withOnlyUVars :: [(LocalName,Term)] -> MRM a -> MRM a
Expand Down Expand Up @@ -799,10 +835,23 @@ instantiateFunAssump fassump =
-- executing a sub-computation
withAssumption :: Term -> MRM a -> MRM a
withAssumption phi m =
do assumps <- mrAssumptions
do mrDebugPPPrefix 1 "withAssumption" phi
assumps <- mrAssumptions
assumps' <- liftSC2 scAnd phi assumps
local (\info -> info { mriAssumptions = assumps' }) m

-- | Add a 'DataTypeAssump' to the current context while executing a
-- sub-computations
withDataTypeAssump :: Term -> DataTypeAssump -> MRM a -> MRM a
withDataTypeAssump x assump m =
do mrDebugPPPrefixSep 1 "withDataTypeAssump" x "==" assump
dataTypeAssumps' <- HashMap.insert x assump <$> mrDataTypeAssumps
local (\info -> info { mriDataTypeAssumps = dataTypeAssumps' }) m

-- | Get the 'DataTypeAssump' associated to the given term, if one exists
mrGetDataTypeAssump :: Term -> MRM (Maybe DataTypeAssump)
mrGetDataTypeAssump x = HashMap.lookup x <$> mrDataTypeAssumps

-- | Print a 'String' if the debug level is at least the supplied 'Int'
debugPrint :: Int -> String -> MRM ()
debugPrint i str =
Expand All @@ -824,6 +873,14 @@ mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc
mrPPInCtx a =
runReader (prettyInCtx a) <$> map fst <$> mrUVars

-- | Pretty-print the result of 'ppWithPrefix' relative to the current uvar
-- context to 'stderr' if the debug level is at least the 'Int' provided
mrDebugPPPrefix :: PrettyInCtx a => Int -> String -> a -> MRM ()
mrDebugPPPrefix i pre a =
mrUVars >>= \ctx ->
debugPretty i $
flip runReader (map fst ctx) (group <$> nest 2 <$> ppWithPrefix pre a)

-- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar
-- context to 'stderr' if the debug level is at least the 'Int' provided
mrDebugPPPrefixSep :: (PrettyInCtx a, PrettyInCtx b) =>
Expand Down
40 changes: 25 additions & 15 deletions src/SAWScript/Prover/MRSolver/SMT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -181,21 +181,29 @@ mrProvable (asBool -> Just b) = return b
mrProvable bool_tm =
do assumps <- mrAssumptions
prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue
prop_inst <- flip instantiateUVarsM prop $ \nm tp ->
liftSC1 scWhnf tp >>= \case
(asBVVecType -> Just (n, len, a)) ->
-- For variables of type BVVec, create a Vec n Bool -> a function as an
-- ExtCns and apply genBVVec to it
do
ec_tp <-
liftSC1 completeOpenTerm $
arrowOpenTerm "_" (applyOpenTermMulti (globalOpenTerm "Prelude.Vec")
[closedOpenTerm n, boolTypeOpenTerm])
(closedOpenTerm a)
ec <- liftSC2 scFreshEC nm ec_tp >>= liftSC1 scExtCns
liftSC4 genBVVecTerm n len a ec
tp' -> liftSC2 scFreshEC nm tp' >>= liftSC1 scExtCns
prop_inst <- instantiateUVarsM instUVar prop
normSMTProp prop_inst >>= mrProvableRaw
where -- | Given a UVar name and type, generate a 'Term' to be passed to
-- SMT, with special cases for BVVec and pair types
instUVar :: LocalName -> Term -> MRM Term
instUVar nm tp = liftSC1 scWhnf tp >>= \case
-- For variables of type BVVec, create a @Vec n Bool -> a@ function
-- as an ExtCns and apply genBVVec to it
(asBVVecType -> Just (n, len, a)) -> do
ec_tp <-
liftSC1 completeOpenTerm $
arrowOpenTerm "_" (applyOpenTermMulti (globalOpenTerm "Prelude.Vec")
[closedOpenTerm n, boolTypeOpenTerm])
(closedOpenTerm a)
ec <- instUVar nm ec_tp
liftSC4 genBVVecTerm n len a ec
-- For pairs, recurse on both sides and combine the result as a pair
(asPairType -> Just (tp1, tp2)) -> do
e1 <- instUVar nm tp1
e2 <- instUVar nm tp2
liftSC2 scPairValue e1 e2
-- Otherwise, create a global variable with the given name and type
tp' -> liftSC2 scFreshEC nm tp' >>= liftSC1 scExtCns


----------------------------------------------------------------------
Expand Down Expand Up @@ -269,7 +277,9 @@ mrProveEq t1 t2 =
tp <- mrTypeOf t1
varmap <- mrVars
cond_in_ctx <- mrProveEqH varmap tp t1 t2
withTermInCtx cond_in_ctx mrProvable
res <- withTermInCtx cond_in_ctx mrProvable
debugPrint 1 $ "mrProveEq: " ++ if res then "Success" else "Failure"
return res

-- | Prove that two terms are equal, instantiating evars if necessary, or
-- throwing an error if this is not possible
Expand Down
Loading