Skip to content

Commit

Permalink
[io] sframe iter
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Jan 13, 2016
1 parent 08466f8 commit 25f5a26
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 4 deletions.
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ ifeq ($(USE_DIST_KVSTORE), 1)
LDFLAGS += $(PS_LDFLAGS_A)
endif

include $(MXNET_PLUGINS)

.PHONY: clean all test lint doc clean_all rcpplint rcppexport roxygen

all: lib/libmxnet.a lib/libmxnet.so $(BIN)
Expand Down Expand Up @@ -117,7 +119,7 @@ ifeq ($(USE_TORCH), 1)
ifeq ($(USE_CUDA), 1)
LDFLAGS += -lcutorch -lcunn
endif

TORCH_SRC = $(wildcard plugin/torch/*.cc)
PLUGIN_OBJ += $(patsubst %.cc, build/%.o, $(TORCH_SRC))
TORCH_CUSRC = $(wildcard plugin/torch/*.cu)
Expand Down Expand Up @@ -200,7 +202,7 @@ include tests/cpp/unittest.mk
test: $(TEST)

lint: rcpplint
python2 dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src plugin scripts python predict/python
python2 dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src plugin scripts python predict/python

doc: doxygen

Expand Down
5 changes: 5 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,8 @@ EXTRA_OPERATORS =
# whether to use torch integration. This requires installing torch.
USE_TORCH = 0
TORCH_PATH = $(HOME)/torch

# whether to use sframe integration. This requires build sframe
# [email protected]:dato-code/SFrame.git
# SFRAME_PATH = $(HOME)/SFrame
# MXNET_PLUGINS += plugin/sframe/SFrame.mk
7 changes: 7 additions & 0 deletions plugin/sframe/SFrame.mk
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SFRMAE_SRC = plugin/sframe/iter_sframe.cc
PLUGIN_OBJ += build/plugin/sframe/iter_sframe.o
CFLAGS += -I$(SFRAME_PATH)/oss_src/unity/lib/
CFLAGS += -I$(SFRAME_PATH)/oss_src/
LDFLAGS += -L$(SFRAME_PATH)/release/oss_src/unity/python/sframe/
LDFLAGS += -lunity_shared
LDFLAGS += -lboost_system
225 changes: 225 additions & 0 deletions plugin/sframe/iter_sframe.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*!
* Copyright (c) 2015 by Contributors
* \file iter_sframe_image.cc
* \brief
* \author Bing Xu
*/

#include <unity/lib/image_util.hpp>
#include <unity/lib/gl_sframe.hpp>
#include <unity/lib/gl_sarray.hpp>
#include <mxnet/io.h>
#include <dmlc/base.h>
#include <dmlc/io.h>
#include <dmlc/omp.h>
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <string>
#include <memory>
#include "../../src/io/inst_vector.h"
#include "../../src/io/image_recordio.h"
#include "../../src/io/image_augmenter.h"
#include "../../src/io/iter_prefetcher.h"
#include "../../src/io/iter_normalize.h"
#include "../../src/io/iter_batchloader.h"

namespace mxnet {
namespace io {

struct SFrameParam : public dmlc::Parameter<SFrameParam> {
/*! \brief sframe path */
std::string path_sframe;
std::string data_field;
std::string label_field;
TShape data_shape;
TShape label_shape;
DMLC_DECLARE_PARAMETER(SFrameParam) {
DMLC_DECLARE_FIELD(path_sframe).set_default("")
.describe("Dataset Param: path to image dataset sframe");
DMLC_DECLARE_FIELD(data_field).set_default("data")
.describe("Dataset Param: data column in sframe");
DMLC_DECLARE_FIELD(label_field).set_default("label")
.describe("Dataset Param: label column in sframe");
DMLC_DECLARE_FIELD(data_shape)
.describe("Dataset Param: input data instance shape");
DMLC_DECLARE_FIELD(label_shape)
.describe("Dataset Param: input label instance shape");
}
}; // struct SFrameImageParam

class SFrameIterBase : public IIterator<DataInst> {
public:
SFrameIterBase() {}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.InitAllowUnknown(kwargs);
sframe_ = graphlab::gl_sframe(param_.path_sframe)[{param_.data_field, param_.label_field}];
range_it_.reset(new graphlab::gl_sframe_range(sframe_.range_iterator()));
this->BeforeFirst();
}

virtual ~SFrameIterBase() {}

virtual void BeforeFirst() {
idx_ = 0;
*range_it_ = sframe_.range_iterator();
current_it_ = range_it_->begin();
}

virtual const DataInst &Value(void) const {
return out_;
}

virtual bool Next() = 0;

protected:
/*! \brief index of instance */
index_t idx_;
/*! \brief output of sframe iterator */
DataInst out_;
/*! \brief temp space */
InstVector tmp_;
/*! \brief sframe iter parameter */
SFrameParam param_;
/*! \brief sframe object*/
graphlab::gl_sframe sframe_;
/*! \brief sframe range iterator */
std::unique_ptr<graphlab::gl_sframe_range> range_it_;
/*! \brief current iterator in range iterator */
graphlab::gl_sframe_range::iterator current_it_;

protected:
/*! \brief copy data */
template<int dim>
void Copy_(mshadow::Tensor<cpu, dim> tensor, const graphlab::flex_vec &vec) {
CHECK_EQ(tensor.shape_.Size(), vec.size());
CHECK_EQ(tensor.CheckContiguous(), true);
mshadow::Tensor<cpu, 1> flatten(tensor.dptr_, mshadow::Shape1(tensor.shape_.Size()));
for (index_t i = 0; i < vec.size(); ++i) {
flatten[i] = static_cast<float>(vec[i]);
}
}
}; // class SFrameIterBase

class SFrameImageIter : public SFrameIterBase {
public:
SFrameImageIter() :
augmenter_(new ImageAugmenter()), prnd_(new common::RANDOM_ENGINE(8964)) {}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
Parent::Init(kwargs);
augmenter_->Init(kwargs);
CHECK_EQ(Parent::param_.data_shape.ndim(), 3)
<< "Image shpae must be (channel, height, width)";
}

bool Next(void) override {
if (Parent::current_it_ == Parent::range_it_->end()) {
return false;
}
graphlab::image_type gl_img = (*Parent::current_it_)[0];
graphlab::flex_vec gl_label = (*Parent::current_it_)[1];
// TODO(bing): check not decoded
// TODO(bing): check img shape
CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match";
const unsigned char *raw_data = gl_img.get_image_data();
cv::Mat res;
cv::Mat buf(1, gl_img.m_image_data_size, CV_8U, const_cast<unsigned char*>(raw_data));
res = cv::imdecode(buf, -1);
res = augmenter_->Process(res, prnd_.get());
const int n_channels = res.channels();
if (!tmp_.Size()) {
tmp_.Push(Parent::idx_++,
Parent::param_.data_shape.get<3>(),
Parent::param_.label_shape.get<1>());
}
mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back();
std::vector<int> swap_indices;
if (n_channels == 1) swap_indices = {0};
if (n_channels == 3) swap_indices = {2, 1, 0};
for (int i = 0; i < res.rows; ++i) {
uchar* im_data = res.ptr<uchar>(i);
for (int j = 0; j < res.cols; ++j) {
for (int k = 0; k < n_channels; ++k) {
data[k][i][j] = im_data[swap_indices[k]];
}
im_data += n_channels;
}
}
mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back();
Parent::Copy_<1>(label, gl_label);
res.release();
out_ = Parent::tmp_[0];
++current_it_;
return true;
}

private:
/*! \brief parent type */
typedef SFrameIterBase Parent;
/*! \brief image augmenter */
std::unique_ptr<ImageAugmenter> augmenter_;
/*! \brief randim generator*/
std::unique_ptr<common::RANDOM_ENGINE> prnd_;
}; // class SFrameImageIter

class SFrameDataIter : public SFrameIterBase {
public:
bool Next() override {
if (Parent::current_it_ == Parent::range_it_->end()) {
return false;
}
graphlab::flex_vec gl_data = (*Parent::current_it_)[0];
graphlab::flex_vec gl_label = (*Parent::current_it_)[1];
CHECK_EQ(gl_data.size(), Parent::param_.data_shape.Size()) << "Data shape does not match";
CHECK_EQ(gl_label.size(), Parent::param_.label_shape.Size()) << "Label shape does not match";
if (!Parent::tmp_.Size()) {
Parent::tmp_.Push(Parent::idx_++,
Parent::param_.data_shape.get<3>(),
Parent::param_.label_shape.get<1>());
}
mshadow::Tensor<cpu, 3> data = Parent::tmp_.data().Back();
Parent::Copy_<3>(data, gl_data);
mshadow::Tensor<cpu, 1> label = Parent::tmp_.label().Back();
Parent::Copy_<1>(label, gl_label);
out_ = Parent::tmp_[0];
++current_it_;
return true;
}

private:
/*! \brief parent type */
typedef SFrameIterBase Parent;
}; // class SFrameDataIter

DMLC_REGISTER_PARAMETER(SFrameParam);

MXNET_REGISTER_IO_ITER(SFrameImageIter)
.describe("Naive SFrame image iterator prototype")
.add_arguments(SFrameParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.add_arguments(ImageAugmentParam::__FIELDS__())
.add_arguments(ImageNormalizeParam::__FIELDS__())
.set_body([]() {
return new PrefetcherIter(
new BatchLoader(
new ImageNormalizeIter(
new SFrameImageIter())));
});

MXNET_REGISTER_IO_ITER(SFrameDataIter)
.describe("Naive SFrame data iterator prototype")
.add_arguments(SFrameParam::__FIELDS__())
.add_arguments(BatchParam::__FIELDS__())
.add_arguments(PrefetcherParam::__FIELDS__())
.set_body([]() {
return new PrefetcherIter(
new BatchLoader(
new SFrameDataIter()));
});


} // namespace io
} // namespace mxnet

2 changes: 1 addition & 1 deletion ps-lite
Submodule ps-lite updated 1 files
+36 −19 include/ps/sarray.h

0 comments on commit 25f5a26

Please sign in to comment.