Skip to content

Commit 8c2221b

Browse files
committed
Use nmod_ctx consistently in nmod_poly and nmod_mat
1 parent ba9c1fd commit 8c2221b

7 files changed

Lines changed: 284 additions & 164 deletions

File tree

src/flint/test/test_all.py

Lines changed: 119 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3595,17 +3595,34 @@ def factor_sqf(p):
35953595

35963596
def _all_matrices():
35973597
"""Return a list of matrix types and scalar types."""
3598+
# Prime modulus
35983599
R163 = flint.fmpz_mod_ctx(163)
35993600
R127 = flint.fmpz_mod_ctx(2**127 - 1)
36003601
R255 = flint.fmpz_mod_ctx(2**255 - 19)
3602+
3603+
# Composite modulus
3604+
R164_C = flint.fmpz_mod_ctx(164)
3605+
R127_C = flint.fmpz_mod_ctx(2**127)
3606+
R255_C = flint.fmpz_mod_ctx(2**255)
3607+
36013608
return [
3602-
# (matrix_type, scalar_type, is_field)
3603-
(flint.fmpz_mat, flint.fmpz, False),
3604-
(flint.fmpq_mat, flint.fmpq, True),
3605-
(lambda *a: flint.nmod_mat(*a, 17), lambda x: flint.nmod(x, 17), True),
3606-
(lambda *a: flint.fmpz_mod_mat(*a, R163), lambda x: flint.fmpz_mod(x, R163), True),
3607-
(lambda *a: flint.fmpz_mod_mat(*a, R127), lambda x: flint.fmpz_mod(x, R127), True),
3608-
(lambda *a: flint.fmpz_mod_mat(*a, R255), lambda x: flint.fmpz_mod(x, R255), True),
3609+
# (matrix_type, scalar_type, is_field, characteristic)
3610+
3611+
# Z and Q
3612+
(flint.fmpz_mat, flint.fmpz, False, 0),
3613+
(flint.fmpq_mat, flint.fmpq, True, 0),
3614+
3615+
# Z/pZ
3616+
(lambda *a: flint.nmod_mat(*a, 17), lambda x: flint.nmod(x, 17), True, 17),
3617+
(lambda *a: flint.fmpz_mod_mat(*a, R163), lambda x: flint.fmpz_mod(x, R163), True, 163),
3618+
(lambda *a: flint.fmpz_mod_mat(*a, R127), lambda x: flint.fmpz_mod(x, R127), True, 2**127 - 1),
3619+
(lambda *a: flint.fmpz_mod_mat(*a, R255), lambda x: flint.fmpz_mod(x, R255), True, 2**255 - 19),
3620+
3621+
# Z/nZ (n composite)
3622+
(lambda *a: flint.nmod_mat(*a, 16), lambda x: flint.nmod(x, 16), False, 16),
3623+
(lambda *a: flint.fmpz_mod_mat(*a, R164_C), lambda x: flint.fmpz_mod(x, R164_C), False, 164),
3624+
(lambda *a: flint.fmpz_mod_mat(*a, R127_C), lambda x: flint.fmpz_mod(x, R127_C), False, 2**127),
3625+
(lambda *a: flint.fmpz_mod_mat(*a, R255_C), lambda x: flint.fmpz_mod(x, R255_C), False, 2**255),
36093626
]
36103627

36113628

@@ -3726,7 +3743,7 @@ def _poly_type_from_matrix_type(mat_type):
37263743

37273744

37283745
def test_matrices_eq():
3729-
for M, S, is_field in _all_matrices():
3746+
for M, S, is_field, characteristic in _all_matrices():
37303747
A1 = M([[1, 2], [3, 4]])
37313748
A2 = M([[1, 2], [3, 4]])
37323749
B = M([[5, 6], [7, 8]])
@@ -3751,7 +3768,7 @@ def test_matrices_eq():
37513768

37523769

37533770
def test_matrices_constructor():
3754-
for M, S, is_field in _all_matrices():
3771+
for M, S, is_field, characteristic in _all_matrices():
37553772
assert raises(lambda: M(), TypeError)
37563773

37573774
# Empty matrices
@@ -3823,7 +3840,7 @@ def _matrix_repr(M):
38233840

38243841

38253842
def test_matrices_strrepr():
3826-
for M, S, is_field in _all_matrices():
3843+
for M, S, is_field, characteristic in _all_matrices():
38273844
A = M([[1, 2], [3, 4]])
38283845
A_str = "[1, 2]\n[3, 4]"
38293846
A_repr = _matrix_repr(A)
@@ -3846,7 +3863,7 @@ def test_matrices_strrepr():
38463863

38473864

38483865
def test_matrices_getitem():
3849-
for M, S, is_field in _all_matrices():
3866+
for M, S, is_field, characteristic in _all_matrices():
38503867
M1234 = M([[1, 2], [3, 4]])
38513868
assert M1234[0, 0] == S(1)
38523869
assert M1234[0, 1] == S(2)
@@ -3862,7 +3879,7 @@ def test_matrices_getitem():
38623879

38633880

38643881
def test_matrices_setitem():
3865-
for M, S, is_field in _all_matrices():
3882+
for M, S, is_field, characteristic in _all_matrices():
38663883
M1234 = M([[1, 2], [3, 4]])
38673884

38683885
assert M1234[0, 0] == S(1)
@@ -3888,7 +3905,7 @@ def setbad(obj, key, val):
38883905

38893906

38903907
def test_matrices_bool():
3891-
for M, S, is_field in _all_matrices():
3908+
for M, S, is_field, characteristic in _all_matrices():
38923909
assert bool(M([])) is False
38933910
assert bool(M([[0]])) is False
38943911
assert bool(M([[1]])) is True
@@ -3899,14 +3916,14 @@ def test_matrices_bool():
38993916

39003917

39013918
def test_matrices_pos_neg():
3902-
for M, S, is_field in _all_matrices():
3919+
for M, S, is_field, characteristic in _all_matrices():
39033920
M1234 = M([[1, 2], [3, 4]])
39043921
assert +M1234 == M1234
39053922
assert -M1234 == M([[-1, -2], [-3, -4]])
39063923

39073924

39083925
def test_matrices_add():
3909-
for M, S, is_field in _all_matrices():
3926+
for M, S, is_field, characteristic in _all_matrices():
39103927
M1234 = M([[1, 2], [3, 4]])
39113928
M5678 = M([[5, 6], [7, 8]])
39123929
assert M1234 + M5678 == M([[6, 8], [10, 12]])
@@ -3926,7 +3943,7 @@ def test_matrices_add():
39263943

39273944

39283945
def test_matrices_sub():
3929-
for M, S, is_field in _all_matrices():
3946+
for M, S, is_field, characteristic in _all_matrices():
39303947
M1234 = M([[1, 2], [3, 4]])
39313948
M5678 = M([[5, 6], [7, 8]])
39323949
assert M1234 - M5678 == M([[-4, -4], [-4, -4]])
@@ -3946,7 +3963,7 @@ def test_matrices_sub():
39463963

39473964

39483965
def test_matrices_mul():
3949-
for M, S, is_field in _all_matrices():
3966+
for M, S, is_field, characteristic in _all_matrices():
39503967
M1234 = M([[1, 2], [3, 4]])
39513968
M5678 = M([[5, 6], [7, 8]])
39523969
assert M1234 * M5678 == M([[19, 22], [43, 50]])
@@ -3972,18 +3989,24 @@ def test_matrices_mul():
39723989

39733990

39743991
def test_matrices_pow():
3975-
for M, S, is_field in _all_matrices():
3992+
for M, S, is_field, characteristic in _all_matrices():
39763993
M1234 = M([[1, 2], [3, 4]])
3994+
39773995
assert M1234**0 == M([[1, 0], [0, 1]])
39783996
assert M1234**1 == M1234
39793997
assert M1234**2 == M([[7, 10], [15, 22]])
39803998
assert M1234**3 == M([[37, 54], [81, 118]])
3999+
39814000
if is_field:
39824001
assert M1234**-1 == M([[-4, 2], [3, -1]]) / 2
39834002
assert M1234**-2 == M([[22, -10], [-15, 7]]) / 4
39844003
assert M1234**-3 == M([[-118, 54], [81, -37]]) / 8
39854004
Ms = M([[1, 2], [3, 6]])
39864005
assert raises(lambda: Ms**-1, ZeroDivisionError)
4006+
else:
4007+
# XXX: Allow unimodular matrices?
4008+
assert raises(lambda: M1234**-1, DomainError)
4009+
39874010
Mr = M([[1, 2, 3], [4, 5, 6]])
39884011
assert raises(lambda: Mr**0, ValueError)
39894012
assert raises(lambda: Mr**1, ValueError)
@@ -3993,31 +4016,49 @@ def test_matrices_pow():
39934016

39944017

39954018
def test_matrices_div():
3996-
for M, S, is_field in _all_matrices():
4019+
4020+
for M, S, is_field, characteristic in _all_matrices():
39974021
M1234 = M([[1, 2], [3, 4]])
4022+
39984023
if is_field:
39994024
assert M1234 / 2 == M([[S(1)/2, S(1)], [S(3)/2, 2]])
40004025
assert M1234 / S(2) == M([[S(1)/2, S(1)], [S(3)/2, 2]])
40014026
assert raises(lambda: M1234 / 0, ZeroDivisionError)
40024027
assert raises(lambda: M1234 / S(0), ZeroDivisionError)
4028+
else:
4029+
assert raises(lambda: M1234 / 2, DomainError)
4030+
if characteristic == 0:
4031+
assert (2*M1234) / 2 == M1234
4032+
else:
4033+
assert raises(lambda: (2*M1234) / 2, DomainError)
4034+
40034035
raises(lambda: M1234 / None, TypeError)
40044036
raises(lambda: None / M1234, TypeError)
40054037

40064038

40074039
def test_matrices_inv():
4008-
for M, S, is_field in _all_matrices():
4009-
if is_field:
4010-
M1234 = M([[1, 2], [3, 4]])
4040+
4041+
for M, S, is_field, characteristic in _all_matrices():
4042+
4043+
M1234 = M([[1, 2], [3, 4]])
4044+
M1236 = M([[1, 2], [3, 6]])
4045+
Mr = M([[1, 2, 3], [4, 5, 6]])
4046+
4047+
if characteristic > 0 and not is_field:
4048+
assert raises(lambda: M([[1, 2], [3, 4]]).inv(), DomainError)
4049+
elif is_field:
40114050
assert M1234.inv() == M([[-2, 1], [S(3)/2, -S(1)/2]])
4012-
M1236 = M([[1, 2], [3, 6]])
40134051
assert raises(lambda: M1236.inv(), ZeroDivisionError)
4014-
Mr = M([[1, 2, 3], [4, 5, 6]])
40154052
assert raises(lambda: Mr.inv(), ValueError)
4016-
# XXX: Test non-field matrices. unimodular?
4053+
else:
4054+
# assert M1234.inv() == (M([[-4, 2], [3, -1]]), 2)
4055+
# assert M1236.inv() == (M([[-6, 2], [3, -1]]), 3)
4056+
# XXX: fmpz_mat.inv() return fmpq_mat...
4057+
assert M1234.inv() * M1234.det() == M([[4, -2], [-3, 1]])
40174058

40184059

40194060
def test_matrices_det():
4020-
for M, S, is_field in _all_matrices():
4061+
for M, S, is_field, characteristic in _all_matrices():
40214062
M1234 = M([[1, 2], [3, 4]])
40224063
assert M1234.det() == S(-2)
40234064
M9 = M([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
@@ -4027,7 +4068,7 @@ def test_matrices_det():
40274068

40284069

40294070
def test_matrices_charpoly():
4030-
for M, S, is_field in _all_matrices():
4071+
for M, S, is_field, characteristic in _all_matrices():
40314072
P = _poly_type_from_matrix_type(M)
40324073
M1234 = M([[1, 2], [3, 4]])
40334074
assert M1234.charpoly() == P([-2, -5, 1])
@@ -4038,18 +4079,21 @@ def test_matrices_charpoly():
40384079

40394080

40404081
def test_matrices_minpoly():
4041-
for M, S, is_field in _all_matrices():
4082+
for M, S, is_field, characteristic in _all_matrices():
4083+
if characteristic > 0 and not is_field:
4084+
assert raises(lambda: M([[1, 2], [3, 4]]).minpoly(), DomainError)
4085+
continue
40424086
P = _poly_type_from_matrix_type(M)
4043-
M1234 = M([[1, 2], [3, 4]])
4044-
assert M1234.minpoly() == P([-2, -5, 1])
4045-
M9 = M([[2, 1, 0], [0, 2, 0], [0, 0, 2]])
4046-
assert M9.minpoly() == P([4, -4, 1])
4047-
Mr = M([[1, 2, 3], [4, 5, 6]])
4048-
assert raises(lambda: Mr.minpoly(), ValueError)
4087+
assert M([[1, 2], [3, 4]]).minpoly() == P([-2, -5, 1])
4088+
assert M([[2, 1, 0], [0, 2, 0], [0, 0, 2]]).minpoly() == P([4, -4, 1])
4089+
assert raises(lambda: M([[1, 2, 3], [4, 5, 6]]).minpoly(), ValueError)
40494090

40504091

40514092
def test_matrices_rank():
4052-
for M, S, is_field in _all_matrices():
4093+
for M, S, is_field, characteristic in _all_matrices():
4094+
if characteristic > 0 and not is_field:
4095+
assert raises(lambda: M([[1, 2], [3, 4]]).rank(), DomainError)
4096+
continue
40534097
M1234 = M([[1, 2], [3, 4]])
40544098
assert M1234.rank() == 2
40554099
Mr = M([[1, 2, 3], [4, 5, 6]])
@@ -4061,37 +4105,57 @@ def test_matrices_rank():
40614105

40624106

40634107
def test_matrices_rref():
4064-
for M, S, is_field in _all_matrices():
4065-
if is_field:
4066-
Mr = M([[1, 2, 3], [4, 5, 6]])
4067-
Mr_rref = M([[1, 0, -1], [0, 1, 2]])
4108+
for M, S, is_field, characteristic in _all_matrices():
4109+
4110+
Mr = M([[1, 2, 3], [4, 5, 6]])
4111+
Mr_rref = M([[1, 0, -1], [0, 1, 2]])
4112+
4113+
if characteristic > 0 and not is_field:
4114+
# Z/nZ (n composite) raises
4115+
assert raises(lambda: Mr.rref(), DomainError)
4116+
elif is_field:
4117+
# Q, Z/pZ and GF(p^d) return usual RREF
40684118
assert Mr.rref() == (Mr_rref, 2)
40694119
assert Mr == M([[1, 2, 3], [4, 5, 6]])
40704120
assert Mr.rref(inplace=True) == (Mr_rref, 2)
40714121
assert Mr == Mr_rref
4122+
else:
4123+
# Z returns RREF with divisor -3
4124+
d = -3
4125+
assert Mr.rref() == (d*Mr_rref, d, 2)
4126+
assert Mr == M([[1, 2, 3], [4, 5, 6]])
4127+
assert Mr.rref(inplace=True) == (d*Mr_rref, d, 2)
4128+
assert Mr == d*Mr_rref
40724129

40734130

40744131
def test_matrices_solve():
4075-
for M, S, is_field in _all_matrices():
4076-
if is_field:
4077-
A = M([[1, 2], [3, 4]])
4078-
x = M([[1], [2]])
4079-
b = M([[5], [11]])
4080-
assert A*x == b
4132+
for M, S, is_field, characteristic in _all_matrices():
4133+
4134+
A = M([[1, 2], [3, 4]])
4135+
x = M([[1], [2]])
4136+
b = M([[5], [11]])
4137+
assert A*x == b
4138+
4139+
A2 = M([[1, 2], [2, 4]])
4140+
4141+
if characteristic > 0 and not is_field:
4142+
assert raises(lambda: A.solve(b), DomainError)
4143+
assert raises(lambda: A2.solve(b), DomainError)
4144+
else:
40814145
assert A.solve(b) == x
4082-
A22 = M([[1, 2], [3, 4]])
4083-
A23 = M([[1, 2, 3], [4, 5, 6]])
4084-
b2 = M([[5], [11]])
4085-
b3 = M([[5], [11], [17]])
4086-
assert raises(lambda: A22.solve(b3), ValueError)
4087-
assert raises(lambda: A23.solve(b2), ValueError)
4088-
assert raises(lambda: A.solve(None), TypeError)
4089-
A = M([[1, 2], [2, 4]])
4090-
assert raises(lambda: A.solve(b), ZeroDivisionError)
4146+
assert raises(lambda: A2.solve(b), ZeroDivisionError)
4147+
4148+
A22 = M([[1, 2], [3, 4]])
4149+
A23 = M([[1, 2, 3], [4, 5, 6]])
4150+
b2 = M([[5], [11]])
4151+
b3 = M([[5], [11], [17]])
4152+
assert raises(lambda: A22.solve(b3), ValueError)
4153+
assert raises(lambda: A23.solve(b2), ValueError)
4154+
assert raises(lambda: A.solve(None), TypeError)
40914155

40924156

40934157
def test_matrices_transpose():
4094-
for M, S, is_field in _all_matrices():
4158+
for M, S, is_field, characteristic in _all_matrices():
40954159
M1234 = M([[1, 2, 3], [4, 5, 6]])
40964160
assert M1234.transpose() == M([[1, 4], [2, 5], [3, 6]])
40974161

src/flint/types/fmpz_mat.pyx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,11 @@ cdef class fmpz_mat(flint_mat):
306306
raise ValueError("matrix must be square")
307307
if m is not None:
308308
raise NotImplementedError("modular matrix exponentiation")
309+
if e < 0:
310+
raise DomainError("negative power of integer matrix: M**%i" % e)
309311
ee = e
310-
t = fmpz_mat(self) # XXX
312+
t = fmpz_mat.__new__(fmpz_mat)
313+
fmpz_mat_init_set(t.val, (<fmpz_mat>self).val)
311314
fmpz_mat_pow(t.val, t.val, ee)
312315
return t
313316

0 commit comments

Comments
 (0)