Skip to content

Commit

Permalink
Check types of refs right before passing to CCache
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisPenner committed Sep 24, 2024
1 parent 67f4597 commit 2bbf1f7
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 49 deletions.
2 changes: 1 addition & 1 deletion stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ allow-newer-deps:

ghc-options:
# All packages
"$locals": -Wall -Werror -Wno-name-shadowing -Wno-missing-pattern-synonym-signatures -fprint-expanded-synonyms -fwrite-ide-info -Wunused-packages #-freverse-errors
"$locals": -Wall -Werror -Wno-name-shadowing -Wno-missing-pattern-synonym-signatures -fprint-expanded-synonyms -fwrite-ide-info -Wunused-packages -debug #-freverse-errors

# See https://github.com/haskell/haskell-language-server/issues/208
"$everything": -haddock
Expand Down
35 changes: 21 additions & 14 deletions unison-runtime/src/Unison/Runtime/ANF.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ module Unison.Runtime.ANF
Direction (..),
SuperNormal (..),
SuperGroup (..),
Cacheability (..),
POp (..),
FOp,
close,
Expand Down Expand Up @@ -80,7 +81,7 @@ module Unison.Runtime.ANF
where

import Control.Exception (throw)
import Control.Lens (snoc, unsnoc)
import Control.Lens (over, snoc, traversed, unsnoc, _2)
import Control.Monad.Reader (ReaderT (..), ask, local)
import Control.Monad.State (MonadState (..), State, gets, modify, runState)
import Data.Bifoldable (Bifoldable (..))
Expand Down Expand Up @@ -402,7 +403,7 @@ freshFloat avoid (Var.freshIn avoid -> v0) =
groupFloater ::
(Var v, Monoid a) =>
(Term v a -> FloatM v a (Term v a)) ->
[(v, Term v a)] ->
[(v, Term v a, Cacheability)] ->
FloatM v a (Map v v)
groupFloater rec vbs = do
cvs <- gets (\(vs, _, _) -> vs)
Expand Down Expand Up @@ -556,8 +557,8 @@ floatGroup ::
(Var v) =>
(Monoid a) =>
Map v Reference ->
[(v, Term v a)] ->
([(v, Id)], [(Reference, Term v a)], [(Reference, Term v a)])
[(v, Term v a, Cacheability)] ->
([(v, Id)], [(Reference, Term v a, Cacheability)], [(Reference, Term v a)])
floatGroup orig grp = case runState go0 (Set.empty, [], []) of
(_, st) -> case postFloat orig st of
(_, subvs, tops, dcmp) -> (subvs, tops, dcmp)
Expand Down Expand Up @@ -601,9 +602,9 @@ lamLiftGroup ::
(Var v) =>
(Monoid a) =>
Map v Reference ->
[(v, Term v a)] ->
([(v, Id)], [(Reference, Term v a)], [(Reference, Term v a)])
lamLiftGroup orig gr = floatGroup orig . (fmap . fmap) (close keep) $ gr
[(v, Term v a, Cacheability)] ->
([(v, Id)], [(Reference, Term v a, Cacheability)], [(Reference, Term v a)])
lamLiftGroup orig gr = floatGroup orig . (over (traversed . _2)) (close keep) $ gr
where
keep = Set.fromList $ map fst gr

Expand Down Expand Up @@ -1470,9 +1471,15 @@ type DNormal v = Directed () (ANormal v)
data SuperNormal v = Lambda {conventions :: [Mem], bound :: ANormal v}
deriving (Show, Eq)

-- | Whether the evaluation of a given definition is cacheable or not.
-- i.e. it's a top-level pure value.
data Cacheability = Cacheable | Uncacheable
deriving stock (Eq, Show)

data SuperGroup v = Rec
{ group :: [(v, SuperNormal v)],
entry :: SuperNormal v
entry :: SuperNormal v,
cacheable :: Cacheability
}
deriving (Show)

Expand All @@ -1496,7 +1503,7 @@ equivocate ::
SuperGroup v ->
SuperGroup v ->
Either (SGEqv v) ()
equivocate g0@(Rec bs0 e0) g1@(Rec bs1 e1)
equivocate g0@(Rec bs0 e0 _c0) g1@(Rec bs1 e1 _c1)
| length bs0 == length bs1 =
traverse_ eqvSN (zip ns0 ns1) *> eqvSN (e0, e1)
| otherwise = Left $ NumDefns g0 g1
Expand Down Expand Up @@ -1586,8 +1593,8 @@ bindDirection = traverse (const binder)
record :: (Var v) => (v, SuperNormal v) -> ANFM v ()
record p = modify $ \(fr, bnd, to) -> (fr, bnd, p : to)

superNormalize :: (Var v) => Term v a -> SuperGroup v
superNormalize tm = Rec l c
superNormalize :: (Var v) => Cacheability -> Term v a -> SuperGroup v
superNormalize cacheable tm = Rec l c cacheable
where
(bs, e)
| LetRecNamed' bs e <- tm = (bs, e)
Expand Down Expand Up @@ -2004,8 +2011,8 @@ traverseGroupLinks ::
(Bool -> Reference -> f Reference) ->
SuperGroup v ->
f (SuperGroup v)
traverseGroupLinks f (Rec bs e) =
Rec <$> (traverse . traverse) (normalLinks f) bs <*> normalLinks f e
traverseGroupLinks f (Rec bs e cacheable) =
Rec <$> (traverse . traverse) (normalLinks f) bs <*> normalLinks f e <*> pure cacheable

foldGroupLinks ::
(Monoid r, Var v) =>
Expand Down Expand Up @@ -2149,7 +2156,7 @@ indent :: Int -> ShowS
indent ind = showString (replicate (ind * 2) ' ')

prettyGroup :: (Var v) => String -> SuperGroup v -> ShowS
prettyGroup s (Rec grp ent) =
prettyGroup s (Rec grp ent _c) =
showString ("let rec[" ++ s ++ "]\n")
. foldr f id grp
. showString "entry"
Expand Down
16 changes: 14 additions & 2 deletions unison-runtime/src/Unison/Runtime/ANF/Serialize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,11 @@ putGroup ::
EC.EnumMap FOp Text ->
SuperGroup v ->
m ()
putGroup refrep fops (Rec bs e) =
putGroup refrep fops (Rec bs e cacheable) =
putLength n
*> traverse_ (putComb refrep fops ctx) cs
*> putComb refrep fops ctx e
*> putCacheability cacheable
where
n = length us
(us, cs) = unzip bs
Expand All @@ -328,7 +329,18 @@ getGroup = do
vs = getFresh <$> take l [0 ..]
ctx = pushCtx vs []
cs <- replicateM l (getComb ctx n)
Rec (zip vs cs) <$> getComb ctx n
Rec (zip vs cs) <$> getComb ctx n <*> getCacheability

putCacheability :: (MonadPut m) => Cacheability -> m ()
putCacheability c = putBool $ case c of
Cacheable -> True
Uncacheable -> False

getCacheability :: (MonadGet m) => m Cacheability
getCacheability =
getBool <&> \case
True -> Cacheable
False -> Uncacheable

putComb ::
(MonadPut m) =>
Expand Down
40 changes: 19 additions & 21 deletions unison-runtime/src/Unison/Runtime/Interface.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import Unison.Codebase.Runtime (Error, Runtime (..))
import Unison.ConstructorReference (ConstructorReference, GConstructorReference (..))
import Unison.ConstructorReference qualified as RF
import Unison.DataDeclaration (Decl, declFields, declTypeDependencies)
import Unison.Debug qualified as Debug
import Unison.Hashing.V2.Convert qualified as Hashing
import Unison.LabeledDependency qualified as RF
import Unison.Parser.Ann (Ann (External))
Expand Down Expand Up @@ -119,7 +120,6 @@ import Unison.Runtime.MCode.Serialize
import Unison.Runtime.Machine
( ActiveThreads,
CCache (..),
Cacheability (..),
MCombs,
Tracer (..),
apply0,
Expand Down Expand Up @@ -200,17 +200,6 @@ resolveTermRef cl r@(RF.DerivedId i) =
Nothing -> die $ "Unknown term reference: " ++ show r
Just tm -> pure tm

resolveTermRefType ::
CodeLookup Symbol IO () ->
RF.Reference ->
IO (Type Symbol)
resolveTermRefType _ b@(RF.Builtin _) =
die $ "Unknown builtin term reference: " ++ show b
resolveTermRefType cl r@(RF.DerivedId i) =
getTypeOfTerm cl i >>= \case
Nothing -> die $ "Unknown term reference: " ++ show r
Just typ -> pure typ

allocType ::
EvalCtx ->
RF.Reference ->
Expand Down Expand Up @@ -467,10 +456,20 @@ loadDeps cl ppe ctx tyrs tmrs = do
where
checkCacheability :: (Reference, sprgrp) -> IO (Reference, sprgrp, Cacheability)
checkCacheability (r, sg) = do
typ <- resolveTermRefType cl r
if ABT.cata hasArrows typ
then pure (r, sg, Uncacheable)
else pure (r, sg, Cacheable)
getTermType r >>= \case
Just typ | not (ABT.cata hasArrows typ) -> pure (r, sg, Cacheable)
_ -> pure (r, sg, Uncacheable)
getTermType :: Reference -> IO (Maybe (Type Symbol))
getTermType = \case
ref@(RF.DerivedId i) ->
getTypeOfTerm cl i >>= \case
Just t -> do
Debug.debugM Debug.Temp "Found type for: " ref
pure $ Just t
Nothing -> do
Debug.debugM Debug.Temp "NO type for: " ref
pure Nothing
RF.Builtin {} -> pure $ Nothing
hasArrows :: a -> ABT.ABT Type.F v Bool -> Bool
hasArrows _ = \case
ABT.Tm f -> case f of
Expand Down Expand Up @@ -718,7 +717,7 @@ intermediateTerms ::
(HasCallStack) =>
PrettyPrintEnv ->
EvalCtx ->
Map RF.Id (Symbol, Term Symbol) ->
Map RF.Id (Symbol, Term Symbol, Cacheability) ->
( Map.Map Symbol Reference,
Map.Map Reference (SuperGroup Symbol),
Map.Map Reference (Map.Map Word64 (Term Symbol))
Expand All @@ -729,7 +728,7 @@ intermediateTerms ppe ctx rtms =
(subvs, Map.mapWithKey f cmbs, Map.map (Map.singleton 0) dcmp)
where
f ref =
superNormalize
superNormalize _cacheable
. splitPatterns (dspec ctx)
. addDefaultCases tmName
where
Expand Down Expand Up @@ -769,9 +768,9 @@ normalizeTerm ctx tm =
normalizeGroup ::
EvalCtx ->
Map Symbol Reference ->
[(Symbol, Term Symbol)] ->
[(Symbol, Term Symbol, Cacheability)] ->
( Map Symbol Reference,
Map Reference (Term Symbol),
Map Reference (Term Symbol, Cacheability),
Map Reference (Term Symbol)
)
normalizeGroup ctx orig gr0 = case lamLiftGroup orig gr of
Expand Down Expand Up @@ -814,7 +813,6 @@ prepareEvaluation ::
EvalCtx ->
IO (EvalCtx, [(Reference, SuperGroup Symbol)], Reference)
prepareEvaluation ppe tm ctx = do
-- TODO: Check whether we need to set cacheability here, I think probably not?
missing <- cacheAdd rgrp (ccache ctx')
when (not . null $ missing) . fail $
reportBug "E029347" $
Expand Down
2 changes: 1 addition & 1 deletion unison-runtime/src/Unison/Runtime/MCode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ emitCombs ::
Word64 ->
SuperGroup v ->
EnumMap Word64 Comb
emitCombs rns grpr grpn (Rec grp ent) =
emitCombs rns grpr grpn (Rec grp ent _cacheable) =
emitComb rns grpr grpn rec (0, ent) <> aux
where
(rvs, cmbs) = unzip grp
Expand Down
15 changes: 5 additions & 10 deletions unison-runtime/src/Unison/Runtime/Machine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ type MComb = RComb Closure

type MRef = RRef Closure

-- | Whether the evaluation of a given definition is cacheable or not.
-- i.e. it's a top-level pure value.
data Cacheability = Cacheable | Uncacheable
deriving stock (Eq, Show)

data Tracer
= NoTrace
| MsgTrace String String String
Expand Down Expand Up @@ -370,7 +365,7 @@ exec !env !denv !_activeThreads !ustk !bstk !k _ (BPrim1 LKUP i)
Just sn <- EC.lookup w numberedTermLookup -> do
poke ustk 1
bstk <- bump bstk
bstk <$ pokeBi bstk (ANF.Rec [] sn)
bstk <$ pokeBi bstk (ANF.Rec [] sn ANF.Uncacheable)
| otherwise -> bstk <$ poke ustk 0
Just sg -> do
poke ustk 1
Expand Down Expand Up @@ -2124,7 +2119,7 @@ evaluateSTM x = unsafeIOToSTM (evaluate x)

cacheAdd0 ::
S.Set Reference ->
[(Reference, SuperGroup Symbol, Cacheability)] ->
[(Reference, SuperGroup Symbol, ANF.Cacheability)] ->
[(Reference, Set Reference)] ->
CCache ->
IO ()
Expand All @@ -2133,8 +2128,8 @@ cacheAdd0 ntys0 termSuperGroups sands cc = do
termSuperGroups
& mapMaybe
( \case
(ref, _gr, Cacheable) -> Just ref
(_ref, _gr, Uncacheable) -> Nothing
(ref, _gr, ANF.Cacheable) -> Just ref
(_ref, _gr, ANF.Uncacheable) -> Nothing
)
& Set.fromList
let toAdd = M.fromList (termSuperGroups <&> \(r, g, _) -> (r, g))
Expand Down Expand Up @@ -2233,7 +2228,7 @@ cacheAdd l cc = do
-- Terms added via cacheAdd will have already been eval'd and cached if possible when
-- they were originally loaded, so we
-- don't need to re-check for cacheability here as part of a dynamic cache add.
l'' = l' <&> (\(r, g) -> (r, g, Uncacheable))
l'' = l' <&> (\(r, g) -> (r, g, ANF.Uncacheable))
if S.null missing
then [] <$ cacheAdd0 tys l'' (expandSandbox sand l') cc
else pure $ S.toList missing
Expand Down

0 comments on commit 2bbf1f7

Please sign in to comment.