@@ -19,112 +19,63 @@ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod
1919from flint.utils.flint_exceptions import DomainError
2020
2121
22- # cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
23- # return mod.ctx.any_as_nmod(val, obj)
24-
25- _nmod_ctx_cache = {}
26-
27-
28- cdef nmod_ctx any_as_nmod_ctx(obj):
29- """ Convert an int to an nmod_ctx."""
30- if typecheck(obj, nmod_ctx):
31- return obj
32- if typecheck(obj, int ):
33- ctx = _nmod_ctx_cache.get(obj)
34- if ctx is None :
35- ctx = nmod_ctx(obj)
36- _nmod_ctx_cache[obj] = ctx
37- return ctx
38- return NotImplemented
39-
40-
41- cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except - 1 :
42- """ Convert an object to an nmod element."""
43- cdef int success
44- cdef fmpz_t t
45- if typecheck(obj, nmod):
46- if (< nmod> obj).ctx.mod.n != mod.n:
47- raise ValueError (" cannot coerce integers mod n with different n" )
48- val[0 ] = (< nmod> obj).val
49- return 1
50- z = any_as_fmpz(obj)
51- if z is not NotImplemented :
52- val[0 ] = fmpz_fdiv_ui((< fmpz> z).val, mod.n)
53- return 1
54- q = any_as_fmpq(obj)
55- if q is not NotImplemented :
56- fmpz_init(t)
57- fmpz_set_ui(t, mod.n)
58- success = fmpq_mod_fmpz(t, (< fmpq> q).val, t)
59- val[0 ] = fmpz_get_ui(t)
60- fmpz_clear(t)
61- if not success:
62- raise ZeroDivisionError (" %s does not exist mod %i !" % (q, mod.n))
63- return 1
64- return 0
22+ cdef dict _nmod_ctx_cache = {}
6523
6624
6725cdef class nmod_ctx:
6826 """
6927 Context object for creating :class:`~.nmod` initalised
7028 with modulus :math:`N`.
7129
72- >>> nmod_ctx(17)
30+ >>> nmod_ctx.get_ctx (17)
7331 nmod_ctx(17)
7432
7533 """
76- def __init__ (self , mod ):
77- cdef mp_limb_t m
78- m = mod
79- nmod_init(& self .mod, m)
80- self ._is_prime = n_is_prime(m)
8134
82- def __eq__ (self , other ):
83- # XXX: If we could ensure uniqueness of nmod_ctx for given modulus then
84- # we would need to implement __eq__ and __hash__ at all...
85- #
86- # It isn't possible to ensure uniqueness in __new__ like it is in
87- # Python because we can't return an existing object from __new__. What
88- # we could do though is make it so that __init__ raises an error and
89- # use a static method .new() to create new objects.
90- if self is other:
91- return True
92- if not typecheck(other, nmod_ctx):
93- return NotImplemented
94- return self .mod.n == (< nmod_ctx> other).mod.n
35+ def __init__ (self , *args , **kwargs ):
36+ raise TypeError (" cannot create nmod_ctx directly: use nmod_ctx.get_ctx()" )
37+
38+ @staticmethod
39+ cdef nmod_ctx any_as_nmod_ctx(obj):
40+ """ Convert an int to an nmod_ctx."""
41+ if typecheck(obj, nmod_ctx):
42+ return obj
43+ if typecheck(obj, int ):
44+ return nmod_ctx._get_ctx(obj)
45+ return NotImplemented
46+
47+ @staticmethod
48+ def get_ctx (mod ):
49+ """ Create a new nmod context."""
50+ return nmod_ctx._get_ctx(mod)
51+
52+ @staticmethod
53+ cdef _get_ctx(int mod):
54+ """ Create a new nmod context."""
55+ ctx = _nmod_ctx_cache.get(mod)
56+ if ctx is None :
57+ _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
58+ return ctx
59+
60+ @staticmethod
61+ cdef _new_ctx(ulong mod):
62+ """ Create a new nmod context."""
63+ cdef nmod_ctx ctx = nmod_ctx.__new__ (nmod_ctx)
64+ nmod_init(& ctx.mod, mod)
65+ ctx._is_prime = n_is_prime(mod)
66+ return ctx
9567
9668 def __repr__ (self ):
9769 return f" nmod_ctx({self.modulus()})"
9870
9971 cdef int any_as_nmod(self , mp_limb_t * val, obj) except - 1 :
10072 """ Convert an object to an nmod element."""
101- cdef int success
102- cdef fmpz_t t
103- if typecheck(obj, nmod):
104- if (< nmod> obj).ctx != self :
105- raise ValueError (" cannot coerce integers mod n with different n" )
106- val[0 ] = (< nmod> obj).val
107- return 1
108- z = any_as_fmpz(obj)
109- if z is not NotImplemented :
110- val[0 ] = fmpz_fdiv_ui((< fmpz> z).val, self .mod.n)
111- return 1
112- q = any_as_fmpq(obj)
113- if q is not NotImplemented :
114- fmpz_init(t)
115- fmpz_set_ui(t, self .mod.n)
116- success = fmpq_mod_fmpz(t, (< fmpq> q).val, t)
117- val[0 ] = fmpz_get_ui(t)
118- fmpz_clear(t)
119- if not success:
120- raise ZeroDivisionError (" %s does not exist mod %i !" % (q, self .mod.n))
121- return 1
122- return 0
73+ return any_as_nmod(val, obj, self )
12374
12475 def modulus (self ):
12576 """ Get the modulus of the context.
12677
127- >>> ctx = nmod_ctx(17)
78+ >>> ctx = nmod_ctx.get_ctx (17)
12879 >>> ctx.modulus()
12980 17
13081
@@ -134,7 +85,7 @@ cdef class nmod_ctx:
13485 def is_prime (self ):
13586 """ Check if the modulus is prime.
13687
137- >>> ctx = nmod_ctx(17)
88+ >>> ctx = nmod_ctx.get_ctx (17)
13889 >>> ctx.is_prime()
13990 True
14091
@@ -144,7 +95,7 @@ cdef class nmod_ctx:
14495 def zero (self ):
14596 """ Return the zero element of the context.
14697
147- >>> ctx = nmod_ctx(17)
98+ >>> ctx = nmod_ctx.get_ctx (17)
14899 >>> ctx.zero()
149100 0
150101
@@ -154,7 +105,7 @@ cdef class nmod_ctx:
154105 def one (self ):
155106 """ Return the one element of the context.
156107
157- >>> ctx = nmod_ctx(17)
108+ >>> ctx = nmod_ctx.get_ctx (17)
158109 >>> ctx.one()
159110 1
160111
@@ -185,16 +136,42 @@ cdef class nmod_ctx:
185136 def __call__ (self , val ):
186137 """ Create an nmod element from an integer.
187138
188- >>> ctx = nmod_ctx(17)
139+ >>> ctx = nmod_ctx.get_ctx (17)
189140 >>> ctx(10)
190141 10
191142
192143 """
193144 cdef mp_limb_t v
194- v = val
145+ v = val % self .mod.n
195146 return self ._new(& v)
196147
197148
149+ cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except - 1 :
150+ """ Convert an object to an nmod element."""
151+ cdef int success
152+ cdef fmpz_t t
153+ if typecheck(obj, nmod):
154+ if (< nmod> obj).ctx.mod.n != ctx.mod.n:
155+ raise ValueError (" cannot coerce integers mod n with different n" )
156+ val[0 ] = (< nmod> obj).val
157+ return 1
158+ z = any_as_fmpz(obj)
159+ if z is not NotImplemented :
160+ val[0 ] = fmpz_fdiv_ui((< fmpz> z).val, ctx.mod.n)
161+ return 1
162+ q = any_as_fmpq(obj)
163+ if q is not NotImplemented :
164+ fmpz_init(t)
165+ fmpz_set_ui(t, ctx.mod.n)
166+ success = fmpq_mod_fmpz(t, (< fmpq> q).val, t)
167+ val[0 ] = fmpz_get_ui(t)
168+ fmpz_clear(t)
169+ if not success:
170+ raise ZeroDivisionError (" %s does not exist mod %i !" % (q, ctx.mod.n))
171+ return 1
172+ return 0
173+
174+
198175@cython.no_gc
199176cdef class nmod(flint_scalar):
200177 """
@@ -205,10 +182,10 @@ cdef class nmod(flint_scalar):
205182
206183 """
207184 def __init__ (self , val , mod ):
208- ctx = any_as_nmod_ctx(mod)
185+ ctx = nmod_ctx. any_as_nmod_ctx(mod)
209186 if ctx is NotImplemented :
210187 raise TypeError (" Invalid context/modulus for nmod: %s " % mod)
211- if not ctx. any_as_nmod(& self .val, val):
188+ if not any_as_nmod(& self .val, val, ctx ):
212189 raise TypeError (" cannot create nmod from object of type %s " % type (val))
213190 self .ctx = ctx
214191
@@ -262,7 +239,7 @@ cdef class nmod(flint_scalar):
262239 cdef nmod r, s2
263240 cdef mp_limb_t val
264241 s2 = s
265- if s2.ctx. any_as_nmod(& val, t):
242+ if any_as_nmod(& val, t, s2.ctx ):
266243 r = nmod.__new__ (nmod)
267244 r.ctx = s2.ctx
268245 r.val = nmod_add(val, s2.val, s2.ctx.mod)
@@ -273,7 +250,7 @@ cdef class nmod(flint_scalar):
273250 cdef nmod r, s2
274251 cdef mp_limb_t val
275252 s2 = s
276- if s2.ctx. any_as_nmod(& val, t):
253+ if any_as_nmod(& val, t, s2.ctx ):
277254 r = nmod.__new__ (nmod)
278255 r.ctx = s2.ctx
279256 r.val = nmod_add(s2.val, val, s2.ctx.mod)
@@ -284,7 +261,7 @@ cdef class nmod(flint_scalar):
284261 cdef nmod r, s2
285262 cdef mp_limb_t val
286263 s2 = s
287- if s2.ctx. any_as_nmod(& val, t):
264+ if any_as_nmod(& val, t, s2.ctx ):
288265 r = nmod.__new__ (nmod)
289266 r.ctx = s2.ctx
290267 r.val = nmod_sub(s2.val, val, s2.ctx.mod)
@@ -295,7 +272,7 @@ cdef class nmod(flint_scalar):
295272 cdef nmod r
296273 cdef mp_limb_t val
297274 s2 = s
298- if s2.ctx. any_as_nmod(& val, t):
275+ if any_as_nmod(& val, t, s2.ctx ):
299276 r = nmod.__new__ (nmod)
300277 r.ctx = s2.ctx
301278 r.val = nmod_sub(val, s2.val, s2.ctx.mod)
@@ -306,7 +283,7 @@ cdef class nmod(flint_scalar):
306283 cdef nmod r, s2
307284 cdef mp_limb_t val
308285 s2 = s
309- if any_as_nmod(& val, t, s2.ctx.mod ):
286+ if any_as_nmod(& val, t, s2.ctx):
310287 r = nmod.__new__ (nmod)
311288 r.ctx = s2.ctx
312289 r.val = nmod_mul(val, s2.val, s2.ctx.mod)
@@ -317,7 +294,7 @@ cdef class nmod(flint_scalar):
317294 cdef nmod r, s2
318295 cdef mp_limb_t val
319296 s2 = s
320- if s2.ctx. any_as_nmod(& val, t):
297+ if any_as_nmod(& val, t, s2.ctx ):
321298 r = nmod.__new__ (nmod)
322299 r.ctx = s2.ctx
323300 r.val = nmod_mul(s2.val, val, s2.ctx.mod)
@@ -335,13 +312,13 @@ cdef class nmod(flint_scalar):
335312 s2 = s
336313 ctx = s2.ctx
337314 sval = s2.val
338- if not ctx. any_as_nmod(& tval, t):
315+ if not any_as_nmod(& tval, t, ctx ):
339316 return NotImplemented
340317 else :
341318 t2 = t
342319 ctx = t2.ctx
343320 tval = t2.val
344- if not ctx. any_as_nmod(& sval, s):
321+ if not any_as_nmod(& sval, s, ctx ):
345322 return NotImplemented
346323
347324 if tval == 0 :
0 commit comments