Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix predict c api #4537

Merged
merged 3 commits into from
Jan 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/image-classification/predict-cpp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ LDFLAGS+=`pkg-config --libs opencv`
# Added for mxnet
export MXNET_ROOT=`pwd`/../../../../mxnet

CFLAGS+= -I$(MXNET_ROOT)/include
CFLAGS+=-Wall -I$(MXNET_ROOT)/include
LDFLAGS+=$(MXNET_ROOT)/lib/libmxnet.so

image-classification-predict: image-classification-predict.o
Expand Down
4 changes: 2 additions & 2 deletions example/image-classification/predict-cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ It uses opencv for image reading
1. Edit image-classification-predict.cc file, change the following lines to your model paths:
```bash
// Models path for your model, you have to modify it
std::string json_file = "model/Inception/Inception_BN-symbol.json";
std::string param_file = "model/Inception/Inception_BN-0039.params";
std::string json_file = "model/Inception/Inception-BN-symbol.json";
std::string param_file = "model/Inception/Inception-BN-0126.params";
std::string synset_file = "model/Inception/synset.txt";
std::string nd_file = "model/Inception/mean_224.nd";
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
// Path for c_predict_api
#include <mxnet/c_predict_api.h>

#include <opencv2/core/core.hpp>
#include <opencv2/contrib/contrib.hpp>
#include <opencv2/highgui/highgui.hpp>

#include <iostream>
#include <fstream>
#include <string>
#include <vector>

#include <opencv2/core/core.hpp>
#include <opencv2/contrib/contrib.hpp>
#include <opencv2/highgui/highgui.hpp>


const mx_float DEFAULT_MEAN = 117.0;

// Read file to buffer
Expand All @@ -44,7 +45,9 @@ class BufferFile {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
assert(false);
length_ = 0;
buffer_ = NULL;
return;
}

ifs.seekg(0, std::ios::end);
Expand All @@ -65,8 +68,10 @@ class BufferFile {
}

~BufferFile() {
delete[] buffer_;
buffer_ = NULL;
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};

Expand Down Expand Up @@ -98,7 +103,7 @@ void GetImageFile(const std::string image_file,
uchar* data = im.ptr<uchar>(i);

for (int j = 0; j < im.cols; j++) {
if (image_data) {
if (mean_data) {
mean_r = *mean_data;
if (channels > 1) {
mean_g = *(mean_data + size / 3);
Expand All @@ -112,7 +117,6 @@ void GetImageFile(const std::string image_file,
}

*ptr_image_r++ = static_cast<mx_float>(*data++) - mean_r;;

}
}
}
Expand Down Expand Up @@ -172,8 +176,8 @@ int main(int argc, char* argv[]) {
test_file = std::string(argv[1]);

// Models path for your model, you have to modify it
std::string json_file = "model/Inception/Inception_BN-symbol.json";
std::string param_file = "model/Inception/Inception_BN-0039.params";
std::string json_file = "model/Inception/Inception-BN-symbol.json";
std::string param_file = "model/Inception/Inception-BN-0126.params";
std::string synset_file = "model/Inception/synset.txt";
std::string nd_file = "model/Inception/mean_224.nd";

Expand All @@ -199,7 +203,12 @@ int main(int argc, char* argv[]) {
static_cast<mx_uint>(height) };
PredictorHandle pred_hnd = 0;

//-- Create Predictor
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
return -1;
}

// Create Predictor
MXPredCreate((const char*)json_data.GetBuffer(),
(const char*)param_data.GetBuffer(),
static_cast<size_t>(param_data.GetLength()),
Expand All @@ -210,45 +219,47 @@ int main(int argc, char* argv[]) {
input_shape_indptr,
input_shape_data,
&pred_hnd);
assert(pred_hnd);

int image_size = width * height * channels;

//-- Read Mean Data
// Read Mean Data
const mx_float* nd_data = NULL;
NDListHandle nd_hnd = 0;
BufferFile nd_buf(nd_file);

NDListHandle nd_hnd;

mx_uint nd_index = 0;
mx_uint nd_len;
const mx_uint* nd_shape = 0;
const char* nd_key = 0;
const mx_float* nd_data = 0;
mx_uint nd_ndim = 0;
if (nd_buf.GetLength() > 0) {
mx_uint nd_index = 0;
mx_uint nd_len;
const mx_uint* nd_shape = 0;
const char* nd_key = 0;
mx_uint nd_ndim = 0;

MXNDListCreate((const char*)nd_buf.GetBuffer(),
MXNDListCreate((const char*)nd_buf.GetBuffer(),
nd_buf.GetLength(),
&nd_hnd, &nd_len);

MXNDListGet(nd_hnd, nd_index, &nd_key, &nd_data, &nd_shape, &nd_ndim);
MXNDListGet(nd_hnd, nd_index, &nd_key, &nd_data, &nd_shape, &nd_ndim);
}

//-- Read Image Data
// Read Image Data
std::vector<mx_float> image_data = std::vector<mx_float>(image_size);

GetImageFile(test_file, image_data.data(),
channels, cv::Size(width, height), nd_data);

//-- Set Input Image
// Set Input Image
MXPredSetInput(pred_hnd, "data", image_data.data(), image_size);

//-- Do Predict Forward
// Do Predict Forward
MXPredForward(pred_hnd);

mx_uint output_index = 0;

mx_uint *shape = 0;
mx_uint shape_len;

//-- Get Output Result
// Get Output Result
MXPredGetOutputShape(pred_hnd, output_index, &shape, &shape_len);

size_t size = 1;
Expand All @@ -258,18 +269,17 @@ int main(int argc, char* argv[]) {

MXPredGetOutput(pred_hnd, output_index, &(data[0]), size);

//-- Release

// Release NDList
MXNDListFree(nd_hnd);
if (nd_hnd)
MXNDListFree(nd_hnd);

// Release Predictor
MXPredFree(pred_hnd);

// Synset path for your model, you have to modify it
std::vector<std::string> synset = LoadSynset(synset_file);

//-- Print Output Data
// Print Output Data
PrintOutputResult(data, synset);

return 0;
Expand Down
12 changes: 10 additions & 2 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,17 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
MXAPIPredictor* ret = new MXAPIPredictor();
API_BEGIN();
Symbol sym;
// make sure symbols are registered
{
mx_uint outSize;
const char **outArray;
MXListAllOpNames(&outSize, &outArray);
}
// load in the symbol.
{
std::string json = symbol_json_str;
sym.outputs = nnvm::pass::LoadJSON(json).outputs;
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
}
// looks likely to output the internal results
if (num_output_nodes != 0) {
Expand Down Expand Up @@ -206,6 +213,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
arg_arrays,
grad_store, grad_req,
aux_arrays));
ret->out_shapes = out_shapes;
ret->out_arrays = ret->exec->outputs();
}
*out = ret;
Expand Down
1 change: 1 addition & 0 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Comm {
Context pinned_ctx() const {
return pinned_ctx_;
}

protected:
Context pinned_ctx_;
};
Expand Down