diff --git a/cryptol-saw-core/cryptol-saw-core.cabal b/cryptol-saw-core/cryptol-saw-core.cabal index a0d14d15f3..6385b341ee 100644 --- a/cryptol-saw-core/cryptol-saw-core.cabal +++ b/cryptol-saw-core/cryptol-saw-core.cabal @@ -28,6 +28,7 @@ library data-inttrie >= 0.1.4, integer-gmp, modern-uri, + mtl, panic, saw-core, saw-core-aig, @@ -44,7 +45,9 @@ library Verifier.SAW.Cryptol Verifier.SAW.Cryptol.Panic Verifier.SAW.Cryptol.Prelude + Verifier.SAW.Cryptol.PreludeM Verifier.SAW.Cryptol.Simpset + Verifier.SAW.Cryptol.Monadify Verifier.SAW.CryptolEnv Verifier.SAW.TypedTerm GHC-options: -Wall -Werror diff --git a/cryptol-saw-core/saw/Cryptol.sawcore b/cryptol-saw-core/saw/Cryptol.sawcore index fb75134bbc..fe13215f39 100644 --- a/cryptol-saw-core/saw/Cryptol.sawcore +++ b/cryptol-saw-core/saw/Cryptol.sawcore @@ -1886,5 +1886,3 @@ axiom demote_add_distr (ecNumber (tcAdd x y) (TCNum w)) (bvAdd w (ecNumber x (TCNum w)) (ecNumber y (TCNum w))); -} - --------------------------------------------------------------------------------- diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore new file mode 100644 index 0000000000..452fb45cbc --- /dev/null +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -0,0 +1,395 @@ +------------------------------------------------------------------------------- +-- Cryptol primitives for SAWCore + +module CryptolM where + +-- import Prelude; +import Cryptol; + + +-------------------------------------------------------------------------------- +-- Monadic assertions + +primitive proveEqNum : (n m:Num) -> Maybe (Eq Num n m); + +-- A version of unsafeAssert specialized to the Num type +numAssertEqM : (n m:Num) -> CompM (Eq Num n m); +numAssertEqM n m = + maybe (Eq Num n m) (CompM (Eq Num n m)) + (errorM (Eq Num n m) "numAssertEqM: assertion failed") + (returnM (Eq Num n m)) + (proveEqNum n m); + +-- A proof that a Num is finite +isFinite : Num -> Prop; +isFinite = Num_rec (\ (_:Num) -> Prop) (\ (_:Nat) -> TrueProp) FalseProp; + +-- Assert that a Num is finite, or fail +assertFiniteM : (n:Num) -> CompM (isFinite n); +assertFiniteM = + Num_rec (\ (n:Num) -> CompM (isFinite n)) + (\ (_:Nat) -> returnM TrueProp TrueI) + (errorM FalseProp "assertFiniteM: Num not finite"); + +-- Recurse over a Num known to be finite +Num_rec_fin : (p: Num -> sort 1) -> ((n:Nat) -> p (TCNum n)) -> + (n:Num) -> isFinite n -> p n; +Num_rec_fin p f = + Num_rec (\ (n:Num) -> isFinite n -> p n) + (\ (n:Nat) (_:TrueProp) -> f n) + (efq1 (p TCInf)); + + +-------------------------------------------------------------------------------- +-- Monadic Sequences + +-- The type of monadified sequences, which are just vectors for finite length +-- but are sequences of computations for streams +mseq : Num -> sort 0 -> sort 0; +mseq num a = + Num_rec (\ (_:Num) -> sort 0) (\ (n:Nat) -> Vec n a) (Stream (CompM a)) num; + +vecMapM : (a b : isort 0) -> (n : Nat) -> (a -> CompM b) -> Vec n a -> + CompM (Vec n b); +vecMapM a b n_top f = + Nat__rec (\ (n:Nat) -> Vec n a -> CompM (Vec n b)) + (\ (_:Vec 0 a) -> returnM (Vec 0 b) (EmptyVec b)) + (\ (n:Nat) (rec:Vec n a -> CompM (Vec n b)) (v:Vec (Succ n) a) -> + fmapM2 b (Vec n b) (Vec (Succ n) b) + (\ (x:b) (xs:Vec n b) -> ConsVec b x n xs) + (f (head n a v)) (rec (tail n a v))) + n_top; + +-- Computational version of seqMap +seqMapM : (a b : sort 0) -> (n : Num) -> (a -> CompM b) -> mseq n a -> + CompM (mseq n b); +seqMapM a b n_top f = + Num_rec (\ (n:Num) -> mseq n a -> CompM (mseq n b)) + (\ (n:Nat) -> vecMapM a b n f) + (\ (s:Stream (CompM a)) -> + returnM (Stream (CompM b)) + (streamMap (CompM a) (CompM b) + (\ (m:CompM a) -> bindM a b m f) s)) + n_top; + +mseq_cong1 : (m : Num) -> (n : Num) -> (a : sort 0) -> + Eq Num m n -> Eq (sort 0) (mseq m a) (mseq n a); +mseq_cong1 m n a eq_mn = + eq_cong Num m n eq_mn (sort 0) (\ (x:Num) -> mseq x a); + +-- Convert a seq to an mseq +seqToMseq : (n:Num) -> (a:sort 0) -> seq n a -> mseq n a; +seqToMseq n_top a = + Num_rec (\ (n:Num) -> seq n a -> mseq n a) + (\ (n:Nat) (v:Vec n a) -> v) + (streamMap a (CompM a) (returnM a)) + n_top; + + +-------------------------------------------------------------------------------- +-- Auxiliary functions + +atM : (n : Nat) -> (a : sort 0) -> Vec n a -> Nat -> CompM a; +atM n_top a = + Nat__rec + (\ (n:Nat) -> Vec n a -> Nat -> CompM a) + (\ (_:Vec 0 a) (_:Nat) -> errorM a "atM: index out of bounds") + (\ (n:Nat) (rec_f: Vec n a -> Nat -> CompM a) (v:Vec (Succ n) a) (i:Nat) -> + Nat_cases (CompM a) + (returnM a (head n a v)) + (\ (i_prev:Nat) (_:CompM a) -> rec_f (tail n a v) i_prev) i) + n_top; + + +eListSelM : (a : isort 0) -> (n : Num) -> mseq n a -> Nat -> CompM a; +eListSelM a = + Num_rec (\ (n:Num) -> mseq n a -> Nat -> CompM a) + (\ (n:Nat) -> atM n a) + (eListSel (CompM a) TCInf); + + +-------------------------------------------------------------------------------- +-- List comprehensions + +-- FIXME +primitive +fromM : (a b : sort 0) -> (m n : Num) -> mseq m a -> (a -> CompM (mseq n b)) -> + CompM (seq (tcMul m n) (a * b)); + +-- FIXME +primitive +mletM : (a b : sort 0) -> (n : Num) -> a -> (a -> CompM (mseq n b)) -> + CompM (mseq n (a * b)); + +-- FIXME +primitive +seqZipM : (a b : sort 0) -> (m n : Num) -> mseq m a -> mseq n b -> + CompM (mseq (tcMin m n) (a * b)); +{- +seqZipM a b m n ms1 ms2 = + seqMap + (CompM a * CompM b) (CompM (a * b)) (tcMin m n) + (\ (p : CompM a * CompM b) -> + bindM2 a b (a*b) p.(1) p.(2) (\ (x:a) (y:b) -> returnM (a*b) (x,y))) + (seqZip (CompM a) (CompM b) m n ms1 ms2); +-} + + +-------------------------------------------------------------------------------- +-- Monadic versions of the Cryptol typeclass instances + +-- PEq +PEqMSeq : (n:Num) -> isFinite n -> (a:isort 0) -> PEq a -> PEq (mseq n a); +PEqMSeq = + Num_rec_fin (\ (n:Num) -> (a:isort 0) -> PEq a -> PEq (mseq n a)) + (\ (n:Nat) -> PEqVec n); + +PEqMSeqBool : (n : Num) -> isFinite n -> PEq (mseq n Bool); +PEqMSeqBool = + Num_rec_fin (\ (n:Num) -> PEq (mseq n Bool)) PEqWord; + +-- PCmp +PCmpMSeq : (n:Num) -> isFinite n -> (a:isort 0) -> PCmp a -> PCmp (mseq n a); +PCmpMSeq = + Num_rec_fin (\ (n:Num) -> (a:isort 0) -> PCmp a -> PCmp (mseq n a)) + (\ (n:Nat) -> PCmpVec n); + +PCmpMSeqBool : (n : Num) -> isFinite n -> PCmp (seq n Bool); +PCmpMSeqBool = + Num_rec_fin (\ (n:Num) -> PCmp (seq n Bool)) PCmpWord; + +-- PSignedCmp +PSignedCmpMSeq : (n:Num) -> isFinite n -> (a:isort 0) -> PSignedCmp a -> + PSignedCmp (mseq n a); +PSignedCmpMSeq = + Num_rec_fin (\ (n:Num) -> (a:isort 0) -> PSignedCmp a -> + PSignedCmp (mseq n a)) + (\ (n:Nat) -> PSignedCmpVec n); + +PSignedCmpMSeqBool : (n : Num) -> isFinite n -> PSignedCmp (seq n Bool); +PSignedCmpMSeqBool = + Num_rec_fin (\ (n:Num) -> PSignedCmp (seq n Bool)) PSignedCmpWord; + + +-- PZero +PZeroCompM : (a : sort 0) -> PZero a -> PZero (CompM a); +PZeroCompM = returnM; + +PZeroMSeq : (n : Num) -> (a : sort 0) -> PZero a -> PZero (mseq n a); +PZeroMSeq n_top a pa = + Num_rec (\ (n:Num) -> PZero (mseq n a)) + (\ (n:Nat) -> seqConst (TCNum n) a pa) + (seqConst TCInf (CompM a) (returnM a pa)) + n_top; + +-- PLogic +PLogicCompM : (a : sort 0) -> PLogic a -> PLogic (CompM a); +PLogicCompM a pa = + { logicZero = returnM a (pa.logicZero) + , and = fmapM2 a a a (pa.and) + , or = fmapM2 a a a (pa.or) + , xor = fmapM2 a a a (pa.xor) + , not = fmapM a a (pa.not) + }; + +PLogicMSeq : (n : Num) -> (a : isort 0) -> PLogic a -> PLogic (mseq n a); +PLogicMSeq n_top a pa = + Num_rec (\ (n:Num) -> PLogic (mseq n a)) + (\ (n:Nat) -> PLogicVec n a pa) + (PLogicStream (CompM a) (PLogicCompM a pa)) + n_top; + +PLogicMSeqBool : (n : Num) -> isFinite n -> PLogic (mseq n Bool); +PLogicMSeqBool = + Num_rec_fin (\ (n:Num) -> PLogic (mseq n Bool)) PLogicWord; + +-- PRing +PRingCompM : (a : sort 0) -> PRing a -> PRing (CompM a); +PRingCompM a pa = + { ringZero = returnM a (pa.ringZero) + , add = fmapM2 a a a (pa.add) + , sub = fmapM2 a a a (pa.sub) + , mul = fmapM2 a a a (pa.mul) + , neg = fmapM a a (pa.neg) + , int = \ (i : Integer) -> returnM a (pa.int i) + }; + +PRingMSeq : (n : Num) -> (a : isort 0) -> PRing a -> PRing (mseq n a); +PRingMSeq n_top a pa = + Num_rec (\ (n:Num) -> PRing (mseq n a)) + (\ (n:Nat) -> PRingVec n a pa) + (PRingStream (CompM a) (PRingCompM a pa)) + n_top; + +PRingMSeqBool : (n : Num) -> isFinite n -> PRing (mseq n Bool); +PRingMSeqBool = + Num_rec_fin (\ (n:Num) -> PRing (mseq n Bool)) PRingWord; + +-- Integral +PIntegralMSeqBool : (n : Num) -> isFinite n -> PIntegral (mseq n Bool); +PIntegralMSeqBool = + Num_rec_fin (\ (n:Num) -> PIntegral (mseq n Bool)) PIntegralWord; + +-- PLiteral +PLiteralSeqBoolM : (n : Num) -> isFinite n -> PLiteral (mseq n Bool); +PLiteralSeqBoolM = + Num_rec_fin (\ (n:Num) -> PLiteral (mseq n Bool)) bvNat; + + +-------------------------------------------------------------------------------- +-- Monadic versions of the Cryptol primitives + + +-- Sequences + +-- FIXME: a number of the non-monadic versions of these functions contain calls +-- to finNumRec, which calls error on non-finite numbers. The monadic versions +-- of these, below, should be reimplemented to not contain finNumRec, but to +-- just use Num_rec_fin directly, rather than using it and then calling out to +-- the non-monadic version with finNumRec. + +ecShiftLM : (m : Num) -> (ix a : sort 0) -> PIntegral ix -> PZero a -> + mseq m a -> ix -> mseq m a; +ecShiftLM = + Num_rec (\ (m:Num) -> (ix a : sort 0) -> PIntegral ix -> PZero a -> + mseq m a -> ix -> mseq m a) + (\ (m:Nat) -> ecShiftL (TCNum m)) + (\ (ix a : sort 0) (pix:PIntegral ix) (pa:PZero a) -> + ecShiftL TCInf ix (CompM a) pix (PZeroCompM a pa)); + +ecShiftRM : (m : Num) -> (ix a : sort 0) -> PIntegral ix -> PZero a -> + mseq m a -> ix -> mseq m a; +ecShiftRM = + Num_rec (\ (m:Num) -> (ix a : sort 0) -> PIntegral ix -> PZero a -> + mseq m a -> ix -> mseq m a) + (\ (m:Nat) -> ecShiftL (TCNum m)) + (\ (ix a : sort 0) (pix:PIntegral ix) (pa:PZero a) -> + ecShiftR TCInf ix (CompM a) pix (PZeroCompM a pa)); + +ecSShiftRM : (n : Num) -> isFinite n -> (ix : sort 0) -> PIntegral ix -> + mseq n Bool -> ix -> mseq n Bool; +ecSShiftRM = + Num_rec_fin + (\ (n:Num) -> (ix : sort 0) -> PIntegral ix -> mseq n Bool -> ix -> + mseq n Bool) + (\ (n:Nat) -> ecSShiftR (TCNum n)); + +ecRotLM : (m : Num) -> isFinite m -> (ix a : sort 0) -> PIntegral ix -> + mseq m a -> ix -> mseq m a; +ecRotLM = + Num_rec_fin + (\ (m:Num) -> (ix a : sort 0) -> PIntegral ix -> mseq m a -> ix -> mseq m a) + (\ (m:Nat) -> ecRotL (TCNum m)); + +ecRotRM : (m : Num) -> isFinite m -> (ix a : sort 0) -> PIntegral ix -> + mseq m a -> ix -> mseq m a; +ecRotRM = + Num_rec_fin + (\ (m:Num) -> (ix a : sort 0) -> PIntegral ix -> mseq m a -> ix -> mseq m a) + (\ (m:Nat) -> ecRotR (TCNum m)); + +ecCatM : (m : Num) -> isFinite m -> (n : Num) -> (a : sort 0) -> + mseq m a -> mseq n a -> mseq (tcAdd m n) a; +ecCatM = + Num_rec_fin + (\ (m:Num) -> (n:Num) -> (a:sort 0) -> mseq m a -> mseq n a -> + mseq (tcAdd m n) a) + (\ (m:Nat) -> + Num_rec + (\ (n:Num) -> (a:isort 0) -> Vec m a -> mseq n a -> + mseq (tcAdd (TCNum m) n) a) + -- Case for (TCNum m, TCNum n) + (\ (n:Nat) -> \ (a:isort 0) -> append m n a) + -- Case for (TCNum m, TCInf) + (\ (a:isort 0) (v:Vec m a) -> + streamAppend (CompM a) m (map a (CompM a) (returnM a) m v))); + +-- FIXME +primitive +ecTakeM : (m n : Num) -> (a : sort 0) -> mseq (tcAdd m n) a -> mseq m a; +{- +ecTakeM = + Num_rec (\ (m:Num) -> (n:Num) -> (a:sort 0) -> mseq (tcAdd m n) a -> mseq m a) + (\ (m:Nat) -> ecTake (TCNum m)) + (\ (n:Num) (a:sort 0) (s:Stream (CompM a)) -> + ecTake TCInf n (CompM a) s); +-} + +-- FIXME +primitive +ecDropM : (m : Num) -> isFinite m -> (n : Num) -> (a : sort 0) -> + mseq (tcAdd m n) a -> mseq n a; + +-- FIXME +primitive +ecJoinM : (m n : Num) -> (a : sort 0) -> mseq m (mseq n a) -> mseq (tcMul m n) a; + +-- FIXME +primitive +ecSplitM : (m n : Num) -> (a : sort 0) -> mseq (tcMul m n) a -> + mseq m (mseq n a); + +ecReverseM : (n : Num) -> isFinite n -> (a : sort 0) -> mseq n a -> mseq n a; +ecReverseM = + Num_rec_fin (\ (n:Num) -> (a : sort 0) -> mseq n a -> mseq n a) + (\ (n:Nat) -> ecReverse (TCNum n)); + +-- FIXME +primitive +ecTransposeM : (m n : Num) -> (a : sort 0) -> mseq m (mseq n a) -> + mseq n (mseq m a); + +ecAtM : (n : Num) -> (a ix: sort 0) -> PIntegral ix -> mseq n a -> ix -> CompM a; +ecAtM n_top a ix pix = + Num_rec + (\ (n:Num) -> mseq n a -> ix -> CompM a) + (\ (n:Nat) (v:Vec n a) -> + pix.posNegCases (CompM a) (atM n a v) (\ (_:Nat) -> atM n a v 0)) + (\ (s:Stream (CompM a)) -> + pix.posNegCases (CompM a) (streamGet (CompM a) s) + (\ (_:Nat) -> (streamGet (CompM a) s) 0)) + n_top; + +-- FIXME +primitive +ecAtBackM : (n : Num) -> isFinite n -> (a ix : sort 0) -> PIntegral ix -> + mseq n a -> ix -> CompM a; + +-- FIXME +primitive +ecFromToM : (first : Num) -> isFinite first -> (last : Num) -> isFinite last -> + (a : isort 0) -> PLiteral a -> + mseq (tcAdd (TCNum 1) (tcSub last first)) a; + +-- FIXME +primitive +ecFromToLessThanM : (first : Num) -> isFinite first -> (bound : Num) -> + (a : isort 0) -> PLiteralLessThan a -> + mseq (tcSub bound first) a; + +-- FIXME +primitive +ecFromThenToM : + (first next last : Num) -> (a : sort 0) -> (len : Num) -> isFinite len -> + PLiteral a -> PLiteral a -> PLiteral a -> mseq len a; + +ecInfFromM : (a : sort 0) -> PIntegral a -> a -> mseq TCInf a; +ecInfFromM a pa x = + MkStream (CompM a) + (\ (i : Nat) -> + returnM a (pa.integralRing.add x (pa.integralRing.int (natToInt i)))); + +ecInfFromThenM : (a : sort 0) -> PIntegral a -> a -> a -> mseq TCInf a; +ecInfFromThenM a pa x y = + MkStream (CompM a) + (\ (i : Nat) -> + returnM a (pa.integralRing.add x + (pa.integralRing.mul (pa.integralRing.sub y x) + (pa.integralRing.int (natToInt i))))); + +ecErrorM : (a : sort 0) -> (len : Num) -> mseq len (Vec 8 Bool) -> CompM a; +ecErrorM a len msg = + errorM a "encountered call to the Cryptol 'error' function"; + + +-------------------------------------------------------------------------------- diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs index fd901b4e2d..acf988b2c5 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs @@ -4,6 +4,8 @@ {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE BangPatterns #-} {- | Module : Verifier.SAW.Cryptol @@ -17,6 +19,7 @@ Portability : non-portable (language extensions) module Verifier.SAW.Cryptol where import Control.Monad (foldM, join, unless) +import Control.Exception (catch, SomeException) import Data.Bifunctor (first) import qualified Data.Foldable as Fold import Data.List @@ -1646,9 +1649,14 @@ asCryptolTypeValue v = scCryptolType :: SharedContext -> Term -> IO (Maybe (Either C.Kind C.Type)) scCryptolType sc t = do modmap <- scGetModuleMap sc - case SC.evalSharedTerm modmap Map.empty Map.empty t of - SC.TValue tv -> return (asCryptolTypeValue tv) - _ -> return Nothing + catch + (case SC.evalSharedTerm modmap Map.empty Map.empty t of + -- NOTE: we make sure that asCryptolTypeValue gets evaluated, to + -- ensure that any panics in the simulator get caught here + SC.TValue tv + | Just !ret <- asCryptolTypeValue tv -> return $ Just ret + _ -> return Nothing) + (\ (_::SomeException) -> return Nothing) -- | Convert from SAWCore's Value type to Cryptol's, guided by the -- Cryptol type schema. diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs new file mode 100644 index 0000000000..526990782e --- /dev/null +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -0,0 +1,1199 @@ +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternGuards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} + +{- | +Module : Verifier.SAW.Cryptol.Monadify +Copyright : Galois, Inc. 2021 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module implements a "monadification" transformation, which converts "pure" +SAW core terms that use inconsistent operations like @fix@ and convert them to +monadic SAW core terms that use monadic versions of these operations that are +consistent. The monad that is used is the @CompM@ monad that is axiomatized in +the SAW cxore prelude. This is only a partial transformation, meaning that it +will fail on some SAW core terms. Specifically, it requires that all +applications @f arg@ in a term either have a non-dependent function type for @f@ +(i.e., a function with type @'Pi' x a b@ where @x@ does not occur in @b@) or a +pure argument @arg@ that does not use any of the inconsistent operations. + +FIXME: explain this better + + +Type-level translation: + +MT(Pi x (sort 0) b) = Pi x (sort 0) CompMT(b) +MT(Pi x Num b) = Pi x Num CompMT(b) +MT(Pi _ a b) = MT(a) -> CompMT(b) +MT(#(a,b)) = #(MT(a),MT(b)) +MT(seq n a) = mseq n MT(a) +MT(f arg) = f MT(arg) -- NOTE: f must be a pure function! +MT(cnst) = cnst +MT(dt args) = dt MT(args) +MT(x) = x +MT(_) = error + +CompMT(tp = Pi _ _ _) = MT(tp) +CompMT(n : Num) = n +CompMT(tp) = CompM MT(tp) + + +Term-level translation: + +MonArg(t : tp) ==> MT(tp) +MonArg(t) = + case Mon(t) of + m : CompM MT(a) => shift \k -> m >>= \x -> k x + _ => t + +Mon(t : tp) ==> MT(tp) or CompMT(tp) (which are the same type for pis) +Mon((f : Pi x a b) arg) = Mon(f) MT(arg) +Mon((f : Pi _ a b) arg) = Mon(f) MonArg(arg) +Mon(Lambda x a t) = Lambda x MT(a) Mon(t) +Mon((t,u)) = (MonArg(t),MonArg(u)) +Mon(c args) = c MonArg(args) +Mon(x) = x +Mon(fix) = fixM (of some form...) +Mon(cnst) = cnstM if cnst is impure and monadifies to constM +Mon(cnst) = cnst otherwise +-} + +module Verifier.SAW.Cryptol.Monadify where + +import Data.Maybe +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import Data.IntMap.Strict (IntMap) +import qualified Data.IntMap.Strict as IntMap +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Cont +import qualified Control.Monad.Fail as Fail +-- import Control.Monad.IO.Class (MonadIO, liftIO) +import qualified Data.Text as T +import qualified Text.URI as URI + +import Verifier.SAW.Name +import Verifier.SAW.Term.Functor +import Verifier.SAW.SharedTerm +import Verifier.SAW.OpenTerm +-- import Verifier.SAW.SCTypeCheck +import Verifier.SAW.Recognizer +-- import Verifier.SAW.Position +import Verifier.SAW.Cryptol.PreludeM + +import Debug.Trace + + +-- Type-check the Prelude, Cryptol, and CryptolM modules at compile time +{- +import Language.Haskell.TH +import Verifier.SAW.Cryptol.Prelude + +$(runIO (mkSharedContext >>= \sc -> + scLoadPreludeModule sc >> scLoadCryptolModule sc >> + scLoadCryptolMModule sc >> return [])) +-} + + +---------------------------------------------------------------------- +-- * Typing All Subterms +---------------------------------------------------------------------- + +-- | A SAW core term where all of the subterms are typed +data TypedSubsTerm + = TypedSubsTerm { tpSubsIndex :: Maybe TermIndex, + tpSubsFreeVars :: BitSet, + tpSubsTermF :: TermF TypedSubsTerm, + tpSubsTypeF :: TermF TypedSubsTerm, + tpSubsSort :: Sort } + +-- | Convert a 'Term' to a 'TypedSubsTerm' +typeAllSubterms :: SharedContext -> Term -> IO TypedSubsTerm +typeAllSubterms = error "FIXME HERE" + +-- | Convert a 'TypedSubsTerm' back to a 'Term' +typedSubsTermTerm :: TypedSubsTerm -> Term +typedSubsTermTerm = error "FIXME HERE" + +-- | Get the type of a 'TypedSubsTerm' as a 'TypedSubsTerm' +typedSubsTermType :: TypedSubsTerm -> TypedSubsTerm +typedSubsTermType tst = + TypedSubsTerm { tpSubsIndex = Nothing, tpSubsFreeVars = tpSubsFreeVars tst, + tpSubsTermF = tpSubsTypeF tst, + tpSubsTypeF = FTermF (Sort (tpSubsSort tst) False), + tpSubsSort = sortOf (tpSubsSort tst) } + +-- | Count the number of right-nested pi-abstractions of a 'TypedSubsTerm' +typedSubsTermArity :: TypedSubsTerm -> Int +typedSubsTermArity (TypedSubsTerm { tpSubsTermF = Pi _ _ tst }) = + 1 + typedSubsTermArity tst +typedSubsTermArity _ = 0 + +-- | Count the number of right-nested pi abstractions in a term, which +-- represents a type. This assumes that the type is in WHNF. +typeArity :: Term -> Int +typeArity tp = length $ fst $ asPiList tp + +class ToTerm a where + toTerm :: a -> Term + +instance ToTerm Term where + toTerm = id + +instance ToTerm TypedSubsTerm where + toTerm = typedSubsTermTerm + +unsharedApply :: Term -> Term -> Term +unsharedApply f arg = Unshared $ App f arg + + +---------------------------------------------------------------------- +-- * Monadifying Types +---------------------------------------------------------------------- + +-- | Test if a 'Term' is a first-order function type +isFirstOrderType :: Term -> Bool +isFirstOrderType (asPi -> Just (_, asPi -> Just _, _)) = False +isFirstOrderType (asPi -> Just (_, _, tp_out)) = isFirstOrderType tp_out +isFirstOrderType _ = True + +-- | A global definition, which is either a primitive or a constant. As +-- described in the documentation for 'ExtCns', the names need not be unique, +-- but the 'VarIndex' is, and this is what is used to index 'GlobalDef's. +data GlobalDef = GlobalDef { globalDefName :: NameInfo, + globalDefIndex :: VarIndex, + globalDefType :: Term, + globalDefTerm :: Term, + globalDefBody :: Maybe Term } + +instance Eq GlobalDef where + gd1 == gd2 = globalDefIndex gd1 == globalDefIndex gd2 + +instance Ord GlobalDef where + compare gd1 gd2 = compare (globalDefIndex gd1) (globalDefIndex gd2) + +instance Show GlobalDef where + show = show . globalDefName + +-- | Get the 'String' name of a 'GlobalDef' +globalDefString :: GlobalDef -> String +globalDefString = T.unpack . toAbsoluteName . globalDefName + +-- | Build an 'OpenTerm' from a 'GlobalDef' +globalDefOpenTerm :: GlobalDef -> OpenTerm +globalDefOpenTerm = closedOpenTerm . globalDefTerm + +-- | Recognize a named global definition, including its type +asTypedGlobalDef :: Recognizer Term GlobalDef +asTypedGlobalDef t = + case unwrapTermF t of + FTermF (Primitive pn) -> + Just $ GlobalDef (ModuleIdentifier $ + primName pn) (primVarIndex pn) (primType pn) t Nothing + Constant ec body -> + Just $ GlobalDef (ecName ec) (ecVarIndex ec) (ecType ec) t body + FTermF (ExtCns ec) -> + Just $ GlobalDef (ecName ec) (ecVarIndex ec) (ecType ec) t Nothing + _ -> Nothing + + +data MonKind = MKType Sort | MKNum | MKFun MonKind MonKind deriving Eq + +-- | Convert a kind to a SAW core sort, if possible +monKindToSort :: MonKind -> Maybe Sort +monKindToSort (MKType s) = Just s +monKindToSort _ = Nothing + +-- | Convert a 'MonKind' to the term it represents +monKindOpenTerm :: MonKind -> OpenTerm +monKindOpenTerm (MKType s) = sortOpenTerm s +monKindOpenTerm MKNum = dataTypeOpenTerm "Cryptol.Num" [] +monKindOpenTerm (MKFun k1 k2) = + arrowOpenTerm "_" (monKindOpenTerm k1) (monKindOpenTerm k2) + +data MonType + = MTyForall LocalName MonKind (MonType -> MonType) + | MTyArrow MonType MonType + | MTySeq OpenTerm MonType + | MTyPair MonType MonType + | MTyRecord [(FieldName, MonType)] + | MTyBase MonKind OpenTerm -- A "base type" or type var of a given kind + | MTyNum OpenTerm + +-- | Make a base type of sort 0 from an 'OpenTerm' +mkMonType0 :: OpenTerm -> MonType +mkMonType0 = MTyBase (MKType $ mkSort 0) + +-- | Make a 'MonType' for the Boolean type +boolMonType :: MonType +boolMonType = mkMonType0 $ globalOpenTerm "Prelude.Bool" + +-- | Test that a monadification type is monomorphic, i.e., has no foralls +monTypeIsMono :: MonType -> Bool +monTypeIsMono (MTyForall _ _ _) = False +monTypeIsMono (MTyArrow tp1 tp2) = monTypeIsMono tp1 && monTypeIsMono tp2 +monTypeIsMono (MTyPair tp1 tp2) = monTypeIsMono tp1 && monTypeIsMono tp2 +monTypeIsMono (MTyRecord tps) = all (monTypeIsMono . snd) tps +monTypeIsMono (MTySeq _ tp) = monTypeIsMono tp +monTypeIsMono (MTyBase _ _) = True +monTypeIsMono (MTyNum _) = True + +-- | Test if a monadification type @tp@ is considered a base type, meaning that +-- @CompMT(tp) = CompM MT(tp)@ +isBaseType :: MonType -> Bool +isBaseType (MTyForall _ _ _) = False +isBaseType (MTyArrow _ _) = False +isBaseType (MTySeq _ _) = True +isBaseType (MTyPair _ _) = True +isBaseType (MTyRecord _) = True +isBaseType (MTyBase (MKType _) _) = True +isBaseType (MTyBase _ _) = True +isBaseType (MTyNum _) = False + +-- | If a 'MonType' is a type-level number, return its 'OpenTerm', otherwise +-- return 'Nothing' +monTypeNum :: MonType -> Maybe OpenTerm +monTypeNum (MTyNum t) = Just t +monTypeNum (MTyBase MKNum t) = Just t +monTypeNum _ = Nothing + +-- | Get the kind of a 'MonType', assuming it has one +monTypeKind :: MonType -> Maybe MonKind +monTypeKind (MTyForall _ _ _) = Nothing +monTypeKind (MTyArrow t1 t2) = + do s1 <- monTypeKind t1 >>= monKindToSort + s2 <- monTypeKind t2 >>= monKindToSort + return $ MKType $ maxSort [s1, s2] +monTypeKind (MTyPair tp1 tp2) = + do sort1 <- monTypeKind tp1 >>= monKindToSort + sort2 <- monTypeKind tp2 >>= monKindToSort + return $ MKType $ maxSort [sort1, sort2] +monTypeKind (MTyRecord tps) = + do sorts <- mapM (monTypeKind . snd >=> monKindToSort) tps + return $ MKType $ maxSort sorts +monTypeKind (MTySeq _ tp) = + do sort <- monTypeKind tp >>= monKindToSort + return $ MKType sort +monTypeKind (MTyBase k _) = Just k +monTypeKind (MTyNum _) = Just MKNum + +-- | Get the 'Sort' @s@ of a 'MonType' if it has kind @'MKType' s@ +monTypeSort :: MonType -> Maybe Sort +monTypeSort = monTypeKind >=> monKindToSort + +-- | Convert a SAW core 'Term' to a monadification kind, if possible +monadifyKind :: Term -> Maybe MonKind +monadifyKind (asDataType -> Just (num, [])) + | primName num == "Cryptol.Num" = return MKNum +monadifyKind (asSort -> Just s) = return $ MKType s +monadifyKind (asPi -> Just (_, tp_in, tp_out)) = + MKFun <$> monadifyKind tp_in <*> monadifyKind tp_out +monadifyKind _ = Nothing + +-- | Get the kind of a type constructor with kind @k@ applied to type @t@, or +-- return 'Nothing' if the kinds do not line up +applyKind :: MonKind -> MonType -> Maybe MonKind +applyKind (MKFun k1 k2) t + | Just kt <- monTypeKind t + , kt == k1 = Just k2 +applyKind _ _ = Nothing + +-- | Perform 'applyKind' for 0 or more argument types +applyKinds :: MonKind -> [MonType] -> Maybe MonKind +applyKinds = foldM applyKind + +-- | Convert a 'MonType' to the argument type @MT(tp)@ it represents +toArgType :: MonType -> OpenTerm +toArgType (MTyForall x k body) = + piOpenTerm x (monKindOpenTerm k) (\tp -> toCompType (body $ MTyBase k tp)) +toArgType (MTyArrow t1 t2) = + arrowOpenTerm "_" (toArgType t1) (toCompType t2) +toArgType (MTySeq n t) = + applyOpenTermMulti (globalOpenTerm "CryptolM.mseq") [n, toArgType t] +toArgType (MTyPair mtp1 mtp2) = + pairTypeOpenTerm (toArgType mtp1) (toArgType mtp2) +toArgType (MTyRecord tps) = + recordTypeOpenTerm $ map (\(f,tp) -> (f, toArgType tp)) tps +toArgType (MTyBase _ t) = t +toArgType (MTyNum n) = n + +-- | Convert a 'MonType' to the computation type @CompMT(tp)@ it represents +toCompType :: MonType -> OpenTerm +toCompType mtp@(MTyForall _ _ _) = toArgType mtp +toCompType mtp@(MTyArrow _ _) = toArgType mtp +toCompType mtp = applyOpenTerm (globalOpenTerm "Prelude.CompM") (toArgType mtp) + +-- | The mapping for monadifying Cryptol typeclasses +-- FIXME: this is no longer needed, as it is now the identity +typeclassMonMap :: [(Ident,Ident)] +typeclassMonMap = + [("Cryptol.PEq", "Cryptol.PEq"), + ("Cryptol.PCmp", "Cryptol.PCmp"), + ("Cryptol.PSignedCmp", "Cryptol.PSignedCmp"), + ("Cryptol.PZero", "Cryptol.PZero"), + ("Cryptol.PLogic", "Cryptol.PLogic"), + ("Cryptol.PRing", "Cryptol.PRing"), + ("Cryptol.PIntegral", "Cryptol.PIntegral"), + ("Cryptol.PLiteral", "Cryptol.PLiteral")] + +-- | A context of local variables used for monadifying types, which includes the +-- variable names, their original types (before monadification), and, if their +-- types corespond to 'MonKind's, a local 'MonType' that quantifies over them. +-- +-- NOTE: the reason this type is different from 'MonadifyCtx', the context type +-- for monadifying terms, is that monadifying arrow types does not introduce a +-- local 'MonTerm' argument, since they are not dependent functions and so do +-- not use a HOAS encoding. +type MonadifyTypeCtx = [(LocalName,Term,Maybe MonType)] + +-- | Pretty-print a 'Term' relative to a 'MonadifyTypeCtx' +ppTermInTypeCtx :: MonadifyTypeCtx -> Term -> String +ppTermInTypeCtx ctx t = + scPrettyTermInCtx defaultPPOpts (map (\(x,_,_) -> x) ctx) t + +-- | Extract the variables and their original types from a 'MonadifyTypeCtx' +typeCtxPureCtx :: MonadifyTypeCtx -> [(LocalName,Term)] +typeCtxPureCtx = map (\(x,tp,_) -> (x,tp)) + +-- | Make a monadification type that is to be considered a base type +mkTermBaseType :: MonadifyTypeCtx -> MonKind -> Term -> MonType +mkTermBaseType ctx k t = + MTyBase k $ openOpenTerm (typeCtxPureCtx ctx) t + +-- | Monadify a type and convert it to its corresponding argument type +monadifyTypeArgType :: MonadifyTypeCtx -> Term -> OpenTerm +monadifyTypeArgType ctx t = toArgType $ monadifyType ctx t + +-- | Apply a monadified type to a type or term argument in the sense of +-- 'applyPiOpenTerm', meaning give the type of applying @f@ of a type to a +-- particular argument @arg@ +applyMonType :: MonType -> Either MonType ArgMonTerm -> MonType +applyMonType (MTyArrow _ tp_ret) (Right _) = tp_ret +applyMonType (MTyForall _ _ f) (Left mtp) = f mtp +applyMonType _ _ = error "applyMonType: application at incorrect type" + +-- | Convert a SAW core 'Term' to a monadification type +monadifyType :: MonadifyTypeCtx -> Term -> MonType +{- +monadifyType ctx t + | trace ("\nmonadifyType:\n" ++ ppTermInTypeCtx ctx t) False = undefined +-} +monadifyType ctx (asPi -> Just (x, tp_in, tp_out)) + | Just k <- monadifyKind tp_in = + MTyForall x k (\tp' -> monadifyType ((x,tp_in,Just tp'):ctx) tp_out) +monadifyType ctx tp@(asPi -> Just (_, _, tp_out)) + | inBitSet 0 (looseVars tp_out) = + error ("monadifyType: " ++ + "dependent function type with non-kind argument type: " ++ + ppTermInTypeCtx ctx tp) +monadifyType ctx tp@(asPi -> Just (x, tp_in, tp_out)) = + MTyArrow (monadifyType ctx tp_in) + (monadifyType ((x,tp,Nothing):ctx) tp_out) +monadifyType ctx (asPairType -> Just (tp1, tp2)) = + MTyPair (monadifyType ctx tp1) (monadifyType ctx tp2) +monadifyType ctx (asRecordType -> Just tps) = + MTyRecord $ map (\(fld,tp) -> (fld, monadifyType ctx tp)) $ Map.toList tps +monadifyType ctx (asDataType -> Just (eq_pn, [k_trm, tp1, tp2])) + | primName eq_pn == "Prelude.Eq" + , isJust (monadifyKind k_trm) = + -- NOTE: technically this is a Prop and not a sort 0, but it doesn't matter + mkMonType0 $ dataTypeOpenTerm "Prelude.Eq" [monadifyTypeArgType ctx tp1, + monadifyTypeArgType ctx tp2] +monadifyType ctx (asDataType -> Just (pn, args)) + | Just pn_k <- monadifyKind (primType pn) + , margs <- map (monadifyType ctx) args + , Just k_out <- applyKinds pn_k margs = + -- NOTE: this case only recognizes data types whose arguments are all types + -- and/or Nums + MTyBase k_out $ dataTypeOpenTerm (primName pn) (map toArgType margs) +monadifyType ctx (asVectorType -> Just (len, tp)) = + let lenOT = openOpenTerm (typeCtxPureCtx ctx) len in + MTySeq (ctorOpenTerm "Cryptol.TCNum" [lenOT]) $ monadifyType ctx tp +monadifyType ctx tp@(asApplyAll -> ((asGlobalDef -> Just seq_id), [n, a])) + | seq_id == "Cryptol.seq" = + case monTypeNum (monadifyType ctx n) of + Just n_trm -> MTySeq n_trm (monadifyType ctx a) + Nothing -> + error ("Monadify type: not a number: " ++ ppTermInTypeCtx ctx n + ++ " in type: " ++ ppTermInTypeCtx ctx tp) +monadifyType ctx (asApp -> Just ((asGlobalDef -> Just f), arg)) + | Just f_trans <- lookup f typeclassMonMap = + MTyBase (MKType $ mkSort 1) $ + applyOpenTerm (globalOpenTerm f_trans) $ monadifyTypeArgType ctx arg +monadifyType _ (asGlobalDef -> Just bool_id) + | bool_id == "Prelude.Bool" = + mkMonType0 (globalOpenTerm "Prelude.Bool") +{- +monadifyType ctx (asApplyAll -> (f, args)) + | Just glob <- asTypedGlobalDef f + , Just ec_k <- monadifyKind $ globalDefType glob + , margs <- map (monadifyType ctx) args + , Just k_out <- applyKinds ec_k margs = + MTyBase k_out (applyOpenTermMulti (globalDefOpenTerm glob) $ + map toArgType margs) +-} +monadifyType ctx tp@(asCtor -> Just (pn, _)) + | primName pn == "Cryptol.TCNum" || primName pn == "Cryptol.TCInf" = + MTyNum $ openOpenTerm (typeCtxPureCtx ctx) tp +monadifyType ctx (asLocalVar -> Just i) + | i < length ctx + , (_,_,Just tp) <- ctx!!i = tp +monadifyType ctx tp = + error ("monadifyType: not a valid type for monadification: " + ++ ppTermInTypeCtx ctx tp) + + +---------------------------------------------------------------------- +-- * Monadified Terms +---------------------------------------------------------------------- + +-- | A representation of a term that has been translated to argument type +-- @MT(tp)@ +data ArgMonTerm + -- | A monadification term of a base type @MT(tp)@ + = BaseMonTerm MonType OpenTerm + -- | A monadification term of non-depedent function type + | FunMonTerm LocalName MonType MonType (ArgMonTerm -> MonTerm) + -- | A monadification term of polymorphic type + | ForallMonTerm LocalName MonKind (MonType -> MonTerm) + +-- | A representation of a term that has been translated to computational type +-- @CompMT(tp)@ +data MonTerm + = ArgMonTerm ArgMonTerm + | CompMonTerm MonType OpenTerm + +-- | Get the monadification type of a monadification term +class GetMonType a where + getMonType :: a -> MonType + +instance GetMonType ArgMonTerm where + getMonType (BaseMonTerm tp _) = tp + getMonType (ForallMonTerm x k body) = MTyForall x k (getMonType . body) + getMonType (FunMonTerm _ tp_in tp_out _) = MTyArrow tp_in tp_out + +instance GetMonType MonTerm where + getMonType (ArgMonTerm t) = getMonType t + getMonType (CompMonTerm tp _) = tp + + +-- | Convert a monadification term to a SAW core term of type @CompMT(tp)@ +class ToCompTerm a where + toCompTerm :: a -> OpenTerm + +instance ToCompTerm ArgMonTerm where + toCompTerm (BaseMonTerm mtp t) = + applyOpenTermMulti (globalOpenTerm "Prelude.returnM") [toArgType mtp, t] + toCompTerm (FunMonTerm x tp_in _ body) = + lambdaOpenTerm x (toArgType tp_in) (toCompTerm . body . fromArgTerm tp_in) + toCompTerm (ForallMonTerm x k body) = + lambdaOpenTerm x (monKindOpenTerm k) (toCompTerm . body . MTyBase k) + +instance ToCompTerm MonTerm where + toCompTerm (ArgMonTerm amtrm) = toCompTerm amtrm + toCompTerm (CompMonTerm _ trm) = trm + + +-- | Convert an 'ArgMonTerm' to a SAW core term of type @MT(tp)@ +toArgTerm :: ArgMonTerm -> OpenTerm +toArgTerm (BaseMonTerm _ t) = t +toArgTerm t = toCompTerm t + + +-- | Build a monadification term from a term of type @MT(tp)@ +class FromArgTerm a where + fromArgTerm :: MonType -> OpenTerm -> a + +instance FromArgTerm ArgMonTerm where + fromArgTerm (MTyForall x k body) t = + ForallMonTerm x k (\tp -> fromCompTerm (body tp) (applyOpenTerm t $ + toArgType tp)) + fromArgTerm (MTyArrow t1 t2) t = + FunMonTerm "_" t1 t2 (\x -> fromCompTerm t2 (applyOpenTerm t $ toArgTerm x)) + fromArgTerm tp t = BaseMonTerm tp t + +instance FromArgTerm MonTerm where + fromArgTerm mtp t = ArgMonTerm $ fromArgTerm mtp t + +-- | Build a monadification term from a computational term of type @CompMT(tp)@ +fromCompTerm :: MonType -> OpenTerm -> MonTerm +fromCompTerm mtp t | isBaseType mtp = CompMonTerm mtp t +fromCompTerm mtp t = ArgMonTerm $ fromArgTerm mtp t + +-- | Build a monadification term from a function on terms which, when viewed as +-- a lambda, is a "semi-pure" function of the given monadification type, meaning +-- it maps terms of argument type @MT(tp)@ to an output value of argument type; +-- i.e., it has type @SemiP(tp)@, defined as: +-- +-- > SemiP(Pi x (sort 0) b) = Pi x (sort 0) SemiP(b) +-- > SemiP(Pi x Num b) = Pi x Num SemiP(b) +-- > SemiP(Pi _ a b) = MT(a) -> SemiP(b) +-- > SemiP(a) = MT(a) +fromSemiPureTermFun :: MonType -> ([OpenTerm] -> OpenTerm) -> ArgMonTerm +fromSemiPureTermFun (MTyForall x k body) f = + ForallMonTerm x k $ \tp -> + ArgMonTerm $ fromSemiPureTermFun (body tp) (f . (toArgType tp:)) +fromSemiPureTermFun (MTyArrow t1 t2) f = + FunMonTerm "_" t1 t2 $ \x -> + ArgMonTerm $ fromSemiPureTermFun t2 (f . (toArgTerm x:)) +fromSemiPureTermFun tp f = BaseMonTerm tp (f []) + +-- | Like 'fromSemiPureTermFun' but use a term rather than a term function +fromSemiPureTerm :: MonType -> OpenTerm -> ArgMonTerm +fromSemiPureTerm mtp t = fromSemiPureTermFun mtp (applyOpenTermMulti t) + +-- | Build a 'MonTerm' that 'fail's when converted to a term +failMonTerm :: MonType -> String -> MonTerm +failMonTerm mtp str = fromArgTerm mtp (failOpenTerm str) + +-- | Build an 'ArgMonTerm' that 'fail's when converted to a term +failArgMonTerm :: MonType -> String -> ArgMonTerm +failArgMonTerm tp str = fromArgTerm tp (failOpenTerm str) + +-- | Apply a monadified term to a type or term argument +applyMonTerm :: MonTerm -> Either MonType ArgMonTerm -> MonTerm +applyMonTerm (ArgMonTerm (FunMonTerm _ _ _ f)) (Right arg) = f arg +applyMonTerm (ArgMonTerm (ForallMonTerm _ _ f)) (Left mtp) = f mtp +applyMonTerm _ _ = error "applyMonTerm: application at incorrect type" + +-- | Apply a monadified term to 0 or more arguments +applyMonTermMulti :: MonTerm -> [Either MonType ArgMonTerm] -> MonTerm +applyMonTermMulti = foldl applyMonTerm + +-- | Build a 'MonTerm' from a global of a given argument type +mkGlobalArgMonTerm :: MonType -> Ident -> ArgMonTerm +mkGlobalArgMonTerm tp ident = fromArgTerm tp (globalOpenTerm ident) + +-- | Build a 'MonTerm' from a 'GlobalDef' of semi-pure type +mkSemiPureGlobalDefTerm :: GlobalDef -> ArgMonTerm +mkSemiPureGlobalDefTerm glob = + fromSemiPureTerm (monadifyType [] $ + globalDefType glob) (globalDefOpenTerm glob) + +-- | Build a 'MonTerm' from a constructor with the given 'PrimName' +mkCtorArgMonTerm :: PrimName Term -> ArgMonTerm +mkCtorArgMonTerm pn + | not (isFirstOrderType (primType pn)) = + failArgMonTerm (monadifyType [] $ primType pn) + ("monadification failed: cannot handle constructor " + ++ show (primName pn) ++ " with higher-order type") +mkCtorArgMonTerm pn = + fromSemiPureTermFun (monadifyType [] $ primType pn) (ctorOpenTerm $ primName pn) + + +---------------------------------------------------------------------- +-- * Monadification Environments and Contexts +---------------------------------------------------------------------- + +-- | A monadification macro is a function that inspects its first @N@ arguments +-- before deciding how to monadify itself +data MonMacro = MonMacro { + macroNumArgs :: Int, + macroApply :: GlobalDef -> [Term] -> MonadifyM MonTerm } + +-- | Make a simple 'MonMacro' that inspects 0 arguments and just returns a term +monMacro0 :: MonTerm -> MonMacro +monMacro0 mtrm = MonMacro 0 (\_ _ -> return mtrm) + +-- | Make a 'MonMacro' that maps a named global to a global of semi-pure type. +-- (See 'fromSemiPureTermFun'.) Because we can't get access to the type of the +-- global until we apply the macro, we monadify its type at macro application +-- time. +semiPureGlobalMacro :: Ident -> Ident -> MonMacro +semiPureGlobalMacro from to = + MonMacro 0 $ \glob args -> + if globalDefName glob == ModuleIdentifier from && args == [] then + return $ ArgMonTerm $ + fromSemiPureTerm (monadifyType [] $ globalDefType glob) (globalOpenTerm to) + else + error ("Monadification macro for " ++ show from ++ " applied incorrectly") + +-- | Make a 'MonMacro' that maps a named global to a global of argument +-- type. Because we can't get access to the type of the global until we apply +-- the macro, we monadify its type at macro application time. +argGlobalMacro :: NameInfo -> Ident -> MonMacro +argGlobalMacro from to = + MonMacro 0 $ \glob args -> + if globalDefName glob == from && args == [] then + return $ ArgMonTerm $ + mkGlobalArgMonTerm (monadifyType [] $ globalDefType glob) to + else + error ("Monadification macro for " ++ show from ++ " applied incorrectly") + +-- | An environment of named primitives and how to monadify them +type MonadifyEnv = Map NameInfo MonMacro + +-- | A context for monadifying 'Term's which maintains, for each deBruijn index +-- in scope, both its original un-monadified type along with either a 'MonTerm' +-- or 'MonType' for the translation of the variable to a local variable of +-- monadified type or monadified kind +type MonadifyCtx = [(LocalName,Term,Either MonType MonTerm)] + +-- | Convert a 'MonadifyCtx' to a 'MonadifyTypeCtx' +ctxToTypeCtx :: MonadifyCtx -> MonadifyTypeCtx +ctxToTypeCtx = map (\(x,tp,arg) -> + (x,tp,case arg of + Left mtp -> Just mtp + Right _ -> Nothing)) + +-- | Pretty-print a 'Term' relative to a 'MonadifyCtx' +ppTermInMonCtx :: MonadifyCtx -> Term -> String +ppTermInMonCtx ctx t = + scPrettyTermInCtx defaultPPOpts (map (\(x,_,_) -> x) ctx) t + +-- | A memoization table for monadifying terms +type MonadifyMemoTable = IntMap MonTerm + +-- | The empty memoization table +emptyMemoTable :: MonadifyMemoTable +emptyMemoTable = IntMap.empty + + +---------------------------------------------------------------------- +-- * The Monadification Monad +---------------------------------------------------------------------- + +-- | The read-only state of a monadification computation +data MonadifyROState = MonadifyROState { + -- | The monadification environment + monStEnv :: MonadifyEnv, + -- | The monadification context + monStCtx :: MonadifyCtx, + -- | The monadified return type of the top-level term being monadified + monStTopRetType :: OpenTerm +} + +-- | The monad for monadifying SAW core terms +newtype MonadifyM a = + MonadifyM { unMonadifyM :: + ReaderT MonadifyROState (StateT MonadifyMemoTable + (Cont MonTerm)) a } + deriving (Functor, Applicative, Monad, + MonadReader MonadifyROState, MonadState MonadifyMemoTable) + +instance Fail.MonadFail MonadifyM where + fail str = + do ret_tp <- topRetType + shiftMonadifyM $ \_ -> failMonTerm (mkMonType0 ret_tp) str + +-- | Capture the current continuation and pass it to a function, which must +-- return the final computation result. Note that this is slightly differnet +-- from normal shift, and I think corresponds to the C operator, but my quick +-- googling couldn't find the right name... +shiftMonadifyM :: ((a -> MonTerm) -> MonTerm) -> MonadifyM a +shiftMonadifyM f = MonadifyM $ lift $ lift $ cont f + +-- | Locally run a 'MonadifyM' computation with an empty memoization table, +-- making all binds be local to that computation, and return the result +resetMonadifyM :: OpenTerm -> MonadifyM MonTerm -> MonadifyM MonTerm +resetMonadifyM ret_tp m = + do ro_st <- ask + return $ runMonadifyM (monStEnv ro_st) (monStCtx ro_st) ret_tp m + +-- | Get the monadified return type of the top-level term being monadified +topRetType :: MonadifyM OpenTerm +topRetType = monStTopRetType <$> ask + +-- | Run a monadification computation +-- +-- FIXME: document the arguments +runMonadifyM :: MonadifyEnv -> MonadifyCtx -> OpenTerm -> + MonadifyM MonTerm -> MonTerm +runMonadifyM env ctx top_ret_tp m = + let ro_st = MonadifyROState env ctx top_ret_tp in + runCont (evalStateT (runReaderT (unMonadifyM m) ro_st) emptyMemoTable) id + +-- | Run a monadification computation using a mapping for identifiers that have +-- already been monadified and generate a SAW core term +runCompleteMonadifyM :: MonadIO m => SharedContext -> MonadifyEnv -> + Term -> MonadifyM MonTerm -> m Term +runCompleteMonadifyM sc env top_ret_tp m = + liftIO $ completeOpenTerm sc $ toCompTerm $ + runMonadifyM env [] (toArgType $ monadifyType [] top_ret_tp) m + +-- | Memoize a computation of the monadified term associated with a 'TermIndex' +memoizingM :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm +memoizingM i m = + (IntMap.lookup i <$> get) >>= \case + Just ret -> + return ret + Nothing -> + do ret <- m + modify (IntMap.insert i ret) + return ret + +-- | Turn a 'MonTerm' of type @CompMT(tp)@ to a term of argument type @MT(tp)@ +-- by inserting a monadic bind if the 'MonTerm' is computational +argifyMonTerm :: MonTerm -> MonadifyM ArgMonTerm +argifyMonTerm (ArgMonTerm mtrm) = return mtrm +argifyMonTerm (CompMonTerm mtp trm) = + do let tp = toArgType mtp + top_ret_tp <- topRetType + shiftMonadifyM $ \k -> + CompMonTerm (mkMonType0 top_ret_tp) $ + applyOpenTermMulti (globalOpenTerm "Prelude.bindM") + [tp, top_ret_tp, trm, + lambdaOpenTerm "x" tp (toCompTerm . k . fromArgTerm mtp)] + +-- | Build a proof of @isFinite n@ by calling @assertFiniteM@ and binding the +-- result to an 'ArgMonTerm' +assertIsFinite :: MonType -> MonadifyM ArgMonTerm +assertIsFinite (MTyNum n) = + argifyMonTerm (CompMonTerm + (mkMonType0 (applyOpenTerm + (globalOpenTerm "CryptolM.isFinite") n)) + (applyOpenTerm (globalOpenTerm "CryptolM.assertFiniteM") n)) +assertIsFinite _ = + fail ("assertIsFinite applied to non-Num argument") + + +---------------------------------------------------------------------- +-- * Monadification +---------------------------------------------------------------------- + +-- | Monadify a type in the context of the 'MonadifyM' monad +monadifyTypeM :: Term -> MonadifyM MonType +monadifyTypeM tp = + do ctx <- monStCtx <$> ask + return $ monadifyType (ctxToTypeCtx ctx) tp + +-- | Monadify a term to a monadified term of argument type +monadifyArg :: Maybe MonType -> Term -> MonadifyM ArgMonTerm +monadifyArg mtp t = monadifyTerm mtp t >>= argifyMonTerm + +-- | Monadify a term to argument type and convert back to a term +monadifyArgTerm :: Maybe MonType -> Term -> MonadifyM OpenTerm +monadifyArgTerm mtp t = toArgTerm <$> monadifyArg mtp t + +-- | Monadify a term +monadifyTerm :: Maybe MonType -> Term -> MonadifyM MonTerm +{- +monadifyTerm _ t + | trace ("Monadifying term: " ++ showTerm t) False + = undefined +-} +monadifyTerm mtp t@(STApp { stAppIndex = ix }) = + memoizingM ix $ monadifyTerm' mtp t +monadifyTerm mtp t = + monadifyTerm' mtp t + +-- | The main implementation of 'monadifyTerm', which monadifies a term given an +-- optional monadification type. The type must be given for introduction forms +-- (i.e.,, lambdas, pairs, and records), but is optional for elimination forms +-- (i.e., applications, projections, and also in this case variables). Note that +-- this means monadification will fail on terms with beta or tuple redexes. +monadifyTerm' :: Maybe MonType -> Term -> MonadifyM MonTerm +monadifyTerm' (Just mtp) t@(asLambda -> Just _) = + ask >>= \(MonadifyROState { monStEnv = env, monStCtx = ctx }) -> + return $ monadifyLambdas env ctx mtp t +{- +monadifyTerm' (Just mtp@(MTyForall _ _ _)) t = + ask >>= \ro_st -> + get >>= \table -> + return $ monadifyLambdas (monStEnv ro_st) table (monStCtx ro_st) mtp t +monadifyTerm' (Just mtp@(MTyArrow _ _)) t = + ask >>= \ro_st -> + get >>= \table -> + return $ monadifyLambdas (monStEnv ro_st) table (monStCtx ro_st) mtp t +-} +monadifyTerm' (Just mtp@(MTyPair mtp1 mtp2)) (asPairValue -> + Just (trm1, trm2)) = + fromArgTerm mtp <$> (pairOpenTerm <$> + monadifyArgTerm (Just mtp1) trm1 <*> + monadifyArgTerm (Just mtp2) trm2) +monadifyTerm' (Just mtp@(MTyRecord fs_mtps)) (asRecordValue -> Just trm_map) + | length fs_mtps == Map.size trm_map + , (fs,mtps) <- unzip fs_mtps + , Just trms <- mapM (\f -> Map.lookup f trm_map) fs = + fromArgTerm mtp <$> recordOpenTerm <$> zip fs <$> + zipWithM monadifyArgTerm (map Just mtps) trms +monadifyTerm' _ (asPairSelector -> Just (trm, False)) = + do mtrm <- monadifyArg Nothing trm + mtp <- case getMonType mtrm of + MTyPair t _ -> return t + _ -> fail "Monadification failed: projection on term of non-pair type" + return $ fromArgTerm mtp $ + pairLeftOpenTerm $ toArgTerm mtrm +monadifyTerm' (Just mtp@(MTySeq n mtp_elem)) (asFTermF -> + Just (ArrayValue _ trms)) = + do trms' <- traverse (monadifyArgTerm $ Just mtp_elem) trms + return $ fromArgTerm mtp $ + applyOpenTermMulti (globalOpenTerm "CryptolM.seqToMseq") + [n, toArgType mtp_elem, + flatOpenTerm $ ArrayValue (toArgType mtp_elem) trms'] +monadifyTerm' _ (asPairSelector -> Just (trm, True)) = + do mtrm <- monadifyArg Nothing trm + mtp <- case getMonType mtrm of + MTyPair _ t -> return t + _ -> fail "Monadification failed: projection on term of non-pair type" + return $ fromArgTerm mtp $ + pairRightOpenTerm $ toArgTerm mtrm +monadifyTerm' _ (asRecordSelector -> Just (trm, fld)) = + do mtrm <- monadifyArg Nothing trm + mtp <- case getMonType mtrm of + MTyRecord mtps | Just mtp <- lookup fld mtps -> return mtp + _ -> fail ("Monadification failed: " ++ + "record projection on term of incorrect type") + return $ fromArgTerm mtp $ projRecordOpenTerm (toArgTerm mtrm) fld +monadifyTerm' _ (asLocalVar -> Just ix) = + (monStCtx <$> ask) >>= \case + ctx | ix >= length ctx -> fail "Monadification failed: vaiable out of scope!" + ctx | (_,_,Right mtrm) <- ctx !! ix -> return mtrm + _ -> fail "Monadification failed: type variable used in term position!" +monadifyTerm' _ (asCtor -> Just (pn, args)) = + monadifyApply (ArgMonTerm $ mkCtorArgMonTerm pn) args +monadifyTerm' _ (asApplyAll -> (asTypedGlobalDef -> Just glob, args)) = + (Map.lookup (globalDefName glob) <$> monStEnv <$> ask) >>= \case + Just macro -> + do let (macro_args, reg_args) = splitAt (macroNumArgs macro) args + mtrm_f <- macroApply macro glob macro_args + monadifyApply mtrm_f reg_args + Nothing -> error ("Monadification failed: unhandled constant: " + ++ globalDefString glob) +monadifyTerm' _ (asApp -> Just (f, arg)) = + do mtrm_f <- monadifyTerm Nothing f + monadifyApply mtrm_f [arg] +monadifyTerm' _ t = + (monStCtx <$> ask) >>= \ctx -> + fail ("Monadifiction failed: no case for term: " ++ ppTermInMonCtx ctx t) + + +-- | Monadify the application of a monadified term to a list of terms, using the +-- type of the already monadified to monadify the arguments +monadifyApply :: MonTerm -> [Term] -> MonadifyM MonTerm +monadifyApply f (t : ts) + | MTyArrow tp_in _ <- getMonType f = + do mtrm <- monadifyArg (Just tp_in) t + monadifyApply (applyMonTerm f (Right mtrm)) ts +monadifyApply f (t : ts) + | MTyForall _ _ _ <- getMonType f = + do mtp <- monadifyTypeM t + monadifyApply (applyMonTerm f (Left mtp)) ts +monadifyApply _ (_:_) = fail "monadifyApply: application at incorrect type" +monadifyApply f [] = return f + + +-- | FIXME: documentation; get our type down to a base type before going into +-- the MonadifyM monad +monadifyLambdas :: MonadifyEnv -> MonadifyCtx -> MonType -> Term -> MonTerm +monadifyLambdas env ctx (MTyForall _ k tp_f) (asLambda -> + Just (x, x_tp, body)) = + -- FIXME: check that monadifyKind x_tp == k + ArgMonTerm $ ForallMonTerm x k $ \mtp -> + monadifyLambdas env ((x,x_tp,Left mtp) : ctx) (tp_f mtp) body +monadifyLambdas env ctx (MTyArrow tp_in tp_out) (asLambda -> + Just (x, x_tp, body)) = + -- FIXME: check that monadifyType x_tp == tp_in + ArgMonTerm $ FunMonTerm x tp_in tp_out $ \arg -> + monadifyLambdas env ((x,x_tp,Right (ArgMonTerm arg)) : ctx) tp_out body +monadifyLambdas env ctx tp t = + monadifyEtaExpand env ctx tp tp t [] + +-- | FIXME: documentation +monadifyEtaExpand :: MonadifyEnv -> MonadifyCtx -> MonType -> MonType -> Term -> + [Either MonType ArgMonTerm] -> MonTerm +monadifyEtaExpand env ctx top_mtp (MTyForall x k tp_f) t args = + ArgMonTerm $ ForallMonTerm x k $ \mtp -> + monadifyEtaExpand env ctx top_mtp (tp_f mtp) t (args ++ [Left mtp]) +monadifyEtaExpand env ctx top_mtp (MTyArrow tp_in tp_out) t args = + ArgMonTerm $ FunMonTerm "_" tp_in tp_out $ \arg -> + monadifyEtaExpand env ctx top_mtp tp_out t (args ++ [Right arg]) +monadifyEtaExpand env ctx top_mtp mtp t args = + applyMonTermMulti + (runMonadifyM env ctx (toArgType mtp) (monadifyTerm (Just top_mtp) t)) + args + + +---------------------------------------------------------------------- +-- * Handling the Primitives +---------------------------------------------------------------------- + +-- | The macro for unsafeAssert, which checks the type of the objects being +-- compared and dispatches to the proper comparison function +unsafeAssertMacro :: MonMacro +unsafeAssertMacro = MonMacro 1 $ \_ ts -> + let numFunType = + MTyForall "n" (MKType $ mkSort 0) $ \n -> + MTyForall "m" (MKType $ mkSort 0) $ \m -> + MTyBase (MKType $ mkSort 0) $ + dataTypeOpenTerm "Prelude.Eq" + [dataTypeOpenTerm "Cryptol.Num" [], + toArgType n, toArgType m] in + case ts of + [(asDataType -> Just (num, []))] + | primName num == "Cryptol.Num" -> + return $ ArgMonTerm $ + mkGlobalArgMonTerm numFunType "CryptolM.numAssertEqM" + _ -> + fail "Monadification failed: unsafeAssert applied to non-Num type" + +-- | The macro for if-then-else, which contains any binds in a branch to that +-- branch +iteMacro :: MonMacro +iteMacro = MonMacro 4 $ \_ args -> + do let (tp, cond, branch1, branch2) = + case args of + [t1, t2, t3, t4] -> (t1, t2, t3, t4) + _ -> error "iteMacro: wrong number of arguments!" + atrm_cond <- monadifyArg (Just boolMonType) cond + mtp <- monadifyTypeM tp + mtrm1 <- resetMonadifyM (toArgType mtp) $ monadifyTerm (Just mtp) branch1 + mtrm2 <- resetMonadifyM (toArgType mtp) $ monadifyTerm (Just mtp) branch2 + case (mtrm1, mtrm2) of + (ArgMonTerm atrm1, ArgMonTerm atrm2) -> + return $ fromArgTerm mtp $ + applyOpenTermMulti (globalOpenTerm "Prelude.ite") + [toArgType mtp, toArgTerm atrm_cond, toArgTerm atrm1, toArgTerm atrm2] + _ -> + return $ fromCompTerm mtp $ + applyOpenTermMulti (globalOpenTerm "Prelude.ite") + [toCompType mtp, toArgTerm atrm_cond, + toCompTerm mtrm1, toCompTerm mtrm2] + + +-- | Make a 'MonMacro' that maps a named global whose first argument is @n:Num@ +-- to a global of semi-pure type that takes an additional argument of type +-- @isFinite n@ +fin1Macro :: Ident -> Ident -> MonMacro +fin1Macro from to = + MonMacro 1 $ \glob args -> + do if globalDefName glob == ModuleIdentifier from && length args == 1 then + return () + else error ("Monadification macro for " ++ show from ++ + " applied incorrectly") + -- Monadify the first arg, n, and build a proof it is finite + n_mtp <- monadifyTypeM (head args) + let n = toArgType n_mtp + fin_pf <- assertIsFinite n_mtp + -- Apply the type of @glob@ to n, and apply @to@ to n and fin_pf + let glob_tp = monadifyType [] $ globalDefType glob + let glob_tp_app = applyMonType glob_tp $ Left n_mtp + let to_app = applyOpenTermMulti (globalOpenTerm to) [n, toArgTerm fin_pf] + -- Finally, return @to n fin_pf@ as a MonTerm of monadified type @to_tp n@ + return $ ArgMonTerm $ fromSemiPureTerm glob_tp_app to_app + +-- | Helper function: build a @LetRecType@ for a nested pi type +lrtFromMonType :: MonType -> OpenTerm +lrtFromMonType (MTyForall x k body_f) = + ctorOpenTerm "Prelude.LRT_Fun" + [monKindOpenTerm k, + lambdaOpenTerm x (monKindOpenTerm k) (\tp -> lrtFromMonType $ + body_f $ MTyBase k tp)] +lrtFromMonType (MTyArrow mtp1 mtp2) = + ctorOpenTerm "Prelude.LRT_Fun" + [toArgType mtp1, + lambdaOpenTerm "_" (toArgType mtp1) (\_ -> lrtFromMonType mtp2)] +lrtFromMonType mtp = + ctorOpenTerm "Prelude.LRT_Ret" [toArgType mtp] + + +-- | The macro for fix +-- +-- FIXME: does not yet handle mutual recursion +fixMacro :: MonMacro +fixMacro = MonMacro 2 $ \_ args -> case args of + [tp@(asPi -> Just _), f] -> + do mtp <- monadifyTypeM tp + amtrm_f <- monadifyArg (Just $ MTyArrow mtp mtp) f + return $ fromCompTerm mtp $ + applyOpenTermMulti (globalOpenTerm "Prelude.multiArgFixM") + [lrtFromMonType mtp, toCompTerm amtrm_f] + [(asRecordType -> Just _), _] -> + fail "Monadification failed: cannot yet handle mutual recursion" + _ -> error "fixMacro: malformed arguments!" + +-- | A "macro mapping" maps a single pure identifier to a 'MonMacro' for it +type MacroMapping = (NameInfo, MonMacro) + +-- | Build a 'MacroMapping' for an identifier to a semi-pure named function +mmSemiPure :: Ident -> Ident -> MacroMapping +mmSemiPure from_id to_id = + (ModuleIdentifier from_id, semiPureGlobalMacro from_id to_id) + +-- | Build a 'MacroMapping' for an identifier to a semi-pure named function +-- whose first argument is a @Num@ that requires an @isFinite@ proof +mmSemiPureFin1 :: Ident -> Ident -> MacroMapping +mmSemiPureFin1 from_id to_id = + (ModuleIdentifier from_id, fin1Macro from_id to_id) + +-- | Build a 'MacroMapping' for an identifier to itself as a semi-pure function +mmSelf :: Ident -> MacroMapping +mmSelf self_id = + (ModuleIdentifier self_id, semiPureGlobalMacro self_id self_id) + +-- | Build a 'MacroMapping' from an identifier to a function of argument type +mmArg :: Ident -> Ident -> MacroMapping +mmArg from_id to_id = (ModuleIdentifier from_id, + argGlobalMacro (ModuleIdentifier from_id) to_id) + +-- | Build a 'MacroMapping' from an identifier and a custom 'MonMacro' +mmCustom :: Ident -> MonMacro -> MacroMapping +mmCustom from_id macro = (ModuleIdentifier from_id, macro) + +-- | The default monadification environment +defaultMonEnv :: MonadifyEnv +defaultMonEnv = + Map.fromList + [ + -- Prelude functions + mmCustom "Prelude.unsafeAssert" unsafeAssertMacro + , mmCustom "Prelude.ite" iteMacro + , mmCustom "Prelude.fix" fixMacro + + -- Top-level sequence functions + , mmArg "Cryptol.seqMap" "CryptolM.seqMapM" + , mmSemiPure "Cryptol.seq_cong1" "CryptolM.mseq_cong1" + , mmArg "Cryptol.eListSel" "CryptolM.eListSelM" + + -- List comprehensions + , mmArg "Cryptol.from" "CryptolM.fromM" + -- FIXME: continue here... + + -- PEq constraints + , mmSemiPureFin1 "Cryptol.PEqSeq" "CryptolM.PEqMSeq" + , mmSemiPureFin1 "Cryptol.PEqSeqBool" "CryptolM.PEqMSeqBool" + + -- PCmp constraints + , mmSemiPureFin1 "Cryptol.PCmpSeq" "CryptolM.PCmpMSeq" + , mmSemiPureFin1 "Cryptol.PCmpSeqBool" "CryptolM.PCmpMSeqBool" + + -- PSignedCmp constraints + , mmSemiPureFin1 "Cryptol.PSignedCmpSeq" "CryptolM.PSignedCmpMSeq" + , mmSemiPureFin1 "Cryptol.PSignedCmpSeqBool" "CryptolM.PSignedCmpMSeqBool" + + -- PZero constraints + , mmSemiPureFin1 "Cryptol.PZeroSeq" "CryptolM.PZeroMSeq" + + -- PLogic constraints + , mmSemiPure "Cryptol.PLogicSeq" "CryptolM.PLogicMSeq" + , mmSemiPureFin1 "Cryptol.PLogicSeqBool" "CryptolM.PLogicMSeqBool" + + -- PRing constraints + , mmSemiPure "Cryptol.PRingSeq" "CryptolM.PRingMSeq" + , mmSemiPureFin1 "Cryptol.PRingSeqBool" "CryptolM.PRingMSeqBool" + + -- PIntegral constraints + , mmSemiPureFin1 "Cryptol.PIntegeralSeqBool" "CryptolM.PIntegeralMSeqBool" + + -- PLiteral constraints + , mmSemiPureFin1 "Cryptol.PLiteralSeqBool" "CryptolM.PLiteralSeqBoolM" + + -- The Cryptol Literal primitives + , mmSelf "Cryptol.ecNumber" + , mmSelf "Cryptol.ecFromZ" + + -- The Ring primitives + , mmSelf "Cryptol.ecPlus" + , mmSelf "Cryptol.ecMinus" + , mmSelf "Cryptol.ecMul" + , mmSelf "Cryptol.ecNeg" + , mmSelf "Cryptol.ecToInteger" + + -- The comparison primitives + , mmSelf "Cryptol.ecEq" + , mmSelf "Cryptol.ecNotEq" + , mmSelf "Cryptol.ecLt" + , mmSelf "Cryptol.ecLtEq" + , mmSelf "Cryptol.ecGt" + , mmSelf "Cryptol.ecGtEq" + + -- Sequences + , mmSemiPure "Cryptol.ecShiftL" "CryptolM.ecShiftLM" + , mmSemiPure "Cryptol.ecShiftR" "CryptolM.ecShiftRM" + , mmSemiPure "Cryptol.ecSShiftR" "CryptolM.ecSShiftRM" + , mmSemiPureFin1 "Cryptol.ecRotL" "CryptolM.ecRotLM" + , mmSemiPureFin1 "Cryptol.ecRotR" "CryptolM.ecRotRM" + , mmSemiPureFin1 "Cryptol.ecCat" "CryptolM.ecCatM" + , mmSemiPure "Cryptol.ecTake" "CryptolM.ecTakeM" + , mmSemiPureFin1 "Cryptol.ecDrop" "CryptolM.ecDropM" + , mmSemiPure "Cryptol.ecJoin" "CryptolM.ecJoinM" + , mmSemiPure "Cryptol.ecSplit" "CryptolM.ecSplitM" + , mmSemiPureFin1 "Cryptol.ecReverse" "CryptolM.ecReverseM" + , mmSemiPure "Cryptol.ecTranspose" "CryptolM.ecTransposeM" + , mmArg "Cryptol.ecAt" "CryptolM.ecAtM" + -- , mmArgFin1 "Cryptol.ecAtBack" "CryptolM.ecAtBackM" + -- , mmSemiPureFin2 "Cryptol.ecFromTo" "CryptolM.ecFromToM" + , mmSemiPureFin1 "Cryptol.ecFromToLessThan" "CryptolM.ecFromToLessThanM" + -- , mmSemiPureNthFin 5 "Cryptol.ecFromThenTo" "CryptolM.ecFromThenToM" + , mmSemiPure "Cryptol.ecInfFrom" "CryptolM.ecInfFromM" + , mmSemiPure "Cryptol.ecInfFromThen" "CryptolM.ecInfFromThenM" + , mmArg "Cryptol.ecError" "CryptolM.ecErrorM" + ] + + +---------------------------------------------------------------------- +-- * Top-Level Entrypoints +---------------------------------------------------------------------- + +-- | Ensure that the @CryptolM@ module is loaded +ensureCryptolMLoaded :: SharedContext -> IO () +ensureCryptolMLoaded sc = + scModuleIsLoaded sc (mkModuleName ["CryptolM"]) >>= \is_loaded -> + if is_loaded then return () else + scLoadCryptolMModule sc + +-- | Monadify a type to its argument type and complete it to a 'Term' +monadifyCompleteArgType :: SharedContext -> Term -> IO Term +monadifyCompleteArgType sc tp = + completeOpenTerm sc $ monadifyTypeArgType [] tp + +-- | Monadify a term of the specified type to a 'MonTerm' and then complete that +-- 'MonTerm' to a SAW core 'Term', or 'fail' if this is not possible +monadifyCompleteTerm :: SharedContext -> MonadifyEnv -> Term -> Term -> IO Term +monadifyCompleteTerm sc env trm tp = + runCompleteMonadifyM sc env tp $ + monadifyTerm (Just $ monadifyType [] tp) trm + +-- | Convert a name of a definition to the name of its monadified version +monadifyName :: NameInfo -> IO NameInfo +monadifyName (ModuleIdentifier ident) = + return $ ModuleIdentifier $ mkIdent (identModule ident) $ + T.append (identBaseName ident) (T.pack "M") +monadifyName (ImportedName uri _) = + do frag <- URI.mkFragment (T.pack "M") + return $ ImportedName (uri { URI.uriFragment = Just frag }) [] + +-- | Monadify a 'Term' of the specified type with an optional body, bind the +-- result to a fresh SAW core constant generated from the supplied name, and +-- then convert that constant back to a 'MonTerm' +monadifyNamedTerm :: SharedContext -> MonadifyEnv -> + NameInfo -> Maybe Term -> Term -> IO MonTerm +monadifyNamedTerm sc env nmi maybe_trm tp = + trace ("Monadifying " ++ T.unpack (toAbsoluteName nmi)) $ + do let mtp = monadifyType [] tp + nmi' <- monadifyName nmi + comp_tp <- completeOpenTerm sc $ toCompType mtp + const_trm <- + case maybe_trm of + Just trm -> + do trm' <- monadifyCompleteTerm sc env trm tp + scConstant' sc nmi' trm' comp_tp + Nothing -> scOpaqueConstant sc nmi' tp + return $ fromCompTerm mtp $ closedOpenTerm const_trm + +-- | Monadify a term with the specified type along with all constants it +-- contains, adding the monadifications of those constants to the monadification +-- environment +monadifyTermInEnv :: SharedContext -> MonadifyEnv -> Term -> Term -> + IO (Term, MonadifyEnv) +monadifyTermInEnv sc top_env top_trm top_tp = + flip runStateT top_env $ + do lift $ ensureCryptolMLoaded sc + let const_infos = + map snd $ Map.toAscList $ getConstantSet top_trm + forM_ const_infos $ \(nmi,tp,maybe_body) -> + get >>= \env -> + if Map.member nmi env then return () else + do mtrm <- lift $ monadifyNamedTerm sc env nmi maybe_body tp + put $ Map.insert nmi (monMacro0 mtrm) env + env <- get + lift $ monadifyCompleteTerm sc env top_trm top_tp diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/PreludeM.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/PreludeM.hs new file mode 100644 index 0000000000..e984c25310 --- /dev/null +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/PreludeM.hs @@ -0,0 +1,22 @@ +{-# LANGUAGE TemplateHaskell #-} + +{- | +Module : Verifier.SAW.Cryptol.Prelude +Copyright : Galois, Inc. 2012-2015 +License : BSD3 +Maintainer : huffman@galois.com +Stability : experimental +Portability : non-portable (language extensions) +-} + +module Verifier.SAW.Cryptol.PreludeM + ( Module + , module Verifier.SAW.Cryptol.PreludeM + , scLoadPreludeModule + ) where + +import Verifier.SAW.Prelude +import Verifier.SAW.ParserUtils + +$(defineModuleFromFileWithFns + "cryptolMModule" "scLoadCryptolMModule" "saw/CryptolM.sawcore") diff --git a/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs b/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs index 66dd22db6b..1e9a9fc9ad 100644 --- a/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs +++ b/cryptol-saw-core/src/Verifier/SAW/CryptolEnv.hs @@ -33,6 +33,7 @@ module Verifier.SAW.CryptolEnv , lookupIn , resolveIdentifier , meSolverConfig + , mkCryEnv , C.ImportPrimitiveOptions(..) , C.defaultPrimitiveOptions ) diff --git a/examples/mr_solver/SpecPrims.cry b/examples/mr_solver/SpecPrims.cry new file mode 100644 index 0000000000..4d938e3ecd --- /dev/null +++ b/examples/mr_solver/SpecPrims.cry @@ -0,0 +1,35 @@ + +module SpecPrims where + +/* Specification primitives */ + +// The specification that holds for f x for some input x +exists : {a, b} (a -> b) -> b +exists f = error "Cannot run exists" + +// The specification that holds for f x for all inputs x +forall : {a, b} (a -> b) -> b +forall f = error "Cannot run forall" + +// The specification that a computation has no errors +noErrors : {a} a +noErrors = exists (\x -> x) + +// The specification that matches any computation. This calls exists at the +// function type () -> a, which is monadified to () -> CompM a. This means that +// the exists does not just quantify over all values of type a like noErrors, +// but it quantifies over all computations of type a, including those that +// contain errors. +anySpec : {a} a +anySpec = exists (\f -> f ()) + +// The specification which asserts that the first argument is True and then +// returns the second argument +asserting : {a} Bit -> a -> a +asserting b x = + if b then x else error "Assertion failed" + +// The specification which assumes that the first argument is True and then +// returns the second argument +assuming : {a} Bit -> a -> a +assuming b x = if b then x else anySpec diff --git a/examples/mr_solver/monadify.cry b/examples/mr_solver/monadify.cry new file mode 100644 index 0000000000..1d5659f5f7 --- /dev/null +++ b/examples/mr_solver/monadify.cry @@ -0,0 +1,25 @@ + +module Monadify where + +import SpecPrims + +my_abs : [64] -> [64] +my_abs x = if x < 0 then -x else x + +err_if_lt0 : [64] -> [64] +err_if_lt0 x = + if x < 0 then error "x < 0" else x + +sha1 : ([8], [32], [32], [32]) -> [32] +sha1 (t, x, y, z) = + if (0 <= t) && (t <= 19) then (x && y) ^ (~x && z) + | (20 <= t) && (t <= 39) then x ^ y ^ z + | (40 <= t) && (t <= 59) then (x && y) ^ (x && z) ^ (y && z) + | (60 <= t) && (t <= 79) then x ^ y ^ z + else error "sha1: t out of range" + +fib : [64] -> [64] +fib x = if x == 0 then 1 else x * fib (x - 1) + +fibSpecNoErrors : [64] -> [64] +fibSpecNoErrors _ = noErrors diff --git a/examples/mr_solver/monadify.saw b/examples/mr_solver/monadify.saw new file mode 100644 index 0000000000..5b5ba8974e --- /dev/null +++ b/examples/mr_solver/monadify.saw @@ -0,0 +1,51 @@ + +enable_experimental; +import "SpecPrims.cry" as SpecPrims; +import "monadify.cry"; +set_monadification "SpecPrims::exists" "Prelude.existsM"; +set_monadification "SpecPrims::forall" "Prelude.forallM"; + +my_abs <- unfold_term ["my_abs"] {{ my_abs }}; +print "[my_abs] original term:"; +print_term my_abs; +my_absM <- monadify_term my_abs; +print "[my_abs] monadified term:"; +print_term my_absM; + +/* +err_if_lt0 <- unfold_term ["err_if_lt0"] {{ err_if_lt0 }}; +print "[err_if_lt0] original term:"; +err_if_lt0M <- monadify_term err_if_lt0; +print "[err_if_lt0] monadified term:"; +print_term err_if_lt0M; +*/ + +/* +sha1 <- {{ sha1 }}; +print "[SHA1] original term:"; +print_term sha1; +mtrm <- monadify_term sha1; +print "[SHA1] monadified term:"; +print_term mtrm; +*/ + +fib <- unfold_term ["fib"] {{ fib }}; +print "[fib] original term:"; +print_term fib; +fibM <- monadify_term fib; +print "[fib] monadified term:"; +print_term fibM; + +noErrors <- unfold_term ["noErrors"] {{ SpecPrims::noErrors }}; +print "[noErrors] original term:"; +print_term noErrors; +noErrorsM <- monadify_term noErrors; +print "[noErrors] monadified term:"; +print_term noErrorsM; + +fibSpecNoErrors <- unfold_term ["fibSpecNoErrors"] {{ fibSpecNoErrors }}; +print "[fibSpecNoErrors] original term:"; +print_term fibSpecNoErrors; +fibSpecNoErrorsM <- monadify_term fibSpecNoErrors; +print "[fibSpecNoErrors] monadified term:"; +print_term fibSpecNoErrorsM; diff --git a/examples/mr_solver/mr_solver_unit_tests.saw b/examples/mr_solver/mr_solver_unit_tests.saw new file mode 100644 index 0000000000..5366cdbab2 --- /dev/null +++ b/examples/mr_solver/mr_solver_unit_tests.saw @@ -0,0 +1,77 @@ +enable_experimental; + +let eq_bool b1 b2 = + if b1 then + if b2 then true else false + else + if b2 then false else true; + +let fail = do { print "Test failed"; exit 1; }; +let run_test name test expected = + do { if expected then print (str_concat "Test: " name) else + print (str_concat (str_concat "Test: " name) " (expecting failure)"); + actual <- test; + if eq_bool actual expected then print "Success\n" else + do { print "Test failed\n"; exit 1; }; }; + +// The constant 0 function const0 x = 0 +const0 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 0)"; + +// The constant 1 function const1 x = 1 +const1 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 1)"; + +// const0 <= const0 +run_test "mr_solver const0 const0" (mr_solver const0 const0) true; + +// The function test_fun0 from the prelude = const0 +test_fun0 <- parse_core "test_fun0"; +run_test "mr_solver const0 test_fun0" (mr_solver const0 test_fun0) true; + +// not const0 <= const1 +run_test "mr_solver const0 const1" (mr_solver const0 const1) false; + +// The function test_fun1 from the prelude = const1 +test_fun1 <- parse_core "test_fun1"; +run_test "mr_solver const1 test_fun1" (mr_solver const1 test_fun1) true; +run_test "mr_solver const0 test_fun1" (mr_solver const0 test_fun1) false; + +// ifxEq0 x = If x == 0 then x else 0; should be equal to 0 +ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \ + \ ite (CompM (Vec 64 Bool)) (bvEq 64 x (bvNat 64 0)) \ + \ (returnM (Vec 64 Bool) x) \ + \ (returnM (Vec 64 Bool) (bvNat 64 0))"; + +// ifxEq0 <= const0 +run_test "mr_solver ifxEq0 const0" (mr_solver ifxEq0 const0) true; + +// not ifxEq0 <= const1 +run_test "mr_solver ifxEq0 const1" (mr_solver ifxEq0 const1) false; + +// noErrors1 x = exists x. returnM x +noErrors1 <- parse_core "\\ (x:Vec 64 Bool) -> \ + \ existsM (Vec 64 Bool) (Vec 64 Bool) \ + \ (\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool) x)"; + +// const0 <= noErrors +run_test "mr_solver noErrors1 noErrors1" (mr_solver noErrors1 noErrors1) true; + +// const1 <= noErrors +run_test "mr_solver const1 noErrors1" (mr_solver const1 noErrors1) true; + +// noErrorsRec1 x = orM (existsM x. returnM x) (noErrorsRec1 x) +// Intuitively, this specifies functions that either return a value or loop +noErrorsRec1 <- parse_core + "fixM (Vec 64 Bool) (\\ (_:Vec 64 Bool) -> Vec 64 Bool) \ + \ (\\ (f: Vec 64 Bool -> CompM (Vec 64 Bool)) (x:Vec 64 Bool) -> \ + \ orM (Vec 64 Bool) \ + \ (existsM (Vec 64 Bool) (Vec 64 Bool) \ + \ (\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool) x)) \ + \ (f x))"; + +// loop x = loop x +loop1 <- parse_core + "fixM (Vec 64 Bool) (\\ (_:Vec 64 Bool) -> Vec 64 Bool) \ + \ (\\ (f: Vec 64 Bool -> CompM (Vec 64 Bool)) (x:Vec 64 Bool) -> f x)"; + +// loop1 <= noErrorsRec1 +run_test "mr_solver loop1 noErrorsRec1" (mr_solver loop1 noErrorsRec1) true; diff --git a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs index b110ea3c69..c772b7d01d 100644 --- a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs +++ b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs @@ -451,6 +451,8 @@ sawCorePreludeSpecialTreatmentMap configuration = , ("errorM", replace (Coq.App (Coq.ExplVar "errorM") [Coq.Var "CompM", Coq.Var "_"])) , ("catchM", skip) + , ("existsM", mapsTo compMModule "existsM") + , ("forallM", mapsTo compMModule "forallM") , ("fixM", replace (Coq.App (Coq.ExplVar "fixM") [Coq.Var "CompM", Coq.Var "_"])) , ("existsM", mapsToExpl compMModule "existsM") diff --git a/saw-core/prelude/Prelude.sawcore b/saw-core/prelude/Prelude.sawcore index a8ef48ecef..7b0676ff67 100644 --- a/saw-core/prelude/Prelude.sawcore +++ b/saw-core/prelude/Prelude.sawcore @@ -645,12 +645,12 @@ and_triv2 (x : Bool) : Eq Bool (and (not x) x) False = -------------------------------------------------------------------------------- -- Converting Booleans to Propositions -FalseProp : Prop; -FalseProp = Eq Bool True False; - EqTrue : Bool -> Prop; EqTrue x = Eq Bool x True; +TrueProp : Prop; +TrueProp = EqTrue True; + TrueI : EqTrue True; TrueI = Refl Bool True; @@ -921,6 +921,20 @@ primitive gen : (n : Nat) -> (a : sort 0) -> (Nat -> a) -> Vec n a; primitive head : (n : Nat) -> (a : sort 0) -> Vec (Succ n) a -> a; primitive tail : (n : Nat) -> (a : sort 0) -> Vec (Succ n) a -> Vec n a; +-- An implementation for atWithDefault +-- +-- FIXME: can we replace atWithDefault with this implementation? Or does some +-- automation rely on atWithDefault being a primitive? +atWithDefault' : (n : Nat) -> (a : sort 0) -> a -> Vec n a -> Nat -> a; +atWithDefault' n_top a d = + Nat__rec + (\ (n:Nat) -> Vec n a -> Nat -> a) + (\ (_:Vec 0 a) (_:Nat) -> d) + (\ (n:Nat) (rec_f: Vec n a -> Nat -> a) (v:Vec (Succ n) a) (i:Nat) -> + Nat_cases a (head n a v) + (\ (i_prev:Nat) (_:a) -> rec_f (tail n a v) i_prev) i) + n_top; + primitive atWithDefault : (n : Nat) -> (a : sort 0) -> a -> Vec n a -> Nat -> a; at : (n : Nat) -> (a : isort 0) -> Vec n a -> Nat -> a; @@ -1680,8 +1694,12 @@ genBVVecFromVec m a v def n len = genBVVec n len a (\ (i:Vec n Bool) (_:is_bvult n i len) -> atWithDefault m a def v (bvToNat n i)); +-- The false proposition +FalseProp : Prop; +FalseProp = Eq Bool True False; + -- Ex Falso Quodlibet: if True = False then anything is possible -efq : (a : sort 0) -> Eq Bool True False -> a; +efq : (a : sort 0) -> FalseProp -> a; efq a contra = coerce #() a @@ -1697,6 +1715,13 @@ efq a contra = )) (); +-- Ex Falso Quodlibet at sort 1 +efq1 : (a : sort 1) -> Eq Bool True False -> a; +efq1 a contra = + Eq#rec Bit Bit1 + (\ (b:Bit) (_:Eq Bit Bit1 b) -> Bit#rec (\ (_:Bit) -> sort 1) #() a b) + () Bit0 (efq (Eq Bit Bit1 Bit0) contra); + -- Generate an empty BVVec emptyBVVec : (n : Nat) -> (a : sort 0) -> BVVec n (bvNat n 0) a; emptyBVVec n a = @@ -2019,7 +2044,6 @@ primitive CompM : sort 0 -> sort 0; primitive returnM : (a:sort 0) -> a -> CompM a; primitive bindM : (a b:sort 0) -> CompM a -> (a -> CompM b) -> CompM b; -primitive existsM : (a:sort 0) -> (b:sort 0) -> (a -> CompM b) -> CompM b; -- Raise an error in the computation monad primitive errorM : (a:sort 0) -> String -> CompM a; @@ -2042,6 +2066,28 @@ fmapM3 : (a b c d: sort 0) -> (a -> b -> c -> d) -> CompM a -> CompM b -> CompM c -> CompM d; fmapM3 a b c d f m1 m2 m3 = applyM c d (fmapM2 a b (c -> d) f m1 m2) m3; +-- Bind two values and pass them to a binary function +bindM2 : (a b c: sort 0) -> CompM a -> CompM b -> (a -> b -> CompM c) -> CompM c; +bindM2 a b c m1 m2 f = bindM a c m1 (\ (x:a) -> bindM b c m2 (f x)); + +-- Bind three values and pass them to a trinary function +bindM3 : (a b c d: sort 0) -> CompM a -> CompM b -> CompM c -> + (a -> b -> c -> CompM d) -> CompM d; +bindM3 a b c d m1 m2 m3 f = bindM a d m1 (\ (x:a) -> bindM2 b c d m2 m3 (f x)); + +-- A version of bind that takes the function first +bindApplyM : (a b : sort 0) -> (a -> CompM b) -> CompM a -> CompM b; +bindApplyM a b f m = bindM a b m f; + +-- A version of bindM2 that takes the function first +bindApplyM2 : (a b c: sort 0) -> (a -> b -> CompM c) -> CompM a -> CompM b -> CompM c; +bindApplyM2 a b c f m1 m2 = bindM a c m1 (\ (x:a) -> bindM b c m2 (f x)); + +-- A version of bindM3 that takes the function first +bindApplyM3 : (a b c d: sort 0) -> (a -> b -> c -> CompM d) -> + CompM a -> CompM b -> CompM c -> CompM d; +bindApplyM3 a b c d f m1 m2 m3 = bindM3 a b c d m1 m2 m3 f; + -- Compose two monadic functions composeM : (a b c: sort 0) -> (a -> CompM b) -> (b -> CompM c) -> a -> CompM c; composeM a b c f g x = bindM b c (f x) g; @@ -2096,11 +2142,34 @@ appendCastBVVecM n len1 len2 len3 a v1 v2 = -- run the second computation -- primitive catchM : (a:sort 0) -> CompM a -> CompM a -> CompM a; --- We can define fixM as let rec f x = ... in f +-- The computation that nondeterministically chooses a value of type a and +-- passes it to the supplied function f to get a computation of type b. As a +-- specification, this is the union of computations f x. +primitive existsM : (a b:sort 0) -> (a -> CompM b) -> CompM b; + +-- The computation that nondeterministically chooses one computation or another. +-- As a specification, represents the disjunction of two specifications. +orM : (a : sort 0) -> CompM a -> CompM a -> CompM a; +orM a m1 m2 = existsM Bool a (\ (b:Bool) -> ite (CompM a) b m1 m2); + +-- The specification formed from the intersection of all computations f x for +-- all possible inputs x. Computationally, this is sort of like running f for +-- all possible inputs x at the same time and then raising an error if any of +-- those computations diverge from each other. +primitive forallM : (a b:sort 0) -> (a -> CompM b) -> CompM b; + +-- NOTE: for the simplicity and efficiency of MR solver, we define all +-- fixed-point computations in CompM via a primitive multiFixM, defined below. +-- Thus, even though fixM is really the primitive operation, we write this file +-- as if multiFixM is, but I am leaving this version of fixM commented out here +-- to keep this decision explicitly documented and to make it easier to switch +-- back to having fixM be primitive if we decide to do so later. +-- +{- primitive fixM : (a:sort 0) -> (b:a -> sort 0) -> (((x:a) -> CompM (b x)) -> ((x:a) -> CompM (b x))) -> (x:a) -> CompM (b x); --- fixM a b fn x = letRecM1 a b b fn (\ (f:a -> CompM b) -> f x); +-} -- A representation of the type (x1:A1) -> ... -> (xn:An) -> CompM (B x1 ... xn) data LetRecType : sort 1 where { @@ -2193,6 +2262,27 @@ lrtPi lrts b = (\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> lrtToType lrt -> rest) lrts; +-- Apply a function the the body of a multi-arity lrtPi function +lrtPiMap : (a b : sort 0) -> (f : a -> b) -> (lrts : LetRecTypes) -> + lrtPi lrts a -> lrtPi lrts b; +lrtPiMap a b f lrts_top = + LetRecTypes#rec + (\ (lrts:LetRecTypes) -> lrtPi lrts a -> lrtPi lrts b) + (\ (x:a) -> f x) + (\ (lrt:LetRecType) (lrts:LetRecTypes) (rec:lrtPi lrts a -> lrtPi lrts b) + (f:lrtToType lrt -> lrtPi lrts a) (g:lrtToType lrt) -> + rec (f g)) + lrts_top; + +-- Convert a multi-arity lrtPi that returns a pair to a pair of lrtPi functions +-- that return the individual arguments +lrtPiPair : (a b:sort 0) -> (lrts : LetRecTypes) -> lrtPi lrts #(a,b) -> + #(lrtPi lrts a, lrtPi lrts b); +lrtPiPair a b lrts f = + (lrtPiMap #(a,b) a (\ (tup:#(a,b)) -> tup.(1)) lrts f, + lrtPiMap #(a,b) b (\ (tup:#(a,b)) -> tup.(2)) lrts f); + + -- Build the product type (lrtToType lrt1, ..., lrtToType lrtn) from the -- LetRecTypes list [lrt1, ..., lrtn] lrtTupleType : LetRecTypes -> sort 0; @@ -2203,9 +2293,20 @@ lrtTupleType lrts = (\ (lrt:LetRecType) (_:LetRecTypes) (rest:sort 0) -> #(lrtToType lrt, rest)) lrts; --- NOTE: the following are needed to define multiFixM instead of making it a +-- NOTE: the following are needed to define letRecM instead of making it a -- primitive, which we are keeping commented here in case that is needed {- +-- Apply a multi-arity function of type lrtPi lrts B to an lrtTupleType lrts +lrtApply : (lrts:LetRecTypes) -> (B:sort 0) -> lrtPi lrts B -> lrtTupleType lrts -> B; +lrtApply top_lrts B = + LetRecTypes#rec + (\ (lrts:LetRecTypes) -> lrtPi lrts B -> lrtTupleType lrts -> B) + (\ (F:B) (_:#()) -> F) + (\ (lrt:LetRecType) (lrts:LetRecTypes) (rest:lrtPi lrts B -> lrtTupleType lrts -> B) + (F:lrtPi (LRT_Cons lrt lrts) B) (fs:lrtTupleType (LRT_Cons lrt lrts)) -> + rest (F fs.(1)) fs.(2)) + top_lrts; + -- Construct a multi-arity function of type lrtPi lrts B from one of type -- lrtTupleType lrts -> B lrtLambda : (lrts:LetRecTypes) -> (B:sort 0) -> (lrtTupleType lrts -> B) -> lrtPi lrts B; @@ -2219,17 +2320,6 @@ lrtLambda top_lrts B = rest (\ (fs:lrtTupleType lrts) -> F (f, fs))) top_lrts; --- Apply a multi-arity function of type lrtPi lrts B to an lrtTupleType lrts -lrtApply : (lrts:LetRecTypes) -> (B:sort 0) -> lrtPi lrts B -> lrtTupleType lrts -> B; -lrtApply top_lrts B = - LetRecTypes#rec - (\ (lrts:LetRecTypes) -> lrtPi lrts B -> lrtTupleType lrts -> B) - (\ (F:B) (_:#()) -> F) - (\ (lrt:LetRecType) (lrts:LetRecTypes) (rest:lrtPi lrts B -> lrtTupleType lrts -> B) - (F:lrtPi (LRT_Cons lrt lrts) B) (fs:lrtTupleType (LRT_Cons lrt lrts)) -> - rest (F fs.(1)) fs.(2)) - top_lrts; - -- Build a multi-argument fixed-point of type A1 -> ... -> An -> CompM B multiArgFixM : (lrt:LetRecType) -> (lrtToType lrt -> lrtToType lrt) -> lrtToType lrt; @@ -2266,25 +2356,17 @@ multiFixM : (lrts:LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) -> lrtTupleType lrts; multiFixM lrts F = multiTupleFixM lrts (\ (fs:lrtTupleType lrts) -> lrtApply lrts (lrtTupleType lrts) F fs); - --- A letrec construct for binding 0 or more mutually recursive functions -letRecM : (lrts : LetRecTypes) -> (B:sort 0) -> lrtPi lrts (lrtTupleType lrts) -> - lrtPi lrts (CompM B) -> CompM B; -letRecM lrts B F body = lrtApply lrts (CompM B) body (multiFixM lrts F); -} --- Construct a fixed-point for a tuple of mutually-recursive functions -primitive multiFixM : (lrts:LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) -> - lrtTupleType lrts; - -- This is like let rec in ML: letRecM defs body defines N recursive functions -- in terms of themselves using defs, and then passes them to body. We use this -- instead of the more standard fixM because it offers a more compact -- representation, and because fixM messes with functional extensionality by -- introducing an irreducible term at function type. -primitive letRecM : (lrts : LetRecTypes) -> (b : sort 0) -> - (lrtPi lrts (lrtTupleType lrts)) -> - (lrtPi lrts (CompM b)) -> CompM b; +primitive letRecM : (lrts : LetRecTypes) -> (B:sort 0) -> + lrtPi lrts (lrtTupleType lrts) -> + lrtPi lrts (CompM B) -> CompM B; +-- letRecM lrts B F body = lrtApply lrts (CompM B) body (multiFixM lrts F); -- This is let rec with exactly one binding letRecM1 : (a b c : sort 0) -> ((a -> CompM b) -> (a -> CompM b)) -> @@ -2295,6 +2377,69 @@ letRecM1 a b c fn body = (\ (f:a -> CompM b) -> (fn f, ())) (\ (f:a -> CompM b) -> body f); +-- A single-argument fixed-point function +fixM : (a:sort 0) -> (b:a -> sort 0) -> + (((x:a) -> CompM (b x)) -> ((x:a) -> CompM (b x))) -> + (x:a) -> CompM (b x); +fixM a b f x = + letRecM (LRT_Cons (LRT_Fun a (\ (y:a) -> LRT_Ret (b y))) LRT_Nil) + (b x) + (\ (g: (y:a) -> CompM (b y)) -> (f g, ())) + (\ (g: (y:a) -> CompM (b y)) -> g x); + +-- Build a monadic function that takes in its arguments and then calls letRecM. +-- That is, build a function +-- +-- \x1 ... xn -> letRecM lrts F (\f1 ... fm -> body f1 ... fm x1 ... xn) +-- +-- where F recursively defines the fi functions and body defines the computation +-- for the function we are defining in terms of the fi and the xj arguments. +letRecFun : (lrts : LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) -> + (lrt : LetRecType) -> lrtPi lrts (lrtToType lrt) -> lrtToType lrt; +letRecFun lrts F lrt_top = + LetRecType#rec + (\ (lrt:LetRecType) -> lrtPi lrts (lrtToType lrt) -> lrtToType lrt) + (\ (b:sort 0) (body:lrtPi lrts (CompM b)) -> + letRecM lrts b F body) + (\ (a:sort 0) (lrtF: a -> LetRecType) + (rec: (x:a) -> lrtPi lrts (lrtToType (lrtF x)) -> lrtToType (lrtF x)) + (body:lrtPi lrts ((x:a) -> lrtToType (lrtF x))) + (x:a) -> + rec x (lrtPiMap ((y:a) -> lrtToType (lrtF y)) + (lrtToType (lrtF x)) + (\ (g:(y:a) -> lrtToType (lrtF y)) -> g x) + lrts + body)) + lrt_top; + +-- Build a multi-argument fixed-point of type A1 -> ... -> An -> CompM B +multiArgFixM : (lrt:LetRecType) -> (lrtToType lrt -> lrtToType lrt) -> + lrtToType lrt; +multiArgFixM lrt F = + letRecFun (LRT_Cons lrt LRT_Nil) + (\ (f:lrtToType lrt) -> (F f, ())) + lrt + (\ (f:lrtToType lrt) -> f); + +-- Construct a fixed-point for a tuple of mutually-recursive functions +multiFixM : (lrts:LetRecTypes) -> lrtPi lrts (lrtTupleType lrts) -> + lrtTupleType lrts; +multiFixM lrts_top F_top = + LetRecTypes#rec + (\ (lrts:LetRecTypes) -> lrtPi lrts_top (lrtTupleType lrts) -> + lrtTupleType lrts) + (\ (_:lrtPi lrts_top #()) -> ()) + (\ (lrt:LetRecType) (lrts:LetRecTypes) + (rec: lrtPi lrts_top (lrtTupleType lrts) -> lrtTupleType lrts) + (F: lrtPi lrts_top #(lrtToType lrt, lrtTupleType lrts)) -> + (letRecFun + lrts_top F_top lrt + (lrtPiPair (lrtToType lrt) (lrtTupleType lrts) lrts_top F).(1) + , + rec (lrtPiPair (lrtToType lrt) (lrtTupleType lrts) lrts_top F).(2))) + lrts_top + F_top; + -- Test computations test_fun0 : Vec 64 Bool -> CompM (Vec 64 Bool); diff --git a/saw-core/src/Verifier/SAW/OpenTerm.hs b/saw-core/src/Verifier/SAW/OpenTerm.hs index 3e04541712..d271153071 100644 --- a/saw-core/src/Verifier/SAW/OpenTerm.hs +++ b/saw-core/src/Verifier/SAW/OpenTerm.hs @@ -1,6 +1,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE OverloadedStrings #-} @@ -20,18 +21,21 @@ module Verifier.SAW.OpenTerm ( -- * Open terms and converting to closed terms OpenTerm(..), completeOpenTerm, completeNormOpenTerm, completeOpenTermType, -- * Basic operations for building open terms - closedOpenTerm, flatOpenTerm, sortOpenTerm, natOpenTerm, + closedOpenTerm, openOpenTerm, failOpenTerm, + bindTCMOpenTerm, bindPPOpenTerm, openTermType, + flatOpenTerm, sortOpenTerm, natOpenTerm, unitOpenTerm, unitTypeOpenTerm, stringLitOpenTerm, stringTypeOpenTerm, trueOpenTerm, falseOpenTerm, boolOpenTerm, boolTypeOpenTerm, arrayValueOpenTerm, bvLitOpenTerm, bvTypeOpenTerm, pairOpenTerm, pairTypeOpenTerm, pairLeftOpenTerm, pairRightOpenTerm, tupleOpenTerm, tupleTypeOpenTerm, projTupleOpenTerm, - ctorOpenTerm, dataTypeOpenTerm, globalOpenTerm, - applyOpenTerm, applyOpenTermMulti, + tupleOpenTerm', tupleTypeOpenTerm', + recordOpenTerm, recordTypeOpenTerm, projRecordOpenTerm, + ctorOpenTerm, dataTypeOpenTerm, globalOpenTerm, extCnsOpenTerm, + applyOpenTerm, applyOpenTermMulti, applyPiOpenTerm, piArgOpenTerm, lambdaOpenTerm, lambdaOpenTermMulti, piOpenTerm, piOpenTermMulti, - arrowOpenTerm, - letOpenTerm, sawLetOpenTerm, + arrowOpenTerm, letOpenTerm, sawLetOpenTerm, -- * Monadic operations for building terms with binders OpenTermM(..), completeOpenTermM, dedupOpenTermM, lambdaOpenTermM, piOpenTermM, @@ -49,6 +53,7 @@ import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IntMap import Verifier.SAW.Term.Functor +import Verifier.SAW.Term.Pretty import Verifier.SAW.SharedTerm import Verifier.SAW.SCTypeCheck import Verifier.SAW.Module @@ -79,6 +84,45 @@ completeOpenTermType sc (OpenTerm termM) = closedOpenTerm :: Term -> OpenTerm closedOpenTerm t = OpenTerm $ typeInferComplete t +-- | Embed a 'Term' in the given typing context into an 'OpenTerm' +openOpenTerm :: [(LocalName, Term)] -> Term -> OpenTerm +openOpenTerm ctx t = + -- Extend the local type-checking context, wherever this OpenTerm gets used, + -- by appending ctx to the end, so that variables 0..length ctx-1 all get + -- type-checked with ctx. If these are really the only free variables, then it + -- won't matter what the rest of the ambient context is. + -- + -- FIXME: we should check that the free variables of t are all < length ctx + OpenTerm $ withCtx ctx $ typeInferComplete t + +-- | Build an 'OpenTerm' that 'fail's in the underlying monad when completed +failOpenTerm :: String -> OpenTerm +failOpenTerm str = OpenTerm $ fail str + +-- | Bind the result of a type-checking computation in building an 'OpenTerm'. +-- NOTE: this operation should be considered "unsafe" because it can create +-- malformed 'OpenTerm's if the result of the 'TCM' computation is used as part +-- of the resulting 'OpenTerm'. For instance, @a@ should not be 'OpenTerm'. +bindTCMOpenTerm :: TCM a -> (a -> OpenTerm) -> OpenTerm +bindTCMOpenTerm m f = OpenTerm (m >>= unOpenTerm . f) + +-- | Bind the result of pretty-printing an 'OpenTerm' while building another +bindPPOpenTerm :: OpenTerm -> (String -> OpenTerm) -> OpenTerm +bindPPOpenTerm (OpenTerm m) f = + OpenTerm $ + do ctx <- askCtx + t <- typedVal <$> m + unOpenTerm $ f $ renderSawDoc defaultPPOpts $ + ppTermInCtx defaultPPOpts (map fst ctx) t + +-- | Return type type of an 'OpenTerm' as an 'OpenTerm +openTermType :: OpenTerm -> OpenTerm +openTermType (OpenTerm m) = + OpenTerm $ do TypedTerm _ tp <- m + ctx <- askCtx + tp_tp <- liftTCM scTypeOf' (map snd ctx) tp + return (TypedTerm tp tp_tp) + -- | Build an 'OpenTerm' from a 'FlatTermF' flatOpenTerm :: FlatTermF OpenTerm -> OpenTerm flatOpenTerm ftf = OpenTerm $ @@ -170,6 +214,37 @@ projTupleOpenTerm :: Integer -> OpenTerm -> OpenTerm projTupleOpenTerm 0 t = pairLeftOpenTerm t projTupleOpenTerm i t = projTupleOpenTerm (i-1) (pairRightOpenTerm t) +-- | Build a right-nested tuple as an 'OpenTerm' but without adding a final unit +-- as the right-most element +tupleOpenTerm' :: [OpenTerm] -> OpenTerm +tupleOpenTerm' [] = unitOpenTerm +tupleOpenTerm' ts = foldr1 pairTypeOpenTerm ts + +-- | Build a right-nested tuple type as an 'OpenTerm' +tupleTypeOpenTerm' :: [OpenTerm] -> OpenTerm +tupleTypeOpenTerm' [] = unitTypeOpenTerm +tupleTypeOpenTerm' ts = foldr1 pairTypeOpenTerm ts + +-- | Build a record value as an 'OpenTerm' +recordOpenTerm :: [(FieldName, OpenTerm)] -> OpenTerm +recordOpenTerm flds_ts = + OpenTerm $ do let (flds,ots) = unzip flds_ts + ts <- mapM unOpenTerm ots + typeInferComplete $ RecordValue $ zip flds ts + +-- | Build a record type as an 'OpenTerm' +recordTypeOpenTerm :: [(FieldName, OpenTerm)] -> OpenTerm +recordTypeOpenTerm flds_ts = + OpenTerm $ do let (flds,ots) = unzip flds_ts + ts <- mapM unOpenTerm ots + typeInferComplete $ RecordType $ zip flds ts + +-- | Project a field from a record +projRecordOpenTerm :: OpenTerm -> FieldName -> OpenTerm +projRecordOpenTerm (OpenTerm m) f = + OpenTerm $ do t <- m + typeInferComplete $ RecordProj t f + -- | Build an 'OpenTerm' for a constructor applied to its arguments ctorOpenTerm :: Ident -> [OpenTerm] -> OpenTerm ctorOpenTerm c all_args = OpenTerm $ do @@ -199,6 +274,10 @@ globalOpenTerm ident = tp <- liftTCM scTypeOfGlobal ident return $ TypedTerm trm tp) +-- | Build an 'OpenTerm' for an external constant +extCnsOpenTerm :: ExtCns Term -> OpenTerm +extCnsOpenTerm ec = OpenTerm (liftTCM scExtCns ec >>= typeInferComplete) + -- | Apply an 'OpenTerm' to another applyOpenTerm :: OpenTerm -> OpenTerm -> OpenTerm applyOpenTerm (OpenTerm f) (OpenTerm arg) = @@ -208,6 +287,30 @@ applyOpenTerm (OpenTerm f) (OpenTerm arg) = applyOpenTermMulti :: OpenTerm -> [OpenTerm] -> OpenTerm applyOpenTermMulti = foldl applyOpenTerm +-- | Compute the output type of applying a function of a given type to an +-- argument. That is, given @tp@ and @arg@, compute the type of applying any @f@ +-- of type @tp@ to @arg@. +applyPiOpenTerm :: OpenTerm -> OpenTerm -> OpenTerm +applyPiOpenTerm (OpenTerm m_f) (OpenTerm m_arg) = + OpenTerm $ + do f <- m_f + arg <- m_arg + ret <- applyPiTyped (NotFuncTypeInApp f arg) (typedVal f) arg + ctx <- askCtx + ret_tp <- liftTCM scTypeOf' (map snd ctx) ret + return (TypedTerm ret ret_tp) + +-- | Get the argument type of a function type, 'fail'ing if the input term is +-- not a function type +piArgOpenTerm :: OpenTerm -> OpenTerm +piArgOpenTerm (OpenTerm m) = + OpenTerm $ m >>= \case + (unwrapTermF . typedVal -> Pi _ tp _) -> typeInferComplete tp + t -> + do ctx <- askCtx + fail ("piArgOpenTerm: not a pi type: " ++ + scPrettyTermInCtx defaultPPOpts (map fst ctx) (typedVal t)) + -- | Build an 'OpenTerm' for the top variable in the current context, by -- building the 'TCM' computation which checks how much longer the context has -- gotten since the variable was created and uses this to compute its deBruijn diff --git a/saw-core/src/Verifier/SAW/Term/CtxTerm.hs b/saw-core/src/Verifier/SAW/Term/CtxTerm.hs index 95fe01fab9..ddb79fe7e0 100644 --- a/saw-core/src/Verifier/SAW/Term/CtxTerm.hs +++ b/saw-core/src/Verifier/SAW/Term/CtxTerm.hs @@ -75,6 +75,7 @@ import Data.Kind(Type) import Data.Proxy import Data.Type.Equality import Control.Monad +import Control.Monad.Trans import Data.Parameterized.Context @@ -356,6 +357,12 @@ class Monad m => MonadTerm m where -- ^ NOTE: the first term in the list is substituted for the most -- recently-bound variable, i.e., deBruijn index 0 +instance (MonadTerm m, MonadTrans t, Monad (t m)) => MonadTerm (t m) where + mkTermF = lift . mkTermF + liftTerm n i t = lift $ liftTerm n i t + whnfTerm = lift . whnfTerm + substTerm n s t = lift $ substTerm n s t + -- | Build a 'Term' from a 'FlatTermF' in a 'MonadTerm' mkFlatTermF :: MonadTerm m => FlatTermF Term -> m Term mkFlatTermF = mkTermF . FTermF diff --git a/saw-core/src/Verifier/SAW/Term/Functor.hs b/saw-core/src/Verifier/SAW/Term/Functor.hs index 1ee814b8e8..fb7ae57fef 100644 --- a/saw-core/src/Verifier/SAW/Term/Functor.hs +++ b/saw-core/src/Verifier/SAW/Term/Functor.hs @@ -57,7 +57,7 @@ module Verifier.SAW.Term.Functor , Sort, mkSort, propSort, sortOf, maxSort -- * Sets of free variables , BitSet, emptyBitSet, inBitSet, unionBitSets, intersectBitSets - , decrBitSet, completeBitSet, singletonBitSet + , decrBitSet, multiDecrBitSet, completeBitSet, singletonBitSet, bitSetElems , looseVars, smallestFreeVar ) where @@ -455,6 +455,12 @@ intersectBitSets (BitSet i1) (BitSet i2) = BitSet (i1 .&. i2) decrBitSet :: BitSet -> BitSet decrBitSet (BitSet i) = BitSet (shiftR i 1) +-- | Decrement all elements of a 'BitSet' by some non-negative amount @N@, +-- removing any value less than @N@. This is the same as calling 'decrBitSet' +-- @N@ times. +multiDecrBitSet :: Int -> BitSet -> BitSet +multiDecrBitSet n (BitSet i) = BitSet (shiftR i n) + -- | The 'BitSet' containing all elements less than a given index @i@ completeBitSet :: Int -> BitSet completeBitSet i = BitSet (bit i - 1) @@ -471,6 +477,16 @@ smallestBitSetElem (BitSet i) = Just $ go 0 i where where xw :: Word64 xw = fromInteger x +-- | Compute the list of all elements of a 'BitSet' +bitSetElems :: BitSet -> [Int] +bitSetElems = go 0 where + -- Return the addition of shft to all elements of a BitSet + go :: Int -> BitSet -> [Int] + go shft bs = case smallestBitSetElem bs of + Nothing -> [] + Just i -> + shft + i : go (shft + i + 1) (multiDecrBitSet (shft + i + 1) bs) + -- | Compute the free variables of a term given free variables for its immediate -- subterms freesTermF :: TermF BitSet -> BitSet diff --git a/saw-remote-api/src/SAWServer.hs b/saw-remote-api/src/SAWServer.hs index 655f3f461d..66e1269579 100644 --- a/saw-remote-api/src/SAWServer.hs +++ b/saw-remote-api/src/SAWServer.hs @@ -58,6 +58,7 @@ import SAWScript.Value (AIGProxy(..), BuiltinContext(..), JVMSetupM, LLVMCrucibl import qualified Verifier.SAW.Cryptol.Prelude as CryptolSAW import Verifier.SAW.CryptolEnv (initCryptolEnv, bindTypedTerm) import qualified Cryptol.Utils.Ident as Cryptol +import Verifier.SAW.Cryptol.Monadify (defaultMonEnv) import qualified Argo --import qualified CryptolServer (validateServerState, ServerState(..)) @@ -216,6 +217,7 @@ initialState readFileFn = , rwTypedef = mempty , rwDocs = mempty , rwCryptol = cenv + , rwMonadify = defaultMonEnv , rwPPOpts = defaultPPOpts , rwJVMTrans = jvmTrans , rwPrimsAvail = mempty diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 348f8e95b5..cb687716e5 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -54,8 +54,10 @@ import System.Process (callCommand, readProcessWithExitCode) import Text.Printf (printf) import Text.Read (readMaybe) +import qualified Cryptol.TypeCheck.AST as Cryptol import qualified Verifier.SAW.Cryptol as Cryptol import qualified Verifier.SAW.Cryptol.Simpset as Cryptol +import qualified Verifier.SAW.Cryptol.Monadify as Monadify -- saw-core import Verifier.SAW.Grammar (parseSAWTerm) @@ -1498,23 +1500,90 @@ cryptol_add_path path = let rw' = rw { rwCryptol = ce' } putTopLevelRW rw' -mr_solver_tests :: [SharedContext -> IO Term] -mr_solver_tests = - let helper nm = \sc -> scGlobalDef sc nm in - map helper - [ "Prelude.test_fun0", "Prelude.test_fun1", "Prelude.test_fun2" - , "Prelude.test_fun3", "Prelude.test_fun4", "Prelude.test_fun5" - , "Prelude.test_fun6"] - -testMRSolver :: Integer -> Integer -> TopLevel () -testMRSolver i1 i2 = - do sc <- getSharedContext - t1 <- liftIO $ (mr_solver_tests !! fromInteger i1) sc - t2 <- liftIO $ (mr_solver_tests !! fromInteger i2) sc - res <- liftIO $ Prover.askMRSolver sc SBV.z3 Nothing t1 t2 +-- | Call 'Cryptol.importSchema' using a 'CEnv.CryptolEnv' +importSchemaCEnv :: SharedContext -> CEnv.CryptolEnv -> Cryptol.Schema -> + IO Term +importSchemaCEnv sc cenv schema = + do cry_env <- let ?fileReader = StrictBS.readFile in CEnv.mkCryEnv cenv + Cryptol.importSchema sc cry_env schema + +monadifyTypedTerm :: SharedContext -> TypedTerm -> TopLevel TypedTerm +monadifyTypedTerm sc t = + do rw <- get + let menv = rwMonadify rw + (ret_t, menv') <- + liftIO $ + case ttType t of + TypedTermSchema schema -> + do tp <- importSchemaCEnv sc (rwCryptol rw) schema + Monadify.monadifyTermInEnv sc menv (ttTerm t) tp + TypedTermKind _ -> + fail "monadify_term applied to a type" + TypedTermOther tp -> + Monadify.monadifyTermInEnv sc menv (ttTerm t) tp + modify (\s -> s { rwMonadify = menv' }) + tp <- liftIO $ scTypeOf sc ret_t + return $ TypedTerm (TypedTermOther tp) ret_t + +-- | Ensure that a 'TypedTerm' has been monadified +ensureMonadicTerm :: SharedContext -> TypedTerm -> TopLevel TypedTerm +ensureMonadicTerm _ t + | TypedTermOther tp <- ttType t + , Prover.isCompFunType tp = return t +ensureMonadicTerm sc t = monadifyTypedTerm sc t + +mrSolver :: SharedContext -> Int -> TypedTerm -> TypedTerm -> TopLevel Bool +mrSolver sc dlvl t1 t2 = + do m1 <- ttTerm <$> ensureMonadicTerm sc t1 + m2 <- ttTerm <$> ensureMonadicTerm sc t2 + res <- liftIO $ Prover.askMRSolver sc dlvl SBV.z3 Nothing m1 m2 case res of - Just err -> io $ putStrLn $ Prover.showMRFailure err - Nothing -> io $ putStrLn "Success!" + Just err -> io (putStrLn $ Prover.showMRFailure err) >> return False + Nothing -> return True + +setMonadification :: SharedContext -> String -> String -> TopLevel () +setMonadification sc cry_str saw_str = + do rw <- get + + -- Step 1: convert the first string to a Cryptol name + cry_nm <- + let ?fileReader = StrictBS.readFile in + liftIO (CEnv.resolveIdentifier + (rwCryptol rw) (Text.pack cry_str)) >>= \case + Just n -> return n + Nothing -> fail ("No such Cryptol identifer: " ++ cry_str) + cry_nmi <- liftIO $ Cryptol.importName cry_nm + + -- Step 2: get the monadified type for this Cryptol name + -- + -- FIXME: not sure if this is the correct way to get the type of a Cryptol + -- name, so we are falling back on just translating the name to SAW core + -- and monadifying its type there + cry_saw_tp <- + liftIO $ + case Map.lookup cry_nm (CEnv.eExtraTypes $ rwCryptol rw) of + Just schema -> + -- putStrLn ("Found Cryptol type for name: " ++ show cry_str) >> + importSchemaCEnv sc (rwCryptol rw) schema + Nothing + | Just cry_nm_trans <- Map.lookup cry_nm (CEnv.eTermEnv $ + rwCryptol rw) -> + -- putStrLn ("No Cryptol type for name: " ++ cry_str) >> + scTypeOf sc cry_nm_trans + _ -> fail ("Could not find type for Cryptol name: " ++ cry_str) + cry_mon_tp <- liftIO $ Monadify.monadifyCompleteArgType sc cry_saw_tp + + -- Step 3: convert the second string to a typed SAW core term, and check + -- that it has the same type as the monadified type for the Cryptol name + let saw_ident = parseIdent saw_str + saw_trm <- liftIO $ scGlobalDef sc saw_ident + saw_tp <- liftIO $ scTypeOf sc saw_trm + liftIO $ scCheckSubtype sc Nothing (TC.TypedTerm saw_trm saw_tp) cry_mon_tp + + -- Step 4: Add a mapping from the Cryptol name to the SAW core term + put (rw { rwMonadify = + Map.insert cry_nmi (Monadify.argGlobalMacro + cry_nmi saw_ident) (rwMonadify rw) }) parseSharpSATResult :: String -> Maybe Integer parseSharpSATResult s = parse (lines s) diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 5621dd90f2..3553afac37 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -75,6 +75,7 @@ import Verifier.SAW.SharedTerm import Verifier.SAW.TypedAST hiding (FlatTermF(..)) import Verifier.SAW.TypedTerm import qualified Verifier.SAW.CryptolEnv as CEnv +import qualified Verifier.SAW.Cryptol.Monadify as Monadify import qualified Lang.JVM.Codebase as JCB @@ -472,6 +473,7 @@ buildTopLevelEnv proxy opts = , rwTypedef = Map.empty , rwDocs = primDocEnv primsAvail , rwCryptol = ce0 + , rwMonadify = Monadify.defaultMonEnv , rwProofs = [] , rwPPOpts = SAWScript.Value.defaultPPOpts , rwJVMTrans = jvmTrans @@ -3104,11 +3106,28 @@ primitives = Map.fromList --------------------------------------------------------------------- - , prim "test_mr_solver" "Int -> Int -> TopLevel Bool" - (pureVal testMRSolver) + , prim "mr_solver" "Term -> Term -> TopLevel Bool" + (scVal (\sc -> mrSolver sc 0)) Experimental [ "Call the monadic-recursive solver (that's MR. Solver to you)" - , " to ask if two monadic terms are equal" ] + , " to ask if one monadic term refines another" ] + + , prim "mr_solver_debug" "Int -> Term -> Term -> TopLevel Bool" + (scVal mrSolver) + Experimental + [ "Call the monadic-recursive solver at the supplied debug level" ] + + , prim "monadify_term" "Term -> TopLevel Term" + (scVal monadifyTypedTerm) + Experimental + [ "Monadify a Cryptol term, converting it to a form where all recursion" + , " and errors are represented as monadic operators"] + + , prim "set_monadification" "String -> String -> TopLevel Term" + (scVal setMonadification) + Experimental + [ "Set the monadification of a specific Cryptol identifer to a SAW core " + , "identifier of monadic type" ] , prim "heapster_init_env" "String -> String -> TopLevel HeapsterEnv" diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index 73b52762cf..d29bfa0617 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -2,246 +2,611 @@ {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} + +{- | +Module : SAWScript.Prover.MRSolver +Copyright : Galois, Inc. 2021 +License : BSD3 +Maintainer : westbrook@galois.com +Stability : experimental +Portability : non-portable (language extensions) + +This module implements a monadic-recursive solver, for proving that one monadic +term refines another. The algorithm works on the "monadic normal form" of +computations, which uses the following laws to simplify binds in computations, +where either is the sum elimination function defined in the SAW core prelude: + +returnM x >>= k = k x +errorM str >>= k = errorM +(m >>= k1) >>= k2 = m >>= \x -> k1 x >>= k2 +(existsM f) >>= k = existsM (\x -> f x >>= k) +(forallM f) >>= k = forallM (\x -> f x >>= k) +(orM m1 m2) >>= k = orM (m1 >>= k) (m2 >>= k) +(if b then m1 else m2) >>= k = if b then m1 >>= k else m2 >>1 k +(either f1 f2 e) >>= k = either (\x -> f1 x >= k) (\x -> f2 x >= k) e +(letrecM funs body) >>= k = letrecM funs (\F1 ... Fn -> body F1 ... Fn >>= k) + +The resulting computations of one of the following forms: + +returnM e | errorM str | existsM f | forallM f | orM m1 m2 | +if b then m1 else m2 | either f1 f2 e | F e1 ... en | F e1 ... en >>= k | +letrecM lrts B (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> m) + +The form F e1 ... en refers to a recursively-defined function or a function +variable that has been locally bound by a letrecM. Either way, monadic +normalization does not attempt to normalize these functions. + +The algorithm maintains a context of three sorts of variables: letrec-bound +variables, existential variables, and universal variables. Universal variables +are represented as free SAW core variables, while the other two forms of +variable are represented as SAW core 'ExtCns's terms, which are essentially +axioms that have been generated internally. These 'ExtCns's are Skolemized, +meaning that they take in as arguments all universal variables that were in +scope when they were created. The context also maintains a partial substitution +for the existential variables, as they become instantiated with values, and it +additionally remembers the bodies / unfoldings of the letrec-bound variables. + +The goal of the solver at any point is of the form C |- m1 |= m2, meaning that +we are trying to prove m1 refines m2 in context C. This proceed by cases: + +C |- returnM e1 |= returnM e2: prove C |- e1 = e2 + +C |- errorM str1 |= errorM str2: vacuously true + +C |- if b then m1' else m1'' |= m2: prove C,b=true |- m1' |= m2 and +C,b=false |- m1'' |= m2, skipping either case where C,b=X is unsatisfiable; + +C |- m1 |= if b then m2' else m2'': similar to the above + +C |- either T U (CompM V) f1 f2 e |= m: prove C,x:T,e=inl x |- f1 x |= m and +C,y:U,e=inl y |- f2 y |= m, again skippping any case with unsatisfiable context; + +C |- m |= either T U (CompM V) f1 f2 e: similar to previous + +C |- m |= forallM f: make a new universal variable x and recurse + +C |- existsM f |= m: make a new universal variable x and recurse (existential +elimination uses universal variables and vice-versa) + +C |- m |= existsM f: make a new existential variable x and recurse + +C |- forall f |= m: make a new existential variable x and recurse + +C |- m |= orM m1 m2: try to prove C |- m |= m1, and if that fails, backtrack and +prove C |- m |= m2 + +C |- orM m1 m2 |= m: prove both C |- m1 |= m and C |- m2 |= m + +C |- letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body) |= m: create +letrec-bound variables F1 through Fn in the context bound to their unfoldings f1 +through fn, respectively, and recurse on body |= m + +C |- m |= letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body): similar to +previous case + +C |- F e1 ... en >>= k |= F e1' ... en' >>= k': prove C |- ei = ei' for each i +and then prove k x |= k' x for new universal variable x + +C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': + +* If we have an assumption that forall x1 ... xj, F a1 ... an |= F' a1' .. am', + prove ei = ai and ei' = ai' and then that C |- k x |= k' x for fresh uvar x + +* If we have an assumption that forall x1, ..., xn, F e1'' ... en'' |= m' for + some ei'' and m', match the ei'' against the ei by instantiating the xj with + fresh evars, and if this succeeds then recursively prove C |- m' >>= k |= RHS + +(We don't do this one right now) +* If we have an assumption that forall x1', ..., xn', m |= F e1'' ... en'' for + some ei'' and m', match the ei'' against the ei by instantiating the xj with + fresh evars, and if this succeeds then recursively prove C |- LHS |= m' >>= k' + +* If either side is a definition whose unfolding does not contain letrecM, fixM, + or any related operations, unfold it + +* If F and F' have the same return type, add an assumption forall uvars in scope + that F e1 ... en |= F' e1' ... em' and unfold both sides, recursively proving + that F_body e1 ... en |= F_body' e1' ... em'. Then also prove k x |= k' x for + fresh uvar x. + +* Otherwise we don't know to "split" one of the sides into a bind whose + components relate to the two components on the other side, so just fail +-} module SAWScript.Prover.MRSolver - (askMRSolver, MRFailure(..), showMRFailure + (askMRSolver, MRFailure(..), showMRFailure, isCompFunType , SBV.SMTConfig , SBV.z3, SBV.cvc4, SBV.yices, SBV.mathSAT, SBV.boolector ) where +import Data.List (find, findIndex) +import qualified Data.Text as T +import Data.IORef +import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State import Control.Monad.Except -import Data.Semigroup +import Control.Monad.Trans.Maybe + +import qualified Data.IntMap as IntMap +import Data.Map (Map) +import qualified Data.Map as Map import Prettyprinter import Verifier.SAW.Term.Functor +import Verifier.SAW.Term.CtxTerm (MonadTerm(..)) +import Verifier.SAW.Term.Pretty +import Verifier.SAW.SCTypeCheck import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer +import Verifier.SAW.Cryptol.Monadify -import SAWScript.Proof (boolToProp) +import SAWScript.Proof (termToProp) import qualified SAWScript.Prover.SBV as SBV -import Prelude - - -newtype LocalFunName = LocalFunName { unLocalFunName :: ExtCns Term } deriving (Eq, Show) - --- | Names of functions to be used in computations, which are either local, --- letrec-bound names (represented with an 'ExtCns'), or global named constants -data FunName = LocalName LocalFunName | GlobalName Ident - deriving (Eq, Show) -funNameType :: FunName -> MRM Term -funNameType (LocalName (LocalFunName ec)) = return $ ecType ec -funNameType (GlobalName i) = liftSC1 scTypeOfGlobal i - --- | A "marking" consisting of a set of unfolded function names -newtype Mark = Mark [FunName] deriving (Semigroup, Monoid, Show) - -inMark :: FunName -> Mark -> Bool -inMark f (Mark fs) = elem f fs - -singleMark :: FunName -> Mark -singleMark f = Mark [f] +---------------------------------------------------------------------- +-- * Utility Functions for Transforming 'Term's +---------------------------------------------------------------------- + +-- | Transform the immediate subterms of a term using the supplied function +traverseSubterms :: MonadTerm m => (Term -> m Term) -> Term -> m Term +traverseSubterms f (unwrapTermF -> tf) = traverse f tf >>= mkTermF + +-- | Build a recursive memoized function for tranforming 'Term's. Take in a +-- function @f@ that intuitively performs one step of the transformation and +-- allow it to recursively call the memoized function being defined by passing +-- it as the first argument to @f@. +memoFixTermFun :: MonadIO m => ((Term -> m a) -> Term -> m a) -> Term -> m a +memoFixTermFun f term_top = + do table_ref <- liftIO $ newIORef IntMap.empty + let go t@(STApp { stAppIndex = ix }) = + liftIO (readIORef table_ref) >>= \table -> + case IntMap.lookup ix table of + Just ret -> return ret + Nothing -> + do ret <- f go t + liftIO $ modifyIORef' table_ref (IntMap.insert ix ret) + return ret + go t = f go t + go term_top + +-- | Recursively test if a 'Term' contains @letRecM@ +_containsLetRecM :: Term -> Bool +_containsLetRecM (asGlobalDef -> Just "Prelude.letRecM") = True +_containsLetRecM (unwrapTermF -> tf) = any _containsLetRecM tf + + +---------------------------------------------------------------------- +-- * MR Solver Term Representation +---------------------------------------------------------------------- + +-- | A variable used by the MR solver +newtype MRVar = MRVar { unMRVar :: ExtCns Term } deriving (Eq, Show, Ord) + +-- | Get the type of an 'MRVar' +mrVarType :: MRVar -> Term +mrVarType = ecType . unMRVar + +-- | Names of functions to be used in computations, which are either names bound +-- by letrec to for recursive calls to fixed-points, existential variables, or +-- global named constants +data FunName + = LetRecName MRVar | EVarFunName MRVar | GlobalName GlobalDef + deriving (Eq, Ord, Show) + +-- | Get the type of a 'FunName' +funNameType :: FunName -> Term +funNameType (LetRecName var) = mrVarType var +funNameType (EVarFunName var) = mrVarType var +funNameType (GlobalName gd) = globalDefType gd -- | A term specifically known to be of type @sort i@ for some @i@ newtype Type = Type Term deriving Show --- | A computation in WHNF -data WHNFComp - = Return Term -- ^ Terminates with a return - | Error -- ^ Raises an error - | If Term Comp Comp -- ^ If-then-else that returns @CompM a@ - | FunBind FunName [Term] Mark CompFun - -- ^ Bind a monadic function with @N@ arguments in an @a -> CompM b@ term, - -- marked with a set of function names +-- | A Haskell representation of a @CompM@ in "monadic normal form" +data NormComp + = ReturnM Term -- ^ A term @returnM a x@ + | ErrorM Term -- ^ A term @errorM a str@ + | Ite Term Comp Comp -- ^ If-then-else computation + | Either CompFun CompFun Term -- ^ A sum elimination + | OrM Comp Comp -- ^ an @orM@ computation + | ExistsM Type CompFun -- ^ an @existsM@ computation + | ForallM Type CompFun -- ^ a @forallM@ computation + | FunBind FunName [Term] CompFun + -- ^ Bind a monadic function with @N@ arguments in an @a -> CompM b@ term deriving Show -- | A computation function of type @a -> CompM b@ for some @a@ and @b@ data CompFun + -- | An arbitrary term = CompFunTerm Term + -- | A special case for the term @\ (x:a) -> returnM a x@ + | CompFunReturn + -- | The monadic composition @f >=> g@ | CompFunComp CompFun CompFun - -- ^ The monadic composition @f >=> g@ - | CompFunMark CompFun Mark - -- ^ A computation marked with function names deriving Show +-- | Compose two 'CompFun's, simplifying if one is a 'CompFunReturn' +compFunComp :: CompFun -> CompFun -> CompFun +compFunComp CompFunReturn f = f +compFunComp f CompFunReturn = f +compFunComp f g = CompFunComp f g + +-- | If a 'CompFun' contains an explicit lambda-abstraction, then return the +-- textual name bound by that lambda +compFunVarName :: CompFun -> Maybe LocalName +compFunVarName (CompFunTerm (asLambda -> Just (nm, _, _))) = Just nm +compFunVarName (CompFunComp f _) = compFunVarName f +compFunVarName _ = Nothing + +-- | If a 'CompFun' contains an explicit lambda-abstraction, then return the +-- input type for it +compFunInputType :: CompFun -> Maybe Type +compFunInputType (CompFunTerm (asLambda -> Just (_, tp, _))) = Just $ Type tp +compFunInputType (CompFunComp f _) = compFunInputType f +compFunInputType _ = Nothing + -- | A computation of type @CompM a@ for some @a@ -data Comp = CompTerm Term | CompBind Comp CompFun | CompMark Comp Mark +data Comp = CompTerm Term | CompBind Comp CompFun | CompReturn Term deriving Show --- | A universal type for all the different ways MR. Solver represents terms -data MRTerm - = MRTermTerm Term - | MRTermType Type - | MRTermComp Comp - | MRTermCompFun CompFun - | MRTermWHNFComp WHNFComp - | MRTermFunName FunName - deriving Show - --- | Typeclass for things that can be coerced to 'MRTerm' -class IsMRTerm a where - toMRTerm :: a -> MRTerm -instance IsMRTerm Term where toMRTerm = MRTermTerm -instance IsMRTerm Type where toMRTerm = MRTermType -instance IsMRTerm Comp where toMRTerm = MRTermComp -instance IsMRTerm CompFun where toMRTerm = MRTermCompFun -instance IsMRTerm WHNFComp where toMRTerm = MRTermWHNFComp -instance IsMRTerm FunName where toMRTerm = MRTermFunName +---------------------------------------------------------------------- +-- * Pretty-Printing MR Solver Terms +---------------------------------------------------------------------- + +-- | The monad for pretty-printing in a context of SAW core variables +type PPInCtxM = Reader [LocalName] + +-- | Pretty-print an object in a SAW core context and render to a 'String' +showInCtx :: PrettyInCtx a => [LocalName] -> a -> String +showInCtx ctx a = + renderSawDoc defaultPPOpts $ runReader (prettyInCtx a) ctx + +-- | A generic function for pretty-printing an object in a SAW core context of +-- locally-bound names +class PrettyInCtx a where + prettyInCtx :: a -> PPInCtxM SawDoc + +instance PrettyInCtx Term where + prettyInCtx t = flip (ppTermInCtx defaultPPOpts) t <$> ask + +-- | Combine a list of pretty-printed documents that represent an application +prettyAppList :: [PPInCtxM SawDoc] -> PPInCtxM SawDoc +prettyAppList = fmap (group . hang 2 . vsep) . sequence + +instance PrettyInCtx Type where + prettyInCtx (Type t) = prettyInCtx t + +instance PrettyInCtx MRVar where + prettyInCtx (MRVar ec) = return $ ppName $ ecName ec + +instance PrettyInCtx FunName where + prettyInCtx (LetRecName var) = prettyInCtx var + prettyInCtx (EVarFunName var) = prettyInCtx var + prettyInCtx (GlobalName i) = return $ viaShow i + +instance PrettyInCtx Comp where + prettyInCtx (CompTerm t) = prettyInCtx t + prettyInCtx (CompBind c f) = + prettyAppList [prettyInCtx c, return ">>=", prettyInCtx f] + prettyInCtx (CompReturn t) = + prettyAppList [ return "returnM", return "_", parens <$> prettyInCtx t] + +instance PrettyInCtx CompFun where + prettyInCtx (CompFunTerm t) = prettyInCtx t + prettyInCtx CompFunReturn = return "returnM" + prettyInCtx (CompFunComp f g) = + prettyAppList [prettyInCtx f, return ">=>", prettyInCtx g] + +instance PrettyInCtx NormComp where + prettyInCtx (ReturnM t) = + prettyAppList [return "returnM", return "_", parens <$> prettyInCtx t] + prettyInCtx (ErrorM str) = + prettyAppList [return "errorM", return "_", parens <$> prettyInCtx str] + prettyInCtx (Ite cond t1 t2) = + prettyAppList [return "ite", return "_", prettyInCtx cond, + parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] + prettyInCtx (Either f g eith) = + prettyAppList [return "either", return "_", return "_", return "_", + prettyInCtx f, prettyInCtx g, prettyInCtx eith] + prettyInCtx (OrM t1 t2) = + prettyAppList [return "orM", return "_", + parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] + prettyInCtx (ExistsM tp f) = + prettyAppList [return "existsM", prettyInCtx tp, return "_", + parens <$> prettyInCtx f] + prettyInCtx (ForallM tp f) = + prettyAppList [return "forallM", prettyInCtx tp, return "_", + parens <$> prettyInCtx f] + prettyInCtx (FunBind f args CompFunReturn) = + prettyAppList (prettyInCtx f : map prettyInCtx args) + prettyInCtx (FunBind f [] k) = + prettyAppList [prettyInCtx f, return ">>=", prettyInCtx k] + prettyInCtx (FunBind f args k) = + prettyAppList + [parens <$> prettyAppList (prettyInCtx f : map prettyInCtx args), + return ">>=", prettyInCtx k] + + +---------------------------------------------------------------------- +-- * Lifting MR Solver Terms +---------------------------------------------------------------------- + +-- | A term-like object is one that supports lifting and substitution +class TermLike a where + liftTermLike :: MonadTerm m => DeBruijnIndex -> DeBruijnIndex -> a -> m a + substTermLike :: MonadTerm m => DeBruijnIndex -> [Term] -> a -> m a + +instance (TermLike a, TermLike b) => TermLike (a,b) where + liftTermLike n i (a,b) = (,) <$> liftTermLike n i a <*> liftTermLike n i b + substTermLike n s (a,b) = (,) <$> substTermLike n s a <*> substTermLike n s b + +instance TermLike a => TermLike [a] where + liftTermLike n i l = mapM (liftTermLike n i) l + substTermLike n s l = mapM (substTermLike n s) l + +instance TermLike Term where + liftTermLike = liftTerm + substTermLike = substTerm + +instance TermLike Type where + liftTermLike n i (Type tp) = Type <$> liftTerm n i tp + substTermLike n s (Type tp) = Type <$> substTerm n s tp + +instance TermLike NormComp where + liftTermLike n i (ReturnM t) = ReturnM <$> liftTermLike n i t + liftTermLike n i (ErrorM str) = ErrorM <$> liftTermLike n i str + liftTermLike n i (Ite cond t1 t2) = + Ite <$> liftTermLike n i cond <*> liftTermLike n i t1 <*> liftTermLike n i t2 + liftTermLike n i (Either f g eith) = + Either <$> liftTermLike n i f <*> liftTermLike n i g <*> liftTermLike n i eith + liftTermLike n i (OrM t1 t2) = OrM <$> liftTermLike n i t1 <*> liftTermLike n i t2 + liftTermLike n i (ExistsM tp f) = + ExistsM <$> liftTermLike n i tp <*> liftTermLike n i f + liftTermLike n i (ForallM tp f) = + ForallM <$> liftTermLike n i tp <*> liftTermLike n i f + liftTermLike n i (FunBind nm args f) = + FunBind nm <$> mapM (liftTermLike n i) args <*> liftTermLike n i f + + substTermLike n s (ReturnM t) = ReturnM <$> substTermLike n s t + substTermLike n s (ErrorM str) = ErrorM <$> substTermLike n s str + substTermLike n s (Ite cond t1 t2) = + Ite <$> substTermLike n s cond <*> substTermLike n s t1 + <*> substTermLike n s t2 + substTermLike n s (Either f g eith) = + Either <$> substTermLike n s f <*> substTermLike n s g + <*> substTermLike n s eith + substTermLike n s (OrM t1 t2) = + OrM <$> substTermLike n s t1 <*> substTermLike n s t2 + substTermLike n s (ExistsM tp f) = + ExistsM <$> substTermLike n s tp <*> substTermLike n s f + substTermLike n s (ForallM tp f) = + ForallM <$> substTermLike n s tp <*> substTermLike n s f + substTermLike n s (FunBind nm args f) = + FunBind nm <$> mapM (substTermLike n s) args <*> substTermLike n s f + +instance TermLike CompFun where + liftTermLike n i (CompFunTerm t) = CompFunTerm <$> liftTermLike n i t + liftTermLike _ _ CompFunReturn = return CompFunReturn + liftTermLike n i (CompFunComp f g) = + CompFunComp <$> liftTermLike n i f <*> liftTermLike n i g + + substTermLike n s (CompFunTerm t) = CompFunTerm <$> substTermLike n s t + substTermLike _ _ CompFunReturn = return CompFunReturn + substTermLike n s (CompFunComp f g) = + CompFunComp <$> substTermLike n s f <*> substTermLike n s g + +instance TermLike Comp where + liftTermLike n i (CompTerm t) = CompTerm <$> liftTermLike n i t + liftTermLike n i (CompBind m f) = + CompBind <$> liftTermLike n i m <*> liftTermLike n i f + liftTermLike n i (CompReturn t) = CompReturn <$> liftTermLike n i t + substTermLike n s (CompTerm t) = CompTerm <$> substTermLike n s t + substTermLike n s (CompBind m f) = + CompBind <$> substTermLike n s m <*> substTermLike n s f + substTermLike n s (CompReturn t) = CompReturn <$> substTermLike n s t + + +---------------------------------------------------------------------- +-- * MR Solver Errors +---------------------------------------------------------------------- -- | The context in which a failure occurred data FailCtx - = FailCtxCmp MRTerm MRTerm - | FailCtxWHNF Term + = FailCtxRefines NormComp NormComp + | FailCtxMNF Term deriving Show -- | That's MR. Failure to you data MRFailure = TermsNotEq Term Term | TypesNotEq Type Type + | CompsDoNotRefine NormComp NormComp | ReturnNotError Term | FunsNotEq FunName FunName | CannotLookupFunDef FunName | RecursiveUnfold FunName - | MalformedInOutTypes Term + | MalformedLetRecTypes Term | MalformedDefsFun Term | MalformedComp Term | NotCompFunType Term + -- | A local variable binding + | MRFailureLocalVar LocalName MRFailure + -- | Information about the context of the failure | MRFailureCtx FailCtx MRFailure - -- ^ Records terms we were trying to compare when we got a failure + -- | Records a disjunctive branch we took, where both cases failed | MRFailureDisj MRFailure MRFailure - -- ^ Records a disjunctive branch we took, where both cases failed deriving Show -prettyTerm :: Term -> Doc ann -prettyTerm = unAnnotate . ppTerm defaultPPOpts - -prettyAppList :: [Doc ann] -> Doc ann -prettyAppList = group . hang 2 . vsep - -instance Pretty Type where - pretty (Type t) = prettyTerm t - -instance Pretty FunName where - pretty (LocalName (LocalFunName ec)) = unAnnotate $ ppName $ ecName ec - pretty (GlobalName i) = viaShow i - -instance Pretty Comp where - pretty (CompTerm t) = prettyTerm t - pretty (CompBind c f) = - prettyAppList [pretty c, ">>=", pretty f] - pretty (CompMark c _) = - -- FIXME: print the mark? - pretty c - -instance Pretty CompFun where - pretty (CompFunTerm t) = prettyTerm t - pretty (CompFunComp f g) = - prettyAppList [pretty f, ">=>", pretty g] - pretty (CompFunMark f _) = - -- FIXME: print the mark? - pretty f - -instance Pretty WHNFComp where - pretty (Return t) = - prettyAppList ["returnM", parens (prettyTerm t)] - pretty Error = "errorM" - pretty (If cond t1 t2) = - prettyAppList ["ite", prettyTerm cond, - parens (pretty t1), parens (pretty t2)] - pretty (FunBind f [] _ k) = - prettyAppList [pretty f, ">>=", pretty k] - pretty (FunBind f args _ k) = - prettyAppList - [parens (prettyAppList (pretty f : map prettyTerm args)), - ">>=" <+> pretty k] - -vsepIndent24 :: Doc ann -> Doc ann -> Doc ann -> Doc ann -> Doc ann -vsepIndent24 d1 d2 d3 d4 = - group (d1 <> nest 2 (line <> d2) <> line <> d3 <> nest 2 (line <> d4)) - -instance Pretty MRTerm where - pretty (MRTermTerm t) = prettyTerm t - pretty (MRTermType tp) = pretty tp - pretty (MRTermComp comp) = pretty comp - pretty (MRTermCompFun f) = pretty f - pretty (MRTermWHNFComp norm) = pretty norm - pretty (MRTermFunName nm) = "function" <+> pretty nm - -instance Pretty FailCtx where - pretty (FailCtxCmp t1 t2) = - group $ nest 2 $ vsep ["When comparing terms:", pretty t1, pretty t2] - pretty (FailCtxWHNF t) = - group $ nest 2 $ vsep ["When normalizing computation:", prettyTerm t] - -instance Pretty MRFailure where - pretty (TermsNotEq t1 t2) = - vsepIndent24 - "Terms not equal:" (prettyTerm t1) - "and" (prettyTerm t2) - pretty (TypesNotEq tp1 tp2) = - vsepIndent24 - "Types not equal:" (pretty tp1) - "and" (pretty tp2) - pretty (ReturnNotError t) = - nest 2 ("errorM not equal to" <+> - group (hang 2 $ vsep ["returnM", prettyTerm t])) - pretty (FunsNotEq nm1 nm2) = - vsep ["Named functions not equal:", pretty nm1, pretty nm2] - pretty (CannotLookupFunDef nm) = - vsep ["Could not find definition for function:", pretty nm] - pretty (RecursiveUnfold nm) = - vsep ["Recursive unfolding of function inside its own body:", - pretty nm] - pretty (MalformedInOutTypes t) = - "Not a ground InputOutputTypes list:" - <> nest 2 (line <> prettyTerm t) - pretty (MalformedDefsFun t) = - "Cannot handle letRecM recursive definitions term:" - <> nest 2 (line <> prettyTerm t) - pretty (MalformedComp t) = - "Could not handle computation:" - <> nest 2 (line <> prettyTerm t) - pretty (NotCompFunType tp) = - "Not a computation or computational function type:" - <> nest 2 (line <> prettyTerm tp) - pretty (MRFailureCtx ctx err) = - pretty ctx <> line <> pretty err - pretty (MRFailureDisj err1 err2) = - vsepIndent24 "Tried two comparisons:" (pretty err1) - "Backtracking..." (pretty err2) - +-- | Pretty-print an object prefixed with a 'String' that describes it +ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc +ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a + +-- | Pretty-print two objects, prefixed with a 'String' and with a separator +ppWithPrefixSep :: PrettyInCtx a => String -> a -> String -> a -> + PPInCtxM SawDoc +ppWithPrefixSep d1 t2 d3 t4 = + prettyInCtx t2 >>= \d2 -> prettyInCtx t4 >>= \d4 -> + return $ group (pretty d1 <> nest 2 (line <> d2) <> line <> + pretty d3 <> nest 2 (line <> d4)) + +-- | Apply 'vsep' to a list of pretty-printing computations +vsepM :: [PPInCtxM SawDoc] -> PPInCtxM SawDoc +vsepM = fmap vsep . sequence + +instance PrettyInCtx FailCtx where + prettyInCtx (FailCtxRefines m1 m2) = + group <$> nest 2 <$> + ppWithPrefixSep "When proving refinement:" m1 "|=" m2 + prettyInCtx (FailCtxMNF t) = + group <$> nest 2 <$> vsepM [return "When normalizing computation:", + prettyInCtx t] + +instance PrettyInCtx MRFailure where + prettyInCtx (TermsNotEq t1 t2) = + ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2 + prettyInCtx (TypesNotEq tp1 tp2) = + ppWithPrefixSep "Types not equal:" tp1 "and" tp2 + prettyInCtx (CompsDoNotRefine m1 m2) = + ppWithPrefixSep "Could not prove refinement: " m1 "|=" m2 + prettyInCtx (ReturnNotError t) = + ppWithPrefix "errorM computation not equal to:" (ReturnM t) + prettyInCtx (FunsNotEq nm1 nm2) = + vsepM [return "Named functions not equal:", + prettyInCtx nm1, prettyInCtx nm2] + prettyInCtx (CannotLookupFunDef nm) = + ppWithPrefix "Could not find definition for function:" nm + prettyInCtx (RecursiveUnfold nm) = + ppWithPrefix "Recursive unfolding of function inside its own body:" nm + prettyInCtx (MalformedLetRecTypes t) = + ppWithPrefix "Not a ground LetRecTypes list:" t + prettyInCtx (MalformedDefsFun t) = + ppWithPrefix "Cannot handle letRecM recursive definitions term:" t + prettyInCtx (MalformedComp t) = + ppWithPrefix "Could not handle computation:" t + prettyInCtx (NotCompFunType tp) = + ppWithPrefix "Not a computation or computational function type:" tp + prettyInCtx (MRFailureLocalVar x err) = + local (x:) $ prettyInCtx err + prettyInCtx (MRFailureCtx ctx err) = + do pp1 <- prettyInCtx ctx + pp2 <- prettyInCtx err + return (pp1 <> line <> pp2) + prettyInCtx (MRFailureDisj err1 err2) = + ppWithPrefixSep "Tried two comparisons:" err1 "Backtracking..." err2 + +-- | Render a 'MRFailure' to a 'String' showMRFailure :: MRFailure -> String -showMRFailure = show . pretty +showMRFailure = showInCtx [] + + +---------------------------------------------------------------------- +-- * MR Monad +---------------------------------------------------------------------- + +-- | Classification info for what sort of variable an 'MRVar' is +data MRVarInfo + -- | An existential variable, that might be instantiated + = EVarInfo (Maybe Term) + -- | A letrec-bound function, with its body + | FunVarInfo Term + +-- | A map from 'MRVar's to their info +type MRVarMap = Map MRVar MRVarInfo + +-- | Test if a 'Term' is an application of an 'ExtCns' to some arguments +asExtCnsApp :: Recognizer Term (ExtCns Term, [Term]) +asExtCnsApp (asApplyAll -> (asExtCns -> Just ec, args)) = + return (ec, args) +asExtCnsApp _ = Nothing + +-- | Recognize an evar applied to 0 or more arguments relative to a 'MRVarMap' +-- along with its instantiation, if any +asEVarApp :: MRVarMap -> Recognizer Term (MRVar, [Term], Maybe Term) +asEVarApp var_map (asExtCnsApp -> Just (ec, args)) + | Just (EVarInfo maybe_inst) <- Map.lookup (MRVar ec) var_map = + Just (MRVar ec, args, maybe_inst) +asEVarApp _ _ = Nothing + +-- | An assumption that a named function refines some specificaiton. This has +-- the form +-- +-- > forall x1, ..., xn. F e1 ... ek |= m +-- +-- for some universal context @x1:T1, .., xn:Tn@, some list of argument +-- expressions @ei@ over the universal @xj@ variables, and some right-hand side +-- computation expression @m@. +data FunAssump = FunAssump { + -- | The uvars that were in scope when this assmption was created, in order + -- from outermost to innermost; that is, the uvars as "seen from outside their + -- scope", which is the reverse of the order of 'mrUVars', below + fassumpCtx :: [(LocalName,Term)], + -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars + fassumpArgs :: [Term], + -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars + fassumpRHS :: NormComp } -- | State maintained by MR. Solver data MRState = MRState { + -- | Global shared context for building terms, etc. mrSC :: SharedContext, - -- ^ Global shared context for building terms, etc. + -- | Global SMT configuration for the duration of the MR. Solver call mrSMTConfig :: SBV.SMTConfig, - -- ^ Global SMT configuration for the duration of the MR. Solver call + -- | SMT timeout for SMT calls made by Mr. Solver mrSMTTimeout :: Maybe Integer, - -- ^ SMT timeout for SMT calls made by Mr. Solver - mrLocalFuns :: [(LocalFunName, Term)], - -- ^ Letrec-bound function names with their definitions as lambda-terms - mrFunEqs :: [((FunName, FunName), Bool)], - -- ^ Cache of which named functions are equal - mrPathCondition :: Term - -- ^ The conjunction of all Boolean if conditions along the current path + -- | The context of universal variables, which are free SAW core variables, in + -- order from innermost to outermost, i.e., where element @0@ corresponds to + -- deBruijn index @0@ + mrUVars :: [(LocalName,Type)], + -- | The existential and letrec-bound variables + mrVars :: MRVarMap, + -- | The current assumptions of function refinement + mrFunAssumps :: Map FunName FunAssump, + -- | The current assumptions, which are conjoined into a single Boolean term + mrAssumptions :: Term, + -- | The debug level, which controls debug printing + mrDebugLevel :: Int } --- | Monad used by the MR. Solver -type MRM = ExceptT MRFailure (StateT MRState IO) - --- | Run an 'MRM' computation, and apply a function to any failure thrown +-- | Build a default, empty state from SMT configuration parameters and a set of +-- function refinement assumptions +mkMRState :: SharedContext -> Map FunName FunAssump -> SBV.SMTConfig -> + Maybe Integer -> Int -> IO MRState +mkMRState sc fun_assumps smt_config timeout dlvl = + scBool sc True >>= \true_tm -> + return $ MRState { mrSC = sc, mrSMTConfig = smt_config, + mrSMTTimeout = timeout, mrUVars = [], mrVars = Map.empty, + mrFunAssumps = fun_assumps, mrAssumptions = true_tm, + mrDebugLevel = dlvl } + +-- | Mr. Monad, the monad used by MR. Solver, which is the state-exception monad +newtype MRM a = MRM { unMRM :: StateT MRState (ExceptT MRFailure IO) a } + deriving (Functor, Applicative, Monad, MonadIO, + MonadState MRState, MonadError MRFailure) + +instance MonadTerm MRM where + mkTermF = liftSC1 scTermF + liftTerm = liftSC3 incVars + whnfTerm = liftSC1 scWhnf + substTerm = liftSC3 instantiateVarList + +-- | Run an 'MRM' computation and return a result or an error +runMRM :: MRState -> MRM a -> IO (Either MRFailure a) +runMRM init_st m = runExceptT $ flip evalStateT init_st $ unMRM m + +-- | Apply a function to any failure thrown by an 'MRM' computation mapFailure :: (MRFailure -> MRFailure) -> MRM a -> MRM a mapFailure f m = catchError m (throwError . f) --- | Try two different 'MRM' computations, combining their failures if needed +-- | Try two different 'MRM' computations, combining their failures if needed. +-- Note that the 'MRState' will reset if the first computation fails. mrOr :: MRM a -> MRM a -> MRM a mrOr m1 m2 = catchError m1 $ \err1 -> @@ -252,13 +617,20 @@ mrOr m1 m2 = withFailureCtx :: FailCtx -> MRM a -> MRM a withFailureCtx ctx = mapFailure (MRFailureCtx ctx) +{- -- | Catch any errors thrown by a computation and coerce them to a 'Left' catchErrorEither :: MonadError e m => m a -> m (Either e a) catchErrorEither m = catchError (Right <$> m) (return . Left) +-} +-- FIXME: replace these individual lifting functions with a more general +-- typeclass like LiftTCM + +{- -- | Lift a nullary SharedTerm computation into 'MRM' liftSC0 :: (SharedContext -> IO a) -> MRM a liftSC0 f = (mrSC <$> get) >>= \sc -> liftIO (f sc) +-} -- | Lift a unary SharedTerm computation into 'MRM' liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM b @@ -272,75 +644,606 @@ liftSC2 f a b = (mrSC <$> get) >>= \sc -> liftIO (f sc a b) liftSC3 :: (SharedContext -> a -> b -> c -> IO d) -> a -> b -> c -> MRM d liftSC3 f a b c = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c) --- | Test if a Boolean term is "provable", i.e., its negation is unsatisfiable -mrProvable :: Term -> MRM Bool -mrProvable bool_prop = +-- | Lift a quaternary SharedTerm computation into 'MRM' +liftSC4 :: (SharedContext -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> + MRM e +liftSC4 f a b c d = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c d) + +-- | Apply a 'Term' to a list of arguments and beta-reduce in Mr. Monad +mrApplyAll :: Term -> [Term] -> MRM Term +mrApplyAll f args = liftSC2 scApplyAll f args >>= liftSC1 betaNormalize + +-- | Get the current context of uvars as a list of variable names and their +-- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in +-- the order as seen "from the outside" +mrUVarCtx :: MRM [(LocalName,Term)] +mrUVarCtx = reverse <$> map (\(nm,Type tp) -> (nm,tp)) <$> mrUVars <$> get + +-- | Get the type of a 'Term' in the current uvar context +mrTypeOf :: Term -> MRM Term +mrTypeOf t = mrUVarCtx >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t + +-- | Check if two 'Term's are convertible in the 'MRM' monad +mrConvertible :: Term -> Term -> MRM Bool +mrConvertible = liftSC4 scConvertibleEval scTypeCheckWHNF True + +-- | Take a 'FunName' @f@ for a monadic function of type @vars -> CompM a@ and +-- compute the type @CompM [args/vars]a@ of @f@ applied to @args@. Return the +-- type @[args/vars]a@ that @CompM@ is applied to. +mrFunOutType :: FunName -> [Term] -> MRM Term +mrFunOutType ((asPiList . funNameType) -> (vars, asCompM -> Just tp)) args + | length vars == length args = + substTermLike 0 args tp +mrFunOutType _ _ = + -- NOTE: this is an error because we should only ever call mrFunOutType with a + -- well-formed application at a CompM type + error "mrFunOutType" + +-- | Turn a 'LocalName' into one not in a list, adding a suffix if necessary +uniquifyName :: LocalName -> [LocalName] -> LocalName +uniquifyName nm nms | notElem nm nms = nm +uniquifyName nm nms = + case find (flip notElem nms) $ + map (T.append nm . T.pack . show) [(0::Int) ..] of + Just nm' -> nm' + Nothing -> error "uniquifyName" + +-- | Run a MR Solver computation in a context extended with a universal +-- variable, which is passed as a 'Term' to the sub-computation +withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a +withUVar nm tp m = + do st <- get + let nm' = uniquifyName nm (map fst $ mrUVars st) + put (st { mrUVars = (nm',tp) : mrUVars st }) + ret <- mapFailure (MRFailureLocalVar nm') (liftSC1 scLocalVar 0 >>= m) + modify (\st' -> st' { mrUVars = mrUVars st }) + return ret + +-- | Run a MR Solver computation in a context extended with a universal variable +-- and pass it the lifting (in the sense of 'incVars') of an MR Solver term +withUVarLift :: TermLike tm => LocalName -> Type -> tm -> + (Term -> tm -> MRM a) -> MRM a +withUVarLift nm tp t m = + withUVar nm tp (\x -> liftTermLike 0 1 t >>= m x) + +-- | Run a MR Solver computation in a context extended with a list of universal +-- variables, passing 'Term's for those variables to the supplied computation. +-- The variables are bound "outside in", meaning the first variable in the list +-- is bound outermost, and so will have the highest deBruijn index. +withUVars :: [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a +withUVars = helper [] where + -- The extra input list gives the variables that have already been bound, in + -- order from most to least recently bound + helper :: [Term] -> [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a + helper vars [] m = m $ reverse vars + helper vars ((nm,tp):ctx) m = + substTerm 0 vars tp >>= \tp' -> + withUVar nm (Type tp') $ \var -> helper (var:vars) ctx m + +-- | Build 'Term's for all the uvars currently in scope, ordered from least to +-- most recently bound +getAllUVarTerms :: MRM [Term] +getAllUVarTerms = + (length <$> mrUVars <$> get) >>= \len -> + mapM (liftSC1 scLocalVar) [len-1, len-2 .. 0] + +-- | Lambda-abstract all the current uvars out of a 'Term', with the least +-- recently bound variable being abstracted first +lambdaUVarsM :: Term -> MRM Term +lambdaUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scLambdaList ctx t + +-- | Pi-abstract all the current uvars out of a 'Term', with the least recently +-- bound variable being abstracted first +piUVarsM :: Term -> MRM Term +piUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scPiList ctx t + +-- | Convert an 'MRVar' to a 'Term', applying it to all the uvars in scope +mrVarTerm :: MRVar -> MRM Term +mrVarTerm (MRVar ec) = + do var_tm <- liftSC1 scExtCns ec + vars <- getAllUVarTerms + liftSC2 scApplyAll var_tm vars + +-- | Get the 'VarInfo' associated with a 'MRVar' +mrVarInfo :: MRVar -> MRM (Maybe MRVarInfo) +mrVarInfo var = Map.lookup var <$> mrVars <$> get + +-- | Convert an 'ExtCns' to a 'FunName' +extCnsToFunName :: ExtCns Term -> MRM FunName +extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case + Just (EVarInfo _) -> return $ EVarFunName var + Just (FunVarInfo _) -> return $ LetRecName var + Nothing + | Just glob <- asTypedGlobalDef (Unshared $ FTermF $ ExtCns ec) -> + return $ GlobalName glob + _ -> error "extCnsToFunName: unreachable" + +-- | Get the body of a function @f@ if it has one +mrFunNameBody :: FunName -> MRM (Maybe Term) +mrFunNameBody (LetRecName var) = + mrVarInfo var >>= \case + Just (FunVarInfo body) -> return $ Just body + _ -> error "mrFunBody: unknown letrec var" +mrFunNameBody (GlobalName glob) = return $ globalDefBody glob +mrFunNameBody (EVarFunName _) = return Nothing + +-- | Get the body of a function @f@ applied to some arguments, if possible +mrFunBody :: FunName -> [Term] -> MRM (Maybe Term) +mrFunBody f args = mrFunNameBody f >>= \case + Just body -> Just <$> mrApplyAll body args + Nothing -> return Nothing + +-- | Get the body of a function @f@ applied to some arguments, as per +-- 'mrFunBody', and also return whether its body recursively calls itself, as +-- per 'mrCallsFun' +mrFunBodyRecInfo :: FunName -> [Term] -> MRM (Maybe (Term, Bool)) +mrFunBodyRecInfo f args = + mrFunBody f args >>= \case + Just f_body -> Just <$> (f_body,) <$> mrCallsFun f f_body + Nothing -> return Nothing + +-- | Test if a 'Term' contains, after possibly unfolding some functions, a call +-- to a given function @f@ again +mrCallsFun :: FunName -> Term -> MRM Bool +mrCallsFun f = memoFixTermFun $ \recurse t -> case t of + (asExtCns -> Just ec) -> + do g <- extCnsToFunName ec + maybe_body <- mrFunNameBody g + case maybe_body of + _ | f == g -> return True + Just body -> recurse body + Nothing -> return False + (asTypedGlobalDef -> Just gdef) -> + case globalDefBody gdef of + _ | f == GlobalName gdef -> return True + Just body -> recurse body + Nothing -> return False + (unwrapTermF -> tf) -> + foldM (\b t' -> if b then return b else recurse t') False tf + +-- | Make a fresh 'MRVar' of a given type, which must be closed, i.e., have no +-- free uvars +mrFreshVar :: LocalName -> Term -> MRM MRVar +mrFreshVar nm tp = MRVar <$> liftSC2 scFreshEC nm tp + +-- | Set the info associated with an 'MRVar', assuming it has not been set +mrSetVarInfo :: MRVar -> MRVarInfo -> MRM () +mrSetVarInfo var info = + modify $ \st -> + st { mrVars = + Map.alter (\case + Just _ -> error "mrSetVarInfo" + Nothing -> Just info) + var (mrVars st) } + +-- | Make a fresh existential variable of the given type, abstracting out all +-- the current uvars and returning the new evar applied to all current uvars +mrFreshEVar :: LocalName -> Type -> MRM Term +mrFreshEVar nm (Type tp) = + do tp' <- piUVarsM tp + var <- mrFreshVar nm tp' + mrSetVarInfo var (EVarInfo Nothing) + mrVarTerm var + +-- | Return a fresh sequence of existential variables for a context of variable +-- names and types, assuming each variable is free in the types that occur after +-- it in the list. Return the new evars all applied to the current uvars. +mrFreshEVars :: [(LocalName,Term)] -> MRM [Term] +mrFreshEVars = helper [] where + -- Return fresh evars for the suffix of a context of variable names and types, + -- where the supplied Terms are evars that have already been generated for the + -- earlier part of the context, and so must be substituted into the remaining + -- types in the context + helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] + helper evars [] = return evars + helper evars ((nm,tp):ctx) = + do evar <- substTerm 0 evars tp >>= mrFreshEVar nm . Type + helper (evar:evars) ctx + +-- | Set the value of an evar to a closed term +mrSetEVarClosed :: MRVar -> Term -> MRM () +mrSetEVarClosed var val = + do val_tp <- mrTypeOf val + -- FIXME: catch subtyping errors and report them as being evar failures + liftSC3 scCheckSubtype Nothing (TypedTerm val val_tp) (mrVarType var) + modify $ \st -> + st { mrVars = + Map.alter + (\case + Just (EVarInfo Nothing) -> Just $ EVarInfo (Just val) + Just (EVarInfo (Just _)) -> + error "Setting existential variable: variable already set!" + _ -> error "Setting existential variable: not an evar!") + var (mrVars st) } + + +-- | Try to set the value of the application @X e1 .. en@ of evar @X@ to an +-- expression @e@ by trying to set @X@ to @\ x1 ... xn -> e@. This only works if +-- each free uvar @xi@ in @e@ is one of the arguments @ej@ to @X@ (though it +-- need not be the case that @i=j@). Return whether this succeeded. +mrTrySetAppliedEVar :: MRVar -> [Term] -> Term -> MRM Bool +mrTrySetAppliedEVar evar args t = + -- Get the complete list of argument variables of the type of evar + let (evar_vars, _) = asPiList (mrVarType evar) in + -- Get all the free variables of t + let free_vars = bitSetElems (looseVars t) in + -- For each free var of t, find an arg equal to it + case mapM (\i -> findIndex (\case + (asLocalVar -> Just j) -> i == j + _ -> False) args) free_vars of + Just fv_arg_ixs + -- Check to make sure we have the right number of args + | length args == length evar_vars -> do + -- Build a list of the input vars x1 ... xn as terms, noting that the + -- first variable is the least recently bound and so has the highest + -- deBruijn index + let arg_ixs = [length args - 1, length args - 2 .. 0] + arg_vars <- mapM (liftSC1 scLocalVar) arg_ixs + + -- For free variable of t, we substitute the corresponding variable + -- xi, substituting error terms for the variables that are not free + -- (since we have nothing else to substitute for them) + let var_map = zip free_vars fv_arg_ixs + let subst = flip map [0 .. length args - 1] $ \i -> + maybe (error "mrTrySetAppliedEVar: unexpected free variable") + (arg_vars !!) (lookup i var_map) + body <- substTerm 0 subst t + + -- Now instantiate evar to \x1 ... xn -> body + evar_inst <- liftSC2 scLambdaList evar_vars body + mrSetEVarClosed evar evar_inst + return True + + _ -> return False + + +-- | Replace all evars in a 'Term' with their instantiations when they have one +mrSubstEVars :: Term -> MRM Term +mrSubstEVars = memoFixTermFun $ \recurse t -> + do var_map <- mrVars <$> get + case t of + -- If t is an instantiated evar, recurse on its instantiation + (asEVarApp var_map -> Just (_, args, Just t')) -> + mrApplyAll t' args >>= recurse + -- If t is anything else, recurse on its immediate subterms + _ -> traverseSubterms recurse t + +-- | Replace all evars in a 'Term' with their instantiations, returning +-- 'Nothing' if we hit an uninstantiated evar +mrSubstEVarsStrict :: Term -> MRM (Maybe Term) +mrSubstEVarsStrict top_t = + runMaybeT $ flip memoFixTermFun top_t $ \recurse t -> + do var_map <- mrVars <$> get + case t of + -- If t is an instantiated evar, recurse on its instantiation + (asEVarApp var_map -> Just (_, args, Just t')) -> + lift (mrApplyAll t' args) >>= recurse + -- If t is an uninstantiated evar, return Nothing + (asEVarApp var_map -> Just (_, _, Nothing)) -> + mzero + -- If t is anything else, recurse on its immediate subterms + _ -> traverseSubterms recurse t + +-- | Makes 'mrSubstEVarsStrict' be marked as used +_mrSubstEVarsStrict :: Term -> MRM (Maybe Term) +_mrSubstEVarsStrict = mrSubstEVarsStrict + +-- | Look up the 'FunAssump' for a 'FunName', if there is one +mrGetFunAssump :: FunName -> MRM (Maybe FunAssump) +mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps <$> get + +-- | Run a computation under the additional assumption that a named function +-- applied to a list of arguments refines a given right-hand side, all of which +-- are 'Term's that can have the current uvars free +withFunAssump :: FunName -> [Term] -> NormComp -> MRM a -> MRM a +withFunAssump fname args rhs m = + do mrDebugPPPrefixSep 1 "withFunAssump" (FunBind + fname args CompFunReturn) "|=" rhs + ctx <- mrUVarCtx + assumps <- mrFunAssumps <$> get + let assumps' = Map.insert fname (FunAssump ctx args rhs) assumps + modify (\s -> s { mrFunAssumps = assumps' }) + ret <- m + modify (\s -> s { mrFunAssumps = assumps }) + return ret + +-- | Generate fresh evars for the context of a 'FunAssump' and substitute them +-- into its arguments and right-hand side +instantiateFunAssump :: FunAssump -> MRM ([Term], NormComp) +instantiateFunAssump fassump = + do evars <- mrFreshEVars $ fassumpCtx fassump + args <- substTermLike 0 evars $ fassumpArgs fassump + rhs <- substTermLike 0 evars $ fassumpRHS fassump + return (args, rhs) + +-- | Add an assumption of type @Bool@ to the current path condition while +-- executing a sub-computation +withAssumption :: Term -> MRM a -> MRM a +withAssumption phi m = + do assumps <- mrAssumptions <$> get + assumps' <- liftSC2 scAnd phi assumps + modify (\s -> s { mrAssumptions = assumps' }) + ret <- m + modify (\s -> s { mrAssumptions = assumps }) + return ret + +-- | Print a 'String' if the debug level is at least the supplied 'Int' +debugPrint :: Int -> String -> MRM () +debugPrint i str = + (mrDebugLevel <$> get) >>= \lvl -> + if lvl >= i then liftIO (hPutStrLn stderr str) else return () + +-- | Print a document if the debug level is at least the supplied 'Int' +debugPretty :: Int -> SawDoc -> MRM () +debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp + +-- | Pretty-print an object in the current context if the current debug level is +-- at least the supplied 'Int' +_debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () +_debugPrettyInCtx i a = + (mrUVars <$> get) >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) + +-- | Pretty-print an object relative to the current context +_mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc +_mrPPInCtx a = + runReader (prettyInCtx a) <$> map fst <$> mrUVars <$> get + +-- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar +-- context to 'stderr' if the debug level is at least the 'Int' provided +mrDebugPPPrefixSep :: PrettyInCtx a => Int -> String -> a -> String -> a -> + MRM () +mrDebugPPPrefixSep i pre a1 sp a2 = + (mrUVars <$> get) >>= \ctx -> + debugPretty i $ + flip runReader (map fst ctx) (group <$> nest 2 <$> + ppWithPrefixSep pre a1 sp a2) + + +---------------------------------------------------------------------- +-- * Calling Out to SMT +---------------------------------------------------------------------- + +-- | Test if a closed Boolean term is "provable", i.e., its negation is +-- unsatisfiable, using an SMT solver. By "closed" we mean that it contains no +-- uvars or 'MRVar's. +mrProvableRaw :: Term -> MRM Bool +mrProvableRaw prop_term = do smt_conf <- mrSMTConfig <$> get timeout <- mrSMTTimeout <$> get - path_prop <- mrPathCondition <$> get - bool_prop' <- liftSC2 scImplies path_prop bool_prop - sc <- mrSC <$> get - prop <- liftIO (boolToProp sc [] bool_prop') - (smt_res, _) <- liftIO (SBV.proveUnintSBVIO sc smt_conf mempty timeout prop) + prop <- liftSC1 termToProp prop_term + (smt_res, _) <- liftSC4 SBV.proveUnintSBVIO smt_conf mempty timeout prop case smt_res of Just _ -> return False Nothing -> return True --- | Test if a Boolean term is satisfiable -mrSatisfiable :: Term -> MRM Bool -mrSatisfiable prop = not <$> (liftSC1 scNot prop >>= mrProvable) - --- | Test if two terms are equal using an SMT solver -mrTermsEq :: Term -> Term -> MRM Bool -mrTermsEq t1 t2 = - do tp <- liftSC1 scTypeOf t1 - eq_fun_tm <- liftSC1 scGlobalDef "Prelude.eq" - prop <- liftSC2 scApplyAll eq_fun_tm [tp, t1, t2] - -- Remember, t1 == t2 is true iff t1 /= t2 is not satisfiable - -- not_prop <- liftSC1 scNot prop - -- not <$> mrSatisfiable not_prop - mrProvable prop - --- | Run an equality-testing computation under the assumption of an additional --- path condition. If the condition is unsatisfiable, the test is vacuously --- true, so need not be run. -withPathCondition :: Term -> MRM () -> MRM () -withPathCondition cond m = - do sat <- mrSatisfiable cond - if sat then - do old_cond <- mrPathCondition <$> get - new_cond <- liftSC2 scAnd old_cond cond - modify $ \st -> st { mrPathCondition = new_cond } - m - modify $ \st -> st { mrPathCondition = old_cond } - else return () - --- | Like 'withPathCondition' but for the negation of a condition -withNotPathCondition :: Term -> MRM () -> MRM () -withNotPathCondition cond m = - liftSC1 scNot cond >>= \cond' -> withPathCondition cond' m - --- | Get the input type of a computation function -compFunInputType :: CompFun -> MRM Term -compFunInputType (CompFunTerm t) = - do tp <- liftSC1 scTypeOf t - case asPi tp of - Just (_, tp_in, _) -> return tp_in - Nothing -> error "compFunInputType: Pi type expected!" -compFunInputType (CompFunComp f _) = compFunInputType f -compFunInputType (CompFunMark f _) = compFunInputType f +-- | Test if a Boolean term over the current uvars is provable given the current +-- assumptions +mrProvable :: Term -> MRM Bool +mrProvable bool_tm = + do assumps <- mrAssumptions <$> get + prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue + forall_prop <- piUVarsM prop + mrProvableRaw forall_prop + +-- | Build a Boolean 'Term' stating that two 'Term's are equal. This is like +-- 'scEq' except that it works on open terms. +mrEq :: Term -> Term -> MRM Term +mrEq t1 t2 = mrTypeOf t1 >>= \tp -> mrEq' tp t1 t2 + +-- | Build a Boolean 'Term' stating that the second and third 'Term' arguments +-- are equal, where the first 'Term' gives their type (which we assume is the +-- same for both). This is like 'scEq' except that it works on open terms. +mrEq' :: Term -> Term -> Term -> MRM Term +mrEq' (asDataType -> Just (pn, [])) t1 t2 + | primName pn == "Prelude.Nat" = liftSC2 scEqualNat t1 t2 +mrEq' (asBoolType -> Just _) t1 t2 = liftSC2 scBoolEq t1 t2 +mrEq' (asIntegerType -> Just _) t1 t2 = liftSC2 scIntEq t1 t2 +mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = + liftSC3 scBvEq n t1 t2 +mrEq' _ _ _ = error "mrEq': unsupported type" + +-- | A "simple" strategy for proving equality between two terms, which we assume +-- are of the same type. This strategy first checks if either side is an +-- uninstantiated evar, in which case it set that evar to the other side. If +-- not, it builds an equality proposition by applying the supplied function to +-- both sides, and passes this proposition to an SMT solver. +mrProveEqSimple :: (Term -> Term -> MRM Term) -> MRVarMap -> Term -> Term -> + MRM () + +-- If t1 is an instantiated evar, substitute and recurse +mrProveEqSimple eqf var_map (asEVarApp var_map -> Just (_, args, Just f)) t2 = + mrApplyAll f args >>= \t1' -> mrProveEqSimple eqf var_map t1' t2 + +-- If t1 is an uninstantiated evar, instantiate it with t2 +mrProveEqSimple _ var_map t1@(asEVarApp var_map -> + Just (evar, args, Nothing)) t2 = + do t2' <- mrSubstEVars t2 + success <- mrTrySetAppliedEVar evar args t2' + if success then return () else throwError (TermsNotEq t1 t2) + +-- If t2 is an instantiated evar, substitute and recurse +mrProveEqSimple eqf var_map t1 (asEVarApp var_map -> Just (_, args, Just f)) = + mrApplyAll f args >>= \t2' -> mrProveEqSimple eqf var_map t1 t2' + +-- If t2 is an uninstantiated evar, instantiate it with t1 +mrProveEqSimple _ var_map t1 t2@(asEVarApp var_map -> + Just (evar, args, Nothing)) = + do t1' <- mrSubstEVars t1 + success <- mrTrySetAppliedEVar evar args t1' + if success then return () else throwError (TermsNotEq t1 t2) + +-- Otherwise, try to prove both sides are equal. The use of mrSubstEVars instead +-- of mrSubstEVarsStrict means that we allow evars in the terms we send to the +-- SMT solver, but we treat them as uvars. +mrProveEqSimple eqf _ t1 t2 = + do t1' <- mrSubstEVars t1 + t2' <- mrSubstEVars t2 + prop <- eqf t1' t2' + success <- mrProvable prop + if success then return () else + throwError (TermsNotEq t1 t2) + + +-- | Prove that two terms are equal, instantiating evars if necessary, or +-- throwing an error if this is not possible +mrProveEq :: Term -> Term -> MRM () +mrProveEq t1_top t2_top = + (do mrDebugPPPrefixSep 1 "mrProveEq" t1_top "==" t2_top + tp <- mrTypeOf t1_top + varmap <- mrVars <$> get + proveEq varmap tp t1_top t2_top) + where + proveEq :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM () + proveEq var_map (asDataType -> Just (pn, [])) t1 t2 + | primName pn == "Prelude.Nat" = + mrProveEqSimple (liftSC2 scEqualNat) var_map t1 t2 + proveEq var_map (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = + -- FIXME: make a better solver for bitvector equalities + mrProveEqSimple (liftSC3 scBvEq n) var_map t1 t2 + proveEq var_map (asBoolType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scBoolEq) var_map t1 t2 + proveEq var_map (asIntegerType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scIntEq) var_map t1 t2 + proveEq _ _ t1 t2 = + -- As a fallback, for types we can't handle, just check convertibility + mrConvertible t1 t2 >>= \case + True -> return () + False -> throwError (TermsNotEq t1 t2) + + +---------------------------------------------------------------------- +-- * Normalizing and Matching on Terms +---------------------------------------------------------------------- + +-- | Match a type as being of the form @CompM a@ for some @a@ +asCompM :: Term -> Maybe Term +asCompM (asApp -> Just (isGlobalDef "Prelude.CompM" -> Just (), tp)) = + return tp +asCompM _ = fail "not a CompM type!" --- | Match a term as a function name -asFunName :: Term -> Maybe FunName -asFunName t = - (LocalName <$> LocalFunName <$> asExtCns t) - `mplus` (GlobalName <$> asGlobalDef t) +-- | Test if a type is a monadic function type of 0 or more arguments +isCompFunType :: Term -> Bool +isCompFunType (asPiList -> (_, asCompM -> Just _)) = True +isCompFunType _ = False + +-- | Pattern-match on a @LetRecTypes@ list in normal form and return a list of +-- the types it specifies, each in normal form and with uvars abstracted out +asLRTList :: Term -> MRM [Term] +asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Nil", [])) = + return [] +asLRTList (asCtor -> Just (primName -> "Prelude.LRT_Cons", [lrt, lrts])) = + do tp <- liftSC2 scGlobalApply "Prelude.lrtToType" [lrt] + tp_norm_closed <- liftSC1 scWhnf tp >>= piUVarsM + (tp_norm_closed :) <$> asLRTList lrts +asLRTList t = throwError (MalformedLetRecTypes t) + +-- | Match a right-nested series of pairs. This is similar to 'asTupleValue' +-- except that it expects a unit value to always be at the end. +asNestedPairs :: Recognizer Term [Term] +asNestedPairs (asPairValue -> Just (x, asNestedPairs -> Just xs)) = Just (x:xs) +asNestedPairs (asFTermF -> Just UnitValue) = Just [] +asNestedPairs _ = Nothing + +-- | Normalize a 'Term' of monadic type to monadic normal form +normCompTerm :: Term -> MRM NormComp +normCompTerm = normComp . CompTerm + +-- | Normalize a computation to monadic normal form, assuming any 'Term's it +-- contains have already been normalized with respect to beta and projections +-- (but constants need not be unfolded) +normComp :: Comp -> MRM NormComp +normComp (CompReturn t) = return $ ReturnM t +normComp (CompBind m f) = + do norm <- normComp m + normBind norm f +normComp (CompTerm t) = + withFailureCtx (FailCtxMNF t) $ + case asApplyAll t of + (isGlobalDef "Prelude.returnM" -> Just (), [_, x]) -> + return $ ReturnM x + (isGlobalDef "Prelude.bindM" -> Just (), [_, _, m, f]) -> + do norm <- normComp (CompTerm m) + normBind norm (CompFunTerm f) + (isGlobalDef "Prelude.errorM" -> Just (), [_, str]) -> + return (ErrorM str) + (isGlobalDef "Prelude.ite" -> Just (), [_, cond, then_tm, else_tm]) -> + return $ Ite cond (CompTerm then_tm) (CompTerm else_tm) + (isGlobalDef "Prelude.either" -> Just (), [_, _, _, f, g, eith]) -> + return $ Either (CompFunTerm f) (CompFunTerm g) eith + (isGlobalDef "Prelude.orM" -> Just (), [_, m1, m2]) -> + return $ OrM (CompTerm m1) (CompTerm m2) + (isGlobalDef "Prelude.existsM" -> Just (), [tp, _, body_tm]) -> + return $ ExistsM (Type tp) (CompFunTerm body_tm) + (isGlobalDef "Prelude.forallM" -> Just (), [tp, _, body_tm]) -> + return $ ForallM (Type tp) (CompFunTerm body_tm) + (isGlobalDef "Prelude.letRecM" -> Just (), [lrts, _, defs_f, body_f]) -> + do + -- First, make fresh function constants for all the bound functions, + -- using the names bound by body_f and just "F" if those run out + let fun_var_names = + map fst (fst $ asLambdaList body_f) ++ repeat "F" + fun_tps <- asLRTList lrts + funs <- zipWithM mrFreshVar fun_var_names fun_tps + fun_tms <- mapM mrVarTerm funs + + -- Next, apply the definition function defs_f to our function vars, + -- yielding the definitions of the individual letrec-bound functions in + -- terms of the new function constants + defs_tm <- mrApplyAll defs_f fun_tms + defs <- case asNestedPairs defs_tm of + Just defs -> return defs + Nothing -> throwError (MalformedDefsFun defs_f) + + -- Remember the body associated with each fresh function constant + zipWithM_ (\f body -> + lambdaUVarsM body >>= \cl_body -> + mrSetVarInfo f (FunVarInfo cl_body)) funs defs + + -- Finally, apply the body function to our function vars and recursively + -- normalize the resulting computation + body_tm <- mrApplyAll body_f fun_tms + normComp (CompTerm body_tm) + + -- Only unfold constants that are not recursive functions, i.e., whose + -- bodies do not contain letrecs + {- FIXME: this should be handled by mrRefines; we want it to be handled there + so that we use refinement assumptions before unfolding constants, to give + the user control over refinement proofs + ((asConstant -> Just (_, body)), args) + | not (containsLetRecM body) -> + mrApplyAll body args >>= normCompTerm + -} + + -- For an ExtCns, we have to check what sort of variable it is + -- FIXME: substitute for evars if they have been instantiated + ((asExtCns -> Just ec), args) -> + do fun_name <- extCnsToFunName ec + return $ FunBind fun_name args CompFunReturn + + ((asTypedGlobalDef -> Just gdef), args) -> + return $ FunBind (GlobalName gdef) args CompFunReturn + + _ -> throwError (MalformedComp t) --- | Match a term as being of the form @CompM a@ for some @a@ -asCompMApp :: Term -> Maybe Term -asCompMApp (asApp -> Just (isGlobalDef "Prelude.CompM" -> Just (), tp)) = - return tp -asCompMApp _ = fail "not CompM app" + +-- | Bind a computation in whnf with a function, and normalize +normBind :: NormComp -> CompFun -> MRM NormComp +normBind (ReturnM t) k = applyNormCompFun k t +normBind (ErrorM msg) _ = return (ErrorM msg) +normBind (Ite cond comp1 comp2) k = + return $ Ite cond (CompBind comp1 k) (CompBind comp2 k) +normBind (Either f g t) k = + return $ Either (compFunComp f k) (compFunComp g k) t +normBind (OrM comp1 comp2) k = + return $ OrM (CompBind comp1 k) (CompBind comp2 k) +normBind (ExistsM tp f) k = return $ ExistsM tp (compFunComp f k) +normBind (ForallM tp f) k = return $ ForallM tp (compFunComp f k) +normBind (FunBind f args k1) k2 = + return $ FunBind f args (compFunComp k1 k2) + +-- | Bind a 'Term' for a computation with a function and normalize +normBindTerm :: Term -> CompFun -> MRM NormComp +normBindTerm t f = normCompTerm t >>= \m -> normBind m f -- | Apply a computation function to a term argument to get a computation applyCompFun :: CompFun -> Term -> MRM Comp @@ -348,100 +1251,27 @@ applyCompFun (CompFunComp f g) t = -- (f >=> g) t == f t >>= g do comp <- applyCompFun f t return $ CompBind comp g -applyCompFun (CompFunTerm f) t = - CompTerm <$> liftSC2 scApply f t -applyCompFun (CompFunMark f mark) t = - do comp <- applyCompFun f t - return $ CompMark comp mark +applyCompFun CompFunReturn t = + return $ CompReturn t +applyCompFun (CompFunTerm f) t = CompTerm <$> mrApplyAll f [t] --- | Take in an @InputOutputTypes@ list (as a SAW core term) and build a fresh --- function variable for each pair of input and output types in it -mkFunVarsForTps :: Term -> MRM [LocalFunName] -mkFunVarsForTps (asCtor -> Just (primName -> "Prelude.TypesNil", [])) = - return [] -mkFunVarsForTps (asCtor -> Just (primName -> "Prelude.TypesCons", [a, b, tps])) = - do compM <- liftSC1 scGlobalDef "Prelude.CompM" - comp_b <- liftSC2 scApply compM b - tp <- liftSC3 scPi "x" a comp_b - rest <- mkFunVarsForTps tps - ec <- liftSC2 scFreshEC "f" tp - return (LocalFunName ec : rest) -mkFunVarsForTps t = throwError (MalformedInOutTypes t) - --- | Normalize a computation to weak head normal form -whnfComp :: Comp -> MRM WHNFComp -whnfComp (CompBind m f) = - do norm <- whnfComp m - whnfBind norm f -whnfComp (CompMark m mark) = - do norm <- whnfComp m - whnfMark norm mark -whnfComp (CompTerm t) = - withFailureCtx (FailCtxWHNF t) $ - do t' <- liftSC1 scWhnf t - case asApplyAll t' of - (isGlobalDef "Prelude.returnM" -> Just (), [_, x]) -> - return $ Return x - (isGlobalDef "Prelude.bindM" -> Just (), [_, _, m, f]) -> - do norm <- whnfComp (CompTerm m) - whnfBind norm (CompFunTerm f) - (isGlobalDef "Prelude.errorM" -> Just (), [_]) -> - return Error - (isGlobalDef "Prelude.ite" -> Just (), [_, cond, then_tm, else_tm]) -> - return $ If cond (CompTerm then_tm) (CompTerm else_tm) - (isGlobalDef "Prelude.letRecM" -> Just (), [tps, _, defs_f, body_f]) -> - do funs <- mkFunVarsForTps tps - fun_tms <- mapM (liftSC1 scFlatTermF . ExtCns . unLocalFunName) funs - funs_tm <- - foldr ((=<<) . liftSC2 scPairValue) (liftSC0 scUnitValue) fun_tms - defs_tm <- liftSC2 scApply defs_f funs_tm >>= liftSC1 scWhnf - defs <- case asTupleValue defs_tm of - Just defs -> return defs - Nothing -> throwError (MalformedDefsFun defs_f) - modify $ \st -> - st { mrLocalFuns = (zip funs defs) ++ mrLocalFuns st } - body_tm <- liftSC2 scApply body_f funs_tm - whnfComp (CompTerm body_tm) - ((asFunName -> Just f), args) -> - do comp_tp <- liftSC1 scTypeOf t >>= liftSC1 scWhnf - tp <- - case asCompMApp comp_tp of - Just tp -> return tp - _ -> error "Computation not of type CompM a for some a" - ret_fun <- liftSC1 scGlobalDef "Prelude.returnM" - g <- liftSC2 scApply ret_fun tp - return $ FunBind f args mempty (CompFunTerm g) - _ -> throwError (MalformedComp t') +-- | Apply a 'CompFun' to a term and normalize the resulting computation +applyNormCompFun :: CompFun -> Term -> MRM NormComp +applyNormCompFun f arg = applyCompFun f arg >>= normComp +-- | Apply a 'Comp --- | Bind a computation in whnf with a function, and normalize -whnfBind :: WHNFComp -> CompFun -> MRM WHNFComp -whnfBind (Return t) f = applyCompFun f t >>= whnfComp -whnfBind Error _ = return Error -whnfBind (If cond comp1 comp2) f = - return $ If cond (CompBind comp1 f) (CompBind comp2 f) -whnfBind (FunBind f args mark g) h = - return $ FunBind f args mark (CompFunComp g h) - --- | Mark a normalized computation -whnfMark :: WHNFComp -> Mark -> MRM WHNFComp -whnfMark (Return t) _ = return $ Return t -whnfMark Error _ = return Error -whnfMark (If cond comp1 comp2) mark = - return $ If cond (CompMark comp1 mark) (CompMark comp2 mark) -whnfMark (FunBind f args mark1 g) mark2 = - return $ FunBind f args (mark1 `mappend` mark2) (CompFunMark g mark2) - +{- FIXME: do these go away? -- | Lookup the definition of a function or throw a 'CannotLookupFunDef' if this is -- not allowed, either because it is a global function we are treating as opaque -- or because it is a locally-bound function variable mrLookupFunDef :: FunName -> MRM Term mrLookupFunDef f@(GlobalName _) = throwError (CannotLookupFunDef f) -mrLookupFunDef f@(LocalName nm) = - do fun_assoc <- mrLocalFuns <$> get - case lookup nm fun_assoc of - Just body -> return body - Nothing -> throwError (CannotLookupFunDef f) +mrLookupFunDef f@(LocalName var) = + mrVarInfo var >>= \case + Just (FunVarInfo body) -> return body + Just _ -> throwError (CannotLookupFunDef f) + Nothing -> error "mrLookupFunDef: unknown variable!" -- | Unfold a call to function @f@ in term @f args >>= g@ mrUnfoldFunBind :: FunName -> [Term] -> Mark -> CompFun -> MRM Comp @@ -452,174 +1282,267 @@ mrUnfoldFunBind f args mark g = (CompMark <$> (CompTerm <$> liftSC2 scApplyAll f_def args) <*> (return $ singleMark f `mappend` mark)) <*> return g - --- | Coinductively prove an equality between two named functions by assuming --- the names are equal and proving their bodies equal -mrSolveCoInd :: FunName -> FunName -> MRM () -mrSolveCoInd f1 f2 = - do def1 <- mrLookupFunDef f1 - def2 <- mrLookupFunDef f2 - saved <- get - put $ saved { mrFunEqs = ((f1,f2),True) : mrFunEqs saved } - catchError (mrSolveEq (CompFunMark (CompFunTerm def1) (singleMark f1)) - (CompFunMark (CompFunTerm def2) (singleMark f2))) $ \err -> - -- NOTE: any equalities proved under the assumption that f1 == f2 are - -- suspect, so we have to throw them out and revert to saved on error - (put saved >> throwError err) - - --- | Typeclass for proving that two (representations of) objects of the same SAW --- core type @a@ are equivalent, where the notion of equivalent depends on the --- type @a@. This assumes that the two objects have the same SAW core type. The --- 'MRM' computation returns @()@ on success and throws a 'MRFailure' on error. -class (IsMRTerm a, IsMRTerm b) => MRSolveEq a b where - mrSolveEq' :: a -> b -> MRM () - --- | The main function for solving equations, that calls @mrSovleEq'@ but with --- debugging support for errors, i.e., adding to the failure context -mrSolveEq :: MRSolveEq a b => a -> b -> MRM () -mrSolveEq a b = - withFailureCtx (FailCtxCmp (toMRTerm a) (toMRTerm b)) $ mrSolveEq' a b - --- NOTE: this instance is specifically for terms of non-computation type -instance MRSolveEq Term Term where - mrSolveEq' t1 t2 = - do eq <- mrTermsEq t1 t2 - if eq then return () else throwError (TermsNotEq t1 t2) - -instance MRSolveEq Type Type where - mrSolveEq' tp1@(Type t1) tp2@(Type t2) = - do eq <- liftSC3 scConvertible True t1 t2 - if eq then return () else - throwError (TypesNotEq tp1 tp2) - -instance MRSolveEq FunName FunName where - mrSolveEq' f1 f2 | f1 == f2 = return () - mrSolveEq' f1 f2 = - do eqs <- mrFunEqs <$> get - case lookup (f1,f2) eqs of - Just True -> return () - Just False -> throwError (FunsNotEq f1 f2) - Nothing -> mrSolveCoInd f1 f2 - -instance MRSolveEq Comp Comp where - mrSolveEq' comp1 comp2 = - do norm1 <- whnfComp comp1 - norm2 <- whnfComp comp2 - mrSolveEq norm1 norm2 - -instance MRSolveEq CompFun CompFun where - mrSolveEq' f1 f2 = - do tp <- compFunInputType f1 - var <- liftSC2 scFreshGlobal "x" tp - comp1 <- applyCompFun f1 var - comp2 <- applyCompFun f2 var - mrSolveEq comp1 comp2 - -instance MRSolveEq Comp WHNFComp where - mrSolveEq' comp1 norm2 = - do norm1 <- whnfComp comp1 - mrSolveEq norm1 norm2 - -instance MRSolveEq WHNFComp Comp where - mrSolveEq' norm1 comp2 = - do norm2 <- whnfComp comp2 - mrSolveEq norm1 norm2 - -instance MRSolveEq WHNFComp WHNFComp where - mrSolveEq' (Return t1) (Return t2) = - -- Returns are equal iff their returned values are - mrSolveEq t1 t2 - mrSolveEq' (Return t1) Error = - -- Return is never equal to error - throwError (ReturnNotError t1) - mrSolveEq' Error (Return t2) = - -- Return is never equal to error - throwError (ReturnNotError t2) - mrSolveEq' Error Error = - -- Error trivially equals itself - return () - mrSolveEq' (If cond1 then1 else1) norm2@(If cond2 then2 else2) = - -- Special case if the two conditions are equal: assert the one condition to - -- test the then branches and assert its negtion to test the elses - do eq <- mrTermsEq cond1 cond2 - if eq then - (withPathCondition cond1 $ mrSolveEq then1 then2) >> - (withNotPathCondition cond1 $ mrSolveEq else1 else2) - else - -- Otherwise, compare the first then and else, under their respective - -- path conditions, to the whole second computation - (withPathCondition cond1 $ mrSolveEq then1 norm2) >> - (withNotPathCondition cond1 $ mrSolveEq else1 norm2) - mrSolveEq' (If cond1 then1 else1) norm2 = - -- To compare an if to anything else, compare the then and else, under their - -- respective path conditions, to the other computation - (withPathCondition cond1 $ mrSolveEq then1 norm2) >> - (withNotPathCondition cond1 $ mrSolveEq else1 norm2) - mrSolveEq' norm1 (If cond2 then2 else2) = - -- To compare an if to anything else, compare the then and else, under their - -- respective path conditions, to the other computation - (withPathCondition cond2 $ mrSolveEq norm1 then2) >> - (withNotPathCondition cond2 $ mrSolveEq norm1 else2) - mrSolveEq' comp1@(FunBind f1 args1 mark1 k1) comp2@(FunBind f2 args2 mark2 k2) = - -- To compare two computations (f1 args1 >>= norm1) and (f2 args2 >>= norm2) - -- we first test if (f1 args1) and (f2 args2) are equal. If so, we recurse - -- and compare norm1 and norm2; otherwise, we try unfolding one or the other - -- of f1 and f2. - catchErrorEither cmp_funs >>= \ cmp_fun_res -> - case cmp_fun_res of - Right () -> mrSolveEq k1 k2 - Left err -> - mapFailure (MRFailureDisj err) $ - (mrUnfoldFunBind f1 args1 mark1 k1 >>= \c -> mrSolveEq c comp2) - `mrOr` - (mrUnfoldFunBind f2 args2 mark2 k2 >>= \c -> mrSolveEq comp1 c) - where - cmp_funs = - do tp1 <- funNameType f1 - tp2 <- funNameType f2 - mrSolveEq (Type tp1) (Type tp2) - mrSolveEq f1 f2 - zipWithM_ mrSolveEq args1 args2 - mrSolveEq' (FunBind f1 args1 mark1 k1) comp2 = - -- This case compares a function call to a Return or Error; the only thing - -- to do is unfold the function call and recurse - mrUnfoldFunBind f1 args1 mark1 k1 >>= \c -> mrSolveEq c comp2 - mrSolveEq' comp1 (FunBind f2 args2 mark2 k2) = - -- This case compares a function call to a Return or Error; the only thing - -- to do is unfold the function call and recurse - mrUnfoldFunBind f2 args2 mark2 k2 >>= \c -> mrSolveEq comp1 c +-} + +{- +FIXME HERE NOW: maybe each FunName should stipulate whether it is recursive or +not, so that mrRefines can unfold the non-recursive ones early but wait on +handling the recursive ones +-} + +---------------------------------------------------------------------- +-- * Mr Solver Himself (He Identifies as Male) +---------------------------------------------------------------------- + +-- | An object that can be converted to a normalized computation +class ToNormComp a where + toNormComp :: a -> MRM NormComp + +instance ToNormComp NormComp where + toNormComp = return +instance ToNormComp Comp where + toNormComp = normComp +instance ToNormComp Term where + toNormComp = normComp . CompTerm + +-- | Prove that the left-hand computation refines the right-hand one. See the +-- rules described at the beginning of this module. +mrRefines :: (ToNormComp a, ToNormComp b) => a -> b -> MRM () +mrRefines t1 t2 = + do m1 <- toNormComp t1 + m2 <- toNormComp t2 + mrDebugPPPrefixSep 1 "mrRefines" m1 "|=" m2 + withFailureCtx (FailCtxRefines m1 m2) $ mrRefines' m1 m2 + +-- | The main implementation of 'mrRefines' +mrRefines' :: NormComp -> NormComp -> MRM () +mrRefines' (ReturnM e1) (ReturnM e2) = mrProveEq e1 e2 +mrRefines' (ErrorM _) (ErrorM _) = return () +mrRefines' (ReturnM e) (ErrorM _) = throwError (ReturnNotError e) +mrRefines' (ErrorM _) (ReturnM e) = throwError (ReturnNotError e) +mrRefines' (Ite cond1 m1 m1') m2_all@(Ite cond2 m2 m2') = + liftSC1 scNot cond1 >>= \not_cond1 -> + (mrEq cond1 cond2 >>= mrProvable) >>= \case + True -> + -- If we can prove cond1 == cond2, then we just need to prove m1 |= m2 and + -- m1' |= m2'; further, we need only add assumptions about cond1, because it + -- is provably equal to cond2 + withAssumption cond1 (mrRefines m1 m2) >> + withAssumption not_cond1 (mrRefines m1' m2') + False -> + -- Otherwise, prove each branch of the LHS refines the whole RHS + withAssumption cond1 (mrRefines m1 m2_all) >> + withAssumption not_cond1 (mrRefines m1' m2_all) +mrRefines' (Ite cond1 m1 m1') m2 = + do not_cond1 <- liftSC1 scNot cond1 + withAssumption cond1 (mrRefines m1 m2) + withAssumption not_cond1 (mrRefines m1' m2) +mrRefines' m1 (Ite cond2 m2 m2') = + do not_cond2 <- liftSC1 scNot cond2 + withAssumption cond2 (mrRefines m1 m2) + withAssumption not_cond2 (mrRefines m1 m2') +-- FIXME: handle sum elimination +-- mrRefines (Either f1 g1 e1) (Either f2 g2 e2) = +mrRefines' m1 (ForallM tp f2) = + let nm = maybe "x" id (compFunVarName f2) in + withUVarLift nm tp (m1,f2) $ \x (m1',f2') -> + applyNormCompFun f2' x >>= \m2' -> + mrRefines m1' m2' +mrRefines' (ExistsM tp f1) m2 = + let nm = maybe "x" id (compFunVarName f1) in + withUVarLift nm tp (f1,m2) $ \x (f1',m2') -> + applyNormCompFun f1' x >>= \m1' -> + mrRefines m1' m2' +mrRefines' m1 (OrM m2 m2') = + mrOr (mrRefines m1 m2) (mrRefines m1 m2') +mrRefines' (OrM m1 m1') m2 = + mrRefines m1 m2 >> mrRefines m1' m2 + +-- FIXME: the following cases don't work unless we either allow evars to be set +-- to NormComps or we can turn NormComps back into terms +mrRefines' m1@(FunBind (EVarFunName _) _ _) m2 = + throwError (CompsDoNotRefine m1 m2) +mrRefines' m1 m2@(FunBind (EVarFunName _) _ _) = + throwError (CompsDoNotRefine m1 m2) +{- +mrRefines' (FunBind (EVarFunName evar) args CompFunReturn) m2 = + mrGetEVar evar >>= \case + Just f -> + (mrApplyAll f args >>= normCompTerm) >>= \m1' -> + mrRefines m1' m2 + Nothing -> mrTrySetAppliedEVar evar args m2 +-} + +mrRefines' (FunBind (LetRecName f) args1 k1) (FunBind (LetRecName f') args2 k2) + | f == f' && length args1 == length args2 = + zipWithM_ mrProveEq args1 args2 >> + mrRefinesFun k1 k2 + +mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = + mrFunOutType f1 args1 >>= \tp1 -> + mrFunOutType f2 args2 >>= \tp2 -> + mrConvertible tp1 tp2 >>= \tps_eq -> + mrFunBodyRecInfo f1 args1 >>= \maybe_f1_body -> + mrFunBodyRecInfo f2 args2 >>= \maybe_f2_body -> + mrGetFunAssump f1 >>= \case + + -- If we have an assumption that f1 args' refines some rhs, then prove that + -- args1 = args' and then that rhs refines m2 + Just fassump -> + do (assump_args, assump_rhs) <- instantiateFunAssump fassump + zipWithM_ mrProveEq assump_args args1 + m1' <- normBind assump_rhs k1 + mrRefines m1' m2 + + -- If f1 unfolds and is not recursive in itself, unfold it and recurse + _ | Just (f1_body, False) <- maybe_f1_body -> + normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 + + -- If f2 unfolds and is not recursive in itself, unfold it and recurse + _ | Just (f2_body, False) <- maybe_f2_body -> + normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + + -- If we do not already have an assumption that f1 refines some specification, + -- and both f1 and f2 are recursive but have the same return type, then try to + -- coinductively prove that f1 args1 |= f2 args2 under the assumption that f1 + -- args1 |= f2 args2, and then try to prove that k1 |= k2 + Nothing + | tps_eq + , Just (f1_body, _) <- maybe_f1_body + , Just (f2_body, _) <- maybe_f2_body -> + do withFunAssump f1 args1 (FunBind f2 args2 CompFunReturn) $ + mrRefines f1_body f2_body + mrRefinesFun k1 k2 + + -- If we cannot line up f1 and f2, then making progress here would require us + -- to somehow split either m1 or m2 into some bind m' >>= k' such that m' is + -- related to the function call on the other side and k' is related to the + -- continuation on the other side, but we don't know how to do that, so give + -- up + Nothing -> + throwError (CompsDoNotRefine m1 m2) + +{- FIXME: handle FunBind on just one side +mrRefines' m1@(FunBind f@(GlobalName _) args k1) m2 = + mrGetFunAssump f >>= \case + Just fassump -> + -- If we have an assumption that f args' refines some rhs, then prove that + -- args = args' and then that rhs refines m2 + do (assump_args, assump_rhs) <- instantiateFunAssump fassump + zipWithM_ mrProveEq assump_args args + m1' <- normBind assump_rhs k1 + mrRefines m1' m2 + Nothing -> + -- We don't want to do inter-procedural proofs, so if we don't know anything + -- about f already then give up + throwError (CompsDoNotRefine m1 m2) +-} + + +mrRefines' m1@(FunBind f1 args1 k1) m2 = + mrGetFunAssump f1 >>= \case + + -- If we have an assumption that f1 args' refines some rhs, then prove that + -- args1 = args' and then that rhs refines m2 + Just fassump -> + do (assump_args, assump_rhs) <- instantiateFunAssump fassump + zipWithM_ mrProveEq assump_args args1 + m1' <- normBind assump_rhs k1 + mrRefines m1' m2 + + -- Otherwise, see if we can unfold f1 + Nothing -> + mrFunBodyRecInfo f1 args1 >>= \case + + -- If f1 unfolds and is not recursive in itself, unfold it and recurse + Just (f1_body, False) -> + normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 + + -- Otherwise we would have to somehow split m2 into some computation of the + -- form m2' >>= k2 where f1 args1 |= m2' and k1 |= k2, but we don't know how + -- to do this splitting, so give up + _ -> + throwError (CompsDoNotRefine m1 m2) + + +mrRefines' m1 m2@(FunBind f2 args2 k2) = + mrFunBodyRecInfo f2 args2 >>= \case + + -- If f2 unfolds and is not recursive in itself, unfold it and recurse + Just (f2_body, False) -> + normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + + -- If f2 unfolds but is recursive, and k2 is the trivial continuation, meaning + -- m2 is just f2 args2, use the law of coinduction to prove m1 |= f2 args2 by + -- proving m1 |= f2_body under the assumption that m1 |= f2 args2 + {- FIXME: implement something like this + Just (f2_body, True) + | CompFunReturn <- k2 -> + withFunAssumpR m1 f2 args2 $ + -} + + -- Otherwise we would have to somehow split m1 into some computation of the + -- form m1' >>= k1 where m1' |= f2 args2 and k1 |= k2, but we don't know how + -- to do this splitting, so give up + _ -> + throwError (CompsDoNotRefine m1 m2) + + +-- NOTE: the rules that introduce existential variables need to go last, so that +-- they can quantify over as many universals as possible +mrRefines' m1 (ExistsM tp f2) = + do let nm = maybe "x" id (compFunVarName f2) + evar <- mrFreshEVar nm tp + m2' <- applyNormCompFun f2 evar + mrRefines m1 m2' +mrRefines' (ForallM tp f1) m2 = + do let nm = maybe "x" id (compFunVarName f1) + evar <- mrFreshEVar nm tp + m1' <- applyNormCompFun f1 evar + mrRefines m1' m2 + +-- If none of the above cases match, then fail +mrRefines' m1 m2 = throwError (CompsDoNotRefine m1 m2) + + +-- | Prove that one function refines another for all inputs +mrRefinesFun :: CompFun -> CompFun -> MRM () +mrRefinesFun CompFunReturn CompFunReturn = return () +mrRefinesFun f1 f2 + | Just nm <- compFunVarName f1 `mplus` compFunVarName f2 + , Just tp <- compFunInputType f1 `mplus` compFunInputType f2 = + withUVarLift nm tp (f1,f2) $ \x (f1', f2') -> + do m1' <- applyNormCompFun f1' x + m2' <- applyNormCompFun f2' x + mrRefines m1' m2' +mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" + + +---------------------------------------------------------------------- +-- * External Entrypoints +---------------------------------------------------------------------- -- | Test two monadic, recursive terms for equivalence askMRSolver :: SharedContext -> + Int {- ^ The debug level -} -> SBV.SMTConfig {- ^ SBV configuration -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> Term -> Term -> IO (Maybe MRFailure) -askMRSolver sc smt_conf timeout t1 t2 = +askMRSolver sc dlvl smt_conf timeout t1 t2 = do tp1 <- scTypeOf sc t1 tp2 <- scTypeOf sc t2 - true_tm <- scBool sc True - let init_st = MRState { - mrSC = sc, - mrSMTConfig = smt_conf, - mrSMTTimeout = timeout, - mrLocalFuns = [], - mrFunEqs = [], - mrPathCondition = true_tm - } - res <- - flip evalStateT init_st $ runExceptT $ - do mrSolveEq (Type tp1) (Type tp2) - let (pi_args, ret_tp) = asPiList tp1 - vars <- mapM (\(x, x_tp) -> liftSC2 scFreshGlobal x x_tp) pi_args - case asCompMApp ret_tp of - Just _ -> return () - Nothing -> throwError (NotCompFunType tp1) - t1_app <- liftSC2 scApplyAll t1 vars - t2_app <- liftSC2 scApplyAll t2 vars - mrSolveEq (CompTerm t1_app) (CompTerm t2_app) - case res of - Left err -> return $ Just err - Right () -> return Nothing + init_st <- mkMRState sc Map.empty smt_conf timeout dlvl + case asPiList tp1 of + (uvar_ctx, asCompM -> Just _) -> + fmap (either Just (const Nothing)) $ runMRM init_st $ + withUVars uvar_ctx $ \vars -> + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwError (TypesNotEq (Type tp1) (Type tp2)) + mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 + m1 <- mrApplyAll t1 vars >>= normCompTerm + m2 <- mrApplyAll t2 vars >>= normCompTerm + mrRefines m1 m2 + _ -> return $ Just $ NotCompFunType tp1 diff --git a/src/SAWScript/Value.hs b/src/SAWScript/Value.hs index 917d399452..ca1eac5f59 100644 --- a/src/SAWScript/Value.hs +++ b/src/SAWScript/Value.hs @@ -81,6 +81,7 @@ import SAWScript.X86 (X86Unsupported(..), X86Error(..)) import Verifier.SAW.Name (toShortName) import Verifier.SAW.CryptolEnv as CEnv +import Verifier.SAW.Cryptol.Monadify as Monadify import Verifier.SAW.FiniteValue (FirstOrderValue, ppFirstOrderValue) import Verifier.SAW.Rewriter (Simpset, lhsRewriteRule, rhsRewriteRule, listRules) import Verifier.SAW.SharedTerm hiding (PPOpts(..), defaultPPOpts, @@ -429,6 +430,7 @@ data TopLevelRW = , rwTypedef :: Map SS.Name SS.Type , rwDocs :: Map SS.Name String , rwCryptol :: CEnv.CryptolEnv + , rwMonadify :: Monadify.MonadifyEnv , rwProofs :: [Value] {- ^ Values, generated anywhere, that represent proofs. -} , rwPPOpts :: PPOpts -- , rwCrucibleLLVMCtx :: Crucible.LLVMContext