Skip to content

Commit

Permalink
Simplify AstLetDomainsIn a little
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 24, 2023
1 parent 43f58f1 commit b585479
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 17 deletions.
28 changes: 25 additions & 3 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import HordeAd.Core.AstSimplify
import HordeAd.Core.AstTools
import HordeAd.Core.TensorClass
import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances (sameShape)
import HordeAd.Internal.OrthotopeOrphanInstances (matchingRank, sameShape)
import HordeAd.Util.ShapedList (ShapedList (..))
import HordeAd.Util.SizedIndex

Expand Down Expand Up @@ -127,7 +127,19 @@ interpretAst !env = \case
`blame` (sh, rshape t, var, t, env)) t
_ -> error "interpretAst: type mismatch"
_ -> error "interpretAst: wrong shape in environment"
Just{} -> error "interpretAst: wrong tensor kind in environment"
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- Just{} -> error "interpretAst: wrong tensor kind in environment"
Just (AstEnvElemS @sh2 @r2 t) -> case shapeToList sh == Sh.shapeT @sh2 of
True -> case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> rfromS @_ @_ @r2 @sh2 t
_ -> error "interpretAst: type mismatch"
_ -> error "interpretAst: wrong rank"
False -> error $ "interpretAst: wrong shape in environment"
`showFailure`
(sh, Sh.shapeT @sh2, var, t, env)
Nothing -> error $ "interpretAst: unknown variable " ++ show var
++ " in environment " ++ show env
AstLet var u v ->
Expand Down Expand Up @@ -679,7 +691,17 @@ interpretAstS !env = \case
Nothing -> error $ "interpretAstS: wrong shape in environment"
`showFailure`
(Sh.shapeT @sh, Sh.shapeT @sh2, var, t, env)
Just{} -> error "interpretAstS: wrong tensor kind in environment"
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- Just{} -> error "interpretAstS: wrong tensor kind in environment"
Just (AstEnvElemR @n2 @r2 t) -> case matchingRank @sh @n2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> assert (Sh.shapeT @sh == shapeToList (rshape t)
`blame` (Sh.shapeT @sh, rshape t, var, t, env))
$ sfromR @_ @_ @r2 @sh t
_ -> error "interpretAstS: type mismatch"
_ -> error "interpretAstS: wrong shape in environment"
Nothing -> error $ "interpretAstS: unknown variable " ++ show var
AstLetS var u v ->
-- We assume there are no nested lets with the same variable.
Expand Down
85 changes: 80 additions & 5 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1651,18 +1651,70 @@ astLetInDomainsS var u v | astIsSmallS True u =
astLetInDomainsS var u v = Ast.AstLetInDomainsS var u v

astLetDomainsIn
:: forall n s s2 r. (AstSpan s, KnownNat n)
:: forall n s s2 r. (AstSpan s, GoodScalar r, KnownNat n)
=> [AstDynamicVarName] -> AstDomains s
-> AstRanked s2 r n
-> AstRanked s2 r n
astLetDomainsIn vars l v = Ast.AstLetDomainsIn vars l v
astLetDomainsIn vars l v =
let sh = shapeAst v
in Sh.withShapeP (shapeToList sh) $ \proxy -> case proxy of
Proxy @sh | Just Refl <- matchingRank @sh @n -> case l of
Ast.AstDomains l3 -> -- TODO: other cases: collect AstLetInDomains
let f :: (AstDynamicVarName, DynamicExists (AstDynamic s))
-> AstRanked s2 r n
-> AstRanked s2 r n
f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId)
, DynamicExists @r4 (Ast.AstRToD @n4 v3) )
acc
| Just Refl <- matchingRank @sh3 @n4
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- , Just Refl <- testEquality (typeRep @k) (typeRep @Nat)
, Just Refl <- testEquality (typeRep @r3) (typeRep @r4) =
Ast.AstLet (AstVarName varId) v3 acc
f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId)
, DynamicExists @r4 (Ast.AstSToD @sh4 v3) )
acc
| Just Refl <- sameShape @sh3 @sh4
, Just Refl <- testEquality (typeRep @r3) (typeRep @r4) =
Ast.AstSToR @sh
$ Ast.AstLetS (AstVarName varId) v3 $ Ast.AstRToS acc
f _ _ = error "astLetDomainsIn: corrupted arguments"
in foldr f v (zip vars (V.toList l3))
_ -> Ast.AstLetDomainsIn vars l v
_ -> error "astLetDomainsIn: wrong rank of the argument"

astLetDomainsInS
:: forall sh s s2 r. (AstSpan s, Sh.Shape sh)
=> [AstDynamicVarName] -> AstDomains s
-> AstShaped s2 r sh
-> AstShaped s2 r sh
astLetDomainsInS vars l v = Ast.AstLetDomainsInS vars l v
astLetDomainsInS vars l v =
case someNatVal $ toInteger (length (Sh.shapeT @sh)) of
Just (SomeNat @n _) -> gcastWith (unsafeCoerce Refl :: n :~: Sh.Rank sh)
$ case l of
Ast.AstDomains l3 -> -- TODO: other cases: collect AstLetInDomainsS
let f :: (AstDynamicVarName, DynamicExists (AstDynamic s))
-> AstShaped s2 r sh
-> AstShaped s2 r sh
f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId)
, DynamicExists @r4 (Ast.AstRToD @n4 v3) )
acc
| Just Refl <- matchingRank @sh3 @n4
, Just Refl <- testEquality (typeRep @r3) (typeRep @r4) =
Ast.AstRToS @sh
$ Ast.AstLet (AstVarName varId) v3 $ Ast.AstSToR acc
f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId)
, DynamicExists @r4 (Ast.AstSToD @sh4 v3) )
acc
| Just Refl <- sameShape @sh3 @sh4
, Just Refl <- testEquality (typeRep @r3) (typeRep @r4) =
Ast.AstLetS (AstVarName varId) v3 acc
f _ _ = error "astLetDomainsInS: corrupted arguments"
in foldr f v (zip vars (V.toList l3))
_ -> Ast.AstLetDomainsInS vars l v
_ -> error "astLetDomainsInS: impossible someNatVal"


-- * The simplifying bottom-up pass
Expand Down Expand Up @@ -2211,7 +2263,19 @@ substitute1Ast i var v1 = case v1 of
_ -> error "substitute1Ast: scalar"
_ -> error "substitute1Ast: rank"
_ -> error "substitute1Ast: span"
_ -> error "substitute1Ast: type"
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- _ -> error "substitute1Ast: type"
SubstitutionPayloadShaped @_ @_ @sh2 t -> case sameAstSpan @s @s2 of
Just Refl -> case shapeToList sh == Sh.shapeT @sh2 of
True -> case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> Just $ astSToR t
_ -> error "substitute1Ast: scalar"
_ -> error "substitute1Ast: rank"
False -> error "substitute1Ast: shape"
_ -> error "substitute1Ast: span"
else Nothing
Ast.AstLet var2 u v ->
case (substitute1Ast i var u, substitute1Ast i var v) of
Expand Down Expand Up @@ -2464,7 +2528,18 @@ substitute1AstS i var = \case
_ -> error "substitute1AstS: scalar"
_ -> error "substitute1AstS: shape"
_ -> error "substitute1Ast: span"
_ -> error "substitute1AstS: type"
-- To impose such checks, we'd need to switch from OD tensors
-- to existential OR/OS tensors so that we can inspect
-- which it is and then seed Delta evaluation maps with that.
-- _ -> error "substitute1AstS: type"
SubstitutionPayloadRanked @_ @_ @m t -> case sameAstSpan @s @s2 of
Just Refl -> case matchingRank @sh @m of
Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of
Just Refl -> assert (Sh.shapeT @sh == shapeToList (shapeAst t))
$ Just $ astRToS t
_ -> error "substitute1Ast: scalar"
_ -> error "substitute1Ast: rank"
_ -> error "substitute1Ast: span"
else Nothing
Ast.AstLetS var2 u v ->
case (substitute1AstS i var u, substitute1AstS i var v) of
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ isTensorDummyAst t = case t of

-- TODO: move the impure part to AstFreshId
astLetDomainsInFun
:: forall n s r. (AstSpan s, KnownNat n)
:: forall n s r. (AstSpan s, GoodScalar r, KnownNat n)
=> DomainsOD -> AstDomains s -> (Domains (AstDynamic s) -> AstRanked s r n)
-> AstRanked s r n
{-# NOINLINE astLetDomainsInFun #-}
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class ( Integral (IntOf ranked), CRanked ranked Num
rzero :: (GoodScalar r, KnownNat n)
=> ShapeInt n -> ranked r n
rzero sh = rreplicate0N sh 0
rletDomainsIn :: KnownNat n
rletDomainsIn :: (KnownNat n, GoodScalar r)
=> DomainsOD
-> DomainsOf ranked
-> (Domains (DynamicOf ranked) -> ranked r n)
Expand Down
Loading

0 comments on commit b585479

Please sign in to comment.