Skip to content

Commit af9172f

Browse files
authored
Merge pull request #1675 from GaloisInc/mr-solver/clean-up-interface
Improve MRSolver interface
2 parents 7f1a012 + b42baa2 commit af9172f

File tree

12 files changed

+295
-158
lines changed

12 files changed

+295
-158
lines changed

cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs

+41-11
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,11 @@ ppTermInMonCtx :: MonadifyCtx -> Term -> String
681681
ppTermInMonCtx ctx t =
682682
scPrettyTermInCtx defaultPPOpts (map (\(x,_,_) -> x) ctx) t
683683

684-
-- | A memoization table for monadifying terms
685-
type MonadifyMemoTable = IntMap MonTerm
684+
-- | A memoization table for monadifying terms: a map from 'TermIndex'es to
685+
-- 'MonTerm's and, possibly, corresponding 'ArgMonTerm's. The latter are simply
686+
-- the result of calling 'argifyMonTerm' on the former, but are only added when
687+
-- needed (i.e. when 'memoArgMonTerm' is called, e.g. in 'monadifyArg').
688+
type MonadifyMemoTable = IntMap (MonTerm, Maybe ArgMonTerm)
686689

687690
-- | The empty memoization table
688691
emptyMemoTable :: MonadifyMemoTable
@@ -752,15 +755,34 @@ runCompleteMonadifyM sc env top_ret_tp m =
752755
runMonadifyM env [] (toArgType $ monadifyType [] top_ret_tp) m
753756

754757
-- | Memoize a computation of the monadified term associated with a 'TermIndex'
755-
memoizingM :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm
756-
memoizingM i m =
758+
memoMonTerm :: TermIndex -> MonadifyM MonTerm -> MonadifyM MonTerm
759+
memoMonTerm i m =
757760
(IntMap.lookup i <$> get) >>= \case
758-
Just ret ->
759-
return ret
761+
Just (mtm, _) ->
762+
return mtm
760763
Nothing ->
761-
do ret <- m
762-
modify (IntMap.insert i ret)
763-
return ret
764+
do mtm <- m
765+
modify (IntMap.insert i (mtm, Nothing))
766+
return mtm
767+
768+
-- | Memoize a computation of the monadified term of argument type associated
769+
-- with a 'TermIndex', using a memoized 'ArgTerm' directly if it exists or
770+
-- applying 'argifyMonTerm' to a memoized 'MonTerm' (and memoizing the result)
771+
-- if it exists
772+
memoArgMonTerm :: TermIndex -> MonadifyM MonTerm -> MonadifyM ArgMonTerm
773+
memoArgMonTerm i m =
774+
(IntMap.lookup i <$> get) >>= \case
775+
Just (_, Just argmtm) ->
776+
return argmtm
777+
Just (mtm, Nothing) ->
778+
do argmtm <- argifyMonTerm mtm
779+
modify (IntMap.insert i (mtm, Just argmtm))
780+
return argmtm
781+
Nothing ->
782+
do mtm <- m
783+
argmtm <- argifyMonTerm mtm
784+
modify (IntMap.insert i (mtm, Just argmtm))
785+
return argmtm
764786

765787
-- | Turn a 'MonTerm' of type @CompMT(tp)@ to a term of argument type @MT(tp)@
766788
-- by inserting a monadic bind if the 'MonTerm' is computational
@@ -799,7 +821,15 @@ monadifyTypeM tp =
799821

800822
-- | Monadify a term to a monadified term of argument type
801823
monadifyArg :: Maybe MonType -> Term -> MonadifyM ArgMonTerm
802-
monadifyArg mtp t = monadifyTerm mtp t >>= argifyMonTerm
824+
{-
825+
monadifyArg _ t
826+
| trace ("Monadifying term of argument type: " ++ showTerm t) False
827+
= undefined
828+
-}
829+
monadifyArg mtp t@(STApp { stAppIndex = ix }) =
830+
memoArgMonTerm ix $ monadifyTerm' mtp t
831+
monadifyArg mtp t =
832+
monadifyTerm' mtp t >>= argifyMonTerm
803833

804834
-- | Monadify a term to argument type and convert back to a term
805835
monadifyArgTerm :: Maybe MonType -> Term -> MonadifyM OpenTerm
@@ -813,7 +843,7 @@ monadifyTerm _ t
813843
= undefined
814844
-}
815845
monadifyTerm mtp t@(STApp { stAppIndex = ix }) =
816-
memoizingM ix $ monadifyTerm' mtp t
846+
memoMonTerm ix $ monadifyTerm' mtp t
817847
monadifyTerm mtp t =
818848
monadifyTerm' mtp t
819849

examples/mr_solver/mr_solver_unit_tests.saw

+11-11
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ let run_test name test expected =
1111
do { if expected then print (str_concat "Test: " name) else
1212
print (str_concat (str_concat "Test: " name) " (expecting failure)");
1313
actual <- test;
14-
if eq_bool actual expected then print "Success\n" else
14+
if eq_bool actual expected then print "Test passed\n" else
1515
do { print "Test failed\n"; exit 1; }; };
1616

1717
// The constant 0 function const0 x = 0
@@ -21,19 +21,19 @@ const0 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 0)";
2121
const1 <- parse_core "\\ (_:Vec 64 Bool) -> returnM (Vec 64 Bool) (bvNat 64 1)";
2222

2323
// const0 <= const0
24-
run_test "mr_solver const0 const0" (mr_solver const0 const0) true;
24+
run_test "const0 |= const0" (mr_solver_query const0 const0) true;
2525

2626
// The function test_fun0 from the prelude = const0
2727
test_fun0 <- parse_core "test_fun0";
28-
run_test "mr_solver const0 test_fun0" (mr_solver const0 test_fun0) true;
28+
run_test "const0 |= test_fun0" (mr_solver_query const0 test_fun0) true;
2929

3030
// not const0 <= const1
31-
run_test "mr_solver const0 const1" (mr_solver const0 const1) false;
31+
run_test "const0 |= const1" (mr_solver_query const0 const1) false;
3232

3333
// The function test_fun1 from the prelude = const1
3434
test_fun1 <- parse_core "test_fun1";
35-
run_test "mr_solver const1 test_fun1" (mr_solver const1 test_fun1) true;
36-
run_test "mr_solver const0 test_fun1" (mr_solver const0 test_fun1) false;
35+
run_test "const1 |= test_fun1" (mr_solver_query const1 test_fun1) true;
36+
run_test "const0 |= test_fun1" (mr_solver_query const0 test_fun1) false;
3737

3838
// ifxEq0 x = If x == 0 then x else 0; should be equal to 0
3939
ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \
@@ -42,21 +42,21 @@ ifxEq0 <- parse_core "\\ (x:Vec 64 Bool) -> \
4242
\ (returnM (Vec 64 Bool) (bvNat 64 0))";
4343

4444
// ifxEq0 <= const0
45-
run_test "mr_solver ifxEq0 const0" (mr_solver ifxEq0 const0) true;
45+
run_test "ifxEq0 |= const0" (mr_solver_query ifxEq0 const0) true;
4646

4747
// not ifxEq0 <= const1
48-
run_test "mr_solver ifxEq0 const1" (mr_solver ifxEq0 const1) false;
48+
run_test "ifxEq0 |= const1" (mr_solver_query ifxEq0 const1) false;
4949

5050
// noErrors1 x = exists x. returnM x
5151
noErrors1 <- parse_core "\\ (x:Vec 64 Bool) -> \
5252
\ existsM (Vec 64 Bool) (Vec 64 Bool) \
5353
\ (\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool) x)";
5454

5555
// const0 <= noErrors
56-
run_test "mr_solver noErrors1 noErrors1" (mr_solver noErrors1 noErrors1) true;
56+
run_test "noErrors1 |= noErrors1" (mr_solver_query noErrors1 noErrors1) true;
5757

5858
// const1 <= noErrors
59-
run_test "mr_solver const1 noErrors1" (mr_solver const1 noErrors1) true;
59+
run_test "const1 |= noErrors1" (mr_solver_query const1 noErrors1) true;
6060

6161
// noErrorsRec1 x = orM (existsM x. returnM x) (noErrorsRec1 x)
6262
// Intuitively, this specifies functions that either return a value or loop
@@ -74,4 +74,4 @@ loop1 <- parse_core
7474
\ (\\ (f: Vec 64 Bool -> CompM (Vec 64 Bool)) (x:Vec 64 Bool) -> f x)";
7575

7676
// loop1 <= noErrorsRec1
77-
run_test "mr_solver loop1 noErrorsRec1" (mr_solver loop1 noErrorsRec1) true;
77+
run_test "loop1 |= noErrorsRec1" (mr_solver_query loop1 noErrorsRec1) true;
+4-20
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,15 @@
11
include "arrays.saw";
22

3-
let eq_bool b1 b2 =
4-
if b1 then
5-
if b2 then true else false
6-
else
7-
if b2 then false else true;
8-
9-
let fail = do { print "Test failed"; exit 1; };
10-
let run_test name test expected =
11-
do { if expected then print (str_concat "Test: " name) else
12-
print (str_concat (str_concat "Test: " name) " (expecting failure)");
13-
actual <- test;
14-
if eq_bool actual expected then print "Success\n" else
15-
do { print "Test failed\n"; exit 1; }; };
16-
173
// Test that contains0 |= contains0
184
contains0 <- parse_core_mod "arrays" "contains0";
19-
// run_test "contains0 |= contains0" (mr_solver contains0 contains0) true;
5+
mr_solver_test contains0 contains0;
206

217
noErrorsContains0 <- parse_core_mod "arrays" "noErrorsContains0";
22-
run_test "contains0 |= noErrorsContains0"
23-
(mr_solver_debug 0 contains0 noErrorsContains0) true;
8+
mr_solver_prove contains0 noErrorsContains0;
249

2510
include "specPrims.saw";
2611
import "arrays.cry";
2712

2813
zero_array <- parse_core_mod "arrays" "zero_array";
29-
run_test "zero_array |= zero_array_spec"
30-
// (mr_solver_debug 0 zero_array {{ zero_array_loop_spec }}) true;
31-
(mr_solver_debug 0 zero_array {{ zero_array_spec }}) true;
14+
// mr_solver_prove zero_array {{ zero_array_loop_spec }};
15+
mr_solver_prove zero_array {{ zero_array_spec }};
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,7 @@
11
include "exp_explosion.saw";
22

3-
let eq_bool b1 b2 =
4-
if b1 then
5-
if b2 then true else false
6-
else
7-
if b2 then false else true;
8-
9-
let fail = do { print "Test failed"; exit 1; };
10-
let run_test name test expected =
11-
do { if expected then print (str_concat "Test: " name) else
12-
print (str_concat (str_concat "Test: " name) " (expecting failure)");
13-
actual <- test;
14-
if eq_bool actual expected then print "Success\n" else
15-
do { print "Test failed\n"; exit 1; }; };
16-
17-
18-
193
import "exp_explosion.cry";
204
monadify_term {{ op }};
215

226
exp_explosion <- parse_core_mod "exp_explosion" "exp_explosion";
23-
run_test "exp_explosion |= exp_explosion_spec" (mr_solver exp_explosion {{ exp_explosion_spec }}) true;
7+
mr_solver_prove exp_explosion {{ exp_explosion_spec }};
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,5 @@
11
include "linked_list.saw";
22

3-
/***
4-
*** Testing infrastructure
5-
***/
6-
7-
let eq_bool b1 b2 =
8-
if b1 then
9-
if b2 then true else false
10-
else
11-
if b2 then false else true;
12-
13-
let fail = do { print "Test failed"; exit 1; };
14-
let run_test name test expected =
15-
do { if expected then print (str_concat "Test: " name) else
16-
print (str_concat (str_concat "Test: " name) " (expecting failure)");
17-
actual <- test;
18-
if eq_bool actual expected then print "Success\n" else
19-
do { print "Test failed\n"; exit 1; }; };
20-
21-
223
/***
234
*** Setup Cryptol environment
245
***/
@@ -45,15 +26,13 @@ heapster_typecheck_fun env "is_head"
4526
"(). arg0:int64<>, arg1:List<int64<>,always,R> -o \
4627
\ arg0:true, arg1:true, ret:int64<>";
4728

48-
/*
4929
is_head <- parse_core_mod "linked_list" "is_head";
50-
run_test "is_head |= is_head" (mr_solver is_head is_head) true;
51-
*/
30+
mr_solver_test is_head is_head;
5231

5332
is_elem <- parse_core_mod "linked_list" "is_elem";
54-
// run_test "is_elem |= is_elem" (mr_solver_debug 0 is_elem is_elem) true;
5533

56-
/*
34+
mr_solver_test is_elem is_elem;
35+
5736
is_elem_noErrorsSpec <- parse_core
5837
"\\ (x:Vec 64 Bool) (y:List (Vec 64 Bool)) -> \
5938
\ fixM (Vec 64 Bool * List (Vec 64 Bool)) \
@@ -63,10 +42,9 @@ is_elem_noErrorsSpec <- parse_core
6342
\ orM (Vec 64 Bool) \
6443
\ (existsM (Vec 64 Bool) (Vec 64 Bool) (returnM (Vec 64 Bool))) \
6544
\ (rec x)) (x, y)";
66-
run_test "is_elem |= noErrorsSpec" (mr_solver is_elem is_elem_noErrorsSpec) true;
67-
*/
45+
mr_solver_test is_elem is_elem_noErrorsSpec;
6846

69-
run_test "is_elem |= is_elem_spec" (mr_solver is_elem {{ is_elem_spec }}) true;
47+
mr_solver_prove is_elem {{ is_elem_spec }};
7048

7149

7250
monadify_term {{ Right }};
@@ -75,5 +53,4 @@ monadify_term {{ nil }};
7553
monadify_term {{ cons }};
7654

7755
sorted_insert_no_malloc <- parse_core_mod "linked_list" "sorted_insert_no_malloc";
78-
run_test "sorted_insert_no_malloc |= sorted_insert_spec"
79-
(mr_solver sorted_insert_no_malloc {{ sorted_insert_spec }}) true;
56+
mr_solver_prove sorted_insert_no_malloc {{ sorted_insert_spec }};

heapster-saw/examples/sha512_mr_solver.saw

+6-6
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ processBlock <- parse_core_mod "SHA512" "processBlock";
8686
processBlocks <- parse_core_mod "SHA512" "processBlocks";
8787

8888
// Test that every function refines itself
89-
// run_test "processBlocks |= processBlocks" (mr_solver processBlocks processBlocks) true;
90-
// run_test "processBlock |= processBlock" (mr_solver processBlock processBlock) true;
91-
// run_test "round_16_80 |= round_16_80" (mr_solver round_16_80 round_16_80) true;
92-
// run_test "round_00_15 |= round_00_15" (mr_solver round_00_15 round_00_15) true;
89+
// mr_solver_test processBlocks processBlocks;
90+
// mr_solver_test processBlock processBlock;
91+
// mr_solver_test round_16_80 round_16_80;
92+
// mr_solver_test round_00_15 round_00_15;
9393

9494
import "sha512.cry";
9595
// FIXME: Why aren't we monadifying these automatically when they're used?
@@ -105,5 +105,5 @@ monadify_term {{ Maj }};
105105
// "round_16_80 |= round_16_80_spec"?
106106
monadify_term {{ round_00_15_spec }};
107107

108-
run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true;
109-
run_test "round_16_80 |= round_16_80_spec" (mr_solver round_16_80 {{ round_16_80_spec }}) true;
108+
mr_solver_prove round_00_15 {{ round_00_15_spec }};
109+
mr_solver_prove round_16_80 {{ round_16_80_spec }};

0 commit comments

Comments
 (0)