Skip to content

Commit

Permalink
Pass the rrev tests through rev'
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 11, 2023
1 parent 497accd commit 218b958
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 52 deletions.
17 changes: 13 additions & 4 deletions src/HordeAd/Core/DualNumber.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE QuantifiedConstraints, UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes, QuantifiedConstraints,
UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | Dual numbers and arithmetic operations on them. This is a part of
Expand Down Expand Up @@ -241,13 +242,21 @@ instance (GoodScalar r, Sh.Shape sh, ShapedTensor (ADVal shaped))
LetS{} -> d -- should not happen, but older/lower id is safer anyway
_ -> wrapDeltaS d

instance IsPrimal (Clown (ADValClown dynamic)) r '() where
dZeroOfShape = undefined
instance ( GoodScalar r
, dynamic ~ DynamicOf (ShapedOf @() (Clown dynamic))
, ConvertTensor (RankedOf @() (Clown dynamic))
(ShapedOf @() (Clown dynamic)) )
=> IsPrimal (Clown (Flip (ADVal (Clown dynamic)) '())) r '() where
dZeroOfShape (Clown (Flip (D _ (Clown tsh) _))) =
let shL = dshape @(RankedOf @() (Clown dynamic)) tsh
in case someNatVal $ toInteger $ length shL of
Just (SomeNat @n _) -> RToD @n (ZeroR (listShapeToShape shL))
Nothing -> error "dZeroOfShape: impossible someNatVal error"
dScale = undefined
dAdd = undefined
intOfShape = undefined
recordSharingPrimal = undefined
recordSharing = undefined
recordSharing = undefined


-- * Auxiliary definitions
Expand Down
72 changes: 51 additions & 21 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (testEquality, (:~:) (Refl))
import qualified Data.Vector.Generic as V
import GHC.TypeLits (KnownNat, Nat, sameNat, type (+))
import GHC.TypeLits
(KnownNat, Nat, SomeNat (..), sameNat, someNatVal, type (+))
import System.IO.Unsafe (unsafePerformIO)
import Type.Reflection (typeRep)

Expand Down Expand Up @@ -90,13 +91,17 @@ instance (GoodScalar r, Sh.Shape sh)
LetS{} -> d -- should not happen, but older/lower id is safer anyway
_ -> wrapDeltaS d

instance IsPrimal (Clown (AstDynamic s)) r '() where
dZeroOfShape = undefined
instance GoodScalar r => IsPrimal (Clown (AstDynamic PrimalSpan)) r '() where
dZeroOfShape (Clown tsh) =
let shL = dshape @(AstRanked PrimalSpan) tsh
in case someNatVal $ toInteger $ length shL of
Just (SomeNat @n _) -> RToD @n (ZeroR (listShapeToShape shL))
Nothing -> error "dZeroOfShape: impossible someNatVal error"
dScale = undefined
dAdd = undefined
intOfShape = undefined
recordSharingPrimal = undefined
recordSharing = undefined
recordSharing = undefined


-- * Reverse and forward derivative stages instances
Expand Down Expand Up @@ -846,16 +851,28 @@ instance AstSpan s
rD u u' = AstNoVectorize $ astSpanD u u'
rScale s t = astDualPart $ AstConstant s * AstD (rzero (rshape s)) t

instance AstSpan s
=> ShapedTensor (AstNoVectorizeS s) where

instance ConvertTensor (AstNoVectorize 'PrimalSpan)
(AstNoVectorizeS 'PrimalSpan) where

instance DomainsTensor (AstNoVectorize s) (AstNoVectorizeS s) where

instance AstSpan s
=> RankedTensor (AstNoSimplify s) where
instance AstSpan s => ShapedTensor (AstNoVectorizeS s) where

instance AstSpan s => ConvertTensor (AstNoVectorize s) (AstNoVectorizeS s) where
rfromD = AstNoVectorize . rfromD @(AstRanked s)
rfromS = AstNoVectorize . rfromS @(AstRanked s) . unAstNoVectorizeS
dfromR = dfromR @(AstRanked s) . unAstNoVectorize
dfromS = dfromS @(AstRanked s) . unAstNoVectorizeS
sfromR = AstNoVectorizeS . sfromR @(AstRanked s) . unAstNoVectorize
sfromD = AstNoVectorizeS . sfromD @(AstRanked s)
ddummy = ddummy @(AstRanked s)
dIsDummy = dIsDummy @(AstRanked s)
dshape = dshape @(AstRanked s)

instance AstSpan s => DomainsTensor (AstNoVectorize s) (AstNoVectorizeS s) where
dmkDomains = dmkDomains @(AstRanked s)
rletInDomains u f =
rletInDomains @(AstRanked s) (unAstNoVectorize u) (f . AstNoVectorize)
sletInDomains u f =
sletInDomains @(AstRanked s) (unAstNoVectorizeS u) (f . AstNoVectorizeS)
rrev = rrev @(AstRanked s)

instance AstSpan s => RankedTensor (AstNoSimplify s) where
rlet a f =
AstNoSimplify
$ astLetFunUnSimp (unAstNoSimplify a) (unAstNoSimplify . f . AstNoSimplify)
Expand Down Expand Up @@ -910,10 +927,23 @@ astLetFunUnSimp a f =
(var, ast) = funToAstR sh f
in AstLet var a ast

instance AstSpan s
=> ShapedTensor (AstNoSimplifyS s) where

instance ConvertTensor (AstNoSimplify 'PrimalSpan)
(AstNoSimplifyS 'PrimalSpan) where

instance DomainsTensor (AstNoSimplify s) (AstNoSimplifyS s) where
instance AstSpan s => ShapedTensor (AstNoSimplifyS s) where

instance AstSpan s => ConvertTensor (AstNoSimplify s) (AstNoSimplifyS s) where
rfromD = AstNoSimplify . rfromD @(AstRanked s)
rfromS = AstNoSimplify . rfromS @(AstRanked s) . unAstNoSimplifyS
dfromR = dfromR @(AstRanked s) . unAstNoSimplify
dfromS = dfromS @(AstRanked s) . unAstNoSimplifyS
sfromR = AstNoSimplifyS . sfromR @(AstRanked s) . unAstNoSimplify
sfromD = AstNoSimplifyS . sfromD @(AstRanked s)
ddummy = ddummy @(AstRanked s)
dIsDummy = dIsDummy @(AstRanked s)
dshape = dshape @(AstRanked s)

instance AstSpan s => DomainsTensor (AstNoSimplify s) (AstNoSimplifyS s) where
dmkDomains = dmkDomains @(AstRanked s)
rletInDomains u f =
rletInDomains @(AstRanked s) (unAstNoSimplify u) (f . AstNoSimplify)
sletInDomains u f =
sletInDomains @(AstRanked s) (unAstNoSimplifyS u) (f . AstNoSimplifyS)
rrev = rrev @(AstRanked s)
75 changes: 48 additions & 27 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,14 @@ testTrees =
, testCase "2Sin0Rrev" testSin0Rrev
, testCase "2Sin0RrevPP1" testSin0RrevPP1
, testCase "2Sin0RrevPP2" testSin0RrevPP2
, testCase "2Sin0Rrev2" testSin0Rrev2
, testCase "2Sin0Rrev3" testSin0Rrev3
, testCase "2Sin0RrevPP3" testSin0RrevPP3
, testCase "2Sin0Rrev4" testSin0Rrev4
, testCase "2Sin0RrevPP4" testSin0RrevPP4
, testCase "2Sin0Rrev5" testSin0Rrev5
, testCase "2Sin0RrevPP5" testSin0RrevPP5
, testCase "2Sin0Rrev3'" testSin0Rrev3'
, testCase "2Sin0Rrev4'" testSin0Rrev4'
, testCase "2Sin0Rrev5'" testSin0Rrev5'
]

testZero :: Assertion
Expand Down Expand Up @@ -1986,9 +1989,9 @@ testFooRrev3 = do
0
(crev f 1.1)

sin0Rrev :: forall g a. (ADReady g, GoodScalar a)
=> (forall f. ADReady f => f a 0 -> f a 0) -> g a 0 -> g a 0
sin0Rrev f u =
rrev00 :: forall g a. (ADReady g, GoodScalar a)
=> (forall f. ADReady f => f a 0 -> f a 0) -> g a 0 -> g a 0
rrev00 f u =
let fromDynamicExists :: forall f. ADReady f
=> DynamicExists (DynamicOf f) -> f a 0
fromDynamicExists (DynamicExists @r d)
Expand All @@ -1997,9 +2000,9 @@ sin0Rrev f u =
fromDoms :: forall f. ADReady f
=> Domains (DynamicOf f) -> f a 0
fromDoms v = fromDynamicExists $ v V.! 0
fooDomains :: forall f. ADReady f
=> Domains (DynamicOf f) -> f a 0
fooDomains v = f (fromDoms v)
fDomains :: forall f. ADReady f
=> Domains (DynamicOf f) -> f a 0
fDomains v = f (fromDoms v)
toDynamicExists :: forall f. ADReady f
=> f a 0 -> DynamicExists (DynamicOf f)
toDynamicExists a = DynamicExists $ dfromR a
Expand All @@ -2008,7 +2011,7 @@ sin0Rrev f u =
shapes = V.fromList [zero]
domsOf =
rrev @g
fooDomains
fDomains
shapes
(V.fromList $ map (toDynamicExists @g) [u])
in rletDomainsIn shapes domsOf (\v -> fromDynamicExists $ v V.! 0)
Expand All @@ -2017,48 +2020,66 @@ testSin0Rrev :: Assertion
testSin0Rrev = do
assertEqualUpToEpsilon 1e-10
0.4535961214255773
(sin0Rrev @(Flip OR.Array) @Double sin 1.1)
(rrev00 @(Flip OR.Array) @Double sin 1.1)

testSin0RrevPP1 :: Assertion
testSin0RrevPP1 = do
resetVarCounter
let a1 = sin0Rrev @(AstRanked FullSpan) @Double sin 1.1
let a1 = rrev00 @(AstRanked FullSpan) @Double sin 1.1
printAstPretty IM.empty a1
@?= "rletDomainsIn (cos (rconst 1.1) * rreshape [] (rreplicate 1 (rconst 1.0))) (\\[dret] -> dret)"

testSin0RrevPP2 :: Assertion
testSin0RrevPP2 = do
let a1 = sin0Rrev @(AstRanked FullSpan) @Double sin 1.1
let a1 = rrev00 @(AstRanked FullSpan) @Double sin 1.1
printAstSimple IM.empty a1
@?= "rletDomainsIn (dmkDomains (fromList [dfromR (cos (rconst 1.1) * rreshape [] (rreplicate 1 (rconst 1.0)))])) (\\[dret] -> dret)"

testSin0Rrev2 :: Assertion
testSin0Rrev2 = do
let f = sin0Rrev @(ADVal (Flip OR.Array)) @Double sin
testSin0Rrev3 :: Assertion
testSin0Rrev3 = do
let f = rrev00 @(ADVal (Flip OR.Array)) @Double sin
assertEqualUpToEpsilon 1e-10
(-0.8912073600614354)
(crev f 1.1)

testSin0Rrev3 :: Assertion
testSin0Rrev3 = do
testSin0Rrev4 :: Assertion
testSin0Rrev4 = do
assertEqualUpToEpsilon 1e-10
0.8988770945225438
((sin0Rrev sin . sin0Rrev @(Flip OR.Array) @Double sin) 1.1)
((rrev00 sin . rrev00 @(Flip OR.Array) @Double sin) 1.1)

testSin0RrevPP3 :: Assertion
testSin0RrevPP3 = do
let a1 = (sin0Rrev sin . sin0Rrev @(AstRanked FullSpan) @Double sin) 1.1
testSin0RrevPP4 :: Assertion
testSin0RrevPP4 = do
let a1 = (rrev00 sin . rrev00 @(AstRanked FullSpan) @Double sin) 1.1
printAstPretty IM.empty (simplifyAst6 a1)
@?= "rletDomainsIn (cos (rletDomainsIn (cos (rconst 1.1) * rconst 1.0) (\\[dret] -> dret)) * rconst 1.0) (\\[x4] -> x4)"

testSin0Rrev4 :: Assertion
testSin0Rrev4 = do
testSin0Rrev5 :: Assertion
testSin0Rrev5 = do
assertEqualUpToEpsilon 1e-10
(-0.8912073600614354)
(sin0Rrev @(Flip OR.Array) @Double (sin0Rrev sin) 1.1)
(rrev00 @(Flip OR.Array) @Double (rrev00 sin) 1.1)

testSin0RrevPP4 :: Assertion
testSin0RrevPP4 = do
let a1 = sin0Rrev @(AstRanked FullSpan) @Double (sin0Rrev sin) 1.1
testSin0RrevPP5 :: Assertion
testSin0RrevPP5 = do
let a1 = rrev00 @(AstRanked FullSpan) @Double (rrev00 sin) 1.1
printAstPretty IM.empty (simplifyAst6 a1)
@?= "rletDomainsIn (negate (sin (rconst 1.1)) * (rconst 1.0 * rconst 1.0)) (\\[x7] -> x7)"

testSin0Rrev3' :: Assertion
testSin0Rrev3' = do
assertEqualUpToEpsilon' 1e-10
(-0.8912073600614354 :: OR.Array 0 Double)
(rev' (rrev00 sin) 1.1)

testSin0Rrev4' :: Assertion
testSin0Rrev4' = do
assertEqualUpToEpsilon' 1e-10
(0.39052780643689855 :: OR.Array 0 Double)
(rev' (rrev00 sin . rrev00 sin) 1.1)

testSin0Rrev5' :: Assertion
testSin0Rrev5' = do
assertEqualUpToEpsilon' 1e-10
(-0.4535961214255773 :: OR.Array 0 Double)
(rev' (rrev00 (rrev00 sin)) 1.1)

0 comments on commit 218b958

Please sign in to comment.