Skip to content

Commit c6ea352

Browse files
committed
fmpz_mat uses exact division
Previously fmpz_mat / fmpz would give an fmpq_mat. Now fmpz_mat / fmpz succeeds and returns fmpz_mat if the division is exact and raises DomainError otherwise.
1 parent d47e0b5 commit c6ea352

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

src/flint/test/test.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,8 @@ def test_fmpz_mat():
547547
assert str(M(2,2,[1,2,3,4])) == '[1, 2]\n[3, 4]'
548548
assert M(1,2,[3,4]) * flint.fmpq(1,3) == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)])
549549
assert flint.fmpq(1,3) * M(1,2,[3,4]) == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)])
550-
assert M(1,2,[3,4]) / 3 == flint.fmpq_mat(1, 2, [1, flint.fmpq(4,3)])
550+
assert raises(lambda: M(1,2,[3,4]) / 3, DomainError)
551+
assert M(1,2,[2,4]) / 2 == M(1,2,[1,2])
551552
assert M(2,2,[1,2,3,4]).inv().det() == flint.fmpq(1) / M(2,2,[1,2,3,4]).det()
552553
assert M(2,2,[1,2,3,4]).inv().inv() == M(2,2,[1,2,3,4])
553554
assert raises(lambda: M.randrank(4,3,4,1), ValueError)
@@ -2318,6 +2319,44 @@ def test_division_scalar():
23182319
assert raises(lambda: "AAA" / R(5), TypeError)
23192320

23202321

2322+
def test_division_matrix():
2323+
Z = flint.fmpz
2324+
Q = flint.fmpq
2325+
F17 = lambda x: flint.nmod(x, 17)
2326+
ctx = flint.fmpz_mod_ctx(163)
2327+
F163 = lambda a: flint.fmpz_mod(a, ctx)
2328+
MZ = lambda x: flint.fmpz_mat(x)
2329+
MQ = lambda x: flint.fmpq_mat(x)
2330+
MF17 = lambda x: flint.nmod_mat(x, 17)
2331+
MF163 = lambda x: flint.fmpz_mod_mat(x, ctx)
2332+
# fmpz exact division
2333+
assert MZ([[2, 4]]) / Z(2) == MZ([[1, 2]])
2334+
assert MZ([[2, 4]]) / 2 == MZ([[1, 2]])
2335+
assert raises(lambda: MZ([[2, 5]]) / Z(2), DomainError)
2336+
assert raises(lambda: MZ([[2, 5]]) / 2, DomainError)
2337+
# field division by scalar
2338+
for (K, MK) in [(Q, MQ), (F17, MF17), (F163, MF163)]:
2339+
assert MK([[2, 5]]) / K(2) == MK([[K(2)/K(2), K(5)/K(2)]])
2340+
assert MK([[2, 5]]) / 2 == MK([[K(2)/K(2), K(5)/K(2)]])
2341+
# No other division is allowed
2342+
for (R, MR) in [(Z, MZ), (Q, MQ), (F17, MF17), (F163, MF163)]:
2343+
M = MR([[2, 5]])
2344+
for s in (2, R(2)):
2345+
assert raises(lambda: s / M, TypeError)
2346+
assert raises(lambda: M // s, TypeError)
2347+
assert raises(lambda: s // M, TypeError)
2348+
assert raises(lambda: M % s, TypeError)
2349+
assert raises(lambda: s % M, TypeError)
2350+
assert raises(lambda: divmod(s, M), TypeError)
2351+
assert raises(lambda: divmod(M, s), TypeError)
2352+
assert raises(lambda: M / M, TypeError)
2353+
assert raises(lambda: M // M, TypeError)
2354+
assert raises(lambda: M % M, TypeError)
2355+
assert raises(lambda: divmod(M, M), TypeError)
2356+
assert raises(lambda: M / 0, ZeroDivisionError)
2357+
assert raises(lambda: M / R(0), ZeroDivisionError)
2358+
2359+
23212360
def _all_polys():
23222361
return [
23232362
# (poly_type, scalar_type, is_field)
@@ -3066,6 +3105,7 @@ def test_all_tests():
30663105
test_fmpz_mod_mat,
30673106

30683107
test_division_scalar,
3108+
test_division_matrix,
30693109

30703110
test_polys,
30713111

src/flint/types/fmpz_mat.pyx

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ from flint.flintlib.fmpq_mat cimport fmpq_mat_init
1818
from flint.flintlib.fmpq_mat cimport fmpq_mat_set_fmpz_mat_div_fmpz
1919
from flint.flintlib.fmpq_mat cimport fmpq_mat_solve_fmpz_mat
2020

21+
from flint.utils.flint_exceptions import DomainError
22+
23+
2124
cdef any_as_fmpz_mat(obj):
2225
if typecheck(obj, fmpz_mat):
2326
return obj
@@ -131,13 +134,10 @@ cdef class fmpz_mat(flint_mat):
131134
def __nonzero__(self):
132135
return not fmpz_mat_is_zero(self.val)
133136

134-
def __richcmp__(s, t, int op):
137+
def __richcmp__(fmpz_mat s, t, int op):
135138
cdef bint r
136139
if op != 2 and op != 3:
137140
raise TypeError("matrices cannot be ordered")
138-
s = any_as_fmpz_mat(s)
139-
if t is NotImplemented:
140-
return s
141141
t = any_as_fmpz_mat(t)
142142
if t is NotImplemented:
143143
return t
@@ -282,15 +282,22 @@ cdef class fmpz_mat(flint_mat):
282282
return fmpq_mat(s) * t
283283
return NotImplemented
284284

285-
@staticmethod
286-
def _div_(fmpz_mat s, t):
287-
return s * (1 / fmpq(t))
288-
289-
def __truediv__(s, t):
290-
return fmpz_mat._div_(s, t)
291-
292-
def __div__(s, t):
293-
return fmpz_mat._div_(s, t)
285+
def __truediv__(fmpz_mat s, t):
286+
cdef fmpz_mat u
287+
cdef fmpz_mat_struct *sval
288+
t = any_as_fmpz(t)
289+
if t is NotImplemented:
290+
return t
291+
if fmpz_is_zero((<fmpz>t).val):
292+
raise ZeroDivisionError("division by zero")
293+
sval = &(<fmpz_mat>s).val[0]
294+
u = fmpz_mat.__new__(fmpz_mat)
295+
fmpz_mat_init(u.val, fmpz_mat_nrows(sval), fmpz_mat_ncols(sval))
296+
fmpz_mat_scalar_divexact_fmpz(u.val, sval, (<fmpz>t).val)
297+
# XXX: check for exact division - there should be a better way!
298+
if u * t != s:
299+
raise DomainError("fmpz_mat division is not exact")
300+
return u
294301

295302
def __pow__(self, e, m):
296303
cdef fmpz_mat t

0 commit comments

Comments
 (0)