Skip to content

Commit

Permalink
type HasShape = KnownNats
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyday567 committed Sep 12, 2024
1 parent 75f95f4 commit b6e2e4d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 26 deletions.
51 changes: 31 additions & 20 deletions src/NumHask/Array/Fixed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ instance
(HasShape s) =>
Data.Distributive.Distributive (Array s)
where
distribute :: (HasShape s, Functor f) => f (Array s a) -> Array s (f a)
distribute = distributeRep
{-# INLINE distribute #-}

Expand Down Expand Up @@ -1176,7 +1177,7 @@ indexesT _ a = unsafeBackpermute (S.insertDims (List.zip (shapeOf @ds) (shapeOf

-- | Select an index /except/ along specified dimensions.
--
-- >>> let s = indexesExcept (Proxy :: Proxy '[2]) [1,1] a
-- >>> let s = indexesExcept (S.SNats @'[2]) [1,1] a
-- >>> :t s
-- s :: Array '[4] Int
--
Expand All @@ -1187,13 +1188,14 @@ indexesExcept ::
( HasShape s,
HasShape ds,
HasShape s',
KnownNats ds,
s' ~ Eval (TakeDims ds s)
) =>
Proxy ds ->
SNats ds ->
[Int] ->
Array s a ->
Array s' a
indexesExcept _ i a = unsafeBackpermute (\s -> insertDims (List.zip (shapeOf @ds) s) i) a
indexesExcept ds i a = unsafeBackpermute (\s -> insertDims (List.zip (Prelude.fromIntegral <$> natVals ds) s) i) a

-- | Select the first element along the supplied dimensions
--
Expand Down Expand Up @@ -1321,7 +1323,7 @@ extracts d a = tabulate (\s -> indexes d (fromFins s) a)

-- | Extracts /except/ dimensions to an outer layer.
--
-- >>> let e = extractsExcept (Proxy :: Proxy '[1,2]) a
-- >>> let e = extractsExcept (S.SNats @[1,2]) a
-- >>> pretty $ shape <$> e
-- [[3,4],[3,4]]
extractsExcept ::
Expand All @@ -1330,15 +1332,16 @@ extractsExcept ::
HasShape ds,
HasShape si,
HasShape so,
KnownNats ds,
so ~ Eval (DeleteDims ds st),
si ~ Eval (TakeDims ds st)
) =>
Proxy ds ->
SNats ds ->
Array st a ->
Array so (Array si a)
extractsExcept d a = tabulate go
extractsExcept ds a = tabulate go
where
go s = indexesExcept d (fromFins s) a
go s = indexesExcept ds (fromFins s) a

-- | Reduce along specified dimensions, using the supplied fold.
--
Expand Down Expand Up @@ -1628,10 +1631,11 @@ contract ::
HasShape ss,
HasShape s',
s' ~ Eval (DeleteDims ds s),
ss ~ '[Eval (Minimum (Eval (TakeDims ds s)))]
ss ~ '[Eval (Minimum (Eval (TakeDims ds s)))],
KnownNats ds
) =>
(Array ss a -> b) ->
Proxy ds ->
SNats ds ->
Array s a ->
Array s' b
contract f xs a = f . diag <$> extractsExcept xs a
Expand Down Expand Up @@ -1660,7 +1664,7 @@ contract f xs a = f . diag <$> extractsExcept xs a
-- > pretty $ dot sum (*) b v
-- [14,32]
dot ::
forall a b c d sa sb s' ss se.
forall a b c d sa sb s' ss se x.
( HasShape sa,
HasShape sb,
HasShape (Eval ((++) sa sb)),
Expand All @@ -1671,15 +1675,17 @@ dot ::
KnownNat (Eval (Rank sa)),
ss ~ '[Eval (Minimum se)],
HasShape ss,
s' ~ Eval (DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb))),
HasShape s'
s' ~ Eval (DeleteDims x (Eval ((++) sa sb))),
HasShape s',
KnownNats x,
x ~ '[Eval (Rank sa) - 1, Eval (Rank sa)]
) =>
(Array ss c -> d) ->
(a -> b -> c) ->
Array sa a ->
Array sb b ->
Array s' d
dot f g a b = contract f (Proxy :: Proxy '[Eval (Rank sa) - 1, Eval (Rank sa)]) (expand g a b)
dot f g a b = contract f (SNats :: SNats x) (expand g a b)

-- | Array multiplication.
--
Expand All @@ -1704,7 +1710,7 @@ dot f g a b = contract f (Proxy :: Proxy '[Eval (Rank sa) - 1, Eval (Rank sa)])
-- > pretty $ mult b v
-- [14,32]
mult ::
forall a sa sb s' ss se.
forall a sa sb s' ss se x.
( Additive a,
Multiplicative a,
HasShape sa,
Expand All @@ -1717,8 +1723,11 @@ mult ::
KnownNat (Eval (Rank sa)),
ss ~ '[Eval (Minimum se)],
HasShape ss,
s' ~ Eval (DeleteDims '[Eval (Rank sa) - 1, Eval (Rank sa)] (Eval ((++) sa sb))),
HasShape s'
s' ~ Eval (DeleteDims x (Eval ((++) sa sb))),
HasShape s',
KnownNats x,
x ~ '[Eval (Rank sa) - 1, Eval (Rank sa)]

) =>
Array sa a ->
Array sb a ->
Expand Down Expand Up @@ -2408,7 +2417,8 @@ instance
P.Distributive a,
Subtractive a,
KnownNat m,
HasShape '[m, m]
HasShape '[m, m],
KnownNats '[1,2]
) =>
Multiplicative (Matrix m m a)
where
Expand All @@ -2423,7 +2433,8 @@ instance
Eq a,
ExpField a,
KnownNat m,
HasShape '[m, m]
HasShape '[m, m],
KnownNats '[1,2]
) =>
Divisive (Matrix m m a)
where
Expand Down Expand Up @@ -2637,7 +2648,7 @@ uniform g r = do
-- [2.1111111111111107,-0.5555555555555555,0.1111111111111111]]
--
-- > D.mult (D.inverse a) a == a
inverse :: (Eq a, ExpField a, KnownNat m) => Matrix m m a -> Matrix m m a
inverse :: (Eq a, ExpField a, KnownNat m, KnownNats [1,2]) => Matrix m m a -> Matrix m m a
inverse a = mult (invtri (transpose (chol a))) (invtri (chol a))

-- | [Inversion of a Triangular Matrix](https://math.stackexchange.com/questions/1003801/inverse-of-an-invertible-upper-triangular-matrix-of-order-3)
Expand All @@ -2649,7 +2660,7 @@ inverse a = mult (invtri (transpose (chol a))) (invtri (chol a))
-- [0.0,0.0,1.0]]
-- >>> ident == mult t (invtri t)
-- True
invtri :: forall a n. (KnownNat n, ExpField a, Eq a) => Array '[n, n] a -> Array '[n, n] a
invtri :: forall a n. (KnownNat n, KnownNats [1,2], ExpField a, Eq a) => Array '[n, n] a -> Array '[n, n] a
invtri a = sum (fmap (l ^) (iota @n)) * ti
where
ti = undiag (fmap recip (diag a))
Expand Down
65 changes: 59 additions & 6 deletions src/NumHask/Array/Shape.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ module NumHask.Array.Shape
withSomeNat,
valueOf,
int,
Shape (..),
HasShape (..),
SNats (..),
pattern SNats,
fromSNats,
KnownNats (..),
natVals,
HasShape,
shapeOf,
rankOf,
sizeOf,
Expand Down Expand Up @@ -193,6 +197,52 @@ valueOf = Prelude.fromIntegral $ natVal (Proxy :: Proxy n)
int :: SNat n -> Int
int = Prelude.fromIntegral . fromSNat

-- | Mimics SNat from GHC.TypeNats
newtype SNats (ns :: [Nat]) = UnsafeSNats [Nat]

instance (KnownNats ns) => Show (SNats ns)
where
show s = "SNats @" <> bool "" "'" (length (natVals s) < 2) <> "[" <> mconcat (List.intersperse ", " (show <$> (natVals s))) <> "]"

type role SNats nominal

pattern SNats :: forall ns. () => KnownNats ns => SNats ns
pattern SNats <- (knownNatsInstance -> KnownNatsInstance)
where SNats = natsSing

fromSNats :: SNats s -> [Nat]
fromSNats (UnsafeSNats s) = s

-- An internal data type that is only used for defining the SNat pattern
-- synonym.
data KnownNatsInstance (ns :: [Nat]) where
KnownNatsInstance :: KnownNats ns => KnownNatsInstance ns

-- An internal function that is only used for defining the SNat pattern
-- synonym.
knownNatsInstance :: SNats ns -> KnownNatsInstance ns
knownNatsInstance dims = withKnownNats dims KnownNatsInstance

-- | Reflect a list of Nats
class KnownNats (ns :: [Nat]) where
natsSing :: SNats ns

instance KnownNats '[] where
natsSing = UnsafeSNats []

instance (KnownNat n, KnownNats s) => KnownNats (n ': s)
where
natsSing = UnsafeSNats (fromSNat (SNat :: SNat n) : fromSNats (SNats :: SNats s))

natVals :: forall ns proxy. KnownNats ns => proxy ns -> [Nat]
natVals _ = case natsSing :: SNats ns of
UnsafeSNats xs -> xs

withKnownNats :: forall ns rep (r :: TYPE rep).
SNats ns -> (KnownNats ns => r) -> r
withKnownNats = withDict @(KnownNats ns)

{-
-- | The Shape type holds a [Nat] at type level and the equivalent [Int] at value level.
--
-- >>> toShape @[2,3,4]
Expand All @@ -211,29 +261,32 @@ instance HasShape '[] where
instance (KnownNat n, HasShape s) => HasShape (n : s) where
toShape = Shape $ Prelude.fromIntegral (natVal (Proxy :: Proxy n)) : shapeVal (toShape :: Shape s)
-}

type HasShape = KnownNats

-- | Supply the value-level of a 'HasShape'
-- | Supply the value-level of a 'HasShape' as an [Int]
--
-- >>> shapeOf @[2,3,4]
-- [2,3,4]
shapeOf :: forall s. (HasShape s) => [Int]
shapeOf = shapeVal (toShape @s)
shapeOf = Prelude.fromIntegral <$> natVals (Proxy :: Proxy s)
{-# INLINE shapeOf #-}

-- | The rank of a 'Shape'.
--
-- >>> rankOf @[2,3,4]
-- 3
rankOf :: forall s. (HasShape s) => Int
rankOf = length (shapeVal (toShape @s))
rankOf = length (shapeOf @s)
{-# INLINE rankOf #-}

-- | The size of a 'Shape'.
--
-- >>> sizeOf @[2,3,4]
-- 24
sizeOf :: forall s. (HasShape s) => Int
sizeOf = product (shapeVal (toShape @s))
sizeOf = product (shapeOf @s)
{-# INLINE sizeOf #-}

-- | Fin most often represents a (finite) zer-based index for a single dimension (of a multi-dimensioned hyper-rectangular array).
Expand Down

0 comments on commit b6e2e4d

Please sign in to comment.