Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[C++] fix type inconsistent issue when loading quantized parameters (#…
Browse files Browse the repository at this point in the history
…15038)

* fix type inconsistent when using C++ API to load params file

* add test case

* fix cpplint

* address comment

* retrigger CI

* fix comments

* modify ci_test

* fix indentation
  • Loading branch information
wuxun-zhang authored and pengzhao-intel committed May 24, 2019
1 parent 5763ba9 commit 93fdcad
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 10 deletions.
62 changes: 62 additions & 0 deletions cpp-package/example/test_ndarray_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
#include <vector>
#include "mxnet/c_api.h"
#include "dmlc/logging.h"
#include "mxnet-cpp/MxNetCpp.h"
using namespace mxnet::cpp;

enum TypeFlag {
kFloat32 = 0,
kFloat64 = 1,
kFloat16 = 2,
kUint8 = 3,
kInt32 = 4,
kInt8 = 5,
kInt64 = 6,
};

/*
* The file is used for testing if there exist type inconsistency
* when using Copy API to create a new NDArray.
* By running: build/test_ndarray.
*/
int main(int argc, char** argv) {
std::vector<mx_uint> shape1{128, 2, 32};
Shape shape2(32, 8, 64);

int gpu_count = 0;
if (MXGetGPUCount(&gpu_count) != 0) {
LOG(ERROR) << "MXGetGPUCount failed";
return -1;
}

Context context = (gpu_count > 0) ? Context::gpu() : Context::cpu();

NDArray src1(shape1, context, true, kFloat16);
NDArray src2(shape2, context, false, kInt8);
NDArray dst1, dst2;
dst1 = src1.Copy(context);
dst2 = src2.Copy(context);
NDArray::WaitAll();
CHECK_EQ(src1.GetDType(), dst1.GetDType());
CHECK_EQ(src2.GetDType(), dst2.GetDType());
return 0;
}
9 changes: 6 additions & 3 deletions cpp-package/include/mxnet-cpp/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,21 @@ class NDArray {
/*!
* \brief construct a new dynamic NDArray
* \param shape the shape of array
* \param constext context of NDArray
* \param context context of NDArray
* \param delay_alloc whether delay the allocation
* \param dtype data type of NDArray
*/
NDArray(const std::vector<mx_uint> &shape, const Context &context,
bool delay_alloc = true);
bool delay_alloc = true, int dtype = 0);
/*!
* \brief construct a new dynamic NDArray
* \param shape the shape of array
* \param constext context of NDArray
* \param delay_alloc whether delay the allocation
* \param dtype data type of NDArray
*/
NDArray(const Shape &shape, const Context &context, bool delay_alloc = true);
NDArray(const Shape &shape, const Context &context,
bool delay_alloc = true, int dtype = 0);
NDArray(const mx_float *data, size_t size);
/*!
* \brief construct a new dynamic NDArray
Expand Down
15 changes: 8 additions & 7 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ inline NDArray::NDArray(const NDArrayHandle &handle) {
blob_ptr_ = std::make_shared<NDBlob>(handle);
}
inline NDArray::NDArray(const std::vector<mx_uint> &shape, const Context &context,
bool delay_alloc) {
bool delay_alloc, int dtype) {
NDArrayHandle handle;
CHECK_EQ(MXNDArrayCreate(shape.data(), shape.size(), context.GetDeviceType(),
context.GetDeviceId(), delay_alloc, &handle),
CHECK_EQ(MXNDArrayCreateEx(shape.data(), shape.size(), context.GetDeviceType(),
context.GetDeviceId(), delay_alloc, dtype, &handle),
0);
blob_ptr_ = std::make_shared<NDBlob>(handle);
}
inline NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) {
inline NDArray::NDArray(const Shape &shape, const Context &context,
bool delay_alloc, int dtype) {
NDArrayHandle handle;
CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(),
context.GetDeviceId(), delay_alloc, &handle),
CHECK_EQ(MXNDArrayCreateEx(shape.data(), shape.ndim(), context.GetDeviceType(),
context.GetDeviceId(), delay_alloc, dtype, &handle),
0);
blob_ptr_ = std::make_shared<NDBlob>(handle);
}
Expand Down Expand Up @@ -208,7 +209,7 @@ inline void NDArray::SyncCopyToCPU(std::vector<mx_float> *data, size_t size) {
MXNDArraySyncCopyToCPU(blob_ptr_->handle_, data->data(), size);
}
inline NDArray NDArray::Copy(const Context &ctx) const {
NDArray ret(GetShape(), ctx);
NDArray ret(GetShape(), ctx, true, this->GetDType());
Operator("_copyto")(*this).Invoke(ret);
return ret;
}
Expand Down
3 changes: 3 additions & 0 deletions cpp-package/tests/ci_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ cp ../../build/cpp-package/example/test_kvstore .
cp ../../build/cpp-package/example/test_score .
./test_score 0.93

cp ../../build/cpp-package/example/test_ndarray_copy .
./test_ndarray_copy

sh unittests/unit_test_mlp_csv.sh

cd inference
Expand Down

0 comments on commit 93fdcad

Please sign in to comment.