diff --git a/ext/nmatrix/data/complex.h b/ext/nmatrix/data/complex.h index f455757b..9da1f766 100644 --- a/ext/nmatrix/data/complex.h +++ b/ext/nmatrix/data/complex.h @@ -110,11 +110,11 @@ class Complex { } /* - * Complex inverse function -- creates a copy, but inverted. + * Complex inverse (reciprocal) function -- computes 1 / n. * * FIXME: Check that this doesn't duplicate functionality of NativeType / Complex */ - inline Complex inverse() const { + inline Complex reciprocal() const { Complex conj = conjugate(); Type denom = this->r * this->r + this->i * this->i; return Complex(conj.r / denom, conj.i / denom); diff --git a/ext/nmatrix/data/ruby_object.h b/ext/nmatrix/data/ruby_object.h index a9fcaeac..4e62c1df 100644 --- a/ext/nmatrix/data/ruby_object.h +++ b/ext/nmatrix/data/ruby_object.h @@ -54,6 +54,12 @@ * Classes and Functions */ +extern "C" { + inline VALUE quo_reciprocal(const VALUE rval) { + return rb_funcall(INT2FIX(1), nm_rb_quo, 1, rval); + } +} + namespace nm { template struct made_from_same_template : std::false_type {}; @@ -125,10 +131,32 @@ class RubyObject { inline RubyObject(const RubyObject& other) : rval(other.rval) {} /* - * Inverse operator. + * Quotient operator (typically for rational divisions) + */ + inline RubyObject quo(const RubyObject& other) const { + return RubyObject(rb_funcall(this->rval, nm_rb_quo, 1, other.rval)); + } + + /* + * Numeric inverse (reciprocal) operator, usually used for matrix + * operations like getrf. For this to work, Fixnum.quo(this) must not + * return a TypeError. */ - inline RubyObject inverse() const { - rb_raise(rb_eNotImpError, "RubyObject#inverse needs to be implemented"); + inline RubyObject reciprocal() const { + int exception; + + // Attempt to call 1.quo(this). + VALUE result = rb_protect(quo_reciprocal, this->rval, &exception); + if (exception) { + ID rb_reciprocal = rb_intern("reciprocal"); + // quo failed, so let's see if the object has a reciprocal method. + if (rb_respond_to(this->rval, rb_reciprocal)) { + return RubyObject(rb_funcall(this->rval, rb_reciprocal, 0)); + } else { + rb_raise(rb_eNoMethodError, "expected reciprocal method, since 1.quo(object) raises an error"); + } + } + return result; } /* diff --git a/ext/nmatrix/math.cpp b/ext/nmatrix/math.cpp index 095f8909..d17864a7 100644 --- a/ext/nmatrix/math.cpp +++ b/ext/nmatrix/math.cpp @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -134,6 +134,7 @@ #include "math/cblas_enums.h" #include "data/data.h" +#include "math/magnitude.h" #include "math/imax.h" #include "math/scal.h" #include "math/laswp.h" @@ -206,7 +207,7 @@ namespace nm { } else if (M == 3) { x = A[lda+1] * A[2*lda+2] - A[lda+2] * A[2*lda+1]; // ei - fh y = A[lda] * A[2*lda+2] - A[lda+2] * A[2*lda]; // fg - di - x = A[0]*x - A[1]*y ; // a*(ei-fh) - b*(fg-di) + x = A[0]*x - A[1]*y; // a*(ei-fh) - b*(fg-di) y = A[lda] * A[2*lda+1] - A[lda+1] * A[2*lda]; // dh - eg *result = A[2]*y + x; // c*(dh-eg) + _ @@ -237,11 +238,14 @@ namespace nm { int col_index[M]; for (int k = 0;k < M; ++k) { - DType akk = std::abs( matrix[k * (M + 1)] ) ; // diagonal element + typename MagnitudeDType::type akk; + akk = magnitude( matrix[k * (M + 1)] ); // diagonal element + int interchange = k; for (int row = k + 1; row < M; ++row) { - DType big = std::abs( matrix[M*row + k] ); // element below the temp pivot + typename MagnitudeDType::type big; + big = magnitude( matrix[M*row + k] ); // element below the temp pivot if ( big > akk ) { interchange = row; @@ -748,16 +752,16 @@ static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static VALUE nm_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum, - nm::math::cblas_asum + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum, + nm::math::cblas_asum }; nm::dtype_t dtype = NM_DTYPE(x); diff --git a/ext/nmatrix/math/asum.h b/ext/nmatrix/math/asum.h index 697b9af9..dbd9ab65 100644 --- a/ext/nmatrix/math/asum.h +++ b/ext/nmatrix/math/asum.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -60,6 +60,8 @@ #define ASUM_H +#include "math/magnitude.h" + namespace nm { namespace math { /* @@ -73,44 +75,21 @@ namespace nm { namespace math { * complex64 -> float or double * complex128 -> double */ -template -inline ReturnDType asum(const int N, const DType* X, const int incX) { - ReturnDType sum = 0; - if ((N > 0) && (incX > 0)) { - for (int i = 0; i < N; ++i) { - sum += std::abs(X[i*incX]); - } - } - return sum; -} - - -template <> -inline float asum(const int N, const Complex64* X, const int incX) { - float sum = 0; - if ((N > 0) && (incX > 0)) { - for (int i = 0; i < N; ++i) { - sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i); - } - } - return sum; -} - -template <> -inline double asum(const int N, const Complex128* X, const int incX) { - double sum = 0; +template ::type> +inline MDType asum(const int N, const DType* X, const int incX) { + MDType sum = 0; if ((N > 0) && (incX > 0)) { for (int i = 0; i < N; ++i) { - sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i); + sum += magnitude(X[i*incX]); } } return sum; } -template +template ::type> inline void cblas_asum(const int N, const void* X, const int incX, void* sum) { - *reinterpret_cast( sum ) = asum( N, reinterpret_cast(X), incX ); + *reinterpret_cast( sum ) = asum( N, reinterpret_cast(X), incX ); } diff --git a/ext/nmatrix/math/cblas_templates_core.h b/ext/nmatrix/math/cblas_templates_core.h index 5060ae08..4631d791 100644 --- a/ext/nmatrix/math/cblas_templates_core.h +++ b/ext/nmatrix/math/cblas_templates_core.h @@ -107,9 +107,9 @@ inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int i * complex64 -> float or double * complex128 -> double */ -template -inline ReturnDType asum(const int N, const DType* X, const int incX) { - return nm::math::asum(N,X,incX); +template ::type> +inline MDType asum(const int N, const DType* X, const int incX) { + return nm::math::asum(N,X,incX); } @@ -134,9 +134,9 @@ inline double asum(const int N, const Complex128* X, const int incX) { } -template +template ::type> inline void cblas_asum(const int N, const void* X, const int incX, void* sum) { - *static_cast( sum ) = asum( N, static_cast(X), incX ); + *static_cast( sum ) = asum( N, static_cast(X), incX ); } /* @@ -149,9 +149,9 @@ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) { * complex64 -> float or double * complex128 -> double */ -template +template ::type> inline ReturnDType nrm2(const int N, const DType* X, const int incX) { - return nm::math::nrm2(N, X, incX); + return nm::math::nrm2(N, X, incX); } @@ -175,9 +175,9 @@ inline double nrm2(const int N, const Complex128* X, const int incX) { return cblas_dznrm2(N, X, incX); } -template +template ::type> inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) { - *static_cast( result ) = nrm2( N, static_cast(X), incX ); + *static_cast( result ) = nrm2( N, static_cast(X), incX ); } //imax diff --git a/ext/nmatrix/math/getrf.h b/ext/nmatrix/math/getrf.h index 03b702bf..f1cd050c 100644 --- a/ext/nmatrix/math/getrf.h +++ b/ext/nmatrix/math/getrf.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -59,6 +59,7 @@ #ifndef GETRF_H #define GETRF_H +#include "math/reciprocal.h" #include "math/laswp.h" #include "math/math.h" #include "math/trsm.h" @@ -68,14 +69,6 @@ namespace nm { namespace math { -/* Numeric inverse -- usually just 1 / f, but a little more complicated for complex. */ -template -inline DType numeric_inverse(const DType& n) { - return n.inverse(); -} -template <> inline float numeric_inverse(const float& n) { return 1 / n; } -template <> inline double numeric_inverse(const double& n) { return 1 / n; } - /* * Templated version of row-order and column-order getrf, derived from ATL_getrfR.c (from ATLAS 3.8.0). * @@ -132,10 +125,8 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* An = &(Ar[N_ul]); nm::math::laswp(N_dr, Ar, lda, 0, N_ul, ipiv, 1); - nm::math::trsm(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, N_dr, N_ul, one, A, lda, Ar, lda); nm::math::gemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N_dr, N-N_ul, N_ul, &neg_one, Ar, lda, Ac, lda, &one, An, lda); - i = getrf_nothrow(N_dr, N-N_ul, An, lda, ipiv+N_ul); } else { Ar = NULL; @@ -143,10 +134,8 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* An = &(Ac[N_ul]); nm::math::laswp(N_dr, Ac, lda, 0, N_ul, ipiv, 1); - nm::math::trsm(CblasColMajor, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N_ul, N_dr, one, A, lda, Ac, lda); nm::math::gemm(CblasColMajor, CblasNoTrans, CblasNoTrans, M-N_ul, N_dr, N_ul, &neg_one, &(A[N_ul]), lda, Ac, lda, &one, An, lda); - i = getrf_nothrow(M-N_ul, N_dr, An, lda, ipiv+N_ul); } @@ -157,7 +146,6 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* } nm::math::laswp(N_ul, A, lda, N_ul, MN, ipiv, 1); /* apply pivots */ - } else if (MN == 1) { // there's another case for the colmajor version, but it doesn't seem to be necessary. int i; @@ -170,13 +158,14 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int* DType tmp = A[i]; if (tmp != 0) { - nm::math::scal((RowMajor ? N : M), nm::math::numeric_inverse(tmp), A, 1); + nm::math::scal((RowMajor ? N : M), nm::math::reciprocal(tmp), A, 1); A[i] = *A; *A = tmp; } else ierr = 1; } + return(ierr); } diff --git a/ext/nmatrix/math/imax.h b/ext/nmatrix/math/imax.h index 01d4fa8e..4ff57201 100644 --- a/ext/nmatrix/math/imax.h +++ b/ext/nmatrix/math/imax.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -29,8 +29,11 @@ #ifndef IMAX_H #define IMAX_H +#include "math/magnitude.h" + namespace nm { namespace math { + template inline int imax(const int n, const DType *x, const int incx) { @@ -41,28 +44,28 @@ inline int imax(const int n, const DType *x, const int incx) { return 0; } - DType dmax; + typename MagnitudeDType::type dmax; int imax = 0; if (incx == 1) { // if incrementing by 1 - dmax = abs(x[0]); + dmax = magnitude(x[0]); for (int i = 1; i < n; ++i) { - if (std::abs(x[i]) > dmax) { + if (magnitude(x[i]) > dmax) { imax = i; - dmax = std::abs(x[i]); + dmax = magnitude(x[i]); } } } else { // if incrementing by more than 1 - dmax = std::abs(x[0]); + dmax = magnitude(x[0]); for (int i = 1, ix = incx; i < n; ++i, ix += incx) { - if (std::abs(x[ix]) > dmax) { + if (magnitude(x[ix]) > dmax) { imax = i; - dmax = std::abs(x[ix]); + dmax = magnitude(x[ix]); } } } diff --git a/ext/nmatrix/math/long_dtype.h b/ext/nmatrix/math/long_dtype.h index 7c6138db..dd4a1dbb 100644 --- a/ext/nmatrix/math/long_dtype.h +++ b/ext/nmatrix/math/long_dtype.h @@ -9,8 +9,8 @@ // // == Copyright Information // -// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation -// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation // // Please see LICENSE.txt for additional copyright notices. // @@ -23,7 +23,8 @@ // // == long_dtype.h // -// Declarations necessary for the native versions of GEMM and GEMV. +// Declarations necessary for the native versions of GEMM and GEMV, +// as well as for IMAX. // #ifndef LONG_DTYPE_H @@ -44,6 +45,18 @@ namespace nm { namespace math { template <> struct LongDType { typedef Complex128 type; }; template <> struct LongDType { typedef RubyObject type; }; + template struct MagnitudeDType; + template <> struct MagnitudeDType { typedef uint8_t type; }; + template <> struct MagnitudeDType { typedef int8_t type; }; + template <> struct MagnitudeDType { typedef int16_t type; }; + template <> struct MagnitudeDType { typedef int32_t type; }; + template <> struct MagnitudeDType { typedef int64_t type; }; + template <> struct MagnitudeDType { typedef float type; }; + template <> struct MagnitudeDType { typedef double type; }; + template <> struct MagnitudeDType { typedef float type; }; + template <> struct MagnitudeDType { typedef double type; }; + template <> struct MagnitudeDType { typedef RubyObject type; }; + }} // end of namespace nm::math #endif diff --git a/ext/nmatrix/math/magnitude.h b/ext/nmatrix/math/magnitude.h new file mode 100644 index 00000000..e53303d3 --- /dev/null +++ b/ext/nmatrix/math/magnitude.h @@ -0,0 +1,54 @@ +///////////////////////////////////////////////////////////////////// +// = NMatrix +// +// A linear algebra library for scientific computation in Ruby. +// NMatrix is part of SciRuby. +// +// NMatrix was originally inspired by and derived from NArray, by +// Masahiro Tanaka: http://narray.rubyforge.org +// +// == Copyright Information +// +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation +// +// Please see LICENSE.txt for additional copyright notices. +// +// == Contributing +// +// By contributing source code to SciRuby, you agree to be bound by +// our Contributor Agreement: +// +// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement +// +// == math/magnitude.h +// +// Takes the absolute value (meaning magnitude) of each DType. +// Needed for a variety of BLAS/LAPACK functions. +// + +#ifndef MAGNITUDE_H +#define MAGNITUDE_H + +#include "math/long_dtype.h" + +namespace nm { namespace math { + +/* Magnitude -- may be complicated for unsigned types, and need to call the correct STL abs for floats/doubles */ +template ::type> +inline MDType magnitude(const DType& v) { + return v.abs(); +} +template <> inline float magnitude(const float& v) { return std::abs(v); } +template <> inline double magnitude(const double& v) { return std::abs(v); } +template <> inline uint8_t magnitude(const uint8_t& v) { return v; } +template <> inline int8_t magnitude(const int8_t& v) { return std::abs(v); } +template <> inline int16_t magnitude(const int16_t& v) { return std::abs(v); } +template <> inline int32_t magnitude(const int32_t& v) { return std::abs(v); } +template <> inline int64_t magnitude(const int64_t& v) { return std::abs(v); } +template <> inline float magnitude(const nm::Complex64& v) { return std::sqrt(v.r * v.r + v.i * v.i); } +template <> inline double magnitude(const nm::Complex128& v) { return std::sqrt(v.r * v.r + v.i * v.i); } + +}} + +#endif // MAGNITUDE_H diff --git a/ext/nmatrix/math/math.h b/ext/nmatrix/math/math.h index 02290955..6e045a13 100644 --- a/ext/nmatrix/math/math.h +++ b/ext/nmatrix/math/math.h @@ -715,30 +715,6 @@ int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, con } */ -/* - * Macro for declaring LAPACK specializations of the getrf function. - * - * type is the DType; call is the specific function to call; cast_as is what the DType* should be - * cast to in order to pass it to LAPACK. - */ -#define LAPACK_GETRF(type, call, cast_as) \ -template <> \ -inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \ - int info = call(Order, M, N, reinterpret_cast(A), lda, ipiv); \ - if (!info) return info; \ - else { \ - rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \ - return info; \ - } \ -} - -/* Specialize for ATLAS types */ -/*LAPACK_GETRF(float, clapack_sgetrf, float) -LAPACK_GETRF(double, clapack_dgetrf, double) -LAPACK_GETRF(Complex64, clapack_cgetrf, void) -LAPACK_GETRF(Complex128, clapack_zgetrf, void) -*/ - }} // end namespace nm::math diff --git a/ext/nmatrix/math/reciprocal.h b/ext/nmatrix/math/reciprocal.h new file mode 100644 index 00000000..5045ed16 --- /dev/null +++ b/ext/nmatrix/math/reciprocal.h @@ -0,0 +1,48 @@ +///////////////////////////////////////////////////////////////////// +// = NMatrix +// +// A linear algebra library for scientific computation in Ruby. +// NMatrix is part of SciRuby. +// +// NMatrix was originally inspired by and derived from NArray, by +// Masahiro Tanaka: http://narray.rubyforge.org +// +// == Copyright Information +// +// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation +// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation +// +// Please see LICENSE.txt for additional copyright notices. +// +// == Contributing +// +// By contributing source code to SciRuby, you agree to be bound by +// our Contributor Agreement: +// +// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement +// +// == math/reciprocal.h +// +// Helper function for computing the reciprocal of a template type. +// Called by getrf and possibly other things. +// + +#ifndef RECIPROCAL_H +#define RECIPROCAL_H + + +namespace nm { namespace math { + +/* Numeric inverse -- basically just 1 / n, which in Ruby is 1.quo(n). */ +template +inline DType reciprocal(const DType& n) { + return n.reciprocal(); +} +template <> inline float reciprocal(const float& n) { return 1 / n; } +template <> inline double reciprocal(const double& n) { return 1 / n; } + +}} + + +#endif // RECIPROCAL_H + diff --git a/ext/nmatrix/math/trsm.h b/ext/nmatrix/math/trsm.h index 9af41952..764f4ebf 100644 --- a/ext/nmatrix/math/trsm.h +++ b/ext/nmatrix/math/trsm.h @@ -58,6 +58,7 @@ #ifndef TRSM_H #define TRSM_H +#include "math/reciprocal.h" namespace nm { namespace math { @@ -189,7 +190,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, } } if (diag == CblasNonUnit) { - DType temp = 1 / a[j + j * lda]; + DType temp = reciprocal(a[j + j * lda]); for (int i = 1; i <= m; ++i) { b[i + j * ldb] = temp * b[i + j * ldb]; } @@ -211,7 +212,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, } } if (diag == CblasNonUnit) { - DType temp = 1 / a[j + j * lda]; + DType temp = reciprocal(a[j + j * lda]); for (int i = 1; i <= m; ++i) { b[i + j * ldb] = temp * b[i + j * ldb]; @@ -226,7 +227,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, if (uplo == CblasUpper) { for (int k = n; k >= 1; --k) { if (diag == CblasNonUnit) { - DType temp= 1 / a[k + k * lda]; + DType temp = reciprocal(a[k + k * lda]); for (int i = 1; i <= m; ++i) { b[i + k * ldb] = temp * b[i + k * ldb]; } @@ -248,7 +249,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo, } else { for (int k = 1; k <= n; ++k) { if (diag == CblasNonUnit) { - DType temp = 1 / a[k + k * lda]; + DType temp = reciprocal(a[k + k * lda]); for (int i = 1; i <= m; ++i) { b[i + k * ldb] = temp * b[i + k * ldb]; } diff --git a/ext/nmatrix/ruby_constants.cpp b/ext/nmatrix/ruby_constants.cpp index 3479eb74..52ba4568 100644 --- a/ext/nmatrix/ruby_constants.cpp +++ b/ext/nmatrix/ruby_constants.cpp @@ -67,6 +67,7 @@ ID nm_rb_dtype, nm_rb_sub, nm_rb_mul, nm_rb_div, + nm_rb_quo, nm_rb_both, nm_rb_none, @@ -124,6 +125,7 @@ void nm_init_ruby_constants(void) { nm_rb_sub = rb_intern("-"); nm_rb_mul = rb_intern("*"); nm_rb_div = rb_intern("/"); + nm_rb_quo = rb_intern("quo"); nm_rb_negate = rb_intern("-@"); diff --git a/ext/nmatrix/ruby_constants.h b/ext/nmatrix/ruby_constants.h index 583fd87a..b1375109 100644 --- a/ext/nmatrix/ruby_constants.h +++ b/ext/nmatrix/ruby_constants.h @@ -71,6 +71,7 @@ extern ID nm_rb_dtype, nm_rb_sub, nm_rb_mul, nm_rb_div, + nm_rb_quo, nm_rb_negate, diff --git a/ext/nmatrix/ruby_nmatrix.c b/ext/nmatrix/ruby_nmatrix.c index 1790f50b..8c890a67 100644 --- a/ext/nmatrix/ruby_nmatrix.c +++ b/ext/nmatrix/ruby_nmatrix.c @@ -2693,11 +2693,11 @@ static SLICE* get_slice(size_t dim, int argc, VALUE* arg, size_t* shape) { VALUE begin_end = rb_funcall(v, rb_intern("shift"), 0); // rb_hash_shift nm_register_value(&begin_end); - if (rb_ary_entry(begin_end, 0) >= 0) + if (FIX2INT(rb_ary_entry(begin_end, 0)) >= 0) slice->coords[r] = FIX2INT(rb_ary_entry(begin_end, 0)); else slice->coords[r] = shape[r] + FIX2INT(rb_ary_entry(begin_end, 0)); - if (rb_ary_entry(begin_end, 1) >= 0) + if (FIX2INT(rb_ary_entry(begin_end, 1)) >= 0) slice->lengths[r] = FIX2INT(rb_ary_entry(begin_end, 1)) - slice->coords[r]; else slice->lengths[r] = shape[r] + FIX2INT(rb_ary_entry(begin_end, 1)) - slice->coords[r]; diff --git a/ext/nmatrix_atlas/math_atlas.cpp b/ext/nmatrix_atlas/math_atlas.cpp index 34ac1a1c..c91d9e37 100644 --- a/ext/nmatrix_atlas/math_atlas.cpp +++ b/ext/nmatrix_atlas/math_atlas.cpp @@ -417,16 +417,16 @@ static VALUE nm_atlas_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static VALUE nm_atlas_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum, - nm::math::atlas::cblas_asum + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum, + nm::math::atlas::cblas_asum }; nm::dtype_t dtype = NM_DTYPE(x); diff --git a/ext/nmatrix_lapacke/math_lapacke.cpp b/ext/nmatrix_lapacke/math_lapacke.cpp index 55b90044..e57c1d6b 100644 --- a/ext/nmatrix_lapacke/math_lapacke.cpp +++ b/ext/nmatrix_lapacke/math_lapacke.cpp @@ -369,16 +369,16 @@ static VALUE nm_lapacke_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) { static VALUE nm_lapacke_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) { static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = { - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum, - nm::math::lapacke::cblas_asum + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum, + nm::math::lapacke::cblas_asum }; nm::dtype_t dtype = NM_DTYPE(x); @@ -1096,4 +1096,4 @@ static VALUE nm_lapacke_lapacke_unmqr(VALUE self, VALUE order, VALUE side, VALUE } } -} \ No newline at end of file +} diff --git a/lib/nmatrix/atlas.rb b/lib/nmatrix/atlas.rb index e941ef4f..69d1f316 100644 --- a/lib/nmatrix/atlas.rb +++ b/lib/nmatrix/atlas.rb @@ -205,7 +205,7 @@ def invert! end def potrf!(which) - raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense? + raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense? raise(ShapeError, "Cholesky decomposition only valid for square matrices") unless self.dim == 2 && self.shape[0] == self.shape[1] NMatrix::LAPACK::clapack_potrf(:row, which, self.shape[0], self, self.shape[1]) diff --git a/lib/nmatrix/lapacke.rb b/lib/nmatrix/lapacke.rb index d130fdee..3492eedd 100644 --- a/lib/nmatrix/lapacke.rb +++ b/lib/nmatrix/lapacke.rb @@ -338,7 +338,7 @@ def ormqr(tau, side=:left, transpose=false, c=nil) # - +TypeError+ -> c must have the same dtype as the calling NMatrix # def unmqr(tau, side=:left, transpose=false, c=nil) - raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense? + raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense? raise(TypeError, "Works only on complex matrices, use ormqr for normal floating point matrices") unless self.complex_dtype? raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype diff --git a/lib/nmatrix/math.rb b/lib/nmatrix/math.rb index 36d1d98d..b5009b11 100644 --- a/lib/nmatrix/math.rb +++ b/lib/nmatrix/math.rb @@ -130,10 +130,10 @@ def invert # * *Returns* : # - The IPIV vector. The L and U matrices are stored in A. # * *Raises* : - # - +StorageTypeError+ -> ATLAS functions only work on dense matrices. + # - +StorageTypeError+ -> LAPACK functions only work on dense matrices. # def getrf! - raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense? + raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense? #For row-major matrices, clapack_getrf uses a different convention than #described above (U has unit diagonal elements instead of L and columns @@ -142,15 +142,28 @@ def getrf! #and after calling clapack_getrf. #Unfortunately, this is not a very good way, uses a lot of memory. temp = self.transpose - ipiv = NMatrix::LAPACK::clapack_getrf(:col, self.shape[0], self.shape[1], temp, self.shape[0]) - temp = temp.transpose - self[0...self.shape[0], 0...self.shape[1]] = temp - - #for some reason, in clapack_getrf, the indices in ipiv start from 0 - #instead of 1 as in LAPACK. - ipiv.each_index { |i| ipiv[i]+=1 } - - return ipiv + begin + ipiv = NMatrix::LAPACK::clapack_getrf(:col, self.shape[0], self.shape[1], temp, self.shape[0]) + temp = temp.transpose + self[0...self.shape[0], 0...self.shape[1]] = temp + + #for some reason, in clapack_getrf, the indices in ipiv start from 0 + #instead of 1 as in LAPACK. + ipiv.each_index { |i| ipiv[i]+=1 } + + return ipiv + rescue NoMethodError => e + if e.message =~ /abs/ || e.message =~ /reciprocal/ + raise(NoMethodError, "getrf! requires #abs and #reciprocal methods to be defined on the Ruby object stored in the matrix (or, instead of #reciprocal, 1.quo(object) must work)") + else + raise(e) + end + rescue TypeError => e + if e.message =~ /coerced/ + STDERR.puts "Error: matrix content class can't be coerced into a numeric type; you probably need to re-define the math operators on your numeric classes" + end + raise(e) + end end # @@ -270,7 +283,7 @@ def unmqr(tau, side=:left, transpose=false, c=nil) # * *Returns* : # the triangular portion specified by the parameter # * *Raises* : - # - +StorageTypeError+ -> ATLAS functions only work on dense matrices. + # - +StorageTypeError+ -> LAPACK functions only work on dense matrices. # - +ShapeError+ -> Must be square. # - +NotImplementedError+ -> If called without nmatrix-atlas or nmatrix-lapacke gem # @@ -568,7 +581,7 @@ def gesdd(workspace_size=nil) # * +:covention+ - Possible values are +:lapack+ and +:intuitive+. Default is +:intuitive+. See above for details. # def laswp!(ary, opts={}) - raise(StorageTypeError, "ATLAS functions only work on dense matrices") unless self.dense? + raise(StorageTypeError, "LAPACK functions only work on dense matrices") unless self.dense? opts = { convention: :intuitive }.merge(opts) if opts[:convention] == :intuitive diff --git a/spec/math_spec.rb b/spec/math_spec.rb index 113f0e4a..e3c04c49 100644 --- a/spec/math_spec.rb +++ b/spec/math_spec.rb @@ -902,8 +902,81 @@ end context "determinants" do + context :object do + before(:all) do + class StringWithAbs + def initialize contents + @s = contents + end + def abs + @s.to_r.abs.to_s + end + + def to_r + @s.to_r + end + + def * rhs + if rhs.is_a?(StringWithAbs) + @s.to_r * rhs.to_r + else + @s.to_r * rhs + end + end + + def quo(rhs) + @s.to_r.quo(rhs.to_r) + end + + def / rhs + @s.quo(rhs) + end + + def - rhs + @s.to_r - rhs.to_r + end + + def + rhs + @s.to_r + rhs.to_r + end + end + @string_ary = ["1", "0", "1", "1", + "1", "2", "3", "1", + "3", "3", "3", "1", + "1", "2", "3", "4"] + @c_ary = @string_ary.map { |s| StringWithAbs.new(s) } + @c = NMatrix.new([4,4], @c_ary, dtype: :object) + end + + it "raises an exception when Numeric#* is not properly defined" do + pending "Need to figure out how to refine Numeric in rspec" + class OtherStringWithAbs + def initialize s; @s = s; end + def abs; @s.to_r.abs; end + def to_r; @s.to_r; end + end + d = NMatrix.new([4,4], @string_ary.map { |s| OtherStringWithAbs.new(s) }, dtype: :object) + expect(d.det).to be_within(1e-64).of(-18) + end + + it "raises an exception when a quotient is not defined" do + expect { @c.det }.to raise_error(NoMethodError, /reciprocal/) + end + + it "does not raise an exception when object responds to :reciprocal" do + pending "Need to figure out how to refine Numeric in rspec" + class StringWithAbs + def reciprocal + 1.quo(self.to_r) + end + end + + expect(@c.det).to be_within(1e-64).of(-18) + end + end + ALL_DTYPES.each do |dtype| - next if dtype == :object + #next if dtype == :object context dtype do before do @a = NMatrix.new([2,2], [1,2, @@ -934,6 +1007,7 @@ expect(@c.det).to be_within(@err).of(-18) end it "computes the exact determinant of 2x2 matrix" do + pending if dtype == :object if dtype == :byte expect{@a.det_exact}.to raise_error(DataTypeError) else @@ -941,6 +1015,7 @@ end end it "computes the exact determinant of 3x3 matrix" do + pending if dtype == :object if dtype == :byte expect{@a.det_exact}.to raise_error(DataTypeError) else