diff --git a/.gitignore b/.gitignore index 6b1fe6fa172d..24d1338bbf77 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,4 @@ tags miniconda.sh deps_version +*.noseids diff --git a/Makefile b/Makefile index 98d42e8bb428..164fafe9b9fe 100644 --- a/Makefile +++ b/Makefile @@ -96,7 +96,8 @@ endif # SFrame flexible_type FLEXIBLE_TYPE = $(ROOTDIR)/flexible_type -LIB_DEP += $(FLEXIBLE_TYPE)/build/libflexible_type.a +LIB_DEP += $(FLEXIBLE_TYPE)/build/libflexible_type.a +LDFLAGS += -lpng -ljpeg -lz # plugins include $(MXNET_PLUGINS) diff --git a/flexible_type/Makefile b/flexible_type/Makefile index 56c9b579c716..42ad479148df 100644 --- a/flexible_type/Makefile +++ b/flexible_type/Makefile @@ -16,7 +16,7 @@ clean: flexible_type: build/libflexible_type.a test: build/flexible_type_test -OBJS = $(addprefix build/, flexible_type/flexible_type.o image/image_type.o) +OBJS = $(addprefix build/, flexible_type/flexible_type.o image/image_type.o image/jpeg_io.o image/png_io.o) build/libflexible_type.a: $(OBJS) ar crv $@ $(filter %.o, $?) @@ -27,4 +27,4 @@ build/%.o: $(SFRAME_SRC)/%.cpp $(CXX) $(CFLAGS) -c $< -o $@ build/flexible_type_test: test/flexible_type_test.cpp - $(CXX) $(CFLAGS) $< -o $@ -L build -lflexible_type + $(CXX) $(CFLAGS) $< -o $@ -L build -lflexible_type -lpng -ljpeg -lz diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 45bb25976006..935ea472ed8e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -227,21 +227,30 @@ MXNET_DLL int MXNDArrayLoad(const char* fname, MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const void *data, size_t size); +/*! + * \brief + * \param handle Handle of NDArray + * \param idx the offset of the dest array for write + * \param batch_size size of current batch + * \param field_length_p array contain the size (number of floats) of ith data field. Length of the array is "size". + */ +struct SFrameCallbackHandle { + NDArrayHandle handle; + size_t idx; + size_t batch_size; + size_t* field_length_p; +}; + /*! * \brief Perform a synchonize copy by using as SFrame callback * This function will call WaitToWrite before the copy is performed. * This is useful to copy data from existing memory region that are * not wrapped by NDArray(thus dependency not being tracked). * - * \param callback_handle pointer to struct of callback function state * \param data the data source to copy from * \param size the element size to copy + * \param callback_handle pointer to struct of callback function state */ -struct SFrameCallbackHandle { - NDArrayHandle handle; - size_t idx; - size_t batch_size; -}; MXNET_DLL int MXNDArraySyncCopyFromSFrame(const void *data, size_t size, void* callback_handle); diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e55094183b45..58798660e0b1 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -217,11 +217,13 @@ class NDArray { * \param size field lenth of data * \param idx index in batch * \param batch_size total batch size + * \param field_length_p size (in number of floats) of each data element in the row */ void SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t size, size_t idx, - size_t batch_size) const; + size_t batch_size, + size_t* field_length_p) const; /*! * \brief Do a synchronize copy to a continugous CPU memory region. * diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 35219cff68bf..9bd8098a67a1 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -14,12 +14,14 @@ from .base import c_array, c_str, mx_uint, py_str from .base import DataIterHandle, NDArrayHandle from .base import check_call, ctypes2docstring +from array import array as _array from .ndarray import NDArray from .ndarray import array from .ndarray import _copy_from_sarray, _copy_from_sframe DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad', 'index']) + class DataIter(object): """DataIter object in mxnet. """ @@ -356,36 +358,57 @@ def getpad(self): else: return 0 +try: + import graphlab as gl +except: + try: + import sframe as gl + except: + pass + + class SFrameIter(DataIter): - def __init__(self, sframe, data_field, data_shape, label_field=None, batch_size=1): + def __init__(self, sframe, data_field, label_field=None, batch_size=1): """ - Single data input, single label SFrame iterator + Iterator over from SFrame Parameters ---------- sframe: SFrame object - source SFrmae + source SFrame data_field: string or list(string) select fields of training data. For image or array type, only support string - data_shape: tuple - input data shape label_field: string (optional) label field in SFrame - batch_size: int + batch_size: int (optional) batch size """ super(SFrameIter, self).__init__() + if not isinstance(sframe, gl.SFrame): + raise TypeError + if not (isinstance(data_field, str) or isinstance(data_field, list)): + raise TypeError + if not (label_field is None or isinstance(label_field, str)): + raise TypeError + + if type(data_field) is str: + data_field = [data_field] + + self._type_check(sframe, data_field, label_field) self.data_field = data_field self.label_field = label_field self.data_sframe = sframe[data_field] - if label_field != None: + if label_field is not None: self.label_sframe = sframe[label_field] + # allocate ndarray - data_shape = list(data_shape) + inferred_shape = self.infer_shape() + data_shape = list(inferred_shape["final_shape"]) data_shape.insert(0, batch_size) self.data_shape = tuple(data_shape) self.label_shape = (batch_size, ) + self.field_length = inferred_shape["field_length"] self.data_ndarray = array(np.zeros(self.data_shape)) self.label_ndarray = array(np.zeros(self.label_shape)) self.data = _init_data(self.data_ndarray, allow_empty=False, default_name="data") @@ -399,6 +422,7 @@ def __init__(self, sframe, data_field, data_shape, label_field=None, batch_size= def provide_data(self): """The name and shape of data provided by this iterator""" return [(k, tuple([self.batch_size] + list(v.shape[1:]))) for k, v in self.data] + @property def provide_label(self): """The name and shape of label provided by this iterator""" @@ -409,13 +433,59 @@ def reset(self): self.cursor = 0 self.has_next = True - def _copy(self, start, end, bias=0): - if isinstance(self.data_field, list): - _copy_from_sframe(self.data_sframe, self.data_ndarray, start, end, bias) + def _type_check(self, sframe, data_field, label_field): + if label_field is not None: + label_column_type = sframe[label_field].dtype() + if label_column_type not in [int, float]: + raise TypeError('Unexpected type for label_field \"%s\". Expect int or float, got %s' % + (label_field, str(label_column_type))) + for col in data_field: + col_type = sframe[col].dtype() + if col_type not in [int, float, _array, gl.Image]: + raise TypeError('Unexpected type for data_field \"%s\". Expect int, float, array or image, got %s' % + (col, str(col_type))) + + def _infer_column_shape(self, sarray): + dtype = sarray.dtype() + if (dtype in [int, float]): + return (1, ) + elif dtype is _array: + lengths = sarray.item_length() + if lengths.min() != lengths.max(): + raise ValueError('Array column does not have the same length') + else: + return (lengths.max(), ) + elif dtype is gl.Image: + first_image = sarray.dropna()[0] + return (first_image.channels, first_image.height, first_image.width) + + def infer_shape(self): + ret = {"field_length": [], "final_shape": None} + features = self.data_sframe.column_names() + assert len(features) > 0 + if len(features) > 1: + # If more than one feature, all features must be numeric or array + shape = 0 + for col in features: + colshape = self._infer_column_shape(self.data_sframe[col]) + if len(colshape) != 1: + raise ValueError('Only one column is allowed if input is image typed') + shape += colshape[0] + ret["field_length"].append(colshape[0]) + ret["final_shape"] = (shape,) else: - _copy_from_sarray(self.data_sframe, self.data_ndarray, start, end, bias) - if isinstance(self.label_field, str): - _copy_from_sarray(self.label_sframe, self.label_ndarray, start, end) + col_shape = self._infer_column_shape(self.data_sframe[features[0]]) + ret["final_shape"] = col_shape + length = 1 + for x in col_shape: + length = length * x + ret["field_length"].append(length) + return ret + + def _copy(self, start, end, bias=0): + _copy_from_sframe(self.data_sframe, self.data_ndarray, start, end, self.field_length, bias) + if self.label_field is not None: + _copy_from_sarray(self.label_sframe, self.label_ndarray, start, end, 1, bias) def iter_next(self): if self.has_next: @@ -448,7 +518,6 @@ def getpad(self): return self.pad - class MXDataIter(DataIter): """DataIter built in MXNet. List all the needed functions here. Parameters diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index ef2fc17cf92c..eb0d43c944eb 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -62,31 +62,44 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): ctypes.byref(hdl))) return hdl +try: + import sframe as gl +except: + pass + +try: + import graphlab as gl +except: + pass + class SFrameCallBackHandle(ctypes.Structure): - _fields_ = [('handle', ctypes.c_void_p), ('idx', ctypes.c_ulonglong), ('batch_size', ctypes.c_ulonglong)] + _fields_ = [('handle', ctypes.c_void_p), + ('idx', ctypes.c_ulonglong), + ('batch_size', ctypes.c_ulonglong), + ('field_length_p', ctypes.POINTER(ctypes.c_ulonglong))] -def _copy_from_sframe(sf, arr, start, end, bias=0): + +def _copy_from_sframe(sf, arr, start, end, field_length, bias=0): + assert isinstance(sf, gl.SFrame) callback = _LIB.MXNDArraySyncCopyFromSFrame callback.argtypes = [ctypes.c_void_p, ctypes.c_ulonglong, ctypes.c_void_p] callback.restype = ctypes.c_int addr = ctypes.cast(callback, ctypes.c_void_p).value - callback_handle = SFrameCallBackHandle(arr.handle.value, bias, (end - start)) - try: - import graphlab as gl - except: - import sframe as gl - assert isinstance(sf, gl.SFrame) + num_fields = sf.num_columns() + c_field_length_arr = (ctypes.c_ulonglong * num_fields)() + for i in range(num_fields): + c_field_length_arr[i] = int(field_length[i]) + callback_handle = SFrameCallBackHandle(arr.handle.value, bias, (end - start), + ctypes.cast(c_field_length_arr, ctypes.POINTER(ctypes.c_ulonglong))) gl.extensions.sframe_callback(sf, addr, ctypes.addressof(callback_handle), start, end) -def _copy_from_sarray(sa, arr, start, end, bias=0): - try: - import graphlab as gl - except: - import sframe as gl + +def _copy_from_sarray(sa, arr, start, end, field_length, bias=0): assert isinstance(sa, gl.SArray) sf = gl.SFrame({'__tmp__': sa}) - _copy_from_sframe(sf, arr, start, end, bias) + _copy_from_sframe(sf, arr, start, end, [field_length], bias) + def waitall(): """Wait all async operation to finish in MXNet diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 34cebf3aff33..97364ac57813 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -192,7 +192,7 @@ int MXNDArraySyncCopyFromSFrame(const void *data, const graphlab::flexible_type *flex_data = reinterpret_cast(data); static_cast(callback_handle->handle)->SyncCopyFromSFrame(flex_data, - size, callback_handle->idx, callback_handle->batch_size); + size, callback_handle->idx, callback_handle->batch_size, callback_handle->field_length_p); ++callback_handle->idx; API_END(); } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index a6c45dc0f998..1a88d7528be8 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -12,6 +12,10 @@ #include #include "./ndarray_function.h" +// from graphlab +#include +#include + namespace dmlc { DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg); } // namespace dmlc @@ -564,53 +568,84 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const { } void NDArray::SyncCopyFromSFrame(const graphlab::flexible_type *data, size_t size, - size_t idx, size_t batch_size) const { + size_t idx, size_t batch_size, size_t* field_length_p) const { CHECK_GE(size, 1); this->WaitToWrite(); TShape dshape = this->shape(); TBlob dst = this->data(); - CHECK_EQ(this->ctx().dev_type, Context::kCPU) - << "Only support copy SFrame to CPU NDarray"; + CHECK_EQ(this->ctx().dev_type, Context::kCPU) << "Only support copy SFrame to CPU NDarray"; CHECK_EQ(dst.type_flag_, mshadow::DataType::kFlag); + + for (size_t i = 0; i < size; ++i) { + CHECK_NE(data[i].get_type(), graphlab::flex_type_enum::UNDEFINED) << "Missing value is not supported. Please use fillna() or dropna() to remove missing values."; + } + + // Case 1: Image type auto type = data[0].get_type(); - // TODO(bing): segfault in sframe if get 2d - if (size == 1 && - (type == graphlab::flex_type_enum::FLOAT || type == graphlab::flex_type_enum::INTEGER)) { - mshadow::Tensor obj = dst.GetWithShape( - mshadow::Shape1(dshape[0])); - obj[idx] = data[0].to(); - } else { - mshadow::Tensor obj = dst.GetWithShape( - mshadow::Shape2(dshape[0], dshape.Size() / dshape[0])); - if (type == graphlab::flex_type_enum::VECTOR) { - CHECK_EQ(size, 1) << "Only support 1 field if input is array"; - graphlab::flex_vec v = data[0].to(); - for (size_t i = 0; i < v.size(); ++i) { - obj[idx][i] = static_cast(v[i]); - } - } else if (type == graphlab::flex_type_enum::IMAGE) { - CHECK_EQ(size, 1) << "Only support 1 field if input is image"; - auto img = data[0].to(); - mshadow::Tensor batch_tensor = dst.GetWithShape( + if (type == graphlab::flex_type_enum::IMAGE) { + CHECK_EQ(size, 1) << "Image data only support one input field"; + graphlab::image_type img = data[0].get(); + mshadow::Tensor batch_tensor = dst.GetWithShape( mshadow::Shape4(dshape[0], img.m_channels, img.m_height, img.m_width)); - CHECK_EQ(img.is_decoded(), true) << "image must be decoded by using SFrame"; - size_t cnt = 0; - const unsigned char* raw_data = img.get_image_data(); - for (size_t i = 0; i < img.m_height; ++i) { - for (size_t j = 0; j < img.m_width; ++j) { - for (size_t k = 0; k < img.m_channels; ++k) { - batch_tensor[idx][k][i][j] = raw_data[cnt++]; - } + + // Shape check + CHECK_EQ(dshape[1], img.m_channels) << "Unexpected image shape. Please use gl.image_analysis.resize() to resize the images. Expect channel is " + << dshape[1] << " actual " << img.m_channels; + CHECK_EQ(dshape[2], img.m_height) << "Unexpected image shape. Please use gl.image_analysis.resize() to resize the images. Expect height is " + << dshape[2] << " actual " << img.m_height; + CHECK_EQ(dshape[3], img.m_width) << "Unexpected image shape. Please use gl.image_analysis.resize() to resize the images. Expect width is " + << dshape[3] << " actual " << img.m_width; + // Decode if needed + if (!img.is_decoded()) { + char* buf = NULL; + size_t length = 0; + if (img.m_format == graphlab::Format::JPG) { + graphlab::decode_jpeg((const char*)img.get_image_data(), img.m_image_data_size, &buf, length); + } else if (img.m_format == graphlab::Format::PNG) { + graphlab::decode_png((const char*)img.get_image_data(), img.m_image_data_size, &buf, length); + } + img.m_image_data.reset(buf); + img.m_image_data_size = length; + img.m_format = graphlab::Format::RAW_ARRAY; + } + size_t cnt = 0; + const unsigned char* raw_data = img.get_image_data(); + for (size_t i = 0; i < img.m_height; ++i) { + for (size_t j = 0; j < img.m_width; ++j) { + for (size_t k = 0; k < img.m_channels; ++k) { + batch_tensor[idx][k][i][j] = raw_data[cnt++]; } } - } else if (type == graphlab::flex_type_enum::FLOAT || - type == graphlab::flex_type_enum::INTEGER) { - CHECK_EQ(obj.shape_[1], size) << "Input dimension doesn't match"; - for (size_t i = 0; i < size; ++i) { - obj[idx][i] = data[i].to(); + } + return; + } + // Case 2: Single value type (should really get rid of this special case) + if (size == 1 && (type == graphlab::flex_type_enum::FLOAT || type == graphlab::flex_type_enum::INTEGER)) { + auto shape = mshadow::Shape1(dshape[0]); + mshadow::Tensor obj = dst.GetWithShape(shape); + obj[idx] = (float)(data[0]); + return; + } + // Case 3: Array type or mixed types + auto shape = mshadow::Shape2(dshape[0], dshape.Size() / dshape[0]); + mshadow::Tensor obj = dst.GetWithShape(shape); + size_t pos = 0; + for (size_t i = 0; i < size; ++i) { + auto type = data[i].get_type(); + if (type == graphlab::flex_type_enum::VECTOR) { + const graphlab::flex_vec& v = data[i].to(); + CHECK_EQ(v.size(), field_length_p[i]); + for (size_t j = 0; j < v.size(); ++j) { + obj[idx][pos++] = (float)(v[j]); } + } else if (type == graphlab::flex_type_enum::INTEGER || + type == graphlab::flex_type_enum::FLOAT) { + obj[idx][pos++] = (float)(data[i]); + } else { + CHECK(false) << "Unsupported type"; } } + return; } void NDArray::SyncCopyToCPU(void *data, size_t size) const { diff --git a/tests/python/unittest/test_sframe_iter.py b/tests/python/unittest/test_sframe_iter.py index 8b96aea60821..d6bbe293d8ef 100644 --- a/tests/python/unittest/test_sframe_iter.py +++ b/tests/python/unittest/test_sframe_iter.py @@ -20,23 +20,18 @@ @unittest.skipIf(__has_sframe__ is False and __has_graphlab__ is False, 'graphlab or sframe not found') class SFrameIteratorBaseTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - if __has_sframe__ is False and __has_graphlab__ is False: - return - cls.data = gl.SFrame({'x': np.random.randn(10), + def setUp(self): + self.data = gl.SFrame({'x': np.random.randn(10), 'y': np.random.randint(2, size=10)}) - cls.shape = [1] - cls.label_field = 'y' - cls.data_field = 'x' - cls.data_size = len(cls.data) - cls.data_expected = list(cls.data['x']) - cls.label_expected = list(cls.data['y']) - return cls + self.shape = (1,) + self.label_field = 'y' + self.data_field = 'x' + self.data_size = len(self.data) + self.data_expected = list(self.data['x']) + self.label_expected = list(self.data['y']) def test_one_batch(self): it = mxnet.io.SFrameIter(self.data, data_field=self.data_field, - data_shape=self.shape, label_field=self.label_field, batch_size=self.data_size) label_actual = [] @@ -50,7 +45,6 @@ def test_one_batch(self): def test_non_divisible_batch(self): batch_size = self.data_size + 1 it = mxnet.io.SFrameIter(self.data, data_field=self.data_field, - data_shape=self.shape, label_field=self.label_field, batch_size=batch_size) label_actual = [] @@ -70,11 +64,10 @@ def test_padding(self): padding = 5 batch_size = self.data_size + padding it = mxnet.io.SFrameIter(self.data, data_field=self.data_field, - data_shape=self.shape, label_field=self.label_field, batch_size=batch_size) label_expected = self.label_expected + [0.0] * padding - data_expected = self.data_expected + list(np.ndarray(self.shape).flatten()) * padding + data_expected = self.data_expected + list(np.zeros(self.shape).flatten()) * padding label_actual = [] data_actual = [] for d in it: @@ -84,48 +77,125 @@ def test_padding(self): np.testing.assert_almost_equal(data_actual, data_expected) def test_shape_inference(self): - # TODO - pass + it = mxnet.io.SFrameIter(self.data, data_field=self.data_field, + label_field=self.label_field, + batch_size=1) + self.assertEquals(it.infer_shape()["final_shape"], self.shape) + + def test_missing_value(self): + data = self.data.copy() + if not isinstance(self.data_field, list): + self.data_field = [self.data_field] + for col in self.data_field: + ls = list(data[col]) + ls[0] = None + data[col] = ls + it = mxnet.io.SFrameIter(data, data_field=self.data_field) + self.assertRaises(lambda: [it]) + class SFrameArrayIteratorTest(SFrameIteratorBaseTest): - @classmethod - def setUpClass(cls): - if __has_sframe__ is False and __has_graphlab__ is False: - return - cls.data = gl.SFrame({'x': [np.random.randn(5)] * 10, + def setUp(self): + self.data = gl.SFrame({'x': [np.random.randn(8)] * 10, 'y': np.random.randint(2, size=10)}) - cls.shape = [5] - cls.label_field = 'y' - cls.data_field = 'x' - cls.data_size = len(cls.data) - cls.data_expected = list(x for arr in cls.data['x'] for x in arr) - cls.label_expected = list(cls.data['y']) - return cls + self.shape = (8,) + self.label_field = 'y' + self.data_field = 'x' + self.data_size = len(self.data) + self.data_expected = list(x for arr in self.data['x'] for x in arr) + self.label_expected = list(self.data['y']) + + def test_size1_array(self): + # setup data + self.data = gl.SFrame({'x': [np.random.randn(1)] * 10, + 'y': np.random.randint(2, size=10)}) + self.shape = (1,) + self.label_field = 'y' + self.data_field = 'x' + self.data_size = len(self.data) + self.data_expected = list(x for arr in self.data['x'] for x in arr) + self.label_expected = list(self.data['y']) + + self.test_one_batch() + self.test_non_divisible_batch() + self.test_padding() + self.test_shape_inference() + + def test_zero_size_array(self): + self.data = gl.SFrame() + self.data['x'] = [array.array('d')] * 10 + it = mxnet.io.SFrameIter(self.data, data_field='x') + data_actual = [] + for d in it: + data_actual.extend(d.data[0].asnumpy().flatten()) + self.assertEquals(data_actual, []) + + def test_variable_size_array(self): + self.data = gl.SFrame({'x': [[0], [0, 1], [0, 1, 2]]}) + self.assertRaises(ValueError, lambda: mxnet.io.SFrameIter(self.data, data_field='x')) class SFrameImageIteratorTest(SFrameIteratorBaseTest): - @classmethod - def setUpClass(cls): - if __has_sframe__ is False and __has_graphlab__ is False: - return + def setUp(self): w = 2 h = 3 c = 1 - d = w * h * c - cls.data = gl.SFrame({'arr': [array.array('d', range(x, x+d)) for x in range(10)], + d = 6 + self.data = gl.SFrame({'arr': [array.array('d', range(x, x + d)) for x in range(10)], 'y': np.random.randint(2, size=10)}) - cls.data['img'] = cls.data['arr'].pixel_array_to_image(w, h, c) - cls.shape = (c, h, w) - cls.label_field = 'y' - cls.data_field = 'img' - cls.data_size = len(cls.data) - cls.data_expected = list(x for arr in cls.data['arr'] for x in arr) - cls.label_expected = list(cls.data['y']) - return cls + self.data['img'] = self.data['arr'].pixel_array_to_image(w, h, c) + self.shape = (c, h, w) + self.label_field = 'y' + self.data_field = 'img' + self.data_size = len(self.data) + self.data_expected = list(x for arr in self.data['arr'] for x in arr) + self.label_expected = list(self.data['y']) + + def test_encoded_image(self): + # resize encodes the image + self.data['img'] = gl.image_analysis.resize(self.data['img'], 2, 3, 1) + self.test_shape_inference() + self.test_padding() + self.test_one_batch() + self.test_missing_value() + self.test_non_divisible_batch() + + def test_variable_size_image(self): + shape1 = (2, 3, 1) + shape2 = (2, 2, 2) + tmp1 = gl.SArray([array.array('d', [0] * 6)]) + tmp2 = gl.SArray([array.array('d', [0] * 8)]) + data = gl.SFrame({'x': [tmp1.pixel_array_to_image(*shape1)[0], tmp2.pixel_array_to_image(*shape2)[0]]}) + it = mxnet.io.SFrameIter(data, data_field='x') + self.assertRaises(lambda: [it]) class SFrameMultiColumnIteratorTest(SFrameIteratorBaseTest): - @classmethod - def setUpClass(cls): - # TODO - pass + def setUp(self): + self.data = gl.SFrame({'i': [x for x in range(10)], + '-i': [-x for x in range(10)], + 'f': [float(x) for x in range(10)], + '-f': [-float(x) for x in range(10)], + 'arr': [range(2) for x in range(10)], + 'y': np.random.randint(2, size=10)}) + self.shape = (6,) + self.label_field = 'y' + self.data_field = ['i', '-i', 'f', '-f', 'arr'] + self.data_size = len(self.data) + def val_iter(): + for row in self.data: + for col in self.data_field: + v = row[col] + if type(v) is array.array: + for x in v: + yield x + else: + yield float(v) + self.data_expected = list(val_iter()) + self.label_expected = list(self.data['y']) + + def test_image_input_with_more_than_one_column(self): + data = self.data.copy() + data['img'] = [array.array('d', range(8))] * len(data) + data['img'] = data['img'].pixel_array_to_image(2,2,2) + self.assertRaises(ValueError, lambda: mxnet.io.SFrameIter(data, data_field=self.data_field + ['img']))