From 4ae5024f7eccda9fb02d6ab01bf51c355d2f2060 Mon Sep 17 00:00:00 2001 From: Christiaan Baaij Date: Wed, 14 Feb 2024 16:13:52 +0100 Subject: [PATCH] Refactor DEC transformation The previous code was a big mess where we partioned arguments into shared and non-shared and then filtered the case-tree depending on whether they were part of the shared arguments or not. But then with the normalisation of type arguments, the second filter did not work properly. This then resulted in shared arguments becoming part of the tuple in the alternatives of the case-expression for the non-shared arguments. The new code is also more robust in the sense that shared and non-shared arguments no longer need to be partioned (shared occur left-most, non-shared occur right-most). They can now be interleaved. The old code would also generate bad Core if ever type and term arguments occured interleaved, this is no longer the case for the new code. Fixes #2628 --- changelog/2024-02-23T16_56_56+01_00_fix2628 | 1 + .../Clash/Normalize/Transformations/DEC.hs | 214 +++++++++--------- tests/Main.hs | 1 + tests/shouldwork/Issues/T2628.hs | 156 +++++++++++++ 4 files changed, 267 insertions(+), 105 deletions(-) create mode 100644 changelog/2024-02-23T16_56_56+01_00_fix2628 create mode 100644 tests/shouldwork/Issues/T2628.hs diff --git a/changelog/2024-02-23T16_56_56+01_00_fix2628 b/changelog/2024-02-23T16_56_56+01_00_fix2628 new file mode 100644 index 0000000000..e84c479203 --- /dev/null +++ b/changelog/2024-02-23T16_56_56+01_00_fix2628 @@ -0,0 +1 @@ +FIXED: Clash errors our in netlist-generation stage out when a polymorphic function is applied to type X in one alternative of a case-statement and applied to a new-type wrapper of type X in a different alternative. See [#2828](https://github.com/clash-lang/clash-compiler/issues/2628) diff --git a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs index aaa0c30cb3..003f754459 100644 --- a/clash-lib/src/Clash/Normalize/Transformations/DEC.hs +++ b/clash-lib/src/Clash/Normalize/Transformations/DEC.hs @@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC ) where import Control.Concurrent.Supply (splitSupply) +#if !MIN_VERSION_base(4,18,0) +import Control.Applicative (liftA2) +#endif import Control.Lens ((^.), _1) import qualified Control.Lens as Lens import qualified Control.Monad as Monad @@ -72,13 +75,14 @@ import Constants (mAX_TUPLE_SIZE) #endif -- internal -import Clash.Core.DataCon (DataCon) +import Clash.Core.DataCon (DataCon) import Clash.Core.Evaluator.Types (whnf') import Clash.Core.FreeVars (termFreeVars', typeFreeVars', localVarsDoNotOccurIn) import Clash.Core.HasType import Clash.Core.Literal (Literal(..)) import Clash.Core.Name (nameOcc) +import Clash.Core.Pretty (showPpr) import Clash.Core.Term ( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..) , collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks) @@ -86,7 +90,7 @@ import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons) import Clash.Core.Type (Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView) import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings) -import Clash.Core.Var (isGlobalId, isLocalId, varName) +import Clash.Core.Var (Id, isGlobalId, isLocalId, varName) import Clash.Core.VarEnv ( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList , notElemInScopeSet, unionInScope) @@ -138,6 +142,24 @@ import qualified GHC.Prim -- B -> f_out -- C -> h x -- @ +-- +-- Though that's a lie. It actually converts it into: +-- +-- @ +-- let tupIn = case x of {A -> (3,y); B -> (x,x)} +-- f_arg0 = case tupIn of (l,_) -> l +-- f_arg1 = case tupIn of (_,r) -> r +-- f_out = f f_arg0 f_arg1 +-- in case x of +-- A -> f_out +-- B -> f_out +-- C -> h x +-- @ +-- +-- In order to share the expression that's in the subject of the case expression, +-- and to share the /decoder/ circuit that logic synthesis will create to map the +-- bits of the subject expression to the bits needed to make the selection in the +-- multiplexer. disjointExpressionConsolidation :: HasCallStack => NormRewrite disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _ty _alts@(_:_:_)) = do -- Collect all (the applications of) global binders (and certain primitives) @@ -150,11 +172,12 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t else do -- For every to-lift expression create (the generalization of): -- - -- let fargs = case x of {A -> (3,y); B -> (x,x)} - -- in f (fst fargs) (snd fargs) + -- let djArg0 = case x of {A -> 3; B -> x} + -- djArg1 = case x of {A -> y; B -> x} + -- in f djArg0 djArg1 -- - -- the let-expression is not created when `f` has only one (selectable) - -- argument + -- if an argument is non-representable, the case-expression is inlined, + -- and no let-binding will be created for it. -- -- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine -- whether expressions reference variables from the context, or @@ -255,6 +278,13 @@ data CaseTree a | Branch Term [(Pat,CaseTree a)] deriving (Eq,Show,Functor,Foldable) +instance Applicative CaseTree where + pure = Leaf + liftA2 f (Leaf a) (Leaf b) = Leaf (f a b) + liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2) + liftA2 f (Branch scrut alts1) (Branch _ alts2) = Branch scrut (zipWith (\(p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2) + liftA2 _ _ _ = error "bad" + -- | Test if a 'CaseTree' collected from an expression indicates that -- application of a global binder is disjoint: occur in separate branches of a -- case-expression. @@ -269,18 +299,6 @@ isDisjoint ct = go ct go (Branch _ [(_,x)]) = go x go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b)) --- Remove empty branches from a 'CaseTree' -removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a] -removeEmpty l@(Leaf _) = l -removeEmpty (LB lb ct) = - case removeEmpty ct of - Leaf [] -> Leaf [] - ct' -> LB lb ct' -removeEmpty (Branch s bs) = - case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of - [] -> Leaf [] - bs' -> Branch s bs' - -- | Test if all elements in a list are equal to each other. allEqual :: Eq a => [a] -> Bool allEqual [] = True @@ -464,90 +482,89 @@ collectGlobalsLbs is0 substitution seen lbs = do -- function-position\", return a let-expression: where the let-binding holds -- a case-expression selecting between the distinct arguments of the case-tree, -- and the body is an application of the term applied to the shared arguments of --- the case tree, and projections of let-binding corresponding to the distinct --- argument positions. +-- the case tree, and variable references to the created let-bindings. +-- +-- case-expressions whose type would be non-representable are not let-bound, +-- but occur directly in the argument position of the application in the body +-- of the let-expression. mkDisjointGroup :: InScopeSet -- ^ Variables in scope at the very top of the case-tree, i.e., the original -- expression - -> (Term,([Term],CaseTree [(Either Term Type)])) + -> (Term,([Term],CaseTree [Either Term Type])) -- ^ Case-tree of arguments belonging to the applied term. -> NormalizeSession (Term,[Term]) mkDisjointGroup inScope (fun,(seen,cs)) = do tcm <- Lens.view tcCache - let argss = Foldable.toList cs - argssT = zip [0..] (List.transpose argss) - (sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT - -- TODO: find a better solution than "maybe undefined fst . uncons" - shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT - distinct = map (Either.lefts) (List.transpose (map snd distinctT)) - cs' = fmap (zip [0..]) cs - cs'' = removeEmpty - $ fmap (Either.lefts . map snd) - (if null shared - then cs' - else fmap (filter (`notElem` shared)) cs') - (distinctCaseM,distinctProjections) <- case distinct of - -- only shared arguments: do nothing. - [] -> return (Nothing,[]) - -- Create selectors and projections - (uc:_) -> do - let argTys = map (inferCoreTypeOf tcm) uc - disJointSelProj inScope argTys cs'' - let newArgs = mkDJArgs 0 shared distinctProjections - case distinctCaseM of - Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen) - Nothing -> return (mkApps fun newArgs, seen) - --- | Create a single selector for all the representable distinct arguments by --- selecting between tuples. This selector is only ('Just') created when the --- number of representable uncommmon arguments is larger than one, otherwise it --- is not ('Nothing'). --- --- It also returns: --- --- * For all the non-representable distinct arguments: a selector --- * For all the representable distinct arguments: a projection out of the tuple --- created by the larger selector. If this larger selector does not exist, a --- single selector is created for the single representable distinct argument. + let argLen = case Foldable.toList cs of + [] -> error "mkDisjointGroup: no disjoint groups" + l:_ -> length l + csT :: [CaseTree (Either Term Type)] -- "Transposed" 'CaseTree [Either Term Type]' + csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] -- sequenceA does the wrong thing + (lbs,newArgs) <- List.mapAccumLM (\lbs c -> do + let cL = Foldable.toList c + case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of + (Right ty:_, True) -> + return (lbs,Right ty) + (Right _:_, False) -> + error ("mkDisjointGroup: non-equal type arguments: " <> + showPpr (Either.rights cL)) + (Left tm:_, True) -> + return (lbs,Left tm) + (Left tm:_, False) -> do + let ty = inferCoreTypeOf tcm tm + let err = error ("mkDisjointGroup: mixed type and term arguments: " <> show cL) + (lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c) + case lbM of + Just lb -> return (lb:lbs,Left arg) + _ -> return (lbs,Left arg) + ([], _) -> + error "mkDisjointGroup: no arguments" + ) [] csT + let funApp = mkApps fun newArgs + tupTcm <- Lens.view tupleTcCache + case lbs of + [] -> + return (funApp, seen) + [(v,(ty,ct))] -> do + let e = genCase tcm tupTcm ty [ty] (fmap (:[]) ct) + return (Letrec [(v,e)] funApp, seen) + _ -> do + let (vs,zs) = unzip lbs + csL :: [CaseTree Term] + (tys,csL) = unzip zs + csLT :: CaseTree [Term] + csLT = fmap ($ []) (foldr1 (liftA2 (.)) (fmap (fmap (:)) csL)) + bigTupTy = mkBigTupTy tcm tupTcm tys + ct = genCase tcm tupTcm bigTupTy tys csLT + tupIn <- mkInternalVar inScope "tupIn" bigTupTy + projections <- + Monad.zipWithM (\v n -> + (v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n) + vs [0..] + return (Letrec ((tupIn,ct):projections) funApp, seen) + +-- | Create a selector for the case-tree of the argument. If the argument is +-- representable create a let-binding for the created selector, and return +-- a variable reference to this let-binding. If the argument is not representable +-- return the selector directly. disJointSelProj :: InScopeSet - -> [Type] - -- ^ Types of the arguments - -> CaseTree [Term] - -- The case-tree of arguments - -> NormalizeSession (Maybe LetBinding,[Term]) -disJointSelProj _ _ (Leaf []) = return (Nothing,[]) -disJointSelProj inScope argTys cs = do - tcm <- Lens.view tcCache + -> Type + -- ^ Types of the argument + -> CaseTree Term + -- The case-tree of argument + -> NormalizeSession (Maybe (Id, (Type, CaseTree Term)),Term) +disJointSelProj inScope argTy cs = do + tcm <- Lens.view tcCache tupTcm <- Lens.view tupleTcCache - let maxIndex = length argTys - 1 - css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex] - (untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys) - let untranCs = map (css!!) (map fst untran) - untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs') - untran untranCs - (lbM,projs) <- case tran of - [] -> return (Nothing,[]) - [(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)]) - tys -> do - let m = length tys - (tyIxs,tys') = unzip tys - tupTy = mkBigTupTy tcm tupTcm tys' - cs' = fmap (\es -> map (es !!) tyIxs) cs - djCase = genCase tcm tupTcm tupTy tys' cs' - scrutId <- mkInternalVar inScope "tupIn" tupTy - projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1] - return (Just (scrutId,djCase),projections) - let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs - - return (lbM,selProjs) - where - tranOrUnTran _ [] projs = projs - tranOrUnTran _ sels [] = map snd sels - tranOrUnTran n ((ut,s):uts) (p:projs) - | n == ut = s : tranOrUnTran (n+1) uts (p:projs) - | otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs + let sel = genCase tcm tupTcm argTy [argTy] (fmap (:[]) cs) + untran <- isUntranslatableType False argTy + case untran of + True -> return (Nothing, sel) + False -> do + argId <- mkInternalVar inScope "djArg" argTy + return (Just (argId,(argTy,cs)), Var argId) -- | Arguments are shared between invocations if: -- @@ -579,23 +596,11 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs) _ -> False isProof _ = False --- | Create a list of arguments given a map of positions to common arguments, --- and a list of arguments -mkDJArgs :: Int -- ^ Current position - -> [(Int,Either Term Type)] -- ^ map from position to common argument - -> [Term] -- ^ (projections for) distinct arguments - -> [Either Term Type] -mkDJArgs _ cms [] = map snd cms -mkDJArgs _ [] uncms = map Left uncms -mkDJArgs n ((m,x):cms) (y:uncms) - | n == m = x : mkDJArgs (n+1) cms (y:uncms) - | otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms - -- | Create a case-expression that selects between the distinct arguments given -- a case-tree genCase :: TyConMap -> IntMap TyConName - -> Type -- ^ Type of the alternatives + -> Type -> [Type] -- ^ Types of the arguments -> CaseTree [Term] -- ^ CaseTree of arguments -> Term @@ -678,7 +683,6 @@ mkBigTupSelector inScope tcm tupTcm scrut tys n = go (chunkify tys) inner <- mkSmallTupSelector inScope tcm tupTcm outer (tyss List.!! nOuter) nInner return inner - -- | Determine if a term in a function position is interesting to lift out of -- of a case-expression. -- diff --git a/tests/Main.hs b/tests/Main.hs index 719c33b0f8..da9f9949d1 100755 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -802,6 +802,7 @@ runClashTest = defaultMain $ clashTestRoot , outputTest "T2542" def{hdlTargets=[VHDL]} , runTest "T2593" def{hdlSim=[]} , runTest "T2623CaseConFVs" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]} + , runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]} ] <> if compiledWith == Cabal then -- This tests fails without environment files present, which are only diff --git a/tests/shouldwork/Issues/T2628.hs b/tests/shouldwork/Issues/T2628.hs new file mode 100644 index 0000000000..2d6a001bb5 --- /dev/null +++ b/tests/shouldwork/Issues/T2628.hs @@ -0,0 +1,156 @@ +module T2628 where + +import Clash.Prelude + +-- idx cacheline entries are Just(tag,Just addr) to translate idx++tag->addr +-- and Just(tag,Nothing) for invalidated idx++tag entry +-- and Nothing for no entry there +type CacheLine m tag addr -- 2^m tags per line, 2^n lines + = Vec (2^m) (Maybe(tag,Maybe addr)) + +{-# ANN tacache_server_step32 + (Synthesize { t_name = "TACacheServerStep" + , t_inputs = [ PortName "dx" -- user B + , PortName "d_x" -- tlb C + , PortName "dw" -- tlb D + , PortName "out2" -- cache B + , PortName "out3" -- cache C + ] + , t_output = PortProduct "" + [ PortName "win1" -- cache A1 + , PortName "win2" -- cache A2 + ] + }) #-} + +{-# NOINLINE tacache_server_step32 #-} +tacache_server_step32 = tacache_server_step' + where + tacache_server_step' + :: forall (m::Nat) (n::Nat) (p::Nat) (q::Nat) + cxdr addr idx tag cacheline + . ( KnownNat q, KnownNat n, KnownNat m, KnownNat p + , n <= p + , cxdr ~ Signed p + , addr ~ Signed q + , idx ~ Signed n + , tag ~ Signed (p-n) + , cacheline ~ CacheLine m tag addr + , p ~ 132 + , q ~ 32 + , n ~ 6 + , m ~ 0 + ) + -- SNat n -- 2^n lines + -- SNat m -- of 2^m entries each + => ( Maybe cxdr -- input frnt invalidate addr req to server + , Maybe cxdr -- input back/weak invalidate req to server + , Maybe (cxdr,addr) -- input back/weak write req to server + , Maybe (idx,cacheline) + , Maybe (idx,cacheline) + ) + -> ( Maybe(idx,cacheline) + , Maybe(idx,cacheline) + ) + tacache_server_step' = tacache_server_step (SNat::SNat n) (SNat::SNat m) + +tacache_server_step + :: forall (m::Nat) (n::Nat) (p::Nat) (q::Nat) + cxdr addr idx tag cacheline + . ( KnownNat q, KnownNat n, KnownNat m, KnownNat p + , n <= p + , cxdr ~ Signed p + , addr ~ Signed q + , idx ~ Signed n + , tag ~ Signed (p-n) + , cacheline ~ CacheLine m tag addr +-- , p ~ 132 +-- , q ~ 32 + ) + => SNat n -- 2^n lines + -> SNat m -- of 2^m entries each + -> ( Maybe cxdr -- input frnt invalidate addr req to server + , Maybe cxdr -- input back/weak invalidate req to server + , Maybe (cxdr,addr) -- input back/weak write req to server + , Maybe (idx,cacheline) + , Maybe (idx,cacheline) + ) + -> ( Maybe(idx,cacheline) + , Maybe(idx,cacheline) + ) +tacache_server_step n m (dx,d_x,dw,out1,out2) = (win1,win2) + + where + -- outs1 and outs2 are prev state + -- (may need to write two lines in one cycle) + win1,win2 :: Maybe(idx,CacheLine m tag addr) + (win1,win2) = + case (dx, d_x, dw, out1, out2) of + + -- !!! FIX for HDL from here on, replace (v,_) = with v = fst $ !!! -- + + (Just x1,Just x2,Nothing,Just (idx1,v1),Just (idx2,v2)) -> + let (idx2',tag2) = tacache_split_cxdr x2 + in + if 1 /= idx2' then + ( Just(1,v1) + , Just(idx2',v2) + ) + else + let (v1',_) = tazcache_line_inval_step v1 2 -- HERE + (v2',_) = tazcache_line_weak_inval_step v1' tag2 -- HERE + in ( Just(idx2',v2') + , Nothing + ) + + -- !!! FIX for HDL from here, as above, and make cases top level fns !!! --- + + (Nothing,Just x,Nothing,_,Just (idx,v)) -> + let (v',_) = tazcache_line_weak_inval_step v 4 -- HERE + in ( Nothing + , Just(3,v') + ) + + _ -> (Nothing,Nothing) + + -------------------- DUMMY NOINLINE support ----------------------- + +-- split incoming addr for translation into a cacheline index and tag +{-# NOINLINE tacache_split_cxdr #-} +tacache_split_cxdr + :: forall (n::Nat) (p::Nat) tag cxdr idx f + . ( KnownNat n, KnownNat p + , Resize f -- might as well be just Signed + , n <= p, (n + (p-n)) ~ p, ((p-n) + n) ~ p + , BitPack cxdr, p ~ BitSize cxdr, cxdr ~ f p + , BitPack idx, n ~ BitSize idx, idx ~ f n + , BitPack tag, (p-n) ~ BitSize tag, tag ~ f (p-n) + ) + => cxdr + -> (idx,tag) +tacache_split_cxdr x = (unpack 5, unpack 6) + + ------------------ DUMMY NOINLINE cacheline ops --------------------- + +-- remove element with matching tag from cacheline, report position +{-# NOINLINE tazcache_line_inval_step #-} +tazcache_line_inval_step :: + ( KnownNat m, KnownNat p_n, KnownNat q + , BitPack tag, p_n ~ BitSize tag, Eq tag + , BitPack addr, q ~ BitSize addr + ) + => CacheLine m tag addr + -> tag + -> (CacheLine m tag addr, Maybe(Index(2^m))) +tazcache_line_inval_step v tag = (v,Nothing) + +-- add placeholder invalidated entry to cacheline, replace entry if was there +{-# NOINLINE tazcache_line_weak_inval_step #-} +tazcache_line_weak_inval_step :: + ( KnownNat m, KnownNat p_n, KnownNat q + , BitPack tag, p_n ~ BitSize tag, Eq tag + , BitPack addr, q ~ BitSize addr + ) + => CacheLine m tag addr + -> tag + -> (CacheLine m tag addr, Maybe(Index(2^m))) +tazcache_line_weak_inval_step v tag = (v,Nothing)