-
Notifications
You must be signed in to change notification settings - Fork 156
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create multiple selectors, one for each non-shared argument. The previous code was a big mess where we partioned arguments into shared and non-shared and then filtered the case-tree depending on whether they were part of the shared arguments or not. But then with the normalisation of type arguments, the second filter did not work properly. This then resulted in shared arguments becoming part of the tuple in the alternatives of the case-expression for the non-shared arguments. The new code is also more robust in the sense that shared and non-shared arguments no longer need to be partioned (shared occur left-most, non-shared occur right-most). They can now be interleaved. The old code would also generate bad Core if ever type and term arguments occured interleaved, this is no longer the case for the new code. Fixes #2628
- Loading branch information
1 parent
cb331a8
commit d04ec3e
Showing
3 changed files
with
225 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
{-| | ||
Copyright : (C) 2015-2016, University of Twente, | ||
2021-2022, QBayLogic B.V. | ||
2021-2024, QBayLogic B.V. | ||
2022, LumiGuide Fietsdetectie B.V. | ||
License : BSD2 (see the file LICENSE) | ||
Maintainer : QBayLogic B.V. <[email protected]> | ||
|
@@ -46,8 +46,6 @@ import Data.Coerce (coerce) | |
import qualified Data.Either as Either | ||
import qualified Data.Foldable as Foldable | ||
import qualified Data.Graph as Graph | ||
import Data.IntMap.Strict (IntMap) | ||
import qualified Data.IntMap.Strict as IntMap | ||
import qualified Data.IntSet as IntSet | ||
import qualified Data.List as List | ||
import qualified Data.List.Extra as List | ||
|
@@ -57,45 +55,31 @@ import Data.Monoid (All(..)) | |
import qualified Data.Text as Text | ||
import GHC.Stack (HasCallStack) | ||
|
||
#if MIN_VERSION_ghc(9,6,0) | ||
import GHC.Core.Make (chunkify, mkChunkified) | ||
#else | ||
import GHC.Hs.Utils (chunkify, mkChunkified) | ||
#endif | ||
|
||
#if MIN_VERSION_ghc(9,0,0) | ||
import GHC.Settings.Constants (mAX_TUPLE_SIZE) | ||
#else | ||
import Constants (mAX_TUPLE_SIZE) | ||
#endif | ||
|
||
-- internal | ||
import Clash.Core.DataCon (DataCon) | ||
import Clash.Core.Evaluator.Types (whnf') | ||
import Clash.Core.FreeVars | ||
(termFreeVars', typeFreeVars', localVarsDoNotOccurIn) | ||
import Clash.Core.HasType | ||
import Clash.Core.Literal (Literal(..)) | ||
import Clash.Core.Name (nameOcc) | ||
import Clash.Core.Pretty (showPpr) | ||
import Clash.Core.Term | ||
( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..) | ||
, collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks) | ||
import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons) | ||
import Clash.Core.TyCon (TyConMap) | ||
import Clash.Core.Type | ||
(Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView) | ||
import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings) | ||
(Type, TypeView (..), isPolyFunTy, splitFunForallTy, tyView) | ||
import Clash.Core.Util (mkInternalVar, sccLetBindings) | ||
import Clash.Core.Var (isGlobalId, isLocalId, varName) | ||
import Clash.Core.VarEnv | ||
( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList | ||
, notElemInScopeSet, unionInScope) | ||
import qualified Clash.Data.UniqMap as UniqMap | ||
import Clash.Normalize.Transformations.Letrec (deadCode) | ||
import Clash.Normalize.Types (NormRewrite, NormalizeSession) | ||
import Clash.Rewrite.Combinators (bottomupR) | ||
import Clash.Rewrite.Types | ||
import Clash.Rewrite.Util (changed, isUntranslatableType) | ||
import Clash.Rewrite.WorkFree (isConstant) | ||
import Clash.Util (MonadUnique, curLoc) | ||
|
||
-- | This transformation lifts applications of global binders out of | ||
-- alternatives of case-statements. | ||
|
@@ -132,11 +116,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t | |
else do | ||
-- For every to-lift expression create (the generalization of): | ||
-- | ||
-- let fargs = case x of {A -> (3,y); B -> (x,x)} | ||
-- in f (fst fargs) (snd fargs) | ||
-- let djArg0 = case x of {A -> 3; B -> x} | ||
-- djArg1 = case x of {A -> y; B -> x} | ||
-- in f djArg0 djArg1 | ||
-- | ||
-- the let-expression is not created when `f` has only one (selectable) | ||
-- argument | ||
-- if an argument is non-representable, the case-expression is inlined, | ||
-- and no let-binding will be created for it. | ||
-- | ||
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine | ||
-- whether expressions reference variables from the context, or | ||
|
@@ -251,18 +236,6 @@ isDisjoint ct = go ct | |
go (Branch _ [(_,x)]) = go x | ||
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b)) | ||
|
||
-- Remove empty branches from a 'CaseTree' | ||
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a] | ||
removeEmpty l@(Leaf _) = l | ||
removeEmpty (LB lb ct) = | ||
case removeEmpty ct of | ||
Leaf [] -> Leaf [] | ||
ct' -> LB lb ct' | ||
removeEmpty (Branch s bs) = | ||
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of | ||
[] -> Leaf [] | ||
bs' -> Branch s bs' | ||
|
||
-- | Test if all elements in a list are equal to each other. | ||
allEqual :: Eq a => [a] -> Bool | ||
allEqual [] = True | ||
|
@@ -464,90 +437,69 @@ collectGlobalsLbs is0 substitution seen lbs = do | |
-- function-position\", return a let-expression: where the let-binding holds | ||
-- a case-expression selecting between the distinct arguments of the case-tree, | ||
-- and the body is an application of the term applied to the shared arguments of | ||
-- the case tree, and projections of let-binding corresponding to the distinct | ||
-- argument positions. | ||
-- the case tree, and variable references to the created let-bindings. | ||
-- | ||
-- case-expressions whose type would be non-representable are not let-bound, | ||
-- but occur directly in the argument position of the application in the body | ||
-- of the let-expression. | ||
mkDisjointGroup | ||
:: InScopeSet | ||
-- ^ Variables in scope at the very top of the case-tree, i.e., the original | ||
-- expression | ||
-> (Term,([Term],CaseTree [(Either Term Type)])) | ||
-> (Term,([Term],CaseTree [Either Term Type])) | ||
-- ^ Case-tree of arguments belonging to the applied term. | ||
-> NormalizeSession (Term,[Term]) | ||
mkDisjointGroup inScope (fun,(seen,cs)) = do | ||
tcm <- Lens.view tcCache | ||
let argss = Foldable.toList cs | ||
argssT = zip [0..] (List.transpose argss) | ||
(sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT | ||
-- TODO: find a better solution than "maybe undefined fst . uncons" | ||
shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT | ||
distinct = map (Either.lefts) (List.transpose (map snd distinctT)) | ||
cs' = fmap (zip [0..]) cs | ||
cs'' = removeEmpty | ||
$ fmap (Either.lefts . map snd) | ||
(if null shared | ||
then cs' | ||
else fmap (filter (`notElem` shared)) cs') | ||
(distinctCaseM,distinctProjections) <- case distinct of | ||
-- only shared arguments: do nothing. | ||
[] -> return (Nothing,[]) | ||
-- Create selectors and projections | ||
(uc:_) -> do | ||
let argTys = map (inferCoreTypeOf tcm) uc | ||
disJointSelProj inScope argTys cs'' | ||
let newArgs = mkDJArgs 0 shared distinctProjections | ||
case distinctCaseM of | ||
Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen) | ||
Nothing -> return (mkApps fun newArgs, seen) | ||
|
||
-- | Create a single selector for all the representable distinct arguments by | ||
-- selecting between tuples. This selector is only ('Just') created when the | ||
-- number of representable uncommmon arguments is larger than one, otherwise it | ||
-- is not ('Nothing'). | ||
-- | ||
-- It also returns: | ||
-- | ||
-- * For all the non-representable distinct arguments: a selector | ||
-- * For all the representable distinct arguments: a projection out of the tuple | ||
-- created by the larger selector. If this larger selector does not exist, a | ||
-- single selector is created for the single representable distinct argument. | ||
let argLen = case Foldable.toList cs of | ||
[] -> error "mkDisjointGroup: no disjoint groups" | ||
l:_ -> length l | ||
csT :: [CaseTree (Either Term Type)] -- "Transposed" 'CaseTree [Either Term Type]' | ||
csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] -- sequenceA does the wrong thing | ||
(lbs,newArgs) <- List.mapAccumLM (\lbs c -> do | ||
let cL = Foldable.toList c | ||
case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of | ||
(Right ty:_, True) -> | ||
return (lbs,Right ty) | ||
(Right _:_, False) -> | ||
error ("mkDisjointGroup: non-equal type arguments: " <> | ||
showPpr (Either.rights cL)) | ||
(Left tm:_, True) -> | ||
return (lbs,Left tm) | ||
(Left tm:_, False) -> do | ||
let ty = inferCoreTypeOf tcm tm | ||
let err = error ("mkDisjointGroup: mixed type and term arguments: " <> show cL) | ||
(lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c) | ||
case lbM of | ||
Just lb -> return (lb:lbs,Left arg) | ||
_ -> return (lbs,Left arg) | ||
([], _) -> | ||
error "mkDisjointGroup: no arguments" | ||
) [] csT | ||
let funApp = mkApps fun newArgs | ||
case lbs of | ||
[] -> return (funApp, seen) | ||
_ -> return (Letrec lbs funApp, seen) | ||
|
||
-- | Create a selector for the case-tree of the argument. If the argument is | ||
-- representable create a let-binding for the created selector, and return | ||
-- a variable reference to this let-binding. If the argument is not representable | ||
-- return the selector directly. | ||
disJointSelProj | ||
:: InScopeSet | ||
-> [Type] | ||
-- ^ Types of the arguments | ||
-> CaseTree [Term] | ||
-- The case-tree of arguments | ||
-> NormalizeSession (Maybe LetBinding,[Term]) | ||
disJointSelProj _ _ (Leaf []) = return (Nothing,[]) | ||
disJointSelProj inScope argTys cs = do | ||
tcm <- Lens.view tcCache | ||
tupTcm <- Lens.view tupleTcCache | ||
let maxIndex = length argTys - 1 | ||
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex] | ||
(untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys) | ||
let untranCs = map (css!!) (map fst untran) | ||
untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs') | ||
untran untranCs | ||
(lbM,projs) <- case tran of | ||
[] -> return (Nothing,[]) | ||
[(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)]) | ||
tys -> do | ||
let m = length tys | ||
(tyIxs,tys') = unzip tys | ||
tupTy = mkBigTupTy tcm tupTcm tys' | ||
cs' = fmap (\es -> map (es !!) tyIxs) cs | ||
djCase = genCase tcm tupTcm tupTy tys' cs' | ||
scrutId <- mkInternalVar inScope "tupIn" tupTy | ||
projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1] | ||
return (Just (scrutId,djCase),projections) | ||
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs | ||
|
||
return (lbM,selProjs) | ||
where | ||
tranOrUnTran _ [] projs = projs | ||
tranOrUnTran _ sels [] = map snd sels | ||
tranOrUnTran n ((ut,s):uts) (p:projs) | ||
| n == ut = s : tranOrUnTran (n+1) uts (p:projs) | ||
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs | ||
-> Type | ||
-- ^ Types of the argument | ||
-> CaseTree Term | ||
-- The case-tree of argument | ||
-> NormalizeSession (Maybe LetBinding,Term) | ||
disJointSelProj inScope argTy cs = do | ||
let sel = genCase argTy cs | ||
untran <- isUntranslatableType False argTy | ||
case untran of | ||
True -> return (Nothing, sel) | ||
False -> do | ||
argId <- mkInternalVar inScope "djArg" argTy | ||
return (Just (argId,sel), Var argId) | ||
|
||
-- | Arguments are shared between invocations if: | ||
-- | ||
|
@@ -579,30 +531,15 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs) | |
_ -> False | ||
isProof _ = False | ||
|
||
-- | Create a list of arguments given a map of positions to common arguments, | ||
-- and a list of arguments | ||
mkDJArgs :: Int -- ^ Current position | ||
-> [(Int,Either Term Type)] -- ^ map from position to common argument | ||
-> [Term] -- ^ (projections for) distinct arguments | ||
-> [Either Term Type] | ||
mkDJArgs _ cms [] = map snd cms | ||
mkDJArgs _ [] uncms = map Left uncms | ||
mkDJArgs n ((m,x):cms) (y:uncms) | ||
| n == m = x : mkDJArgs (n+1) cms (y:uncms) | ||
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms | ||
|
||
-- | Create a case-expression that selects between the distinct arguments given | ||
-- a case-tree | ||
genCase :: TyConMap | ||
-> IntMap TyConName | ||
-> Type -- ^ Type of the alternatives | ||
-> [Type] -- ^ Types of the arguments | ||
-> CaseTree [Term] -- ^ CaseTree of arguments | ||
genCase :: Type -- ^ Types of the arguments | ||
-> CaseTree Term -- ^ CaseTree of arguments | ||
-> Term | ||
genCase tcm tupTcm ty argTys = go | ||
genCase ty = go | ||
where | ||
go (Leaf tms) = | ||
mkBigTupTm tcm tupTcm (List.zipEqual argTys tms) | ||
go (Leaf tm) = | ||
tm | ||
|
||
go (LB lb ct) = | ||
Letrec lb (go ct) | ||
|
@@ -617,68 +554,6 @@ genCase tcm tupTcm ty argTys = go | |
go (Branch scrut pats) = | ||
Case scrut ty (map (second go) pats) | ||
|
||
-- | Lookup the TyConName and DataCon for a tuple of size n | ||
findTup :: TyConMap -> IntMap TyConName -> Int -> (TyConName,DataCon) | ||
findTup tcm tupTcm n = | ||
Maybe.fromMaybe (error ("Cannot build " <> show n <> "-tuble")) $ do | ||
tupTcNm <- IntMap.lookup n tupTcm | ||
tupTc <- UniqMap.lookup tupTcNm tcm | ||
tupDc <- Maybe.listToMaybe (tyConDataCons tupTc) | ||
return (tupTcNm,tupDc) | ||
|
||
mkBigTupTm :: TyConMap -> IntMap TyConName -> [(Type,Term)] -> Term | ||
mkBigTupTm tcm tupTcm args = snd $ mkBigTup tcm tupTcm args | ||
|
||
mkSmallTup,mkBigTup :: TyConMap -> IntMap TyConName -> [(Type,Term)] -> (Type,Term) | ||
mkSmallTup _ _ [] = error $ $curLoc ++ "mkSmallTup: Can't create 0-tuple" | ||
mkSmallTup _ _ [(ty,tm)] = (ty,tm) | ||
mkSmallTup tcm tupTcm args = (ty,tm) | ||
where | ||
(argTys,tms) = unzip args | ||
(tupTcNm,tupDc) = findTup tcm tupTcm (length args) | ||
tm = mkApps (Data tupDc) (map Right argTys ++ map Left tms) | ||
ty = mkTyConApp tupTcNm argTys | ||
|
||
mkBigTup tcm tupTcm = mkChunkified (mkSmallTup tcm tupTcm) | ||
|
||
mkSmallTupTy,mkBigTupTy | ||
:: TyConMap | ||
-> IntMap TyConName | ||
-> [Type] | ||
-> Type | ||
mkSmallTupTy _ _ [] = error $ $curLoc ++ "mkSmallTupTy: Can't create 0-tuple" | ||
mkSmallTupTy _ _ [ty] = ty | ||
mkSmallTupTy tcm tupTcm tys = mkTyConApp tupTcNm tys | ||
where | ||
m = length tys | ||
(tupTcNm,_) = findTup tcm tupTcm m | ||
|
||
mkBigTupTy tcm tupTcm = mkChunkified (mkSmallTupTy tcm tupTcm) | ||
|
||
mkSmallTupSelector,mkBigTupSelector | ||
:: MonadUnique m | ||
=> InScopeSet | ||
-> TyConMap | ||
-> IntMap TyConName | ||
-> Term | ||
-> [Type] | ||
-> Int | ||
-> m Term | ||
mkSmallTupSelector _ _ _ scrut [_] 0 = return scrut | ||
mkSmallTupSelector _ _ _ _ [_] n = error $ $curLoc ++ "mkSmallTupSelector called with one type, but to select " ++ show n | ||
mkSmallTupSelector inScope tcm _ scrut _ n = mkSelectorCase ($curLoc ++ "mkSmallTupSelector") inScope tcm scrut 1 n | ||
|
||
mkBigTupSelector inScope tcm tupTcm scrut tys n = go (chunkify tys) | ||
where | ||
go [_] = mkSmallTupSelector inScope tcm tupTcm scrut tys n | ||
go tyss = do | ||
let (nOuter,nInner) = divMod n mAX_TUPLE_SIZE | ||
tyss' = map (mkSmallTupTy tcm tupTcm) tyss | ||
outer <- mkSmallTupSelector inScope tcm tupTcm scrut tyss' nOuter | ||
inner <- mkSmallTupSelector inScope tcm tupTcm outer (tyss List.!! nOuter) nInner | ||
return inner | ||
|
||
|
||
-- | Determine if a term in a function position is interesting to lift out of | ||
-- of a case-expression. | ||
-- | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.