Skip to content

Commit 5480a34

Browse files
committed
Refactor DEC transformation
Create multiple selectors, one for each non-shared argument. The previous code was a big mess where we partioned arguments into shared and non-shared and then filtered the case-tree depending on whether they were part of the shared arguments or not. But then with the normalisation of type arguments, the second filter did not work properly. This then resulted in shared arguments becoming part of the tuple in the alternatives of the case-expression for the non-shared arguments. The new code is also more robust in the sense that shared and non-shared arguments no longer need to be partioned (shared occur left-most, non-shared occur right-most). They can now be interleaved. The old code would also generate bad Core if ever type and term arguments occured interleaved, this is no longer the case for the new code. Fixes #2628
1 parent cb331a8 commit 5480a34

File tree

3 files changed

+229
-192
lines changed

3 files changed

+229
-192
lines changed

clash-lib/src/Clash/Normalize/Transformations/DEC.hs

+72-192
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{-|
22
Copyright : (C) 2015-2016, University of Twente,
3-
2021-2022, QBayLogic B.V.
3+
2021-2024, QBayLogic B.V.
44
2022, LumiGuide Fietsdetectie B.V.
55
License : BSD2 (see the file LICENSE)
66
Maintainer : QBayLogic B.V. <[email protected]>
@@ -46,8 +46,6 @@ import Data.Coerce (coerce)
4646
import qualified Data.Either as Either
4747
import qualified Data.Foldable as Foldable
4848
import qualified Data.Graph as Graph
49-
import Data.IntMap.Strict (IntMap)
50-
import qualified Data.IntMap.Strict as IntMap
5149
import qualified Data.IntSet as IntSet
5250
import qualified Data.List as List
5351
import qualified Data.List.Extra as List
@@ -57,45 +55,32 @@ import Data.Monoid (All(..))
5755
import qualified Data.Text as Text
5856
import GHC.Stack (HasCallStack)
5957

60-
#if MIN_VERSION_ghc(9,6,0)
61-
import GHC.Core.Make (chunkify, mkChunkified)
62-
#else
63-
import GHC.Hs.Utils (chunkify, mkChunkified)
64-
#endif
65-
66-
#if MIN_VERSION_ghc(9,0,0)
67-
import GHC.Settings.Constants (mAX_TUPLE_SIZE)
68-
#else
69-
import Constants (mAX_TUPLE_SIZE)
70-
#endif
71-
7258
-- internal
73-
import Clash.Core.DataCon (DataCon)
7459
import Clash.Core.Evaluator.Types (whnf')
7560
import Clash.Core.FreeVars
7661
(termFreeVars', typeFreeVars', localVarsDoNotOccurIn)
7762
import Clash.Core.HasType
7863
import Clash.Core.Literal (Literal(..))
7964
import Clash.Core.Name (nameOcc)
65+
import Clash.Core.Pretty (showPpr)
8066
import Clash.Core.Term
8167
( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..)
8268
, collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks)
83-
import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons)
69+
import Clash.Core.TyCon (TyConMap)
8470
import Clash.Core.Type
85-
(Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView)
86-
import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings)
71+
(Type, TypeView (..), isPolyFunTy, splitFunForallTy, tyView)
72+
import Clash.Core.Util (mkInternalVar, sccLetBindings)
8773
import Clash.Core.Var (isGlobalId, isLocalId, varName)
8874
import Clash.Core.VarEnv
8975
( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList
9076
, notElemInScopeSet, unionInScope)
91-
import qualified Clash.Data.UniqMap as UniqMap
9277
import Clash.Normalize.Transformations.Letrec (deadCode)
9378
import Clash.Normalize.Types (NormRewrite, NormalizeSession)
9479
import Clash.Rewrite.Combinators (bottomupR)
9580
import Clash.Rewrite.Types
9681
import Clash.Rewrite.Util (changed, isUntranslatableType)
9782
import Clash.Rewrite.WorkFree (isConstant)
98-
import Clash.Util (MonadUnique, curLoc)
83+
import Clash.Util (curLoc)
9984

10085
-- | This transformation lifts applications of global binders out of
10186
-- alternatives of case-statements.
@@ -132,11 +117,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
132117
else do
133118
-- For every to-lift expression create (the generalization of):
134119
--
135-
-- let fargs = case x of {A -> (3,y); B -> (x,x)}
136-
-- in f (fst fargs) (snd fargs)
120+
-- let djArg0 = case x of {A -> 3; B -> x}
121+
-- djArg1 = case x of {A -> y; B -> x}
122+
-- in f djArg0 djArg1
137123
--
138-
-- the let-expression is not created when `f` has only one (selectable)
139-
-- argument
124+
-- if an argument is non-representable, the case-expression is inlined,
125+
-- and no let-binding will be created for it.
140126
--
141127
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
142128
-- whether expressions reference variables from the context, or
@@ -251,18 +237,6 @@ isDisjoint ct = go ct
251237
go (Branch _ [(_,x)]) = go x
252238
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))
253239

254-
-- Remove empty branches from a 'CaseTree'
255-
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
256-
removeEmpty l@(Leaf _) = l
257-
removeEmpty (LB lb ct) =
258-
case removeEmpty ct of
259-
Leaf [] -> Leaf []
260-
ct' -> LB lb ct'
261-
removeEmpty (Branch s bs) =
262-
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
263-
[] -> Leaf []
264-
bs' -> Branch s bs'
265-
266240
-- | Test if all elements in a list are equal to each other.
267241
allEqual :: Eq a => [a] -> Bool
268242
allEqual [] = True
@@ -464,8 +438,11 @@ collectGlobalsLbs is0 substitution seen lbs = do
464438
-- function-position\", return a let-expression: where the let-binding holds
465439
-- a case-expression selecting between the distinct arguments of the case-tree,
466440
-- and the body is an application of the term applied to the shared arguments of
467-
-- the case tree, and projections of let-binding corresponding to the distinct
468-
-- argument positions.
441+
-- the case tree, and variable references to the created let-bindings.
442+
--
443+
-- case-expressions whose type would be non-representable are not let-bound,
444+
-- but occur directly in the argument position of the application in the body
445+
-- of the let-expression.
469446
mkDisjointGroup
470447
:: InScopeSet
471448
-- ^ Variables in scope at the very top of the case-tree, i.e., the original
@@ -475,79 +452,59 @@ mkDisjointGroup
475452
-> NormalizeSession (Term,[Term])
476453
mkDisjointGroup inScope (fun,(seen,cs)) = do
477454
tcm <- Lens.view tcCache
478-
let argss = Foldable.toList cs
479-
argssT = zip [0..] (List.transpose argss)
480-
(sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT
481-
-- TODO: find a better solution than "maybe undefined fst . uncons"
482-
shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT
483-
distinct = map (Either.lefts) (List.transpose (map snd distinctT))
484-
cs' = fmap (zip [0..]) cs
485-
cs'' = removeEmpty
486-
$ fmap (Either.lefts . map snd)
487-
(if null shared
488-
then cs'
489-
else fmap (filter (`notElem` shared)) cs')
490-
(distinctCaseM,distinctProjections) <- case distinct of
491-
-- only shared arguments: do nothing.
492-
[] -> return (Nothing,[])
493-
-- Create selectors and projections
494-
(uc:_) -> do
495-
let argTys = map (inferCoreTypeOf tcm) uc
496-
disJointSelProj inScope argTys cs''
497-
let newArgs = mkDJArgs 0 shared distinctProjections
498-
case distinctCaseM of
499-
Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
500-
Nothing -> return (mkApps fun newArgs, seen)
501-
502-
-- | Create a single selector for all the representable distinct arguments by
503-
-- selecting between tuples. This selector is only ('Just') created when the
504-
-- number of representable uncommmon arguments is larger than one, otherwise it
505-
-- is not ('Nothing').
506-
--
507-
-- It also returns:
508-
--
509-
-- * For all the non-representable distinct arguments: a selector
510-
-- * For all the representable distinct arguments: a projection out of the tuple
511-
-- created by the larger selector. If this larger selector does not exist, a
512-
-- single selector is created for the single representable distinct argument.
455+
let argLen = case Foldable.toList cs of
456+
[] -> error ($curLoc <> "mkDisjointGroup: no disjoint groups")
457+
l:_ -> length l
458+
csT :: [CaseTree (Either Term Type)]
459+
csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)]
460+
(lbs,newArgs) <- List.mapAccumLM (\lbs c -> do
461+
let cL :: [Either Term Type]
462+
cL = Foldable.toList c
463+
case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
464+
(Right ty:_, True) ->
465+
return (lbs,Right ty)
466+
(Right _:_, False) ->
467+
error ($curLoc <> "mkDisjointGroup: non-equal type arguments: " <>
468+
showPpr (Either.rights cL))
469+
(Left tm:_, True) ->
470+
return (lbs,Left tm)
471+
(Left tm:_, False) -> do
472+
let ty = inferCoreTypeOf tcm tm
473+
let err = error $
474+
$curLoc <>
475+
"mkDisjointGroup: mixed type and term arguments: " <>
476+
show cL
477+
(lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c)
478+
case lbM of
479+
Just lb -> return (lb:lbs,Left arg)
480+
_ -> return (lbs,Left arg)
481+
([], _) ->
482+
error ($curLoc ++ "mkDisjointGroup: no arguments")
483+
) [] csT
484+
let funApp = mkApps fun newArgs
485+
case lbs of
486+
[] -> return (funApp, seen)
487+
_ -> return (Letrec lbs funApp, seen)
488+
489+
-- | Create a selector for the case-tree of the argument. If the argument is
490+
-- representable create a let-binding for the created selector, and return
491+
-- a variable reference to this let-binding. If the argument is not representable
492+
-- return the selector directly.
513493
disJointSelProj
514494
:: InScopeSet
515-
-> [Type]
516-
-- ^ Types of the arguments
517-
-> CaseTree [Term]
518-
-- The case-tree of arguments
519-
-> NormalizeSession (Maybe LetBinding,[Term])
520-
disJointSelProj _ _ (Leaf []) = return (Nothing,[])
521-
disJointSelProj inScope argTys cs = do
522-
tcm <- Lens.view tcCache
523-
tupTcm <- Lens.view tupleTcCache
524-
let maxIndex = length argTys - 1
525-
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
526-
(untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys)
527-
let untranCs = map (css!!) (map fst untran)
528-
untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
529-
untran untranCs
530-
(lbM,projs) <- case tran of
531-
[] -> return (Nothing,[])
532-
[(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)])
533-
tys -> do
534-
let m = length tys
535-
(tyIxs,tys') = unzip tys
536-
tupTy = mkBigTupTy tcm tupTcm tys'
537-
cs' = fmap (\es -> map (es !!) tyIxs) cs
538-
djCase = genCase tcm tupTcm tupTy tys' cs'
539-
scrutId <- mkInternalVar inScope "tupIn" tupTy
540-
projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1]
541-
return (Just (scrutId,djCase),projections)
542-
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
543-
544-
return (lbM,selProjs)
545-
where
546-
tranOrUnTran _ [] projs = projs
547-
tranOrUnTran _ sels [] = map snd sels
548-
tranOrUnTran n ((ut,s):uts) (p:projs)
549-
| n == ut = s : tranOrUnTran (n+1) uts (p:projs)
550-
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
495+
-> Type
496+
-- ^ Types of the argument
497+
-> CaseTree Term
498+
-- The case-tree of argument
499+
-> NormalizeSession (Maybe LetBinding,Term)
500+
disJointSelProj inScope argTy cs = do
501+
let sel = genCase argTy cs
502+
untran <- isUntranslatableType False argTy
503+
case untran of
504+
True -> return (Nothing, sel)
505+
False -> do
506+
argId <- mkInternalVar inScope "djArg" argTy
507+
return (Just (argId,sel), Var argId)
551508

552509
-- | Arguments are shared between invocations if:
553510
--
@@ -579,30 +536,15 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
579536
_ -> False
580537
isProof _ = False
581538

582-
-- | Create a list of arguments given a map of positions to common arguments,
583-
-- and a list of arguments
584-
mkDJArgs :: Int -- ^ Current position
585-
-> [(Int,Either Term Type)] -- ^ map from position to common argument
586-
-> [Term] -- ^ (projections for) distinct arguments
587-
-> [Either Term Type]
588-
mkDJArgs _ cms [] = map snd cms
589-
mkDJArgs _ [] uncms = map Left uncms
590-
mkDJArgs n ((m,x):cms) (y:uncms)
591-
| n == m = x : mkDJArgs (n+1) cms (y:uncms)
592-
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms
593-
594539
-- | Create a case-expression that selects between the distinct arguments given
595540
-- a case-tree
596-
genCase :: TyConMap
597-
-> IntMap TyConName
598-
-> Type -- ^ Type of the alternatives
599-
-> [Type] -- ^ Types of the arguments
600-
-> CaseTree [Term] -- ^ CaseTree of arguments
541+
genCase :: Type -- ^ Types of the arguments
542+
-> CaseTree Term -- ^ CaseTree of arguments
601543
-> Term
602-
genCase tcm tupTcm ty argTys = go
544+
genCase ty = go
603545
where
604-
go (Leaf tms) =
605-
mkBigTupTm tcm tupTcm (List.zipEqual argTys tms)
546+
go (Leaf tm) =
547+
tm
606548

607549
go (LB lb ct) =
608550
Letrec lb (go ct)
@@ -617,68 +559,6 @@ genCase tcm tupTcm ty argTys = go
617559
go (Branch scrut pats) =
618560
Case scrut ty (map (second go) pats)
619561

620-
-- | Lookup the TyConName and DataCon for a tuple of size n
621-
findTup :: TyConMap -> IntMap TyConName -> Int -> (TyConName,DataCon)
622-
findTup tcm tupTcm n =
623-
Maybe.fromMaybe (error ("Cannot build " <> show n <> "-tuble")) $ do
624-
tupTcNm <- IntMap.lookup n tupTcm
625-
tupTc <- UniqMap.lookup tupTcNm tcm
626-
tupDc <- Maybe.listToMaybe (tyConDataCons tupTc)
627-
return (tupTcNm,tupDc)
628-
629-
mkBigTupTm :: TyConMap -> IntMap TyConName -> [(Type,Term)] -> Term
630-
mkBigTupTm tcm tupTcm args = snd $ mkBigTup tcm tupTcm args
631-
632-
mkSmallTup,mkBigTup :: TyConMap -> IntMap TyConName -> [(Type,Term)] -> (Type,Term)
633-
mkSmallTup _ _ [] = error $ $curLoc ++ "mkSmallTup: Can't create 0-tuple"
634-
mkSmallTup _ _ [(ty,tm)] = (ty,tm)
635-
mkSmallTup tcm tupTcm args = (ty,tm)
636-
where
637-
(argTys,tms) = unzip args
638-
(tupTcNm,tupDc) = findTup tcm tupTcm (length args)
639-
tm = mkApps (Data tupDc) (map Right argTys ++ map Left tms)
640-
ty = mkTyConApp tupTcNm argTys
641-
642-
mkBigTup tcm tupTcm = mkChunkified (mkSmallTup tcm tupTcm)
643-
644-
mkSmallTupTy,mkBigTupTy
645-
:: TyConMap
646-
-> IntMap TyConName
647-
-> [Type]
648-
-> Type
649-
mkSmallTupTy _ _ [] = error $ $curLoc ++ "mkSmallTupTy: Can't create 0-tuple"
650-
mkSmallTupTy _ _ [ty] = ty
651-
mkSmallTupTy tcm tupTcm tys = mkTyConApp tupTcNm tys
652-
where
653-
m = length tys
654-
(tupTcNm,_) = findTup tcm tupTcm m
655-
656-
mkBigTupTy tcm tupTcm = mkChunkified (mkSmallTupTy tcm tupTcm)
657-
658-
mkSmallTupSelector,mkBigTupSelector
659-
:: MonadUnique m
660-
=> InScopeSet
661-
-> TyConMap
662-
-> IntMap TyConName
663-
-> Term
664-
-> [Type]
665-
-> Int
666-
-> m Term
667-
mkSmallTupSelector _ _ _ scrut [_] 0 = return scrut
668-
mkSmallTupSelector _ _ _ _ [_] n = error $ $curLoc ++ "mkSmallTupSelector called with one type, but to select " ++ show n
669-
mkSmallTupSelector inScope tcm _ scrut _ n = mkSelectorCase ($curLoc ++ "mkSmallTupSelector") inScope tcm scrut 1 n
670-
671-
mkBigTupSelector inScope tcm tupTcm scrut tys n = go (chunkify tys)
672-
where
673-
go [_] = mkSmallTupSelector inScope tcm tupTcm scrut tys n
674-
go tyss = do
675-
let (nOuter,nInner) = divMod n mAX_TUPLE_SIZE
676-
tyss' = map (mkSmallTupTy tcm tupTcm) tyss
677-
outer <- mkSmallTupSelector inScope tcm tupTcm scrut tyss' nOuter
678-
inner <- mkSmallTupSelector inScope tcm tupTcm outer (tyss List.!! nOuter) nInner
679-
return inner
680-
681-
682562
-- | Determine if a term in a function position is interesting to lift out of
683563
-- of a case-expression.
684564
--

tests/Main.hs

+1
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,7 @@ runClashTest = defaultMain $ clashTestRoot
788788
, outputTest "T2510" def{hdlTargets=[VHDL], clashFlags=["-DNOINLINE=OPAQUE"]}
789789
#endif
790790
, outputTest "T2542" def{hdlTargets=[VHDL]}
791+
, runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]}
791792
] <>
792793
if compiledWith == Cabal then
793794
-- This tests fails without environment files present, which are only

0 commit comments

Comments
 (0)