Skip to content

Commit

Permalink
Capture the normal form of UnWind in RepD
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 10, 2024
1 parent 2972f15 commit 707a1b0
Showing 1 changed file with 16 additions and 32 deletions.
48 changes: 16 additions & 32 deletions src/HordeAd/Core/HVectorOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Util.SizedList

-- This captures the normal form of type family UnWind and also
-- corresponds to the portion of ox-arrays that has Num defined.
type role RepD nominal nominal
data RepD target y where
DTKScalar :: GoodScalar r
Expand Down Expand Up @@ -125,6 +127,8 @@ addRepD ::
addRepD a b = case (a, b) of
(DTKScalar ta, DTKScalar tb) ->
DTKScalar $ rtoScalar $ rfromScalar ta + rfromScalar tb
-- somehow this results in shorter terms than @ta + tb@
-- TODO: re-evaluate once scalar term simplification is complete
(DTKR ta, DTKR tb) -> DTKR $ ta + tb
(DTKS ta, DTKS tb) -> DTKS $ ta + tb
(DTKX ta, DTKX tb) -> DTKX $ ta + tb
Expand Down Expand Up @@ -174,6 +178,7 @@ type family UnWind tk where
UnWind TKUntyped =
TKUntyped

-- TODO: should be unused now that we removed addWindShare?
unWindSTK :: STensorKindType y -> STensorKindType (UnWind y)
unWindSTK = \case
stk@STKScalar{} -> stk
Expand Down Expand Up @@ -214,29 +219,29 @@ unWindSTK = \case
STKS _ STKUntyped -> error "unWindSTK: STKUntyped can't be nested in arrays"
STKX _ STKUntyped -> error "unWindSTK: STKUntyped can't be nested in arrays"

-- Alternatively the codomain could be RepD, which clearly indicates
-- what the normal form of UnWind is.
unWindShare :: (BaseTensor target, ShareTensor target)
=> STensorKindType y -> target y -> target (UnWind y)
=> STensorKindType y -> target y -> RepD target (UnWind y)
unWindShare stk t = case stk of
STKScalar{} -> t
STKR _ STKScalar{} -> t
STKS _ STKScalar{} -> t
STKScalar{} -> DTKScalar t
STKR SNat STKScalar{} -> DTKR t
STKS sh STKScalar{} -> withKnownShS sh $ DTKS t
STKS sh (STKS sh2 stk2) | Dict <- lemTensorKindOfSTK stk2 ->
withKnownShS sh $ withKnownShS sh2 $ withKnownShS (shsAppend sh sh2)
$ unWindShare (STKS (shsAppend sh sh2) stk2) (sunNest t)
STKS sh (STKProduct stk1 stk2) | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2 ->
withKnownShS sh
$ unWindShare (STKProduct (STKS sh stk1) (STKS sh stk2)) (sunzip t)
STKX _ STKScalar{} -> t
STKX sh STKScalar{} -> withKnownShX sh $ DTKX t
STKProduct stk1 stk2 | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2
, (Dict, Dict) <- lemTensorKind1OfSTK (unWindSTK stk1)
, (Dict, Dict) <- lemTensorKind1OfSTK (unWindSTK stk2) ->
let (t1, t2) = tunpair t
in tpair (unWindShare stk1 t1) (unWindShare stk2 t2)
STKUntyped -> t
in DTKProduct (unWindShare stk1 t1) (unWindShare stk2 t2)
STKUntyped ->
let vt = tunvector t
in DTKUntyped vt
_ -> error "TODO"

windShare :: (BaseTensor target, ShareTensor target)
Expand All @@ -262,35 +267,14 @@ windShare stk t = case stk of
STKUntyped -> t
_ -> error "TODO"

addWindShare ::
(ADReadyNoLet target, ShareTensor target)
=> STensorKindType y
-> target y -> target y -> target y
addWindShare stk a b = case stk of
STKScalar{} -> a + b
STKR SNat STKScalar{} -> a + b
STKS sh STKScalar{} -> withKnownShS sh $ a + b
STKX sh STKScalar{} -> withKnownShX sh $ a + b
STKProduct stk1 stk2 | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2 ->
let (a1, a2) = tunpair a
(b1, b2) = tunpair b
in tpair (addWindShare stk1 a1 b1) (addWindShare stk2 a2 b2)
STKUntyped ->
let va = tunvector a
vb = tunvector b
in dmkHVector $ V.zipWith addDynamic va vb
_ -> error "addWindShare: impossible normal form of UnWind"

addShare ::
(ADReadyNoLet target, ShareTensor target)
=> STensorKindType y
-> target y -> target y -> target y
addShare stk a b =
let stk2 = unWindSTK stk
a2 = unWindShare stk a
let a2 = unWindShare stk a
b2 = unWindShare stk b
in windShare stk $ addWindShare stk2 a2 b2
in windShare stk $ fromRepD $ addRepD a2 b2


-- * Dynamic
Expand Down

0 comments on commit 707a1b0

Please sign in to comment.