This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
ndarray.py
5143 lines (4367 loc) · 173 KB
/
ndarray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=too-many-lines, protected-access
# pylint: disable=import-error, no-name-in-module, undefined-variable
"""NDArray API of MXNet."""
try:
from __builtin__ import slice as py_slice
except ImportError:
from builtins import slice as py_slice
from array import array as native_array
import ctypes
import warnings
import operator
from functools import reduce # pylint: disable=redefined-builtin
import numpy as np
from ..base import _LIB, numeric_types, integer_types
from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle, mx_int, mx_int64
from ..base import ctypes2buffer
from ..runtime import Features
from ..context import Context, current_context
from ..util import is_np_array
from . import _internal
from . import op
from ._internal import NDArrayBase
__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "linspace", "eye", "divide", "equal", "full", "greater",
"greater_equal", "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or",
"logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal",
"onehot_encode", "power", "subtract", "true_divide", "waitall", "_new_empty_handle",
"histogram", "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack",
"from_numpy", "zeros", "indexing_key_expand_implicit_axes", "get_indexing_dispatch_code",
"get_oshape_of_gather_nd_op"]
_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
_STORAGE_TYPE_ROW_SPARSE = 1
_STORAGE_TYPE_CSR = 2
_SIGNED_INT32_UPPER_LIMIT = (2**31 - 1)
# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
None: -1,
np.float32: 0,
np.float64: 1,
np.float16: 2,
np.uint8: 3,
np.int32: 4,
np.int8: 5,
np.int64: 6,
np.bool_: 7,
np.dtype([('bfloat16', np.uint16)]): 12,
}
_DTYPE_MX_TO_NP = {
-1: None,
0: np.float32,
1: np.float64,
2: np.float16,
3: np.uint8,
4: np.int32,
5: np.int8,
6: np.int64,
7: np.bool_,
12: np.dtype([('bfloat16', np.uint16)]),
}
_STORAGE_TYPE_STR_TO_ID = {
'undefined': _STORAGE_TYPE_UNDEFINED,
'default': _STORAGE_TYPE_DEFAULT,
'row_sparse': _STORAGE_TYPE_ROW_SPARSE,
'csr': _STORAGE_TYPE_CSR,
}
_STORAGE_TYPE_ID_TO_STR = {
_STORAGE_TYPE_UNDEFINED: 'undefined',
_STORAGE_TYPE_DEFAULT: 'default',
_STORAGE_TYPE_ROW_SPARSE: 'row_sparse',
_STORAGE_TYPE_CSR: 'csr',
}
_GRAD_REQ_MAP = {
'null': 0,
'write': 1,
'add': 3
}
# pylint: enable= no-member
# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1
_NDARRAY_EMPTY_TUPLE_INDEXING = 2
# Return code for 0-d boolean array handler
_NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1
_NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0
_NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1
# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None
def _int64_enabled():
global _INT64_TENSOR_SIZE_ENABLED
if _INT64_TENSOR_SIZE_ENABLED is None:
_INT64_TENSOR_SIZE_ENABLED = Features().is_enabled('INT64_TENSOR_SIZE')
return _INT64_TENSOR_SIZE_ENABLED
def _new_empty_handle():
"""Returns a new empty handle.
Empty handle can be used to hold a result.
Returns
-------
handle
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateNone(ctypes.byref(hdl)))
return hdl
def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
"""Return a new handle with specified shape and context.
Empty handle is only used to hold results.
Returns
-------
handle
A new empty `NDArray` handle.
"""
hdl = NDArrayHandle()
if _int64_enabled():
check_call(_LIB.MXNDArrayCreateEx64(
c_array_buf(mx_int64, native_array('q', shape)),
ctypes.c_int(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
else:
# When shape is larger than unit32 then there is an overflow error at python end itself.
# It needs to be caught here since the call doesn't even reach backend.
size = 1
for idx in shape:
size = size * idx
if size > _SIGNED_INT32_UPPER_LIMIT:
raise Exception("[_new_alloc_handle] Size of tensor you are trying to allocate is " +
"larger than 2^31 elements. Please build with flag " +
"USE_INT64_TENSOR_SIZE=1")
if np.dtype(dtype) == np.dtype([('bfloat16', np.uint16)]):
dtype_type = np.dtype(dtype)
else:
dtype_type = np.dtype(dtype).type
check_call(_LIB.MXNDArrayCreateEx(
c_array_buf(mx_uint, native_array('I', shape)),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[dtype_type])),
ctypes.byref(hdl)))
return hdl
def _new_from_shared_mem(shared_pid, shared_id, shape, dtype):
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreateFromSharedMemEx(
ctypes.c_int(shared_pid),
ctypes.c_int(shared_id),
c_array(mx_int, shape),
mx_int(len(shape)),
ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
ctypes.byref(hdl)))
return hdl
def waitall():
"""Wait for all async operations to finish in MXNet.
This function is used for benchmarking only.
.. note::
If your mxnet code throws an exception, then waitall can cause performance impact.
"""
check_call(_LIB.MXNDArrayWaitAll())
def _storage_type(handle):
storage_type = ctypes.c_int(0)
check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type)))
return storage_type.value
class NDArray(NDArrayBase):
"""An array object representing a multidimensional, homogeneous array of
fixed-size items.
"""
__slots__ = []
# make numpy functions return NDArray instead of numpy object array
__array_priority__ = 1000.0
# Extension type code for TVM function.
# See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h
_tvm_tcode = 19
# pylint: disable= no-member, undefined-variable
def as_np_ndarray(self):
"""Convert mxnet.ndarray.NDArray to mxnet.numpy.ndarray."""
storage_type = self.stype
if storage_type != 'default':
raise ValueError('cannot convert ndarray of stype {} to numpy ndarray'
.format(str(type(storage_type))))
from ..numpy import ndarray
hdl = NDArrayHandle()
check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
return ndarray(handle=hdl, writable=self.writable)
def as_nd_ndarray(self):
"""A convenience function for creating a classic ndarray from the current
ndarray with zero copy. For this class, it just returns itself since it is
already a classic ndarray."""
return self
@property
def _tvm_handle(self):
return self.handle.value
def __repr__(self):
"""Returns a string representation of the array."""
shape_info = 'x'.join(['%d' % x for x in self.shape])
return '\n%s\n<%s %s @%s>' % (str(self.asnumpy()),
self.__class__.__name__,
shape_info, self.ctx)
def __reduce__(self):
return NDArray, (None,), self.__getstate__()
def _to_shared_mem(self):
shared_pid = ctypes.c_int()
shared_id = ctypes.c_int()
check_call(_LIB.MXNDArrayGetSharedMemHandle(
self.handle, ctypes.byref(shared_pid), ctypes.byref(shared_id)))
return shared_pid.value, shared_id.value, self.shape, self.dtype
def __abs__(self):
"""x.__abs__() <=> abs(x) <=> x.abs() <=> mx.nd.abs(x, y)"""
return self.abs()
def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
return add(self, other)
def __iadd__(self, other):
"""x.__iadd__(y) <=> x+=y """
if not self.writable:
raise ValueError('trying to add to a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_add(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._plus_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
"""x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """
return subtract(self, other)
def __isub__(self, other):
"""x.__isub__(y) <=> x-=y """
if not self.writable:
raise ValueError('trying to subtract from a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_sub(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._minus_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other):
"""x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """
return subtract(other, self)
def __mul__(self, other):
"""x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """
return multiply(self, other)
def __neg__(self):
"""x.__neg__(y) <=> -x """
return _internal._mul_scalar(self, -1.0)
def __imul__(self, other):
"""x.__imul__(y) <=> x*=y """
if not self.writable:
raise ValueError('trying to multiply to a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_mul(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._mul_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
return self.__mul__(other)
def __div__(self, other):
"""x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """
return divide(self, other)
def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """
return divide(other, self)
def __idiv__(self, other):
"""x.__rdiv__(y) <=> x/=y """
if not self.writable:
raise ValueError('trying to divide from a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_div(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._div_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
return divide(self, other)
def __rtruediv__(self, other):
return divide(other, self)
def __itruediv__(self, other):
return self.__idiv__(other)
def __mod__(self, other):
"""x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
return modulo(self, other)
def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
return modulo(other, self)
def __imod__(self, other):
"""x.__rmod__(y) <=> x%=y """
if not self.writable:
raise ValueError('trying to take modulo from a readonly NDArray')
if isinstance(other, NDArray):
return op.broadcast_mod(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._mod_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __pow__(self, other):
"""x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
return power(self, other)
def __rpow__(self, other):
"""x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """
return power(other, self)
def __eq__(self, other):
"""x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """
return equal(self, other)
def __hash__(self):
"""Default hash function."""
return id(self)//16
def __ne__(self, other):
"""x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """
return not_equal(self, other)
def __gt__(self, other):
"""x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """
return greater(self, other)
def __ge__(self, other):
"""x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """
return greater_equal(self, other)
def __lt__(self, other):
"""x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
return lesser(self, other)
def __le__(self, other):
"""x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """
return lesser_equal(self, other)
def __bool__(self):
num_elements = reduce(operator.mul, self.shape, 1)
if num_elements == 0:
return False
elif num_elements == 1:
return bool(self.asscalar())
else:
raise ValueError("The truth value of an NDArray with multiple elements " \
"is ambiguous.")
__nonzero__ = __bool__
def __len__(self):
"""Number of element along the first axis."""
return self.shape[0]
def __getstate__(self):
handle = self.handle
this = {'handle' : None}
if handle is not None:
length = ctypes.c_size_t()
cptr = ctypes.POINTER(ctypes.c_char)()
check_call(_LIB.MXNDArraySaveRawBytes(self.handle,
ctypes.byref(length),
ctypes.byref(cptr)))
this['handle'] = ctypes2buffer(cptr, length.value)
return this
def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
buf = handle
handle = NDArrayHandle()
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
length = ctypes.c_size_t(len(buf))
check_call(_LIB.MXNDArrayLoadFromRawBytes(ptr, length, ctypes.byref(handle)))
self.handle = handle
else:
self.handle = None
def __setitem__(self, key, value):
"""x.__setitem__(i, y) <=> x[i]=y
Sets ``self[key]`` to ``value``.
This functions supports advanced indexing as defined in `the NumPy
advanced indexing documentation
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_,
with the restriction that boolean array indexing is not supported.
Parameters
----------
key : int, mxnet.ndarray.slice, list, np.ndarray, NDArray, or tuple of all previous types
The indexing key.
value : scalar or array-like object that can be broadcast to the shape of self[key]
The value to set.
Examples
--------
>>> x = mx.nd.zeros((2, 3))
>>> x[:] = 1
>>> x.asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> x[:, 1:2] = 2
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 1., 2., 1.]], dtype=float32)
>>> x[1:2, 1:] = 3
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 1., 3., 3.]], dtype=float32)
>>> x[1:, 0:2] = mx.nd.zeros((1, 2))
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 0., 0., 3.]], dtype=float32)
>>> x[1, 2] = 4
>>> x.asnumpy()
array([[ 1., 2., 1.],
[ 0., 0., 4.]], dtype=float32)
>>> x[[0], [1, 2]] = 5
>>> x.asnumpy()
array([[ 1., 5., 5.],
[ 0., 0., 4.]], dtype=float32)
>>> x[::-1, 0:2:2] = [6]
>>> x.asnumpy()
array([[ 6., 5., 5.],
[ 6., 0., 4.]], dtype=float32)
"""
if self.ndim == 0:
if not isinstance(key, (tuple, py_slice)):
raise IndexError('scalar tensor can only accept `()` and `:` as index')
if isinstance(key, tuple) and len(key) != 0:
raise IndexError('scalar tensor can only accept `()` and `:` as index')
if isinstance(value, numeric_types):
self._full(value)
elif isinstance(value, NDArray) and value.size == 1:
if value.shape != self.shape:
value = value.reshape(self.shape)
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)) and value.size == 1:
if isinstance(value, np.generic) or value.shape != self.shape:
value = value.reshape(self.shape)
self._sync_copyfrom(value)
else:
raise ValueError('setting an array element with a sequence.')
elif self.size == 0:
return
else:
key, _ = indexing_key_expand_implicit_axes(key, self.shape)
slc_key = tuple(idx for idx in key if idx is not None)
if len(slc_key) < self.ndim:
raise RuntimeError(
'too few indices after normalization: expected `ndim` ({}) '
'but got {}. This is a bug, please report it!'
''.format(self.ndim, len(slc_key))
)
if len(slc_key) > self.ndim:
raise IndexError(
'too many indices ({}) for array with {} dimensions'
''.format(len(slc_key), self.ndim)
)
indexing_dispatch_code = get_indexing_dispatch_code(slc_key)
if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
self._set_nd_basic_indexing(key, value)
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
self._set_nd_advanced_indexing(key, value)
else:
raise ValueError(
'Indexing NDArray with index {} of type {} is not supported'
''.format(key, type(key))
)
def __getitem__(self, key): # pylint: disable=too-many-return-statements
"""x.__getitem__(i) <=> x[i]
Returns a sliced view of this array if the elements fetched are contiguous in memory;
otherwise, returns a newly created NDArray.
This functions supports advanced indexing defined in the following reference with
some restrictions.
For basic indexing, i.e., if ``key`` consists only of integers,
``slice``, ``Ellipsis`` (``...``) and ``None``, a mutable view is
returned that shares memory with this array if the accessed portion is
contiguous in memory.
Otherwise, a newly created ``NDArray`` is returned.
This functions supports advanced indexing as defined in `the NumPy
advanced indexing documentation
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_,
with the restriction that boolean array indexing is not supported.
Parameters
----------
key : int, mxnet.ndarray.slice, list, np.ndarray, NDArray, or tuple of all previous types
Indexing key.
Examples
--------
The default is to give explicit indices for all axes:
>>> x = mx.nd.arange(0, 6).reshape((2, 3))
>>> x.asnumpy()
array([[ 0., 1., 2.],
[ 3., 4., 5.]], dtype=float32)
>>> x[0, :].asnumpy()
array([0., 1., 2.], dtype=float32)
>>> x[0, :2].asnumpy()
array([0., 1.], dtype=float32)
>>> x[:, :-1].asnumpy()
array([[0., 1.],
[3., 4.]], dtype=float32)
If fewer indices are given, they are automatically supplemented by an
appropriate number of ``slice(None)`` ("``:``") to the right. For
instance, a single integer indexes along the first axis:
>>> x = mx.nd.arange(0, 6).reshape((2, 3))
>>> x[0].asnumpy()
array([0., 1., 2.], dtype=float32)
>>> x[1:].asnumpy()
array([[3., 4., 5.]], dtype=float32)
To omit a range of axes that should be kept as-is, an `Ellipsis`
("``...``") can be used:
>>> x = mx.nd.arange(0, 16).reshape((2, 2, 2, 2))
>>> x[0, ..., 1].asnumpy()
array([[1., 3.],
[5., 7.]], dtype=float32)
>>> x[0, :, :, 1].asnumpy() # equivalent
array([[1., 3.],
[5., 7.]], dtype=float32)
New axes of length 1 can be created by inserting ``None``
(`numpy.newaxis`) in the index:
>>> x = mx.nd.arange(0, 6).reshape((2, 3))
>>> x[None, :, :].asnumpy()
array([[[0., 1., 2.],
[3., 4., 5.]]], dtype=float32)
>>> x[None, :, :].shape
(1, 2, 3)
If the indexed portion of the array is contiguous in memory, no data
is copied. Instead, a shared-memory view of the original array is
returned, and changes to that view affect the original array:
>>> x = mx.nd.arange(0, 8).reshape((2, 2, 2))
>>> y = x[0] # contiguous
>>> y.asnumpy()
array([[0., 1.],
[2., 3.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[[-1., -1.],
[-1., -1.]],
<BLANKLINE>
[[ 4., 5.],
[ 6., 7.]]], dtype=float32)
>>> x = mx.nd.arange(0, 8).reshape((2, 2, 2))
>>> y = x[1, :1, :] # contiguous
>>> y.asnumpy()
array([[4., 5.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[[ 0., 1.],
[ 2., 3.]],
<BLANKLINE>
[[-1., -1.],
[ 6., 7.]]], dtype=float32)
>>> x = mx.nd.arange(0, 8).reshape((2, 2, 2))
>>> y = x[:, :, 1] # not contiguous
>>> y.asnumpy()
array([[1., 3.],
[5., 7.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[[0., 1.],
[2., 3.]],
<BLANKLINE>
[[4., 5.],
[6., 7.]]], dtype=float32)
If the indexing key contains `list`, `numpy.ndarray` or `NDArray`
objects, advanced indexing is triggered, which always returns a
copy:
>>> x = mx.nd.arange(0, 8).reshape((2, 2, 2))
>>> x[[0, 1]].asnumpy()
array([[[0., 1.],
[2., 3.]],
<BLANKLINE>
[[4., 5.],
[6., 7.]]], dtype=float32)
>>> x[[0, 1], :].asnumpy() # equivalent
array([[[0., 1.],
[2., 3.]],
<BLANKLINE>
[[4., 5.],
[6., 7.]]], dtype=float32)
>>> y = np.array([0, 1], dtype='int32')
>>> x[1:, y].asnumpy()
array([[[4., 5.],
[6., 7.]]], dtype=float32)
>>> y = mx.nd.array([0, 1], dtype='int32')
>>> x[1:, y].asnumpy()
array([[[4., 5.],
[6., 7.]]], dtype=float32)
"""
ndim = self.ndim
shape = self.shape
if ndim == 0 and (key == () or key == slice(None, None, None)):
return self
# Handle simple cases for higher speed
if isinstance(key, tuple) and len(key) == 0:
return self
if isinstance(key, tuple) and len(key) == ndim\
and all(isinstance(idx, integer_types) for idx in key):
out = self
for idx in key:
out = out[idx]
return out
if isinstance(key, integer_types):
if key > shape[0] - 1:
raise IndexError(
'index {} is out of bounds for axis 0 with size {}'.format(
key, shape[0]))
return self._at(key)
elif isinstance(key, py_slice):
if (key.step is None or key.step == 1):
if key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
return self
elif key.step == 0:
raise ValueError("slice step cannot be zero")
key, _ = indexing_key_expand_implicit_axes(key, self.shape)
if len(key) == 0:
raise ValueError('indexing key cannot be an empty tuple')
indexing_dispatch_code = get_indexing_dispatch_code(key)
if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
return self._get_nd_basic_indexing(key)
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
return self._get_nd_advanced_indexing(key)
else:
raise RuntimeError
def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
"""Return a broadcast `NDArray` with same context and dtype as ``self``.
For setting item, The returned `ndarray` is squeezed according to squeeze_axes since the
value_nd is assigned to not yet expanded space in original array.
`value`: numeric types or array like.
`bcast_shape`: a shape tuple.
`squeeze_axes`: a sequence of axes to squeeze in the value array.
"""
if isinstance(value, numeric_types):
value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype)
elif type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck
value_nd = value.as_in_context(self.ctx)
if value_nd.dtype != self.dtype:
value_nd = value_nd.astype(self.dtype)
else:
try:
value_nd = array(value, ctx=self.ctx, dtype=self.dtype)
except:
raise TypeError('{} does not support assignment with non-array-like '
'object {} of type {}'.format(self.__class__, value, type(value)))
# For setitem, if there is None in indices, we need to squeeze the assigned value_nd
# since None is also ignored in slicing the original array.
if squeeze_axes and value_nd.ndim > len(bcast_shape):
squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)])
value_nd = value_nd.squeeze(axis=tuple(squeeze_axes))
# handle the cases like the following
# a = nd.zeros((3, 3)), b = nd.ones((1, 1, 1, 1, 3)), a[0] = b
# b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed
if value_nd.ndim > len(bcast_shape):
squeeze_axes = []
for i in range(value_nd.ndim - len(bcast_shape)):
if value_nd.shape[i] == 1:
squeeze_axes.append(i)
else:
break
if squeeze_axes:
value_nd = value_nd.squeeze(squeeze_axes)
if value_nd.shape != bcast_shape:
if value_nd.size == 0:
value_nd = value_nd.reshape(bcast_shape)
else:
value_nd = value_nd.broadcast_to(bcast_shape)
return value_nd
# pylint: disable=invalid-name
@staticmethod
def _basic_indexing_key_to_begin_end_step(idcs, shape, keep_none=True):
"""Map a tuple of ``slice`` and ``None`` (ignored) to begin, end, step tuples."""
idcs = [idx for idx in idcs if idx is not None]
idcs = [idx if isinstance(idx, py_slice) else _int_to_slice(idx)
for idx in idcs]
if keep_none:
sss_list = [(slc.start, slc.stop, slc.step) for slc, n in zip(idcs, shape)]
else:
sss_list = [slc.indices(n) for slc, n in zip(idcs, shape)]
return tuple(zip(*sss_list))
# pylint: enable=invalid-name
# pylint: disable=invalid-name
@staticmethod
def _basic_indexing_key_int_to_slice(idcs):
"""Return the converted indexing tuple and the integer axes."""
int_axes = []
conv_idcs = []
for ax, idx in enumerate(idcs):
if isinstance(idx, integer_types):
conv_idcs.append(_int_to_slice(idx))
int_axes.append(ax)
else:
conv_idcs.append(idx)
return tuple(conv_idcs), tuple(int_axes)
# pylint: enable=invalid-name
@staticmethod
def _new_axes_after_basic_indexing(axes, key):
"""Return indices of ``axes`` after slicing with ``key``.
This function is used to calculate the positions where new axes should
end up after indexing, taking into account the removal of axes by
integer indexing.
The ``key`` sequence should be the exapanded key including slices, integer types
and ``None``.
"""
steps = [0] + [0 if isinstance(idx, integer_types) else 1 for idx in key]
cum_steps = np.cumsum(steps)
axes_after = tuple(cum_steps[axes])
return axes_after
@staticmethod
def _new_axes_after_advanced_indexing(key, adv_axs, bcast_adv_ndim, adv_are_adjacent): # pylint: disable=invalid-name
"""
Return indices of ``axes`` after slicing with ``key_nd``.
This function is used to calculate the positions where new axes should
end up after indexing, taking into account the removal of axes by
integer indexing.
The ``key`` sequence should be the exapanded key including slices, array like objects,
integer types and ``None``.
``adv_axes`` is the sequence of indices of advanced axes.
``bcast_adv_ndim`` is the number of dimensions of advanced indexing subspace.
``adv_are_adjacent`` is a boolean value. Value being True means all advanced indicies are adjacent.
Note: integer indices are also considered advanced indices here.
"""
new_axes = [ax for ax in range(len(key)) if key[ax] is None]
adv_axs_set = set(adv_axs)
if not adv_are_adjacent:
steps = [bcast_adv_ndim] + [0 if ax in adv_axs_set else 1 for ax in range(len(key))]
else:
steps = [0] + [0 if ax in adv_axs_set else 1 for ax in range(len(key))]
cum_steps = np.cumsum(steps)
axes_after = tuple(cum_steps[new_axes])
return axes_after
# pylint: disable=invalid-name
@staticmethod
def _basic_indexing_slice_is_contiguous(slc_key, shape):
"""Whether indexing with the given key results in a contiguous array.
The rule is: From right to left, if in an axis, a slice produces a
proper subset, the later slice must have <=1 elements.
The ``slc_key`` sequence must have the same length as ``shape`` and
only contain `slice` objects.
"""
assert len(slc_key) == len(shape)
is_subset = False
total_sliced_elements = np.prod([_get_slice_len(slc, n)
for slc, n in zip(slc_key, shape)])
if total_sliced_elements in (0, 1):
return True
for idx, n in zip(reversed(slc_key), reversed(shape)):
_, _, step = idx.indices(n)
num_elements = _get_slice_len(idx, n)
if num_elements == 0:
return True
elif num_elements > 1 and (step > 1 or step < 0):
# We do not support the case of reverse slicing of multiple elements and
# forward slicing of #elements > 1 and step > 1
return False
elif is_subset:
if num_elements > 1:
return False
else:
if num_elements < n:
is_subset = True
return True
# pylint: enable=invalid-name
@staticmethod
def _basic_indexing_sliced_shape(slc_key, shape):
"""Return the shape after slicing with the given key."""
assert len(slc_key) == len(shape)
sliced_shape = []
for slc, n in zip(slc_key, shape):
num_elements = _get_slice_len(slc, n)
sliced_shape.append(num_elements)
return tuple(sliced_shape)
# pylint: disable=invalid-name
@staticmethod
def _basic_indexing_contiguous_flat_begin_end(slc_key, shape):
"""Return the flat indices of begin and end for contiguous slicing."""
assert len(slc_key) == len(shape)
flat_begin, flat_end = 0, 0
for slc, n in zip(slc_key, shape):
flat_begin *= n
flat_end *= n
begin, _, _ = slc.indices(n)
num_elements = _get_slice_len(slc, n)
if num_elements == 0:
return 0, 0
else:
flat_begin += begin
flat_end += begin + num_elements - 1
return flat_begin, flat_end + 1
# pylint: enable=invalid-name
@staticmethod
def _drop_int_axes(indexed_shape, int_axes):
"""drop the axis of indexed_shape corresponding to int axes"""
bcast_shape = []
for i, size in enumerate(indexed_shape):
if i not in int_axes:
bcast_shape.append(size)
if not bcast_shape:
bcast_shape = [1]
return tuple(bcast_shape)
def _set_nd_basic_indexing(self, key, value):
"""This function indexes ``self`` with a tuple of ``slice`` objects only."""
for idx in key:
if idx is not None and not isinstance(idx, (py_slice, integer_types)):
raise RuntimeError(
'`key` may only contain `slice` or integer objects in the '
'basic implementation, got object of type {}. '
'This is a bug, please report it!'
''.format(type(idx)))
key_nd = tuple(idx for idx in key if idx is not None)
int_axes = [
ax for ax in range(len(key_nd)) if isinstance(key_nd[ax], integer_types)
]
# Check bounds for integer axes
for ax in int_axes: # pylint: disable=invalid-name
if not -self.shape[ax] <= key_nd[ax] < self.shape[ax]:
raise IndexError(
'index {} is out of bounds for axis {} with size {}'
''.format(key_nd[ax], ax, self.shape[ax]))
begin, end, step = self._basic_indexing_key_to_begin_end_step(
key, self.shape, keep_none=False
)
indexed_shape = tuple(
_get_dim_size(b, e, s) for b, e, s in zip(begin, end, step)
)
can_assign_directly = (
(indexed_shape == self.shape) and all(s > 0 for s in step)
)
begin, end, step = self._basic_indexing_key_to_begin_end_step(
key, self.shape, keep_none=True
)
none_axes = [ax for ax in range(len(key)) if key[ax] is None]
new_axes = self._new_axes_after_basic_indexing(none_axes, key)
if can_assign_directly:
# Easy case, overwrite whole array.
if type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck
if value.handle is not self.handle:
# Need to do this before `broadcast_to`.
bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
value_nd = value_nd.reshape(indexed_shape)
value_nd.copyto(self)
elif isinstance(value, numeric_types):
if isinstance(value, bool):
self._full(int(value))
else:
self._full(value)
elif isinstance(value, (np.ndarray, np.generic)):
tmp_shape = _shape_for_bcast(
value.shape, target_ndim=self.ndim, new_axes=int_axes
)
value = value.reshape(tmp_shape)
if isinstance(value, np.generic) or value.shape != self.shape:
value = np.broadcast_to(value, self.shape)
self._sync_copyfrom(value)
else:
# Other array-like
# drop the axis of indexed_shape corresponding to int axes
bcast_shape = self._drop_int_axes(indexed_shape, int_axes)
value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes)
value_nd = value_nd.reshape(indexed_shape)
value_nd.copyto(self)
elif isinstance(value, numeric_types):
self.slice_assign_scalar(float(value), begin, end, step)