Skip to content

Commit

Permalink
Implement rfold in a way that unrolls only in Delta eval
Browse files Browse the repository at this point in the history
TODO: don't unroll at all
  • Loading branch information
Mikolaj committed Dec 15, 2023
1 parent 25b6f13 commit 1587916
Show file tree
Hide file tree
Showing 12 changed files with 470 additions and 6 deletions.
22 changes: 22 additions & 0 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,24 @@ data AstRanked :: AstSpanType -> RankedTensorKind where
-> Domains (AstDynamic s)
-> Domains (AstDynamic s)
-> AstRanked s r n
AstFold :: forall rn rm n m s. (GoodScalar rm, KnownNat m)
=> ( AstVarName (AstRanked PrimalSpan) rn n
, AstVarName (AstRanked PrimalSpan) rm m
, AstRanked PrimalSpan rn n )
-> AstRanked s rn n
-> AstRanked s rm (1 + m)
-> AstRanked s rn n
AstFoldRev :: forall rn rm n m s. (GoodScalar rm, KnownNat m)
=> ( AstVarName (AstRanked PrimalSpan) rn n
, AstVarName (AstRanked PrimalSpan) rm m
, AstRanked PrimalSpan rn n )
-> ( AstVarName (AstRanked PrimalSpan) rn n
, AstVarName (AstRanked PrimalSpan) rn n
, AstVarName (AstRanked PrimalSpan) rm m
, AstDomains PrimalSpan )
-> AstRanked s rn n
-> AstRanked s rm (1 + m)
-> AstRanked s rn n

deriving instance GoodScalar r => Show (AstRanked s r n)

Expand Down Expand Up @@ -434,6 +452,10 @@ type role AstDomains nominal
data AstDomains s where
-- There are existential variables inside DynamicExists here.
AstDomains :: Domains (AstDynamic s) -> AstDomains s
-- This operation is why we need AstDomains and so DomainsOf.
-- If we kept a vector of terms instead, we'd need to let-bind in each
-- of the terms separately, duplicating the let-bound term.
--
-- The r variable is existential here, so a proper specialization needs
-- to be picked explicitly at runtime.
AstLetInDomains :: (KnownNat n, GoodScalar r, AstSpan s)
Expand Down
35 changes: 35 additions & 0 deletions src/HordeAd/Core/AstEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module HordeAd.Core.AstEnv
, interpretLambdaIndex, interpretLambdaIndexS
, interpretLambdaIndexToIndex, interpretLambdaIndexToIndexS
, interpretLambdaDomains, interpretLambdaDomainsS
, interpretLambda2, interpretLambda3
-- * Interpretation of arithmetic, boolean and relation operations
, interpretAstN1, interpretAstN2, interpretAstR1, interpretAstR2
, interpretAstI2, interpretAstB2, interpretAstRelOp
Expand Down Expand Up @@ -238,6 +239,40 @@ interpretLambdaDomainsS
interpretLambdaDomainsS f !env (!vars, !ast) =
\pars -> f (extendEnvParsS vars pars env) ast

interpretLambda2
:: forall s ranked shaped rn rm n m.
(GoodScalar rn, GoodScalar rm, KnownNat n, KnownNat m)
=> (AstEnv ranked shaped -> AstRanked s rn n -> ranked rn n)
-> AstEnv ranked shaped
-> ( AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rm m
, AstRanked s rn n )
-> ranked rn n -> ranked rm m
-> ranked rn n
{-# INLINE interpretLambda2 #-}
interpretLambda2 f !env (!varn, !varm, !ast) =
\x0 as -> let envE = extendEnvR varn x0
$ extendEnvR varm as env
in f envE ast

interpretLambda3
:: forall s ranked shaped rn rm n m.
(GoodScalar rn, GoodScalar rm, KnownNat n, KnownNat m)
=> (AstEnv ranked shaped -> AstDomains s -> DomainsOf ranked)
-> AstEnv ranked shaped
-> ( AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rm m
, AstDomains s )
-> ranked rn n -> ranked rn n -> ranked rm m
-> DomainsOf ranked
{-# INLINE interpretLambda3 #-}
interpretLambda3 f !env (!varDt, !varn, !varm, !ast) =
\dt x0 as -> let envE = extendEnvR varDt dt
$ extendEnvR varn x0
$ extendEnvR varm as env
in f envE ast


-- * Interpretation of arithmetic, boolean and relation operations

Expand Down
93 changes: 92 additions & 1 deletion src/HordeAd/Core/AstFreshId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
-- with @unsafePerformIO@ outside, so some of it escapes.
module HordeAd.Core.AstFreshId
( astRegisterFun, astRegisterADShare, astRegisterADShareS
, funToAstIOR, funToAstR, funToAstDomains, funToAstDomainsS
, funToAstIOR, funToAstR, fun2ToAstR, fun2ToAstS, fun3ToAstR, fun3ToAstS
, funToAstDomains, funToAstDomainsS
, funToAstRevIO, funToAstRev, funToAstRevIOS, funToAstRevS
, funToAstFwdIO, funToAstFwd, funToAstFwdIOS, funToAstFwdS
, funToAstIOI, funToAstI, funToAstIndexIO, funToAstIndex
Expand Down Expand Up @@ -139,6 +140,96 @@ funToAstS f = unsafePerformIO $ do
(!var, _, !ast) <- funToAstIOS f
return (var, ast)

fun2ToAstIOR :: ShapeInt n
-> ShapeInt m
-> (AstRanked s rn n -> AstRanked s rm m -> AstRanked s rn n)
-> IO ( AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rm m
, AstRanked s rn n )
{-# INLINE fun2ToAstIOR #-}
fun2ToAstIOR shn shm f = do
nvarName <- unsafeGetFreshAstVarName
mvarName <- unsafeGetFreshAstVarName
let !x = f (AstVar shn nvarName) (AstVar shm mvarName)
return (nvarName, mvarName, x)

fun2ToAstR :: ShapeInt n
-> ShapeInt m
-> (AstRanked s rn n -> AstRanked s rm m -> AstRanked s rn n)
-> ( AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rm m
, AstRanked s rn n )
{-# NOINLINE fun2ToAstR #-}
fun2ToAstR shn shm f = unsafePerformIO $ fun2ToAstIOR shn shm f

fun2ToAstIOS :: (AstShaped s rn shn -> AstShaped s rm shm -> AstShaped s rn shn)
-> IO ( AstVarName (AstShaped s) rn shn
, AstVarName (AstShaped s) rm shm
, AstShaped s rn shn )
{-# INLINE fun2ToAstIOS #-}
fun2ToAstIOS f = do
nvarName <- unsafeGetFreshAstVarName
mvarName <- unsafeGetFreshAstVarName
let !x = f (AstVarS nvarName) (AstVarS mvarName)
return (nvarName, mvarName, x)

fun2ToAstS :: (AstShaped s rn shn -> AstShaped s rm shm -> AstShaped s rn shn)
-> ( AstVarName (AstShaped s) rn shn
, AstVarName (AstShaped s) rm shm
, AstShaped s rn shn )
{-# NOINLINE fun2ToAstS #-}
fun2ToAstS f = unsafePerformIO $ fun2ToAstIOS f

fun3ToAstIOR :: ShapeInt n
-> ShapeInt m
-> (AstRanked s rn n -> AstRanked s rn n -> AstRanked s rm m
-> AstDomains s)
-> IO ( AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rm m
, AstDomains s )
{-# INLINE fun3ToAstIOR #-}
fun3ToAstIOR shn shm f = do
nvarName <- unsafeGetFreshAstVarName
nvarName2 <- unsafeGetFreshAstVarName
mvarName <- unsafeGetFreshAstVarName
let !x = f (AstVar shn nvarName) (AstVar shn nvarName2) (AstVar shm mvarName)
return (nvarName, nvarName2, mvarName, x)

fun3ToAstR :: ShapeInt n
-> ShapeInt m
-> (AstRanked s rn n -> AstRanked s rn n -> AstRanked s rm m
-> AstDomains s)
-> ( AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rn n
, AstVarName (AstRanked s) rm m
, AstDomains s )
{-# NOINLINE fun3ToAstR #-}
fun3ToAstR shn shm f = unsafePerformIO $ fun3ToAstIOR shn shm f

fun3ToAstIOS :: (AstShaped s rn shn -> AstShaped s rn shn -> AstShaped s rm shm
-> AstDomains s)
-> IO ( AstVarName (AstShaped s) rn shn
, AstVarName (AstShaped s) rn shn
, AstVarName (AstShaped s) rm shm
, AstDomains s )
{-# INLINE fun3ToAstIOS #-}
fun3ToAstIOS f = do
nvarName <- unsafeGetFreshAstVarName
nvarName2 <- unsafeGetFreshAstVarName
mvarName <- unsafeGetFreshAstVarName
let !x = f (AstVarS nvarName) (AstVarS nvarName2) (AstVarS mvarName)
return (nvarName, nvarName2, mvarName, x)

fun3ToAstS :: (AstShaped s rn shn -> AstShaped s rn shn -> AstShaped s rm shm
-> AstDomains s)
-> ( AstVarName (AstShaped s) rn shn
, AstVarName (AstShaped s) rn shn
, AstVarName (AstShaped s) rm shm
, AstDomains s )
{-# NOINLINE fun3ToAstS #-}
fun3ToAstS f = unsafePerformIO $ fun3ToAstIOS f

funToAstDomainsIO
:: (Domains (AstDynamic s) -> AstRanked s r n)
-> DomainsOD
Expand Down
22 changes: 22 additions & 0 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,18 @@ inlineAst memo v0 = case v0 of
(memo1, l1) = mapAccumR inlineAstDynamic memo l
(memo2, ds2) = mapAccumR inlineAstDynamic memo1 ds
in (memo2, Ast.AstFwd (vars, v2) l1 ds2)
Ast.AstFold (nvar, mvar, v) x0 as ->
let (_, v2) = inlineAst EM.empty v
(memo1, x02) = inlineAst memo x0
(memo2, as2) = inlineAst memo1 as
in (memo2, Ast.AstFold (nvar, mvar, v2) x02 as2)
Ast.AstFoldRev (nvar, mvar, v) (varDt2, nvar2, mvar2, doms) x0 as ->
let (_, v2) = inlineAst EM.empty v
(_, doms2) = inlineAstDomains EM.empty doms
(memo1, x02) = inlineAst memo x0
(memo2, as2) = inlineAst memo1 as
in (memo2, Ast.AstFoldRev (nvar, mvar, v2)
(varDt2, nvar2, mvar2, doms2) x02 as2)

inlineAstDynamic
:: AstSpan s
Expand Down Expand Up @@ -527,6 +539,16 @@ unletAst env t = case t of
Ast.AstFwd (vars, unletAst (emptyUnletEnv emptyADShare) v)
(V.map (unletAstDynamic env) l)
(V.map (unletAstDynamic env) ds)
Ast.AstFold (nvar, mvar, v) x0 as ->
Ast.AstFold (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
(unletAst env x0)
(unletAst env as)
Ast.AstFoldRev (nvar, mvar, v) (varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstFoldRev (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
( varDt2, nvar2, mvar2
, unletAstDomains (emptyUnletEnv emptyADShare) doms )
(unletAst env x0)
(unletAst env as)

unletAstDynamic
:: AstSpan s
Expand Down
19 changes: 19 additions & 0 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,25 @@ interpretAst !env = \case
pars = interpretAstDynamic @ranked env <$> parameters
d = interpretAstDynamic @ranked env <$> ds
in rfwd @ranked g parameters0 pars d
AstFold @_ @rm @_ @m f x0 as ->
let g :: forall f. ADReady f => f r n -> f rm m -> f r n
g = interpretLambda2 interpretAst EM.empty f
-- Interpretation in empty environment --- makes sense only
-- if there are no free variables outside of those listed.
-- Note that @f@ is in @PrimalSpan@, but this does not affect
-- the interpretation, only what term can be built.
x0i = interpretAst @ranked env x0
asi = interpretAst @ranked env as
in rfold @ranked g x0i asi
AstFoldRev @_ @rm @_ @m f df x0 as ->
let g :: forall f. ADReady f => f r n -> f rm m -> f r n
g = interpretLambda2 interpretAst EM.empty f
h :: forall f. ADReady f
=> f r n -> f r n -> f rm m -> DomainsOf f
h = interpretLambda3 interpretAstDomains EM.empty df
x0i = interpretAst @ranked env x0
asi = interpretAst @ranked env as
in rfoldRev @ranked g h x0i asi

interpretAstDynamic
:: forall ranked shaped s. (ADReadyBoth ranked shaped, AstSpan s)
Expand Down
40 changes: 40 additions & 0 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ areAllArgsInts = \case
AstD{} -> False -- dual number
AstLetDomainsIn{} -> True -- too early to tell
AstFwd{} -> False
AstFold{} -> False
AstFoldRev{} -> False


-- * Pretty-print variables
Expand Down Expand Up @@ -342,6 +344,44 @@ printAstAux cfg d = \case
. printDomainsAst cfg parameters
. showString " "
. printDomainsAst cfg ds
AstFold (nvar, mvar, v) x0 as ->
showParen (d > 10)
$ showString "rfold "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) nvar)
. showString " "
. showString (printAstVarName (varRenames cfg) mvar)
. showString " -> "
. printAst cfg 0 v)
. showString " "
. printAst cfg 11 x0
. showString " "
. printAst cfg 11 as
AstFoldRev (nvar, mvar, v) (varDt2, nvar2, mvar2, doms) x0 as ->
showParen (d > 10)
$ showString "rfold "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) nvar)
. showString " "
. showString (printAstVarName (varRenames cfg) mvar)
. showString " -> "
. printAst cfg 0 v)
. showString " "
. (showParen True
$ showString "\\"
. showString (printAstVarName (varRenames cfg) varDt2)
. showString " "
. showString (printAstVarName (varRenames cfg) nvar2)
. showString " "
. showString (printAstVarName (varRenames cfg) mvar2)
. showString " -> "
. printAstDomains cfg 0 doms)
. showString " "
. printAst cfg 11 x0
. showString " "
. printAst cfg 11 as

-- Differs from standard only in the space after comma.
showListWith :: (a -> ShowS) -> [a] -> ShowS
Expand Down
29 changes: 27 additions & 2 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ simplifyStepNonIndex t = case t of
Ast.AstD{} -> t
Ast.AstLetDomainsIn{} -> t
Ast.AstFwd{} -> t
Ast.AstFold{} -> t
Ast.AstFoldRev{} -> t

simplifyStepNonIndexS
:: ()
Expand Down Expand Up @@ -417,6 +419,8 @@ astIndexROrStepOnly stepOnly v0 ix@(i1 :. (rest1 :: AstIndex m1)) =
Ast.AstLetDomainsIn vars l v ->
Ast.AstLetDomainsIn vars l (astIndexRec v ix)
Ast.AstFwd{} -> Ast.AstIndex v0 ix
Ast.AstFold{} -> Ast.AstIndex v0 ix
Ast.AstFoldRev{} -> Ast.AstIndex v0 ix

-- TODO: compared to tletIx, it adds many lets, not one, but does not
-- create other (and non-simplified!) big terms and also uses astIsSmall,
Expand Down Expand Up @@ -740,6 +744,8 @@ astGatherROrStepOnly stepOnly sh0 v0 (vars0, ix0) =
Ast.AstLetDomainsIn vars l v ->
Ast.AstLetDomainsIn vars l (astGatherCase sh4 v (vars4, ix4))
Ast.AstFwd{} -> Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstFold{} -> Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstFoldRev{} -> Ast.AstGather sh4 v4 (vars4, ix4)

gatherFromNF :: forall m p. (KnownNat m, KnownNat p)
=> AstVarList m -> AstIndex (1 + p) -> Bool
Expand Down Expand Up @@ -1514,6 +1520,8 @@ astPrimalPart t = case t of
Ast.AstLetDomainsIn vars l v -> Ast.AstLetDomainsIn vars l (astPrimalPart v)
Ast.AstCond b a2 a3 -> astCond b (astPrimalPart a2) (astPrimalPart a3)
Ast.AstFwd{} -> Ast.AstPrimalPart t -- the other only normal form
Ast.AstFold{} -> Ast.AstPrimalPart t
Ast.AstFoldRev{} -> Ast.AstPrimalPart t

astPrimalPartS :: (GoodScalar r, Sh.Shape sh)
=> AstShaped FullSpan r sh -> AstShaped PrimalSpan r sh
Expand Down Expand Up @@ -1583,6 +1591,8 @@ astDualPart t = case t of
Ast.AstLetDomainsIn vars l v -> Ast.AstLetDomainsIn vars l (astDualPart v)
Ast.AstCond b a2 a3 -> astCond b (astDualPart a2) (astDualPart a3)
Ast.AstFwd{} -> Ast.AstDualPart t
Ast.AstFold{} -> Ast.AstDualPart t
Ast.AstFoldRev{} -> Ast.AstDualPart t

astDualPartS :: (GoodScalar r, Sh.Shape sh)
=> AstShaped FullSpan r sh -> AstShaped DualSpan r sh
Expand Down Expand Up @@ -1738,6 +1748,12 @@ simplifyAst t = case t of
Ast.AstFwd (var, v) l ds -> Ast.AstFwd (var, simplifyAst v)
(V.map simplifyAstDynamic l)
(V.map simplifyAstDynamic ds)
Ast.AstFold (nvar, mvar, v) x0 as ->
Ast.AstFold (nvar, mvar, simplifyAst v) (simplifyAst x0) (simplifyAst as)
Ast.AstFoldRev (nvar, mvar, v) (varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstFoldRev (nvar, mvar, simplifyAst v)
(varDt2, nvar2, mvar2, simplifyAstDomains doms)
(simplifyAst x0) (simplifyAst as)

simplifyAstDynamic
:: AstSpan s
Expand Down Expand Up @@ -2270,7 +2286,7 @@ substitute1Ast i var v1 = case v1 of
(Nothing, Nothing) -> Nothing
(ml, mv) ->
Just $ Ast.AstLetDomainsIn vars (fromMaybe l ml) (fromMaybe v mv)
Ast.AstFwd (vars, v) args ds ->
Ast.AstFwd f args ds ->
-- No other free variables in v and var is not among vars.
let margs = V.map (\(DynamicExists d) ->
DynamicExists <$> substitute1AstDynamic i var d) args
Expand All @@ -2284,7 +2300,16 @@ substitute1Ast i var v1 = case v1 of
else Nothing
in case (marg, md) of
(Nothing, Nothing) -> Nothing
_ -> Just $ Ast.AstFwd (vars, v) (fromMaybe args marg) (fromMaybe ds md)
_ -> Just $ Ast.AstFwd f (fromMaybe args marg) (fromMaybe ds md)
Ast.AstFold f x0 as ->
case (substitute1Ast i var x0, substitute1Ast i var as) of
(Nothing, Nothing) -> Nothing
(mx0, mas) -> Just $ Ast.AstFold f (fromMaybe x0 mx0) (fromMaybe as mas)
Ast.AstFoldRev f df x0 as ->
case (substitute1Ast i var x0, substitute1Ast i var as) of
(Nothing, Nothing) -> Nothing
(mx0, mas) ->
Just $ Ast.AstFoldRev f df (fromMaybe x0 mx0) (fromMaybe as mas)

substitute1AstIndex
:: (GoodScalar r2, AstSpan s2)
Expand Down
Loading

0 comments on commit 1587916

Please sign in to comment.