Skip to content

Commit

Permalink
Made templates a little smarter for those functions which require a s…
Browse files Browse the repository at this point in the history
…eparate return DType by adding MagnitudeDType. Also added a magnitude function to replace std::abs and abs -- but mostly it just calls those.
  • Loading branch information
John Woods committed Apr 11, 2016
1 parent 92da6e5 commit 56f6eae
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 116 deletions.
46 changes: 23 additions & 23 deletions ext/nmatrix/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand Down Expand Up @@ -134,6 +134,7 @@
#include "math/cblas_enums.h"

#include "data/data.h"
#include "math/magnitude.h"
#include "math/imax.h"
#include "math/scal.h"
#include "math/laswp.h"
Expand Down Expand Up @@ -237,11 +238,14 @@ namespace nm {
int col_index[M];

for (int k = 0;k < M; ++k) {
DType akk = std::abs( matrix[k * (M + 1)] ) ; // diagonal element
typename MagnitudeDType<DType>::type akk;
akk = magnitude( matrix[k * (M + 1)] ); // diagonal element

int interchange = k;

for (int row = k + 1; row < M; ++row) {
DType big = std::abs( matrix[M*row + k] ); // element below the temp pivot
typename MagnitudeDType<DType>::type big;
big = magnitude( matrix[M*row + k] ); // element below the temp pivot

if ( big > akk ) {
interchange = row;
Expand Down Expand Up @@ -694,16 +698,12 @@ static VALUE nm_cblas_rot(VALUE self, VALUE n, VALUE x, VALUE incx, VALUE y, VAL
static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) {

static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = {
/* nm::math::cblas_nrm2<uint8_t,uint8_t>,
nm::math::cblas_nrm2<int8_t,int8_t>,
nm::math::cblas_nrm2<int16_t,int16_t>,
nm::math::cblas_nrm2<int32_t,int32_t>, */
NULL, NULL, NULL, NULL, NULL, // no help for integers
nm::math::cblas_nrm2<float32_t,float32_t>,
nm::math::cblas_nrm2<float64_t,float64_t>,
nm::math::cblas_nrm2<float32_t,nm::Complex64>,
nm::math::cblas_nrm2<float64_t,nm::Complex128>,
nm::math::cblas_nrm2<nm::RubyObject,nm::RubyObject>
nm::math::cblas_nrm2<float32_t>,
nm::math::cblas_nrm2<float64_t>,
nm::math::cblas_nrm2<nm::Complex64>,
nm::math::cblas_nrm2<nm::Complex128>,
nm::math::cblas_nrm2<nm::RubyObject>
};

nm::dtype_t dtype = NM_DTYPE(x);
Expand Down Expand Up @@ -748,16 +748,16 @@ static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) {
static VALUE nm_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) {

static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = {
nm::math::cblas_asum<uint8_t,uint8_t>,
nm::math::cblas_asum<int8_t,int8_t>,
nm::math::cblas_asum<int16_t,int16_t>,
nm::math::cblas_asum<int32_t,int32_t>,
nm::math::cblas_asum<int64_t,int64_t>,
nm::math::cblas_asum<float32_t,float32_t>,
nm::math::cblas_asum<float64_t,float64_t>,
nm::math::cblas_asum<float32_t,nm::Complex64>,
nm::math::cblas_asum<float64_t,nm::Complex128>,
nm::math::cblas_asum<nm::RubyObject,nm::RubyObject>
nm::math::cblas_asum<uint8_t>,
nm::math::cblas_asum<int8_t>,
nm::math::cblas_asum<int16_t>,
nm::math::cblas_asum<int32_t>,
nm::math::cblas_asum<int64_t>,
nm::math::cblas_asum<float32_t>,
nm::math::cblas_asum<float64_t>,
nm::math::cblas_asum<nm::Complex64>,
nm::math::cblas_asum<nm::Complex128>,
nm::math::cblas_asum<nm::RubyObject>
};

nm::dtype_t dtype = NM_DTYPE(x);
Expand Down
41 changes: 10 additions & 31 deletions ext/nmatrix/math/asum.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand Down Expand Up @@ -60,6 +60,8 @@
#define ASUM_H


#include "math/magnitude.h"

namespace nm { namespace math {

/*
Expand All @@ -73,44 +75,21 @@ namespace nm { namespace math {
* complex64 -> float or double
* complex128 -> double
*/
template <typename ReturnDType, typename DType>
inline ReturnDType asum(const int N, const DType* X, const int incX) {
ReturnDType sum = 0;
if ((N > 0) && (incX > 0)) {
for (int i = 0; i < N; ++i) {
sum += std::abs(X[i*incX]);
}
}
return sum;
}


template <>
inline float asum(const int N, const Complex64* X, const int incX) {
float sum = 0;
if ((N > 0) && (incX > 0)) {
for (int i = 0; i < N; ++i) {
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
}
}
return sum;
}

template <>
inline double asum(const int N, const Complex128* X, const int incX) {
double sum = 0;
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType asum(const int N, const DType* X, const int incX) {
MDType sum = 0;
if ((N > 0) && (incX > 0)) {
for (int i = 0; i < N; ++i) {
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
sum += magnitude(X[i*incX]);
}
}
return sum;
}


template <typename ReturnDType, typename DType>
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
*reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
*reinterpret_cast<MDType*>( sum ) = asum<DType,MDType>( N, reinterpret_cast<const DType*>(X), incX );
}


Expand Down
20 changes: 10 additions & 10 deletions ext/nmatrix/math/cblas_templates_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int i
* complex64 -> float or double
* complex128 -> double
*/
template <typename ReturnDType, typename DType>
inline ReturnDType asum(const int N, const DType* X, const int incX) {
return nm::math::asum<ReturnDType,DType>(N,X,incX);
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType asum(const int N, const DType* X, const int incX) {
return nm::math::asum<DType,MDType>(N,X,incX);
}


Expand All @@ -134,9 +134,9 @@ inline double asum(const int N, const Complex128* X, const int incX) {
}


template <typename ReturnDType, typename DType>
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
*static_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, static_cast<const DType*>(X), incX );
*static_cast<MDType*>( sum ) = asum<DType, MDType>( N, static_cast<const DType*>(X), incX );
}

/*
Expand All @@ -149,9 +149,9 @@ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
* complex64 -> float or double
* complex128 -> double
*/
template <typename ReturnDType, typename DType>
inline ReturnDType nrm2(const int N, const DType* X, const int incX) {
return nm::math::nrm2<ReturnDType,DType>(N, X, incX);
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType nrm2(const int N, const DType* X, const int incX) {
return nm::math::nrm2<DType,MDType>(N, X, incX);
}


Expand All @@ -175,9 +175,9 @@ inline double nrm2(const int N, const Complex128* X, const int incX) {
return cblas_dznrm2(N, X, incX);
}

template <typename ReturnDType, typename DType>
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) {
*static_cast<ReturnDType*>( result ) = nrm2<ReturnDType, DType>( N, static_cast<const DType*>(X), incX );
*static_cast<MDType*>( result ) = nrm2<DType, MDType>( N, static_cast<const DType*>(X), incX );
}

//imax
Expand Down
4 changes: 2 additions & 2 deletions ext/nmatrix/math/getrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand Down
21 changes: 12 additions & 9 deletions ext/nmatrix/math/imax.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand All @@ -29,8 +29,11 @@
#ifndef IMAX_H
#define IMAX_H

#include "math/magnitude.h"

namespace nm { namespace math {


template<typename DType>
inline int imax(const int n, const DType *x, const int incx) {

Expand All @@ -41,28 +44,28 @@ inline int imax(const int n, const DType *x, const int incx) {
return 0;
}

DType dmax;
typename MagnitudeDType<DType>::type dmax;
int imax = 0;

if (incx == 1) { // if incrementing by 1

dmax = abs(x[0]);
dmax = magnitude(x[0]);

for (int i = 1; i < n; ++i) {
if (std::abs(x[i]) > dmax) {
if (magnitude(x[i]) > dmax) {
imax = i;
dmax = std::abs(x[i]);
dmax = magnitude(x[i]);
}
}

} else { // if incrementing by more than 1

dmax = std::abs(x[0]);
dmax = magnitude(x[0]);

for (int i = 1, ix = incx; i < n; ++i, ix += incx) {
if (std::abs(x[ix]) > dmax) {
if (magnitude(x[ix]) > dmax) {
imax = i;
dmax = std::abs(x[ix]);
dmax = magnitude(x[ix]);
}
}
}
Expand Down
19 changes: 16 additions & 3 deletions ext/nmatrix/math/long_dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand All @@ -23,7 +23,8 @@
//
// == long_dtype.h
//
// Declarations necessary for the native versions of GEMM and GEMV.
// Declarations necessary for the native versions of GEMM and GEMV,
// as well as for IMAX.
//

#ifndef LONG_DTYPE_H
Expand All @@ -44,6 +45,18 @@ namespace nm { namespace math {
template <> struct LongDType<Complex128> { typedef Complex128 type; };
template <> struct LongDType<RubyObject> { typedef RubyObject type; };

template <typename DType> struct MagnitudeDType;
template <> struct MagnitudeDType<uint8_t> { typedef uint8_t type; };
template <> struct MagnitudeDType<int8_t> { typedef int8_t type; };
template <> struct MagnitudeDType<int16_t> { typedef int16_t type; };
template <> struct MagnitudeDType<int32_t> { typedef int32_t type; };
template <> struct MagnitudeDType<int64_t> { typedef int64_t type; };
template <> struct MagnitudeDType<float> { typedef float type; };
template <> struct MagnitudeDType<double> { typedef double type; };
template <> struct MagnitudeDType<Complex64> { typedef float type; };
template <> struct MagnitudeDType<Complex128> { typedef double type; };
template <> struct MagnitudeDType<RubyObject> { typedef RubyObject type; };

}} // end of namespace nm::math

#endif
54 changes: 54 additions & 0 deletions ext/nmatrix/math/magnitude.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/////////////////////////////////////////////////////////////////////
// = NMatrix
//
// A linear algebra library for scientific computation in Ruby.
// NMatrix is part of SciRuby.
//
// NMatrix was originally inspired by and derived from NArray, by
// Masahiro Tanaka: http://narray.rubyforge.org
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
// == Contributing
//
// By contributing source code to SciRuby, you agree to be bound by
// our Contributor Agreement:
//
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
//
// == math/magnitude.h
//
// Takes the absolute value (meaning magnitude) of each DType.
// Needed for a variety of BLAS/LAPACK functions.
//

#ifndef MAGNITUDE_H
#define MAGNITUDE_H

#include "math/long_dtype.h"

namespace nm { namespace math {

/* Magnitude -- may be complicated for unsigned types, and need to call the correct STL abs for floats/doubles */
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType magnitude(const DType& v) {
return v.abs();
}
template <> inline float magnitude(const float& v) { return std::abs(v); }
template <> inline double magnitude(const double& v) { return std::abs(v); }
template <> inline uint8_t magnitude(const uint8_t& v) { return v; }
template <> inline int8_t magnitude(const int8_t& v) { return std::abs(v); }
template <> inline int16_t magnitude(const int16_t& v) { return std::abs(v); }
template <> inline int32_t magnitude(const int32_t& v) { return std::abs(v); }
template <> inline int64_t magnitude(const int64_t& v) { return std::abs(v); }
template <> inline float magnitude(const nm::Complex64& v) { return std::sqrt(v.r * v.r + v.i * v.i); }
template <> inline double magnitude(const nm::Complex128& v) { return std::sqrt(v.r * v.r + v.i * v.i); }

}}

#endif // MAGNITUDE_H
Loading

0 comments on commit 56f6eae

Please sign in to comment.