Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made templates a little smarter for those functions which require a s… #499

Merged
merged 1 commit into from
Apr 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions ext/nmatrix/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand Down Expand Up @@ -134,6 +134,7 @@
#include "math/cblas_enums.h"

#include "data/data.h"
#include "math/magnitude.h"
#include "math/imax.h"
#include "math/scal.h"
#include "math/laswp.h"
Expand Down Expand Up @@ -237,11 +238,14 @@ namespace nm {
int col_index[M];

for (int k = 0;k < M; ++k) {
DType akk = std::abs( matrix[k * (M + 1)] ) ; // diagonal element
typename MagnitudeDType<DType>::type akk;
akk = magnitude( matrix[k * (M + 1)] ); // diagonal element

int interchange = k;

for (int row = k + 1; row < M; ++row) {
DType big = std::abs( matrix[M*row + k] ); // element below the temp pivot
typename MagnitudeDType<DType>::type big;
big = magnitude( matrix[M*row + k] ); // element below the temp pivot

if ( big > akk ) {
interchange = row;
Expand Down Expand Up @@ -694,16 +698,12 @@ static VALUE nm_cblas_rot(VALUE self, VALUE n, VALUE x, VALUE incx, VALUE y, VAL
static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) {

static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = {
/* nm::math::cblas_nrm2<uint8_t,uint8_t>,
nm::math::cblas_nrm2<int8_t,int8_t>,
nm::math::cblas_nrm2<int16_t,int16_t>,
nm::math::cblas_nrm2<int32_t,int32_t>, */
NULL, NULL, NULL, NULL, NULL, // no help for integers
nm::math::cblas_nrm2<float32_t,float32_t>,
nm::math::cblas_nrm2<float64_t,float64_t>,
nm::math::cblas_nrm2<float32_t,nm::Complex64>,
nm::math::cblas_nrm2<float64_t,nm::Complex128>,
nm::math::cblas_nrm2<nm::RubyObject,nm::RubyObject>
nm::math::cblas_nrm2<float32_t>,
nm::math::cblas_nrm2<float64_t>,
nm::math::cblas_nrm2<nm::Complex64>,
nm::math::cblas_nrm2<nm::Complex128>,
nm::math::cblas_nrm2<nm::RubyObject>
};

nm::dtype_t dtype = NM_DTYPE(x);
Expand Down Expand Up @@ -748,16 +748,16 @@ static VALUE nm_cblas_nrm2(VALUE self, VALUE n, VALUE x, VALUE incx) {
static VALUE nm_cblas_asum(VALUE self, VALUE n, VALUE x, VALUE incx) {

static void (*ttable[nm::NUM_DTYPES])(const int N, const void* X, const int incX, void* sum) = {
nm::math::cblas_asum<uint8_t,uint8_t>,
nm::math::cblas_asum<int8_t,int8_t>,
nm::math::cblas_asum<int16_t,int16_t>,
nm::math::cblas_asum<int32_t,int32_t>,
nm::math::cblas_asum<int64_t,int64_t>,
nm::math::cblas_asum<float32_t,float32_t>,
nm::math::cblas_asum<float64_t,float64_t>,
nm::math::cblas_asum<float32_t,nm::Complex64>,
nm::math::cblas_asum<float64_t,nm::Complex128>,
nm::math::cblas_asum<nm::RubyObject,nm::RubyObject>
nm::math::cblas_asum<uint8_t>,
nm::math::cblas_asum<int8_t>,
nm::math::cblas_asum<int16_t>,
nm::math::cblas_asum<int32_t>,
nm::math::cblas_asum<int64_t>,
nm::math::cblas_asum<float32_t>,
nm::math::cblas_asum<float64_t>,
nm::math::cblas_asum<nm::Complex64>,
nm::math::cblas_asum<nm::Complex128>,
nm::math::cblas_asum<nm::RubyObject>
};

nm::dtype_t dtype = NM_DTYPE(x);
Expand Down
41 changes: 10 additions & 31 deletions ext/nmatrix/math/asum.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand Down Expand Up @@ -60,6 +60,8 @@
#define ASUM_H


#include "math/magnitude.h"

namespace nm { namespace math {

/*
Expand All @@ -73,44 +75,21 @@ namespace nm { namespace math {
* complex64 -> float or double
* complex128 -> double
*/
template <typename ReturnDType, typename DType>
inline ReturnDType asum(const int N, const DType* X, const int incX) {
ReturnDType sum = 0;
if ((N > 0) && (incX > 0)) {
for (int i = 0; i < N; ++i) {
sum += std::abs(X[i*incX]);
}
}
return sum;
}


template <>
inline float asum(const int N, const Complex64* X, const int incX) {
float sum = 0;
if ((N > 0) && (incX > 0)) {
for (int i = 0; i < N; ++i) {
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
}
}
return sum;
}

template <>
inline double asum(const int N, const Complex128* X, const int incX) {
double sum = 0;
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType asum(const int N, const DType* X, const int incX) {
MDType sum = 0;
if ((N > 0) && (incX > 0)) {
for (int i = 0; i < N; ++i) {
sum += std::abs(X[i*incX].r) + std::abs(X[i*incX].i);
sum += magnitude(X[i*incX]);
}
}
return sum;
}


template <typename ReturnDType, typename DType>
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
*reinterpret_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, reinterpret_cast<const DType*>(X), incX );
*reinterpret_cast<MDType*>( sum ) = asum<DType,MDType>( N, reinterpret_cast<const DType*>(X), incX );
}


Expand Down
20 changes: 10 additions & 10 deletions ext/nmatrix/math/cblas_templates_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int i
* complex64 -> float or double
* complex128 -> double
*/
template <typename ReturnDType, typename DType>
inline ReturnDType asum(const int N, const DType* X, const int incX) {
return nm::math::asum<ReturnDType,DType>(N,X,incX);
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType asum(const int N, const DType* X, const int incX) {
return nm::math::asum<DType,MDType>(N,X,incX);
}


Expand All @@ -134,9 +134,9 @@ inline double asum(const int N, const Complex128* X, const int incX) {
}


template <typename ReturnDType, typename DType>
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
*static_cast<ReturnDType*>( sum ) = asum<ReturnDType, DType>( N, static_cast<const DType*>(X), incX );
*static_cast<MDType*>( sum ) = asum<DType, MDType>( N, static_cast<const DType*>(X), incX );
}

/*
Expand All @@ -149,9 +149,9 @@ inline void cblas_asum(const int N, const void* X, const int incX, void* sum) {
* complex64 -> float or double
* complex128 -> double
*/
template <typename ReturnDType, typename DType>
inline ReturnDType nrm2(const int N, const DType* X, const int incX) {
return nm::math::nrm2<ReturnDType,DType>(N, X, incX);
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType nrm2(const int N, const DType* X, const int incX) {
return nm::math::nrm2<DType,MDType>(N, X, incX);
}


Expand All @@ -175,9 +175,9 @@ inline double nrm2(const int N, const Complex128* X, const int incX) {
return cblas_dznrm2(N, X, incX);
}

template <typename ReturnDType, typename DType>
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline void cblas_nrm2(const int N, const void* X, const int incX, void* result) {
*static_cast<ReturnDType*>( result ) = nrm2<ReturnDType, DType>( N, static_cast<const DType*>(X), incX );
*static_cast<MDType*>( result ) = nrm2<DType, MDType>( N, static_cast<const DType*>(X), incX );
}

//imax
Expand Down
4 changes: 2 additions & 2 deletions ext/nmatrix/math/getrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand Down
21 changes: 12 additions & 9 deletions ext/nmatrix/math/imax.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand All @@ -29,8 +29,11 @@
#ifndef IMAX_H
#define IMAX_H

#include "math/magnitude.h"

namespace nm { namespace math {


template<typename DType>
inline int imax(const int n, const DType *x, const int incx) {

Expand All @@ -41,28 +44,28 @@ inline int imax(const int n, const DType *x, const int incx) {
return 0;
}

DType dmax;
typename MagnitudeDType<DType>::type dmax;
int imax = 0;

if (incx == 1) { // if incrementing by 1

dmax = abs(x[0]);
dmax = magnitude(x[0]);

for (int i = 1; i < n; ++i) {
if (std::abs(x[i]) > dmax) {
if (magnitude(x[i]) > dmax) {
imax = i;
dmax = std::abs(x[i]);
dmax = magnitude(x[i]);
}
}

} else { // if incrementing by more than 1

dmax = std::abs(x[0]);
dmax = magnitude(x[0]);

for (int i = 1, ix = incx; i < n; ++i, ix += incx) {
if (std::abs(x[ix]) > dmax) {
if (magnitude(x[ix]) > dmax) {
imax = i;
dmax = std::abs(x[ix]);
dmax = magnitude(x[ix]);
}
}
}
Expand Down
19 changes: 16 additions & 3 deletions ext/nmatrix/math/long_dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
Expand All @@ -23,7 +23,8 @@
//
// == long_dtype.h
//
// Declarations necessary for the native versions of GEMM and GEMV.
// Declarations necessary for the native versions of GEMM and GEMV,
// as well as for IMAX.
//

#ifndef LONG_DTYPE_H
Expand All @@ -44,6 +45,18 @@ namespace nm { namespace math {
template <> struct LongDType<Complex128> { typedef Complex128 type; };
template <> struct LongDType<RubyObject> { typedef RubyObject type; };

template <typename DType> struct MagnitudeDType;
template <> struct MagnitudeDType<uint8_t> { typedef uint8_t type; };
template <> struct MagnitudeDType<int8_t> { typedef int8_t type; };
template <> struct MagnitudeDType<int16_t> { typedef int16_t type; };
template <> struct MagnitudeDType<int32_t> { typedef int32_t type; };
template <> struct MagnitudeDType<int64_t> { typedef int64_t type; };
template <> struct MagnitudeDType<float> { typedef float type; };
template <> struct MagnitudeDType<double> { typedef double type; };
template <> struct MagnitudeDType<Complex64> { typedef float type; };
template <> struct MagnitudeDType<Complex128> { typedef double type; };
template <> struct MagnitudeDType<RubyObject> { typedef RubyObject type; };

}} // end of namespace nm::math

#endif
54 changes: 54 additions & 0 deletions ext/nmatrix/math/magnitude.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/////////////////////////////////////////////////////////////////////
// = NMatrix
//
// A linear algebra library for scientific computation in Ruby.
// NMatrix is part of SciRuby.
//
// NMatrix was originally inspired by and derived from NArray, by
// Masahiro Tanaka: http://narray.rubyforge.org
//
// == Copyright Information
//
// SciRuby is Copyright (c) 2010 - present, Ruby Science Foundation
// NMatrix is Copyright (c) 2012 - present, John Woods and the Ruby Science Foundation
//
// Please see LICENSE.txt for additional copyright notices.
//
// == Contributing
//
// By contributing source code to SciRuby, you agree to be bound by
// our Contributor Agreement:
//
// * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
//
// == math/magnitude.h
//
// Takes the absolute value (meaning magnitude) of each DType.
// Needed for a variety of BLAS/LAPACK functions.
//

#ifndef MAGNITUDE_H
#define MAGNITUDE_H

#include "math/long_dtype.h"

namespace nm { namespace math {

/* Magnitude -- may be complicated for unsigned types, and need to call the correct STL abs for floats/doubles */
template <typename DType, typename MDType = typename MagnitudeDType<DType>::type>
inline MDType magnitude(const DType& v) {
return v.abs();
}
template <> inline float magnitude(const float& v) { return std::abs(v); }
template <> inline double magnitude(const double& v) { return std::abs(v); }
template <> inline uint8_t magnitude(const uint8_t& v) { return v; }
template <> inline int8_t magnitude(const int8_t& v) { return std::abs(v); }
template <> inline int16_t magnitude(const int16_t& v) { return std::abs(v); }
template <> inline int32_t magnitude(const int32_t& v) { return std::abs(v); }
template <> inline int64_t magnitude(const int64_t& v) { return std::abs(v); }
template <> inline float magnitude(const nm::Complex64& v) { return std::sqrt(v.r * v.r + v.i * v.i); }
template <> inline double magnitude(const nm::Complex128& v) { return std::sqrt(v.r * v.r + v.i * v.i); }

}}

#endif // MAGNITUDE_H
Loading