Skip to content

Commit

Permalink
Get rid of unwrapAstDomains
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 6, 2023
1 parent 1cc7b3a commit b4b56d0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 38 deletions.
10 changes: 1 addition & 9 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module HordeAd.Core.AstTools
-- * Determining if a term is too small to require sharing
, astIsSmall, astIsSmallS
-- * Odds and ends
, unwrapAstDomains, bindsToLet, bindsToLetS, bindsToDomainsLet
, bindsToLet, bindsToLetS, bindsToDomainsLet
) where

import Prelude hiding (foldl')
Expand All @@ -21,7 +21,6 @@ import qualified Data.Array.RankedS as OR
import qualified Data.Array.Shape as OS
import Data.List (foldl')
import Data.Proxy (Proxy (Proxy))
import qualified Data.Strict.Vector as Data.Vector
import Data.Type.Equality (gcastWith, (:~:) (Refl))
import qualified Data.Vector.Generic as V
import GHC.TypeLits
Expand Down Expand Up @@ -279,13 +278,6 @@ astIsSmallS relaxed = \case

-- * Odds and ends

unwrapAstDomains :: AstDomains s
-> Data.Vector.Vector (DynamicExists (AstDynamic s))
unwrapAstDomains = \case
AstDomains l -> l
AstDomainsLet _ _ v -> unwrapAstDomains v
AstDomainsLetS _ _ v -> unwrapAstDomains v

bindsToLet :: forall n s r. (KnownNat n, GoodScalar r, AstSpan s)
=> AstRanked s r n -> AstBindings (AstRanked s)
-> AstRanked s r n
Expand Down
48 changes: 19 additions & 29 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -246,28 +246,27 @@ build1V k (var, v00) =
(build1VOccurenceUnknown k (var, u'))
Ast.AstLetDomains @s1 vars l v ->
-- Here substitution traverses @v@ term tree @length vars@ times.
let subst (var1, DynamicExists (AstRToD u1)) =
let sh = shapeAst u1
projection =
Ast.AstIndex (Ast.AstVar (k :$ sh) $ AstVarName var1)
(Ast.AstIntVar var :. ZI)
in substituteAst (SubstitutionPayloadRanked @s1 @r projection)
(AstVarName var1)
subst (var1, DynamicExists (AstSToD @sh1 _)) =
--
-- We lose the type information surrounding var1 twice: first,
-- because we create a variable with one more dimension,
-- again, because the original variables might have been marked
-- with AstShaped and here we require AstRanked.
let subst (AstDynamicVarName @_ @sh1 (AstVarName var1)) =
let ls = OS.shapeT @sh1
in case someNatVal $ toInteger (length ls) of
Just (SomeNat @n2 _proxy) ->
let sh = listShapeToShape @n2 ls
Just (SomeNat @n2 _) ->
let shV = listShapeToShape @n2 ls
projection =
Ast.AstIndex (Ast.AstVar (k :$ sh) $ AstVarName var1)
Ast.AstIndex (Ast.AstVar (k :$ shV) $ AstVarName var1)
(Ast.AstIntVar var :. ZI)
in substituteAst (SubstitutionPayloadRanked @s1 @r projection)
(AstVarName var1)
Nothing -> error "build1V: impossible someNatVal error"
v2 = foldr subst v (zip (map dynamicVarNameToAstVarId vars)
(V.toList $ unwrapAstDomains l))
in Ast.AstLetDomains vars (build1VOccurenceUnknownDomains k (var, l))
(build1VOccurenceUnknownRefresh k (var, v2))
v2 = foldr subst v vars
in Ast.AstLetDomains
vars (build1VOccurenceUnknownDomains k (var, l))
(build1VOccurenceUnknownRefresh k (var, v2))
-- TODO: comment why @r instead of @r1 from AstDynamicVarName

build1VOccurenceUnknownDynamic
:: AstSpan s
Expand Down Expand Up @@ -570,26 +569,17 @@ build1VS (var, v00) =
Ast.AstDS (build1VOccurenceUnknownS (var, u))
(build1VOccurenceUnknownS (var, u'))
Ast.AstLetDomainsS @s1 vars l v ->
-- Here substitution traverses @v@ term tree @length vars@ times.
let subst (var1, DynamicExists (AstRToD u1)) =
OS.withShapeP (shapeToList $ shapeAst u1) $ \(_ :: Proxy sh1) ->
let projection =
Ast.AstIndexS (Ast.AstVarS @(k ': sh1) $ AstVarName var1)
(Ast.AstIntVar var :$: ZSH)
in substituteAstS (SubstitutionPayloadShaped @s1 @r projection)
(AstVarName var1)
subst (var1, DynamicExists (AstSToD @sh1 _)) =
-- See the AstLetDomains case for comments.
let subst (AstDynamicVarName @_ @sh1 (AstVarName var1)) =
let projection =
Ast.AstIndexS (Ast.AstVarS @(k ': sh1) $ AstVarName var1)
(Ast.AstIntVar var :$: ZSH)
in substituteAstS (SubstitutionPayloadShaped @s1 @r projection)
(AstVarName var1)
v2 = foldr subst v (zip (map dynamicVarNameToAstVarId vars)
(V.toList $ unwrapAstDomains l))
v2 = foldr subst v vars
in Ast.AstLetDomainsS
vars
(build1VOccurenceUnknownDomains (valueOf @k) (var, l))
(build1VOccurenceUnknownRefreshS (var, v2))
vars (build1VOccurenceUnknownDomains (valueOf @k) (var, l))
(build1VOccurenceUnknownRefreshS (var, v2))

build1VIndexS
:: forall k p sh s r.
Expand Down

0 comments on commit b4b56d0

Please sign in to comment.