Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions clash-protocols/src/Protocols/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,26 @@ module Protocols.Vec (
unzip3,
concat,
unconcat,
repeat,
replicate,
) where

-- base
import Data.Tuple
import Prelude ()

-- clash-prelude
import Clash.Prelude hiding (concat, split, unconcat, unzip, unzip3, zip, zip3)
import Clash.Prelude hiding (
concat,
repeat,
replicate,
split,
unconcat,
unzip,
unzip3,
zip,
zip3,
)
import Clash.Prelude qualified as C

-- clash-protocols-base
Expand Down Expand Up @@ -114,7 +126,7 @@ unconcat SNat = Circuit (swap . bimap (C.unconcat SNat) C.concat)
uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 f (a, b, c) = f a b c

-- Append three vectors of `a` into one vector of `a`.
-- | Append three vectors of `a` into one vector of `a`.
append3Vec ::
(KnownNat n0, KnownNat n1, KnownNat n2) =>
C.Vec n0 a ->
Expand All @@ -123,11 +135,23 @@ append3Vec ::
C.Vec (n0 + n1 + n2) a
append3Vec v0 v1 v2 = v0 ++ v1 ++ v2

-- Split a C.Vector of 3-tuples into three vectors of the same length.
-- | Split a C.Vector of 3-tuples into three vectors of the same length.
split3Vec ::
(KnownNat n0, KnownNat n1, KnownNat n2) =>
C.Vec (n0 + n1 + n2) a ->
(C.Vec n0 a, C.Vec n1 a, C.Vec n2 a)
split3Vec v = (v0, v1, v2)
where
(v0, splitAtI -> (v1, v2)) = splitAtI v

{- | repeat a circuit for a number of times, the number of times the circuit is repeated
is determined by the type-level natural number `n`.
-}
repeat :: (C.KnownNat n) => Circuit a b -> Circuit (Vec n a) (Vec n b)
repeat (Circuit function) = Circuit (C.unzip . uncurry (zipWith (curry function)))

{- | replicate a circuit for a given number of times, the number of times the circuit is replicated
is given by the supplied `SNat n`.
-}
replicate :: SNat n -> Circuit a b -> Circuit (Vec n a) (Vec n b)
replicate SNat = repeat
29 changes: 29 additions & 0 deletions clash-protocols/tests/Tests/Protocols/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Test.Tasty.TH (testGroupGenerator)

-- clash-protocols (me!)
import Protocols
import Protocols.Df qualified as Df
import Protocols.Vec qualified as Vec

import Clash.Hedgehog.Sized.Vector (genVec)
Expand Down Expand Up @@ -176,6 +177,34 @@ prop_unconcat =
dut = Vec.unconcat C.d2
model = C.unconcat C.d2

prop_repeat :: Property
prop_repeat =
idWithModel
@(C.Vec 3 (Df System Int))
@(C.Vec 3 (Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVecData genSmallInt
dut = Vec.repeat (Df.map succ)
model = fmap (fmap succ)

prop_replicate :: Property
prop_replicate =
idWithModel
@(C.Vec 3 (Df System Int))
@(C.Vec 3 (Df System Int))
defExpectOptions
gen
model
dut
where
gen = genVecData genSmallInt
dut = Vec.replicate C.SNat (Df.map succ)
model = fmap (fmap succ)

tests :: TestTree
tests =
-- TODO: Move timeout option to hedgehog for better error messages.
Expand Down
Loading