diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index eee2ccf14a8e..b846700fb503 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -93,7 +93,7 @@ def imdecode(buf, *args, **kwargs): Parameters ---------- - buf : str/bytes or numpy.ndarray + buf : str/bytes/bytearray or numpy.ndarray Binary image data as string or numpy ndarray. flag : int, optional, default=1 1 for three channel color output. 0 for grayscale output. @@ -135,10 +135,15 @@ def imdecode(buf, *args, **kwargs): """ if not isinstance(buf, nd.NDArray): - if sys.version_info[0] == 3 and not isinstance(buf, (bytes, np.ndarray)): - raise ValueError('buf must be of type bytes or numpy.ndarray,' + if sys.version_info[0] == 3 and not isinstance(buf, (bytes, bytearray, np.ndarray)): + raise ValueError('buf must be of type bytes, bytearray or numpy.ndarray,' 'if you would like to input type str, please convert to bytes') buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) + + if len(buf) == 0: + # empty buf causes OpenCV crash. + raise ValueError("input buf cannot be empty.") + return _internal._cvimdecode(buf, *args, **kwargs) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 0df08af317aa..c8022b67bee8 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -92,6 +92,22 @@ def test_imdecode(self): cv_image = cv2.imread(img) assert_almost_equal(image.asnumpy(), cv_image) + def test_imdecode_bytearray(self): + try: + import cv2 + except ImportError: + return + for img in TestImage.IMAGES: + with open(img, 'rb') as fp: + str_image = bytearray(fp.read()) + image = mx.image.imdecode(str_image, to_rgb=0) + cv_image = cv2.imread(img) + assert_almost_equal(image.asnumpy(), cv_image) + + @raises(ValueError) + def test_imdecode_empty_buffer(self): + mx.image.imdecode(b'', to_rgb=0) + def test_scale_down(self): assert mx.image.scale_down((640, 480), (720, 120)) == (640, 106) assert mx.image.scale_down((360, 1000), (480, 500)) == (360, 375)