@@ -43,8 +43,7 @@ from flint.flintlib.nmod_mat cimport (
4343from flint.utils.typecheck cimport typecheck
4444from flint.types.fmpz_mat cimport any_as_fmpz_mat
4545from flint.types.fmpz_mat cimport fmpz_mat
46- from flint.types.nmod cimport nmod
47- from flint.types.nmod cimport any_as_nmod
46+ from flint.types.nmod cimport nmod, any_as_nmod_ctx
4847from flint.types.nmod_poly cimport nmod_poly
4948from flint.pyflint cimport global_random_state
5049from flint.flint_base.flint_context cimport thectx
@@ -87,20 +86,24 @@ cdef class nmod_mat(flint_mat):
8786 def __init__ (self , *args ):
8887 cdef long m, n, i, j
8988 cdef mp_limb_t mod
89+ cdef nmod_ctx ctx
9090 if len (args) == 1 :
9191 val = args[0 ]
9292 if typecheck(val, nmod_mat):
9393 nmod_mat_init_set(self .val, (< nmod_mat> val).val)
94+ self .ctx = (< nmod_mat> val).ctx
9495 return
9596 mod = args[- 1 ]
9697 args = args[:- 1 ]
9798 if mod == 0 :
9899 raise ValueError (" modulus must be nonzero" )
100+ ctx = any_as_nmod_ctx(mod)
101+ self .ctx = ctx
99102 if len (args) == 1 :
100103 val = args[0 ]
101104 if typecheck(val, fmpz_mat):
102105 nmod_mat_init(self .val, fmpz_mat_nrows((< fmpz_mat> val).val),
103- fmpz_mat_ncols((< fmpz_mat> val).val), mod)
106+ fmpz_mat_ncols((< fmpz_mat> val).val), ctx. mod.n )
104107 fmpz_mat_get_nmod_mat(self .val, (< fmpz_mat> val).val)
105108 elif isinstance (val, (list , tuple )):
106109 m = len (val)
@@ -112,11 +115,11 @@ cdef class nmod_mat(flint_mat):
112115 for i from 1 <= i < m:
113116 if len (val[i]) != n:
114117 raise ValueError (" input rows have different lengths" )
115- nmod_mat_init(self .val, m, n, mod)
118+ nmod_mat_init(self .val, m, n, ctx. mod.n )
116119 for i from 0 <= i < m:
117120 row = val[i]
118121 for j from 0 <= j < n:
119- x = nmod(row[j], mod)
122+ x = nmod(row[j], ctx) # XXX: slow
120123 self .val.rows[i][j] = (< nmod> x).val
121124 else :
122125 raise TypeError (" cannot create nmod_mat from input of type %s " % type (val))
@@ -131,7 +134,7 @@ cdef class nmod_mat(flint_mat):
131134 raise ValueError (" list of entries has the wrong length" )
132135 for i from 0 <= i < m:
133136 for j from 0 <= j < n:
134- x = nmod(entries[i* n + j], mod ) # XXX: slow
137+ x = nmod(entries[i* n + j], ctx ) # XXX: slow
135138 self .val.rows[i][j] = (< nmod> x).val
136139 else :
137140 raise TypeError (" nmod_mat: expected 1-3 arguments plus modulus" )
@@ -207,7 +210,7 @@ cdef class nmod_mat(flint_mat):
207210 i, j = index
208211 if i < 0 or i >= self .nrows() or j < 0 or j >= self .ncols():
209212 raise IndexError (" index %i ,%i exceeds matrix dimensions" % (i, j))
210- if any_as_nmod(& v, value, self .val.mod ):
213+ if self .ctx. any_as_nmod(& v, value):
211214 nmod_mat_set_entry(self .val, i, j, v)
212215 else :
213216 raise TypeError (" cannot set item of type %s " % type (value))
@@ -306,7 +309,7 @@ cdef class nmod_mat(flint_mat):
306309 sv = & (< nmod_mat> s).val[0 ]
307310 u = any_as_nmod_mat(t, sv.mod)
308311 if u is NotImplemented :
309- if any_as_nmod(& c, t, sv.mod ):
312+ if s.ctx. any_as_nmod(& c, t):
310313 return (< nmod_mat> s).__mul_nmod(c)
311314 return NotImplemented
312315 tv = & (< nmod_mat> u).val[0 ]
@@ -323,7 +326,7 @@ cdef class nmod_mat(flint_mat):
323326 cdef nmod_mat_struct * sv
324327 cdef mp_limb_t c
325328 sv = & (< nmod_mat> s).val[0 ]
326- if any_as_nmod(& c, t, sv.mod ):
329+ if s.ctx. any_as_nmod(& c, t):
327330 return (< nmod_mat> s).__mul_nmod(c)
328331 u = any_as_nmod_mat(t, sv.mod)
329332 if u is NotImplemented :
@@ -348,7 +351,7 @@ cdef class nmod_mat(flint_mat):
348351 @staticmethod
349352 def _div_ (nmod_mat s , t ):
350353 cdef mp_limb_t v
351- if not any_as_nmod(& v, t, s.val.mod ):
354+ if not s.ctx. any_as_nmod(& v, t):
352355 return NotImplemented
353356 t = nmod(v, s.val.mod.n)
354357 return s * (~ t)
0 commit comments