Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
in img.imdecode,when py3 check buf param,if it is str,convert it to b…
Browse files Browse the repository at this point in the history
…ytes.
  • Loading branch information
yajiedesign committed Apr 27, 2018
1 parent f236687 commit 5234b10
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import random
import logging
import json
import sys

import numpy as np

try:
Expand Down Expand Up @@ -132,6 +134,9 @@ def imdecode(buf, *args, **kwargs):
<NDArray 224x224x3 @cpu(0)>
"""
if not isinstance(buf, nd.NDArray):
if sys.version_info[0] == 3:
if isinstance(buf, str):
buf = bytes(buf, "ascii")
buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8)
return _internal._cvimdecode(buf, *args, **kwargs)

Expand Down Expand Up @@ -481,6 +486,7 @@ def random_size_crop(src, size, min_area, ratio, interp=2):

class Augmenter(object):
"""Image Augmenter base class"""

def __init__(self, **kwargs):
self._kwargs = kwargs
for k, v in self._kwargs.items():
Expand Down Expand Up @@ -513,6 +519,7 @@ class SequentialAug(Augmenter):
ts : list of augmenters
A series of augmenters to be applied in sequential order.
"""

def __init__(self, ts):
super(SequentialAug, self).__init__()
self.ts = ts
Expand All @@ -538,6 +545,7 @@ class ResizeAug(Augmenter):
interp : int, optional, default=2
Interpolation method. See resize_short for details.
"""

def __init__(self, size, interp=2):
super(ResizeAug, self).__init__(size=size, interp=interp)
self.size = size
Expand All @@ -558,6 +566,7 @@ class ForceResizeAug(Augmenter):
interp : int, optional, default=2
Interpolation method. See resize_short for details.
"""

def __init__(self, size, interp=2):
super(ForceResizeAug, self).__init__(size=size, interp=interp)
self.size = size
Expand All @@ -579,6 +588,7 @@ class RandomCropAug(Augmenter):
interp : int, optional, default=2
Interpolation method. See resize_short for details.
"""

def __init__(self, size, interp=2):
super(RandomCropAug, self).__init__(size=size, interp=interp)
self.size = size
Expand All @@ -603,6 +613,7 @@ class RandomSizedCropAug(Augmenter):
interp: int, optional, default=2
Interpolation method. See resize_short for details.
"""

def __init__(self, size, min_area, ratio, interp=2):
super(RandomSizedCropAug, self).__init__(size=size, min_area=min_area,
ratio=ratio, interp=interp)
Expand All @@ -626,6 +637,7 @@ class CenterCropAug(Augmenter):
interp : int, optional, default=2
Interpolation method. See resize_short for details.
"""

def __init__(self, size, interp=2):
super(CenterCropAug, self).__init__(size=size, interp=interp)
self.size = size
Expand All @@ -644,6 +656,7 @@ class RandomOrderAug(Augmenter):
ts : list of augmenters
A series of augmenters to be applied in random order
"""

def __init__(self, ts):
super(RandomOrderAug, self).__init__()
self.ts = ts
Expand All @@ -668,6 +681,7 @@ class BrightnessJitterAug(Augmenter):
brightness : float
The brightness jitter ratio range, [0, 1]
"""

def __init__(self, brightness):
super(BrightnessJitterAug, self).__init__(brightness=brightness)
self.brightness = brightness
Expand All @@ -687,6 +701,7 @@ class ContrastJitterAug(Augmenter):
contrast : float
The contrast jitter ratio range, [0, 1]
"""

def __init__(self, contrast):
super(ContrastJitterAug, self).__init__(contrast=contrast)
self.contrast = contrast
Expand All @@ -710,6 +725,7 @@ class SaturationJitterAug(Augmenter):
saturation : float
The saturation jitter ratio range, [0, 1]
"""

def __init__(self, saturation):
super(SaturationJitterAug, self).__init__(saturation=saturation)
self.saturation = saturation
Expand All @@ -734,6 +750,7 @@ class HueJitterAug(Augmenter):
hue : float
The hue jitter ratio range, [0, 1]
"""

def __init__(self, hue):
super(HueJitterAug, self).__init__(hue=hue)
self.hue = hue
Expand Down Expand Up @@ -772,6 +789,7 @@ class ColorJitterAug(RandomOrderAug):
saturation : float
The saturation jitter ratio range, [0, 1]
"""

def __init__(self, brightness, contrast, saturation):
ts = []
if brightness > 0:
Expand All @@ -795,6 +813,7 @@ class LightingAug(Augmenter):
eigvec : 3x3 np.array
Eigen vectors
"""

def __init__(self, alphastd, eigval, eigvec):
super(LightingAug, self).__init__(alphastd=alphastd, eigval=eigval, eigvec=eigvec)
self.alphastd = alphastd
Expand All @@ -819,6 +838,7 @@ class ColorNormalizeAug(Augmenter):
std : NDArray
RGB standard deviation to be divided
"""

def __init__(self, mean, std):
super(ColorNormalizeAug, self).__init__(mean=mean, std=std)
self.mean = nd.array(mean) if mean is not None else None
Expand All @@ -837,6 +857,7 @@ class RandomGrayAug(Augmenter):
p : float
Probability to convert to grayscale
"""

def __init__(self, p):
super(RandomGrayAug, self).__init__(p=p)
self.p = p
Expand All @@ -859,6 +880,7 @@ class HorizontalFlipAug(Augmenter):
p : float
Probability to flip image horizontally
"""

def __init__(self, p):
super(HorizontalFlipAug, self).__init__(p=p)
self.p = p
Expand All @@ -872,6 +894,7 @@ def __call__(self, src):

class CastAug(Augmenter):
"""Cast to float32"""

def __init__(self, typ='float32'):
super(CastAug, self).__init__(type=typ)
self.typ = typ
Expand Down Expand Up @@ -1058,7 +1081,8 @@ def __init__(self, batch_size, data_shape, label_width=1,
logging.info('%s: loading recordio %s...',
class_name, path_imgrec)
if path_imgidx:
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
self.imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec,
'r') # pylint: disable=redefined-variable-type
self.imgidx = list(self.imgrec.keys)
else:
self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type
Expand Down Expand Up @@ -1204,6 +1228,7 @@ def check_valid_image(self, data):
def imdecode(self, s):
"""Decodes a string or byte string to an NDArray.
See mx.img.imdecode for more details."""

def locate():
"""Locate the image file/index if decode fails."""
if self.seq is not None:
Expand All @@ -1216,6 +1241,7 @@ def locate():
else:
msg = "index: {}".format(idx)
return "Broken image " + msg

try:
img = imdecode(s)
except Exception as e:
Expand Down

0 comments on commit 5234b10

Please sign in to comment.