Skip to content

Commit 6954600

Browse files
committed
Refactor DEC transformation
The previous code was a big mess where we partioned arguments into shared and non-shared and then filtered the case-tree depending on whether they were part of the shared arguments or not. But then with the normalisation of type arguments, the second filter did not work properly. This then resulted in shared arguments becoming part of the tuple in the alternatives of the case-expression for the non-shared arguments. The new code is also more robust in the sense that shared and non-shared arguments no longer need to be partioned (shared occur left-most, non-shared occur right-most). They can now be interleaved. The old code would also generate bad Core if ever type and term arguments occured interleaved, this is no longer the case for the new code. Fixes #2628
1 parent 0aa341a commit 6954600

File tree

4 files changed

+266
-103
lines changed

4 files changed

+266
-103
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
FIXED: Clash errors our in netlist-generation stage out when a polymorphic function is applied to type X in one alternative of a case-statement and applied to a new-type wrapper of type X in a different alternative. See [#2828](https://github.com/clash-lang/clash-compiler/issues/2628)

clash-lib/src/Clash/Normalize/Transformations/DEC.hs

+108-103
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC
3838
) where
3939

4040
import Control.Concurrent.Supply (splitSupply)
41+
#if !MIN_VERSION_base(4,18,0)
42+
import Control.Applicative (liftA2)
43+
#endif
4144
import Control.Lens ((^.), _1)
4245
import qualified Control.Lens as Lens
4346
import qualified Control.Monad as Monad
@@ -72,21 +75,22 @@ import Constants (mAX_TUPLE_SIZE)
7275
#endif
7376

7477
-- internal
75-
import Clash.Core.DataCon (DataCon)
78+
import Clash.Core.DataCon (DataCon)
7679
import Clash.Core.Evaluator.Types (whnf')
7780
import Clash.Core.FreeVars
7881
(termFreeVars', typeFreeVars', localVarsDoNotOccurIn)
7982
import Clash.Core.HasType
8083
import Clash.Core.Literal (Literal(..))
8184
import Clash.Core.Name (nameOcc)
85+
import Clash.Core.Pretty (showPpr)
8286
import Clash.Core.Term
8387
( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..)
8488
, collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks)
8589
import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons)
8690
import Clash.Core.Type
8791
(Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView)
8892
import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings)
89-
import Clash.Core.Var (isGlobalId, isLocalId, varName)
93+
import Clash.Core.Var (Id, isGlobalId, isLocalId, varName)
9094
import Clash.Core.VarEnv
9195
( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList
9296
, notElemInScopeSet, unionInScope)
@@ -138,6 +142,24 @@ import qualified GHC.Prim
138142
-- B -> f_out
139143
-- C -> h x
140144
-- @
145+
--
146+
-- Though that's a lie. It actually converts it into:
147+
--
148+
-- @
149+
-- let tupIn = case x of {A -> (3,y); B -> (x,x)}
150+
-- f_arg0 = case tupIn of (l,_) -> l
151+
-- f_arg1 = case tupIn of (_,r) -> r
152+
-- f_out = f f_arg0 f_arg1
153+
-- in case x of
154+
-- A -> f_out
155+
-- B -> f_out
156+
-- C -> h x
157+
-- @
158+
--
159+
-- In order to share the expression that's in the subject of the case expression,
160+
-- and to share the /decoder/ circuit that logic synthesis will create to map the
161+
-- bits of the subject expression to the bits needed to make the selection in the
162+
-- multiplexer.
141163
disjointExpressionConsolidation :: HasCallStack => NormRewrite
142164
disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _ty _alts@(_:_:_)) = do
143165
-- Collect all (the applications of) global binders (and certain primitives)
@@ -150,11 +172,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
150172
else do
151173
-- For every to-lift expression create (the generalization of):
152174
--
153-
-- let fargs = case x of {A -> (3,y); B -> (x,x)}
154-
-- in f (fst fargs) (snd fargs)
175+
-- let djArg0 = case x of {A -> 3; B -> x}
176+
-- djArg1 = case x of {A -> y; B -> x}
177+
-- in f djArg0 djArg1
155178
--
156-
-- the let-expression is not created when `f` has only one (selectable)
157-
-- argument
179+
-- if an argument is non-representable, the case-expression is inlined,
180+
-- and no let-binding will be created for it.
158181
--
159182
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
160183
-- whether expressions reference variables from the context, or
@@ -255,6 +278,13 @@ data CaseTree a
255278
| Branch Term [(Pat,CaseTree a)]
256279
deriving (Eq,Show,Functor,Foldable)
257280

281+
instance Applicative CaseTree where
282+
pure = Leaf
283+
liftA2 f (Leaf a) (Leaf b) = Leaf (f a b)
284+
liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2)
285+
liftA2 f (Branch scrut alts1) (Branch _ alts2) = Branch scrut (zipWith (\(p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2)
286+
liftA2 _ _ _ = error "bad"
287+
258288
-- | Test if a 'CaseTree' collected from an expression indicates that
259289
-- application of a global binder is disjoint: occur in separate branches of a
260290
-- case-expression.
@@ -269,18 +299,6 @@ isDisjoint ct = go ct
269299
go (Branch _ [(_,x)]) = go x
270300
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))
271301

272-
-- Remove empty branches from a 'CaseTree'
273-
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
274-
removeEmpty l@(Leaf _) = l
275-
removeEmpty (LB lb ct) =
276-
case removeEmpty ct of
277-
Leaf [] -> Leaf []
278-
ct' -> LB lb ct'
279-
removeEmpty (Branch s bs) =
280-
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
281-
[] -> Leaf []
282-
bs' -> Branch s bs'
283-
284302
-- | Test if all elements in a list are equal to each other.
285303
allEqual :: Eq a => [a] -> Bool
286304
allEqual [] = True
@@ -464,90 +482,89 @@ collectGlobalsLbs is0 substitution seen lbs = do
464482
-- function-position\", return a let-expression: where the let-binding holds
465483
-- a case-expression selecting between the distinct arguments of the case-tree,
466484
-- and the body is an application of the term applied to the shared arguments of
467-
-- the case tree, and projections of let-binding corresponding to the distinct
468-
-- argument positions.
485+
-- the case tree, and variable references to the created let-bindings.
486+
--
487+
-- case-expressions whose type would be non-representable are not let-bound,
488+
-- but occur directly in the argument position of the application in the body
489+
-- of the let-expression.
469490
mkDisjointGroup
470491
:: InScopeSet
471492
-- ^ Variables in scope at the very top of the case-tree, i.e., the original
472493
-- expression
473-
-> (Term,([Term],CaseTree [(Either Term Type)]))
494+
-> (Term,([Term],CaseTree [Either Term Type]))
474495
-- ^ Case-tree of arguments belonging to the applied term.
475496
-> NormalizeSession (Term,[Term])
476497
mkDisjointGroup inScope (fun,(seen,cs)) = do
477498
tcm <- Lens.view tcCache
478-
let argss = Foldable.toList cs
479-
argssT = zip [0..] (List.transpose argss)
480-
(sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT
481-
-- TODO: find a better solution than "maybe undefined fst . uncons"
482-
shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT
483-
distinct = map (Either.lefts) (List.transpose (map snd distinctT))
484-
cs' = fmap (zip [0..]) cs
485-
cs'' = removeEmpty
486-
$ fmap (Either.lefts . map snd)
487-
(if null shared
488-
then cs'
489-
else fmap (filter (`notElem` shared)) cs')
490-
(distinctCaseM,distinctProjections) <- case distinct of
491-
-- only shared arguments: do nothing.
492-
[] -> return (Nothing,[])
493-
-- Create selectors and projections
494-
(uc:_) -> do
495-
let argTys = map (inferCoreTypeOf tcm) uc
496-
disJointSelProj inScope argTys cs''
497-
let newArgs = mkDJArgs 0 shared distinctProjections
498-
case distinctCaseM of
499-
Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
500-
Nothing -> return (mkApps fun newArgs, seen)
501-
502-
-- | Create a single selector for all the representable distinct arguments by
503-
-- selecting between tuples. This selector is only ('Just') created when the
504-
-- number of representable uncommmon arguments is larger than one, otherwise it
505-
-- is not ('Nothing').
506-
--
507-
-- It also returns:
508-
--
509-
-- * For all the non-representable distinct arguments: a selector
510-
-- * For all the representable distinct arguments: a projection out of the tuple
511-
-- created by the larger selector. If this larger selector does not exist, a
512-
-- single selector is created for the single representable distinct argument.
499+
let argLen = case Foldable.toList cs of
500+
[] -> error "mkDisjointGroup: no disjoint groups"
501+
l:_ -> length l
502+
csT :: [CaseTree (Either Term Type)] -- "Transposed" 'CaseTree [Either Term Type]'
503+
csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] -- sequenceA does the wrong thing
504+
(lbs,newArgs) <- List.mapAccumLM (\lbs c -> do
505+
let cL = Foldable.toList c
506+
case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
507+
(Right ty:_, True) ->
508+
return (lbs,Right ty)
509+
(Right _:_, False) ->
510+
error ("mkDisjointGroup: non-equal type arguments: " <>
511+
showPpr (Either.rights cL))
512+
(Left tm:_, True) ->
513+
return (lbs,Left tm)
514+
(Left tm:_, False) -> do
515+
let ty = inferCoreTypeOf tcm tm
516+
let err = error ("mkDisjointGroup: mixed type and term arguments: " <> show cL)
517+
(lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c)
518+
case lbM of
519+
Just lb -> return (lb:lbs,Left arg)
520+
_ -> return (lbs,Left arg)
521+
([], _) ->
522+
error "mkDisjointGroup: no arguments"
523+
) [] csT
524+
let funApp = mkApps fun newArgs
525+
tupTcm <- Lens.view tupleTcCache
526+
case lbs of
527+
[] ->
528+
return (funApp, seen)
529+
[(v,(ty,ct))] -> do
530+
let e = genCase tcm tupTcm ty [ty] (fmap (:[]) ct)
531+
return (Letrec [(v,e)] funApp, seen)
532+
_ -> do
533+
let (vs,zs) = unzip lbs
534+
csL :: [CaseTree Term]
535+
(tys,csL) = unzip zs
536+
csLT :: CaseTree [Term]
537+
csLT = fmap ($ []) (foldr1 (liftA2 (.)) (fmap (fmap (:)) csL))
538+
bigTupTy = mkBigTupTy tcm tupTcm tys
539+
ct = genCase tcm tupTcm bigTupTy tys csLT
540+
tupIn <- mkInternalVar inScope "tupIn" bigTupTy
541+
projections <-
542+
Monad.zipWithM (\v n ->
543+
(v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n)
544+
vs [0..]
545+
return (Letrec ((tupIn,ct):projections) funApp, seen)
546+
547+
-- | Create a selector for the case-tree of the argument. If the argument is
548+
-- representable create a let-binding for the created selector, and return
549+
-- a variable reference to this let-binding. If the argument is not representable
550+
-- return the selector directly.
513551
disJointSelProj
514552
:: InScopeSet
515-
-> [Type]
516-
-- ^ Types of the arguments
517-
-> CaseTree [Term]
518-
-- The case-tree of arguments
519-
-> NormalizeSession (Maybe LetBinding,[Term])
520-
disJointSelProj _ _ (Leaf []) = return (Nothing,[])
521-
disJointSelProj inScope argTys cs = do
522-
tcm <- Lens.view tcCache
553+
-> Type
554+
-- ^ Types of the argument
555+
-> CaseTree Term
556+
-- The case-tree of argument
557+
-> NormalizeSession (Maybe (Id, (Type, CaseTree Term)),Term)
558+
disJointSelProj inScope argTy cs = do
559+
tcm <- Lens.view tcCache
523560
tupTcm <- Lens.view tupleTcCache
524-
let maxIndex = length argTys - 1
525-
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
526-
(untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys)
527-
let untranCs = map (css!!) (map fst untran)
528-
untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
529-
untran untranCs
530-
(lbM,projs) <- case tran of
531-
[] -> return (Nothing,[])
532-
[(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)])
533-
tys -> do
534-
let m = length tys
535-
(tyIxs,tys') = unzip tys
536-
tupTy = mkBigTupTy tcm tupTcm tys'
537-
cs' = fmap (\es -> map (es !!) tyIxs) cs
538-
djCase = genCase tcm tupTcm tupTy tys' cs'
539-
scrutId <- mkInternalVar inScope "tupIn" tupTy
540-
projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1]
541-
return (Just (scrutId,djCase),projections)
542-
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
543-
544-
return (lbM,selProjs)
545-
where
546-
tranOrUnTran _ [] projs = projs
547-
tranOrUnTran _ sels [] = map snd sels
548-
tranOrUnTran n ((ut,s):uts) (p:projs)
549-
| n == ut = s : tranOrUnTran (n+1) uts (p:projs)
550-
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
561+
let sel = genCase tcm tupTcm argTy [argTy] (fmap (:[]) cs)
562+
untran <- isUntranslatableType False argTy
563+
case untran of
564+
True -> return (Nothing, sel)
565+
False -> do
566+
argId <- mkInternalVar inScope "djArg" argTy
567+
return (Just (argId,(argTy,cs)), Var argId)
551568

552569
-- | Arguments are shared between invocations if:
553570
--
@@ -579,18 +596,6 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
579596
_ -> False
580597
isProof _ = False
581598

582-
-- | Create a list of arguments given a map of positions to common arguments,
583-
-- and a list of arguments
584-
mkDJArgs :: Int -- ^ Current position
585-
-> [(Int,Either Term Type)] -- ^ map from position to common argument
586-
-> [Term] -- ^ (projections for) distinct arguments
587-
-> [Either Term Type]
588-
mkDJArgs _ cms [] = map snd cms
589-
mkDJArgs _ [] uncms = map Left uncms
590-
mkDJArgs n ((m,x):cms) (y:uncms)
591-
| n == m = x : mkDJArgs (n+1) cms (y:uncms)
592-
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms
593-
594599
-- | Create a case-expression that selects between the distinct arguments given
595600
-- a case-tree
596601
genCase :: TyConMap

tests/Main.hs

+1
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ runClashTest = defaultMain $ clashTestRoot
802802
, outputTest "T2542" def{hdlTargets=[VHDL]}
803803
, runTest "T2593" def{hdlSim=[]}
804804
, runTest "T2623CaseConFVs" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]}
805+
, runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]}
805806
] <>
806807
if compiledWith == Cabal then
807808
-- This tests fails without environment files present, which are only

0 commit comments

Comments
 (0)