Skip to content

Commit

Permalink
fix bugs in new InjConversion interface
Browse files Browse the repository at this point in the history
  • Loading branch information
m-yac committed Sep 21, 2022
1 parent f496f69 commit 37565f9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 35 deletions.
5 changes: 2 additions & 3 deletions heapster-saw/examples/sha512_mr_solver.saw
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ monadify_term {{ processBlock_spec }};
monadify_term {{ processBlocks_loop_spec }};
monadify_term {{ processBlocks_spec }};

// mr_solver_set_debug_level 3;
mr_solver_prove round_00_15 {{ round_00_15_spec }};
// mr_solver_prove round_16_80 {{ round_16_80_spec }};
// mr_solver_prove processBlock {{ processBlock_spec }};
mr_solver_prove round_16_80 {{ round_16_80_spec }};
mr_solver_prove processBlock {{ processBlock_spec }};
// mr_solver_prove processBlocks {{ processBlocks_spec }};
89 changes: 58 additions & 31 deletions src/SAWScript/Prover/MRSolver/SMT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,21 @@ mrProvable bool_tm =
instUVar :: LocalName -> Term -> MRM Term
instUVar nm tp = mrDebugPPPrefix 3 "instUVar" (nm, tp) >>
liftSC1 scWhnf tp >>= \case
(asNonBVVecVectorType -> Just (m, a)) ->
liftSC1 smtNorm m >>= \m' -> case asBvToNat m' of
-- For variables of type Vec of length which normalizes to
-- a bvToNat term, recurse and wrap the result in genFromBVVec
Just (n, len) -> do
tp' <- liftSC2 scVecType m' a
tm' <- instUVar nm tp'
mrGenFromBVVec n len a tm' "instUVar" m
-- Otherwise for variables of type Vec, create a @Nat -> a@
-- function as an ExtCns and apply genBVVec to it
Nothing -> do
nat_tp <- liftSC0 scNatType
tp' <- liftSC3 scPi "_" nat_tp =<< liftTermLike 0 1 a
tm' <- instUVar nm tp'
liftSC2 scGlobalApply "CryptolM.genCryM" [m, a, tm']
-- For variables of type BVVec, create a @Vec n Bool -> a@ function
-- as an ExtCns and apply genBVVec to it
(asBVVecType -> Just (n, len, a)) -> do
Expand Down Expand Up @@ -499,23 +514,26 @@ nonTrivialConv (ConvComp cs) = not (null cs)
-- and the arguments to those constructors are convertible via 'mrConvertible'
mrConvsConvertible :: InjConversion -> InjConversion -> MRM Bool
mrConvsConvertible (ConvComp cs1) (ConvComp cs2) =
and <$> zipWithM mrSingleConvsConvertible cs1 cs2
where mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM Bool
mrSingleConvsConvertible SingleNatToNum SingleNatToNum = return True
mrSingleConvsConvertible (SingleBVToNat n1) (SingleBVToNat n2) = return $ n1 == n2
mrSingleConvsConvertible (SingleBVVecToVec n1 len1 a1 m1)
(SingleBVVecToVec n2 len2 a2 m2) =
do ns_are_eq <- mrConvertible n1 n2
lens_are_eq <- mrConvertible len1 len2
as_are_eq <- mrConvertible a1 a2
ms_are_eq <- mrConvertible m1 m2
return $ ns_are_eq && lens_are_eq && as_are_eq && ms_are_eq
mrSingleConvsConvertible (SinglePairToPair cL1 cR1)
(SinglePairToPair cL2 cR2) =
do cLs_are_eq <- mrConvsConvertible cL1 cL2
cRs_are_eq <- mrConvsConvertible cR1 cR2
return $ cLs_are_eq && cRs_are_eq
mrSingleConvsConvertible _ _ = return False
if length cs1 /= length cs2 then return False
else and <$> zipWithM mrSingleConvsConvertible cs1 cs2

-- | Used in the definition of 'mrConvsConvertible'
mrSingleConvsConvertible :: SingleInjConversion -> SingleInjConversion -> MRM Bool
mrSingleConvsConvertible SingleNatToNum SingleNatToNum = return True
mrSingleConvsConvertible (SingleBVToNat n1) (SingleBVToNat n2) = return $ n1 == n2
mrSingleConvsConvertible (SingleBVVecToVec n1 len1 a1 m1)
(SingleBVVecToVec n2 len2 a2 m2) =
do ns_are_eq <- mrConvertible n1 n2
lens_are_eq <- mrConvertible len1 len2
as_are_eq <- mrConvertible a1 a2
ms_are_eq <- mrConvertible m1 m2
return $ ns_are_eq && lens_are_eq && as_are_eq && ms_are_eq
mrSingleConvsConvertible (SinglePairToPair cL1 cR1)
(SinglePairToPair cL2 cR2) =
do cLs_are_eq <- mrConvsConvertible cL1 cL2
cRs_are_eq <- mrConvsConvertible cR1 cR2
return $ cLs_are_eq && cRs_are_eq
mrSingleConvsConvertible _ _ = return False

-- | Apply the given 'InjConversion' to the given term, where compositions
-- @c1 <> c2 <> ... <> cn@ are applied from right to left as in function
Expand All @@ -540,11 +558,16 @@ mrApplyInvConv (ConvComp cs) = flip (foldlM go) cs
go t SingleNatToNum = case asNum t of
Just (Left t') -> return t'
_ -> error "mrApplyInvConv: Num term does not normalize to TCNum constructor"
go t (SingleBVToNat n) =
do n_tm <- liftSC1 scNat n
liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t]
go t (SingleBVVecToVec n len a m) =
mrGenBVVecFromVec m a t "mrApplyInvConv" n len
go t (SingleBVToNat n) = case asBvToNat t of
Just (asNat -> Just n', t') | n == n' -> return t'
_ -> do n_tm <- liftSC1 scNat n
liftSC2 scGlobalApply "Prelude.bvNat" [n_tm, t]
go t c@(SingleBVVecToVec n len a m) = case asGenFromBVVecTerm t of
Just (n', len', a', t', _, m') ->
do eq <- mrSingleConvsConvertible c (SingleBVVecToVec n' len' a' m')
if eq then return t'
else mrGenBVVecFromVec m a t "mrApplyInvConv" n len
_ -> mrGenBVVecFromVec m a t "mrApplyInvConv" n len
go t (SinglePairToPair c1 c2) =
do t1 <- mrApplyInvConv c1 =<< doTermProj t TermProjLeft
t2 <- mrApplyInvConv c2 =<< doTermProj t TermProjRight
Expand Down Expand Up @@ -635,16 +658,18 @@ findInjConvs tp1 t1 (asNonBVVecVectorType -> Just (m, _))
-- bit-width from the other side
findInjConvs (asNonBVVecVectorType -> Just (m, a')) _
(asBVVecType -> Just (n, len, a)) _ =
do bvvec_tp <- liftSC2 scVecType n a
lens_are_eq <- mrProveEq m =<< mrBvToNat n len
do len_nat <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len]
bvvec_tp <- liftSC2 scVecType len_nat a
lens_are_eq <- mrProveEq m len_nat
as_are_eq <- mrConvertible a a'
if lens_are_eq && as_are_eq
then return $ Just (bvvec_tp, BVVecToVec n len a m, NoConv)
else return $ Nothing
findInjConvs (asBVVecType -> Just (n, len, a)) _
(asNonBVVecVectorType -> Just (m, a')) _ =
do bvvec_tp <- liftSC2 scVecType n a
lens_are_eq <- mrProveEq m =<< mrBvToNat n len
do len_nat <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len]
bvvec_tp <- liftSC2 scVecType len_nat a
lens_are_eq <- mrProveEq m len_nat
as_are_eq <- mrConvertible a a'
if lens_are_eq && as_are_eq
then return $ Just (bvvec_tp, NoConv, BVVecToVec n len a m)
Expand Down Expand Up @@ -761,10 +786,12 @@ mrProveRel het t1 t2 =
mrDebugPPPrefixSep 2 nm t1 (if het then "~=" else "==") t2
tp1 <- mrTypeOf t1 >>= mrSubstEVars
tp2 <- mrTypeOf t2 >>= mrSubstEVars
cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2
res <- withTermInCtx cond_in_ctx mrProvable
debugPrint 2 $ nm ++ ": " ++ if res then "Success" else "Failure"
return res
tps_eq <- mrConvertible tp1 tp2
if not het && not tps_eq then return False
else do cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2
res <- withTermInCtx cond_in_ctx mrProvable
debugPrint 2 $ nm ++ ": " ++ if res then "Success" else "Failure"
return res

-- | Prove that two terms are related, heterogeneously iff the first argument,
-- is true, instantiating evars if necessary, or throwing an error if this is
Expand Down Expand Up @@ -883,7 +910,7 @@ mrProveRelH' _ het tp1 tp2 t1 t2 = findInjConvs tp1 (Just t1) tp2 (Just t2) >>=
-- injective conversions from a type @tp@ to @tp1@ and @tp2@, apply the
-- inverses of these conversions to @t1@ and @t2@ and continue checking
-- equality on the results
Just (tp, c1, c2) | het, nonTrivialConv c1 || nonTrivialConv c2 -> do
Just (tp, c1, c2) | nonTrivialConv c1 || nonTrivialConv c2 -> do
t1' <- mrApplyInvConv c1 t1
t2' <- mrApplyInvConv c2 t2
mrProveRelH True tp tp t1' t2'
Expand Down
5 changes: 4 additions & 1 deletion src/SAWScript/Prover/MRSolver/Solver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,8 @@ mrRefinesFunH k vars (asPi -> Just (nm1, tp1, _)) t1
-- @tp2@, introduce a variable of type @tp@, apply both conversions to it,
-- and substitute the results on the left and right sides, respectively
Just (tp, c1, c2) ->
mrDebugPPPrefixSep 3 "mrRefinesFunH calling findInjConvs" tp1 "," tp2 >>
mrDebugPPPrefix 3 "mrRefinesFunH got type" tp >>
let nm = maybe "_" id $ find ((/=) '_' . Text.head)
$ [nm1, nm2] ++ catMaybes [ asLambdaName t1
, asLambdaName t2 ] in
Expand Down Expand Up @@ -1151,7 +1153,8 @@ type MRSolverResult = Maybe (FunName, FunAssump)
askMRSolverH :: (NormComp -> NormComp -> MRM ()) ->
Term -> Term -> MRM MRSolverResult
askMRSolverH f t1 t2 =
do m1 <- normCompTerm t1
do mrUVars >>= mrDebugPPPrefix 1 "askMRSolverH uvars:"
m1 <- normCompTerm t1
m2 <- normCompTerm t2
f m1 m2
case (m1, m2) of
Expand Down

0 comments on commit 37565f9

Please sign in to comment.