Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Perl] - ndarray to native array conversion fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tlby committed Oct 26, 2019
1 parent aadef2d commit 12d74ee
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
6 changes: 5 additions & 1 deletion perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ method STORABLE_thaw($cloning, $buf, $writable)

method split_array(@args)
{
$self->shape->[0] > 1 ? $self->split(num_outputs => $self->shape->[0], squeeze_axis => @{ $self->shape } > 1 ? 1 : 0, axis => 0) : [$self];
my $shape = $self->shape;
return [] if $shape->[0] == 0;
my $list = $self->split(num_outputs=>$shape->[0],
squeeze_axis=>int(@$shape > 1), axis=>0);
$shape->[0] == 1 ? [ $list ] : $list;
}

method at(Index @indices)
Expand Down
19 changes: 18 additions & 1 deletion perl-package/AI-MXNet/t/test_ndarray.t
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(almost_equal same rand_ndarray randint zip);
use Test::More tests => 251;
use Test::More tests => 261;
use PDL;
use File::Temp qw(tempdir);
use IO::File;
Expand Down Expand Up @@ -217,6 +217,22 @@ sub test_histogram
ok(same($bins->aspdl, pdl([10, 20, 30, 60])));
}

sub test_array_overload
{
# array conversions are largely calls to mx->nd->split(), but have
# special cases around dimensions of length 0 and 1.
is_deeply([ @{ mx->nd->array(zeros(7, 0)) } ], []);
is_deeply(mx->nd->zeros([3, 7])->[0]->shape, [ 7 ]);
is_deeply(mx->nd->zeros([2, 7])->[0]->shape, [ 7 ]);
is_deeply(mx->nd->zeros([1, 7])->[0]->shape, [ 7 ]);
is_deeply(mx->nd->zeros([3, 7, 11])->[0]->shape, [7, 11]);
is_deeply(mx->nd->zeros([2, 7, 11])->[0]->shape, [7, 11]);
is_deeply(mx->nd->zeros([1, 7, 11])->[0]->shape, [7, 11]);
is_deeply(mx->nd->zeros([3, 7, 11, 13])->[0]->shape, [7, 11, 13]);
is_deeply(mx->nd->zeros([2, 7, 11, 13])->[0]->shape, [7, 11, 13]);
is_deeply(mx->nd->zeros([1, 7, 11, 13])->[0]->shape, [7, 11, 13]);
}

test_ndarray_slice();
test_ndarray_reshape();
test_moveaxis();
Expand All @@ -226,3 +242,4 @@ test_linalg_gemm2();
test_image_to_tensor();
test_buffer_load();
test_histogram();
test_array_overload();

0 comments on commit 12d74ee

Please sign in to comment.