Skip to content

Commit

Permalink
Let rrev take and return dual components but not confuse cotangents
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 3, 2023
1 parent 66654d1 commit d715418
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
15 changes: 7 additions & 8 deletions src/HordeAd/Core/TensorADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,10 @@ instance ( Dual ranked ~ DeltaR ranked shaped
ddummy = undefined
dshape = undefined

instance ( ADReadySmall (ADVal ranked) (ADVal shaped)
, DomainsTensor ranked shaped
, DualPart ranked
, Dual (Clown (DynamicOf ranked)) ~ DeltaD ranked shaped )
instance ( ADReadyBoth (ADVal (ADVal ranked)) (ADVal (ADVal shaped))
, DualPart (ADVal ranked)
, Dual (Clown (DynamicOf (ADVal ranked)))
~ DeltaD (ADVal ranked) (ADVal shaped) )
=> DomainsTensor (ADVal ranked) (ADVal shaped) where
dmkDomains = id
rletDomainsOf = (&)
Expand All @@ -489,8 +489,7 @@ instance ( ADReadySmall (ADVal ranked) (ADVal shaped)
rrev :: (GoodScalar r, KnownNat n)
=> (forall f. ADReady f => Domains (DynamicOf f) -> f r n)
-> DomainsOD
-> Domains (DynamicOf ranked)
-> DomainsOf ranked
-> Domains (DynamicOf (ADVal ranked))
-> DomainsOf (ADVal ranked)
rrev f _parameters0 parameters =
dmkDomains @ranked $ fst
$ crevOnDomains Nothing (f @(ADVal ranked)) parameters
fst $ crevOnDomains Nothing (f @(ADVal (ADVal ranked))) parameters
4 changes: 2 additions & 2 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ instance AstSpan s => DomainsTensor (AstRanked s) (AstShaped s) where
rrev :: GoodScalar r
=> (forall f. ADReady f => Domains (DynamicOf f) -> f r n)
-> DomainsOD
-> Domains (AstDynamic PrimalSpan)
-> AstDomains PrimalSpan
-> Domains (AstDynamic s)
-> AstDomains s
rrev f parameters0 domains =
AstRev (funToAstDomains @PrimalSpan f parameters0)
parameters0 (AstDomains domains)
Expand Down
14 changes: 9 additions & 5 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,13 @@ class DomainsTensor (ranked :: RankedTensorKind)
-- and the third has to have the same shapes as the second.
--
-- The function argument needs to be quantified (or an AST term),
-- because otherwise in the ADVal instance one could put illegal InputR there.
-- For the same reason there is PrimalOf in the last argument.
-- because otherwise in the ADVal instance one could put an illegal
-- InputR there, confusing two levels of contangents.
rrev :: (GoodScalar r, KnownNat n)
=> (forall f. ADReady f => Domains (DynamicOf f) -> f r n)
-> DomainsOD
-> Domains (DynamicOf (PrimalOf ranked))
-> DomainsOf (PrimalOf ranked)
-> Domains (DynamicOf ranked)
-> DomainsOf ranked


-- * The giga-constraint
Expand Down Expand Up @@ -656,7 +656,11 @@ type ADReadySmall ranked shaped =

type ADReadyBoth ranked shaped =
( ADReadySmall ranked shaped
, DomainsTensor ranked shaped
-- TODO: this doesn't type-check and not because rrev uses it,
-- but probably because ADVal instance of DomainsTensor uses ADReady
-- at one more ADVal nesting level:
--, DomainsTensor ranked shaped
-- so we can't nest rrev right now
, DomainsTensor (PrimalOf ranked) (PrimalOf shaped) )


Expand Down

0 comments on commit d715418

Please sign in to comment.