Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 31, 2024
1 parent df6e89e commit e11538c
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ testOverleafPP = do
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rsum (rgather [50] v1 (\\[i3] -> [remF i3 28]))"
show deltas
@?= "ShareG 100000002 (SumG (SNat @50) (ShareG 100000001 (GatherR [50] (InputG (FTKR [28] FTKScalar) (InputId 0)) <function>)))"
@?= "ShareG 100000002 (SumG (SNat @50) (STKR (SNat @0) (STKScalar Double)) (ShareG 100000001 (GatherR [50] (InputG (FTKR [28] FTKScalar) (InputId 0)) <function>)))"

foo :: RealFloatF a => (a, a, a) -> a
foo (x, y, z) =
Expand Down
2 changes: 1 addition & 1 deletion test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -780,4 +780,4 @@ testTomsSlicePP = do
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rsum (rgather [128] (rfromIntegral (rfromS siota) * rreplicate 32 (rsum (rgather [96] m1 (\\[i40] -> [remF (quotF i40 3) 32, remF i40 3]) * rgather [96] m1 (\\[i41] -> [remF (quotF i41 3) 32, 1 + remF i41 3])))) (\\[i39] -> [remF i39 32]))"
show delta
@?= "ShareG 100000010 (Sum0R (ShareG 100000009 (ReplicateG (SNat @4) (ShareG 100000008 (ScaleG (AstRaw {unAstRaw = AstShare (AstVarId 100000016) (AstFromIntegralR (AstRFromS AstIotaS))}) (ShareG 100000007 (ReplicateG (SNat @32) (ShareG 100000006 (AddG (Dot0R (AstRaw {unAstRaw = AstShare (AstVarId 100000015) (AstReshape [96] (AstGather [32,3] (AstVar (FTKR [32,4] FTKScalar) (AstVarId 100000001)) ([AstVarId 100000012,AstVarId 100000013],[AstVar FTKScalar (AstVarId 100000012),AstSumOfList [AstConcrete FTKScalar 1,AstVar FTKScalar (AstVarId 100000013)]])))}) (ShareG 100000002 (ReshapeR [96] (ShareG 100000001 (GatherR [32,3] (InputG (FTKR [32,4] FTKScalar) (InputId 0)) <function>))))) (Dot0R (AstRaw {unAstRaw = AstShare (AstVarId 100000014) (AstReshape [96] (AstGather [32,3] (AstVar (FTKR [32,4] FTKScalar) (AstVarId 100000001)) ([AstVarId 100000010,AstVarId 100000011],[AstVar FTKScalar (AstVarId 100000010),AstVar FTKScalar (AstVarId 100000011)])))}) (ShareG 100000005 (ReshapeR [96] (ShareG 100000004 (GatherR [32,3] (InputG (FTKR [32,4] FTKScalar) (InputId 0)) <function>))))))))))))))"
@?= "ShareG 100000010 (Sum0R (ShareG 100000009 (ReplicateG (SNat @4) (STKR (SNat @1) (STKScalar Double)) (ShareG 100000008 (ScaleG (AstRaw {unAstRaw = AstShare (AstVarId 100000016) (AstFromIntegralR (AstRFromS AstIotaS))}) (ShareG 100000007 (ReplicateG (SNat @32) (STKR (SNat @0) (STKScalar Double)) (ShareG 100000006 (AddG (Dot0R (AstRaw {unAstRaw = AstShare (AstVarId 100000015) (AstReshape [96] (AstGather [32,3] (AstVar (FTKR [32,4] FTKScalar) (AstVarId 100000001)) ([AstVarId 100000012,AstVarId 100000013],[AstVar FTKScalar (AstVarId 100000012),AstSumOfList [AstConcrete FTKScalar 1,AstVar FTKScalar (AstVarId 100000013)]])))}) (ShareG 100000002 (ReshapeR [96] (ShareG 100000001 (GatherR [32,3] (InputG (FTKR [32,4] FTKScalar) (InputId 0)) <function>))))) (Dot0R (AstRaw {unAstRaw = AstShare (AstVarId 100000014) (AstReshape [96] (AstGather [32,3] (AstVar (FTKR [32,4] FTKScalar) (AstVarId 100000001)) ([AstVarId 100000010,AstVarId 100000011],[AstVar FTKScalar (AstVarId 100000010),AstVar FTKScalar (AstVarId 100000011)])))}) (ShareG 100000005 (ReshapeR [96] (ShareG 100000004 (GatherR [32,3] (InputG (FTKR [32,4] FTKScalar) (InputId 0)) <function>))))))))))))))"
32 changes: 16 additions & 16 deletions test/simplified/TestGatherSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ testGatherSimpPP23 = do
gatherReshape22 @(AstTensor AstMethodLet PrimalSpan)
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 219
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 530
length (show t1) @?= 289
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 565
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @2 @2 [2, 6]
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 219
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 530
length (show t2) @?= 289
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 565

-- Depending on if and how transpose it desugared, this may or may not result
-- in dozens of nested gathers that should vanish after simplification.
Expand Down Expand Up @@ -454,31 +454,31 @@ testGatherSimpPP33 = do
resetVarCounter
let !t1 = gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 625
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 625
length (show t1) @?= 730
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 730
resetVarCounter
let !t2 = (\t -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @10 [8, 16] t))
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 544
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 544
length (show t2) @?= 649
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 649

testGatherSimpPP34 :: Assertion
testGatherSimpPP34 = do
resetVarCounter
let !t1 = (\t -> rbuild1 4 (\i ->
gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan) (t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 972
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 972
length (show t1) @?= 1182
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 1182
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
(\t' -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @10 [8, 16] t'))
(t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 725
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 725
length (show t2) @?= 935
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 935

-- scatters instead of gathers

Expand Down Expand Up @@ -713,10 +713,10 @@ testReluSimpPP = do
resetVarCounter
let !t1 = barRelu10xSlower @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (FTKR [1,2,2,1,2,2,2,2,2,1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 14821
length (show (simplifyInline @(TKR 10 Float) t1)) @?= 14821
length (show t1) @?= 20421
length (show (simplifyInline @(TKR 10 Float) t1)) @?= 20421
resetVarCounter
let !t2 = barRelu @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (FTKR [1,2,2,1,2,2,2,2,2,1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 11813
length (show (simplifyInline @(TKR 10 Float) t2)) @?= 14821
length (show t2) @?= 12373
length (show (simplifyInline @(TKR 10 Float) t2)) @?= 20421
2 changes: 1 addition & 1 deletion test/simplified/TestMnistFCNNR.hs
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ testVT2OPPNonLin = do
(_, ast3) = funToAst (FTKR (0 :$: ZSR) (FTKScalar @Float))
(const $ afcnn2TnonLin constant)
"\\dummy" ++ " -> " ++ printAstSimple renames ast3
@?= "\\dummy -> tlet (exp (rsum (rtranspse [1,0] (rreplicate 2 (tlet (rcast (rsum (rtranspose [1,0] (rreplicate 5 (rcast (tlet (rsum (rfromPrimal (rtranspose [1,0] (rreplicate 4 (rreplicate 3 (rscalar 7.0)))) * rfromPrimal (tconcrete (FTKR [3,4] FTKScalar) (rfromListLinear [3,4] [1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0]))) + rcast (rfromPrimal (tconcrete (FTKR [4] FTKScalar) (rfromListLinear [4] [1.0,2.0,3.0,4.0])))) (\\v3 -> tlet (rfromPrimal (recip (rreplicate 4 (rscalar 1.0) + exp (negate (rprimalPart v3))))) (\\v4 -> rD (rprimalPart v4) (rdualPart (rfromPrimal (rprimalPart v4 * (rreplicate 4 (rscalar 1.0) - rprimalPart v4)) * rD (rreplicate 4 (rscalar 0.0)) (rdualPart v3)))))))) * rfromPrimal (tconcrete (FTKR [4,5] FTKScalar) (rfromListLinear [4,5] [1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0,3.0,4.0,4.0,4.0,4.0,4.0])))) + rfromPrimal (rcast (tconcrete (FTKR [5] FTKScalar) (rfromListLinear [5] [1.0,2.0,3.0,4.0,5.0])))) (\\v6 -> tlet (rfromPrimal (recip (rreplicate 5 (rscalar 1.0) + exp (negate (rprimalPart v6))))) (\\v7 -> rD (rprimalPart v7) (rdualPart (rfromPrimal (rprimalPart v7 * (rreplicate 5 (rscalar 1.0) - rprimalPart v7)) * rD (rreplicate 5 (rscalar 0.0)) (rdualPart v6))))))) * rfromPrimal (tconcrete (FTKR [5,2] FTKScalar) (rfromListLinear [5,2] [1.0,1.0,2.0,2.0,3.0,3.0,4.0,4.0,5.0,5.0]))) + rfromPrimal (rcast (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [1.0,2.0]))))) (\\v9 -> rreplicate 2 (recip (rsum v9)) * v9)"
@?= "\\dummy -> tlet (exp (rsum (rtranspose [1,0] (rreplicate 2 (tlet (rcast (rsum (rtranspose [1,0] (rreplicate 5 (rcast (tlet (rsum (rfromPrimal (rtranspose [1,0] (rreplicate 4 (rreplicate 3 (rscalar 7.0)))) * rfromPrimal (tconcrete (FTKR [3,4] FTKScalar) (rfromListLinear [3,4] [1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0]))) + rcast (rfromPrimal (tconcrete (FTKR [4] FTKScalar) (rfromListLinear [4] [1.0,2.0,3.0,4.0])))) (\\v3 -> tlet (rfromPrimal (recip (rreplicate 4 (rscalar 1.0) + exp (negate (rprimalPart v3))))) (\\v4 -> rD (rprimalPart v4) (rdualPart (rfromPrimal (rprimalPart v4 * (rreplicate 4 (rscalar 1.0) - rprimalPart v4)) * rD (rreplicate 4 (rscalar 0.0)) (rdualPart v3)))))))) * rfromPrimal (tconcrete (FTKR [4,5] FTKScalar) (rfromListLinear [4,5] [1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,3.0,3.0,3.0,3.0,3.0,4.0,4.0,4.0,4.0,4.0])))) + rfromPrimal (rcast (tconcrete (FTKR [5] FTKScalar) (rfromListLinear [5] [1.0,2.0,3.0,4.0,5.0])))) (\\v6 -> tlet (rfromPrimal (recip (rreplicate 5 (rscalar 1.0) + exp (negate (rprimalPart v6))))) (\\v7 -> rD (rprimalPart v7) (rdualPart (rfromPrimal (rprimalPart v7 * (rreplicate 5 (rscalar 1.0) - rprimalPart v7)) * rD (rreplicate 5 (rscalar 0.0)) (rdualPart v6))))))) * rfromPrimal (tconcrete (FTKR [5,2] FTKScalar) (rfromListLinear [5,2] [1.0,1.0,2.0,2.0,3.0,3.0,4.0,4.0,5.0,5.0]))) + rfromPrimal (rcast (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [1.0,2.0]))))) (\\v9 -> rreplicate 2 (recip (rsum v9)) * v9)"

testVT2OPPNonLin2 :: Assertion
testVT2OPPNonLin2 = do
Expand Down

0 comments on commit e11538c

Please sign in to comment.