Skip to content

Commit

Permalink
miscellanea
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyday567 committed Jul 31, 2024
1 parent 1259a5d commit c7cc6a4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 82 deletions.
10 changes: 5 additions & 5 deletions readme.org
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Ok, five modules loaded.
#+end_example

#+begin_src haskell-ng :results output
a = iota [2,3,4] :: D.Array Int
a = range [2,3,4] :: D.Array Int
pretty a
:t S.modifyDim
-- :t \d o l a -> backpermute (S.replaceDim d l) (S.modifyDim d (+o)) a
Expand All @@ -134,7 +134,7 @@ pretty a

#+begin_src haskell-ng :results output
import qualified Data.List as List
x = iota [2,3]
x = range [2,3]
x
D.backpermute (List.drop 1 :: [Int] -> [Int]) x
#+end_src
Expand All @@ -143,7 +143,7 @@ D.backpermute (List.drop 1 :: [Int] -> [Int]) x
: UnsafeArray [2,3] [0,1,2,3,4,5]
: UnsafeArray [3] [0,0,0]

** iota
** range

#+begin_src haskell-ng :results output
D.range (D.toScalar 3)
Expand Down Expand Up @@ -533,8 +533,8 @@ D.drops [1,0] m
* scalar applications

#+begin_src haskell-ng :results output
S.shapenL [] 20
S.flattenL [] []
S.shapen [] 20
S.flatten [] []
S.deleteDim [] 2
S.replaceDim 0 1 []
S.modifyDim 0 (+1) []
Expand Down
75 changes: 48 additions & 27 deletions src/NumHask/Array/Dynamic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ module NumHask.Array.Dynamic
flatten,
shapen,
backpermute,

-- Scalar conversions
fromScalar,
toScalar,
Expand All @@ -44,7 +45,7 @@ module NumHask.Array.Dynamic
asScalar,

-- * Creation
iota,
range,
indices,
ident,
konst,
Expand Down Expand Up @@ -79,6 +80,7 @@ module NumHask.Array.Dynamic
delete,
append,
prepend,
modify,
couple,
expand,
expandr,
Expand Down Expand Up @@ -187,10 +189,10 @@ data Array a = UnsafeArray (V.Vector Int) (V.Vector a)
deriving stock (Eq, Ord, Show)

instance Functor Array where
fmap f (UnsafeArray s a) = UnsafeArray s (V.map f a)
fmap f = unsafeModifyVector (V.map f)

instance Foldable Array where
foldr x a (UnsafeArray _ v) = V.foldr x a v
foldr f x0 a = V.foldr f x0 (asVector a)

instance Traversable Array where
traverse f (UnsafeArray s v) =
Expand Down Expand Up @@ -361,7 +363,7 @@ flattenV :: V.Vector Int -> V.Vector Int -> Int
flattenV ns xs = V.sum $ V.zipWith (*) xs (V.drop 1 $ V.scanr (*) one ns)

shapenV :: V.Vector Int -> Int -> V.Vector Int
shapenV ns x = V.fromList $ S.shapenL (V.toList ns) x
shapenV ns x = V.fromList $ S.shapen (V.toList ns) x

-- | convert from a shape index to a flat index
--
Expand All @@ -385,7 +387,7 @@ shapen ns x = vectorAs $ shapenV (asVector ns) x
-- >>> index a [1,2,3]
-- 24
index :: (FromVector u Int) => Array a -> u -> a
index (UnsafeArray s v) i = V.unsafeIndex v (flatten s (asVector i))
index (UnsafeArray s v) i = V.unsafeIndex v (flattenV s (asVector i))

-- | tabulate an array supplying a shape and a generating function
--
Expand Down Expand Up @@ -444,40 +446,40 @@ isScalar a = rank a == zero
-- >>> asSingleton (toScalar 4)
-- UnsafeArray [1] [4]
asSingleton :: Array a -> Array a
asSingleton (UnsafeArray s v) = UnsafeArray (bool s (V.singleton 1) (V.null s)) v
asSingleton = unsafeModifyShape (\s -> bool s (V.singleton 1) (V.null s))

-- | convert arrays with shape [1] to scalars
--
-- >>> asScalar (singleton 3)
-- UnsafeArray [] [3]
asScalar :: Array a -> Array a
asScalar (UnsafeArray s v) = UnsafeArray (bool s V.empty (s == V.singleton 1)) v
asScalar = unsafeModifyShape (\s -> bool s V.empty (s == V.singleton 1))

-- | A flat enumeration
--
-- >>> pretty $ iota [2,3]
-- >>> pretty $ range [2,3]
-- [[0,1,2],
-- [3,4,5]]
iota :: [Int] -> Array Int
iota xs = tabulate xs (flatten xs)
range :: [Int] -> Array Int
range xs = tabulate xs (flatten xs)

-- * operations

-- | apply a function that takes a [(dimension,paramter)] to a paramter list and the first dimensions.
-- | apply a function that takes a [(dimension,parameter)] to a parameter list and the first dimensions.
--
-- >>> rowWise selects [1,0] a
-- UnsafeArray [4] [13,14,15,16]
rowWise :: ([(Int, x)] -> Array a -> Array a) -> [x] -> Array a -> Array a
rowWise f xs a = f (List.zip [0 ..] xs) a

-- | apply a function that takes a [(dimension,paramter)] to a paramter list and the last dimensions (in reverse).
-- | apply a function that takes a [(dimension,parameter)] to a parameter list and the last dimensions (in reverse).
--
-- >>> colWise selects [1,0] a
-- UnsafeArray [2] [2,14]
colWise :: ([(Int, x)] -> Array a -> Array a) -> [x] -> Array a -> Array a
colWise f xs a = f (List.zip (List.reverse [0 .. (rank a - 1)]) xs) a

-- | apply a function that takes a [(dimension,paramter)] to a paramter list and the first dimensions. In a perfect world, if the function is a backpermute, it should fuse.
-- | apply a function that takes a dimension and parameter and modifies an array and folds a [(dimension,parameter)] list. In a perfect world, if the function is a backpermute, it should fuse.
--
-- >>> dimsWise take [(0,1),(2,2)] a
-- UnsafeArray [1,3,2] [1,2,5,6,9,10]
Expand Down Expand Up @@ -506,7 +508,7 @@ takes ::
[(Int, Int)] ->
Array a ->
Array a
takes ts a = backpermute dsNew (\s -> List.zipWith3 (\d' s' a' -> bool s' (s' + a' + d') (d' < 0)) xsNew s (shape a)) a
takes ts a = backpermute dsNew (List.zipWith3 (\d' a' s' -> bool s' (s' + a' + d') (d' < 0)) xsNew (shape a)) a
where
dsNew = S.replaceDims ds xsAbs
xsNew = S.replaceDims ds xs (replicate (rank a) 0)
Expand Down Expand Up @@ -587,7 +589,7 @@ pad d s' a = tabulate s' (\s -> bool d (index a' s) (s `S.inside` shape a'))
--
-- >>> lpad 0 [5] (array [4] [0..3] :: Array Int)
-- UnsafeArray [5] [0,0,1,2,3]
-- >>> pretty $ lpad 0 [3,3] (iota [2,2] :: Array Int)
-- >>> pretty $ lpad 0 [3,3] (range [2,2] :: Array Int)
-- [[0,0,0],
-- [0,0,1],
-- [0,2,3]]
Expand Down Expand Up @@ -649,11 +651,11 @@ cycle ::
[Int] ->
Array a ->
Array a
cycle s a = backpermute (const s) (shapen (shape a) . (\x -> mod x (size a)) . flatten s) a
cycle s a = backpermute (const s) (shapen (shape a) . (`mod` (size a)) . flatten s) a

Check warning on line 654 in src/NumHask/Array/Dynamic.hs

View workflow job for this annotation

GitHub Actions / hlint

Suggestion in cycle in module NumHask.Array.Dynamic: Redundant bracket ▫︎ Found: "(`mod` (size a))" ▫︎ Perhaps: "(`mod` size a)"

-- | windows xs are xs-sized windows of an array
--
-- >>> D.shape @[Int] $ D.windows [2,2] (D.iota [4,3,2])
-- >>> D.shape @[Int] $ D.windows [2,2] (D.range [4,3,2])
-- [3,2,2,2,2]
windows :: [Int] -> Array a -> Array a
windows xs a = backpermute df wf a
Expand Down Expand Up @@ -717,7 +719,7 @@ undiag r a = tabulate (replicate r (head (shape a))) (\xs -> bool zero (index a
-- [[1,1],
-- [1,1],
-- [1,1]]
konst :: (FromVector u Int) => u -> a -> Array a
konst :: [Int] -> a -> Array a
konst ds a = tabulate ds (const a)

-- | Create an array of rank 1, shape [1].
Expand Down Expand Up @@ -1015,6 +1017,8 @@ append ::
Array a
append d a b = insert d (S.indexOf d (shape a)) a b



-- | Insert along a dimension at the beginning.
--
-- >>> pretty $ prepend 2 (array [2,3] [100..105]) a
Expand All @@ -1031,6 +1035,23 @@ prepend ::
Array a
prepend d a b = insert d 0 b a

-- | Modify using the supplied function along dimension(s) at a position(s).
--
-- >>> pretty $ modify (fmap (100+)) [2] [0] a
-- [[[101,2,3,4],
-- [105,6,7,8],
-- [109,10,11,12]],
-- [[113,14,15,16],
-- [117,18,19,20],
-- [121,22,23,24]]]
modify ::
(Array a -> Array a) ->
[Int] ->
[Int] ->
Array a ->
Array a
modify f ds xs a = joins ds $ modifyE xs f (extracts ds a)

-- | Combine two arrays as rows of a new array.
--
-- >>> pretty $ couple (asArray [1,2,3]) (asArray [4,5,6::Int])
Expand Down Expand Up @@ -1206,7 +1227,7 @@ slices ps a = dimsWise slice ps a

-- | find the starting positions of occurences of one array in another.
--
-- >>> a = D.cycle [4,4] (iota [3]) :: Array Int
-- >>> a = D.cycle [4,4] (range [3]) :: Array Int
-- >>> i = array [2,2] [1,2,2,0] :: Array Int
-- >>> pretty $ D.find i a
-- [[False,True,False],
Expand All @@ -1217,7 +1238,7 @@ find i a = xs
where
i' = rerank (rank a) i
ws = windows (shape i') a
xs = fmap (== i') (extracts (arrayAs (iota [rank a]) <> [rank a * 2 .. (rank ws - 1)]) ws)
xs = fmap (== i') (extracts (arrayAs (range [rank a]) <> [rank a * 2 .. (rank ws - 1)]) ws)

-- | find the ending positions of one array in another except where the array overlaps with another copy.
--
Expand Down Expand Up @@ -1277,7 +1298,7 @@ reverses ::
[Int] ->
Array a ->
Array a
reverses ds a = tabulate (shape a) (index a . S.reverseIndex ds (shape a))
reverses ds a = backpermute id (S.reverseIndex ds (shape a)) a

-- | Remove single dimensions.
--
Expand Down Expand Up @@ -1373,7 +1394,7 @@ telecasts ds f a b = zipWithE f (extracts dsa a) (extracts dsb b) & joins dsa
dsb = fmap snd ds
dsa = fmap fst ds

-- | Apply a binary array function to two arrays with matching shapes across the supplied dimensions. No check on shapes.
-- | Apply a binary array function to two arrays with matching shapes across the supplied dimensions. Checks shape.
--
-- >>> a = D.array [2,3] [0..5]
-- >>> b = D.array [1] [1]
Expand All @@ -1400,7 +1421,7 @@ transmit f a b = extracts ds b & fmap (f a) & joins ds
where
ds = [(rank a) .. (rank b - 1)]

-- | Apply a binary array function to two arrays where the shape of the first array is a prefix of the second array. No checks on shape.
-- | Apply a binary array function to two arrays where the shape of the first array is a prefix of the second array. Checks shape.
--
-- >>> a = D.array [2,3] [0..5]
-- >>> D.transmitSafe (D.zipWithE (+)) (array [3] [1,2,3]) a
Expand Down Expand Up @@ -1450,7 +1471,7 @@ pattern x :| xs <- (uncons -> (x, xs))

-- | zip two arrays at an element level. Could also be called liftS2 or sometink like that.
--
-- > zipWithE == \f a b -> zips (iota (rank a)) (\f a b -> f (D.toScalar a) (D.toScalar b))
-- > zipWithE == \f a b -> zips (range (rank a)) (\f a b -> f (D.toScalar a) (D.toScalar b))
--
-- >>> zipWithE (-) v v
-- UnsafeArray [3] [0,0,0]
Expand All @@ -1459,7 +1480,7 @@ zipWithE f (UnsafeArray s v) (UnsafeArray _ v') = UnsafeArray s (V.zipWith f v v

-- | zip two arrays at an element level, checking for shape consistency.
--
-- > zipWithE == \f a b -> zips (iota (rank a)) (\f a b -> f (D.toScalar a) (D.toScalar b))
-- > zipWithE == \f a b -> zips (range (rank a)) (\f a b -> f (D.toScalar a) (D.toScalar b))
--
-- >>> zipWithESafe (-) v (array [7] [0..6])
-- Left (NumHaskException {errorMessage = "Mismatched zip"})
Expand All @@ -1468,7 +1489,7 @@ zipWithESafe f (UnsafeArray s v) (UnsafeArray s' v') = bool (Left (NumHaskExcept

-- | row-wise difference an array using the supplied function with a lag.
--
-- >>> pretty $ diffE 1 (-) (iota [3,2])
-- >>> pretty $ diffE 1 (-) (range [3,2])
-- [[2,2],
-- [2,2]]
diffE :: Int -> (a -> a -> b) -> Array a -> Array b
Expand All @@ -1480,5 +1501,5 @@ diffE n f a = zipWithE f (rowWise (dimsWise drop) [n] a) (rowWise (dimsWise drop
-- [[100,1,2,3],
-- [4,5,6,7],
-- [8,9,10,11]]
modifyE :: (Eq u, FromVector u Int) => u -> (a -> a) -> Array a -> Array a
modifyE :: [Int] -> (a -> a) -> Array a -> Array a
modifyE ds f a = tabulate (shape a) (\s -> bool id f (s == ds) (index a s))
6 changes: 3 additions & 3 deletions src/NumHask/Array/Fixed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ instance
type Rep (Array s) = [Int]

tabulate f =
UnsafeArray . V.generate (S.size s) $ (f . shapenL s)
UnsafeArray . V.generate (S.size s) $ (f . shapen s)
where
s = shapeOf @s
{-# INLINE tabulate #-}

index (Array v) i = V.unsafeIndex v (flattenL s i)
index (Array v) i = V.unsafeIndex v (flatten s i)
where
s = shapeOf @s
{-# INLINE index #-}
Expand Down Expand Up @@ -337,7 +337,7 @@ reshape ::
) =>
Array s a ->
Array s' a
reshape a = tabulate (index a . shapenL s . flattenL s')
reshape a = tabulate (index a . shapen s . flatten s')
where
s = shapeOf @s
s' = shapeOf @s'
Expand Down
Loading

0 comments on commit c7cc6a4

Please sign in to comment.