From 30b227196f7cd40a5feb10acd511d50d9ddfef66 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 16:36:36 -0700 Subject: [PATCH 01/29] initial. Supports Tensor, TBlob, NDArray --- src/common/tensor_inspector.h | 521 ++++++++++++++++++++++++++++++++++ 1 file changed, 521 insertions(+) create mode 100644 src/common/tensor_inspector.h diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h new file mode 100644 index 000000000000..053469e6c6a5 --- /dev/null +++ b/src/common/tensor_inspector.h @@ -0,0 +1,521 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file tensor_inspector.h + * \brief utility to inspector tensor objects + * \author Zhaoqi Zhu +*/ + +#ifndef MXNET_COMMON_TENSOR_INSPECTOR_H_ +#define MXNET_COMMON_TENSOR_INSPECTOR_H_ + +#include +#include +#include "../../3rdparty/mshadow/mshadow/base.h" +#include "../../tests/cpp/include/test_util.h" + +namespace mxnet{ + +/*! + * \brief This singleton struct mediates individual TensorInspector objects + * so that we can control the global behavior from each of them + */ +struct InspectorManager { + static InspectorManager* get() { + static std::mutex mtx; + static std::unique_ptr im = nullptr; + if (!im) { + std::unique_lock lk(mtx); + if (!im) + im = std::make_unique(); + } + return im.get(); + } + /* !\brief mutex used to lock interactive_print() and check_value() */ + std::mutex mutex_; + /* !\brief skip all interactive prints */ + bool interactive_print_skip_all_ = false; + /* !\brief skip all value checks */ + bool check_value_skip_all_ = false; + /* !\brief visit count for interactive print tags */ + std::unordered_map interactive_print_tag_counter_; + /* !\brief visit count for check value tags */ + std::unordered_map check_value_tag_counter_; +}; + +/*! + * \brief Enum for building value checkers for TensorInspector::check_value() + */ +enum CheckerType { + NegativeChecker, // check if is negative + PositiveChecker, // check if is positive + NanChecker // check if is Nan, will always return false if DType is not a float type +}; + +/** + * _______ _____ _ + * |__ __| |_ _| | | + * | | ___ _ __ ___ ___ _ __| | _ __ ___ _ __ ___ ___| |_ ___ _ __ + * | |/ _ \ '_ \/ __|/ _ \| '__| | | '_ \/ __| '_ \ / _ \/ __| __/ _ \| '__| + * | | __/ | | \__ \ (_) | | _| |_| | | \__ \ |_) | __/ (__| || (_) | | + * |_|\___|_| |_|___/\___/|_||_____|_| |_|___/ .__/ \___|\___|\__\___/|_| + * | | + * |_| + */ + +/*! + * \brief This class provides a unified interface to inspect the value of all data types + * including Tensor, TBlob, and NDArray. If the tensor resides on GPU, then it will be + * copied from GPU memory back to CPU memory to be operated on. Internally, all data types + * are stored as a TBlob object tb_. + */ +class TensorInspector { + /*! + * \brief generate the tensor info, including data type and shape + * \tparam DType the data type + * \tparam StreamType the type of the stream object + * \param os stream object to output to + */ + template + inline void tensor_info_to_string(StreamType& os) { + int dimension = tb_.ndim(); + os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; + os << tb_.shape_[0]; + for (int i = 1; i < dimension; i++) { + os << 'x' << tb_.shape_[i]; + } + os << ">" << std::endl; + } + + /*! + * \brief output the tensor info, including data type and shape + * \tparam DType the data type + * \tparam StreamType the type of the stream object + * \param os stream object to output to + * \param shape the shape of the tensor + */ + template + inline void tensor_info_to_string(StreamType& os, const std::vector& shape) { + int dimension = shape.size(); + os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; + os << shape[0]; + for (int i = 1; i < dimension; i++) { + os << 'x' << shape[i]; + } + os << ">" << std::endl; + } + + /*! + * \brief output the tensor in a structed format + * \tparam DType the data type + * \tparam StreamType the type of the stream object + * \param ctx the run context of the tensor + * \param os stream object to output to + */ + template + inline void to_string_helper(const RunContext& ctx, StreamType& os) { + int dimension = tb_.ndim(); + std::vector multiples; + int multiple = 1; + for (int i = dimension-1; i >= 0; i--) { + multiple *= tb_.shape_[i]; + multiples.push_back(multiple); + } + os << std::string(dimension, '['); + os << tb_.dptr()[0]; + for (size_t i = 1; i < tb_.shape_.Size(); i++) { + int n = 0; + for (auto divisor : multiples) { + n += (i % divisor == 0); + } + if (n) { + os << std::string(n, ']') << ", " << std::string(n, '['); + } else { + os << ", "; + } + os << tb_.dptr()[i]; + } + os << std::string(dimension, ']') << std::endl; + tensor_info_to_string(os); + } + + /*! + * \brief output the tensor in a structed format + * \tparam DType the data type + * \tparam StreamType the type of the stream object + * \param ctx the run context of the tensor + * \param os stream object to output to + * \param dptr the data pointer + */ + template + inline void to_string_helper(const RunContext& ctx, StreamType& os, const DType* dptr) { + os << *dptr << std::endl; + os << "<" << typeid(*dptr).name() << ">" << std::endl; + } + + /*! + * \brief output a part of the tensor in a structed format + * \tparam DType the data type + * \tparam StreamType the type of the stream object + * \param ctx the run context of the tensor + * \param os stream object to output to + * \param sub_shape the sub-shape of the desired part of the tensor + * \param offset the position of the first value of the desired part of the tensor + */ + template + inline void to_string_helper(const RunContext& ctx, StreamType& os, const std::vector& sub_shape, size_t offset) { + DType* dptr = tb_.dptr() + offset; + if (sub_shape.size() == 0) { + to_string_helper(ctx, os, dptr); + return; + } + int dimension = sub_shape.size(); + std::vector multiples; + size_t multiple = 1; + for (int i = dimension-1; i >= 0; i--) { + multiple *= sub_shape[i]; + multiples.push_back(multiple); + } + std::stringstream ss; + os << std::string(dimension, '['); + os << dptr[0]; + for (size_t i = 1; i < multiple; i++) { + int n = 0; + for (auto divisor : multiples) { + n += (i % divisor == 0); + } + if (n) { + os << std::string(n, ']') << ", " << std::string(n, '['); + } else { + os << ", "; + } + os << dptr[i]; + } + os << std::string(dimension, ']') << std::endl; + tensor_info_to_string(os, sub_shape); + } + + /*! + * \brief helper functino to calculate the sub_shape and offset for the desired part of the tensor, + * given its coordinates in the original tensor + * \param pos the coordinates of the desired part of the tensor + * \param sub_shape the sub-shape of the desired part of the tensor; calculated here + * \param offset the position of the first value of the desired part of the tensor; calculated here + */ + inline void print_locator(const std::vector& pos, std::vector& sub_shape, size_t& offset) { + int dimension = tb_.ndim(); + int sub_dim = dimension - pos.size(); + sub_shape.resize(sub_dim); + int multiple = 1; + for (int i = pos.size(), j = 0; i < dimension; i++, j++) { + sub_shape[j] = tb_.shape_[i]; + multiple *= tb_.shape_[i]; + } + int sum = 0; + int m = 1; + for (int i = pos.size()-1; i >= 0; i--) { + sum += pos[i] * m; + m *= tb_.shape_[i]; + } + offset = sum * multiple; + } + + /*! + * \brief parse the coordinate of the desired part of the tensor, given a string that represents that + * coordinate + * \param pos the coordinates of the desired part of the tensor, calculated here + * \param str the string that represents the coordinate + */ + inline bool parse_position(std::vector& pos, const std::string& str) { + int dimension = tb_.ndim(); + std::stringstream ss(str); + int i; + while (ss >> i) { + pos.push_back(i); + if (ss.peek() == ',') { + ss.ignore(); + } + } + if (pos.size() > dimension) { + return false; + } + for (unsigned i = 0; i < pos.size(); i++) { + if (pos[i] > (tb_.shape_[i]-1)) { + return false; + } + } + return !pos.empty(); + } + + /*! + * \brief interactive print the tensor value + * \tparam DType the data type + * \param ctx the run context of the tensor + * \param tag the name given to this call + */ + template + inline void interactive_print_helper(const RunContext& ctx, std::string tag) { + std::lock_guard lock(InspectorManager::get()->mutex_); + InspectorManager::get()->interactive_print_tag_counter_[tag] += 1; + while (!InspectorManager::get()->interactive_print_skip_all_) { + std::cout << "----------Interactive Print----------" << std::endl; + if (tag != "") { + std::cout << "Tag: " << tag << " Visit: " << + InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; + } + tensor_info_to_string(std::cout); + std::cout << "Please specify the position, seperated by \",\"" << std::endl + << "\"e\" for the entire tensor, \"b\" to break, \"s\" to skip all: " << std::endl; + std::string str; + std::cin >> str; + if (str == "b") { + break; + } else if (str == "e") { + to_string_helper(ctx, std::cout); + continue; + } else if (str == "s") { + InspectorManager::get()->interactive_print_skip_all_ = true; + break; + } + std::vector pos; + if (parse_position(pos, str)) { + std::vector sub_shape; + size_t offset; + print_locator(pos, sub_shape, offset); + to_string_helper(ctx, std::cout, sub_shape, offset); + } else { + std::cout << "invalid input" << std::endl; + } + } + } + + /*! + * \brief calculate the coordinate of a value in the tensor, given its index + * \param idx the index of the value in the tensor + */ + inline std::vector index_to_coordinates(size_t idx){ + int dimension = tb_.ndim(); + std::vector ret; + for (int i = dimension-1; i >= 0; i--) { + ret.push_back(idx % tb_.shape_[i]); + idx /= tb_.shape_[i]; + } + std::reverse(ret.begin(), ret.end()); + return ret; + } + + /*! + * \brief check/validate the values within the tensor, return the coordinates + * where the lambda evaluates to true + * \tparam DType the data type + * \param ctx the run context of the tensor + * \param checker the lambda function to check each value of within the tensor + * \param interactive wherether to allow the user to interactively check the coordinates + * \param tag the name given to this call + */ + template + inline std::vector> check_value_helper(const RunContext& ctx, + const std::function& checker, bool interactive, std::string tag) { + std::vector> ret; + int count = 0; + std::stringstream ss; + ss << "["; + bool first_pass = true; + for (size_t i = 0; i < tb_.shape_.Size(); i++) { + if (checker(tb_.dptr()[i])) { + count += 1; + if (!first_pass) { + ss << ", "; + } + first_pass = false; + std::vector coords = index_to_coordinates(i); + ss << "(" << coords[0]; + for (size_t i = 1; i < coords.size(); i++) { + ss << ", " << coords[i]; + } + ss << ")"; + ret.push_back(coords); + } + } + ss << "]" << std::endl; + if (interactive) { + std::lock_guard lock(InspectorManager::get()->mutex_); + InspectorManager::get()->check_value_tag_counter_[tag] += 1; + while (!InspectorManager::get()->check_value_skip_all_) { + std::cout << "----------Value Check----------" << std::endl; + if (tag != "") { + std::cout << "Tag: " << tag << " Visit: " << InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; + } + std::cout << count << " value(s) found. \"p\" to print the coordinates, \"b\" to break, \"s\" to skip all: "; + std::string str; + std::cin >> str; + if (str == "b") { + break; + } else if (str == "p") { + std::cout << ss.str() << std::endl; + } else if (str == "s") { + InspectorManager::get()->check_value_skip_all_ = true; + } + } + } + + return ret; + } + + /*! + * \brief build the lambda function, aka the checker, given its type + * \tparam DType the data type + * \param ct the type of the checker + */ + template + inline std::function build_checker(CheckerType ct){ + switch (ct) { + case NegativeChecker: + return [] (DType x) { + return x < 0; + }; + case PositiveChecker: + return [] (DType x) { + return x < 0; + }; + case NanChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x != x; + }; + } else { + LOG(WARNING) << "NanChecker only applies to float types. " << + "Lambda will always return false."; + } + break; + default: + return [] (DType x) { + return false; + }; + } + return [] (DType x) {return false;}; + } + + public: + /*! + * \brief Construct from Tensor object + * \tparam Device the device the tensor resides in + * \tparam dimension the dimension of the tensor + * \tparam DType the data type + * \param ts the source tensor obeject + */ + template + TensorInspector(const Tensor& ts) : tb_(ts) {} + + /*! + * \brief Construct from TBlob object + * \tparam Device the device the tensor resides in + * \tparam dimension the dimension of the tensor + * \tparam DType the data type + * \param ts the source tensor obeject + */ + TensorInspector(const TBlob& tb) : tb_(tb) {} + + /*! + * \brief Construct from NDArray object. Currently this only works with kDefaultStorage + * \tparam Device the device the tensor resides in + * \tparam dimension the dimension of the tensor + * \tparam DType the data type + * \param ts the source tensor obeject + */ + TensorInspector(const NDArray& arr) : tb_(arr.data()){} + + /*! + * \brief print the tensor to std::cout + * \param ctx the run context of the tensor + */ + inline void print_string(const RunContext& ctx) { + std::cout << to_string(ctx) << std::endl; + } + + /*! + * \brief return a string which contains the values and other info of the tensor + * \param ctx the run context of the tensor + */ + inline std::string to_string(const RunContext& ctx) { + std::stringstream ss; + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { + to_string_helper(ctx, ss); + }); + return ss.str(); + } + + /*! + * \brief interactive print the tensor value + * \param ctx the run context of the tensor + * \param tag the name given to this call + */ + inline void interactive_print(const RunContext& ctx, std::string tag = "") { + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { + interactive_print_helper(ctx, tag); + }); + } + + /*! + * \brief check/validate the values within the tensor, return the coordinates + * where the lambda evaluates to true + * \tparam ValueChecker the type of the lambda + * \param ctx the run context of the tensor + * \param checker the lambda function to check each value of within the tensor + * \param interactive wherether to allow the user to interactively check the coordinates + * \param tag the name given to this call + */ + template + std::vector> check_value(const RunContext& ctx, const ValueChecker& checker, + bool interactive = false, std::string tag = "") { + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { + return check_value_helper(ctx, checker, interactive, tag); + }); + return std::vector>(); + } + + /*! + * \brief check/validate the values within the tensor, return the coordinates + * where the lambda evaluates to true + * \param ctx the run context of the tensor + * \param ct the type of the checker + * \param interactive wherether to allow the user to interactively check the coordinates + * \param tag the name given to this call + */ + std::vector> check_value(const RunContext& ctx, CheckerType ct, + bool interactive = false, std::string tag = "") { + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { + return check_value_helper(ctx, build_checker(ct), interactive, tag); + }); + return std::vector>(); + } + + private: + /* !\brief the tensor blob */ + const TBlob tb_; +}; + + +} // namespace mxnet + +#endif // MXNET_COMMON_TENSOR_INSPECTOR_H_ From 571c65a51e72ee5b73e64668b577067eb073711a Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 16:41:03 -0700 Subject: [PATCH 02/29] add GPU tensor support --- src/common/tensor_inspector.h | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 053469e6c6a5..6de928df4886 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -132,6 +132,12 @@ class TensorInspector { */ template inline void to_string_helper(const RunContext& ctx, StreamType& os) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).to_string_helper(ctx, os); + return; + } +#endif // MXNET_USE_CUDA int dimension = tb_.ndim(); std::vector multiples; int multiple = 1; @@ -167,6 +173,12 @@ class TensorInspector { */ template inline void to_string_helper(const RunContext& ctx, StreamType& os, const DType* dptr) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).to_string_helper(ctx, os, dptr); + return; + } +#endif // MXNET_USE_CUDA os << *dptr << std::endl; os << "<" << typeid(*dptr).name() << ">" << std::endl; } @@ -182,6 +194,12 @@ class TensorInspector { */ template inline void to_string_helper(const RunContext& ctx, StreamType& os, const std::vector& sub_shape, size_t offset) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).to_string_helper(ctx, os, sub_shape, offset); + return; + } +#endif // MXNET_USE_CUDA DType* dptr = tb_.dptr() + offset; if (sub_shape.size() == 0) { to_string_helper(ctx, os, dptr); @@ -273,6 +291,12 @@ class TensorInspector { */ template inline void interactive_print_helper(const RunContext& ctx, std::string tag) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).interactive_print_helper(ctx, tag); + return; + } +#endif // MXNET_USE_CUDA std::lock_guard lock(InspectorManager::get()->mutex_); InspectorManager::get()->interactive_print_tag_counter_[tag] += 1; while (!InspectorManager::get()->interactive_print_skip_all_) { @@ -334,6 +358,12 @@ class TensorInspector { template inline std::vector> check_value_helper(const RunContext& ctx, const std::function& checker, bool interactive, std::string tag) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + return TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).check_value_helper(ctx, + checker, interactive, tag); + } +#endif // MXNET_USE_CUDA std::vector> ret; int count = 0; std::stringstream ss; From 0ae0bae733e7c3aed30af881dbb1680686917f09 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 16:58:38 -0700 Subject: [PATCH 03/29] move run context reference to constructors --- src/common/tensor_inspector.h | 74 +++++++++++++++++------------------ 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 6de928df4886..5843f59a2724 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -127,14 +127,13 @@ class TensorInspector { * \brief output the tensor in a structed format * \tparam DType the data type * \tparam StreamType the type of the stream object - * \param ctx the run context of the tensor * \param os stream object to output to */ template - inline void to_string_helper(const RunContext& ctx, StreamType& os) { + inline void to_string_helper(StreamType& os) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).to_string_helper(ctx, os); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os); return; } #endif // MXNET_USE_CUDA @@ -167,15 +166,14 @@ class TensorInspector { * \brief output the tensor in a structed format * \tparam DType the data type * \tparam StreamType the type of the stream object - * \param ctx the run context of the tensor * \param os stream object to output to * \param dptr the data pointer */ template - inline void to_string_helper(const RunContext& ctx, StreamType& os, const DType* dptr) { + inline void to_string_helper(StreamType& os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).to_string_helper(ctx, os, dptr); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, dptr); return; } #endif // MXNET_USE_CUDA @@ -187,22 +185,21 @@ class TensorInspector { * \brief output a part of the tensor in a structed format * \tparam DType the data type * \tparam StreamType the type of the stream object - * \param ctx the run context of the tensor * \param os stream object to output to * \param sub_shape the sub-shape of the desired part of the tensor * \param offset the position of the first value of the desired part of the tensor */ template - inline void to_string_helper(const RunContext& ctx, StreamType& os, const std::vector& sub_shape, size_t offset) { + inline void to_string_helper(StreamType& os, const std::vector& sub_shape, size_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).to_string_helper(ctx, os, sub_shape, offset); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, sub_shape, offset); return; } #endif // MXNET_USE_CUDA DType* dptr = tb_.dptr() + offset; if (sub_shape.size() == 0) { - to_string_helper(ctx, os, dptr); + to_string_helper(os, dptr); return; } int dimension = sub_shape.size(); @@ -286,14 +283,13 @@ class TensorInspector { /*! * \brief interactive print the tensor value * \tparam DType the data type - * \param ctx the run context of the tensor * \param tag the name given to this call */ template - inline void interactive_print_helper(const RunContext& ctx, std::string tag) { + inline void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).interactive_print_helper(ctx, tag); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).interactive_print_helper(tag); return; } #endif // MXNET_USE_CUDA @@ -313,7 +309,7 @@ class TensorInspector { if (str == "b") { break; } else if (str == "e") { - to_string_helper(ctx, std::cout); + to_string_helper(std::cout); continue; } else if (str == "s") { InspectorManager::get()->interactive_print_skip_all_ = true; @@ -324,7 +320,7 @@ class TensorInspector { std::vector sub_shape; size_t offset; print_locator(pos, sub_shape, offset); - to_string_helper(ctx, std::cout, sub_shape, offset); + to_string_helper(std::cout, sub_shape, offset); } else { std::cout << "invalid input" << std::endl; } @@ -350,17 +346,17 @@ class TensorInspector { * \brief check/validate the values within the tensor, return the coordinates * where the lambda evaluates to true * \tparam DType the data type - * \param ctx the run context of the tensor * \param checker the lambda function to check each value of within the tensor * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ template - inline std::vector> check_value_helper(const RunContext& ctx, - const std::function& checker, bool interactive, std::string tag) { + inline std::vector> + check_value_helper(const std::function& checker, + bool interactive, std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - return TensorInspector(test::CAccessAsCPU(ctx, tb_, false)()).check_value_helper(ctx, + return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).check_value_helper( checker, interactive, tag); } #endif // MXNET_USE_CUDA @@ -455,7 +451,8 @@ class TensorInspector { */ template - TensorInspector(const Tensor& ts) : tb_(ts) {} + TensorInspector(const Tensor& ts, const RunContext& ctx): + tb_(ts), ctx_(ctx) {} /*! * \brief Construct from TBlob object @@ -464,7 +461,8 @@ class TensorInspector { * \tparam DType the data type * \param ts the source tensor obeject */ - TensorInspector(const TBlob& tb) : tb_(tb) {} + TensorInspector(const TBlob& tb, const RunContext& ctx): + tb_(tb), ctx_(ctx) {} /*! * \brief Construct from NDArray object. Currently this only works with kDefaultStorage @@ -473,36 +471,34 @@ class TensorInspector { * \tparam DType the data type * \param ts the source tensor obeject */ - TensorInspector(const NDArray& arr) : tb_(arr.data()){} + TensorInspector(const NDArray& arr, const RunContext& ctx): + tb_(arr.data()), ctx_(ctx) {} /*! * \brief print the tensor to std::cout - * \param ctx the run context of the tensor */ - inline void print_string(const RunContext& ctx) { - std::cout << to_string(ctx) << std::endl; + inline void print_string() { + std::cout << to_string() << std::endl; } /*! * \brief return a string which contains the values and other info of the tensor - * \param ctx the run context of the tensor */ - inline std::string to_string(const RunContext& ctx) { + inline std::string to_string() { std::stringstream ss; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - to_string_helper(ctx, ss); + to_string_helper(ss); }); return ss.str(); } /*! * \brief interactive print the tensor value - * \param ctx the run context of the tensor * \param tag the name given to this call */ - inline void interactive_print(const RunContext& ctx, std::string tag = "") { + inline void interactive_print(std::string tag = "") { MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - interactive_print_helper(ctx, tag); + interactive_print_helper(tag); }); } @@ -510,16 +506,15 @@ class TensorInspector { * \brief check/validate the values within the tensor, return the coordinates * where the lambda evaluates to true * \tparam ValueChecker the type of the lambda - * \param ctx the run context of the tensor * \param checker the lambda function to check each value of within the tensor * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ template - std::vector> check_value(const RunContext& ctx, const ValueChecker& checker, - bool interactive = false, std::string tag = "") { + std::vector> check_value(const ValueChecker& checker, bool interactive = false, + std::string tag = "") { MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - return check_value_helper(ctx, checker, interactive, tag); + return check_value_helper(checker, interactive, tag); }); return std::vector>(); } @@ -527,15 +522,14 @@ class TensorInspector { /*! * \brief check/validate the values within the tensor, return the coordinates * where the lambda evaluates to true - * \param ctx the run context of the tensor * \param ct the type of the checker * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - std::vector> check_value(const RunContext& ctx, CheckerType ct, - bool interactive = false, std::string tag = "") { + std::vector> check_value(CheckerType ct, bool interactive = false, + std::string tag = "") { MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - return check_value_helper(ctx, build_checker(ct), interactive, tag); + return check_value_helper(build_checker(ct), interactive, tag); }); return std::vector>(); } @@ -543,6 +537,8 @@ class TensorInspector { private: /* !\brief the tensor blob */ const TBlob tb_; + /* !\brief the run context of the tensor */ + const RunContext& ctx_; }; From 5123a142c1eea8b01a4e953ff3d81c20441f9979 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 17:25:36 -0700 Subject: [PATCH 04/29] sanity fix --- src/common/tensor_inspector.h | 131 ++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 63 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 5843f59a2724..6c022fc7a0fb 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -27,12 +27,14 @@ #ifndef MXNET_COMMON_TENSOR_INSPECTOR_H_ #define MXNET_COMMON_TENSOR_INSPECTOR_H_ -#include -#include +#include +#include +#include +#include #include "../../3rdparty/mshadow/mshadow/base.h" #include "../../tests/cpp/include/test_util.h" -namespace mxnet{ +namespace mxnet { /*! * \brief This singleton struct mediates individual TensorInspector objects @@ -49,15 +51,15 @@ struct InspectorManager { } return im.get(); } - /* !\brief mutex used to lock interactive_print() and check_value() */ + /* !\brief mutex used to lock interactive_print() and check_value() */ std::mutex mutex_; - /* !\brief skip all interactive prints */ + /* !\brief skip all interactive prints */ bool interactive_print_skip_all_ = false; - /* !\brief skip all value checks */ + /* !\brief skip all value checks */ bool check_value_skip_all_ = false; - /* !\brief visit count for interactive print tags */ + /* !\brief visit count for interactive print tags */ std::unordered_map interactive_print_tag_counter_; - /* !\brief visit count for check value tags */ + /* !\brief visit count for check value tags */ std::unordered_map check_value_tag_counter_; }; @@ -65,9 +67,9 @@ struct InspectorManager { * \brief Enum for building value checkers for TensorInspector::check_value() */ enum CheckerType { - NegativeChecker, // check if is negative - PositiveChecker, // check if is positive - NanChecker // check if is Nan, will always return false if DType is not a float type + NegativeChecker, // check if is negative + PositiveChecker, // check if is positive + NanChecker // check if is Nan, will always return false if DType is not a float type }; /** @@ -95,14 +97,14 @@ class TensorInspector { * \param os stream object to output to */ template - inline void tensor_info_to_string(StreamType& os) { + inline void tensor_info_to_string(StreamType* os) { int dimension = tb_.ndim(); - os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; - os << tb_.shape_[0]; + *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; + *os << tb_.shape_[0]; for (int i = 1; i < dimension; i++) { - os << 'x' << tb_.shape_[i]; + *os << 'x' << tb_.shape_[i]; } - os << ">" << std::endl; + *os << ">" << std::endl; } /*! @@ -113,14 +115,14 @@ class TensorInspector { * \param shape the shape of the tensor */ template - inline void tensor_info_to_string(StreamType& os, const std::vector& shape) { + inline void tensor_info_to_string(StreamType* os, const std::vector& shape) { int dimension = shape.size(); - os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; - os << shape[0]; + *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; + *os << shape[0]; for (int i = 1; i < dimension; i++) { - os << 'x' << shape[i]; + *os << 'x' << shape[i]; } - os << ">" << std::endl; + *os << ">" << std::endl; } /*! @@ -130,13 +132,13 @@ class TensorInspector { * \param os stream object to output to */ template - inline void to_string_helper(StreamType& os) { + inline void to_string_helper(StreamType* os) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os); return; } -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA int dimension = tb_.ndim(); std::vector multiples; int multiple = 1; @@ -144,21 +146,21 @@ class TensorInspector { multiple *= tb_.shape_[i]; multiples.push_back(multiple); } - os << std::string(dimension, '['); - os << tb_.dptr()[0]; + *os << std::string(dimension, '['); + *os << tb_.dptr()[0]; for (size_t i = 1; i < tb_.shape_.Size(); i++) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); } if (n) { - os << std::string(n, ']') << ", " << std::string(n, '['); + *os << std::string(n, ']') << ", " << std::string(n, '['); } else { - os << ", "; + *os << ", "; } - os << tb_.dptr()[i]; + *os << tb_.dptr()[i]; } - os << std::string(dimension, ']') << std::endl; + *os << std::string(dimension, ']') << std::endl; tensor_info_to_string(os); } @@ -170,15 +172,15 @@ class TensorInspector { * \param dptr the data pointer */ template - inline void to_string_helper(StreamType& os, const DType* dptr) { + inline void to_string_helper(StreamType* os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, dptr); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, dptr); return; } -#endif // MXNET_USE_CUDA - os << *dptr << std::endl; - os << "<" << typeid(*dptr).name() << ">" << std::endl; +#endif // MXNET_USE_CUDA + *os << *dptr << std::endl; + *os << "<" << typeid(*dptr).name() << ">" << std::endl; } /*! @@ -190,13 +192,14 @@ class TensorInspector { * \param offset the position of the first value of the desired part of the tensor */ template - inline void to_string_helper(StreamType& os, const std::vector& sub_shape, size_t offset) { + inline void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, sub_shape, offset); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, + sub_shape, offset); return; } -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA DType* dptr = tb_.dptr() + offset; if (sub_shape.size() == 0) { to_string_helper(os, dptr); @@ -210,21 +213,21 @@ class TensorInspector { multiples.push_back(multiple); } std::stringstream ss; - os << std::string(dimension, '['); - os << dptr[0]; + *os << std::string(dimension, '['); + *os << dptr[0]; for (size_t i = 1; i < multiple; i++) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); } if (n) { - os << std::string(n, ']') << ", " << std::string(n, '['); + *os << std::string(n, ']') << ", " << std::string(n, '['); } else { - os << ", "; + *os << ", "; } - os << dptr[i]; + *os << dptr[i]; } - os << std::string(dimension, ']') << std::endl; + *os << std::string(dimension, ']') << std::endl; tensor_info_to_string(os, sub_shape); } @@ -235,13 +238,14 @@ class TensorInspector { * \param sub_shape the sub-shape of the desired part of the tensor; calculated here * \param offset the position of the first value of the desired part of the tensor; calculated here */ - inline void print_locator(const std::vector& pos, std::vector& sub_shape, size_t& offset) { + inline void print_locator(const std::vector& pos, std::vector* sub_shape, + size_t* offset) { int dimension = tb_.ndim(); int sub_dim = dimension - pos.size(); - sub_shape.resize(sub_dim); + sub_shape->resize(sub_dim); int multiple = 1; for (int i = pos.size(), j = 0; i < dimension; i++, j++) { - sub_shape[j] = tb_.shape_[i]; + (*sub_shape)[j] = tb_.shape_[i]; multiple *= tb_.shape_[i]; } int sum = 0; @@ -250,7 +254,7 @@ class TensorInspector { sum += pos[i] * m; m *= tb_.shape_[i]; } - offset = sum * multiple; + *offset = sum * multiple; } /*! @@ -289,10 +293,10 @@ class TensorInspector { inline void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).interactive_print_helper(tag); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).interactive_print_helper(tag); return; } -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA std::lock_guard lock(InspectorManager::get()->mutex_); InspectorManager::get()->interactive_print_tag_counter_[tag] += 1; while (!InspectorManager::get()->interactive_print_skip_all_) { @@ -301,7 +305,7 @@ class TensorInspector { std::cout << "Tag: " << tag << " Visit: " << InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; } - tensor_info_to_string(std::cout); + tensor_info_to_string(&std::cout); std::cout << "Please specify the position, seperated by \",\"" << std::endl << "\"e\" for the entire tensor, \"b\" to break, \"s\" to skip all: " << std::endl; std::string str; @@ -309,7 +313,7 @@ class TensorInspector { if (str == "b") { break; } else if (str == "e") { - to_string_helper(std::cout); + to_string_helper(&std::cout); continue; } else if (str == "s") { InspectorManager::get()->interactive_print_skip_all_ = true; @@ -319,8 +323,8 @@ class TensorInspector { if (parse_position(pos, str)) { std::vector sub_shape; size_t offset; - print_locator(pos, sub_shape, offset); - to_string_helper(std::cout, sub_shape, offset); + print_locator(pos, &sub_shape, &offset); + to_string_helper(&std::cout, sub_shape, offset); } else { std::cout << "invalid input" << std::endl; } @@ -331,7 +335,7 @@ class TensorInspector { * \brief calculate the coordinate of a value in the tensor, given its index * \param idx the index of the value in the tensor */ - inline std::vector index_to_coordinates(size_t idx){ + inline std::vector index_to_coordinates(size_t idx) { int dimension = tb_.ndim(); std::vector ret; for (int i = dimension-1; i >= 0; i--) { @@ -359,7 +363,7 @@ class TensorInspector { return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).check_value_helper( checker, interactive, tag); } -#endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA std::vector> ret; int count = 0; std::stringstream ss; @@ -388,9 +392,11 @@ class TensorInspector { while (!InspectorManager::get()->check_value_skip_all_) { std::cout << "----------Value Check----------" << std::endl; if (tag != "") { - std::cout << "Tag: " << tag << " Visit: " << InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; + std::cout << "Tag: " << tag << " Visit: " << + InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; } - std::cout << count << " value(s) found. \"p\" to print the coordinates, \"b\" to break, \"s\" to skip all: "; + std::cout << count << " value(s) found. \"p\" to print the coordinates," << + " \"b\" to break, \"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -402,17 +408,16 @@ class TensorInspector { } } } - return ret; } - + /*! * \brief build the lambda function, aka the checker, given its type * \tparam DType the data type * \param ct the type of the checker */ template - inline std::function build_checker(CheckerType ct){ + inline std::function build_checker(CheckerType ct) { switch (ct) { case NegativeChecker: return [] (DType x) { @@ -487,7 +492,7 @@ class TensorInspector { inline std::string to_string() { std::stringstream ss; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - to_string_helper(ss); + to_string_helper(&ss); }); return ss.str(); } @@ -542,6 +547,6 @@ class TensorInspector { }; -} // namespace mxnet +} // namespace mxnet #endif // MXNET_COMMON_TENSOR_INSPECTOR_H_ From 969cf73d98c1dabb255b74cc57579966de661df6 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 17:32:07 -0700 Subject: [PATCH 05/29] sanity fix --- src/common/tensor_inspector.h | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 6c022fc7a0fb..2dcf497ce331 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -175,7 +175,8 @@ class TensorInspector { inline void to_string_helper(StreamType* os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, dptr); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + .to_string_helper(os, dptr); return; } #endif // MXNET_USE_CUDA @@ -195,8 +196,8 @@ class TensorInspector { inline void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os, - sub_shape, offset); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + .to_string_helper(os, sub_shape, offset); return; } #endif // MXNET_USE_CUDA @@ -263,25 +264,25 @@ class TensorInspector { * \param pos the coordinates of the desired part of the tensor, calculated here * \param str the string that represents the coordinate */ - inline bool parse_position(std::vector& pos, const std::string& str) { + inline bool parse_position(std::vector* pos, const std::string& str) { int dimension = tb_.ndim(); std::stringstream ss(str); int i; while (ss >> i) { - pos.push_back(i); + pos->push_back(i); if (ss.peek() == ',') { ss.ignore(); } } - if (pos.size() > dimension) { + if (pos->size() > dimension) { return false; } - for (unsigned i = 0; i < pos.size(); i++) { - if (pos[i] > (tb_.shape_[i]-1)) { + for (unsigned i = 0; i < pos->size(); i++) { + if ((*pos)[i] > (tb_.shape_[i]-1)) { return false; } } - return !pos.empty(); + return !pos->empty(); } /*! @@ -293,7 +294,8 @@ class TensorInspector { inline void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).interactive_print_helper(tag); + explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + .interactive_print_helper(tag); return; } #endif // MXNET_USE_CUDA @@ -320,7 +322,7 @@ class TensorInspector { break; } std::vector pos; - if (parse_position(pos, str)) { + if (parse_position(&pos, str)) { std::vector sub_shape; size_t offset; print_locator(pos, &sub_shape, &offset); From eafcadc3052b3546ab9319b396db8a74d563fee4 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 22:57:16 -0700 Subject: [PATCH 06/29] fix checker bug & add new checker type --- src/common/tensor_inspector.h | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 2dcf497ce331..c1f64e93b8f4 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -22,7 +22,7 @@ * \file tensor_inspector.h * \brief utility to inspector tensor objects * \author Zhaoqi Zhu -*/ + */ #ifndef MXNET_COMMON_TENSOR_INSPECTOR_H_ #define MXNET_COMMON_TENSOR_INSPECTOR_H_ @@ -69,6 +69,7 @@ struct InspectorManager { enum CheckerType { NegativeChecker, // check if is negative PositiveChecker, // check if is positive + ZeroChecker, // check if is zero NanChecker // check if is Nan, will always return false if DType is not a float type }; @@ -90,6 +91,7 @@ enum CheckerType { * are stored as a TBlob object tb_. */ class TensorInspector { + private: /*! * \brief generate the tensor info, including data type and shape * \tparam DType the data type @@ -427,7 +429,11 @@ class TensorInspector { }; case PositiveChecker: return [] (DType x) { - return x < 0; + return x > 0; + }; + case ZeroChecker: + return [] (DType x) { + return x == 0; }; case NanChecker: if (std::is_same::value || std::is_same::value || @@ -448,6 +454,11 @@ class TensorInspector { return [] (DType x) {return false;}; } + /* !\brief the tensor blob */ + const TBlob tb_; + /* !\brief the run context of the tensor */ + const RunContext& ctx_; + public: /*! * \brief Construct from Tensor object @@ -540,12 +551,6 @@ class TensorInspector { }); return std::vector>(); } - - private: - /* !\brief the tensor blob */ - const TBlob tb_; - /* !\brief the run context of the tensor */ - const RunContext& ctx_; }; From 275fb27b5d0cc3d9d8c969fa8957c64cf19001cc Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 8 Jul 2019 23:54:48 -0700 Subject: [PATCH 07/29] add more checker types --- src/common/tensor_inspector.h | 99 +++++++++++++++++++++++++++++------ 1 file changed, 82 insertions(+), 17 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index c1f64e93b8f4..975017156f39 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -70,7 +70,14 @@ enum CheckerType { NegativeChecker, // check if is negative PositiveChecker, // check if is positive ZeroChecker, // check if is zero - NanChecker // check if is Nan, will always return false if DType is not a float type + NaNChecker, // check if is NaN, will always return false if DType is not a float type + InfChecker, // check if is infinity, will always return false if DType is not a float type + PositiveInfChecker, // check if is positive infinity, + // will always return false if DType is not a float type + NegativeInfChecker, // check if is nagative infinity, + // will always return false if DType is not a float type + FiniteChecker, // check if is finite, will always return false if DType is not a float type + NormalChecker, // check if is neither infinity nor NaN }; /** @@ -235,7 +242,7 @@ class TensorInspector { } /*! - * \brief helper functino to calculate the sub_shape and offset for the desired part of the tensor, + * \brief helper function to calculate the sub_shape and offset for the desired part of the tensor, * given its coordinates in the original tensor * \param pos the coordinates of the desired part of the tensor * \param sub_shape the sub-shape of the desired part of the tensor; calculated here @@ -364,8 +371,8 @@ class TensorInspector { bool interactive, std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).check_value_helper( - checker, interactive, tag); + return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + .check_value_helper(checker, interactive, tag); } #endif // MXNET_USE_CUDA std::vector> ret; @@ -435,14 +442,76 @@ class TensorInspector { return [] (DType x) { return x == 0; }; - case NanChecker: + case NaNChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value) { + std::is_same::value || + std::is_same::value) { return [] (DType x) { return x != x; }; } else { - LOG(WARNING) << "NanChecker only applies to float types. " << + LOG(WARNING) << "NaNChecker only applies to float types. " << + "Lambda will always return false."; + } + break; + case InfChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x == (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0; + }; + } else { + LOG(WARNING) << "InfChecker only applies to float types. " << + "Lambda will always return false."; + } + break; + case PositiveInfChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x == (DType)1.0 / (DType)0.0; + }; + } else { + LOG(WARNING) << "PositiveInfChecker only applies to float types. " << + "Lambda will always return false."; + } + break; + case NegativeInfChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x == -(DType)1.0 / (DType)0.0; + }; + } else { + LOG(WARNING) << "NegativeInfChecker only applies to float types. " << + "Lambda will always return false."; + } + break; + case FiniteChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x != (DType)1.0 / (DType)0.0 && x != -(DType)1.0 / (DType)0.0; + }; + } else { + LOG(WARNING) << "FiniteChecker only applies to float types. " << + "Lambda will always return false."; + } + break; + case NormalChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x != (DType)1.0 / (DType)0.0 && x != -(DType)1.0 / (DType)0.0 && + x == x; + }; + } else { + LOG(WARNING) << "NormalChecker only applies to float types. " << "Lambda will always return false."; } break; @@ -465,7 +534,8 @@ class TensorInspector { * \tparam Device the device the tensor resides in * \tparam dimension the dimension of the tensor * \tparam DType the data type - * \param ts the source tensor obeject + * \param ts the source tensor object + * \param ctx the run context of the tensor */ template @@ -474,20 +544,16 @@ class TensorInspector { /*! * \brief Construct from TBlob object - * \tparam Device the device the tensor resides in - * \tparam dimension the dimension of the tensor - * \tparam DType the data type - * \param ts the source tensor obeject + * \param tb the source tblob object + * \param ctx the run context of the tensor */ TensorInspector(const TBlob& tb, const RunContext& ctx): tb_(tb), ctx_(ctx) {} /*! * \brief Construct from NDArray object. Currently this only works with kDefaultStorage - * \tparam Device the device the tensor resides in - * \tparam dimension the dimension of the tensor - * \tparam DType the data type - * \param ts the source tensor obeject + * \param arr the source ndarray object + * \param ctx the run context of the tensor */ TensorInspector(const NDArray& arr, const RunContext& ctx): tb_(arr.data()), ctx_(ctx) {} @@ -553,7 +619,6 @@ class TensorInspector { } }; - } // namespace mxnet #endif // MXNET_COMMON_TENSOR_INSPECTOR_H_ From 9c06300033815e17243038a96b7b038812593718 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 9 Jul 2019 11:05:57 -0700 Subject: [PATCH 08/29] fix gpu tensor constructor call --- src/common/tensor_inspector.h | 41 ++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 975017156f39..212595321d8d 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -144,7 +144,8 @@ class TensorInspector { inline void to_string_helper(StreamType* os) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()).to_string_helper(os); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) + .to_string_helper(os); return; } #endif // MXNET_USE_CUDA @@ -184,7 +185,7 @@ class TensorInspector { inline void to_string_helper(StreamType* os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) .to_string_helper(os, dptr); return; } @@ -205,7 +206,7 @@ class TensorInspector { inline void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) .to_string_helper(os, sub_shape, offset); return; } @@ -303,7 +304,7 @@ class TensorInspector { inline void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - explicit TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) .interactive_print_helper(tag); return; } @@ -318,7 +319,7 @@ class TensorInspector { } tensor_info_to_string(&std::cout); std::cout << "Please specify the position, seperated by \",\"" << std::endl - << "\"e\" for the entire tensor, \"b\" to break, \"s\" to skip all: " << std::endl; + << "\"e\" for the entire tensor, \"b\" to break, \"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -358,24 +359,23 @@ class TensorInspector { } /*! - * \brief check/validate the values within the tensor, return the coordinates + * \brief check/validate the values within the tensor, find the coordinates * where the lambda evaluates to true * \tparam DType the data type + * \param ret a vector of coordinates which itself is a vector of int; calculated here * \param checker the lambda function to check each value of within the tensor * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ template - inline std::vector> - check_value_helper(const std::function& checker, - bool interactive, std::string tag) { + inline void check_value_helper(std::vector>* ret, + const std::function& checker,bool interactive, std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)()) - .check_value_helper(checker, interactive, tag); + return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) + .check_value_helper(ret, checker, interactive, tag); } #endif // MXNET_USE_CUDA - std::vector> ret; int count = 0; std::stringstream ss; ss << "["; @@ -393,7 +393,7 @@ class TensorInspector { ss << ", " << coords[i]; } ss << ")"; - ret.push_back(coords); + ret->push_back(coords); } } ss << "]" << std::endl; @@ -419,7 +419,6 @@ class TensorInspector { } } } - return ret; } /*! @@ -595,12 +594,13 @@ class TensorInspector { * \param tag the name given to this call */ template - std::vector> check_value(const ValueChecker& checker, bool interactive = false, + inline std::vector> check_value(const ValueChecker& checker, bool interactive = false, std::string tag = "") { + std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - return check_value_helper(checker, interactive, tag); + check_value_helper(&ret, checker, ret, interactive, tag); }); - return std::vector>(); + return ret; } /*! @@ -610,12 +610,13 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - std::vector> check_value(CheckerType ct, bool interactive = false, + inline std::vector> check_value(CheckerType ct, bool interactive = false, std::string tag = "") { + std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - return check_value_helper(build_checker(ct), interactive, tag); + check_value_helper(&ret, build_checker(ct), interactive, tag); }); - return std::vector>(); + return ret; } }; From b013c2f92b4829cbe45e9b418bc25a8c3471bfe8 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 9 Jul 2019 15:20:47 -0700 Subject: [PATCH 09/29] add value dumping funtionality --- src/common/tensor_inspector.h | 251 ++++++++++++++++++++++------------ 1 file changed, 165 insertions(+), 86 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 212595321d8d..f435f8fcad67 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -31,6 +31,7 @@ #include #include #include +#include #include "../../3rdparty/mshadow/mshadow/base.h" #include "../../tests/cpp/include/test_util.h" @@ -61,6 +62,8 @@ struct InspectorManager { std::unordered_map interactive_print_tag_counter_; /* !\brief visit count for check value tags */ std::unordered_map check_value_tag_counter_; + /* !\brief visit count for dump value tags */ + std::unordered_map dump_value_tag_counter_; }; /*! @@ -343,84 +346,6 @@ class TensorInspector { } } - /*! - * \brief calculate the coordinate of a value in the tensor, given its index - * \param idx the index of the value in the tensor - */ - inline std::vector index_to_coordinates(size_t idx) { - int dimension = tb_.ndim(); - std::vector ret; - for (int i = dimension-1; i >= 0; i--) { - ret.push_back(idx % tb_.shape_[i]); - idx /= tb_.shape_[i]; - } - std::reverse(ret.begin(), ret.end()); - return ret; - } - - /*! - * \brief check/validate the values within the tensor, find the coordinates - * where the lambda evaluates to true - * \tparam DType the data type - * \param ret a vector of coordinates which itself is a vector of int; calculated here - * \param checker the lambda function to check each value of within the tensor - * \param interactive wherether to allow the user to interactively check the coordinates - * \param tag the name given to this call - */ - template - inline void check_value_helper(std::vector>* ret, - const std::function& checker,bool interactive, std::string tag) { -#if MXNET_USE_CUDA - if (tb_.dev_mask() == gpu::kDevMask) { - return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) - .check_value_helper(ret, checker, interactive, tag); - } -#endif // MXNET_USE_CUDA - int count = 0; - std::stringstream ss; - ss << "["; - bool first_pass = true; - for (size_t i = 0; i < tb_.shape_.Size(); i++) { - if (checker(tb_.dptr()[i])) { - count += 1; - if (!first_pass) { - ss << ", "; - } - first_pass = false; - std::vector coords = index_to_coordinates(i); - ss << "(" << coords[0]; - for (size_t i = 1; i < coords.size(); i++) { - ss << ", " << coords[i]; - } - ss << ")"; - ret->push_back(coords); - } - } - ss << "]" << std::endl; - if (interactive) { - std::lock_guard lock(InspectorManager::get()->mutex_); - InspectorManager::get()->check_value_tag_counter_[tag] += 1; - while (!InspectorManager::get()->check_value_skip_all_) { - std::cout << "----------Value Check----------" << std::endl; - if (tag != "") { - std::cout << "Tag: " << tag << " Visit: " << - InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; - } - std::cout << count << " value(s) found. \"p\" to print the coordinates," << - " \"b\" to break, \"s\" to skip all: "; - std::string str; - std::cin >> str; - if (str == "b") { - break; - } else if (str == "p") { - std::cout << ss.str() << std::endl; - } else if (str == "s") { - InspectorManager::get()->check_value_skip_all_ = true; - } - } - } - } - /*! * \brief build the lambda function, aka the checker, given its type * \tparam DType the data type @@ -443,7 +368,6 @@ class TensorInspector { }; case NaNChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value) { return [] (DType x) { return x != x; @@ -455,7 +379,6 @@ class TensorInspector { break; case InfChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value) { return [] (DType x) { return x == (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0; @@ -467,7 +390,6 @@ class TensorInspector { break; case PositiveInfChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value) { return [] (DType x) { return x == (DType)1.0 / (DType)0.0; @@ -479,7 +401,6 @@ class TensorInspector { break; case NegativeInfChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value) { return [] (DType x) { return x == -(DType)1.0 / (DType)0.0; @@ -491,7 +412,6 @@ class TensorInspector { break; case FiniteChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value) { return [] (DType x) { return x != (DType)1.0 / (DType)0.0 && x != -(DType)1.0 / (DType)0.0; @@ -503,7 +423,6 @@ class TensorInspector { break; case NormalChecker: if (std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value) { return [] (DType x) { return x != (DType)1.0 / (DType)0.0 && x != -(DType)1.0 / (DType)0.0 && @@ -522,6 +441,155 @@ class TensorInspector { return [] (DType x) {return false;}; } + /*! + * \brief calculate the coordinate of a value in the tensor, given its index + * \param idx the index of the value in the tensor + */ + inline std::vector index_to_coordinates(size_t idx) { + int dimension = tb_.ndim(); + std::vector ret; + for (int i = dimension-1; i >= 0; i--) { + ret.push_back(idx % tb_.shape_[i]); + idx /= tb_.shape_[i]; + } + std::reverse(ret.begin(), ret.end()); + return ret; + } + + /*! + * \brief check/validate the values within the tensor, find the coordinates + * where the lambda evaluates to true + * \tparam DType the data type + * \param ret a vector of coordinates which itself is a vector of int; calculated here + * \param checker the lambda function to check each value of within the tensor + * \param interactive wherether to allow the user to interactively check the coordinates + * \param tag the name given to this call + */ + template + inline void check_value_helper(std::vector>* ret, + const std::function& checker,bool interactive, std::string tag) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) + .check_value_helper(ret, checker, interactive, tag); + } +#endif // MXNET_USE_CUDA + int count = 0; + std::stringstream ss; + ss << "["; + bool first_pass = true; + for (size_t i = 0; i < tb_.shape_.Size(); i++) { + if (checker(tb_.dptr()[i])) { + count += 1; + if (!first_pass) { + ss << ", "; + } + first_pass = false; + std::vector coords = index_to_coordinates(i); + ss << "(" << coords[0]; + for (size_t i = 1; i < coords.size(); i++) { + ss << ", " << coords[i]; + } + ss << ")"; + ret->push_back(coords); + } + } + ss << "]" << std::endl; + if (interactive) { + std::lock_guard lock(InspectorManager::get()->mutex_); + InspectorManager::get()->check_value_tag_counter_[tag] += 1; + while (!InspectorManager::get()->check_value_skip_all_) { + std::cout << "----------Value Check----------" << std::endl; + if (tag != "") { + std::cout << "Tag: " << tag << " Visit: " << + InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; + } + std::cout << count << " value(s) found. \"p\" to print the coordinates," << + " \"b\" to break, \"s\" to skip all: "; + std::string str; + std::cin >> str; + if (str == "b") { + break; + } else if (str == "p") { + std::cout << ss.str() << std::endl; + } else if (str == "s") { + InspectorManager::get()->check_value_skip_all_ = true; + } + } + } + } + + /*! + * \brief infer the python type, given the c++ type + * \tparam ti the type info + */ + inline char infer_type(const std::type_info& ti) { + if(ti == typeid(float)) return 'f'; + else if(ti == typeid(double)) return 'f'; + else if(ti == typeid(mshadow::half::half_t) ) return 'f'; + else if(ti == typeid(uint8_t)) return 'u'; + else if(ti == typeid(int32_t)) return 'i'; + else if(ti == typeid(int64_t)) return 'i'; + else return '?'; + } + + /*! + * \brief check if the host machine is big or small endian + */ + inline char endian_test() { + int x = 1; + return (((char*)&x)[0]) ? '<' : '>'; + } + + /*! + * \brief dump the value of the tensor to a file with name "tag_[visit count].npy" in npy format + * \tparam DType the data type + * \param tag the name given to this call + */ + template + inline void dump_value_helper(const std::string& tag) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) + .dump_value_helper(tag); + return; + } +#endif // MXNET_USE_CUDA + std::string dict; + dict += "{'descr':'"; + dict += endian_test(); + dict += infer_type(typeid(DType)); + dict += std::to_string(sizeof(DType)); + dict += "','fortran_order':False,'shape':("; + dict += std::to_string(tb_.shape_[0]); + for (int i = 1; i < tb_.ndim(); i++) { + dict += ','; + dict += std::to_string(tb_.shape_[i]); + } + if (tb_.ndim() == 1) { + dict += ","; + } + dict += ")} "; + int padding_size = 64 - ((10 + dict.size()) % 64); + dict += std::string(padding_size, ' '); + dict[dict.size()-1] = '\n'; + std::string header; + header += (char)0x93; + header += "NUMPY"; + header += (char)0x01; + header += (char)0x00; + header += (char)((uint16_t)dict.size() & 0x00ff); + header += (char)(((uint16_t)dict.size() >> 8) & 0x00ff); + header += dict; + InspectorManager::get()->dump_value_tag_counter_[tag] += 1; + int visit = InspectorManager::get()->dump_value_tag_counter_[tag]; + std::ofstream file (tag + "_" + std::to_string(visit) + ".npy", + std::ios::out | std::ios::binary); + file.write(header.c_str(), header.size()); + file.write((char*)tb_.dptr(), sizeof(DType) * tb_.shape_.Size()); + file.close(); + } + /* !\brief the tensor blob */ const TBlob tb_; /* !\brief the run context of the tensor */ @@ -594,8 +662,8 @@ class TensorInspector { * \param tag the name given to this call */ template - inline std::vector> check_value(const ValueChecker& checker, bool interactive = false, - std::string tag = "") { + inline std::vector> check_value(const ValueChecker& checker, + bool interactive = false, std::string tag = "") { std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { check_value_helper(&ret, checker, ret, interactive, tag); @@ -618,6 +686,17 @@ class TensorInspector { }); return ret; } + + /*! + * \brief dump the value of the tensor to a file with name "tag_[visit count].npy" in npy format + * \param tag the name given to this call + */ + inline void dump_value(std::string tag) { + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { + dump_value_helper(tag); + }); + } + }; } // namespace mxnet From b8c01ac3bee67c7465cc485d28b47542b6b6aa62 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 9 Jul 2019 15:40:33 -0700 Subject: [PATCH 10/29] sanity fix --- src/common/tensor_inspector.h | 36 +++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index f435f8fcad67..6d78529c1360 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -20,7 +20,7 @@ /*! * Copyright (c) 2019 by Contributors * \file tensor_inspector.h - * \brief utility to inspector tensor objects + * \brief utility to inspect tensor objects * \author Zhaoqi Zhu */ @@ -467,7 +467,7 @@ class TensorInspector { */ template inline void check_value_helper(std::vector>* ret, - const std::function& checker,bool interactive, std::string tag) { + const std::function& checker, bool interactive, std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -524,13 +524,14 @@ class TensorInspector { * \tparam ti the type info */ inline char infer_type(const std::type_info& ti) { - if(ti == typeid(float)) return 'f'; - else if(ti == typeid(double)) return 'f'; - else if(ti == typeid(mshadow::half::half_t) ) return 'f'; - else if(ti == typeid(uint8_t)) return 'u'; - else if(ti == typeid(int32_t)) return 'i'; - else if(ti == typeid(int64_t)) return 'i'; - else return '?'; + if (ti == typeid(float)) return 'f'; + else if (ti == typeid(double)) return 'f'; + else if (ti == typeid(mshadow::half::half_t) ) return 'f'; + else if (ti == typeid(uint8_t)) return 'u'; + else if (ti == typeid(int32_t)) return 'i'; + else if (ti == typeid(int64_t)) return 'i'; + else + return '?'; } /*! @@ -538,7 +539,7 @@ class TensorInspector { */ inline char endian_test() { int x = 1; - return (((char*)&x)[0]) ? '<' : '>'; + return (reinterpret_cast(&x)[0]) ? '<' : '>'; } /*! @@ -574,19 +575,19 @@ class TensorInspector { dict += std::string(padding_size, ' '); dict[dict.size()-1] = '\n'; std::string header; - header += (char)0x93; + header += static_cast(0x93); header += "NUMPY"; - header += (char)0x01; - header += (char)0x00; - header += (char)((uint16_t)dict.size() & 0x00ff); - header += (char)(((uint16_t)dict.size() >> 8) & 0x00ff); + header += static_cast(0x01); + header += static_cast(0x00); + header += static_cast((uint16_t)dict.size() & 0x00ff); + header += static_cast(((uint16_t)dict.size() >> 8) & 0x00ff); header += dict; InspectorManager::get()->dump_value_tag_counter_[tag] += 1; int visit = InspectorManager::get()->dump_value_tag_counter_[tag]; - std::ofstream file (tag + "_" + std::to_string(visit) + ".npy", + std::ofstream file(tag + "_" + std::to_string(visit) + ".npy", std::ios::out | std::ios::binary); file.write(header.c_str(), header.size()); - file.write((char*)tb_.dptr(), sizeof(DType) * tb_.shape_.Size()); + file.write(reinterpret_cast(tb_.dptr()), sizeof(DType) * tb_.shape_.Size()); file.close(); } @@ -696,7 +697,6 @@ class TensorInspector { dump_value_helper(tag); }); } - }; } // namespace mxnet From ec8b44ff2461b703416a1f0d701d0010ba7e03c9 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 9 Jul 2019 15:44:36 -0700 Subject: [PATCH 11/29] sanity fix --- src/common/tensor_inspector.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 6d78529c1360..c7f4d3dfa0e5 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -530,7 +530,7 @@ class TensorInspector { else if (ti == typeid(uint8_t)) return 'u'; else if (ti == typeid(int32_t)) return 'i'; else if (ti == typeid(int64_t)) return 'i'; - else + else return '?'; } From eef4884f28dbbceae9b296687e176b7a4d8237c0 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 9 Jul 2019 16:27:20 -0700 Subject: [PATCH 12/29] add dumping support to interactive print --- src/common/tensor_inspector.h | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index c7f4d3dfa0e5..bf9ea49ce85a 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -322,7 +322,7 @@ class TensorInspector { } tensor_info_to_string(&std::cout); std::cout << "Please specify the position, seperated by \",\"" << std::endl - << "\"e\" for the entire tensor, \"b\" to break, \"s\" to skip all: "; + << "\"e\" for the entire tensor, \"d\" to dump value to file, \"b\" to break, \"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -333,6 +333,18 @@ class TensorInspector { } else if (str == "s") { InspectorManager::get()->interactive_print_skip_all_ = true; break; + } else if (str == "d") { + while (true) { + std::cout << "Please enter a tag: "; + std::cin >> str; + if (str.find(' ') != std::string::npos) { + std::cout << "Invalid input. "; + continue; + } + dump_value_helper(str); + break; + } + continue; } std::vector pos; if (parse_position(&pos, str)) { @@ -504,8 +516,8 @@ class TensorInspector { std::cout << "Tag: " << tag << " Visit: " << InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; } - std::cout << count << " value(s) found. \"p\" to print the coordinates," << - " \"b\" to break, \"s\" to skip all: "; + std::cout << count << " value(s) found." << std::endl; + std::cout << "\"p\" to print the coordinates, \"b\" to break, \"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { From 108d74db9a23702a5eded356cb80c213ddcfa5ae Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 9 Jul 2019 16:40:49 -0700 Subject: [PATCH 13/29] sanity fix --- src/common/tensor_inspector.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index bf9ea49ce85a..02ea34495d68 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -321,8 +321,9 @@ class TensorInspector { InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; } tensor_info_to_string(&std::cout); - std::cout << "Please specify the position, seperated by \",\"" << std::endl - << "\"e\" for the entire tensor, \"d\" to dump value to file, \"b\" to break, \"s\" to skip all: "; + std::cout << "Please specify the position, seperated by \",\"" << std::endl; + std::cout << "\"e\" for the entire tensor, \"d\" to dump value to file," << + " \"b\" to break, \"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { From 95c7c4479158bbc5ef50471d14fb499a849778eb Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 11 Jul 2019 10:59:53 -0700 Subject: [PATCH 14/29] Re-Trigger build From a41720d86aa69c3d7147f7ddfea2ea127079bb54 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 11 Jul 2019 14:06:34 -0700 Subject: [PATCH 15/29] add namespace before Tensor --- src/common/tensor_inspector.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 02ea34495d68..9966cea6a723 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -38,7 +38,7 @@ namespace mxnet { /*! - * \brief This singleton struct mediates individual TensorInspector objects + * \brief this singleton struct mediates individual TensorInspector objects * so that we can control the global behavior from each of them */ struct InspectorManager { @@ -611,7 +611,7 @@ class TensorInspector { public: /*! - * \brief Construct from Tensor object + * \brief construct from Tensor object * \tparam Device the device the tensor resides in * \tparam dimension the dimension of the tensor * \tparam DType the data type @@ -620,11 +620,11 @@ class TensorInspector { */ template - TensorInspector(const Tensor& ts, const RunContext& ctx): + TensorInspector(const mshadow::Tensor& ts, const RunContext& ctx): tb_(ts), ctx_(ctx) {} /*! - * \brief Construct from TBlob object + * \brief construct from TBlob object * \param tb the source tblob object * \param ctx the run context of the tensor */ @@ -632,7 +632,7 @@ class TensorInspector { tb_(tb), ctx_(ctx) {} /*! - * \brief Construct from NDArray object. Currently this only works with kDefaultStorage + * \brief construct from NDArray object. Currently this only works with kDefaultStorage * \param arr the source ndarray object * \param ctx the run context of the tensor */ @@ -658,7 +658,7 @@ class TensorInspector { } /*! - * \brief interactive print the tensor value + * \brief interactively print the tensor value * \param tag the name given to this call */ inline void interactive_print(std::string tag = "") { From e2e16f4217cacd25e2bbee5d0a51fa5c44193f06 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 11 Jul 2019 14:48:47 -0700 Subject: [PATCH 16/29] add more checker types --- src/common/tensor_inspector.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 9966cea6a723..2dc95176408a 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -81,6 +81,7 @@ enum CheckerType { // will always return false if DType is not a float type FiniteChecker, // check if is finite, will always return false if DType is not a float type NormalChecker, // check if is neither infinity nor NaN + AbnormalChecker, // chekc if is infinity or nan }; /** @@ -446,6 +447,18 @@ class TensorInspector { "Lambda will always return false."; } break; + case AbnormalChecker: + if (std::is_same::value || std::is_same::value || + std::is_same::value) { + return [] (DType x) { + return x += (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0 && + x != x; + }; + } else { + LOG(WARNING) << "AbnormalChecker only applies to float types. " << + "Lambda will always return false."; + } + break; default: return [] (DType x) { return false; From 36c8d9f77599789ea829aca3b99817be33f3d7a7 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 11 Jul 2019 14:51:51 -0700 Subject: [PATCH 17/29] bug fix --- src/common/tensor_inspector.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 2dc95176408a..3c6762320f97 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -451,7 +451,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x += (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0 && + return x == (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0 || x != x; }; } else { From e3356d5b538b8a72cfa17cca88340364b812b412 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Thu, 11 Jul 2019 17:04:39 -0700 Subject: [PATCH 18/29] fix comments --- src/common/tensor_inspector.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 3c6762320f97..35dfab2a7e5e 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -81,7 +81,7 @@ enum CheckerType { // will always return false if DType is not a float type FiniteChecker, // check if is finite, will always return false if DType is not a float type NormalChecker, // check if is neither infinity nor NaN - AbnormalChecker, // chekc if is infinity or nan + AbnormalChecker, // chekck if is infinity or nan }; /** @@ -484,7 +484,7 @@ class TensorInspector { /*! * \brief check/validate the values within the tensor, find the coordinates - * where the lambda evaluates to true + * where the value checker evaluates to true * \tparam DType the data type * \param ret a vector of coordinates which itself is a vector of int; calculated here * \param checker the lambda function to check each value of within the tensor @@ -569,7 +569,8 @@ class TensorInspector { } /*! - * \brief dump the value of the tensor to a file with name "tag_[visit count].npy" in npy format + * \brief dump the value of the tensor to a file with name "[tag]_[visit count].npy" in npy format + * the dump file follows npy 1.0 stantand * \tparam DType the data type * \param tag the name given to this call */ @@ -615,6 +616,8 @@ class TensorInspector { file.write(header.c_str(), header.size()); file.write(reinterpret_cast(tb_.dptr()), sizeof(DType) * tb_.shape_.Size()); file.close(); + std::cout << "Tensor dumped to file: " << + tag + "_" + std::to_string(visit) + ".npy" << std::endl; } /* !\brief the tensor blob */ @@ -682,7 +685,7 @@ class TensorInspector { /*! * \brief check/validate the values within the tensor, return the coordinates - * where the lambda evaluates to true + * where the value checker evaluates to true * \tparam ValueChecker the type of the lambda * \param checker the lambda function to check each value of within the tensor * \param interactive wherether to allow the user to interactively check the coordinates From 202ddd41f69b38c3a392f3a4e7c89634675dc4f4 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 12 Jul 2019 11:59:28 -0700 Subject: [PATCH 19/29] change int to size_t --- src/common/tensor_inspector.h | 40 +++++++++++++++++------------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 35dfab2a7e5e..9ed33849bd03 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -114,7 +114,7 @@ class TensorInspector { int dimension = tb_.ndim(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << tb_.shape_[0]; - for (int i = 1; i < dimension; i++) { + for (int i = 1; i < dimension; ++i) { *os << 'x' << tb_.shape_[i]; } *os << ">" << std::endl; @@ -132,7 +132,7 @@ class TensorInspector { int dimension = shape.size(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << shape[0]; - for (int i = 1; i < dimension; i++) { + for (int i = 1; i < dimension; ++i) { *os << 'x' << shape[i]; } *os << ">" << std::endl; @@ -154,15 +154,15 @@ class TensorInspector { } #endif // MXNET_USE_CUDA int dimension = tb_.ndim(); - std::vector multiples; - int multiple = 1; - for (int i = dimension-1; i >= 0; i--) { + std::vector multiples; + size_t multiple = 1; + for (int i = dimension-1; i >= 0; --i) { multiple *= tb_.shape_[i]; multiples.push_back(multiple); } *os << std::string(dimension, '['); *os << tb_.dptr()[0]; - for (size_t i = 1; i < tb_.shape_.Size(); i++) { + for (size_t i = 1; i < tb_.shape_.Size(); ++i) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); @@ -221,16 +221,16 @@ class TensorInspector { return; } int dimension = sub_shape.size(); - std::vector multiples; + std::vector multiples; size_t multiple = 1; - for (int i = dimension-1; i >= 0; i--) { + for (int i = dimension-1; i >= 0; --i) { multiple *= sub_shape[i]; multiples.push_back(multiple); } std::stringstream ss; *os << std::string(dimension, '['); *os << dptr[0]; - for (size_t i = 1; i < multiple; i++) { + for (size_t i = 1; i < multiple; ++i) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); @@ -258,14 +258,14 @@ class TensorInspector { int dimension = tb_.ndim(); int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); - int multiple = 1; - for (int i = pos.size(), j = 0; i < dimension; i++, j++) { + size_t multiple = 1; + for (int i = pos.size(), j = 0; i < dimension; ++i, ++j) { (*sub_shape)[j] = tb_.shape_[i]; multiple *= tb_.shape_[i]; } - int sum = 0; - int m = 1; - for (int i = pos.size()-1; i >= 0; i--) { + size_t sum = 0; + size_t m = 1; + for (int i = pos.size()-1; i >= 0; --i) { sum += pos[i] * m; m *= tb_.shape_[i]; } @@ -291,7 +291,7 @@ class TensorInspector { if (pos->size() > dimension) { return false; } - for (unsigned i = 0; i < pos->size(); i++) { + for (unsigned i = 0; i < pos->size(); ++i) { if ((*pos)[i] > (tb_.shape_[i]-1)) { return false; } @@ -474,7 +474,7 @@ class TensorInspector { inline std::vector index_to_coordinates(size_t idx) { int dimension = tb_.ndim(); std::vector ret; - for (int i = dimension-1; i >= 0; i--) { + for (int i = dimension-1; i >= 0; --i) { ret.push_back(idx % tb_.shape_[i]); idx /= tb_.shape_[i]; } @@ -500,11 +500,11 @@ class TensorInspector { .check_value_helper(ret, checker, interactive, tag); } #endif // MXNET_USE_CUDA - int count = 0; + size_t count = 0; std::stringstream ss; ss << "["; bool first_pass = true; - for (size_t i = 0; i < tb_.shape_.Size(); i++) { + for (size_t i = 0; i < tb_.shape_.Size(); ++i) { if (checker(tb_.dptr()[i])) { count += 1; if (!first_pass) { @@ -513,7 +513,7 @@ class TensorInspector { first_pass = false; std::vector coords = index_to_coordinates(i); ss << "(" << coords[0]; - for (size_t i = 1; i < coords.size(); i++) { + for (unsigned int i = 1; i < coords.size(); ++i) { ss << ", " << coords[i]; } ss << ")"; @@ -590,7 +590,7 @@ class TensorInspector { dict += std::to_string(sizeof(DType)); dict += "','fortran_order':False,'shape':("; dict += std::to_string(tb_.shape_[0]); - for (int i = 1; i < tb_.ndim(); i++) { + for (int i = 1; i < tb_.ndim(); ++i) { dict += ','; dict += std::to_string(tb_.shape_[i]); } From ace747f9606b779f7bcf454491289e4b21442343 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 12 Jul 2019 14:10:40 -0700 Subject: [PATCH 20/29] miscellaneous --- src/common/tensor_inspector.h | 136 +++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 53 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 9ed33849bd03..5d5af11c7779 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -63,7 +63,7 @@ struct InspectorManager { /* !\brief visit count for check value tags */ std::unordered_map check_value_tag_counter_; /* !\brief visit count for dump value tags */ - std::unordered_map dump_value_tag_counter_; + std::unordered_map dump_to_file_tag_counter_; }; /*! @@ -109,9 +109,9 @@ class TensorInspector { * \tparam StreamType the type of the stream object * \param os stream object to output to */ - template + template inline void tensor_info_to_string(StreamType* os) { - int dimension = tb_.ndim(); + const int dimension = tb_.ndim(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << tb_.shape_[0]; for (int i = 1; i < dimension; ++i) { @@ -127,9 +127,9 @@ class TensorInspector { * \param os stream object to output to * \param shape the shape of the tensor */ - template + template inline void tensor_info_to_string(StreamType* os, const std::vector& shape) { - int dimension = shape.size(); + const int dimension = shape.size(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << shape[0]; for (int i = 1; i < dimension; ++i) { @@ -144,7 +144,7 @@ class TensorInspector { * \tparam StreamType the type of the stream object * \param os stream object to output to */ - template + template inline void to_string_helper(StreamType* os) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -153,7 +153,7 @@ class TensorInspector { return; } #endif // MXNET_USE_CUDA - int dimension = tb_.ndim(); + const int dimension = tb_.ndim(); std::vector multiples; size_t multiple = 1; for (int i = dimension-1; i >= 0; --i) { @@ -175,7 +175,7 @@ class TensorInspector { *os << tb_.dptr()[i]; } *os << std::string(dimension, ']') << std::endl; - tensor_info_to_string(os); + tensor_info_to_string(os); } /*! @@ -185,7 +185,7 @@ class TensorInspector { * \param os stream object to output to * \param dptr the data pointer */ - template + template inline void to_string_helper(StreamType* os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -206,7 +206,7 @@ class TensorInspector { * \param sub_shape the sub-shape of the desired part of the tensor * \param offset the position of the first value of the desired part of the tensor */ - template + template inline void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -220,7 +220,7 @@ class TensorInspector { to_string_helper(os, dptr); return; } - int dimension = sub_shape.size(); + const int dimension = sub_shape.size(); std::vector multiples; size_t multiple = 1; for (int i = dimension-1; i >= 0; --i) { @@ -243,7 +243,7 @@ class TensorInspector { *os << dptr[i]; } *os << std::string(dimension, ']') << std::endl; - tensor_info_to_string(os, sub_shape); + tensor_info_to_string(os, sub_shape); } /*! @@ -255,7 +255,7 @@ class TensorInspector { */ inline void print_locator(const std::vector& pos, std::vector* sub_shape, size_t* offset) { - int dimension = tb_.ndim(); + const int dimension = tb_.ndim(); int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); size_t multiple = 1; @@ -279,7 +279,7 @@ class TensorInspector { * \param str the string that represents the coordinate */ inline bool parse_position(std::vector* pos, const std::string& str) { - int dimension = tb_.ndim(); + const int dimension = tb_.ndim(); std::stringstream ss(str); int i; while (ss >> i) { @@ -304,7 +304,7 @@ class TensorInspector { * \tparam DType the data type * \param tag the name given to this call */ - template + template inline void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -321,10 +321,12 @@ class TensorInspector { std::cout << "Tag: " << tag << " Visit: " << InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; } - tensor_info_to_string(&std::cout); + tensor_info_to_string(&std::cout); std::cout << "Please specify the position, seperated by \",\"" << std::endl; - std::cout << "\"e\" for the entire tensor, \"d\" to dump value to file," << - " \"b\" to break, \"s\" to skip all: "; + std::cout << "\"e\" for the entire tensor, " << + "\"d\" to dump value to file, " << + "\"b\" to break, " << + "\"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -343,7 +345,7 @@ class TensorInspector { std::cout << "Invalid input. "; continue; } - dump_value_helper(str); + dump_to_file_helper(str); break; } continue; @@ -365,8 +367,8 @@ class TensorInspector { * \tparam DType the data type * \param ct the type of the checker */ - template - inline std::function build_checker(CheckerType ct) { + template + inline std::function get_checker(CheckerType ct) { switch (ct) { case NegativeChecker: return [] (DType x) { @@ -472,7 +474,7 @@ class TensorInspector { * \param idx the index of the value in the tensor */ inline std::vector index_to_coordinates(size_t idx) { - int dimension = tb_.ndim(); + const int dimension = tb_.ndim(); std::vector ret; for (int i = dimension-1; i >= 0; --i) { ret.push_back(idx % tb_.shape_[i]); @@ -491,7 +493,7 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - template + template inline void check_value_helper(std::vector>* ret, const std::function& checker, bool interactive, std::string tag) { #if MXNET_USE_CUDA @@ -531,7 +533,9 @@ class TensorInspector { InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; } std::cout << count << " value(s) found." << std::endl; - std::cout << "\"p\" to print the coordinates, \"b\" to break, \"s\" to skip all: "; + std::cout << "\"p\" to print the coordinates, " << + "\"b\" to break, " << + "\"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -569,20 +573,12 @@ class TensorInspector { } /*! - * \brief dump the value of the tensor to a file with name "[tag]_[visit count].npy" in npy format - * the dump file follows npy 1.0 stantand + * \brief generate the header following npy 1.0 format * \tparam DType the data type - * \param tag the name given to this call */ - template - inline void dump_value_helper(const std::string& tag) { -#if MXNET_USE_CUDA - if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) - .dump_value_helper(tag); - return; - } -#endif // MXNET_USE_CUDA + template + inline std::string get_header() { + const int dimension = tb_.ndim(); std::string dict; dict += "{'descr':'"; dict += endian_test(); @@ -590,17 +586,17 @@ class TensorInspector { dict += std::to_string(sizeof(DType)); dict += "','fortran_order':False,'shape':("; dict += std::to_string(tb_.shape_[0]); - for (int i = 1; i < tb_.ndim(); ++i) { + for (int i = 1; i < dimension; ++i) { dict += ','; dict += std::to_string(tb_.shape_[i]); } - if (tb_.ndim() == 1) { + if (dimension == 1) { dict += ","; } dict += ")} "; int padding_size = 64 - ((10 + dict.size()) % 64); dict += std::string(padding_size, ' '); - dict[dict.size()-1] = '\n'; + dict.back() = '\n'; std::string header; header += static_cast(0x93); header += "NUMPY"; @@ -609,15 +605,50 @@ class TensorInspector { header += static_cast((uint16_t)dict.size() & 0x00ff); header += static_cast(((uint16_t)dict.size() >> 8) & 0x00ff); header += dict; - InspectorManager::get()->dump_value_tag_counter_[tag] += 1; - int visit = InspectorManager::get()->dump_value_tag_counter_[tag]; - std::ofstream file(tag + "_" + std::to_string(visit) + ".npy", - std::ios::out | std::ios::binary); - file.write(header.c_str(), header.size()); - file.write(reinterpret_cast(tb_.dptr()), sizeof(DType) * tb_.shape_.Size()); - file.close(); - std::cout << "Tensor dumped to file: " << - tag + "_" + std::to_string(visit) + ".npy" << std::endl; + return header; + } + + /*! + * \brief write the header and the date to an npy file + * \tparam DType the data type + * \param header the header of the file + * \param filename the file name + */ + template + void write_npy(const std::string& header, const std::string& filename) { + std::ofstream file; + file.exceptions(std::ofstream::failbit | std::ofstream::badbit); + try { + file.open(filename, std::ios::out | std::ios::binary); + file.write(header.c_str(), header.size()); + file.write(reinterpret_cast(tb_.dptr()), sizeof(DType) * tb_.shape_.Size()); + file.close(); + std::cout << "Tensor dumped to file: " << filename << std::endl; + } catch (std::ofstream::failure e) { + std::cerr << "Exception opening/writing/closing file " << filename << std::endl; + } + } + + /*! + * \brief dump the value of the tensor to a file with name "[tag]_[visit count].npy" in npy format + * the dump file follows npy 1.0 stantand + * \tparam DType the data type + * \param tag the name given to this call + */ + template + inline void dump_to_file_helper(const std::string& tag) { +#if MXNET_USE_CUDA + if (tb_.dev_mask() == gpu::kDevMask) { + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) + .dump_to_file_helper(tag); + return; + } +#endif // MXNET_USE_CUDA + std::string header = get_header(); + InspectorManager::get()->dump_to_file_tag_counter_[tag] += 1; + const int visit = InspectorManager::get()->dump_to_file_tag_counter_[tag]; + std::string filename = tag + "_" + std::to_string(visit) + ".npy"; + write_npy(header, filename); } /* !\brief the tensor blob */ @@ -634,8 +665,7 @@ class TensorInspector { * \param ts the source tensor object * \param ctx the run context of the tensor */ - template + template TensorInspector(const mshadow::Tensor& ts, const RunContext& ctx): tb_(ts), ctx_(ctx) {} @@ -712,7 +742,7 @@ class TensorInspector { std::string tag = "") { std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - check_value_helper(&ret, build_checker(ct), interactive, tag); + check_value_helper(&ret, get_checker(ct), interactive, tag); }); return ret; } @@ -721,9 +751,9 @@ class TensorInspector { * \brief dump the value of the tensor to a file with name "tag_[visit count].npy" in npy format * \param tag the name given to this call */ - inline void dump_value(std::string tag) { + inline void dump_to_file(std::string tag) { MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - dump_value_helper(tag); + dump_to_file_helper(tag); }); } }; From 09a74d04fc2dc86057b659066781f4db7677d77b Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 12 Jul 2019 14:38:57 -0700 Subject: [PATCH 21/29] sanity fix --- src/common/tensor_inspector.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 5d5af11c7779..671c50a5b298 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -626,7 +626,7 @@ class TensorInspector { std::cout << "Tensor dumped to file: " << filename << std::endl; } catch (std::ofstream::failure e) { std::cerr << "Exception opening/writing/closing file " << filename << std::endl; - } + } } /*! From 2434b0a6696445a2b9e92307a20967502a315001 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Fri, 12 Jul 2019 15:28:53 -0700 Subject: [PATCH 22/29] remove unnecessary inlines --- src/common/tensor_inspector.h | 38 +++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 671c50a5b298..f7c523abe62a 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -110,7 +110,7 @@ class TensorInspector { * \param os stream object to output to */ template - inline void tensor_info_to_string(StreamType* os) { + void tensor_info_to_string(StreamType* os) { const int dimension = tb_.ndim(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << tb_.shape_[0]; @@ -128,7 +128,7 @@ class TensorInspector { * \param shape the shape of the tensor */ template - inline void tensor_info_to_string(StreamType* os, const std::vector& shape) { + void tensor_info_to_string(StreamType* os, const std::vector& shape) { const int dimension = shape.size(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << shape[0]; @@ -145,7 +145,7 @@ class TensorInspector { * \param os stream object to output to */ template - inline void to_string_helper(StreamType* os) { + void to_string_helper(StreamType* os) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -186,7 +186,7 @@ class TensorInspector { * \param dptr the data pointer */ template - inline void to_string_helper(StreamType* os, const DType* dptr) { + void to_string_helper(StreamType* os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -207,7 +207,7 @@ class TensorInspector { * \param offset the position of the first value of the desired part of the tensor */ template - inline void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { + void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -253,7 +253,7 @@ class TensorInspector { * \param sub_shape the sub-shape of the desired part of the tensor; calculated here * \param offset the position of the first value of the desired part of the tensor; calculated here */ - inline void print_locator(const std::vector& pos, std::vector* sub_shape, + void print_locator(const std::vector& pos, std::vector* sub_shape, size_t* offset) { const int dimension = tb_.ndim(); int sub_dim = dimension - pos.size(); @@ -278,7 +278,7 @@ class TensorInspector { * \param pos the coordinates of the desired part of the tensor, calculated here * \param str the string that represents the coordinate */ - inline bool parse_position(std::vector* pos, const std::string& str) { + bool parse_position(std::vector* pos, const std::string& str) { const int dimension = tb_.ndim(); std::stringstream ss(str); int i; @@ -305,7 +305,7 @@ class TensorInspector { * \param tag the name given to this call */ template - inline void interactive_print_helper(std::string tag) { + void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -368,7 +368,7 @@ class TensorInspector { * \param ct the type of the checker */ template - inline std::function get_checker(CheckerType ct) { + std::function get_checker(CheckerType ct) { switch (ct) { case NegativeChecker: return [] (DType x) { @@ -473,7 +473,7 @@ class TensorInspector { * \brief calculate the coordinate of a value in the tensor, given its index * \param idx the index of the value in the tensor */ - inline std::vector index_to_coordinates(size_t idx) { + std::vector index_to_coordinates(size_t idx) { const int dimension = tb_.ndim(); std::vector ret; for (int i = dimension-1; i >= 0; --i) { @@ -494,7 +494,7 @@ class TensorInspector { * \param tag the name given to this call */ template - inline void check_value_helper(std::vector>* ret, + void check_value_helper(std::vector>* ret, const std::function& checker, bool interactive, std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -577,7 +577,7 @@ class TensorInspector { * \tparam DType the data type */ template - inline std::string get_header() { + std::string get_header() { const int dimension = tb_.ndim(); std::string dict; dict += "{'descr':'"; @@ -636,7 +636,7 @@ class TensorInspector { * \param tag the name given to this call */ template - inline void dump_to_file_helper(const std::string& tag) { + void dump_to_file_helper(const std::string& tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -688,14 +688,14 @@ class TensorInspector { /*! * \brief print the tensor to std::cout */ - inline void print_string() { + void print_string() { std::cout << to_string() << std::endl; } /*! * \brief return a string which contains the values and other info of the tensor */ - inline std::string to_string() { + std::string to_string() { std::stringstream ss; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { to_string_helper(&ss); @@ -707,7 +707,7 @@ class TensorInspector { * \brief interactively print the tensor value * \param tag the name given to this call */ - inline void interactive_print(std::string tag = "") { + void interactive_print(std::string tag = "") { MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { interactive_print_helper(tag); }); @@ -722,7 +722,7 @@ class TensorInspector { * \param tag the name given to this call */ template - inline std::vector> check_value(const ValueChecker& checker, + std::vector> check_value(const ValueChecker& checker, bool interactive = false, std::string tag = "") { std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { @@ -738,7 +738,7 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - inline std::vector> check_value(CheckerType ct, bool interactive = false, + std::vector> check_value(CheckerType ct, bool interactive = false, std::string tag = "") { std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { @@ -751,7 +751,7 @@ class TensorInspector { * \brief dump the value of the tensor to a file with name "tag_[visit count].npy" in npy format * \param tag the name given to this call */ - inline void dump_to_file(std::string tag) { + void dump_to_file(std::string tag) { MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { dump_to_file_helper(tag); }); From 877047debd7e274e8c79a3533ec8181a403c99eb Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 15 Jul 2019 15:02:28 -0700 Subject: [PATCH 23/29] change size_t to index_t --- src/common/tensor_inspector.h | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index f7c523abe62a..561a21c3e1ca 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -139,7 +139,7 @@ class TensorInspector { } /*! - * \brief output the tensor in a structed format + * \brief output the tensor in a structured format * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to @@ -154,15 +154,15 @@ class TensorInspector { } #endif // MXNET_USE_CUDA const int dimension = tb_.ndim(); - std::vector multiples; - size_t multiple = 1; + std::vector multiples; + index_t multiple = 1; for (int i = dimension-1; i >= 0; --i) { multiple *= tb_.shape_[i]; multiples.push_back(multiple); } *os << std::string(dimension, '['); *os << tb_.dptr()[0]; - for (size_t i = 1; i < tb_.shape_.Size(); ++i) { + for (index_t i = 1; i < tb_.shape_.Size(); ++i) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); @@ -179,7 +179,7 @@ class TensorInspector { } /*! - * \brief output the tensor in a structed format + * \brief output the tensor in a structured format * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to @@ -207,7 +207,7 @@ class TensorInspector { * \param offset the position of the first value of the desired part of the tensor */ template - void to_string_helper(StreamType* os, const std::vector& sub_shape, size_t offset) { + void to_string_helper(StreamType* os, const std::vector& sub_shape, index_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -221,8 +221,8 @@ class TensorInspector { return; } const int dimension = sub_shape.size(); - std::vector multiples; - size_t multiple = 1; + std::vector multiples; + index_t multiple = 1; for (int i = dimension-1; i >= 0; --i) { multiple *= sub_shape[i]; multiples.push_back(multiple); @@ -230,7 +230,7 @@ class TensorInspector { std::stringstream ss; *os << std::string(dimension, '['); *os << dptr[0]; - for (size_t i = 1; i < multiple; ++i) { + for (index_t i = 1; i < multiple; ++i) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); @@ -253,18 +253,17 @@ class TensorInspector { * \param sub_shape the sub-shape of the desired part of the tensor; calculated here * \param offset the position of the first value of the desired part of the tensor; calculated here */ - void print_locator(const std::vector& pos, std::vector* sub_shape, - size_t* offset) { + void print_locator(const std::vector& pos, std::vector* sub_shape, index_t* offset) { const int dimension = tb_.ndim(); int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); - size_t multiple = 1; + index_t multiple = 1; for (int i = pos.size(), j = 0; i < dimension; ++i, ++j) { (*sub_shape)[j] = tb_.shape_[i]; multiple *= tb_.shape_[i]; } - size_t sum = 0; - size_t m = 1; + index_t sum = 0; + index_t m = 1; for (int i = pos.size()-1; i >= 0; --i) { sum += pos[i] * m; m *= tb_.shape_[i]; @@ -353,7 +352,7 @@ class TensorInspector { std::vector pos; if (parse_position(&pos, str)) { std::vector sub_shape; - size_t offset; + index_t offset; print_locator(pos, &sub_shape, &offset); to_string_helper(&std::cout, sub_shape, offset); } else { @@ -473,7 +472,7 @@ class TensorInspector { * \brief calculate the coordinate of a value in the tensor, given its index * \param idx the index of the value in the tensor */ - std::vector index_to_coordinates(size_t idx) { + std::vector index_to_coordinates(index_t idx) { const int dimension = tb_.ndim(); std::vector ret; for (int i = dimension-1; i >= 0; --i) { @@ -502,11 +501,11 @@ class TensorInspector { .check_value_helper(ret, checker, interactive, tag); } #endif // MXNET_USE_CUDA - size_t count = 0; + index_t count = 0; std::stringstream ss; ss << "["; bool first_pass = true; - for (size_t i = 0; i < tb_.shape_.Size(); ++i) { + for (index_t i = 0; i < tb_.shape_.Size(); ++i) { if (checker(tb_.dptr()[i])) { count += 1; if (!first_pass) { From 4f4fca426dcca58e8e903157bdb3d9f3cb515ad1 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Mon, 15 Jul 2019 16:37:21 -0700 Subject: [PATCH 24/29] bug fixes and add print value options in value_check() --- src/common/tensor_inspector.h | 89 +++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 26 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 561a21c3e1ca..a14cc031e2f4 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -128,7 +128,7 @@ class TensorInspector { * \param shape the shape of the tensor */ template - void tensor_info_to_string(StreamType* os, const std::vector& shape) { + void tensor_info_to_string(StreamType* os, const std::vector& shape) { const int dimension = shape.size(); *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; *os << shape[0]; @@ -156,7 +156,7 @@ class TensorInspector { const int dimension = tb_.ndim(); std::vector multiples; index_t multiple = 1; - for (int i = dimension-1; i >= 0; --i) { + for (int i = dimension - 1; i >= 0; --i) { multiple *= tb_.shape_[i]; multiples.push_back(multiple); } @@ -207,7 +207,7 @@ class TensorInspector { * \param offset the position of the first value of the desired part of the tensor */ template - void to_string_helper(StreamType* os, const std::vector& sub_shape, index_t offset) { + void to_string_helper(StreamType* os, const std::vector& sub_shape, index_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -223,7 +223,7 @@ class TensorInspector { const int dimension = sub_shape.size(); std::vector multiples; index_t multiple = 1; - for (int i = dimension-1; i >= 0; --i) { + for (int i = dimension - 1; i >= 0; --i) { multiple *= sub_shape[i]; multiples.push_back(multiple); } @@ -253,7 +253,8 @@ class TensorInspector { * \param sub_shape the sub-shape of the desired part of the tensor; calculated here * \param offset the position of the first value of the desired part of the tensor; calculated here */ - void print_locator(const std::vector& pos, std::vector* sub_shape, index_t* offset) { + void print_locator(const std::vector& pos, std::vector* sub_shape, + index_t* offset) { const int dimension = tb_.ndim(); int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); @@ -264,7 +265,7 @@ class TensorInspector { } index_t sum = 0; index_t m = 1; - for (int i = pos.size()-1; i >= 0; --i) { + for (int i = pos.size() - 1; i >= 0; --i) { sum += pos[i] * m; m *= tb_.shape_[i]; } @@ -277,12 +278,12 @@ class TensorInspector { * \param pos the coordinates of the desired part of the tensor, calculated here * \param str the string that represents the coordinate */ - bool parse_position(std::vector* pos, const std::string& str) { + bool parse_position(std::vector* pos, const std::string& str) { const int dimension = tb_.ndim(); std::stringstream ss(str); - int i; - while (ss >> i) { - pos->push_back(i); + index_t n; + while (ss >> n) { + pos->push_back(n); if (ss.peek() == ',') { ss.ignore(); } @@ -291,7 +292,7 @@ class TensorInspector { return false; } for (unsigned i = 0; i < pos->size(); ++i) { - if ((*pos)[i] > (tb_.shape_[i]-1)) { + if ((*pos)[i] > (tb_.shape_[i] - 1)) { return false; } } @@ -349,9 +350,9 @@ class TensorInspector { } continue; } - std::vector pos; + std::vector pos; if (parse_position(&pos, str)) { - std::vector sub_shape; + std::vector sub_shape; index_t offset; print_locator(pos, &sub_shape, &offset); to_string_helper(&std::cout, sub_shape, offset); @@ -472,10 +473,10 @@ class TensorInspector { * \brief calculate the coordinate of a value in the tensor, given its index * \param idx the index of the value in the tensor */ - std::vector index_to_coordinates(index_t idx) { + std::vector index_to_coordinates(index_t idx) { const int dimension = tb_.ndim(); - std::vector ret; - for (int i = dimension-1; i >= 0; --i) { + std::vector ret; + for (int i = dimension - 1; i >= 0; --i) { ret.push_back(idx % tb_.shape_[i]); idx /= tb_.shape_[i]; } @@ -493,7 +494,7 @@ class TensorInspector { * \param tag the name given to this call */ template - void check_value_helper(std::vector>* ret, + void check_value_helper(std::vector>* ret, const std::function& checker, bool interactive, std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -512,7 +513,7 @@ class TensorInspector { ss << ", "; } first_pass = false; - std::vector coords = index_to_coordinates(i); + std::vector coords = index_to_coordinates(i); ss << "(" << coords[0]; for (unsigned int i = 1; i < coords.size(); ++i) { ss << ", " << coords[i]; @@ -527,22 +528,40 @@ class TensorInspector { InspectorManager::get()->check_value_tag_counter_[tag] += 1; while (!InspectorManager::get()->check_value_skip_all_) { std::cout << "----------Value Check----------" << std::endl; + tensor_info_to_string(&std::cout); if (tag != "") { std::cout << "Tag: " << tag << " Visit: " << InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; } std::cout << count << " value(s) found." << std::endl; - std::cout << "\"p\" to print the coordinates, " << + std::cout << "To print a part of the tensor," << + " please specify a position, seperated by \",\"" << std::endl; + std::cout << "\"e\" for the entire tensor, " << + "\"p\" to print the coordinates of the values found, " << "\"b\" to break, " << "\"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { break; + } else if (str == "e") { + to_string_helper(&std::cout); + continue; } else if (str == "p") { std::cout << ss.str() << std::endl; + continue; } else if (str == "s") { InspectorManager::get()->check_value_skip_all_ = true; + break; + } + std::vector pos; + if (parse_position(&pos, str)) { + std::vector sub_shape; + index_t offset; + print_locator(pos, &sub_shape, &offset); + to_string_helper(&std::cout, sub_shape, offset); + } else { + std::cout << "invalid input" << std::endl; } } } @@ -650,6 +669,18 @@ class TensorInspector { write_npy(header, filename); } + /*! + * \brief validate that the shape + */ + inline void validate_shape() { + const int dimension = tb_.ndim(); + CHECK(dimension > 0) << "Tensor Inspector does not support empty tensors " << + "or tensors of unknow shape."; + for (int i = 0; i < dimension; ++i) { + CHECK(tb_.shape_[i] != 0) << "Invalid tensor shape: shape_[" << i << "] is 0"; + } + } + /* !\brief the tensor blob */ const TBlob tb_; /* !\brief the run context of the tensor */ @@ -666,7 +697,9 @@ class TensorInspector { */ template TensorInspector(const mshadow::Tensor& ts, const RunContext& ctx): - tb_(ts), ctx_(ctx) {} + tb_(ts), ctx_(ctx) { + validate_shape(); + } /*! * \brief construct from TBlob object @@ -674,7 +707,9 @@ class TensorInspector { * \param ctx the run context of the tensor */ TensorInspector(const TBlob& tb, const RunContext& ctx): - tb_(tb), ctx_(ctx) {} + tb_(tb), ctx_(ctx) { + validate_shape(); + } /*! * \brief construct from NDArray object. Currently this only works with kDefaultStorage @@ -682,7 +717,9 @@ class TensorInspector { * \param ctx the run context of the tensor */ TensorInspector(const NDArray& arr, const RunContext& ctx): - tb_(arr.data()), ctx_(ctx) {} + tb_(arr.data()), ctx_(ctx) { + validate_shape(); + } /*! * \brief print the tensor to std::cout @@ -721,9 +758,9 @@ class TensorInspector { * \param tag the name given to this call */ template - std::vector> check_value(const ValueChecker& checker, + std::vector> check_value(const ValueChecker& checker, bool interactive = false, std::string tag = "") { - std::vector> ret; + std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { check_value_helper(&ret, checker, ret, interactive, tag); }); @@ -737,9 +774,9 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - std::vector> check_value(CheckerType ct, bool interactive = false, + std::vector> check_value(CheckerType ct, bool interactive = false, std::string tag = "") { - std::vector> ret; + std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { check_value_helper(&ret, get_checker(ct), interactive, tag); }); From 15c46e675d3b4ac2b292fc2275ee060ec82b2e47 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 16 Jul 2019 11:01:48 -0700 Subject: [PATCH 25/29] fix warnings --- src/common/tensor_inspector.h | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index a14cc031e2f4..a46c63563abe 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -162,7 +162,7 @@ class TensorInspector { } *os << std::string(dimension, '['); *os << tb_.dptr()[0]; - for (index_t i = 1; i < tb_.shape_.Size(); ++i) { + for (index_t i = 1; static_cast(i) < tb_.shape_.Size(); ++i) { int n = 0; for (auto divisor : multiples) { n += (i % divisor == 0); @@ -256,7 +256,7 @@ class TensorInspector { void print_locator(const std::vector& pos, std::vector* sub_shape, index_t* offset) { const int dimension = tb_.ndim(); - int sub_dim = dimension - pos.size(); + const int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); index_t multiple = 1; for (int i = pos.size(), j = 0; i < dimension; ++i, ++j) { @@ -288,7 +288,7 @@ class TensorInspector { ss.ignore(); } } - if (pos->size() > dimension) { + if (pos->size() > static_cast(dimension)) { return false; } for (unsigned i = 0; i < pos->size(); ++i) { @@ -322,7 +322,8 @@ class TensorInspector { InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; } tensor_info_to_string(&std::cout); - std::cout << "Please specify the position, seperated by \",\"" << std::endl; + std::cout << "To print a part of the tensor," << + " please specify a position, seperated by \",\"" << std::endl; std::cout << "\"e\" for the entire tensor, " << "\"d\" to dump value to file, " << "\"b\" to break, " << @@ -397,7 +398,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x == (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0; + return x == (DType)1.0 / 0.0f || x == -(DType)1.0 / 0.0f; }; } else { LOG(WARNING) << "InfChecker only applies to float types. " << @@ -408,7 +409,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x == (DType)1.0 / (DType)0.0; + return x == (DType)1.0 / 0.0f; }; } else { LOG(WARNING) << "PositiveInfChecker only applies to float types. " << @@ -419,7 +420,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x == -(DType)1.0 / (DType)0.0; + return x == -(DType)1.0 / 0.0f; }; } else { LOG(WARNING) << "NegativeInfChecker only applies to float types. " << @@ -430,7 +431,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x != (DType)1.0 / (DType)0.0 && x != -(DType)1.0 / (DType)0.0; + return x != (DType)1.0 / 0.0f && x != -(DType)1.0 / 0.0f; }; } else { LOG(WARNING) << "FiniteChecker only applies to float types. " << @@ -441,7 +442,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x != (DType)1.0 / (DType)0.0 && x != -(DType)1.0 / (DType)0.0 && + return x != (DType)1.0 / 0.0f && x != -(DType)1.0 / 0.0f && x == x; }; } else { @@ -453,7 +454,7 @@ class TensorInspector { if (std::is_same::value || std::is_same::value || std::is_same::value) { return [] (DType x) { - return x == (DType)1.0 / (DType)0.0 || x == -(DType)1.0 / (DType)0.0 || + return x == (DType)1.0 / 0.0f || x == -(DType)1.0 / 0.0f || x != x; }; } else { @@ -506,7 +507,7 @@ class TensorInspector { std::stringstream ss; ss << "["; bool first_pass = true; - for (index_t i = 0; i < tb_.shape_.Size(); ++i) { + for (index_t i = 0; static_cast(i) < tb_.shape_.Size(); ++i) { if (checker(tb_.dptr()[i])) { count += 1; if (!first_pass) { @@ -674,7 +675,7 @@ class TensorInspector { */ inline void validate_shape() { const int dimension = tb_.ndim(); - CHECK(dimension > 0) << "Tensor Inspector does not support empty tensors " << + CHECK(dimension > 0) << "Tensor Inspector does not support empty tensors " << "or tensors of unknow shape."; for (int i = 0; i < dimension; ++i) { CHECK(tb_.shape_[i] != 0) << "Invalid tensor shape: shape_[" << i << "] is 0"; From 2af9226cf2d44534ac130a55687a75a9d5145bd5 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 16 Jul 2019 14:07:13 -0700 Subject: [PATCH 26/29] add negative check --- src/common/tensor_inspector.h | 61 ++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index a46c63563abe..3af5a2a19183 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -112,7 +112,7 @@ class TensorInspector { template void tensor_info_to_string(StreamType* os) { const int dimension = tb_.ndim(); - *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; + *os << "<" << infer_type_string(typeid(DType)) << " Tensor "; *os << tb_.shape_[0]; for (int i = 1; i < dimension; ++i) { *os << 'x' << tb_.shape_[i]; @@ -130,7 +130,7 @@ class TensorInspector { template void tensor_info_to_string(StreamType* os, const std::vector& shape) { const int dimension = shape.size(); - *os << "<" << typeid(tb_.dptr()[0]).name() << " Tensor "; + *os << "<" << infer_type_string(typeid(DType)) << " Tensor "; *os << shape[0]; for (int i = 1; i < dimension; ++i) { *os << 'x' << shape[i]; @@ -154,18 +154,18 @@ class TensorInspector { } #endif // MXNET_USE_CUDA const int dimension = tb_.ndim(); - std::vector multiples; + std::vector offsets; index_t multiple = 1; for (int i = dimension - 1; i >= 0; --i) { multiple *= tb_.shape_[i]; - multiples.push_back(multiple); + offsets.push_back(multiple); } *os << std::string(dimension, '['); *os << tb_.dptr()[0]; - for (index_t i = 1; static_cast(i) < tb_.shape_.Size(); ++i) { + for (index_t i = 1; i < static_cast(tb_.shape_.Size()); ++i) { int n = 0; - for (auto divisor : multiples) { - n += (i % divisor == 0); + for (auto off : offsets) { + n += (i % off == 0); } if (n) { *os << std::string(n, ']') << ", " << std::string(n, '['); @@ -221,19 +221,19 @@ class TensorInspector { return; } const int dimension = sub_shape.size(); - std::vector multiples; + std::vector offsets; index_t multiple = 1; for (int i = dimension - 1; i >= 0; --i) { multiple *= sub_shape[i]; - multiples.push_back(multiple); + offsets.push_back(multiple); } std::stringstream ss; *os << std::string(dimension, '['); *os << dptr[0]; for (index_t i = 1; i < multiple; ++i) { int n = 0; - for (auto divisor : multiples) { - n += (i % divisor == 0); + for (auto off : offsets) { + n += (i % off == 0); } if (n) { *os << std::string(n, ']') << ", " << std::string(n, '['); @@ -254,7 +254,7 @@ class TensorInspector { * \param offset the position of the first value of the desired part of the tensor; calculated here */ void print_locator(const std::vector& pos, std::vector* sub_shape, - index_t* offset) { + index_t* offset) { const int dimension = tb_.ndim(); const int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); @@ -280,7 +280,7 @@ class TensorInspector { */ bool parse_position(std::vector* pos, const std::string& str) { const int dimension = tb_.ndim(); - std::stringstream ss(str); + std::istringstream ss(str); index_t n; while (ss >> n) { pos->push_back(n); @@ -288,11 +288,11 @@ class TensorInspector { ss.ignore(); } } - if (pos->size() > static_cast(dimension)) { + if (pos->size() > static_cast(dimension)) { return false; } for (unsigned i = 0; i < pos->size(); ++i) { - if ((*pos)[i] > (tb_.shape_[i] - 1)) { + if ((*pos)[i] > (tb_.shape_[i] - 1) || (*pos)[i] < 0) { return false; } } @@ -322,8 +322,8 @@ class TensorInspector { InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; } tensor_info_to_string(&std::cout); - std::cout << "To print a part of the tensor," << - " please specify a position, seperated by \",\"" << std::endl; + std::cout << "To print a part of the tensor, " << + "please specify a position, seperated by \",\"" << std::endl; std::cout << "\"e\" for the entire tensor, " << "\"d\" to dump value to file, " << "\"b\" to break, " << @@ -343,7 +343,7 @@ class TensorInspector { std::cout << "Please enter a tag: "; std::cin >> str; if (str.find(' ') != std::string::npos) { - std::cout << "Invalid input. "; + std::cout << "Invalid tag name. No space allowed."; continue; } dump_to_file_helper(str); @@ -358,7 +358,7 @@ class TensorInspector { print_locator(pos, &sub_shape, &offset); to_string_helper(&std::cout, sub_shape, offset); } else { - std::cout << "invalid input" << std::endl; + std::cout << "invalid command/indices" << std::endl; } } } @@ -507,16 +507,16 @@ class TensorInspector { std::stringstream ss; ss << "["; bool first_pass = true; - for (index_t i = 0; static_cast(i) < tb_.shape_.Size(); ++i) { + for (index_t i = 0; i (tb_.shape_.Size()); ++i) { if (checker(tb_.dptr()[i])) { - count += 1; + ++count; if (!first_pass) { ss << ", "; } first_pass = false; std::vector coords = index_to_coordinates(i); ss << "(" << coords[0]; - for (unsigned int i = 1; i < coords.size(); ++i) { + for (size_t i = 1; i < coords.size(); ++i) { ss << ", " << coords[i]; } ss << ")"; @@ -562,7 +562,7 @@ class TensorInspector { print_locator(pos, &sub_shape, &offset); to_string_helper(&std::cout, sub_shape, offset); } else { - std::cout << "invalid input" << std::endl; + std::cout << "invalid command/indices" << std::endl; } } } @@ -583,6 +583,21 @@ class TensorInspector { return '?'; } + /*! + * \brief infer the python type, given the c++ type + * \tparam ti the type info + */ + inline std::string infer_type_string(const std::type_info& ti) { + if (ti == typeid(float)) return "float"; + else if (ti == typeid(double)) return "double"; + else if (ti == typeid(mshadow::half::half_t) ) return "mshasow::half::half_t"; + else if (ti == typeid(uint8_t)) return "uint8_t"; + else if (ti == typeid(int32_t)) return "int32_t"; + else if (ti == typeid(int64_t)) return "int64_t"; + else + return "unknown tyoe"; + } + /*! * \brief check if the host machine is big or small endian */ From 175678986e74c8b4a111d330c88b4c89fe51ac14 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 16 Jul 2019 15:59:33 -0700 Subject: [PATCH 27/29] Re-Trigger build From 728b773088fc54097d56b47c680e305c11cb1f69 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Tue, 16 Jul 2019 17:11:30 -0700 Subject: [PATCH 28/29] Re-Trigger build From 1b1073340e151e959c15a0e8b1325cad0011ba01 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Wed, 17 Jul 2019 11:23:23 -0700 Subject: [PATCH 29/29] change int to size_t --- src/common/tensor_inspector.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index 3af5a2a19183..2df94b7fc04f 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -259,13 +259,13 @@ class TensorInspector { const int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); index_t multiple = 1; - for (int i = pos.size(), j = 0; i < dimension; ++i, ++j) { + for (size_t i = pos.size(), j = 0; i < static_cast(dimension); ++i, ++j) { (*sub_shape)[j] = tb_.shape_[i]; multiple *= tb_.shape_[i]; } index_t sum = 0; index_t m = 1; - for (int i = pos.size() - 1; i >= 0; --i) { + for (index_t i = pos.size() - 1; i >= 0; --i) { sum += pos[i] * m; m *= tb_.shape_[i]; } @@ -291,7 +291,7 @@ class TensorInspector { if (pos->size() > static_cast(dimension)) { return false; } - for (unsigned i = 0; i < pos->size(); ++i) { + for (size_t i = 0; i < pos->size(); ++i) { if ((*pos)[i] > (tb_.shape_[i] - 1) || (*pos)[i] < 0) { return false; }