Skip to content

Commit a88f881

Browse files
committed
Use nmod contexts for nmod_mat
1 parent 27e5e4c commit a88f881

4 files changed

Lines changed: 93 additions & 39 deletions

File tree

src/flint/types/nmod.pyx

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,23 @@ cdef class nmod_ctx:
5252
nmod_init(&self.mod, m)
5353
self._is_prime = n_is_prime(m)
5454

55+
def __eq__(self, other):
56+
# XXX: If we could ensure uniqueness of nmod_ctx for given modulus then
57+
# we would need to implement __eq__ and __hash__ at all...
58+
#
59+
# It isn't possible to ensure uniqueness in __new__ like it is in
60+
# Python because we can't return an existing object from __new__. What
61+
# we could do though is make it so that __init__ raises an error and
62+
# use a static method .new() to create new objects.
63+
if self is other:
64+
return True
65+
if not typecheck(other, nmod_ctx):
66+
return NotImplemented
67+
return self.mod.n == (<nmod_ctx>other).mod.n
68+
69+
def __repr__(self):
70+
return f"nmod_ctx({self.modulus()})"
71+
5572
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
5673
"""Convert an object to an nmod element."""
5774
cdef int success
@@ -85,7 +102,7 @@ cdef class nmod_ctx:
85102
17
86103
87104
"""
88-
return fmpz(self.mod)
105+
return fmpz(self.mod.n)
89106

90107
def is_prime(self):
91108
"""Check if the modulus is prime.
@@ -121,10 +138,10 @@ cdef class nmod_ctx:
121138
return hash(self.mod)
122139

123140
def __eq__(self, other):
124-
if not typecheck(other, nmod_ctx):
125-
return NotImplemented
141+
if typecheck(other, nmod_ctx):
142+
return self.mod.n == (<nmod_ctx>other).mod.n
126143
else:
127-
return self.mod == other.mod
144+
return NotImplemented
128145

129146
def __str__(self):
130147
return f"Context for nmod with modulus: {self.modulus()}"
@@ -165,6 +182,7 @@ cdef class nmod(flint_scalar):
165182
raise TypeError("Invalid context/modulus for nmod: %s" % mod)
166183
if not ctx.any_as_nmod(&self.val, val):
167184
raise TypeError("cannot create nmod from object of type %s" % type(val))
185+
self.ctx = ctx
168186

169187
def repr(self):
170188
return "nmod(%s, %s)" % (self.val, self.ctx.mod.n)
@@ -385,14 +403,14 @@ cdef class nmod(flint_scalar):
385403
cdef nmod r
386404
cdef mp_limb_t val
387405
r = nmod.__new__(nmod)
388-
r.mod = self.mod
406+
r.ctx = self.ctx
389407

390408
if self.val == 0:
391409
return r
392410

393-
val = n_sqrtmod(self.val, self.mod.n)
411+
val = n_sqrtmod(self.val, self.ctx.mod.n)
394412
if val == 0:
395-
raise DomainError("no square root exists for %s mod %s" % (self.val, self.mod.n))
413+
raise DomainError("no square root exists for %s mod %s" % (self.val, self.ctx.mod.n))
396414

397415
r.val = val
398416
return r

src/flint/types/nmod_mat.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ from flint.flint_base.flint_base cimport flint_mat
33
from flint.flintlib.nmod_mat cimport nmod_mat_t
44
from flint.flintlib.flint cimport mp_limb_t
55

6+
from flint.types.nmod cimport nmod_ctx
7+
8+
69
cdef class nmod_mat(flint_mat):
710
cdef nmod_mat_t val
11+
cdef nmod_ctx ctx
12+
813
cpdef long nrows(self)
914
cpdef long ncols(self)
1015
cpdef mp_limb_t modulus(self)

src/flint/types/nmod_mat.pyx

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ from flint.flintlib.nmod_mat cimport (
4343
from flint.utils.typecheck cimport typecheck
4444
from flint.types.fmpz_mat cimport any_as_fmpz_mat
4545
from flint.types.fmpz_mat cimport fmpz_mat
46-
from flint.types.nmod cimport nmod
47-
from flint.types.nmod cimport any_as_nmod
46+
from flint.types.nmod cimport nmod, any_as_nmod_ctx
4847
from flint.types.nmod_poly cimport nmod_poly
4948
from flint.pyflint cimport global_random_state
5049
from flint.flint_base.flint_context cimport thectx
@@ -87,20 +86,24 @@ cdef class nmod_mat(flint_mat):
8786
def __init__(self, *args):
8887
cdef long m, n, i, j
8988
cdef mp_limb_t mod
89+
cdef nmod_ctx ctx
9090
if len(args) == 1:
9191
val = args[0]
9292
if typecheck(val, nmod_mat):
9393
nmod_mat_init_set(self.val, (<nmod_mat>val).val)
94+
self.ctx = (<nmod_mat>val).ctx
9495
return
9596
mod = args[-1]
9697
args = args[:-1]
9798
if mod == 0:
9899
raise ValueError("modulus must be nonzero")
100+
ctx = any_as_nmod_ctx(mod)
101+
self.ctx = ctx
99102
if len(args) == 1:
100103
val = args[0]
101104
if typecheck(val, fmpz_mat):
102105
nmod_mat_init(self.val, fmpz_mat_nrows((<fmpz_mat>val).val),
103-
fmpz_mat_ncols((<fmpz_mat>val).val), mod)
106+
fmpz_mat_ncols((<fmpz_mat>val).val), ctx.mod.n)
104107
fmpz_mat_get_nmod_mat(self.val, (<fmpz_mat>val).val)
105108
elif isinstance(val, (list, tuple)):
106109
m = len(val)
@@ -112,11 +115,11 @@ cdef class nmod_mat(flint_mat):
112115
for i from 1 <= i < m:
113116
if len(val[i]) != n:
114117
raise ValueError("input rows have different lengths")
115-
nmod_mat_init(self.val, m, n, mod)
118+
nmod_mat_init(self.val, m, n, ctx.mod.n)
116119
for i from 0 <= i < m:
117120
row = val[i]
118121
for j from 0 <= j < n:
119-
x = nmod(row[j], mod)
122+
x = nmod(row[j], ctx) # XXX: slow
120123
self.val.rows[i][j] = (<nmod>x).val
121124
else:
122125
raise TypeError("cannot create nmod_mat from input of type %s" % type(val))
@@ -131,7 +134,7 @@ cdef class nmod_mat(flint_mat):
131134
raise ValueError("list of entries has the wrong length")
132135
for i from 0 <= i < m:
133136
for j from 0 <= j < n:
134-
x = nmod(entries[i*n + j], mod) # XXX: slow
137+
x = nmod(entries[i*n + j], ctx) # XXX: slow
135138
self.val.rows[i][j] = (<nmod>x).val
136139
else:
137140
raise TypeError("nmod_mat: expected 1-3 arguments plus modulus")
@@ -207,7 +210,7 @@ cdef class nmod_mat(flint_mat):
207210
i, j = index
208211
if i < 0 or i >= self.nrows() or j < 0 or j >= self.ncols():
209212
raise IndexError("index %i,%i exceeds matrix dimensions" % (i, j))
210-
if any_as_nmod(&v, value, self.val.mod):
213+
if self.ctx.any_as_nmod(&v, value):
211214
nmod_mat_set_entry(self.val, i, j, v)
212215
else:
213216
raise TypeError("cannot set item of type %s" % type(value))
@@ -306,7 +309,7 @@ cdef class nmod_mat(flint_mat):
306309
sv = &(<nmod_mat>s).val[0]
307310
u = any_as_nmod_mat(t, sv.mod)
308311
if u is NotImplemented:
309-
if any_as_nmod(&c, t, sv.mod):
312+
if s.ctx.any_as_nmod(&c, t):
310313
return (<nmod_mat>s).__mul_nmod(c)
311314
return NotImplemented
312315
tv = &(<nmod_mat>u).val[0]
@@ -323,7 +326,7 @@ cdef class nmod_mat(flint_mat):
323326
cdef nmod_mat_struct *sv
324327
cdef mp_limb_t c
325328
sv = &(<nmod_mat>s).val[0]
326-
if any_as_nmod(&c, t, sv.mod):
329+
if s.ctx.any_as_nmod(&c, t):
327330
return (<nmod_mat>s).__mul_nmod(c)
328331
u = any_as_nmod_mat(t, sv.mod)
329332
if u is NotImplemented:
@@ -348,7 +351,7 @@ cdef class nmod_mat(flint_mat):
348351
@staticmethod
349352
def _div_(nmod_mat s, t):
350353
cdef mp_limb_t v
351-
if not any_as_nmod(&v, t, s.val.mod):
354+
if not s.ctx.any_as_nmod(&v, t):
352355
return NotImplemented
353356
t = nmod(v, s.val.mod.n)
354357
return s * (~t)

0 commit comments

Comments
 (0)