diff --git a/heapster-saw/examples/Makefile b/heapster-saw/examples/Makefile index d10f19fc5d..a7cecd96a7 100644 --- a/heapster-saw/examples/Makefile +++ b/heapster-saw/examples/Makefile @@ -34,7 +34,7 @@ rust_lifetimes.bc: rust_lifetimes.rs rustc --crate-type=lib --emit=llvm-bc rust_lifetimes.rs # Lists all the Mr Solver tests, without their ".saw" suffix -MR_SOLVER_TESTS = arrays_mr_solver linked_list_mr_solver +MR_SOLVER_TESTS = arrays_mr_solver linked_list_mr_solver sha512_mr_solver .PHONY: mr-solver-tests $(MR_SOLVER_TESTS) mr-solver-tests: $(MR_SOLVER_TESTS) diff --git a/heapster-saw/examples/sha512.bc b/heapster-saw/examples/sha512.bc index 0e466fa1d8..711867222c 100644 Binary files a/heapster-saw/examples/sha512.bc and b/heapster-saw/examples/sha512.bc differ diff --git a/heapster-saw/examples/sha512.c b/heapster-saw/examples/sha512.c index 5aec308132..a467b2ffaa 100644 --- a/heapster-saw/examples/sha512.c +++ b/heapster-saw/examples/sha512.c @@ -1,7 +1,17 @@ +// ============================================================================ +// The code in this file is based off of that in: +// https://github.com/awslabs/aws-lc/ +// (commit: d84d2f329dccbc7f3866eab54951bd012e317041) +// ============================================================================ + #include #include #include +// ============================================================================ +// Helper functions from crypto/internal.h +// ============================================================================ + static inline void *OPENSSL_memcpy(void *dst, const void *src, size_t n) { if (n == 0) { return dst; @@ -26,6 +36,14 @@ static inline uint64_t CRYPTO_load_u64_be(const void *ptr) { return CRYPTO_bswap8(ret); } +// ============================================================================ +// The defintion of sha512_block_data_order from crypto/fipsmodule/sha/sha512.c +// with only one addition (return_state), needed for Heapster typechecking +// ============================================================================ + +// Used in sha512_block_data_order below, needed for Heapster typechecking +void return_state(uint64_t *state) { } + static const uint64_t K512[80] = { UINT64_C(0x428a2f98d728ae22), UINT64_C(0x7137449123ef65cd), UINT64_C(0xb5c0fbcfec4d3b2f), UINT64_C(0xe9b5dba58189dbbc), @@ -69,9 +87,7 @@ static const uint64_t K512[80] = { UINT64_C(0x5fcb6fab3ad6faec), UINT64_C(0x6c44198c4a475817), }; -#ifndef ROTR #define ROTR(x, s) (((x) >> s) | (x) << (64 - s)) -#endif #define Sigma0(x) (ROTR((x), 28) ^ ROTR((x), 34) ^ ROTR((x), 39)) #define Sigma1(x) (ROTR((x), 14) ^ ROTR((x), 18) ^ ROTR((x), 41)) @@ -99,8 +115,6 @@ static const uint64_t K512[80] = { ROUND_00_15(i + j, a, b, c, d, e, f, g, h); \ } while (0) -void return_state(uint64_t *state) { } - static void sha512_block_data_order(uint64_t *state, const uint8_t *in, size_t num) { uint64_t a, b, c, d, e, f, g, h, s0, s1, T1; @@ -184,7 +198,132 @@ static void sha512_block_data_order(uint64_t *state, const uint8_t *in, } } -// needed for Heapster to be able to see the static function above + +// ============================================================================ +// A definition equivalent to sha512_block_data_order which uses multiple +// functions, for use with Mr. Solver +// ============================================================================ + +static void round_00_15(uint64_t i, + uint64_t *a, uint64_t *b, uint64_t *c, uint64_t *d, + uint64_t *e, uint64_t *f, uint64_t *g, uint64_t *h, + uint64_t *T1) { + *T1 += *h + Sigma1(*e) + Ch(*e, *f, *g) + K512[i]; + *h = Sigma0(*a) + Maj(*a, *b, *c); + *d += *T1; + *h += *T1; +} + +static void round_16_80(uint64_t i, uint64_t j, + uint64_t *a, uint64_t *b, uint64_t *c, uint64_t *d, + uint64_t *e, uint64_t *f, uint64_t *g, uint64_t *h, + uint64_t *X, + uint64_t* s0, uint64_t *s1, uint64_t *T1) { + *s0 = X[(j + 1) & 0x0f]; + *s0 = sigma0(*s0); + *s1 = X[(j + 14) & 0x0f]; + *s1 = sigma1(*s1); + *T1 = X[(j) & 0x0f] += *s0 + *s1 + X[(j + 9) & 0x0f]; + round_00_15(i + j, a, b, c, d, e, f, g, h, T1); +} + +// Used in processBlock below, needed for Heapster typechecking +void return_X(uint64_t *X) { } + +static void processBlock(uint64_t *a, uint64_t *b, uint64_t *c, uint64_t *d, + uint64_t *e, uint64_t *f, uint64_t *g, uint64_t *h, + const uint8_t *in) { + uint64_t s0, s1, T1; + uint64_t X[16]; + int i; + + T1 = X[0] = CRYPTO_load_u64_be(in); + round_00_15(0, a, b, c, d, e, f, g, h, &T1); + T1 = X[1] = CRYPTO_load_u64_be(in + 8); + round_00_15(1, h, a, b, c, d, e, f, g, &T1); + T1 = X[2] = CRYPTO_load_u64_be(in + 2 * 8); + round_00_15(2, g, h, a, b, c, d, e, f, &T1); + T1 = X[3] = CRYPTO_load_u64_be(in + 3 * 8); + round_00_15(3, f, g, h, a, b, c, d, e, &T1); + T1 = X[4] = CRYPTO_load_u64_be(in + 4 * 8); + round_00_15(4, e, f, g, h, a, b, c, d, &T1); + T1 = X[5] = CRYPTO_load_u64_be(in + 5 * 8); + round_00_15(5, d, e, f, g, h, a, b, c, &T1); + T1 = X[6] = CRYPTO_load_u64_be(in + 6 * 8); + round_00_15(6, c, d, e, f, g, h, a, b, &T1); + T1 = X[7] = CRYPTO_load_u64_be(in + 7 * 8); + round_00_15(7, b, c, d, e, f, g, h, a, &T1); + T1 = X[8] = CRYPTO_load_u64_be(in + 8 * 8); + round_00_15(8, a, b, c, d, e, f, g, h, &T1); + T1 = X[9] = CRYPTO_load_u64_be(in + 9 * 8); + round_00_15(9, h, a, b, c, d, e, f, g, &T1); + T1 = X[10] = CRYPTO_load_u64_be(in + 10 * 8); + round_00_15(10, g, h, a, b, c, d, e, f, &T1); + T1 = X[11] = CRYPTO_load_u64_be(in + 11 * 8); + round_00_15(11, f, g, h, a, b, c, d, e, &T1); + T1 = X[12] = CRYPTO_load_u64_be(in + 12 * 8); + round_00_15(12, e, f, g, h, a, b, c, d, &T1); + T1 = X[13] = CRYPTO_load_u64_be(in + 13 * 8); + round_00_15(13, d, e, f, g, h, a, b, c, &T1); + T1 = X[14] = CRYPTO_load_u64_be(in + 14 * 8); + round_00_15(14, c, d, e, f, g, h, a, b, &T1); + T1 = X[15] = CRYPTO_load_u64_be(in + 15 * 8); + round_00_15(15, b, c, d, e, f, g, h, a, &T1); + + return_X(X); // for Heapster + + for (i = 16; i < 80; i += 16) { + round_16_80(i, 0, a, b, c, d, e, f, g, h, X, &s0, &s1, &T1); + round_16_80(i, 1, h, a, b, c, d, e, f, g, X, &s0, &s1, &T1); + round_16_80(i, 2, g, h, a, b, c, d, e, f, X, &s0, &s1, &T1); + round_16_80(i, 3, f, g, h, a, b, c, d, e, X, &s0, &s1, &T1); + round_16_80(i, 4, e, f, g, h, a, b, c, d, X, &s0, &s1, &T1); + round_16_80(i, 5, d, e, f, g, h, a, b, c, X, &s0, &s1, &T1); + round_16_80(i, 6, c, d, e, f, g, h, a, b, X, &s0, &s1, &T1); + round_16_80(i, 7, b, c, d, e, f, g, h, a, X, &s0, &s1, &T1); + round_16_80(i, 8, a, b, c, d, e, f, g, h, X, &s0, &s1, &T1); + round_16_80(i, 9, h, a, b, c, d, e, f, g, X, &s0, &s1, &T1); + round_16_80(i, 10, g, h, a, b, c, d, e, f, X, &s0, &s1, &T1); + round_16_80(i, 11, f, g, h, a, b, c, d, e, X, &s0, &s1, &T1); + round_16_80(i, 12, e, f, g, h, a, b, c, d, X, &s0, &s1, &T1); + round_16_80(i, 13, d, e, f, g, h, a, b, c, X, &s0, &s1, &T1); + round_16_80(i, 14, c, d, e, f, g, h, a, b, X, &s0, &s1, &T1); + round_16_80(i, 15, b, c, d, e, f, g, h, a, X, &s0, &s1, &T1); + } +} + +static void processBlocks(uint64_t *state, const uint8_t *in, size_t num) { + uint64_t a, b, c, d, e, f, g, h; + + while (num--) { + + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + f = state[5]; + g = state[6]; + h = state[7]; + + processBlock(&a, &b, &c, &d, &e, &f, &g, &h, in); + + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + state[5] += f; + state[6] += g; + state[7] += h; + + in += 16 * 8; + } +} + + +// Needed for Heapster to be able to see the static functions above void dummy(uint64_t *state, const uint8_t *in, size_t num) { sha512_block_data_order(state, in, num); + processBlocks(state, in, num); } diff --git a/heapster-saw/examples/sha512.cry b/heapster-saw/examples/sha512.cry new file mode 100644 index 0000000000..da3704db33 --- /dev/null +++ b/heapster-saw/examples/sha512.cry @@ -0,0 +1,86 @@ + +module SHA512 where + +// ============================================================================ +// Definitions from cryptol-specs/Primitive/Keyless/Hash/SHA512.cry, with some +// type annotations added to SIGMA_0, SIGMA_1, sigma_0, and sigma_1 to get +// monadification to go through +// ============================================================================ + +type w = 64 + +type j = 80 + +K : [j][w] +K = [ 0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f, 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, 0x59f111f1b605d019, 0x923f82a4af194f9b, 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, 0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235, 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, 0xefbe4786384f25e3, 0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, 0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5, + 0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f, 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, 0xd5a79147930aa725, 0x06ca6351e003826f, 0x142929670a0e6e70, + 0x27b70a8546d22ffc, 0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df, + 0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6, 0x92722c851482353b, + 0xa2bfe8a14cf10364, 0xa81a664bbc423001, 0xc24b8b70d0f89791, 0xc76c51a30654be30, + 0xd192e819d6ef5218, 0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99, 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb, 0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, 0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec, + 0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915, 0xc67178f2e372532b, + 0xca273eceea26619c, 0xd186b8c721c0c207, 0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, 0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b, + 0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc, 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a, 0x5fcb6fab3ad6faec, 0x6c44198c4a475817] + +SIGMA_0 : [w] -> [w] +SIGMA_0 x = (x >>> (28 : [w])) ^ (x >>> (34 : [w])) ^ (x >>> (39 : [w])) + +SIGMA_1 : [w] -> [w] +SIGMA_1 x = (x >>> (14 : [w])) ^ (x >>> (18 : [w])) ^ (x >>> (41 : [w])) + +sigma_0 : [w] -> [w] +sigma_0 x = (x >>> (1 : [w])) ^ (x >>> (8 : [w])) ^ (x >> (7 : [w])) + +sigma_1 : [w] -> [w] +sigma_1 x = (x >>> (19 : [w])) ^ (x >>> (61 : [w])) ^ (x >> (6 : [w])) + + +// ============================================================================ +// Definitions from cryptol-specs/Primitive/Keyless/Hash/SHA.cry +// ============================================================================ + +Ch : [w] -> [w] -> [w] -> [w] +Ch x y z = (x && y) ^ (~x && z) + +Maj : [w] -> [w] -> [w] -> [w] +Maj x y z = (x && y) ^ (x && z) ^ (y && z) + + +// ============================================================================ +// Cryptol functions which closely match the definitions in sha512.c +// ============================================================================ + +round_00_15_spec : [w] -> + [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> + [w] -> + ([w], [w], [w], [w], [w], [w], [w], [w], [w]) +round_00_15_spec i a b c d e f g h T1 = + (a, b, c, d', e, f, g, h', T1') + where T1' = T1 + h + SIGMA_1 e + Ch e f g + K @ i + d' = d + T1' + h' = SIGMA_0 a + Maj a b c + T1' + +round_16_80_spec : [w] -> [w] -> + [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> + [16][w] -> + [w] -> [w] -> [w] -> + ([w], [w], [w], [w], [w], [w], [w], [w], [16][w], [w], [w], [w]) +round_16_80_spec i j a b c d e f g h X s0 s1 T1 = + (a', b', c', d', e', f', g', h', X', s0', s1', T1'') + where s0' = sigma_0 (X @ ((j + 1) && 15)) + s1' = sigma_1 (X @ ((j + 4) && 15)) + T1' = X @ (j && 15) + s0' + s1' + X @ ((j + 9) && 15) + X' = update X (j && 15) T1' + (a', b', c', d', e', f', g', h', T1'') = + round_00_15_spec (i + j) a b c d e f g h T1' diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw new file mode 100644 index 0000000000..69169abb2e --- /dev/null +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -0,0 +1,108 @@ +enable_experimental; +env <- heapster_init_env "SHA512" "sha512.bc"; + +// Heapster + +heapster_define_perm env "int64" " " "llvmptr 64" "exists x:bv 64.eq(llvmword(x))"; +heapster_define_perm env "int32" " " "llvmptr 32" "exists x:bv 32.eq(llvmword(x))"; +heapster_define_perm env "int8" " " "llvmptr 8" "exists x:bv 8.eq(llvmword(x))"; + +heapster_define_perm env "int64_ptr" " " "llvmptr 64" "ptr((W,0) |-> int64<>)"; + +heapster_assume_fun env "CRYPTO_load_u64_be" + "(). arg0:ptr((R,0) |-> int64<>) -o \ + \ arg0:ptr((R,0) |-> int64<>), ret:int64<>" + "\\ (x:Vec 64 Bool) -> returnM (Vec 64 Bool * Vec 64 Bool) (x, x)"; + +heapster_typecheck_fun env "round_00_15" + "(). arg0:int64<>, \ + \ arg1:int64_ptr<>, arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, \ + \ arg5:int64_ptr<>, arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, \ + \ arg9:int64_ptr<> -o \ + \ arg1:int64_ptr<>, arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, \ + \ arg5:int64_ptr<>, arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, \ + \ arg9:int64_ptr<>, ret:true"; + +heapster_typecheck_fun env "round_16_80" + "(). arg0:int64<>, arg1:int64<>, \ + \ arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ + \ arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, arg9:int64_ptr<>, \ + \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ + \ arg11:ptr((W,0) |-> true), arg12:ptr((W,0) |-> true), arg13:int64_ptr<> -o \ + \ arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ + \ arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, arg9:int64_ptr<>, \ + \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ + \ arg11:int64_ptr<>, arg12:int64_ptr<>, arg13:int64_ptr<>, ret:true"; + +heapster_typecheck_fun env "return_X" + "(). arg0:array(W,0,<16,*8,fieldsh(int64<>)) -o \ + \ arg0:array(W,0,<16,*8,fieldsh(int64<>))"; + +heapster_set_translation_checks env false; +heapster_typecheck_fun env "processBlock" + "(). arg0:int64_ptr<>, arg1:int64_ptr<>, arg2:int64_ptr<>, \ + \ arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ + \ arg6:int64_ptr<>, arg7:int64_ptr<>, \ + \ arg8:array(R,0,<16,*8,fieldsh(int64<>)) -o \ + \ arg0:int64_ptr<>, arg1:int64_ptr<>, arg2:int64_ptr<>, \ + \ arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ + \ arg6:int64_ptr<>, arg7:int64_ptr<>, \ + \ arg8:array(R,0,<16,*8,fieldsh(int64<>)), ret:true"; + +// FIXME: This translation contains errors +heapster_set_translation_checks env false; +heapster_typecheck_fun env "processBlocks" + "(num:bv 64). arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ + \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ + \ arg2:eq(llvmword(num)) -o \ + \ arg0:array(W,0,<8,*8,fieldsh(int64<>)), \ + \ arg1:array(R,0,<16*num,*8,fieldsh(int64<>)), \ + \ arg2:true, ret:true"; + +heapster_export_coq env "sha512_mr_solver_gen.v"; + +// Mr. Solver + +let eq_bool b1 b2 = + if b1 then + if b2 then true else false + else + if b2 then false else true; + +let fail = do { print "Test failed"; exit 1; }; +let run_test name test expected = + do { if expected then print (str_concat "Test: " name) else + print (str_concat (str_concat "Test: " name) " (expecting failure)"); + actual <- test; + if eq_bool actual expected then print "Success\n" else + do { print "Test failed\n"; exit 1; }; }; + +round_00_15 <- parse_core_mod "SHA512" "round_00_15"; +round_16_80 <- parse_core_mod "SHA512" "round_16_80"; +processBlock <- parse_core_mod "SHA512" "processBlock"; +processBlocks <- parse_core_mod "SHA512" "processBlocks"; + +// Test that every function refines itself +// run_test "processBlocks |= processBlocks" (mr_solver processBlocks processBlocks) true; +// run_test "processBlock |= processBlock" (mr_solver processBlock processBlock) true; +// run_test "round_16_80 |= round_16_80" (mr_solver round_16_80 round_16_80) true; +// run_test "round_00_15 |= round_00_15" (mr_solver round_00_15 round_00_15) true; + +import "sha512.cry"; +// FIXME: Why aren't we monadifying these automatically when they're used? +monadify_term {{ K }}; +monadify_term {{ SIGMA_0 }}; +monadify_term {{ SIGMA_1 }}; +monadify_term {{ sigma_0 }}; +monadify_term {{ sigma_1 }}; +monadify_term {{ Ch }}; +monadify_term {{ Maj }}; + +// FIXME: Why does monadification fail without this line while running +// "round_16_80 |= round_16_80_spec"? +monadify_term {{ round_00_15_spec }}; + +run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true; + +// FIXME: Need to add heterogenous equality on output types for this to work +// run_test "round_16_80 |= round_16_80_spec" (mr_solver_debug 0 round_16_80 {{ round_16_80_spec }}) true; diff --git a/heapster-saw/proverTests/Main.hs b/heapster-saw/proverTests/Main.hs index 3f000087ba..94364b3c0b 100644 --- a/heapster-saw/proverTests/Main.hs +++ b/heapster-saw/proverTests/Main.hs @@ -243,6 +243,10 @@ arrayTests = , testCase "sum of borrows" $ passes $ [ int64ArrayPerm 0 3 \\\ (1,2) , int64ArrayPerm 24 4 \\\ (1,2) ] ===> int64ArrayPerm 0 7 \\\ (1,2) \\\ (3,3) + + , testCase "fully-borrowed refl" $ passes $ + int64ArrayPerm 0 4 \\\ (0, 2) \\\ (2, 2) + ===> int64ArrayPerm 0 4 \\\ (0, 2) \\\ (2, 2) ] ] diff --git a/heapster-saw/src/Verifier/SAW/Heapster/Implication.hs b/heapster-saw/src/Verifier/SAW/Heapster/Implication.hs index b7a82cb4e0..e494b071f2 100644 --- a/heapster-saw/src/Verifier/SAW/Heapster/Implication.hs +++ b/heapster-saw/src/Verifier/SAW/Heapster/Implication.hs @@ -6734,7 +6734,8 @@ proveVarLLVMArrayH x psubst ps mb_ap , Just len <- partialSubst psubst $ mbLLVMArrayLen mb_ap , Just lenBytes <- partialSubst psubst $ mbLLVMArrayLenBytes mb_ap , stride <- mbLLVMArrayStride mb_ap - , Just i <- findIndex (suitableAP off lenBytes stride) ps + , Just bs <- partialSubst psubst $ mbLLVMArrayBorrows mb_ap + , Just i <- findIndex (suitableAP off lenBytes stride bs) ps , Perm_LLVMArray ap_lhs <- ps!!i = implVerbTraceM (\info -> pretty "proveVarLLVMArrayH case 1: using" <+> permPretty info ap_lhs) >>> @@ -6742,7 +6743,7 @@ proveVarLLVMArrayH x psubst ps mb_ap recombinePerm x (ValPerm_Conj ps') >>> partialSubstForceM (mbLLVMArrayBorrows mb_ap) - "proveVarLLVMArrayH: incomplete array borrows" >>>= \bs -> + "proveVarLLVMArrayH: incomplete array borrows" >>> if bvEq off (llvmArrayOffset ap_lhs) && bvEq len (llvmArrayLen ap_lhs) then proveVarLLVMArray_FromArray x ap_lhs len bs mb_ap @@ -6752,22 +6753,26 @@ proveVarLLVMArrayH x psubst ps mb_ap proveVarLLVMArray_FromArray x (llvmMakeSubArray ap_lhs off len) len bs mb_ap where -- Test if an atomic permission is a "suitable" array permission for the - -- given offset, length, and stride, meaning that it has the required - -- stride, could contain the offset and length, and does not have all of the - -- offset and length borrowed + -- given offset, length, stride, and borrows, meaning that it has the + -- given stride, could contain the given offset and length, and either + -- has exactly the given borrows or at least does not have all of the + -- given offset and length borrowed suitableAP :: (1 <= w, KnownNat w) => PermExpr (BVType w) -> PermExpr (BVType w) -> Bytes -> - AtomicPerm (LLVMPointerType w) -> Bool - suitableAP off len stride (Perm_LLVMArray ap) = + [LLVMArrayBorrow w] -> AtomicPerm (LLVMPointerType w) -> Bool + suitableAP off len stride bs (Perm_LLVMArray ap) = -- Test that the strides are equal llvmArrayStride ap == stride && - -- Make sure the range [off,len) is not fully borrowed - not (llvmArrayRangeIsBorrowed ap (BVRange off len)) && -- Test if this permission *could* cover the desired off/len all bvPropCouldHold (bvPropRangeSubset (BVRange off len) - (llvmArrayAbsOffsets ap)) - suitableAP _ _ _ _ = False + (llvmArrayAbsOffsets ap)) && + -- Test that either the sets of borrows are equal ... + ((all (flip elem bs) (llvmArrayBorrows ap) && + all (flip elem (llvmArrayBorrows ap)) bs) || + -- ...or the range [off,len) is not fully borrowed + not (llvmArrayRangeIsBorrowed ap (BVRange off len))) + suitableAP _ _ _ _ _ = False -- Check if there is a block that contains the required offset and length, in -- which case eliminate it, allowing us to either satisfy way 4 (eliminate a @@ -6809,8 +6814,8 @@ proveVarLLVMArrayH x psubst ps mb_ap , len <- llvmArrayLen ap , lhs_cells@(lhs_cell_rng:_) <- concatMap (permCells ap) ps , rhs_cells <- map llvmArrayBorrowCells (llvmArrayBorrows ap) - , Just cells <- gatherCoveringRanges (llvmArrayCells ap) (lhs_cells ++ - rhs_cells) + , Just cells <- gatherCoveringRanges (llvmArrayCells ap) (rhs_cells ++ + lhs_cells) , bs <- map cellRangeToBorrow cells , ap_borrowed <- ap { llvmArrayBorrows = bs } , cell_bp <- blockForCell ap (bvRangeOffset lhs_cell_rng) = diff --git a/heapster-saw/src/Verifier/SAW/Heapster/Permissions.hs b/heapster-saw/src/Verifier/SAW/Heapster/Permissions.hs index e86d937293..75461c688d 100644 --- a/heapster-saw/src/Verifier/SAW/Heapster/Permissions.hs +++ b/heapster-saw/src/Verifier/SAW/Heapster/Permissions.hs @@ -8205,10 +8205,14 @@ detVarsClauseAddLHSVar :: ExprVar a -> DetVarsClause -> DetVarsClause detVarsClauseAddLHSVar n (DetVarsClause lhs rhs) = DetVarsClause (NameSet.insert n lhs) rhs +newtype SeenDetVarsClauses :: CrucibleType -> * where + SeenDetVarsClauses :: [DetVarsClause] -> SeenDetVarsClauses tp + -- | Generic function to compute the 'DetVarsClause's for a permission class GetDetVarsClauses a where getDetVarsClauses :: - a -> ReaderT (PermSet ps) (State (NameSet CrucibleType)) [DetVarsClause] + a -> ReaderT (PermSet ps) (State (NameMap SeenDetVarsClauses)) + [DetVarsClause] instance GetDetVarsClauses a => GetDetVarsClauses [a] where getDetVarsClauses l = concat <$> mapM getDetVarsClauses l @@ -8220,11 +8224,13 @@ instance GetDetVarsClauses (ExprVar a) where getDetVarsClauses x = do seen_vars <- get perms <- ask - if NameSet.member x seen_vars then return [] else - do modify (NameSet.insert x) - perm_clauses <- getDetVarsClauses (perms ^. varPerm x) - return (DetVarsClause NameSet.empty (SomeName x) : - map (detVarsClauseAddLHSVar x) perm_clauses) + perm_clauses <- case NameMap.lookup x seen_vars of + Just (SeenDetVarsClauses perm_clauses) -> return perm_clauses + Nothing -> do perm_clauses <- getDetVarsClauses (perms ^. varPerm x) + modify (NameMap.insert x (SeenDetVarsClauses perm_clauses)) + return perm_clauses + return (DetVarsClause NameSet.empty (SomeName x) : + map (detVarsClauseAddLHSVar x) perm_clauses) instance GetDetVarsClauses (PermExpr a) where getDetVarsClauses e @@ -8296,7 +8302,7 @@ instance GetDetVarsClauses (LLVMFieldShape w) where -- | Compute the 'DetVarsClause's for a block permission with the given shape getShapeDetVarsClauses :: (1 <= w, KnownNat w) => PermExpr (LLVMShapeType w) -> - ReaderT (PermSet ps) (State (NameSet CrucibleType)) [DetVarsClause] + ReaderT (PermSet ps) (State (NameMap SeenDetVarsClauses)) [DetVarsClause] getShapeDetVarsClauses (PExpr_Var x) = getDetVarsClauses x getShapeDetVarsClauses (PExpr_NamedShape _ _ _ args) = @@ -8322,20 +8328,23 @@ getShapeDetVarsClauses _ = return [] -- is always a uniquely determined value of @y@ for any proof of @exists y.x:p@. determinedVars :: PermSet ps -> RAssign ExprVar ns -> [SomeName CrucibleType] determinedVars top_perms vars = - let vars_set = NameSet.fromList $ mapToList SomeName vars + let vars_map = NameMap.fromList $ + mapToList (\v -> NameAndElem v (SeenDetVarsClauses [])) vars + vars_set = NameSet.fromList $ mapToList SomeName vars multigraph = evalState (runReaderT (getDetVarsClauses (distPermsToValuePerms $ varPermsMulti vars top_perms)) top_perms) - vars_set in + vars_map in evalState (determinedVarsForGraph multigraph) vars_set where -- Find all variables that are not already marked as determined in our -- NameSet state but that are determined given the current determined - -- variables, mark these variables as determiend, and then repeat, returning + -- variables, mark these variables as determined, and then repeat, returning -- all variables that are found in order determinedVarsForGraph :: [DetVarsClause] -> - State (NameSet CrucibleType) [SomeName CrucibleType] + State (NameSet CrucibleType) + [SomeName CrucibleType] determinedVarsForGraph graph = do det_vars <- concat <$> mapM determinedVarsForClause graph if det_vars == [] then return [] else @@ -8344,7 +8353,8 @@ determinedVars top_perms vars = -- If the LHS of a clause has become determined but its RHS is not, return -- its RHS, otherwise return nothing determinedVarsForClause :: DetVarsClause -> - State (NameSet CrucibleType) [SomeName CrucibleType] + State (NameSet CrucibleType) + [SomeName CrucibleType] determinedVarsForClause (DetVarsClause lhs_vars (SomeName rhs_var)) = do det_vars <- get if not (NameSet.member rhs_var det_vars) && diff --git a/saw-core/prelude/Prelude.sawcore b/saw-core/prelude/Prelude.sawcore index d8c4950a78..f559e19926 100644 --- a/saw-core/prelude/Prelude.sawcore +++ b/saw-core/prelude/Prelude.sawcore @@ -1694,6 +1694,13 @@ genBVVecFromVec m a v def n len = genBVVec n len a (\ (i:Vec n Bool) (_:is_bvult n i len) -> atWithDefault m a def v (bvToNat n i)); +-- Generate a vector from the elements of an existing BVVec, using a default +-- value when we run out of the existing BVVec - the inverse of genBVVecFromVec +genFromBVVec : (n : Nat) -> (len : Vec n Bool) -> (a : sort 0) -> + BVVec n len a -> a -> (m : Nat) -> Vec m a; +genFromBVVec n len a v def m = + gen m a (\ (i:Nat) -> atWithDefault (bvToNat n len) a def v i); + -- The false proposition FalseProp : Prop; FalseProp = Eq Bool True False; diff --git a/saw-core/src/Verifier/SAW/Recognizer.hs b/saw-core/src/Verifier/SAW/Recognizer.hs index ba3d81ead7..ad951c573e 100644 --- a/saw-core/src/Verifier/SAW/Recognizer.hs +++ b/saw-core/src/Verifier/SAW/Recognizer.hs @@ -46,6 +46,7 @@ module Verifier.SAW.Recognizer , asNat , asBvNat , asUnsignedConcreteBv + , asArrayValue , asStringLit , asLambda , asLambdaList @@ -75,6 +76,7 @@ import Control.Lens import Control.Monad import Data.Map (Map) import qualified Data.Map as Map +import qualified Data.Vector as V import Data.Text (Text) import Numeric.Natural (Natural) @@ -287,6 +289,11 @@ asUnsignedConcreteBv term = do (n :*: v) <- asBvNat term return $ mod v (2 ^ n) +asArrayValue :: Recognizer Term (Term, [Term]) +asArrayValue (unwrapTermF -> FTermF (ArrayValue tp tms)) = + return (tp, V.toList tms) +asArrayValue _ = Nothing + asStringLit :: Recognizer Term Text asStringLit t = do StringLit i <- asFTermF t; return i diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index b42322f8a3..3ece60ca01 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -18,6 +18,7 @@ namely 'mrProvable' and 'mrProveEq'. module SAWScript.Prover.MRSolver.SMT where import qualified Data.Vector as V +import Numeric.Natural (Natural) import Control.Monad.Except import qualified Control.Exception as X @@ -33,9 +34,9 @@ import Verifier.SAW.OpenTerm import Verifier.SAW.Prim (EvalError(..)) import qualified Verifier.SAW.Prim as Prim +import Verifier.SAW.Simulator.Value import Verifier.SAW.Simulator.TermModel import Verifier.SAW.Simulator.Prims -import Verifier.SAW.Simulator.MonadLazy import SAWScript.Proof (termToProp, propToTerm, prettyProp) import What4.Solver @@ -84,6 +85,14 @@ asGenBVVecTerm (asApplyAll -> = Just (n, len, a, e) asGenBVVecTerm _ = Nothing +-- | Match a term of the form @genFromBVVec n len a v def m@ +asGenFromBVVecTerm :: Recognizer Term (Term, Term, Term, Term, Term, Term) +asGenFromBVVecTerm (asApplyAll -> + (isGlobalDef "Prelude.genFromBVVec" -> Just _, + [n, len, a, v, def, m])) + = Just (n, len, a, v, def, m) +asGenFromBVVecTerm _ = Nothing + type TmPrim = Prim TermModel -- | Convert a Boolean value to a 'Term'; like 'readBackValue' but that function @@ -126,6 +135,73 @@ primBVTermFun sc = scVectorReduced sc tp tms v -> lift (putStrLn ("primBVTermFun: unhandled value: " ++ show v)) >> mzero +-- | A datatype representing either a @genFromBVVec n len _ v _ _@ term or +-- a vector literal, the latter being represented as a list of 'Term's +data FromBVVecOrLit = FromBVVec { fromBVVec_n :: Natural + , fromBVVec_len :: Term + , fromBVVec_vec :: Term } + | BVVecLit [Term] + +-- | An implementation of a primitive function that expects either a +-- @genFromBVVec@ term or a vector literal +primFromBVVecOrLit :: SharedContext -> TValue TermModel -> + (FromBVVecOrLit -> TmPrim) -> TmPrim +primFromBVVecOrLit sc a = + PrimFilterFun "primFromBVVecOrLit" $ + \case + VExtra (VExtraTerm _ (asGenFromBVVecTerm -> Just (asNat -> Just n, len, _, + v, _, _))) -> + return $ FromBVVec n len v + VVector vs -> + lift $ BVVecLit <$> + traverse (readBackValueNoConfig "primFromBVVecOrLit" sc a <=< force) + (V.toList vs) + _ -> mzero + +-- | Turn a 'FromBVVecOrLit' into a BVVec term, assuming it has the given +-- bit-width (given as both a 'Natural' and a 'Term'), length, and element type +-- FIXME: Properly handle empty vector literals +bvVecFromBVVecOrLit :: SharedContext -> Natural -> Term -> Term -> Term -> + FromBVVecOrLit -> IO Term +bvVecFromBVVecOrLit sc n _ len _ (FromBVVec n' len' v) = + do len_cvt_len' <- scConvertible sc True len len' + if n == n' && len_cvt_len' then return v + else error "bvVecFromBVVecOrLit: genFromBVVec type mismatch" +bvVecFromBVVecOrLit sc n n' len a (BVVecLit vs) = + do body <- mkBody 0 vs + i_tp <- scBitvector sc n + var0 <- scLocalVar sc 0 + pf_tp <- scGlobalApply sc "Prelude.is_bvult" [n', var0, len] + f <- scLambdaList sc [("i", i_tp), ("pf", pf_tp)] body + scGlobalApply sc "Prelude.genBVVec" [n', len, a, f] + where mkBody :: Integer -> [Term] -> IO Term + mkBody _ [] = error "bvVecFromBVVecOrLit: empty vector" + mkBody _ [x] = return $ x + mkBody i (x:xs) = + do var1 <- scLocalVar sc 1 + i' <- scBvConst sc n i + cond <- scBvEq sc n' var1 i' + body' <- mkBody (i+1) xs + scIte sc a cond x body' + +-- | A version of 'readBackTValue' which uses 'error' as the simulator config +-- Q: Is there every a case where this will actually error? +readBackTValueNoConfig :: String -> SharedContext -> + TValue TermModel -> IO Term +readBackTValueNoConfig err_str sc tv = + let ?recordEC = \_ec -> return () in + let cfg = error $ "FIXME: need the simulator config in " ++ err_str + in readBackTValue sc cfg tv + +-- | A version of 'readBackValue' which uses 'error' as the simulator config +-- Q: Is there every a case where this will actually error? +readBackValueNoConfig :: String -> SharedContext -> + TValue TermModel -> Value TermModel -> IO Term +readBackValueNoConfig err_str sc tv v = + let ?recordEC = \_ec -> return () in + let cfg = error $ "FIXME: need the simulator config in " ++ err_str + in readBackValue sc cfg tv v + -- | Implementations of primitives for normalizing Mr Solver terms smtNormPrims :: SharedContext -> Map Ident TmPrim smtNormPrims sc = Map.fromList @@ -133,8 +209,23 @@ smtNormPrims sc = Map.fromList ("Prelude.genBVVec", Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec" VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> - scGlobalDef sc "Prelude.genBVVec")), - + scGlobalDef sc "Prelude.genBVVec") + ), + ("Prelude.genBVVecFromVec", + natFun $ \_m -> tvalFun $ \a -> primFromBVVecOrLit sc a $ \eith -> + PrimFun $ \_def -> natFun $ \n -> primBVTermFun sc $ \len -> + Prim (do n' <- scNat sc n + a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" + sc a + tp <- scGlobalApply sc "Prelude.BVVec" [n', len, a'] + VExtra <$> VExtraTerm (VTyTerm (mkSort 0) tp) <$> + bvVecFromBVVecOrLit sc n n' len a' eith) + ), + ("Prelude.genFromBVVec", + Prim (do tp <- scTypeOfGlobal sc "Prelude.genFromBVVec" + VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> + scGlobalDef sc "Prelude.genFromBVVec") + ), ("Prelude.atBVVec", PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a -> primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> @@ -144,9 +235,7 @@ smtNormPrims sc = Map.fromList PrimFilterFun "CompM" (\case TValue tv -> return tv _ -> mzero) $ \tv -> - Prim (do let ?recordEC = \_ec -> return () - let cfg = error "FIXME: smtNormPrims: need the simulator config" - tv_trm <- readBackTValue sc cfg tv + Prim (do tv_trm <- readBackTValueNoConfig "smtNormPrims (CompM)" sc tv TValue <$> VTyTerm (mkSort 0) <$> scGlobalApply sc "Prelude.CompM" [tv_trm])) ] diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 90a3c4aca8..ce83c2a5d7 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -125,7 +125,8 @@ module SAWScript.Prover.MRSolver.Solver where import Data.Maybe import Data.Either -import Data.List (findIndices, intercalate) +import Data.List (findIndices, intercalate, foldl') +import Data.Bits (shiftL) import Control.Monad.Except import qualified Data.Map as Map import qualified Data.Text as Text @@ -147,6 +148,19 @@ import SAWScript.Prover.MRSolver.SMT -- * Normalizing and Matching on Terms ---------------------------------------------------------------------- +-- | Like 'asVectorType', but returns 'Nothing' if 'asBVVecType' returns 'Just' +asNonBVVecVectorType :: Recognizer Term (Term, Term) +asNonBVVecVectorType (asBVVecType -> Just _) = Nothing +asNonBVVecVectorType t = asVectorType t + +-- | Like 'scBvNat', but if given a bitvector literal it is converted to a +-- natural number literal +mrBvToNat :: Term -> Term -> MRM Term +mrBvToNat _ (asArrayValue -> Just (asBoolType -> Just _, + mapM asBool -> Just bits)) = + liftSC1 scNat $ foldl' (\n bit -> if bit then 2*n+1 else 2*n) 0 bits +mrBvToNat n len = liftSC2 scBvNat n len + -- | Pattern-match on a @LetRecTypes@ list in normal form and return a list of -- the types it specifies, each in normal form and with uvars abstracted out asLRTList :: Term -> MRM [Term] @@ -276,7 +290,8 @@ normComp (CompTerm t) = liftSC2 scGlobalApply "CryptolM.bvVecMapInvarM" [a, b, w, n, f, xs, invar] >>= normCompTerm - -- Convert `atM (bvToNat ...)` into the unfolding of `bvVecAtM` + -- Convert `atM (bvToNat ...) ... (bvToNat ...)` into the unfolding of + -- `bvVecAtM` (asGlobalDef -> Just "CryptolM.atM", [asBvToNat -> Just (w1, n), a, xs, asBvToNat -> Just (w2, i)]) -> do body <- mrGlobalDefBody "CryptolM.bvVecAtM" @@ -285,7 +300,25 @@ normComp (CompTerm t) = mrApplyAll body [w1, n, a, xs, i] >>= normCompTerm else throwMRFailure (MalformedComp t) - -- Convert `updateM (bvToNat ...)` into the unfolding of `bvVecUpdateM` + -- Convert `atM n ... xs (bvToNat ...)` for a constant `n` into the + -- unfolding of `bvVecAtM` after converting `n` to a bitvector constant + -- and applying `genBVVecFromVec` to `xs` + (asGlobalDef -> Just "CryptolM.atM", [n_tm@(asNat -> Just n), a, xs, + asBvToNat -> + Just (w_tm@(asNat -> Just w), + i)]) -> + do body <- mrGlobalDefBody "CryptolM.bvVecAtM" + if n < 1 `shiftL` fromIntegral w then do + n' <- liftSC2 scBvConst w (toInteger n) + err_str <- liftSC1 scString "FIXME: normComp (atM) error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str] + xs' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" + [n_tm, a, xs, err_tm, w_tm, n'] + mrApplyAll body [w_tm, n', a, xs', i] >>= normCompTerm + else throwMRFailure (MalformedComp t) + + -- Convert `updateM (bvToNat ...) ... (bvToNat ...)` into the unfolding of + -- `bvVecUpdateM` (asGlobalDef -> Just "CryptolM.updateM", [asBvToNat -> Just (w1, n), a, xs, asBvToNat -> Just (w2, i), x]) -> do body <- mrGlobalDefBody "CryptolM.bvVecUpdateM" @@ -294,6 +327,23 @@ normComp (CompTerm t) = mrApplyAll body [w1, n, a, xs, i, x] >>= normCompTerm else throwMRFailure (MalformedComp t) + -- Convert `updateM n ... xs (bvToNat ...)` for a constant `n` into the + -- unfolding of `bvVecUpdateM` after converting `n` to a bitvector constant + -- and applying `genBVVecFromVec` to `xs` + (asGlobalDef -> Just "CryptolM.updateM", [n_tm@(asNat -> Just n), a, xs, + asBvToNat -> + Just (w_tm@(asNat -> Just w), + i), x]) -> + do body <- mrGlobalDefBody "CryptolM.bvVecUpdateM" + if n < 1 `shiftL` fromIntegral w then do + n' <- liftSC2 scBvConst w (toInteger n) + err_str <- liftSC1 scString "FIXME: normComp (updateM) error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str] + xs' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" + [n_tm, a, xs, err_tm, w_tm, n'] + mrApplyAll body [w_tm, n', a, xs', i, x] >>= normCompTerm + else throwMRFailure (MalformedComp t) + -- Always unfold: sawLet, multiArgFixM, invariantHint, Num_rec (f@(asGlobalDef -> Just ident), args) | ident `elem` ["Prelude.sawLet", "Prelude.multiArgFixM", @@ -918,6 +968,44 @@ askMRSolverH vars (asPi -> Just (nm1, asDataType -> Just (primName -> "Cryptol.N t2'' <- mrApplyAll t2' [var] askMRSolverH (var : vars') body1' t1'' body2 t2'' +-- If we need to introduce a BVVec on one side and a non-BVVec vector on the +-- other, introduce a BVVec variable and substitute `genBVVecFromVec` of that +-- variable on the non-BVVec side +askMRSolverH vars tp1@(asPi -> Just (nm1, tp@(asBVVecType -> Just (n, len, a)), body1)) t1 + tp2@(asPi -> Just (nm2, asNonBVVecVectorType -> Just (m, a'), body2)) t2 = + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + as_are_eq <- mrConvertible a a' + if ms_are_eq && as_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + let nm = if Text.head nm2 == '_' then nm1 else nm2 + withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> + do err_str_tm <- liftSC1 scString "FIXME: askMRSolverH error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str_tm] + bvvec_tm <- liftSC2 scGlobalApply "Prelude.genFromBVVec" + [n, len, a, var, err_tm, m] + body2' <- substTerm 0 (bvvec_tm : vars') body2 + t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [bvvec_tm] + askMRSolverH (var : vars') body1 t1'' body2' t2'' +askMRSolverH vars tp1@(asPi -> Just (nm1, asNonBVVecVectorType -> Just (m, a'), body2)) t1 + tp2@(asPi -> Just (nm2, tp@(asBVVecType -> Just (n, len, a)), body1)) t2 = + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + as_are_eq <- mrConvertible a a' + if ms_are_eq && as_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + let nm = if Text.head nm2 == '_' then nm1 else nm2 + withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> + do err_str_tm <- liftSC1 scString "FIXME: askMRSolverH error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str_tm] + bvvec_tm <- liftSC2 scGlobalApply "Prelude.genFromBVVec" + [n, len, a, var, err_tm, m] + body1' <- substTerm 0 (bvvec_tm : vars') body1 + t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [bvvec_tm] + askMRSolverH (var : vars') body1' t1'' body2 t2'' + -- Introduce variables of the same type together askMRSolverH vars tp11@(asPi -> Just (nm1, tp1, body1)) t1 tp22@(asPi -> Just (nm2, tp2, body2)) t2 = @@ -937,11 +1025,8 @@ askMRSolverH _ tp1 _ tp2@(asPi -> Just _) _ = throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) -- The base case: both sides are CompM of the same type -askMRSolverH _ tp1@(asCompM -> Just _) t1 tp2@(asCompM -> Just _) t2 = - do tps_are_eq <- mrConvertible tp1 tp2 - if tps_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - m1 <- normCompTerm t1 +askMRSolverH _ (asCompM -> Just _) t1 (asCompM -> Just _) t2 = + do m1 <- normCompTerm t1 m2 <- normCompTerm t2 mrRefines m1 m2 -- If t1 is a named function, add forall xs. f1 xs |= m2 to the env