Skip to content

Commit

Permalink
Use addTarget instead of taddLet
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 13, 2024
1 parent 200e9f9 commit 907ea35
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 35 deletions.
24 changes: 1 addition & 23 deletions src/HordeAd/Core/HVectorOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
-- API of the horde-ad library and it's relatively orthogonal to the
-- differentiation interface in "HordeAd.Core.Engine".
module HordeAd.Core.HVectorOps
( addTarget
, RepD(..)
, toRepDDuplicable, fromRepD, addRepD, addDynamic
( addTarget, RepD(..), addDynamic
, sizeHVector, shapeDynamic, dynamicsMatch, voidHVectorMatches
, voidFromDynamic, voidFromHVector, dynamicFromVoid
, fromDynamicR, fromDynamicS, unravelHVector, ravelHVector
Expand Down Expand Up @@ -74,26 +72,6 @@ data RepD target y where
DTKUntyped :: HVector target
-> RepD target TKUntyped

-- The argument of the first call (but not of recursive calls)
-- is assumed to be duplicable. In AST case, this creates
-- a tower of projections for product, but if it's balanced,
-- that's of logarithmic length, so maybe even better than sharing
-- excessively, which is hard for technical typing reasons.
toRepDDuplicable
:: BaseTensor target
=> STensorKindType x -> target x -> RepD target x
toRepDDuplicable stk t = case stk of
STKScalar _ -> DTKScalar t
STKR SNat STKScalar{} -> DTKR t
STKS sh STKScalar{} -> withKnownShS sh $ DTKS t
STKX sh STKScalar{} -> withKnownShX sh $ DTKX t
STKProduct stk1 stk2 | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2 ->
DTKProduct (toRepDDuplicable stk1 (tproject1 t))
(toRepDDuplicable stk2 (tproject2 t))
STKUntyped{} -> DTKUntyped $ dunHVector t
_ -> error "TODO"

fromRepD :: BaseTensor target
=> RepD target y -> target y
fromRepD = \case
Expand Down
16 changes: 4 additions & 12 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
in tlet db $ \ !db1 ->
let dx_dbRes = tpair dx (tproject2 db1)
in tlet (unHFun rf (tpair dx_dbRes acc_e)) $ \ !daccRes_deRes ->
let added = taddLet stensorKind (tproject1 daccRes_deRes)
(tproject1 db1)
let added = addTarget stensorKind (tproject1 daccRes_deRes)
(tproject1 db1)
in tpair added (tproject2 daccRes_deRes)
p = dmapAccumRDer (Proxy @target)
k accShs codomainShs eShs
Expand Down Expand Up @@ -748,8 +748,8 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
in tlet db $ \ !db1 ->
let dx_dbRes = tpair dx (tproject2 db1)
in tlet (unHFun rf (tpair dx_dbRes acc_e)) $ \ !daccRes_deRes ->
let added = taddLet stensorKind (tproject1 daccRes_deRes)
(tproject1 db1)
let added = addTarget stensorKind (tproject1 daccRes_deRes)
(tproject1 db1)
in tpair added (tproject2 daccRes_deRes)
p = dmapAccumLDer (Proxy @target)
k accShs codomainShs eShs
Expand All @@ -773,14 +773,6 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
df rf acc0' es'
in dD (tpair accFin bs) dual

taddLet :: ADReady target
=> STensorKindType y -> target y -> target y -> target y
taddLet stk t1 t2 | Dict <- lemTensorKindOfSTK stk =
tlet t1 $ \ !u1 ->
tlet t2 $ \ !u2 ->
fromRepD $ addRepD (toRepDDuplicable stk u1)
(toRepDDuplicable stk u2)

unADValDynamicTensor
:: DynamicTensor (ADVal f)
-> (DynamicTensor f, DynamicTensor (Delta f))
Expand Down

0 comments on commit 907ea35

Please sign in to comment.