Skip to content

Commit

Permalink
Merge pull request #1357 from GaloisInc/rust-recursive-types-redo
Browse files Browse the repository at this point in the history
Define recursive Heapster shapes from recursive Rust types
  • Loading branch information
mergify[bot] authored Jun 28, 2021
2 parents bfac705 + 2da41f9 commit ec8a060
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 84 deletions.
17 changes: 12 additions & 5 deletions heapster-saw/examples/rust_data.saw
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ heapster_define_llvmshape env "u32" 64 "" "fieldsh(32,int32<>)";
// bool type
heapster_define_llvmshape env "bool" 64 "" "fieldsh(1,int1<>)";

// Box type
heapster_define_llvmshape env "Box" 64 "T:llvmshape 64" "ptrsh(T)";

// Result type
heapster_define_rust_type env "pub enum Result<X,Y> { Ok (X), Err (Y) }";

Expand All @@ -27,8 +30,9 @@ heapster_define_rust_type env "pub enum Sum<X,Y> { Left (X), Right (Y) }";
heapster_define_rust_type env "pub enum Option<X> { None, Some (X) }";

// List type
heapster_define_llvmshape env "List" 64 "L:perm(llvmptr 64),X:llvmshape 64" "(fieldsh(eq(llvmword(0)))) orsh (fieldsh(eq(llvmword(1)));X;fieldsh(L))";
heapster_define_recursive_perm env "ListPerm" "X:llvmshape 64, Xlen:bv 64, rw:rwmodality, l:lifetime" "llvmptr 64" ["[l]memblock(rw,0,Xlen + 16,List<ListPerm<X,Xlen,rw,l>,X>)"] "\\ (X:sort 0) (_:Vec 64 Bool) -> List X" "\\ (X:sort 0) (_:Vec 64 Bool) -> foldListPermH X" "\\ (X:sort 0) (_:Vec 64 Bool) -> unfoldListPermH X";
//heapster_define_llvmshape env "List" 64 "L:perm(llvmptr 64),X:llvmshape 64" "(fieldsh(eq(llvmword(0)))) orsh (fieldsh(eq(llvmword(1)));X;fieldsh(L))";
//heapster_define_recursive_perm env "ListPerm" "X:llvmshape 64, Xlen:bv 64, rw:rwmodality, l:lifetime" "llvmptr 64" ["[l]memblock(rw,0,Xlen + 16,List<ListPerm<X,Xlen,rw,l>,X>)"] "\\ (X:sort 0) (_:Vec 64 Bool) -> List X" "\\ (X:sort 0) (_:Vec 64 Bool) -> foldListPermH X" "\\ (X:sort 0) (_:Vec 64 Bool) -> unfoldListPermH X";
heapster_define_rust_type env "pub enum List<X> { Nil, Cons (X,Box<List<X>>) }";

// The String type
heapster_define_llvmshape env "String" 64 "" "exsh cap:bv 64. ptrsh(arraysh(cap,1,[(8,int8<>)]));fieldsh(int64<>);fieldsh(eq(llvmword(cap)))";
Expand Down Expand Up @@ -135,15 +139,18 @@ cycle_true_enum_sym <- heapster_find_symbol env "15cycle_true_enum";

// list_is_empty
list_is_empty_sym <- heapster_find_symbol env "13list_is_empty";
heapster_typecheck_fun_rename env list_is_empty_sym "list_is_empty" "(rw:rwmodality).arg0:ListPerm<fieldsh(int64<>),8,rw,always> -o ret:int1<>";
heapster_typecheck_fun_rename env list_is_empty_sym "list_is_empty" "<'a> fn (l: &'a List<u64>) -> bool";
//heapster_typecheck_fun_rename env list_is_empty_sym "list_is_empty" "(rw:rwmodality).arg0:ListPerm<fieldsh(int64<>),8,rw,always> -o ret:int1<>";

// list_head
list_head_sym <- heapster_find_symbol env "9list_head";
heapster_typecheck_fun_rename env list_head_sym "list_head" "(rw:rwmodality).arg0:ListPerm<fieldsh(int64<>),8,rw,always> -o ret:memblock(W,0,16,Result<fieldsh(int64<>),emptysh>)";
heapster_typecheck_fun_rename env list_head_sym "list_head" "<'a> fn (l: &'a List<u64>) -> Box<Sum<u64,()>>";
//heapster_typecheck_fun_rename env list_head_sym "list_head" "(rw:rwmodality).arg0:List<fieldsh(int64<>),8,rw,always> -o ret:memblock(W,0,16,Result<fieldsh(int64<>),emptysh>)";

// list_head_impl
list_head_impl_sym <- heapster_find_symbol env "14list_head_impl";
heapster_typecheck_fun_rename env list_head_impl_sym "list_head_impl" "(rw:rwmodality).arg0:ListPerm<fieldsh(int64<>),8,rw,always> -o ret:(struct(eq(llvmword(0)),exists z:bv 64. eq(llvmword(z)))) or (struct(eq(llvmword(1)),true))";
//heapster_typecheck_fun_rename env list_head_impl_sym "list_head_impl" "<'a> fn (l: &'a List<u64>) -> Result<u64,()>";
//heapster_typecheck_fun_rename env list_head_impl_sym "list_head_impl" "(rw:rwmodality).arg0:List<fieldsh(int64<>),8,rw,always> -o ret:(struct(eq(llvmword(0)),exists z:bv 64. eq(llvmword(z)))) or (struct(eq(llvmword(1)),true))";

// StrStruct::new
str_struct_new <- heapster_find_symbol env "9StrStruct3new";
Expand Down
100 changes: 80 additions & 20 deletions heapster-saw/examples/rust_data.v

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion heapster-saw/examples/rust_data_proofs.v
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Print list_is_empty__tuple_fun.

Print list_head__tuple_fun.

Print list_head_impl__tuple_fun.
(* Print list_head_impl__tuple_fun. *)

Print str_struct_new__tuple_fun.

Expand Down
7 changes: 4 additions & 3 deletions heapster-saw/src/Verifier/SAW/Heapster/IRTTranslation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,10 @@ instance IRTTyVars (PermExpr (LLVMShapeType w)) where
, [nuMP| Nothing |] <- mbMatch maybe_rw
, [nuMP| Nothing |] <- mbMatch maybe_l
-> return ([], IRTRecVar)
IRTRecShapeName _ _
-> throwError $ "recursive shape applied to different"
++ " arguments in its definition!"
IRTRecShapeName _ nmsh_rec
| mbLift $ (namedShapeName nmsh_rec ==) . namedShapeName <$> nmsh
-> throwError $ "recursive shape applied to different"
++ " arguments in its definition!"
_ -> case mbMatch $ namedShapeBody <$> nmsh of
[nuMP| DefinedShapeBody _ |] ->
irtTyVars (mbMap2 unfoldNamedShape nmsh args)
Expand Down
3 changes: 2 additions & 1 deletion heapster-saw/src/Verifier/SAW/Heapster/Implication.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5720,7 +5720,8 @@ proveVarLLVMBlocks' x ps psubst mb_bps_in mb_ps = case mbMatch mb_bps_in of
, Nothing <- psubstLookup psubst memb
, Just off <- partialSubst psubst (fmap llvmBlockOffset mb_bp)
, Just i <- findIndex (isLLVMAtomicPermWithOffset off) ps
, Just len1 <- llvmAtomicPermLen (ps!!i) ->
, Just len1 <- llvmAtomicPermLen (ps!!i)
, not (bvIsZero len1) ->

-- Build existential memblock perms with fresh variables for shapes, where
-- the first one has the length of the atomic perm we found and the other
Expand Down
2 changes: 1 addition & 1 deletion heapster-saw/src/Verifier/SAW/Heapster/PermParser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,5 @@ parseRustTypeString ::
PermEnv {- ^ permission environment -} ->
prx w {- ^ pointer bit-width proxy -} ->
String {- ^ input text -} ->
m SomeNamedShape
m (SomePartialNamedShape w)
parseRustTypeString = parseNamedShapeFromRustDecl
9 changes: 9 additions & 0 deletions heapster-saw/src/Verifier/SAW/Heapster/Permissions.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5842,6 +5842,15 @@ instance AbstractVars (NamedPermName ns args a) where
-- * Abstracting out named shapes
----------------------------------------------------------------------

-- | An existentially quantified, partially defined LLVM shape applied to
-- some arguments
data SomePartialNamedShape w where
NonRecShape :: String -> CruCtx args -> Mb args (PermExpr (LLVMShapeType w))
-> SomePartialNamedShape w
RecShape :: String -> CruCtx args
-> Mb (args :> LLVMShapeType w) (PermExpr (LLVMShapeType w))
-> SomePartialNamedShape w

-- | An existentially quantified LLVM shape applied to some arguments
data SomeNamedShapeApp w where
SomeNamedShapeApp :: String -> CruCtx args -> PermExprs args ->
Expand Down
152 changes: 113 additions & 39 deletions heapster-saw/src/Verifier/SAW/Heapster/RustTypes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import qualified Data.Binding.Hobbits.NameSet as NameSet

import Language.Rust.Syntax
import Language.Rust.Parser
import Language.Rust.Data.Ident (Ident(..), mkIdent, name)
import Language.Rust.Data.Ident (Ident(..), name)

import Prettyprinter as PP

Expand All @@ -72,14 +72,16 @@ data SomeLLVMPerm =
$(mkNuMatching [t| SomeLLVMPerm |])

-- | Info used for converting Rust types to shapes
-- NOTE: @rciRecType@ should probably have some info about lifetimes
data RustConvInfo =
RustConvInfo { rciPermEnv :: PermEnv,
rciCtx :: [(String, TypedName)] }
rciCtx :: [(String, TypedName)],
rciRecType :: Maybe (RustName, [RustName], TypedName) }

-- | The default, top-level 'RustConvInfo' for a given 'PermEnv'
mkRustConvInfo :: PermEnv -> RustConvInfo
mkRustConvInfo env =
RustConvInfo { rciPermEnv = env, rciCtx = [] }
RustConvInfo { rciPermEnv = env, rciCtx = [], rciRecType = Nothing }

-- | The Rust conversion monad is just a state-error monad
newtype RustConvM a =
Expand Down Expand Up @@ -129,20 +131,30 @@ rustCtx1 name tp = MNil :>: Pair (Constant name) tp
rustCtxCtx :: RustCtx ctx -> CruCtx ctx
rustCtxCtx = cruCtxOfTypes . RL.map (\(Pair _ tp) -> tp)

-- | Extend a 'RustCtx' with a single binding on the right
rustCtxCons :: RustCtx ctx -> String -> TypeRepr a -> RustCtx (ctx :> a)
rustCtxCons ctx nm tp = ctx :>: Pair (Constant nm) tp

-- | Build a 'RustCtx' from the given variable names, all having the same type
rustCtxOfNames :: TypeRepr a -> [String] -> Some RustCtx
rustCtxOfNames tp =
foldl (\(Some ctx) name -> Some (ctx :>: Pair (Constant name) tp)) (Some MNil)

-- | Run a 'RustConvM' computation in a context of bound type-level variables
inRustCtx :: NuMatching a => RustCtx ctx -> RustConvM a ->
RustConvM (Mb ctx a)
inRustCtx ctx m =
mbM $ nuMulti (RL.map (\_-> Proxy) ctx) $ \ns ->
-- | Run a 'RustConvM' computation in a context of bound type-level variables,
-- where the bound names are passed to the computation
inRustCtxF :: NuMatching a => RustCtx ctx -> (RAssign Name ctx -> RustConvM a) ->
RustConvM (Mb ctx a)
inRustCtxF ctx m =
mbM $ nuMulti (RL.map (\_ -> Proxy) ctx) $ \ns ->
let ns_ctx =
RL.toList $ RL.map2 (\n (Pair (Constant str) tp) ->
Constant (str, Some (Typed tp n))) ns ctx in
local (\info -> info { rciCtx = ns_ctx ++ rciCtx info }) m
local (\info -> info { rciCtx = ns_ctx ++ rciCtx info }) (m ns)

-- | Run a 'RustConvM' computation in a context of bound type-level variables
inRustCtx :: NuMatching a => RustCtx ctx -> RustConvM a ->
RustConvM (Mb ctx a)
inRustCtx ctx m = inRustCtxF ctx (const m)

-- | Class for a generic "conversion from Rust" function, given the bit width of
-- the pointer type
Expand Down Expand Up @@ -252,7 +264,7 @@ namedTypeTable w =

-- | A fully qualified Rust path without any of the parameters; e.g.,
-- @Foo<X>::Bar<Y,Z>::Baz@ just becomes @[Foo,Bar,Baz]@
newtype RustName = RustName [Ident]
newtype RustName = RustName [Ident] deriving (Eq)

instance Show RustName where
show (RustName ids) = concat $ intersperse "::" $ map show ids
Expand All @@ -279,6 +291,31 @@ rsPathParams :: Path a -> [PathParameters a]
rsPathParams (Path _ segments _) =
mapMaybe (\(PathSegment _ maybe_params _) -> maybe_params) segments

-- | Get the 'RustName' of a type, if it's a 'PathTy'
tyName :: Ty a -> Maybe RustName
tyName (PathTy _ path _) = Just $ rsPathName path
tyName _ = Nothing

-- | Decide whether a Rust type is named (i.e. a 'PathTy')
isNamedType :: Ty a -> Bool
isNamedType (PathTy _ _ _) = True
isNamedType _ = False

-- | Decide whether 'PathParameters' are all named types (angle-bracketed only)
isNamedParams :: PathParameters a -> Bool
isNamedParams (AngleBracketed _ tys _ _) = all isNamedType tys
isNamedParams _ = error "Parenthesized types not supported"

-- | Get all of the 'RustName's of path parameters, if they're angle-bracketed
pParamNames :: PathParameters a -> [RustName]
pParamNames (AngleBracketed _ tys _ _) = mapMaybe tyName tys
pParamNames _ = error "Parenthesized types not supported"

-- | Modify a 'RustConvM' to be run with a recursive type
withRecType :: (1 <= w, KnownNat w) => RustName -> [RustName] -> Name (LLVMShapeType w) ->
RustConvM a -> RustConvM a
withRecType rust_n rust_ns rec_n = local (\info -> info { rciRecType = Just (rust_n, rust_ns, Some (Typed knownRepr rec_n)) })


----------------------------------------------------------------------
-- * Converting Rust Types to Heapster Shapes
Expand Down Expand Up @@ -320,12 +357,27 @@ instance RsConvert w (Ty Span) (PermExpr (LLVMShapeType w)) where
sh <- rsConvert w tp'
return $ PExpr_PtrShape (Just PExpr_Read) (Just l) sh
rsConvert w (PathTy Nothing path _) =
do someShapeFn <- rsConvert w (rsPathName path)
someTypedArgs <- rsConvert w (rsPathParams path)
case tryApplySomeShapeFun someShapeFn someTypedArgs of
Just shTp -> return shTp
Nothing ->
fail $ renderDoc $ pretty "Failed to apply shape funtion to arguments"
do mrec <- asks rciRecType
case mrec of
Just (rec_n, rec_arg_ns, sh_nm)
| rec_n == rsPathName path &&
all isNamedParams (rsPathParams path) &&
rec_arg_ns == concatMap pParamNames (rsPathParams path) ->
PExpr_Var <$> castTypedM "TypedName" (LLVMShapeRepr (natRepr w)) sh_nm
Just (rec_n, _, _)
| rec_n == rsPathName path -> fail "Arguments do not match"
_ ->
do someShapeFn@(SomeShapeFun expected _ ) <- rsConvert w (rsPathName path)
someTypedArgs@(Some tyArgs) <- rsConvert w (rsPathParams path)
let actual = typedPermExprsCtx tyArgs
case tryApplySomeShapeFun someShapeFn someTypedArgs of
Just shTp -> return shTp
Nothing ->
fail $ renderDoc $ fillSep
[ pretty "Converting PathTy: " <+> pretty (show $ rsPathName path)
, pretty "Expected arguments:" <+> pretty expected
, pretty "Actual arguments:" <+> pretty actual
]
rsConvert (w :: prx w) (BareFn _ abi rust_ls2 fn_tp span) =
do Some3FunPerm fun_perm <- rsConvertMonoFun w span abi rust_ls2 fn_tp
let args = funPermArgs fun_perm
Expand All @@ -335,6 +387,9 @@ instance RsConvert w (Ty Span) (PermExpr (LLVMShapeType w)) where
Perm_LLVMFunPtr
(FunctionHandleRepr (cruCtxToRepr args) (funPermRet fun_perm)) $
ValPerm_Conj1 $ Perm_Fun fun_perm
rsConvert w (TupTy tys _) =
do tyShs <- mapM (rsConvert w) tys
return $ foldr PExpr_SeqShape PExpr_EmptyShape tyShs
rsConvert _ tp = fail ("Rust type not supported: " ++ show tp)

instance RsConvert w (Arg Span) (PermExpr (LLVMShapeType w)) where
Expand All @@ -360,11 +415,27 @@ isRecursiveDef item =
_ -> False

where
-- TODO: I hate this, it needs to be better
isBoxed :: Ident -> Ty Span -> Bool
isBoxed i (PathTy _ (Path _ [PathSegment box (Just (AngleBracketed _ [PathTy _ (Path _ [PathSegment i' _ _] _) _] _ _)) _] _) _) =
box == mkIdent "Box" && i == i'
isBoxed _ _ = False
tyContainsName :: Ident -> Ty Span -> Bool
tyContainsName i ty =
case ty of
Slice t _ -> tyContainsName i t
Language.Rust.Syntax.Array t _ _ -> tyContainsName i t
Ptr _ t _ -> tyContainsName i t
Rptr _ _ t _ -> tyContainsName i t
TupTy ts _ -> any (tyContainsName i) ts
PathTy _ (Path _ segs _) _ -> any (segContainsName i) segs
ParenTy t _ -> tyContainsName i t
_ -> False

segContainsName :: Ident -> PathSegment Span -> Bool
segContainsName i (PathSegment i' mParams _) =
i == i' || case mParams of
Nothing -> False
Just params -> paramsContainName i params

paramsContainName :: Ident -> PathParameters Span -> Bool
paramsContainName i (AngleBracketed _ tys _ _) = any (tyContainsName i) tys
paramsContainName _ (Parenthesized _ _ _) = error "Parenthesized types not supported"

typeOf :: StructField Span -> Ty Span
typeOf (StructField _ _ t _ _) = t
Expand All @@ -373,31 +444,34 @@ isRecursiveDef item =
getVD (Variant _ _ vd _ _) = vd

containsName :: Ident -> VariantData Span -> Bool
containsName i (StructD fields _) = any (isBoxed i) $ typeOf <$> fields
containsName i (TupleD fields _) = any (isBoxed i) $ typeOf <$> fields
containsName i (StructD fields _) = any (tyContainsName i) $ typeOf <$> fields
containsName i (TupleD fields _) = any (tyContainsName i) $ typeOf <$> fields
containsName _ (UnitD _) = False

instance RsConvert w (Item Span) SomeNamedShape where
rsConvert w s@(StructItem _ _ ident vd generics _)
| isRecursiveDef s = error "Recursive struct definitions not yet supported"
-- | NOTE: The translation of recursive types ignores lifetime parameters for now
instance RsConvert w (Item Span) (SomePartialNamedShape w) where
rsConvert w s@(StructItem _ _ ident vd generics@(Generics _ tys _ _) _)
| isRecursiveDef s =
do Some ctx <- rsConvert w generics
let ctx' = rustCtxCons ctx (name ident) (LLVMShapeRepr $ natRepr w)
tyIdents = (\(TyParam _ i _ _ _) -> [i]) <$> tys
sh <- inRustCtxF ctx' $ \(_ :>: rec_n) -> withRecType (RustName [ident]) (RustName <$> tyIdents) rec_n $ rsConvert w vd
return $ RecShape (name ident) (rustCtxCtx ctx) sh
| otherwise =
do Some ctx <- rsConvert w generics
sh <- inRustCtx ctx $ rsConvert w vd
let nsh = NamedShape { namedShapeName = name ident
, namedShapeArgs = rustCtxCtx ctx
, namedShapeBody = DefinedShapeBody sh
}
return $ SomeNamedShape nsh
rsConvert w e@(Enum _ _ ident variants generics _)
| isRecursiveDef e = error "Recursive enum definitions not yet supported"
return $ NonRecShape (name ident) (rustCtxCtx ctx) sh
rsConvert w e@(Enum _ _ ident variants generics@(Generics _ tys _ _) _)
| isRecursiveDef e =
do Some ctx <- rsConvert w generics
let ctx' = rustCtxCons ctx (name ident) (LLVMShapeRepr $ natRepr w)
tyIdents = (\(TyParam _ i _ _ _) -> [i]) <$> tys
sh <- inRustCtxF ctx' $ \(_ :>: rec_n) -> withRecType (RustName [ident]) (RustName <$> tyIdents) rec_n $ rsConvert w variants
return $ RecShape (name ident) (rustCtxCtx ctx) sh
| otherwise =
do Some ctx <- rsConvert w generics
sh <- inRustCtx ctx $ rsConvert w variants
let nsh = NamedShape { namedShapeName = name ident
, namedShapeArgs = rustCtxCtx ctx
, namedShapeBody = DefinedShapeBody sh
}
return $ SomeNamedShape nsh
return $ NonRecShape (name ident) (rustCtxCtx ctx) sh
rsConvert _ item = fail ("Top-level item not supported: " ++ show item)

instance RsConvert w [Variant Span] (PermExpr (LLVMShapeType w)) where
Expand Down Expand Up @@ -942,7 +1016,7 @@ parseFunPermFromRust _ _ _ _ str =
-- Note: No CruCtx / TypeRepr as arguments for now
parseNamedShapeFromRustDecl :: (Fail.MonadFail m, 1 <= w, KnownNat w) =>
PermEnv -> prx w -> String ->
m SomeNamedShape
m (SomePartialNamedShape w)
parseNamedShapeFromRustDecl env w str
| Right item <- parse @(Item Span) (inputStreamFromString str) =
runLiftRustConvM (mkRustConvInfo env) $ rsConvert w item
Expand Down
Loading

0 comments on commit ec8a060

Please sign in to comment.