@@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC
38
38
) where
39
39
40
40
import Control.Concurrent.Supply (splitSupply )
41
+ #if !MIN_VERSION_base(4,18,0)
42
+ import Control.Applicative (liftA2 )
43
+ #endif
41
44
import Control.Lens ((^.) , _1 )
42
45
import qualified Control.Lens as Lens
43
46
import qualified Control.Monad as Monad
@@ -72,21 +75,22 @@ import Constants (mAX_TUPLE_SIZE)
72
75
#endif
73
76
74
77
-- internal
75
- import Clash.Core.DataCon (DataCon )
78
+ import Clash.Core.DataCon (DataCon )
76
79
import Clash.Core.Evaluator.Types (whnf' )
77
80
import Clash.Core.FreeVars
78
81
(termFreeVars' , typeFreeVars' , localVarsDoNotOccurIn )
79
82
import Clash.Core.HasType
80
83
import Clash.Core.Literal (Literal (.. ))
81
84
import Clash.Core.Name (nameOcc )
85
+ import Clash.Core.Pretty (showPpr )
82
86
import Clash.Core.Term
83
87
( Alt , LetBinding , Pat (.. ), PrimInfo (.. ), Term (.. ), TickInfo (.. )
84
88
, collectArgs , collectArgsTicks , mkApps , mkTicks , patIds , stripTicks )
85
89
import Clash.Core.TyCon (TyConMap , TyConName , tyConDataCons )
86
90
import Clash.Core.Type
87
91
(Type , TypeView (.. ), isPolyFunTy , mkTyConApp , splitFunForallTy , tyView )
88
92
import Clash.Core.Util (mkInternalVar , mkSelectorCase , sccLetBindings )
89
- import Clash.Core.Var (isGlobalId , isLocalId , varName )
93
+ import Clash.Core.Var (Id , isGlobalId , isLocalId , varName )
90
94
import Clash.Core.VarEnv
91
95
( InScopeSet , elemInScopeSet , extendInScopeSet , extendInScopeSetList
92
96
, notElemInScopeSet , unionInScope )
@@ -138,6 +142,24 @@ import qualified GHC.Prim
138
142
-- B -> f_out
139
143
-- C -> h x
140
144
-- @
145
+ --
146
+ -- Though that's a lie. It actually converts it into:
147
+ --
148
+ -- @
149
+ -- let tupIn = case x of {A -> (3,y); B -> (x,x)}
150
+ -- f_arg0 = case tupIn of (l,_) -> l
151
+ -- f_arg1 = case tupIn of (_,r) -> r
152
+ -- f_out = f f_arg0 f_arg1
153
+ -- in case x of
154
+ -- A -> f_out
155
+ -- B -> f_out
156
+ -- C -> h x
157
+ -- @
158
+ --
159
+ -- In order to share the expression that's in the subject of the case expression,
160
+ -- and to share the /decoder/ circuit that logic synthesis will create to map the
161
+ -- bits of the subject expression to the bits needed to make the selection in the
162
+ -- multiplexer.
141
163
disjointExpressionConsolidation :: HasCallStack => NormRewrite
142
164
disjointExpressionConsolidation ctx@ (TransformContext isCtx _) e@ (Case _scrut _ty _alts@ (_: _: _)) = do
143
165
-- Collect all (the applications of) global binders (and certain primitives)
@@ -150,11 +172,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
150
172
else do
151
173
-- For every to-lift expression create (the generalization of):
152
174
--
153
- -- let fargs = case x of {A -> (3,y); B -> (x,x)}
154
- -- in f (fst fargs) (snd fargs)
175
+ -- let djArg0 = case x of {A -> 3; B -> x}
176
+ -- djArg1 = case x of {A -> y; B -> x}
177
+ -- in f djArg0 djArg1
155
178
--
156
- -- the let-expression is not created when `f` has only one (selectable)
157
- -- argument
179
+ -- if an argument is non-representable, the case-expression is inlined,
180
+ -- and no let-binding will be created for it.
158
181
--
159
182
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
160
183
-- whether expressions reference variables from the context, or
@@ -255,6 +278,13 @@ data CaseTree a
255
278
| Branch Term [(Pat ,CaseTree a )]
256
279
deriving (Eq ,Show ,Functor ,Foldable )
257
280
281
+ instance Applicative CaseTree where
282
+ pure = Leaf
283
+ liftA2 f (Leaf a) (Leaf b) = Leaf (f a b)
284
+ liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2)
285
+ liftA2 f (Branch scrut alts1) (Branch _ alts2) = Branch scrut (zipWith (\ (p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2)
286
+ liftA2 _ _ _ = error " bad"
287
+
258
288
-- | Test if a 'CaseTree' collected from an expression indicates that
259
289
-- application of a global binder is disjoint: occur in separate branches of a
260
290
-- case-expression.
@@ -269,18 +299,6 @@ isDisjoint ct = go ct
269
299
go (Branch _ [(_,x)]) = go x
270
300
go b@ (Branch _ (_: _: _)) = allEqual (map Either. rights (Foldable. toList b))
271
301
272
- -- Remove empty branches from a 'CaseTree'
273
- removeEmpty :: Eq a => CaseTree [a ] -> CaseTree [a ]
274
- removeEmpty l@ (Leaf _) = l
275
- removeEmpty (LB lb ct) =
276
- case removeEmpty ct of
277
- Leaf [] -> Leaf []
278
- ct' -> LB lb ct'
279
- removeEmpty (Branch s bs) =
280
- case filter ((/= (Leaf [] )) . snd ) (map (second removeEmpty) bs) of
281
- [] -> Leaf []
282
- bs' -> Branch s bs'
283
-
284
302
-- | Test if all elements in a list are equal to each other.
285
303
allEqual :: Eq a => [a ] -> Bool
286
304
allEqual [] = True
@@ -464,90 +482,89 @@ collectGlobalsLbs is0 substitution seen lbs = do
464
482
-- function-position\", return a let-expression: where the let-binding holds
465
483
-- a case-expression selecting between the distinct arguments of the case-tree,
466
484
-- and the body is an application of the term applied to the shared arguments of
467
- -- the case tree, and projections of let-binding corresponding to the distinct
468
- -- argument positions.
485
+ -- the case tree, and variable references to the created let-bindings.
486
+ --
487
+ -- case-expressions whose type would be non-representable are not let-bound,
488
+ -- but occur directly in the argument position of the application in the body
489
+ -- of the let-expression.
469
490
mkDisjointGroup
470
491
:: InScopeSet
471
492
-- ^ Variables in scope at the very top of the case-tree, i.e., the original
472
493
-- expression
473
- -> (Term ,([Term ],CaseTree [( Either Term Type ) ]))
494
+ -> (Term ,([Term ],CaseTree [Either Term Type ]))
474
495
-- ^ Case-tree of arguments belonging to the applied term.
475
496
-> NormalizeSession (Term ,[Term ])
476
497
mkDisjointGroup inScope (fun,(seen,cs)) = do
477
498
tcm <- Lens. view tcCache
478
- let argss = Foldable. toList cs
479
- argssT = zip [0 .. ] (List. transpose argss)
480
- (sharedT,distinctT) = List. partition (areShared tcm inScope . fmap (first stripTicks) . snd ) argssT
481
- -- TODO: find a better solution than "maybe undefined fst . uncons"
482
- shared = map (second (maybe (error " impossible" ) fst . List. uncons)) sharedT
483
- distinct = map (Either. lefts) (List. transpose (map snd distinctT))
484
- cs' = fmap (zip [0 .. ]) cs
485
- cs'' = removeEmpty
486
- $ fmap (Either. lefts . map snd )
487
- (if null shared
488
- then cs'
489
- else fmap (filter (`notElem` shared)) cs')
490
- (distinctCaseM,distinctProjections) <- case distinct of
491
- -- only shared arguments: do nothing.
492
- [] -> return (Nothing ,[] )
493
- -- Create selectors and projections
494
- (uc: _) -> do
495
- let argTys = map (inferCoreTypeOf tcm) uc
496
- disJointSelProj inScope argTys cs''
497
- let newArgs = mkDJArgs 0 shared distinctProjections
498
- case distinctCaseM of
499
- Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
500
- Nothing -> return (mkApps fun newArgs, seen)
501
-
502
- -- | Create a single selector for all the representable distinct arguments by
503
- -- selecting between tuples. This selector is only ('Just') created when the
504
- -- number of representable uncommmon arguments is larger than one, otherwise it
505
- -- is not ('Nothing').
506
- --
507
- -- It also returns:
508
- --
509
- -- * For all the non-representable distinct arguments: a selector
510
- -- * For all the representable distinct arguments: a projection out of the tuple
511
- -- created by the larger selector. If this larger selector does not exist, a
512
- -- single selector is created for the single representable distinct argument.
499
+ let argLen = case Foldable. toList cs of
500
+ [] -> error " mkDisjointGroup: no disjoint groups"
501
+ l: _ -> length l
502
+ csT :: [CaseTree (Either Term Type )] -- "Transposed" 'CaseTree [Either Term Type]'
503
+ csT = map (\ i -> fmap (!! i) cs) [0 .. (argLen- 1 )] -- sequenceA does the wrong thing
504
+ (lbs,newArgs) <- List. mapAccumLM (\ lbs c -> do
505
+ let cL = Foldable. toList c
506
+ case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
507
+ (Right ty: _, True ) ->
508
+ return (lbs,Right ty)
509
+ (Right _: _, False ) ->
510
+ error (" mkDisjointGroup: non-equal type arguments: " <>
511
+ showPpr (Either. rights cL))
512
+ (Left tm: _, True ) ->
513
+ return (lbs,Left tm)
514
+ (Left tm: _, False ) -> do
515
+ let ty = inferCoreTypeOf tcm tm
516
+ let err = error (" mkDisjointGroup: mixed type and term arguments: " <> show cL)
517
+ (lbM,arg) <- disJointSelProj inScope ty (Either. fromLeft err <$> c)
518
+ case lbM of
519
+ Just lb -> return (lb: lbs,Left arg)
520
+ _ -> return (lbs,Left arg)
521
+ ([] , _) ->
522
+ error " mkDisjointGroup: no arguments"
523
+ ) [] csT
524
+ let funApp = mkApps fun newArgs
525
+ tupTcm <- Lens. view tupleTcCache
526
+ case lbs of
527
+ [] ->
528
+ return (funApp, seen)
529
+ [(v,(ty,ct))] -> do
530
+ let e = genCase tcm tupTcm ty [ty] (fmap (: [] ) ct)
531
+ return (Letrec [(v,e)] funApp, seen)
532
+ _ -> do
533
+ let (vs,zs) = unzip lbs
534
+ csL :: [CaseTree Term ]
535
+ (tys,csL) = unzip zs
536
+ csLT :: CaseTree [Term ]
537
+ csLT = fmap ($ [] ) (foldr1 (liftA2 (.) ) (fmap (fmap (:) ) csL))
538
+ bigTupTy = mkBigTupTy tcm tupTcm tys
539
+ ct = genCase tcm tupTcm bigTupTy tys csLT
540
+ tupIn <- mkInternalVar inScope " tupIn" bigTupTy
541
+ projections <-
542
+ Monad. zipWithM (\ v n ->
543
+ (v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n)
544
+ vs [0 .. ]
545
+ return (Letrec ((tupIn,ct): projections) funApp, seen)
546
+
547
+ -- | Create a selector for the case-tree of the argument. If the argument is
548
+ -- representable create a let-binding for the created selector, and return
549
+ -- a variable reference to this let-binding. If the argument is not representable
550
+ -- return the selector directly.
513
551
disJointSelProj
514
552
:: InScopeSet
515
- -> [Type ]
516
- -- ^ Types of the arguments
517
- -> CaseTree [Term ]
518
- -- The case-tree of arguments
519
- -> NormalizeSession (Maybe LetBinding ,[Term ])
520
- disJointSelProj _ _ (Leaf [] ) = return (Nothing ,[] )
521
- disJointSelProj inScope argTys cs = do
522
- tcm <- Lens. view tcCache
553
+ -> Type
554
+ -- ^ Types of the argument
555
+ -> CaseTree Term
556
+ -- The case-tree of argument
557
+ -> NormalizeSession (Maybe (Id , (Type , CaseTree Term )),Term )
558
+ disJointSelProj inScope argTy cs = do
559
+ tcm <- Lens. view tcCache
523
560
tupTcm <- Lens. view tupleTcCache
524
- let maxIndex = length argTys - 1
525
- css = map (\ i -> fmap ((: [] ) . (!! i)) cs) [0 .. maxIndex]
526
- (untran,tran) <- List. partitionM (isUntranslatableType False . snd ) (zip [0 .. ] argTys)
527
- let untranCs = map (css!! ) (map fst untran)
528
- untranSels = zipWith (\ (_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
529
- untran untranCs
530
- (lbM,projs) <- case tran of
531
- [] -> return (Nothing ,[] )
532
- [(i,ty)] -> return (Nothing ,[genCase tcm tupTcm ty [ty] (css!! i)])
533
- tys -> do
534
- let m = length tys
535
- (tyIxs,tys') = unzip tys
536
- tupTy = mkBigTupTy tcm tupTcm tys'
537
- cs' = fmap (\ es -> map (es !! ) tyIxs) cs
538
- djCase = genCase tcm tupTcm tupTy tys' cs'
539
- scrutId <- mkInternalVar inScope " tupIn" tupTy
540
- projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0 .. m- 1 ]
541
- return (Just (scrutId,djCase),projections)
542
- let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
543
-
544
- return (lbM,selProjs)
545
- where
546
- tranOrUnTran _ [] projs = projs
547
- tranOrUnTran _ sels [] = map snd sels
548
- tranOrUnTran n ((ut,s): uts) (p: projs)
549
- | n == ut = s : tranOrUnTran (n+ 1 ) uts (p: projs)
550
- | otherwise = p : tranOrUnTran (n+ 1 ) ((ut,s): uts) projs
561
+ let sel = genCase tcm tupTcm argTy [argTy] (fmap (: [] ) cs)
562
+ untran <- isUntranslatableType False argTy
563
+ case untran of
564
+ True -> return (Nothing , sel)
565
+ False -> do
566
+ argId <- mkInternalVar inScope " djArg" argTy
567
+ return (Just (argId,(argTy,cs)), Var argId)
551
568
552
569
-- | Arguments are shared between invocations if:
553
570
--
@@ -579,18 +596,6 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
579
596
_ -> False
580
597
isProof _ = False
581
598
582
- -- | Create a list of arguments given a map of positions to common arguments,
583
- -- and a list of arguments
584
- mkDJArgs :: Int -- ^ Current position
585
- -> [(Int ,Either Term Type )] -- ^ map from position to common argument
586
- -> [Term ] -- ^ (projections for) distinct arguments
587
- -> [Either Term Type ]
588
- mkDJArgs _ cms [] = map snd cms
589
- mkDJArgs _ [] uncms = map Left uncms
590
- mkDJArgs n ((m,x): cms) (y: uncms)
591
- | n == m = x : mkDJArgs (n+ 1 ) cms (y: uncms)
592
- | otherwise = Left y : mkDJArgs (n+ 1 ) ((m,x): cms) uncms
593
-
594
599
-- | Create a case-expression that selects between the distinct arguments given
595
600
-- a case-tree
596
601
genCase :: TyConMap
0 commit comments