diff --git a/src/NumHask/Array/Dynamic.hs b/src/NumHask/Array/Dynamic.hs index 51feef1..69db2ee 100644 --- a/src/NumHask/Array/Dynamic.hs +++ b/src/NumHask/Array/Dynamic.hs @@ -22,6 +22,10 @@ module NumHask.Array.Dynamic unsafeModifyShape, unsafeModifyVector, + -- * Dimensions + Dim, + Dims, + -- * Conversion FromVector (..), FromArray (..), @@ -388,6 +392,12 @@ unsafeModifyShape f (UnsafeArray s v) = UnsafeArray (f s) v unsafeModifyVector :: (FromVector u a) => (FromVector v b) => (u -> v) -> Array a -> Array b unsafeModifyVector f (UnsafeArray s v) = UnsafeArray s (asVector (f (vectorAs v))) +-- | Representation of an index into a shape (an [Int]). The index is a dimension of the shape. +type Dim = Int + +-- | Representation of indexes into a shape (an [Int]). The indexes are dimensions of the shape. +type Dims = [Int] + -- | shape of an Array -- -- >>> shape a @@ -642,7 +652,7 @@ imap f a = zipWith f (indices (shape a)) a -- -- >>> rowWise indexes [1,0] a -- UnsafeArray [4] [12,13,14,15] -rowWise :: ([Int] -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a +rowWise :: (Dims -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a rowWise f xs a = f [0..(S.rank xs - 1)] xs a -- | Apply a function that takes dimensions & parameters and applies a parameter list to the the last dimensions (in reverse). ie @@ -651,14 +661,14 @@ rowWise f xs a = f [0..(S.rank xs - 1)] xs a -- -- >>> colWise indexes [1,0] a -- UnsafeArray [2] [1,13] -colWise :: ([Int] -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a +colWise :: (Dims -> [x] -> Array a -> Array a) -> [x] -> Array a -> Array a colWise f xs a = f (List.reverse [(rank a - (S.rank xs)) .. (rank a - 1)]) xs a -- | Apply a function that takes a dimension and parameter, and folds a (dimension,parameter) list over an array. -- -- >>> dimsWise take [0,2] [1,2] a -- UnsafeArray [1,3,2] [0,1,4,5,8,9] -dimsWise :: (Int -> x -> Array a -> Array a) -> [Int] -> [x] -> Array a -> Array a +dimsWise :: (Dim -> x -> Array a -> Array a) -> Dims -> [x] -> Array a -> Array a dimsWise f ds xs a = foldl' (\a' (d, x) -> f d x a') a (List.zip ds xs) -- | Take the top-most elements across the specified dimension. Negative values take the bottom-most. No index check is performed. @@ -680,7 +690,7 @@ dimsWise f ds xs a = foldl' (\a' (d, x) -> f d x a') a (List.zip ds xs) -- [19], -- [23]]] take :: - Int -> + Dim -> Int -> Array a -> Array a @@ -706,7 +716,7 @@ take d t a = backpermute dsNew (modifyDim d (\x -> x + bool 0 (d' + t) (t < 0))) -- [16,17,18], -- [20,21,22]]] drop :: - Int -> + Dim -> Int -> Array a -> Array a @@ -721,7 +731,7 @@ drop d t a = backpermute dsNew (modifyDim d (\x -> x + bool t 0 (t < 0))) a -- [[3,7,11], -- [15,19,23]] select :: - Int -> + Dim -> Int -> Array a -> Array a @@ -739,7 +749,7 @@ select d x a = backpermute (deleteDim d) (insertDim d x) a -- >>> D.insert 0 0 (D.toScalar 1) (D.toScalar 2) -- UnsafeArray [2] [2,1] insert :: - Int -> + Dim -> Int -> Array a -> Array a -> @@ -761,7 +771,7 @@ insert d i a b = tabulate (incAt d (shape (asSingleton a))) go -- [16,17,18], -- [20,21,22]]] delete :: - Int -> + Dim -> Int -> Array a -> Array a @@ -777,7 +787,7 @@ delete d i a = backpermute (decAt d) (\s -> bool s (incAt d s) (s !! d < i)) (as -- [16,17,18,19,0], -- [20,21,22,23,0]]] append :: - Int -> + Dim -> Array a -> Array a -> Array a @@ -793,7 +803,7 @@ append d a b = insert d (getDim d (shape a)) a b -- [0,16,17,18,19], -- [0,20,21,22,23]]] prepend :: - Int -> + Dim -> Array a -> Array a -> Array a @@ -808,7 +818,7 @@ prepend d a b = insert d 0 b a -- >>> concatenate 0 (toScalar 0) (asArray [1..3]) -- UnsafeArray [4] [0,1,2,3] concatenate :: - Int -> + Dim -> Array a -> Array a -> Array a @@ -849,7 +859,7 @@ couple a a' = concatenate 0 (elongate 0 a) (elongate 0 a') -- [17,18], -- [21,22]]] slice :: - Int -> + Dim -> Int -> Int -> Array a -> @@ -867,7 +877,7 @@ slice d o l a = backpermute (setDim d l) (modifyDim d (+ o)) a -- [5,6,7], -- [9,10,11]]] takes :: - [Int] -> + Dims -> [Int] -> Array a -> Array a @@ -882,7 +892,7 @@ takes ds xs a = backpermute dsNew (List.zipWith (+) start) a -- >>> pretty $ drops [0,1,2] [1,2,-3] a -- [[[20]]] drops :: - [Int] -> + Dims -> [Int] -> Array a -> Array a @@ -898,7 +908,7 @@ drops ds xs a = backpermute dsNew (List.zipWith (\d' s' -> bool (d' + s') s' (d' -- >>> pretty s -- [16,17,18,19] indexes :: - [Int] -> + Dims -> [Int] -> Array a -> Array a @@ -912,7 +922,7 @@ indexes ds xs a = backpermute (deleteDims ds) (insertDims ds xs) a -- [17,18], -- [21,22]]] slices :: - [Int] -> + Dims -> [Int] -> [Int] -> Array a -> @@ -923,14 +933,14 @@ slices ds os ls a = dimsWise (\d (o,l) -> slice d o l) ds (List.zip os ls) a -- -- >>> pretty $ heads [0,2] a -- [0,4,8] -heads :: [Int] -> Array a -> Array a +heads :: Dims -> Array a -> Array a heads ds a = indexes ds (List.replicate (S.rank ds) 0) a -- | Select the last element along the supplied dimensions -- -- >>> pretty $ lasts [0,2] a -- [15,19,23] -lasts :: [Int] -> Array a -> Array a +lasts :: Dims -> Array a -> Array a lasts ds a = indexes ds lastds a where lastds = (\i -> shape a !! i - 1) <$> ds @@ -941,7 +951,7 @@ lasts ds a = indexes ds lastds a -- [[[13,14,15], -- [17,18,19], -- [21,22,23]]] -tails :: [Int] -> Array a -> Array a +tails :: Dims -> Array a -> Array a tails ds a = slices ds os ls a where os = List.replicate (S.rank ls) 1 @@ -953,7 +963,7 @@ tails ds a = slices ds os ls a -- [[[0,1,2], -- [4,5,6], -- [8,9,10]]] -inits :: [Int] -> Array a -> Array a +inits :: Dims -> Array a -> Array a inits ds a = slices ds os ls a where os = List.replicate (S.rank ls) 0 @@ -967,7 +977,7 @@ inits ds a = slices ds os ls a -- >>> pretty $ shape <$> extracts [0] a -- [[3,4],[3,4]] extracts :: - [Int] -> + Dims -> Array a -> Array (Array a) extracts ds a = tabulate (getDims ds (shape a)) go @@ -980,7 +990,7 @@ extracts ds a = tabulate (getDims ds (shape a)) go -- >>> pretty $ shape <$> extracts [0] a -- [[3,4],[3,4]] extractsExcept :: - [Int] -> + Dims -> Array a -> Array (Array a) extractsExcept ds a = extracts (exclude (rank a) ds) a @@ -994,7 +1004,7 @@ extractsExcept ds a = extracts (exclude (rank a) ds) a -- [48,51,54,57]] -- reduces :: - [Int] -> + Dims -> (Array a -> b) -> Array a -> Array b @@ -1007,7 +1017,7 @@ reduces ds f a = fmap f (extracts ds a) -- >>> a == j -- True joins :: - [Int] -> + Dims -> Array (Array a) -> Array a joins ds a = tabulate (insertDims ds so si) go @@ -1023,7 +1033,7 @@ joins ds a = tabulate (insertDims ds so si) go -- >>> a == j -- True joinsSafe :: - [Int] -> + Dims -> Array (Array a) -> Either NumHaskException (Array a) joinsSafe ds a = @@ -1063,7 +1073,7 @@ allEqual a = case arrayAs a of -- | Traverse along specified dimensions. traverses :: (Applicative f) => - [Int] -> + Dims -> (a -> f b) -> Array a -> f (Array b) @@ -1074,7 +1084,7 @@ traverses ds f a = join <$> traverse (traverse f) (extracts ds a) -- >>> shape $ maps [1] transpose a -- [4,3,2] maps :: - [Int] -> + Dims -> (Array a -> Array b) -> Array a -> Array b @@ -1088,7 +1098,7 @@ maps ds f a = joins ds (fmap f (extracts ds a)) -- [12,13,14,15], -- [20,21,22,23]] filters :: - [Int] -> + Dims -> (Array a -> Bool) -> Array a -> Array a @@ -1104,7 +1114,7 @@ filters ds p a = join (asArray $ V.filter p $ asVector (extracts ds a)) -- [(16,4),(17,5),(18,6),(19,7)], -- [(20,8),(21,9),(22,10),(23,11)]]] zips :: - [Int] -> + Dims -> (Array a -> Array b -> Array c) -> Array a -> Array b -> @@ -1116,7 +1126,7 @@ zips ds f a b = joins ds (zipWith f (extracts ds a) (extracts ds b)) -- >>> zipsSafe [0] (zipWith (,)) (asArray [1::Int]) (asArray [1,2::Int]) -- Left (NumHaskException {errorMessage = "MisMatched zip"}) zipsSafe :: - [Int] -> + Dims -> (Array a -> Array b -> Array c) -> Array a -> Array b -> @@ -1138,7 +1148,7 @@ zipsSafe ds f a b = -- [120,21,22,23]]] modifies :: (Array a -> Array a) -> - [Int] -> + Dims -> [Int] -> Array a -> Array a @@ -1151,7 +1161,7 @@ modifies f ds ps a = joins ds $ modify ps f (extracts ds a) -- [4,4,4,4]], -- [[4,4,4,4], -- [4,4,4,4]]] -diffs :: [Int] -> [Int] -> (Array a -> Array a -> Array b) -> Array a -> Array b +diffs :: Dims -> [Int] -> (Array a -> Array a -> Array b) -> Array a -> Array b diffs ds xs f a = zips ds f (drops ds xs a) (drops ds (fmap P.negate xs) a) -- | Product two arrays using the supplied binary function. @@ -1219,7 +1229,7 @@ expandr f a b = tabulate (shape a <> shape b) (\i -> f (index a (List.drop r i)) -- [32,77]] contract :: (Array a -> b) -> - [Int] -> + Dims -> Array a -> Array b contract f xs a = f . diag <$> extractsExcept xs a @@ -1516,7 +1526,7 @@ rerank r a = unsafeModifyShape (S.rerank r) a -- [[3,7,11], -- [15,19,23]]] reorder :: - [Int] -> + Dims -> Array a -> Array a reorder ds a = backpermute (`S.reorder` ds) (\s -> insertDims ds s []) a @@ -1538,7 +1548,7 @@ squeeze a = unsafeModifyShape S.squeeze a -- >>> elongate 0 (toScalar 1) -- UnsafeArray [1] [1] elongate :: - Int -> + Dim -> Array a -> Array a elongate d a = unsafeModifyShape (insertDim d 1) a @@ -1563,7 +1573,7 @@ transpose a = backpermute List.reverse List.reverse a -- [[0,1,2], -- [0,1,2]] inflate :: - Int -> + Dim -> Int -> Array a -> Array a @@ -1578,7 +1588,7 @@ inflate d n a = backpermute (insertDim d n) (deleteDim d) a -- [[12,0,13,0,14,0,15], -- [16,0,17,0,18,0,19], -- [20,0,21,0,22,0,23]]] -intercalate:: [Int] -> Array a -> Array a -> Array a +intercalate:: Dims -> Array a -> Array a -> Array a intercalate ds i a = joins ds $ asArray (List.intersperse i (arrayAs (extracts ds a))) -- | Intersperse an element along dimensions. @@ -1590,7 +1600,7 @@ intercalate ds i a = joins ds $ asArray (List.intersperse i (arrayAs (extracts d -- [[12,0,13,0,14,0,15], -- [16,0,17,0,18,0,19], -- [20,0,21,0,22,0,23]]] -intersperse :: [Int] -> a -> Array a -> Array a +intersperse :: Dims -> a -> Array a -> Array a intersperse ds i a = intercalate ds (konst (deleteDims ds (shape a)) i) a -- | Concatenate and replace dimensions, creating a new dimension at the supplied postion. @@ -1601,7 +1611,7 @@ intersperse ds i a = intercalate ds (konst (deleteDims ds (shape a)) i) a -- [2,6,10,14,18,22], -- [3,7,11,15,19,23]] concats :: - [Int] -> + Dims -> Int -> Array a -> Array a @@ -1620,7 +1630,7 @@ concats ds n a = backpermute concatDims unconcatDims a -- [12,13,14,15], -- [16,17,18,19]]] rotate :: - Int -> + Dim -> Int -> Array a -> Array a @@ -1636,7 +1646,7 @@ rotate d r a = backpermute id (modifyDim d (\i -> (r + i) `mod` (shape a !! d))) -- [4,5,6,7], -- [0,1,2,3]]] reverses :: - [Int] -> + Dims -> Array a -> Array a reverses ds a = backpermute id (reverseIndex ds (shape a)) a @@ -1651,7 +1661,7 @@ reverses ds a = backpermute id (reverseIndex ds (shape a)) a -- UnsafeArray [2,2] [2,3,1,4] -- >>> sorts [0,1] (array [2,2] [2,3,1,4]) -- UnsafeArray [2,2] [1,2,3,4] -sorts :: (Ord a) => [Int] -> Array a -> Array a +sorts :: (Ord a) => Dims -> Array a -> Array a sorts ds a = joins ds $ unsafeModifyVector sortV (extracts ds a) -- | The indices into the array if it were sorted by a comparison function along the dimensions supplied. @@ -1659,14 +1669,14 @@ sorts ds a = joins ds $ unsafeModifyVector sortV (extracts ds a) -- >>> import Data.Ord (Down (..)) -- >>> sortsBy [0] (fmap Down) (array [2,2] [2,3,1,4]) -- UnsafeArray [2,2] [2,3,1,4] -sortsBy :: (Ord b) => [Int] -> (Array a -> Array b) -> Array a -> Array a +sortsBy :: (Ord b) => Dims -> (Array a -> Array b) -> Array a -> Array a sortsBy ds c a = joins ds $ unsafeModifyVector (sortByV c) (extracts ds a) -- | The indices into the array if it were sorted along the dimensions supplied. -- -- >>> orders [0] (array [2,2] [2,3,1,4]) -- UnsafeArray [2] [1,0] -orders :: (Ord a) => [Int] -> Array a -> Array Int +orders :: (Ord a) => Dims -> Array a -> Array Int orders ds a = unsafeModifyVector orderV (extracts ds a) -- | The indices into the array if it were sorted by a comparison function along the dimensions supplied. @@ -1674,7 +1684,7 @@ orders ds a = unsafeModifyVector orderV (extracts ds a) -- >>> import Data.Ord (Down (..)) -- >>> ordersBy [0] (fmap Down) (array [2,2] [2,3,1,4]) -- UnsafeArray [2] [0,1] -ordersBy :: (Ord b) => [Int] -> (Array a -> Array b) -> Array a -> Array Int +ordersBy :: (Ord b) => Dims -> (Array a -> Array b) -> Array a -> Array Int ordersBy ds c a = unsafeModifyVector (orderByV c) (extracts ds a) -- * transmission @@ -1687,7 +1697,7 @@ ordersBy ds c a = unsafeModifyVector (orderByV c) (extracts ds a) -- [[0,1,2], -- [3,4,5], -- [0,1,2]] -telecasts :: [Int] -> [Int] -> (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c +telecasts :: Dims -> Dims -> (Array a -> Array b -> Array c) -> Array a -> Array b -> Array c telecasts dsa dsb f a b = zipWith f (extracts dsa a) (extracts dsb b) & joins dsa -- | Apply a binary array function to two arrays with matching shapes across the supplied dimensions. Checks shape. @@ -1696,7 +1706,7 @@ telecasts dsa dsb f a b = zipWith f (extracts dsa a) (extracts dsb b) & joins ds -- >>> b = D.array [1] [1] -- >>> telecastsSafe [0] [0] (zipWith (+)) a b -- Left (NumHaskException {errorMessage = "MisMatched telecasting"}) -telecastsSafe :: [Int] -> [Int] -> (Array a -> Array b -> Array c) -> Array a -> Array b -> Either NumHaskException (Array c) +telecastsSafe :: Dims -> Dims -> (Array a -> Array b -> Array c) -> Array a -> Array b -> Either NumHaskException (Array c) telecastsSafe dsa dsb f a b = bool (Right $ telecasts dsa dsb f a b) diff --git a/src/NumHask/Array/Fixed.hs b/src/NumHask/Array/Fixed.hs index bace549..afff2c7 100644 --- a/src/NumHask/Array/Fixed.hs +++ b/src/NumHask/Array/Fixed.hs @@ -25,7 +25,11 @@ module NumHask.Array.Fixed unsafeModifyShape, unsafeModifyVector, - -- * Dependent type + -- * Dimensions + Dim, + Dims, + + -- * Dependent type SomeArray (..), someArray, @@ -464,6 +468,13 @@ unsafeModifyShape a = unsafeArray (asVector a) unsafeModifyVector :: (KnownNats s) => (FromVector u a) => (FromVector v b) => (u -> v) -> Array s a -> Array s b unsafeModifyVector f a = unsafeArray (asVector (f (vectorAs (asVector a)))) +-- | Representation of an index into a shape (a type-level [Nat]). The index is a dimension of the shape. +type Dim = SNat + +-- | Representation of indexes into a shape (a type-level [Nat]). The indexes are dimensions of the shape. +type Dims = SNats + + -- | A fixed Array with a hidden shape. -- -- The library design encourages the use of dynamic arrays in preference to dependent-type styles such as this. In particular, no attempt has been made to prove to the compiler that a particular Shape (resulting from any of the supplied functions) exists. Life is short. diff --git a/src/NumHask/Array/Shape.hs b/src/NumHask/Array/Shape.hs index 8bc75b2..ea993a5 100644 --- a/src/NumHask/Array/Shape.hs +++ b/src/NumHask/Array/Shape.hs @@ -55,10 +55,6 @@ module NumHask.Array.Shape Fins (..), toFins, - -- * Dimensions - Dim, - Dims, - -- operators rank, Rank, @@ -345,12 +341,6 @@ instance Show (Fins n) where toFins :: forall s. (KnownNats s) => [Int] -> Maybe (Fins s) toFins xs = bool Nothing (Just (UnsafeFins xs)) (isFins xs (valuesOf @s)) --- | An SNat (a type-level Nat) that represents an index into an SNats (a type-level [Nat]). The index is a dimension of the shape. -type Dim = SNat - --- | An SNats (a type-level [Nat]) that represents indexes into an SNats (a type-level [Nat]). The indexes are dimensions of the shape. -type Dims = SNats - -- | Number of dimensions -- -- >>> rank @Int [2,3,4]