Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make rewriteSharedTerm ensure that rule instantiations are well-typed. #1351

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions saw-core/src/Verifier/SAW/Rewriter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ reduceSharedTerm _ _ = Nothing
-- and returned in the result set.
rewriteSharedTerm :: forall a. Ord a => SharedContext -> Simpset a -> Term -> IO (Set a, Term)
rewriteSharedTerm sc ss t0 =
do cache <- newCache
do let ?env = []
cache <- newCache
let ?cache = cache
setRef <- newIORef mempty
let ?annSet = setRef
Expand All @@ -618,17 +619,39 @@ rewriteSharedTerm sc ss t0 =
pure (anns, t)

where
rewriteAll :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term
rewriteAll ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Term -> IO Term
rewriteAll (Unshared tf) =
traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop
rewriteSubterms tf >>= scTermF sc >>= rewriteTop
rewriteAll STApp{ stAppIndex = tidx, stAppTermF = tf } =
useCache ?cache tidx (traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop)

traverseTF :: forall b. (b -> IO b) -> TermF b -> IO (TermF b)
traverseTF _ tf@(Constant {}) = pure tf
traverseTF f tf = traverse f tf

rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term
useCache ?cache tidx (rewriteSubterms tf >>= scTermF sc >>= rewriteTop)

rewriteSubterms ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
TermF Term -> IO (TermF Term)
rewriteSubterms tf =
case tf of
FTermF ftf -> FTermF <$> traverse rewriteAll ftf
App t1 t2 -> App <$> rewriteAll t1 <*> rewriteAll t2
Lambda x t1 t2 ->
do t1' <- rewriteAll t1
localCache <- newCache
let localEnv = t1' : ?env
t2' <- let ?cache = localCache; ?env = localEnv in rewriteAll t2
pure (Lambda x t1' t2')
Pi x t1 t2 ->
do t1' <- rewriteAll t1
localCache <- newCache
let localEnv = t1' : ?env
t2' <- let ?cache = localCache; ?env = localEnv in rewriteAll t2
pure (Pi x t1' t2')
LocalVar{} -> pure tf
Constant{} -> pure tf

rewriteTop ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
Term -> IO Term
rewriteTop t =
case reduceSharedTerm sc t of
Nothing -> apply (Net.unify_term ss t) t
Expand All @@ -638,8 +661,9 @@ rewriteSharedTerm sc ss t0 =
recordAnn Nothing = return ()
recordAnn (Just a) = modifyIORef' ?annSet (Set.insert a)

apply :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
[Either (RewriteRule a) Conversion] -> Term -> IO Term
apply ::
(?env :: [Term], ?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) =>
[Either (RewriteRule a) Conversion] -> Term -> IO Term
apply [] t = return t
apply (Left (RewriteRule {ctxt, lhs, rhs, permutative, annotation}) : rules) t = do
result <- scMatch sc lhs t
Expand All @@ -665,8 +689,13 @@ rewriteSharedTerm sc ss t0 =
| otherwise ->
do -- putStrLn "REWRITING:"
-- print lhs
recordAnn annotation
rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) rhs
tys <- traverse (scTypeOf' sc ?env) (Map.elems inst)
if tys /= ctxt
then
do apply rules t
else
do recordAnn annotation
rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) rhs
apply (Right conv : rules) t =
do -- putStrLn "REWRITING:"
-- print (Net.toPat conv)
Expand Down