From 12d74ee987b9da13f3a383367cf872a3b141d829 Mon Sep 17 00:00:00 2001 From: Robert Stone Date: Fri, 25 Oct 2019 17:39:59 -0700 Subject: [PATCH] [Perl] - ndarray to native array conversion fix --- perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm | 6 +++++- perl-package/AI-MXNet/t/test_ndarray.t | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm index f75cc84b2a8f..1d968c14a487 100644 --- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm +++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray.pm @@ -116,7 +116,11 @@ method STORABLE_thaw($cloning, $buf, $writable) method split_array(@args) { - $self->shape->[0] > 1 ? $self->split(num_outputs => $self->shape->[0], squeeze_axis => @{ $self->shape } > 1 ? 1 : 0, axis => 0) : [$self]; + my $shape = $self->shape; + return [] if $shape->[0] == 0; + my $list = $self->split(num_outputs=>$shape->[0], + squeeze_axis=>int(@$shape > 1), axis=>0); + $shape->[0] == 1 ? [ $list ] : $list; } method at(Index @indices) diff --git a/perl-package/AI-MXNet/t/test_ndarray.t b/perl-package/AI-MXNet/t/test_ndarray.t index a6cd113c3f89..1e290b4bc715 100644 --- a/perl-package/AI-MXNet/t/test_ndarray.t +++ b/perl-package/AI-MXNet/t/test_ndarray.t @@ -19,7 +19,7 @@ use strict; use warnings; use AI::MXNet qw(mx); use AI::MXNet::TestUtils qw(almost_equal same rand_ndarray randint zip); -use Test::More tests => 251; +use Test::More tests => 261; use PDL; use File::Temp qw(tempdir); use IO::File; @@ -217,6 +217,22 @@ sub test_histogram ok(same($bins->aspdl, pdl([10, 20, 30, 60]))); } +sub test_array_overload +{ + # array conversions are largely calls to mx->nd->split(), but have + # special cases around dimensions of length 0 and 1. + is_deeply([ @{ mx->nd->array(zeros(7, 0)) } ], []); + is_deeply(mx->nd->zeros([3, 7])->[0]->shape, [ 7 ]); + is_deeply(mx->nd->zeros([2, 7])->[0]->shape, [ 7 ]); + is_deeply(mx->nd->zeros([1, 7])->[0]->shape, [ 7 ]); + is_deeply(mx->nd->zeros([3, 7, 11])->[0]->shape, [7, 11]); + is_deeply(mx->nd->zeros([2, 7, 11])->[0]->shape, [7, 11]); + is_deeply(mx->nd->zeros([1, 7, 11])->[0]->shape, [7, 11]); + is_deeply(mx->nd->zeros([3, 7, 11, 13])->[0]->shape, [7, 11, 13]); + is_deeply(mx->nd->zeros([2, 7, 11, 13])->[0]->shape, [7, 11, 13]); + is_deeply(mx->nd->zeros([1, 7, 11, 13])->[0]->shape, [7, 11, 13]); +} + test_ndarray_slice(); test_ndarray_reshape(); test_moveaxis(); @@ -226,3 +242,4 @@ test_linalg_gemm2(); test_image_to_tensor(); test_buffer_load(); test_histogram(); +test_array_overload();