1
1
{-|
2
2
Copyright : (C) 2015-2016, University of Twente,
3
- 2021-2022 , QBayLogic B.V.
3
+ 2021-2024 , QBayLogic B.V.
4
4
2022, LumiGuide Fietsdetectie B.V.
5
5
License : BSD2 (see the file LICENSE)
6
6
Maintainer : QBayLogic B.V. <[email protected] >
@@ -46,8 +46,6 @@ import Data.Coerce (coerce)
46
46
import qualified Data.Either as Either
47
47
import qualified Data.Foldable as Foldable
48
48
import qualified Data.Graph as Graph
49
- import Data.IntMap.Strict (IntMap )
50
- import qualified Data.IntMap.Strict as IntMap
51
49
import qualified Data.IntSet as IntSet
52
50
import qualified Data.List as List
53
51
import qualified Data.List.Extra as List
@@ -57,45 +55,32 @@ import Data.Monoid (All(..))
57
55
import qualified Data.Text as Text
58
56
import GHC.Stack (HasCallStack )
59
57
60
- #if MIN_VERSION_ghc(9,6,0)
61
- import GHC.Core.Make (chunkify , mkChunkified )
62
- #else
63
- import GHC.Hs.Utils (chunkify , mkChunkified )
64
- #endif
65
-
66
- #if MIN_VERSION_ghc(9,0,0)
67
- import GHC.Settings.Constants (mAX_TUPLE_SIZE )
68
- #else
69
- import Constants (mAX_TUPLE_SIZE )
70
- #endif
71
-
72
58
-- internal
73
- import Clash.Core.DataCon (DataCon )
74
59
import Clash.Core.Evaluator.Types (whnf' )
75
60
import Clash.Core.FreeVars
76
61
(termFreeVars' , typeFreeVars' , localVarsDoNotOccurIn )
77
62
import Clash.Core.HasType
78
63
import Clash.Core.Literal (Literal (.. ))
79
64
import Clash.Core.Name (nameOcc )
65
+ import Clash.Core.Pretty (showPpr )
80
66
import Clash.Core.Term
81
67
( Alt , LetBinding , Pat (.. ), PrimInfo (.. ), Term (.. ), TickInfo (.. )
82
68
, collectArgs , collectArgsTicks , mkApps , mkTicks , patIds , stripTicks )
83
- import Clash.Core.TyCon (TyConMap , TyConName , tyConDataCons )
69
+ import Clash.Core.TyCon (TyConMap )
84
70
import Clash.Core.Type
85
- (Type , TypeView (.. ), isPolyFunTy , mkTyConApp , splitFunForallTy , tyView )
86
- import Clash.Core.Util (mkInternalVar , mkSelectorCase , sccLetBindings )
71
+ (Type , TypeView (.. ), isPolyFunTy , splitFunForallTy , tyView )
72
+ import Clash.Core.Util (mkInternalVar , sccLetBindings )
87
73
import Clash.Core.Var (isGlobalId , isLocalId , varName )
88
74
import Clash.Core.VarEnv
89
75
( InScopeSet , elemInScopeSet , extendInScopeSet , extendInScopeSetList
90
76
, notElemInScopeSet , unionInScope )
91
- import qualified Clash.Data.UniqMap as UniqMap
92
77
import Clash.Normalize.Transformations.Letrec (deadCode )
93
78
import Clash.Normalize.Types (NormRewrite , NormalizeSession )
94
79
import Clash.Rewrite.Combinators (bottomupR )
95
80
import Clash.Rewrite.Types
96
81
import Clash.Rewrite.Util (changed , isUntranslatableType )
97
82
import Clash.Rewrite.WorkFree (isConstant )
98
- import Clash.Util (MonadUnique , curLoc )
83
+ import Clash.Util (curLoc )
99
84
100
85
-- | This transformation lifts applications of global binders out of
101
86
-- alternatives of case-statements.
@@ -132,11 +117,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
132
117
else do
133
118
-- For every to-lift expression create (the generalization of):
134
119
--
135
- -- let fargs = case x of {A -> (3,y); B -> (x,x)}
136
- -- in f (fst fargs) (snd fargs)
120
+ -- let djArg0 = case x of {A -> 3; B -> x}
121
+ -- djArg1 = case x of {A -> y; B -> x}
122
+ -- in f djArg0 djArg1
137
123
--
138
- -- the let-expression is not created when `f` has only one (selectable)
139
- -- argument
124
+ -- if an argument is non-representable, the case-expression is inlined,
125
+ -- and no let-binding will be created for it.
140
126
--
141
127
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
142
128
-- whether expressions reference variables from the context, or
@@ -251,18 +237,6 @@ isDisjoint ct = go ct
251
237
go (Branch _ [(_,x)]) = go x
252
238
go b@ (Branch _ (_: _: _)) = allEqual (map Either. rights (Foldable. toList b))
253
239
254
- -- Remove empty branches from a 'CaseTree'
255
- removeEmpty :: Eq a => CaseTree [a ] -> CaseTree [a ]
256
- removeEmpty l@ (Leaf _) = l
257
- removeEmpty (LB lb ct) =
258
- case removeEmpty ct of
259
- Leaf [] -> Leaf []
260
- ct' -> LB lb ct'
261
- removeEmpty (Branch s bs) =
262
- case filter ((/= (Leaf [] )) . snd ) (map (second removeEmpty) bs) of
263
- [] -> Leaf []
264
- bs' -> Branch s bs'
265
-
266
240
-- | Test if all elements in a list are equal to each other.
267
241
allEqual :: Eq a => [a ] -> Bool
268
242
allEqual [] = True
@@ -464,8 +438,11 @@ collectGlobalsLbs is0 substitution seen lbs = do
464
438
-- function-position\", return a let-expression: where the let-binding holds
465
439
-- a case-expression selecting between the distinct arguments of the case-tree,
466
440
-- and the body is an application of the term applied to the shared arguments of
467
- -- the case tree, and projections of let-binding corresponding to the distinct
468
- -- argument positions.
441
+ -- the case tree, and variable references to the created let-bindings.
442
+ --
443
+ -- case-expressions whose type would be non-representable are not let-bound,
444
+ -- but occur directly in the argument position of the application in the body
445
+ -- of the let-expression.
469
446
mkDisjointGroup
470
447
:: InScopeSet
471
448
-- ^ Variables in scope at the very top of the case-tree, i.e., the original
@@ -475,79 +452,59 @@ mkDisjointGroup
475
452
-> NormalizeSession (Term ,[Term ])
476
453
mkDisjointGroup inScope (fun,(seen,cs)) = do
477
454
tcm <- Lens. view tcCache
478
- let argss = Foldable. toList cs
479
- argssT = zip [0 .. ] (List. transpose argss)
480
- (sharedT,distinctT) = List. partition (areShared tcm inScope . fmap (first stripTicks) . snd ) argssT
481
- -- TODO: find a better solution than "maybe undefined fst . uncons"
482
- shared = map (second (maybe (error " impossible" ) fst . List. uncons)) sharedT
483
- distinct = map (Either. lefts) (List. transpose (map snd distinctT))
484
- cs' = fmap (zip [0 .. ]) cs
485
- cs'' = removeEmpty
486
- $ fmap (Either. lefts . map snd )
487
- (if null shared
488
- then cs'
489
- else fmap (filter (`notElem` shared)) cs')
490
- (distinctCaseM,distinctProjections) <- case distinct of
491
- -- only shared arguments: do nothing.
492
- [] -> return (Nothing ,[] )
493
- -- Create selectors and projections
494
- (uc: _) -> do
495
- let argTys = map (inferCoreTypeOf tcm) uc
496
- disJointSelProj inScope argTys cs''
497
- let newArgs = mkDJArgs 0 shared distinctProjections
498
- case distinctCaseM of
499
- Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
500
- Nothing -> return (mkApps fun newArgs, seen)
501
-
502
- -- | Create a single selector for all the representable distinct arguments by
503
- -- selecting between tuples. This selector is only ('Just') created when the
504
- -- number of representable uncommmon arguments is larger than one, otherwise it
505
- -- is not ('Nothing').
506
- --
507
- -- It also returns:
508
- --
509
- -- * For all the non-representable distinct arguments: a selector
510
- -- * For all the representable distinct arguments: a projection out of the tuple
511
- -- created by the larger selector. If this larger selector does not exist, a
512
- -- single selector is created for the single representable distinct argument.
455
+ let argLen = case Foldable. toList cs of
456
+ [] -> error ($ curLoc <> " mkDisjointGroup: no disjoint groups" )
457
+ l: _ -> length l
458
+ csT :: [CaseTree (Either Term Type )]
459
+ csT = map (\ i -> fmap (!! i) cs) [0 .. (argLen- 1 )]
460
+ (lbs,newArgs) <- List. mapAccumLM (\ lbs c -> do
461
+ let cL :: [Either Term Type ]
462
+ cL = Foldable. toList c
463
+ case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
464
+ (Right ty: _, True ) ->
465
+ return (lbs,Right ty)
466
+ (Right _: _, False ) ->
467
+ error ($ curLoc <> " mkDisjointGroup: non-equal type arguments: " <>
468
+ showPpr (Either. rights cL))
469
+ (Left tm: _, True ) ->
470
+ return (lbs,Left tm)
471
+ (Left tm: _, False ) -> do
472
+ let ty = inferCoreTypeOf tcm tm
473
+ let err = error $
474
+ $ curLoc <>
475
+ " mkDisjointGroup: mixed type and term arguments: " <>
476
+ show cL
477
+ (lbM,arg) <- disJointSelProj inScope ty (Either. fromLeft err <$> c)
478
+ case lbM of
479
+ Just lb -> return (lb: lbs,Left arg)
480
+ _ -> return (lbs,Left arg)
481
+ ([] , _) ->
482
+ error ($ curLoc ++ " mkDisjointGroup: no arguments" )
483
+ ) [] csT
484
+ let funApp = mkApps fun newArgs
485
+ case lbs of
486
+ [] -> return (funApp, seen)
487
+ _ -> return (Letrec lbs funApp, seen)
488
+
489
+ -- | Create a selector for the case-tree of the argument. If the argument is
490
+ -- representable create a let-binding for the created selector, and return
491
+ -- a variable reference to this let-binding. If the argument is not representable
492
+ -- return the selector directly.
513
493
disJointSelProj
514
494
:: InScopeSet
515
- -> [Type ]
516
- -- ^ Types of the arguments
517
- -> CaseTree [Term ]
518
- -- The case-tree of arguments
519
- -> NormalizeSession (Maybe LetBinding ,[Term ])
520
- disJointSelProj _ _ (Leaf [] ) = return (Nothing ,[] )
521
- disJointSelProj inScope argTys cs = do
522
- tcm <- Lens. view tcCache
523
- tupTcm <- Lens. view tupleTcCache
524
- let maxIndex = length argTys - 1
525
- css = map (\ i -> fmap ((: [] ) . (!! i)) cs) [0 .. maxIndex]
526
- (untran,tran) <- List. partitionM (isUntranslatableType False . snd ) (zip [0 .. ] argTys)
527
- let untranCs = map (css!! ) (map fst untran)
528
- untranSels = zipWith (\ (_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
529
- untran untranCs
530
- (lbM,projs) <- case tran of
531
- [] -> return (Nothing ,[] )
532
- [(i,ty)] -> return (Nothing ,[genCase tcm tupTcm ty [ty] (css!! i)])
533
- tys -> do
534
- let m = length tys
535
- (tyIxs,tys') = unzip tys
536
- tupTy = mkBigTupTy tcm tupTcm tys'
537
- cs' = fmap (\ es -> map (es !! ) tyIxs) cs
538
- djCase = genCase tcm tupTcm tupTy tys' cs'
539
- scrutId <- mkInternalVar inScope " tupIn" tupTy
540
- projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0 .. m- 1 ]
541
- return (Just (scrutId,djCase),projections)
542
- let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
543
-
544
- return (lbM,selProjs)
545
- where
546
- tranOrUnTran _ [] projs = projs
547
- tranOrUnTran _ sels [] = map snd sels
548
- tranOrUnTran n ((ut,s): uts) (p: projs)
549
- | n == ut = s : tranOrUnTran (n+ 1 ) uts (p: projs)
550
- | otherwise = p : tranOrUnTran (n+ 1 ) ((ut,s): uts) projs
495
+ -> Type
496
+ -- ^ Types of the argument
497
+ -> CaseTree Term
498
+ -- The case-tree of argument
499
+ -> NormalizeSession (Maybe LetBinding ,Term )
500
+ disJointSelProj inScope argTy cs = do
501
+ let sel = genCase argTy cs
502
+ untran <- isUntranslatableType False argTy
503
+ case untran of
504
+ True -> return (Nothing , sel)
505
+ False -> do
506
+ argId <- mkInternalVar inScope " djArg" argTy
507
+ return (Just (argId,sel), Var argId)
551
508
552
509
-- | Arguments are shared between invocations if:
553
510
--
@@ -579,30 +536,15 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
579
536
_ -> False
580
537
isProof _ = False
581
538
582
- -- | Create a list of arguments given a map of positions to common arguments,
583
- -- and a list of arguments
584
- mkDJArgs :: Int -- ^ Current position
585
- -> [(Int ,Either Term Type )] -- ^ map from position to common argument
586
- -> [Term ] -- ^ (projections for) distinct arguments
587
- -> [Either Term Type ]
588
- mkDJArgs _ cms [] = map snd cms
589
- mkDJArgs _ [] uncms = map Left uncms
590
- mkDJArgs n ((m,x): cms) (y: uncms)
591
- | n == m = x : mkDJArgs (n+ 1 ) cms (y: uncms)
592
- | otherwise = Left y : mkDJArgs (n+ 1 ) ((m,x): cms) uncms
593
-
594
539
-- | Create a case-expression that selects between the distinct arguments given
595
540
-- a case-tree
596
- genCase :: TyConMap
597
- -> IntMap TyConName
598
- -> Type -- ^ Type of the alternatives
599
- -> [Type ] -- ^ Types of the arguments
600
- -> CaseTree [Term ] -- ^ CaseTree of arguments
541
+ genCase :: Type -- ^ Types of the arguments
542
+ -> CaseTree Term -- ^ CaseTree of arguments
601
543
-> Term
602
- genCase tcm tupTcm ty argTys = go
544
+ genCase ty = go
603
545
where
604
- go (Leaf tms ) =
605
- mkBigTupTm tcm tupTcm ( List. zipEqual argTys tms)
546
+ go (Leaf tm ) =
547
+ tm
606
548
607
549
go (LB lb ct) =
608
550
Letrec lb (go ct)
@@ -617,68 +559,6 @@ genCase tcm tupTcm ty argTys = go
617
559
go (Branch scrut pats) =
618
560
Case scrut ty (map (second go) pats)
619
561
620
- -- | Lookup the TyConName and DataCon for a tuple of size n
621
- findTup :: TyConMap -> IntMap TyConName -> Int -> (TyConName ,DataCon )
622
- findTup tcm tupTcm n =
623
- Maybe. fromMaybe (error (" Cannot build " <> show n <> " -tuble" )) $ do
624
- tupTcNm <- IntMap. lookup n tupTcm
625
- tupTc <- UniqMap. lookup tupTcNm tcm
626
- tupDc <- Maybe. listToMaybe (tyConDataCons tupTc)
627
- return (tupTcNm,tupDc)
628
-
629
- mkBigTupTm :: TyConMap -> IntMap TyConName -> [(Type ,Term )] -> Term
630
- mkBigTupTm tcm tupTcm args = snd $ mkBigTup tcm tupTcm args
631
-
632
- mkSmallTup ,mkBigTup :: TyConMap -> IntMap TyConName -> [(Type ,Term )] -> (Type ,Term )
633
- mkSmallTup _ _ [] = error $ $ curLoc ++ " mkSmallTup: Can't create 0-tuple"
634
- mkSmallTup _ _ [(ty,tm)] = (ty,tm)
635
- mkSmallTup tcm tupTcm args = (ty,tm)
636
- where
637
- (argTys,tms) = unzip args
638
- (tupTcNm,tupDc) = findTup tcm tupTcm (length args)
639
- tm = mkApps (Data tupDc) (map Right argTys ++ map Left tms)
640
- ty = mkTyConApp tupTcNm argTys
641
-
642
- mkBigTup tcm tupTcm = mkChunkified (mkSmallTup tcm tupTcm)
643
-
644
- mkSmallTupTy,mkBigTupTy
645
- :: TyConMap
646
- -> IntMap TyConName
647
- -> [Type ]
648
- -> Type
649
- mkSmallTupTy _ _ [] = error $ $ curLoc ++ " mkSmallTupTy: Can't create 0-tuple"
650
- mkSmallTupTy _ _ [ty] = ty
651
- mkSmallTupTy tcm tupTcm tys = mkTyConApp tupTcNm tys
652
- where
653
- m = length tys
654
- (tupTcNm,_) = findTup tcm tupTcm m
655
-
656
- mkBigTupTy tcm tupTcm = mkChunkified (mkSmallTupTy tcm tupTcm)
657
-
658
- mkSmallTupSelector,mkBigTupSelector
659
- :: MonadUnique m
660
- => InScopeSet
661
- -> TyConMap
662
- -> IntMap TyConName
663
- -> Term
664
- -> [Type ]
665
- -> Int
666
- -> m Term
667
- mkSmallTupSelector _ _ _ scrut [_] 0 = return scrut
668
- mkSmallTupSelector _ _ _ _ [_] n = error $ $ curLoc ++ " mkSmallTupSelector called with one type, but to select " ++ show n
669
- mkSmallTupSelector inScope tcm _ scrut _ n = mkSelectorCase ($ curLoc ++ " mkSmallTupSelector" ) inScope tcm scrut 1 n
670
-
671
- mkBigTupSelector inScope tcm tupTcm scrut tys n = go (chunkify tys)
672
- where
673
- go [_] = mkSmallTupSelector inScope tcm tupTcm scrut tys n
674
- go tyss = do
675
- let (nOuter,nInner) = divMod n mAX_TUPLE_SIZE
676
- tyss' = map (mkSmallTupTy tcm tupTcm) tyss
677
- outer <- mkSmallTupSelector inScope tcm tupTcm scrut tyss' nOuter
678
- inner <- mkSmallTupSelector inScope tcm tupTcm outer (tyss List. !! nOuter) nInner
679
- return inner
680
-
681
-
682
562
-- | Determine if a term in a function position is interesting to lift out of
683
563
-- of a case-expression.
684
564
--
0 commit comments