Skip to content


Actually load some sound and label files
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jul 5, 2022
1 parent ebeb9c9 commit 7ed4e58
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
2 changes: 2 additions & 0 deletions horde-ad.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ library testLibrary
-- Other library packages from which modules are imported.
, cereal
, bytestring
, deepseq
, HUnit-approx
, hmatrix
Expand Down
95 changes: 85 additions & 10 deletions test/common/TestSpeechRNN.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@ module TestSpeechRNN (testTrees, shortTestForCITrees) where

import Prelude

import Control.Exception (assert)
import Control.Monad (foldM)
import qualified Data.Array.DynamicS as OT
import Data.Array.Internal (valueOf)
import qualified Data.Array.ShapedS as OS
import qualified Data.ByteString.Lazy as LBS
import Data.List (foldl', unfoldr)
import Data.Proxy (Proxy (Proxy))
import Data.Serialize
import qualified Data.Vector.Generic as V
import GHC.TypeLits (KnownNat)
import Numeric.LinearAlgebra (Matrix, Vector)
import Numeric.LinearAlgebra (Matrix, Numeric, Vector)
import qualified Numeric.LinearAlgebra as HM
import System.IO (hPutStrLn, stderr)
import System.IO
(IOMode (ReadMode), hPutStrLn, stderr, withBinaryFile)
import System.Random
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)
Expand All @@ -37,34 +41,97 @@ shortTestForCITrees = [ speechRNNTestsShort

type SpeechDataBatchS batch_size window_size n_labels r =
type SpeechDataBatch batch_size window_size n_labels r =
( OS.Array '[batch_size, window_size] r
, OS.Array '[batch_size, n_labels] r )

chunksOf :: Int -> [e] -> [[e]]
chunksOf n = go where
go [] = []
go l = let (chunk, rest) = splitAt n l
in chunk : go rest

-- The last chunk is thrown away if smaller than batch size.
-- It crashes if the size of either file doesn't match the other.
-- TODO: perhaps then warn instead of failing an assertion.
-- TODO: perhaps warn about the last chunk, too.
-- TODO: this could be so much more elegant, e.g., if OS.fromList
-- returned the remaining list and so no manual size calculations would
-- be required.
-- TODO: performance, see
:: forall batch_size window_size n_labels r.
( Serialize r, Numeric r
, KnownNat batch_size, KnownNat window_size, KnownNat n_labels )
=> Int -> LBS.ByteString -> LBS.ByteString
-> [SpeechDataBatch batch_size window_size n_labels r]
decodeSpeechData len soundsBs labelsBs =
let soundsChunkSize = valueOf @batch_size * valueOf @window_size
labelsChunkSize = valueOf @batch_size * valueOf @n_labels
!_A1 = assert (fromIntegral (LBS.length soundsBs) * labelsChunkSize
== fromIntegral (LBS.length labelsBs) * soundsChunkSize) ()
cutBs :: Int -> LBS.ByteString -> [[r]]
cutBs chunkSize bs =
let list :: [r] =
case decodeLazy
$ LBS.append (encodeLazy
$ len * chunkSize `div` valueOf @batch_size)
bs of
Left err -> error err
Right l -> l
in filter (\ch -> length ch >= chunkSize)
$ chunksOf chunkSize list
soundsChunks :: [[r]] = cutBs soundsChunkSize soundsBs
labelsChunks :: [[r]] = cutBs labelsChunkSize labelsBs
!_A2 = assert (length soundsChunks > 0) ()
!_A3 = assert (length soundsChunks == length labelsChunks) ()
:: [r] -> [r] -> SpeechDataBatch batch_size window_size n_labels r
makeSpeechDataBatch soundsCh labelsCh =
(OS.fromList soundsCh, OS.fromList labelsCh)
in zipWith makeSpeechDataBatch soundsChunks labelsChunks

:: forall batch_size window_size n_labels r.
( Serialize r, Numeric r
, KnownNat batch_size, KnownNat window_size, KnownNat n_labels )
=> Int -> FilePath -> FilePath
-> IO [SpeechDataBatch batch_size window_size n_labels r]
loadSpeechData len soundsPath labelsPath =
withBinaryFile soundsPath ReadMode $ \soundsHandle ->
withBinaryFile labelsPath ReadMode $ \labelsHandle -> do
soundsContents <- LBS.hGetContents soundsHandle
labelsContents <- LBS.hGetContents labelsHandle
let !_A1 = assert (LBS.length soundsContents > 0) ()
return $! decodeSpeechData len soundsContents labelsContents

:: forall out_width batch_size window_size n_labels d r m.
( KnownNat out_width, KnownNat batch_size, KnownNat window_size
, KnownNat n_labels
, r ~ Double, d ~ 'DModeGradient, m ~ DualMonadGradient Double )
=> String
-> Int
-> Int
-> (forall out_width' batch_size' window_size' n_labels'.
(DualMonad d r m, KnownNat out_width', KnownNat batch_size')
( DualMonad d r m, KnownNat out_width', KnownNat batch_size'
, KnownNat n_labels' )
=> Proxy out_width'
-> SpeechDataBatchS batch_size' window_size' n_labels' r
-> SpeechDataBatch batch_size' window_size' n_labels' r
-> DualNumberVariables d r
-> m (DualNumber d r))
-> (forall out_width' batch_size' window_size' n_labels'.
(IsScalar d r, KnownNat out_width', KnownNat batch_size')
( IsScalar d r, KnownNat out_width', KnownNat batch_size'
, KnownNat n_labels' )
=> Proxy out_width'
-> SpeechDataBatchS batch_size' window_size' n_labels' r
-> SpeechDataBatch batch_size' window_size' n_labels' r
-> Domains r
-> r)
-> (forall out_width'. KnownNat out_width'
=> Proxy out_width' -> (Int, [Int], [(Int, Int)], [OT.ShapeL]))
-> Double
-> TestTree
speechTestCaseRNNS prefix epochs maxBatches trainWithLoss ftest flen expected =
speechTestCaseRNN prefix epochs maxBatches trainWithLoss ftest flen expected =
testCase prefix $
1.0 @?= 1.0

Expand All @@ -74,4 +141,12 @@ mnistRNNTestsLong = testGroup "Speech RNN long tests"

speechRNNTestsShort :: TestTree
speechRNNTestsShort = testGroup "Speech RNN short tests"
[ testCase "Load and sanity check speech" $ do
speechDataBatchList <-
@64 @257 @1 @Float
length speechDataBatchList @?= 859 `div` 64

0 comments on commit 7ed4e58

Please sign in to comment.