Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix predict c api
Browse files Browse the repository at this point in the history
  • Loading branch information
howard0su committed Jan 5, 2017
1 parent b5988bd commit ed70c8f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 36 deletions.
6 changes: 3 additions & 3 deletions example/image-classification/predict-cpp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ LDFLAGS+=`pkg-config --libs opencv`
# Added for mxnet
export MXNET_ROOT=`pwd`/../../../../mxnet

CFLAGS+= -I$(MXNET_ROOT)/include
CFLAGS+=-Wall -I$(MXNET_ROOT)/include
LDFLAGS+=$(MXNET_ROOT)/lib/libmxnet.so

image-classification-predict: image-classification-predict.o
g++ -O3 -o image-classification-predict image-classification-predict.o $(LDFLAGS)
g++ -O0 -o image-classification-predict image-classification-predict.o $(LDFLAGS)

image-classification-predict.o: image-classification-predict.cc
g++ -O3 -c image-classification-predict.cc ${CFLAGS}
g++ -O0 -c image-classification-predict.cc ${CFLAGS}

clean:
rm image-classification-predict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
// Path for c_predict_api
#include <mxnet/c_predict_api.h>

#include <opencv2/core/core.hpp>
#include <opencv2/contrib/contrib.hpp>
#include <opencv2/highgui/highgui.hpp>

#include <iostream>
#include <fstream>
#include <string>
#include <vector>

#include <opencv2/core/core.hpp>
#include <opencv2/contrib/contrib.hpp>
#include <opencv2/highgui/highgui.hpp>


const mx_float DEFAULT_MEAN = 117.0;

// Read file to buffer
Expand All @@ -44,7 +45,9 @@ class BufferFile {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
assert(false);
length_ = 0;
buffer_ = NULL;
return;
}

ifs.seekg(0, std::ios::end);
Expand All @@ -65,8 +68,10 @@ class BufferFile {
}

~BufferFile() {
delete[] buffer_;
buffer_ = NULL;
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};

Expand Down Expand Up @@ -98,7 +103,7 @@ void GetImageFile(const std::string image_file,
uchar* data = im.ptr<uchar>(i);

for (int j = 0; j < im.cols; j++) {
if (image_data) {
if (mean_data) {
mean_r = *mean_data;
if (channels > 1) {
mean_g = *(mean_data + size / 3);
Expand All @@ -112,7 +117,6 @@ void GetImageFile(const std::string image_file,
}

*ptr_image_r++ = static_cast<mx_float>(*data++) - mean_r;;

}
}
}
Expand Down Expand Up @@ -172,8 +176,8 @@ int main(int argc, char* argv[]) {
test_file = std::string(argv[1]);

// Models path for your model, you have to modify it
std::string json_file = "model/Inception/Inception_BN-symbol.json";
std::string param_file = "model/Inception/Inception_BN-0039.params";
std::string json_file = "model/Inception/Inception-BN-symbol.json";
std::string param_file = "model/Inception/Inception-BN-0126.params";
std::string synset_file = "model/Inception/synset.txt";
std::string nd_file = "model/Inception/mean_224.nd";

Expand All @@ -199,7 +203,12 @@ int main(int argc, char* argv[]) {
static_cast<mx_uint>(height) };
PredictorHandle pred_hnd = 0;

//-- Create Predictor
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
return -1;
}

// Create Predictor
MXPredCreate((const char*)json_data.GetBuffer(),
(const char*)param_data.GetBuffer(),
static_cast<size_t>(param_data.GetLength()),
Expand All @@ -210,45 +219,47 @@ int main(int argc, char* argv[]) {
input_shape_indptr,
input_shape_data,
&pred_hnd);
assert(pred_hnd);

int image_size = width * height * channels;

//-- Read Mean Data
// Read Mean Data
const mx_float* nd_data = NULL;
NDListHandle nd_hnd = 0;
BufferFile nd_buf(nd_file);

NDListHandle nd_hnd;

mx_uint nd_index = 0;
mx_uint nd_len;
const mx_uint* nd_shape = 0;
const char* nd_key = 0;
const mx_float* nd_data = 0;
mx_uint nd_ndim = 0;
if (nd_buf.GetLength() > 0) {
mx_uint nd_index = 0;
mx_uint nd_len;
const mx_uint* nd_shape = 0;
const char* nd_key = 0;
mx_uint nd_ndim = 0;

MXNDListCreate((const char*)nd_buf.GetBuffer(),
MXNDListCreate((const char*)nd_buf.GetBuffer(),
nd_buf.GetLength(),
&nd_hnd, &nd_len);

MXNDListGet(nd_hnd, nd_index, &nd_key, &nd_data, &nd_shape, &nd_ndim);
MXNDListGet(nd_hnd, nd_index, &nd_key, &nd_data, &nd_shape, &nd_ndim);
}

//-- Read Image Data
// Read Image Data
std::vector<mx_float> image_data = std::vector<mx_float>(image_size);

GetImageFile(test_file, image_data.data(),
channels, cv::Size(width, height), nd_data);

//-- Set Input Image
// Set Input Image
MXPredSetInput(pred_hnd, "data", image_data.data(), image_size);

//-- Do Predict Forward
// Do Predict Forward
MXPredForward(pred_hnd);

mx_uint output_index = 0;

mx_uint *shape = 0;
mx_uint shape_len;

//-- Get Output Result
// Get Output Result
MXPredGetOutputShape(pred_hnd, output_index, &shape, &shape_len);

size_t size = 1;
Expand All @@ -258,18 +269,17 @@ int main(int argc, char* argv[]) {

MXPredGetOutput(pred_hnd, output_index, &(data[0]), size);

//-- Release

// Release NDList
MXNDListFree(nd_hnd);
if (nd_hnd)
MXNDListFree(nd_hnd);

// Release Predictor
MXPredFree(pred_hnd);

// Synset path for your model, you have to modify it
std::vector<std::string> synset = LoadSynset(synset_file);

//-- Print Output Data
// Print Output Data
PrintOutputResult(data, synset);

return 0;
Expand Down
12 changes: 10 additions & 2 deletions src/c_api/c_predict_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,17 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
MXAPIPredictor* ret = new MXAPIPredictor();
API_BEGIN();
Symbol sym;
// make sure symbols are registered
{
mx_uint outSize;
const char **outArray;
MXListAllOpNames(&outSize, &outArray);
}
// load in the symbol.
{
std::string json = symbol_json_str;
sym.outputs = nnvm::pass::LoadJSON(json).outputs;
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
}
// looks likely to output the internal results
if (num_output_nodes != 0) {
Expand Down Expand Up @@ -206,6 +213,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str,
arg_arrays,
grad_store, grad_req,
aux_arrays));
ret->out_shapes = out_shapes;
ret->out_arrays = ret->exec->outputs();
}
*out = ret;
Expand Down

0 comments on commit ed70c8f

Please sign in to comment.