diff --git a/cpp-package/example/test_ndarray_copy.cpp b/cpp-package/example/test_ndarray_copy.cpp new file mode 100644 index 000000000000..a3b3011993fa --- /dev/null +++ b/cpp-package/example/test_ndarray_copy.cpp @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ +#include +#include "mxnet/c_api.h" +#include "dmlc/logging.h" +#include "mxnet-cpp/MxNetCpp.h" +using namespace mxnet::cpp; + +enum TypeFlag { + kFloat32 = 0, + kFloat64 = 1, + kFloat16 = 2, + kUint8 = 3, + kInt32 = 4, + kInt8 = 5, + kInt64 = 6, +}; + +/* + * The file is used for testing if there exist type inconsistency + * when using Copy API to create a new NDArray. + * By running: build/test_ndarray. + */ +int main(int argc, char** argv) { + std::vector shape1{128, 2, 32}; + Shape shape2(32, 8, 64); + + int gpu_count = 0; + if (MXGetGPUCount(&gpu_count) != 0) { + LOG(ERROR) << "MXGetGPUCount failed"; + return -1; + } + + Context context = (gpu_count > 0) ? Context::gpu() : Context::cpu(); + + NDArray src1(shape1, context, true, kFloat16); + NDArray src2(shape2, context, false, kInt8); + NDArray dst1, dst2; + dst1 = src1.Copy(context); + dst2 = src2.Copy(context); + NDArray::WaitAll(); + CHECK_EQ(src1.GetDType(), dst1.GetDType()); + CHECK_EQ(src2.GetDType(), dst2.GetDType()); + return 0; +} diff --git a/cpp-package/include/mxnet-cpp/ndarray.h b/cpp-package/include/mxnet-cpp/ndarray.h index 6f37d91aa68e..7953b61393fe 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.h +++ b/cpp-package/include/mxnet-cpp/ndarray.h @@ -131,18 +131,21 @@ class NDArray { /*! * \brief construct a new dynamic NDArray * \param shape the shape of array - * \param constext context of NDArray + * \param context context of NDArray * \param delay_alloc whether delay the allocation + * \param dtype data type of NDArray */ NDArray(const std::vector &shape, const Context &context, - bool delay_alloc = true); + bool delay_alloc = true, int dtype = 0); /*! * \brief construct a new dynamic NDArray * \param shape the shape of array * \param constext context of NDArray * \param delay_alloc whether delay the allocation + * \param dtype data type of NDArray */ - NDArray(const Shape &shape, const Context &context, bool delay_alloc = true); + NDArray(const Shape &shape, const Context &context, + bool delay_alloc = true, int dtype = 0); NDArray(const mx_float *data, size_t size); /*! * \brief construct a new dynamic NDArray diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index d0438305a62e..283fff1a3b92 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -47,17 +47,18 @@ inline NDArray::NDArray(const NDArrayHandle &handle) { blob_ptr_ = std::make_shared(handle); } inline NDArray::NDArray(const std::vector &shape, const Context &context, - bool delay_alloc) { + bool delay_alloc, int dtype) { NDArrayHandle handle; - CHECK_EQ(MXNDArrayCreate(shape.data(), shape.size(), context.GetDeviceType(), - context.GetDeviceId(), delay_alloc, &handle), + CHECK_EQ(MXNDArrayCreateEx(shape.data(), shape.size(), context.GetDeviceType(), + context.GetDeviceId(), delay_alloc, dtype, &handle), 0); blob_ptr_ = std::make_shared(handle); } -inline NDArray::NDArray(const Shape &shape, const Context &context, bool delay_alloc) { +inline NDArray::NDArray(const Shape &shape, const Context &context, + bool delay_alloc, int dtype) { NDArrayHandle handle; - CHECK_EQ(MXNDArrayCreate(shape.data(), shape.ndim(), context.GetDeviceType(), - context.GetDeviceId(), delay_alloc, &handle), + CHECK_EQ(MXNDArrayCreateEx(shape.data(), shape.ndim(), context.GetDeviceType(), + context.GetDeviceId(), delay_alloc, dtype, &handle), 0); blob_ptr_ = std::make_shared(handle); } @@ -208,7 +209,7 @@ inline void NDArray::SyncCopyToCPU(std::vector *data, size_t size) { MXNDArraySyncCopyToCPU(blob_ptr_->handle_, data->data(), size); } inline NDArray NDArray::Copy(const Context &ctx) const { - NDArray ret(GetShape(), ctx); + NDArray ret(GetShape(), ctx, true, this->GetDType()); Operator("_copyto")(*this).Invoke(ret); return ret; } diff --git a/cpp-package/tests/ci_test.sh b/cpp-package/tests/ci_test.sh index 2d1f8e4f68e6..ef7fceacfd6e 100755 --- a/cpp-package/tests/ci_test.sh +++ b/cpp-package/tests/ci_test.sh @@ -57,6 +57,9 @@ cp ../../build/cpp-package/example/test_kvstore . cp ../../build/cpp-package/example/test_score . ./test_score 0.93 +cp ../../build/cpp-package/example/test_ndarray_copy . +./test_ndarray_copy + sh unittests/unit_test_mlp_csv.sh cd inference