|
| 1 | +@testset "batched_mul" begin |
| 2 | + using NNlib: batched_mul, batched_mul!, batched_vec, |
| 3 | + batched_adjoint, batched_transpose |
| 4 | + |
| 5 | + A = randn(Float32, 3,3,2); |
| 6 | + B = randn(Float32, 3,3,2); |
| 7 | + |
| 8 | + C = batched_mul(A, B) |
| 9 | + @test MtlArray(C) ≈ batched_mul(MtlArray(A), MtlArray(B)) |
| 10 | + |
| 11 | + Ct = batched_mul(batched_transpose(A), B) |
| 12 | + @test MtlArray(Ct) ≈ batched_mul(batched_transpose(MtlArray(A)), MtlArray(B)) |
| 13 | + |
| 14 | + Ca = batched_mul(A, batched_adjoint(B)) |
| 15 | + @test MtlArray(Ca) ≈ batched_mul(MtlArray(A), batched_adjoint(MtlArray(B))) |
| 16 | + |
| 17 | + # 5-arg batched_mul! |
| 18 | + C .= pi |
| 19 | + batched_mul!(C, A, B, 2f0, 3f0) |
| 20 | + gpuCpi = MtlArray(similar(C)) .= pi |
| 21 | + @test MtlArray(C) ≈ batched_mul!(gpuCpi, MtlArray(A), MtlArray(B), 2f0, 3f0) |
| 22 | + |
| 23 | + # PermutedDimsArray |
| 24 | + @test MtlArray(Ct) ≈ batched_mul(PermutedDimsArray(MtlArray(A), (2,1,3)), MtlArray(B)) |
| 25 | + |
| 26 | + D = permutedims(B, (1,3,2)) |
| 27 | + Cp = batched_mul(batched_adjoint(A), B) |
| 28 | + @test_broken MtlArray(Cp) ≈ batched_mul(batched_adjoint(MtlArray(A)), PermutedDimsArray(MtlArray(D), (1,3,2))) |
| 29 | + |
| 30 | + # Methods which reshape |
| 31 | + M = randn(Float32, 3,3) |
| 32 | + |
| 33 | + Cm = batched_mul(A, M) |
| 34 | + @test MtlArray(Cm) ≈ batched_mul(MtlArray(A), MtlArray(M)) |
| 35 | + |
| 36 | + Cv = batched_vec(permutedims(A,(3,1,2)), M) |
| 37 | + @test_broken MtlArray(Cv) ≈ batched_vec(PermutedDimsArray(MtlArray(A),(3,1,2)), MtlArray(M)) |
| 38 | +end |
| 39 | + |
| 40 | +function print_array_strs(x) |
| 41 | + str = sprint((io, x)->show(io, MIME"text/plain"(), x), x) |
| 42 | + return @view split(str, '\n')[2:end] |
| 43 | +end |
| 44 | + |
| 45 | +@testset "BatchedAdjOrTrans" begin |
| 46 | + x = rand(Float32, 3, 4, 2) |
| 47 | + y = MtlArray(x) |
| 48 | + |
| 49 | + bax = batched_adjoint(x) |
| 50 | + btx = batched_transpose(x) |
| 51 | + bay = batched_adjoint(y) |
| 52 | + bty = batched_transpose(y) |
| 53 | + |
| 54 | + @test sprint(show, bax) == sprint(show, bay) |
| 55 | + @test sprint(show, btx) == sprint(show, bty) |
| 56 | + |
| 57 | + @test print_array_strs(bax) == print_array_strs(bay) |
| 58 | + @test print_array_strs(btx) == print_array_strs(bty) |
| 59 | + |
| 60 | + @test Array(bax) == Array(bay) |
| 61 | + @test collect(bax) == collect(bay) |
| 62 | + @test Array(btx) == Array(bty) |
| 63 | + @test collect(btx) == collect(bty) |
| 64 | + |
| 65 | + for shape in (:, (12, 2)) |
| 66 | + rbax = reshape(bax, shape) |
| 67 | + rbtx = reshape(btx, shape) |
| 68 | + rbay = reshape(bay, shape) |
| 69 | + rbty = reshape(bty, shape) |
| 70 | + |
| 71 | + @test sprint(show, rbax) == sprint(show, rbay) |
| 72 | + @test sprint(show, rbtx) == sprint(show, rbty) |
| 73 | + |
| 74 | + @test print_array_strs(rbax) == print_array_strs(rbay) |
| 75 | + @test print_array_strs(rbtx) == print_array_strs(rbty) |
| 76 | + |
| 77 | + @test Array(rbax) == Array(rbay) |
| 78 | + @test collect(rbax) == collect(rbay) |
| 79 | + @test Array(rbtx) == Array(rbty) |
| 80 | + @test collect(rbtx) == collect(rbty) |
| 81 | + end |
| 82 | +end |
0 commit comments