From 56f6eaeb2789c4bdf5b480c1f631ad2e688aa0da Mon Sep 17 00:00:00 2001 From: John Woods Date: Thu, 7 Apr 2016 16:26:32 -0500 Subject: [PATCH] Made templates a little smarter for those functions which require a separate return DType by adding MagnitudeDType. Also added a magnitude function to replace std::abs and abs -- but mostly it just calls those. --- ext/nmatrix/math.cpp | 46 ++++++++++----------- ext/nmatrix/math/asum.h | 41 +++++-------------- ext/nmatrix/math/cblas_templates_core.h | 20 ++++----- ext/nmatrix/math/getrf.h | 4 +- ext/nmatrix/math/imax.h | 21 +++++----- ext/nmatrix/math/long_dtype.h | 19 +++++++-- ext/nmatrix/math/magnitude.h | 54 +++++++++++++++++++++++++ ext/nmatrix/math/nrm2.h | 16 ++++---- ext/nmatrix_atlas/math_atlas.cpp | 30 +++++++------- ext/nmatrix_lapacke/math_lapacke.cpp | 32 +++++++-------- 10 files changed, 167 insertions(+), 116 deletions(-) create mode 100644 ext/nmatrix/math/magnitude.h diff --git a/ext/nmatrix/math.cpp b/ext/nmatrix/math.cpp index 095f8909..9288d8c8 100644 --- a/ext/nmatrix/math.cpp +++ b/ext/nmatrix/math.cpp @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -134,6 +134,7 @@ #include "math/cblas_enums.h" #include "data/data.h" +#include "math/magnitude.h" #include "math/imax.h" #include "math/scal.h" #include "math/laswp.h" @@ -237,11 +238,14 @@ namespace nm { int col_index[M]; for (int k = 0;k < M; ++k) { - DType akk = std::abs( matrix[k * (M + 1)] ) ; // diagonal element + typename MagnitudeDType::type akk; + akk = magnitude( matrix[k * (M + 1)] ); // diagonal element + int interchange = k; for (int row = k + 1; row < M; ++row) { - DType big = std::abs( matrix[M*row + k] ); // element below the temp pivot + typename MagnitudeDType::type big; + big = magnitude( matrix[M*row + k] ); // element below the temp pivot if ( big > akk ) { interchange = row; @@ -694,16 +698,12 @@ static VALUE nm_cblas_rot(VALUE self, VALUE n, VALUE x, VALUE incx, VALUE y, VAL static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { -/* nm::math::cblas_nrm2, - nm::math::cblas_nrm2, - nm::math::cblas_nrm2, - nm::math::cblas_nrm2, */ NULL, NULL, NULL, NULL, NULL, // no help for integers - nm::math::cblas_nrm2, - nm::math::cblas_nrm2, - nm::math::cblas_nrm2, - nm::math::cblas_nrm2, - nm::math::cblas_nrm2 + nm::math::cblas_nrm2, + nm::math::cblas_nrm2, + nm::math::cblas_nrm2, + nm::math::cblas_nrm2, + nm::math::cblas_nrm2 }; nm::dtype_t dtype = NM_DTYPE(x); @@ -748,16 +748,16 @@ static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static VALUE nm_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum }; nm::dtype_t dtype = NM_DTYPE(x); diff --git a/ext/nmatrix/math/asum.h b/ext/nmatrix/math/asum.h index 697b9af9..dbd9ab65 100644 --- a/ext/nmatrix/math/asum.h +++ b/ext/nmatrix/math/asum.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -60,6 +60,8 @@ #define ASUM_H +#include "math/magnitude.h" + namespace nm { namespace math { /* @@ -73,44 +75,21 @@ namespace nm { namespace math { * complex64 -> float or double * complex128 -> double */ -template -inline ReturnDType asum(const int N, const DType* X, const int incX) { - ReturnDType sum = 0; - if ((N > 0) && (incX > 0)) { - for (int i = 0; i < N; ++i) { - sum += std::abs(X[i*incX]); - } - } - return sum; -} - - -template <> -inline float asum(const int N, const Complex64* X, const int incX) { - float sum = 0; - if ((N > 0) && (incX > 0)) { - for (int i = 0; i < N; ++i) { - sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i); - } - } - return sum; -} - -template <> -inline double asum(const int N, const Complex128* X, const int incX) { - double sum = 0; +template ::type> +inline MDType asum(const int N, const DType* X, const int incX) { + MDType sum = 0; if ((N > 0) && (incX > 0)) { for (int i = 0; i < N; ++i) { - sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i); + sum += magnitude(X[i*incX]); } } return sum; } -template +template ::type> inline void cblas_asum(const int N, const void* X, const int incX, void* sum) { - *reinterpret_cast( sum ) = asum( N, reinterpret_cast(X), incX ); + *reinterpret_cast( sum ) = asum( N, reinterpret_cast(X), incX ); } diff --git a/ext/nmatrix/math/cblas_templates_core.h b/ext/nmatrix/math/cblas_templates_core.h index 5060ae08..25aec0a0 100644 --- a/ext/nmatrix/math/cblas_templates_core.h +++ b/ext/nmatrix/math/cblas_templates_core.h @@ -107,9 +107,9 @@ inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int i * complex64 -> float or double * complex128 -> double */ -template -inline ReturnDType asum(const int N, const DType* X, const int incX) { - return nm::math::asum(N,X,incX); +template ::type> +inline MDType asum(const int N, const DType* X, const int incX) { + return nm::math::asum(N,X,incX); } @@ -134,9 +134,9 @@ inline double asum(const int N, const Complex128* X, const int incX) { } -template +template ::type> inline void cblas_asum(const int N, const void* X, const int incX, void* sum) { - *static_cast( sum ) = asum( N, static_cast(X), incX ); + *static_cast( sum ) = asum( N, static_cast(X), incX ); } /* @@ -149,9 +149,9 @@ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) { * complex64 -> float or double * complex128 -> double */ -template -inline ReturnDType nrm2(const int N, const DType* X, const int incX) { - return nm::math::nrm2(N, X, incX); +template ::type> +inline MDType nrm2(const int N, const DType* X, const int incX) { + return nm::math::nrm2(N, X, incX); } @@ -175,9 +175,9 @@ inline double nrm2(const int N, const Complex128* X, const int incX) { return cblas_dznrm2(N, X, incX); } -template +template ::type> inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) { - *static_cast( result ) = nrm2( N, static_cast(X), incX ); + *static_cast( result ) = nrm2( N, static_cast(X), incX ); } //imax diff --git a/ext/nmatrix/math/getrf.h b/ext/nmatrix/math/getrf.h index 03b702bf..29ebe4e2 100644 --- a/ext/nmatrix/math/getrf.h +++ b/ext/nmatrix/math/getrf.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // diff --git a/ext/nmatrix/math/imax.h b/ext/nmatrix/math/imax.h index 01d4fa8e..4ff57201 100644 --- a/ext/nmatrix/math/imax.h +++ b/ext/nmatrix/math/imax.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -29,8 +29,11 @@ #ifndef IMAX_H #define IMAX_H +#include "math/magnitude.h" + namespace nm { namespace math { + template inline int imax(const int n, const DType *x, const int incx) { @@ -41,28 +44,28 @@ inline int imax(const int n, const DType *x, const int incx) { return 0; } - DType dmax; + typename MagnitudeDType::type dmax; int imax = 0; if (incx == 1) { // if incrementing by 1 - dmax = abs(x[0]); + dmax = magnitude(x[0]); for (int i = 1; i < n; ++i) { - if (std::abs(x[i]) > dmax) { + if (magnitude(x[i]) > dmax) { imax = i; - dmax = std::abs(x[i]); + dmax = magnitude(x[i]); } } } else { // if incrementing by more than 1 - dmax = std::abs(x[0]); + dmax = magnitude(x[0]); for (int i = 1, ix = incx; i < n; ++i, ix += incx) { - if (std::abs(x[ix]) > dmax) { + if (magnitude(x[ix]) > dmax) { imax = i; - dmax = std::abs(x[ix]); + dmax = magnitude(x[ix]); } } } diff --git a/ext/nmatrix/math/long_dtype.h b/ext/nmatrix/math/long_dtype.h index 7c6138db..dd4a1dbb 100644 --- a/ext/nmatrix/math/long_dtype.h +++ b/ext/nmatrix/math/long_dtype.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -23,7 +23,8 @@ // // == long_dtype.h // -// Declarations necessary for the native versions of GEMM and GEMV. +// Declarations necessary for the native versions of GEMM and GEMV, +// as well as for IMAX. // #ifndef LONG_DTYPE_H @@ -44,6 +45,18 @@ namespace nm { namespace math { template <> struct LongDType { typedef Complex128 type; }; template <> struct LongDType { typedef RubyObject type; }; + template struct MagnitudeDType; + template <> struct MagnitudeDType { typedef uint8_t type; }; + template <> struct MagnitudeDType { typedef int8_t type; }; + template <> struct MagnitudeDType { typedef int16_t type; }; + template <> struct MagnitudeDType { typedef int32_t type; }; + template <> struct MagnitudeDType { typedef int64_t type; }; + template <> struct MagnitudeDType { typedef float type; }; + template <> struct MagnitudeDType { typedef double type; }; + template <> struct MagnitudeDType { typedef float type; }; + template <> struct MagnitudeDType { typedef double type; }; + template <> struct MagnitudeDType { typedef RubyObject type; }; + }} // end of namespace nm::math #endif diff --git a/ext/nmatrix/math/magnitude.h b/ext/nmatrix/math/magnitude.h new file mode 100644 index 00000000..e53303d3 --- /dev/null +++ b/ext/nmatrix/math/magnitude.h @@ -0,0 +1,54 @@ +///////////////////////////////////////////////////////////////////// +// = NMatrix +// +// A linear algebra library for scientific computation in Ruby. +// NMatrix is part of SciRuby. +// +// NMatrix was originally inspired by and derived from NArray, by +// Masahiro Tanaka: http://narray.rubyforge.org +// +// == Copyright Information +// +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation +// +// Please see LICENSE.txt for additional copyright notices. +// +// == Contributing +// +// By contributing source code to SciRuby, you agree to be bound by +// our Contributor Agreement: +// +// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement +// +// == math/magnitude.h +// +// Takes the absolute value (meaning magnitude) of each DType. +// Needed for a variety of BLAS/LAPACK functions. +// + +#ifndef MAGNITUDE_H +#define MAGNITUDE_H + +#include "math/long_dtype.h" + +namespace nm { namespace math { + +/* Magnitude -- may be complicated for unsigned types, and need to call the correct STL abs for floats/doubles */ +template ::type> +inline MDType magnitude(const DType& v) { + return v.abs(); +} +template <> inline float magnitude(const float& v) { return std::abs(v); } +template <> inline double magnitude(const double& v) { return std::abs(v); } +template <> inline uint8_t magnitude(const uint8_t& v) { return v; } +template <> inline int8_t magnitude(const int8_t& v) { return std::abs(v); } +template <> inline int16_t magnitude(const int16_t& v) { return std::abs(v); } +template <> inline int32_t magnitude(const int32_t& v) { return std::abs(v); } +template <> inline int64_t magnitude(const int64_t& v) { return std::abs(v); } +template <> inline float magnitude(const nm::Complex64& v) { return std::sqrt(v.r * v.r + v.i * v.i); } +template <> inline double magnitude(const nm::Complex128& v) { return std::sqrt(v.r * v.r + v.i * v.i); } + +}} + +#endif // MAGNITUDE_H diff --git a/ext/nmatrix/math/nrm2.h b/ext/nmatrix/math/nrm2.h index ac0c9ef2..deb112df 100644 --- a/ext/nmatrix/math/nrm2.h +++ b/ext/nmatrix/math/nrm2.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -74,8 +74,8 @@ namespace nm { namespace math { * complex64 -> float or double * complex128 -> double */ -template -ReturnDType nrm2(const int N, const DType* X, const int incX) { +template ::type> +MDType nrm2(const int N, const DType* X, const int incX) { const DType ONE = 1, ZERO = 0; typename LongDType::type scale = 0, ssq = 1, absxi, temp; @@ -96,7 +96,7 @@ ReturnDType nrm2(const int N, const DType* X, const int incX) { } } - return scale * std::sqrt( ssq ); + return (MDType)(scale * std::sqrt( ssq )); } @@ -138,6 +138,8 @@ float nrm2(const int N, const Complex64* X, const int incX) { return scale * std::sqrt( ssq ); } +// FIXME: Function above is duplicated here, should be writeable as a template using +// FIXME: xMagnitudeDType. template <> double nrm2(const int N, const Complex128* X, const int incX) { double scale = 0, ssq = 1; @@ -151,9 +153,9 @@ double nrm2(const int N, const Complex128* X, const int incX) { return scale * std::sqrt( ssq ); } -template +template ::type> inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) { - *reinterpret_cast( result ) = nrm2( N, reinterpret_cast(X), incX ); + *reinterpret_cast( result ) = nrm2( N, reinterpret_cast(X), incX ); } diff --git a/ext/nmatrix_atlas/math_atlas.cpp b/ext/nmatrix_atlas/math_atlas.cpp index 34ac1a1c..371bff77 100644 --- a/ext/nmatrix_atlas/math_atlas.cpp +++ b/ext/nmatrix_atlas/math_atlas.cpp @@ -368,11 +368,11 @@ static VALUE nm_atlas_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { NULL, NULL, NULL, NULL, NULL, // no help for integers - nm::math::atlas::cblas_nrm2, - nm::math::atlas::cblas_nrm2, - nm::math::atlas::cblas_nrm2, - nm::math::atlas::cblas_nrm2, - nm::math::atlas::cblas_nrm2 + nm::math::atlas::cblas_nrm2, + nm::math::atlas::cblas_nrm2, + nm::math::atlas::cblas_nrm2, + nm::math::atlas::cblas_nrm2, + nm::math::atlas::cblas_nrm2 }; nm::dtype_t dtype = NM_DTYPE(x); @@ -417,16 +417,16 @@ static VALUE nm_atlas_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static VALUE nm_atlas_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum }; nm::dtype_t dtype = NM_DTYPE(x); diff --git a/ext/nmatrix_lapacke/math_lapacke.cpp b/ext/nmatrix_lapacke/math_lapacke.cpp index 55b90044..5722e988 100644 --- a/ext/nmatrix_lapacke/math_lapacke.cpp +++ b/ext/nmatrix_lapacke/math_lapacke.cpp @@ -320,11 +320,11 @@ static VALUE nm_lapacke_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { NULL, NULL, NULL, NULL, NULL, // no help for integers - nm::math::lapacke::cblas_nrm2, - nm::math::lapacke::cblas_nrm2, - nm::math::lapacke::cblas_nrm2, - nm::math::lapacke::cblas_nrm2, - nm::math::lapacke::cblas_nrm2 + nm::math::lapacke::cblas_nrm2, + nm::math::lapacke::cblas_nrm2, + nm::math::lapacke::cblas_nrm2, + nm::math::lapacke::cblas_nrm2, + nm::math::lapacke::cblas_nrm2 }; nm::dtype_t dtype = NM_DTYPE(x); @@ -369,16 +369,16 @@ static VALUE nm_lapacke_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static VALUE nm_lapacke_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum }; nm::dtype_t dtype = NM_DTYPE(x); @@ -1096,4 +1096,4 @@ static VALUE nm_lapacke_lapacke_unmqr(VALUE self, VALUE order, VALUE side, VALUE } } -} \ No newline at end of file +}