Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewriting SAW core vector folds #1811

Merged
merged 5 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions intTests/test_fold_rewrite_proof/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
$SAW tupletest.saw
28 changes: 28 additions & 0 deletions intTests/test_fold_rewrite_proof/tupletest.cry
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module tupletest where

foldFunction : [8] -> [16] -> [16] -> [8]
foldFunction x y z = output.0
where
output = foldl fnc (x, y, z) [0 .. 15]

fnc : ([8], [16], [16]) -> [4] -> ([8], [16], [16])
fnc (x, y, z) i = returnTup
where
returnTup = (x ^ take y' ^ take z', y', z')
y' = y <<< i
z' = z >>> i

foldFunction' : [8] -> [16] -> [16] -> [8]
foldFunction' x y z = output.0
where
output = foldl fnc' (x, y, z) [15, 14 .. 0]

fnc' : ([8], [16], [16]) -> [4] -> ([8], [16], [16])
fnc' (x, y, z) i = returnTup
where
returnTup = (x ^ take y ^ take z, y', z')
y' = y >>> i
z' = z <<< i

property foldFunctionInverse x y z =
foldFunction' (foldFunction x y z) y z == x
26 changes: 26 additions & 0 deletions intTests/test_fold_rewrite_proof/tupletest.saw
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
enable_experimental;

let use_lemmas lemmas =
simplify (addsimps lemmas
(add_prelude_eqs ["foldl_cons","foldl_nil","head_gen","tail_gen"] (cryptol_ss())));

let proveit p script =
do {
print (str_concat "Proving " (show_term p));
time (prove_print script p);
};

import "tupletest.cry";

fnc_lemma <- proveit {{ \x y z i -> (fnc' (fnc (x, y, z) i) i).0 == x }} z3;

proveit {{ foldFunctionInverse }} do {
unfolding [ "foldFunctionInverse"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is redundant, goal_normalize ["fnc", "fnc'"]; should do this as well.

, "foldFunction"
, "foldFunction'"
];
goal_normalize ["fnc", "fnc'"];
simplify (add_prelude_eqs ["foldl_cons","foldl_nil",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to use fnc_lemma here with the use_lemmas function you created above?

Also, this script would be more convincing if the trivial prover was used (rather than z3).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, that it would be a better example if we could figure out a way to use fnc_lemma, but unfortunately it doesn't match any term in the goal. I tried for a while to modify fnc_lemma to actually apply in this case, but couldn't figure out a way to make it work, so I just had Z3 unfold all the calls to fnc and fnc' to prove the result. After rewriting all the foldls away, the resulting term looks like

\x y z -> (fnc' (fnc' (... (fnc' ((fnc (fnc (... (fnc (x,y,z) 0)...) 14) 15).0, y, z) 15) ... 1) 0).0 == True

So fnc_lemma needs to somehow rewrite a term of that form, probably something of the form fnc' ((fnc tup i).0, y, z) i but I wasn't sure what to rewrite that to, since the final .0 doesn't apply until all 16 fnc' applications happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this example for now is to demonstrate rewriting folds, and it does that. So I'm going to merge this PR now. I'd love to figure out the "right" way to write this lemma, but let's not wait on that before merging this PR.

"head_gen","tail_gen"] (cryptol_ss()));
z3;
};
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ Fixpoint gen (n : nat) (a : Type) (f : nat -> a) {struct n} : Vec n a.
).
Defined.

Definition head (n : nat) (a : Type) (v : Vec (S n) a) : a := hd v.
Definition tail (n : nat) (a : Type) (v : Vec (S n) a) : Vec n a := tl v.

Lemma head_gen (n : nat) (a : Type) (f : nat -> a) :
head n a (gen (Succ n) a f) = f 0.
Proof.
reflexivity.
Qed.

Lemma tail_gen (n : nat) (a : Type) (f : nat -> a) :
tail n a (gen (Succ n) a f) = gen n a (fun (i:Nat) => f (Succ i)).
Proof.
reflexivity.
Qed.

Instance Inhabited_Vec (n:nat) (a:Type) {Ha:Inhabited a} : Inhabited (Vec n a) :=
MkInhabited (Vec n a) (gen n a (fun _ => inhabitant)).

Expand Down Expand Up @@ -156,12 +171,40 @@ Fixpoint foldr (a b : Type) (n : nat) (f : a -> b -> b) (base : b) (v : Vec n a)
| Vector.cons hd _ tl => f hd (foldr _ _ _ f base tl)
end.

Lemma foldr_nil (a b : Type) (f : a -> b -> b) (base : b) (v : Vec 0 a) :
foldr a b 0 f base v = base.
Proof.
rewrite (Vec_0_nil _ v). reflexivity.
Qed.

Lemma foldr_cons (a b : Type) (n : nat) (f : a -> b -> b) (base : b)
(v : Vec (S n) a) : foldr a b (S n) f base v = f (hd v) (foldr a b n f base (tl v)).
Proof.
destruct (Vec_S_cons _ _ v) as [ x [ xs pf ]].
rewrite pf. reflexivity.
Qed.


Fixpoint foldl (a b : Type) (n : nat) (f : b -> a -> b) (acc : b) (v : Vec n a) : b :=
match v with
| Vector.nil => acc
| Vector.cons hd _ tl => foldl _ _ _ f (f acc hd) tl
end.

Lemma foldl_nil (a b : Type) (f : b -> a -> b) (base : b) (v : Vec 0 a) :
foldl a b 0 f base v = base.
Proof.
rewrite (Vec_0_nil _ v). reflexivity.
Qed.

Lemma foldl_cons (a b : Type) (n : nat) (f : b -> a -> b) (base : b)
(v : Vec (S n) a) :
foldl a b (S n) f base v = foldl a b n f (f base (hd v)) (tl v).
Proof.
destruct (Vec_S_cons _ _ v) as [ x [ xs pf ]].
rewrite pf. reflexivity.
Qed.

Fixpoint scanl (a b : Type) (n : nat) (f : b -> a -> b) (acc : b) (v : Vec n a) : Vec (S n) b :=
match v in VectorDef.t _ n return Vec (S n) b with
| Vector.nil => [ acc ]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,11 @@ sawCorePreludeSpecialTreatmentMap configuration =
, ("coerceVec", mapsTo vectorsModule "coerceVec")
, ("eq_Vec", skip)
, ("foldr", mapsTo vectorsModule "foldr")
, ("foldr_nil", mapsTo vectorsModule "foldr_nil")
, ("foldr_cons", mapsTo vectorsModule "foldr_cons")
, ("foldl", mapsTo vectorsModule "foldl")
, ("foldl_nil", mapsTo vectorsModule "foldl_nil")
, ("foldl_cons", mapsTo vectorsModule "foldl_cons")
, ("gen_at_BVVec", mapsTo preludeExtraModule "gen_at_BVVec")
, ("genWithProof", mapsTo vectorsModule "genWithProof")
, ("scanl", mapsTo vectorsModule "scanl")
Expand All @@ -409,6 +413,10 @@ sawCorePreludeSpecialTreatmentMap configuration =
, ("zip", realize zipSnippet)
-- cannot map directly to Vector.t because arguments are in a different order
, ("Vec", mapsTo vectorsModule "Vec")
, ("head", mapsTo vectorsModule "head")
, ("tail", mapsTo vectorsModule "tail")
, ("head_gen", mapsTo vectorsModule "head_gen")
, ("tail_gen", mapsTo vectorsModule "tail_gen")
]

-- Streams
Expand Down
24 changes: 24 additions & 0 deletions saw-core/prelude/Prelude.sawcore
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,14 @@ primitive gen : (n : Nat) -> (a : sort 0) -> (Nat -> a) -> Vec n a;
primitive head : (n : Nat) -> (a : sort 0) -> Vec (Succ n) a -> a;
primitive tail : (n : Nat) -> (a : sort 0) -> Vec (Succ n) a -> Vec n a;

-- Axioms describing head and tail in terms of gen
axiom head_gen : (n : Nat) -> (a : sort 0) -> (f : Nat -> a) ->
Eq a (head n a (gen (Succ n) a f)) (f 0);

axiom tail_gen : (n : Nat) -> (a : sort 0) -> (f : Nat -> a) ->
Eq (Vec n a) (tail n a (gen (Succ n) a f))
(gen n a (\ (i:Nat) -> f (Succ i)));

-- An implementation for atWithDefault
--
-- FIXME: can we replace atWithDefault with this implementation? Or does some
Expand Down Expand Up @@ -1153,6 +1161,22 @@ primitive foldr : (a b : sort 0) -> (n : Nat) -> (a -> b -> b) -> b -> Vec n a -
primitive foldl : (a b : sort 0) -> (n : Nat) -> (b -> a -> b) -> b -> Vec n a -> b;
primitive scanl : (a b : sort 0) -> (n : Nat) -> (b -> a -> b) -> b -> Vec n a -> Vec (addNat 1 n) b;

-- Axioms defining foldr
axiom foldr_nil : (a b : sort 0) -> (f : a -> b -> b) -> (x : b) ->
(v : Vec 0 a) -> Eq b (foldr a b 0 f x v) x;
axiom foldr_cons : (a b : sort 0) -> (n : Nat) -> (f : a -> b -> b) -> (x : b) ->
(v : Vec (Succ n) a) ->
Eq b (foldr a b (Succ n) f x v)
(f (head n a v) (foldr a b n f x (tail n a v)));

-- Axioms defining foldl
axiom foldl_nil : (a b : sort 0) -> (f : b -> a -> b) -> (x : b) ->
(v : Vec 0 a) -> Eq b (foldl a b 0 f x v) x;
axiom foldl_cons : (a b : sort 0) -> (n : Nat) -> (f : b -> a -> b) -> (x : b) ->
(v : Vec (Succ n) a) ->
Eq b (foldl a b (Succ n) f x v)
(foldl a b n f (f x (head n a v)) (tail n a v));

reverse : (n : Nat) -> (a : isort 0) -> Vec n a -> Vec n a;
reverse n a xs = gen n a (\ (i : Nat) -> at n a xs (subNat (subNat n 1) i));

Expand Down
27 changes: 22 additions & 5 deletions saw-core/src/Verifier/SAW/Rewriter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ first_order_match pat term = match pat term Map.empty
-- occur as the 2nd argument of an @App@ constructor. This ensures
-- that instantiations are well-typed.

-- | Test if a term is a constant natural number
asConstantNat :: Term -> Maybe Natural
asConstantNat t =
case R.asCtor t of
Expand Down Expand Up @@ -202,11 +203,13 @@ scMatch ::
IO (Maybe (Map DeBruijnIndex Term))
scMatch sc pat term =
runMaybeT $
do --lift $ putStrLn $ "********** scMatch **********"
do -- lift $ putStrLn $ "********** scMatch **********"
MatchState inst cs <- match 0 [] pat term emptyMatchState
mapM_ (check inst) cs
return inst
where
-- Check that a constraint of the form pat = n for natural number literal n
-- is satisfied by the supplied substitution (aka instantiation) inst
check :: Map DeBruijnIndex Term -> (Term, Natural) -> MaybeT IO ()
check inst (t, n) = do
--lift $ putStrLn $ "checking: " ++ show (t, n)
Expand All @@ -219,6 +222,11 @@ scMatch sc pat term =
Just i | i == n -> return ()
_ -> mzero

-- Check if a term is a higher-order variable pattern, i.e., a free variable
-- (meaning one that can match anything) applied to 0 or more bound variable
-- arguments. Depth is the number of variables bound by lambdas or pis since
-- the top of the current pattern, so "free" means >= the current depth and
-- "bound" means less than the current depth
asVarPat :: Int -> Term -> Maybe (DeBruijnIndex, [DeBruijnIndex])
asVarPat depth = go []
where
Expand All @@ -231,13 +239,17 @@ scMatch sc pat term =
| j < depth -> go (j : js) t
_ -> Nothing

match :: Int -> [(LocalName, Term)] -> Term -> Term -> MatchState -> MaybeT IO MatchState
-- Test if term y matches pattern x, meaning whether there is a substitution
-- to the free variables of x to make it equal to y. Depth is the number of
-- bound variables, so a "free" variable is a deBruijn index >= depth. Env
-- saves the names associated with those bound variables.
match :: Int -> [(LocalName, Term)] -> Term -> Term -> MatchState ->
MaybeT IO MatchState
match _ _ (STApp i fv _) (STApp j _ _) s
| fv == emptyBitSet && i == j = return s
match depth env x y s@(MatchState m cs) =
--do
--lift $ putStrLn $ "matching (lhs): " ++ scPrettyTerm defaultPPOpts x
--lift $ putStrLn $ "matching (rhs): " ++ scPrettyTerm defaultPPOpts y
-- (lift $ putStrLn $ "matching (lhs): " ++ scPrettyTerm defaultPPOpts x) >>
-- (lift $ putStrLn $ "matching (rhs): " ++ scPrettyTerm defaultPPOpts y) >>
case asVarPat depth x of
Just (i, js) ->
do -- ensure parameter variables are distinct
Expand Down Expand Up @@ -268,6 +280,11 @@ scMatch sc pat term =
Just y3 -> if y2 == y3 then return (MatchState m' cs) else mzero
Nothing ->
case (unwrapTermF x, unwrapTermF y) of
(_, FTermF (NatLit n))
| Just (c, [x']) <- R.asCtor x
, primName c == preludeSuccIdent && n > 0 ->
do y' <- lift $ scNat sc (n-1)
match depth env x' y' s
-- check that neither x nor y contains bound variables less than `depth`
(FTermF xf, FTermF yf) ->
case zipWithFlatTermF (match depth env) xf yf of
Expand Down