Skip to content

Commit 1725758

Browse files
npriyadarshitranslunar
authored andcommitted
fixed NMatrix#inverse_exact method for MRI (does not apply to JRuby): issues #444, #569, #581, #582
1 parent cfadf50 commit 1725758

File tree

3 files changed

+88
-15
lines changed

3 files changed

+88
-15
lines changed

ext/nmatrix/math.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ extern "C" {
188188
// Math Functions //
189189
////////////////////
190190

191-
namespace nm {
191+
namespace nm {
192192
namespace math {
193193

194194
/*
@@ -335,18 +335,18 @@ namespace nm {
335335
for (int row = k + 1; row < M; ++row) {
336336
typename MagnitudeDType<DType>::type big;
337337
big = magnitude( matrix[M*row + k] ); // element below the temp pivot
338-
338+
339339
if ( big > akk ) {
340340
interchange = row;
341-
akk = big;
341+
akk = big;
342342
}
343-
}
343+
}
344344

345345
if (interchange != k) { // check if rows need flipping
346346
DType temp;
347347

348348
for (int col = 0; col < M; ++col) {
349-
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
349+
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
350350
}
351351
}
352352

@@ -360,7 +360,7 @@ namespace nm {
360360
DType pivot = matrix[k * (M + 1)];
361361
matrix[k * (M + 1)] = (DType)(1); // set diagonal as 1 for in-place inversion
362362

363-
for (int col = 0; col < M; ++col) {
363+
for (int col = 0; col < M; ++col) {
364364
// divide each element in the kth row with the pivot
365365
matrix[k*M + col] = matrix[k*M + col] / pivot;
366366
}
@@ -369,7 +369,7 @@ namespace nm {
369369
if (kk == k) continue;
370370

371371
DType dum = matrix[k + M*kk];
372-
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
372+
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
373373
for (int col = 0; col < M; ++col) {
374374
matrix[M*kk + col] = matrix[M*kk + col] - matrix[M*k + col] * dum;
375375
}
@@ -384,7 +384,7 @@ namespace nm {
384384

385385
for (int row = 0; row < M; ++row) {
386386
NM_SWAP(matrix[row * M + row_index[k]], matrix[row * M + col_index[k]],
387-
temp);
387+
temp);
388388
}
389389
}
390390
}
@@ -410,14 +410,14 @@ namespace nm {
410410
DType sum_of_squares, *p_row, *psubdiag, *p_a, scale, innerproduct;
411411
int i, k, col;
412412

413-
// For each column use a Householder transformation to zero all entries
413+
// For each column use a Householder transformation to zero all entries
414414
// below the subdiagonal.
415-
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
415+
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
416416
col++) {
417417
// Calculate the signed square root of the sum of squares of the
418418
// elements below the diagonal.
419419

420-
for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
420+
for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
421421
p_a += nrows, i++) {
422422
sum_of_squares += *p_a * *p_a;
423423
}
@@ -447,7 +447,7 @@ namespace nm {
447447
*p_a -= u[k] * innerproduct;
448448
}
449449
}
450-
450+
451451
// Postmultiply QA by Q
452452
for (p_row = a, i = 0; i < nrows; p_row += nrows, i++) {
453453
for (innerproduct = 0.0, k = col + 1; k < nrows; k++) {
@@ -485,7 +485,7 @@ namespace nm {
485485
B[0] = A[lda+1] / det;
486486
B[1] = -A[1] / det;
487487
B[ldb] = -A[lda] / det;
488-
B[ldb+1] = -A[0] / det;
488+
B[ldb+1] = A[0] / det;
489489

490490
} else if (M == 3) {
491491
// Calculate the exact determinant.
@@ -1313,7 +1313,7 @@ void nm_math_hessenberg(VALUE a) {
13131313
NULL, NULL, // does not support Complex
13141314
NULL // no support for Ruby Object
13151315
};
1316-
1316+
13171317
ttable[NM_DTYPE(a)](NM_SHAPE0(a), NM_STORAGE_DENSE(a)->elements);
13181318
}
13191319
/*

lib/nmatrix/math.rb

+57-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,62 @@ def invert
112112
end
113113
alias :inverse :invert
114114

115+
#
116+
# call-seq:
117+
# invert_exact! -> NMatrix
118+
#
119+
# Calulates inverse_exact of a matrix of size 2 or 3.
120+
# Only works on dense matrices.
121+
#
122+
# * *Raises* :
123+
# - +StorageTypeError+ -> only implemented on dense matrices.
124+
# - +ShapeError+ -> matrix must be square.
125+
# - +DataTypeError+ -> cannot invert an integer matrix in-place.
126+
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
127+
#
128+
def invert_exact!
129+
raise(StorageTypeError, "invert only works on dense matrices currently") unless self.dense?
130+
raise(ShapeError, "Cannot invert non-square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
131+
raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype?
132+
#No internal implementation of getri, so use this other function
133+
n = self.shape[0]
134+
if n>3
135+
raise(NotImplementedError, "Cannot find exact inverse of matrix of size greater than 3")
136+
else
137+
clond=self.clone
138+
__inverse_exact__(clond, n, n)
139+
end
140+
end
141+
142+
#
143+
# call-seq:
144+
# invert_exact -> NMatrix
145+
#
146+
# Make a copy of the matrix, then invert using exact_inverse
147+
#
148+
# * *Returns* :
149+
# - A dense NMatrix. Will be the same type as the input NMatrix,
150+
# except if the input is an integral dtype, in which case it will be a
151+
# :float64 NMatrix.
152+
#
153+
# * *Raises* :
154+
# - +StorageTypeError+ -> only implemented on dense matrices.
155+
# - +ShapeError+ -> matrix must be square.
156+
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
157+
#
158+
def invert_exact
159+
#write this in terms of invert_exact! so plugins will only have to overwrite
160+
#invert_exact! and not invert_exact
161+
if self.integer_dtype?
162+
cloned = self.cast(dtype: :float64)
163+
cloned.invert_exact!
164+
else
165+
cloned = self.clone
166+
cloned.invert_exact!
167+
end
168+
end
169+
alias :inverse_exact :invert_exact
170+
115171
#
116172
# call-seq:
117173
# adjugate! -> NMatrix
@@ -1393,4 +1449,4 @@ def dtype_for_floor_or_ceil
13931449
self.__dense_map__ { |l| l.send(op,rhs) }
13941450
end
13951451
end
1396-
end
1452+
end

spec/math_spec.rb

+17
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,23 @@
488488

489489
expect(a.invert).to be_within(err).of(b)
490490
end
491+
492+
it "should correctly find exact inverse" do
493+
pending("not yet implemented for NMatrix-JRuby") if jruby?
494+
a = NMatrix.new(:dense, 3, [1,2,3,0,1,4,5,6,0], dtype)
495+
b = NMatrix.new(:dense, 3, [-24,18,5,20,-15,-4,-5,4,1], dtype)
496+
497+
expect(a.invert_exact).to be_within(err).of(b)
498+
end
499+
500+
it "should correctly find exact inverse" do
501+
pending("not yet implemented for NMatrix-JRuby") if jruby?
502+
a = NMatrix.new(:dense, 2, [1,3,3,8,], dtype)
503+
b = NMatrix.new(:dense, 2, [-8,3,3,-1], dtype)
504+
505+
expect(a.invert_exact).to be_within(err).of(b)
506+
end
507+
491508
end
492509
end
493510

0 commit comments

Comments
 (0)