@@ -43,8 +43,7 @@ from flint.flintlib.nmod_mat cimport (
4343from flint.utils.typecheck cimport typecheck
4444from flint.types.fmpz_mat cimport any_as_fmpz_mat
4545from flint.types.fmpz_mat cimport fmpz_mat
46- from flint.types.nmod cimport nmod
47- from flint.types.nmod cimport any_as_nmod
46+ from flint.types.nmod cimport nmod, any_as_nmod_ctx
4847from flint.types.nmod_poly cimport nmod_poly
4948from flint.pyflint cimport global_random_state
5049from flint.flint_base.flint_context cimport thectx
@@ -87,20 +86,24 @@ cdef class nmod_mat(flint_mat):
8786 def __init__ (self , *args ):
8887 cdef long m, n, i, j
8988 cdef mp_limb_t mod
89+ cdef nmod_ctx ctx
9090 if len (args) == 1 :
9191 val = args[0 ]
9292 if typecheck(val, nmod_mat):
9393 nmod_mat_init_set(self .val, (< nmod_mat> val).val)
94+ self .ctx = (< nmod_mat> val).ctx
9495 return
9596 mod = args[- 1 ]
9697 args = args[:- 1 ]
9798 if mod == 0 :
9899 raise ValueError (" modulus must be nonzero" )
100+ ctx = any_as_nmod_ctx(mod)
101+ self .ctx = ctx
99102 if len (args) == 1 :
100103 val = args[0 ]
101104 if typecheck(val, fmpz_mat):
102105 nmod_mat_init(self .val, fmpz_mat_nrows((< fmpz_mat> val).val),
103- fmpz_mat_ncols((< fmpz_mat> val).val), mod)
106+ fmpz_mat_ncols((< fmpz_mat> val).val), ctx. mod.n )
104107 fmpz_mat_get_nmod_mat(self .val, (< fmpz_mat> val).val)
105108 elif isinstance (val, (list , tuple )):
106109 m = len (val)
@@ -112,11 +115,11 @@ cdef class nmod_mat(flint_mat):
112115 for i from 1 <= i < m:
113116 if len (val[i]) != n:
114117 raise ValueError (" input rows have different lengths" )
115- nmod_mat_init(self .val, m, n, mod)
118+ nmod_mat_init(self .val, m, n, ctx. mod.n )
116119 for i from 0 <= i < m:
117120 row = val[i]
118121 for j from 0 <= j < n:
119- x = nmod(row[j], mod)
122+ x = nmod(row[j], ctx) # XXX: slow
120123 self .val.rows[i][j] = (< nmod> x).val
121124 else :
122125 raise TypeError (" cannot create nmod_mat from input of type %s " % type (val))
@@ -131,7 +134,7 @@ cdef class nmod_mat(flint_mat):
131134 raise ValueError (" list of entries has the wrong length" )
132135 for i from 0 <= i < m:
133136 for j from 0 <= j < n:
134- x = nmod(entries[i* n + j], mod ) # XXX: slow
137+ x = nmod(entries[i* n + j], ctx ) # XXX: slow
135138 self .val.rows[i][j] = (< nmod> x).val
136139 else :
137140 raise TypeError (" nmod_mat: expected 1-3 arguments plus modulus" )
@@ -208,7 +211,7 @@ cdef class nmod_mat(flint_mat):
208211 i, j = index
209212 if i < 0 or i >= self .nrows() or j < 0 or j >= self .ncols():
210213 raise IndexError (" index %i ,%i exceeds matrix dimensions" % (i, j))
211- if any_as_nmod(& v, value, self .val.mod ):
214+ if self .ctx. any_as_nmod(& v, value):
212215 nmod_mat_set_entry(self .val, i, j, v)
213216 else :
214217 raise TypeError (" cannot set item of type %s " % type (value))
@@ -307,7 +310,7 @@ cdef class nmod_mat(flint_mat):
307310 sv = & (< nmod_mat> s).val[0 ]
308311 u = any_as_nmod_mat(t, sv.mod)
309312 if u is NotImplemented :
310- if any_as_nmod(& c, t, sv.mod ):
313+ if s.ctx. any_as_nmod(& c, t):
311314 return (< nmod_mat> s).__mul_nmod(c)
312315 return NotImplemented
313316 tv = & (< nmod_mat> u).val[0 ]
@@ -324,7 +327,7 @@ cdef class nmod_mat(flint_mat):
324327 cdef nmod_mat_struct * sv
325328 cdef mp_limb_t c
326329 sv = & (< nmod_mat> s).val[0 ]
327- if any_as_nmod(& c, t, sv.mod ):
330+ if s.ctx. any_as_nmod(& c, t):
328331 return (< nmod_mat> s).__mul_nmod(c)
329332 u = any_as_nmod_mat(t, sv.mod)
330333 if u is NotImplemented :
@@ -349,7 +352,7 @@ cdef class nmod_mat(flint_mat):
349352 @staticmethod
350353 def _div_ (nmod_mat s , t ):
351354 cdef mp_limb_t v
352- if not any_as_nmod(& v, t, s.val.mod ):
355+ if not s.ctx. any_as_nmod(& v, t):
353356 return NotImplemented
354357 t = nmod(v, s.val.mod.n)
355358 return s * (~ t)
0 commit comments