@@ -19,112 +19,63 @@ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod
1919from flint.utils.flint_exceptions import DomainError
2020
2121
22- # cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
23- # return mod.ctx.any_as_nmod(val, obj)
24-
25- _nmod_ctx_cache = {}
26-
27-
28- cdef nmod_ctx any_as_nmod_ctx(obj):
29- """ Convert an int to an nmod_ctx."""
30- if typecheck(obj, nmod_ctx):
31- return obj
32- if typecheck(obj, int ):
33- ctx = _nmod_ctx_cache.get(obj)
34- if ctx is None :
35- ctx = nmod_ctx(obj)
36- _nmod_ctx_cache[obj] = ctx
37- return ctx
38- return NotImplemented
39-
40-
41- cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except - 1 :
42- """ Convert an object to an nmod element."""
43- cdef int success
44- cdef fmpz_t t
45- if typecheck(obj, nmod):
46- if (< nmod> obj).ctx.mod.n != mod.n:
47- raise ValueError (" cannot coerce integers mod n with different n" )
48- val[0 ] = (< nmod> obj).val
49- return 1
50- z = any_as_fmpz(obj)
51- if z is not NotImplemented :
52- val[0 ] = fmpz_fdiv_ui((< fmpz> z).val, mod.n)
53- return 1
54- q = any_as_fmpq(obj)
55- if q is not NotImplemented :
56- fmpz_init(t)
57- fmpz_set_ui(t, mod.n)
58- success = fmpq_mod_fmpz(t, (< fmpq> q).val, t)
59- val[0 ] = fmpz_get_ui(t)
60- fmpz_clear(t)
61- if not success:
62- raise ZeroDivisionError (" %s does not exist mod %i !" % (q, mod.n))
63- return 1
64- return 0
22+ cdef dict _nmod_ctx_cache = {}
6523
6624
6725cdef class nmod_ctx:
6826 """
6927 Context object for creating :class:`~.nmod` initalised
7028 with modulus :math:`N`.
7129
72- >>> nmod_ctx(17)
30+ >>> nmod_ctx.get_ctx (17)
7331 nmod_ctx(17)
7432
7533 """
76- def __init__ (self , mod ):
77- cdef mp_limb_t m
78- m = mod
79- nmod_init(& self .mod, m)
80- self ._is_prime = n_is_prime(m)
8134
82- def __eq__ (self , other ):
83- # XXX: If we could ensure uniqueness of nmod_ctx for given modulus then
84- # we would need to implement __eq__ and __hash__ at all...
85- #
86- # It isn't possible to ensure uniqueness in __new__ like it is in
87- # Python because we can't return an existing object from __new__. What
88- # we could do though is make it so that __init__ raises an error and
89- # use a static method .new() to create new objects.
90- if self is other:
91- return True
92- if not typecheck(other, nmod_ctx):
93- return NotImplemented
94- return self .mod.n == (< nmod_ctx> other).mod.n
35+ def __init__ (self , *args , **kwargs ):
36+ raise TypeError (" cannot create nmod_ctx directly: use nmod_ctx.get_ctx()" )
37+
38+ @staticmethod
39+ cdef nmod_ctx any_as_nmod_ctx(obj):
40+ """ Convert an int to an nmod_ctx."""
41+ if typecheck(obj, nmod_ctx):
42+ return obj
43+ if typecheck(obj, int ):
44+ return nmod_ctx._get_ctx(obj)
45+ return NotImplemented
46+
47+ @staticmethod
48+ def get_ctx (mod ):
49+ """ Create a new nmod context."""
50+ return nmod_ctx._get_ctx(mod)
51+
52+ @staticmethod
53+ cdef _get_ctx(int mod):
54+ """ Create a new nmod context."""
55+ ctx = _nmod_ctx_cache.get(mod)
56+ if ctx is None :
57+ _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
58+ return ctx
59+
60+ @staticmethod
61+ cdef _new_ctx(ulong mod):
62+ """ Create a new nmod context."""
63+ cdef nmod_ctx ctx = nmod_ctx.__new__ (nmod_ctx)
64+ nmod_init(& ctx.mod, mod)
65+ ctx._is_prime = n_is_prime(mod)
66+ return ctx
9567
9668 def __repr__ (self ):
9769 return f" nmod_ctx({self.modulus()})"
9870
9971 cdef int any_as_nmod(self , mp_limb_t * val, obj) except - 1 :
10072 """ Convert an object to an nmod element."""
101- cdef int success
102- cdef fmpz_t t
103- if typecheck(obj, nmod):
104- if (< nmod> obj).ctx != self :
105- raise ValueError (" cannot coerce integers mod n with different n" )
106- val[0 ] = (< nmod> obj).val
107- return 1
108- z = any_as_fmpz(obj)
109- if z is not NotImplemented :
110- val[0 ] = fmpz_fdiv_ui((< fmpz> z).val, self .mod.n)
111- return 1
112- q = any_as_fmpq(obj)
113- if q is not NotImplemented :
114- fmpz_init(t)
115- fmpz_set_ui(t, self .mod.n)
116- success = fmpq_mod_fmpz(t, (< fmpq> q).val, t)
117- val[0 ] = fmpz_get_ui(t)
118- fmpz_clear(t)
119- if not success:
120- raise ZeroDivisionError (" %s does not exist mod %i !" % (q, self .mod.n))
121- return 1
122- return 0
73+ return any_as_nmod(val, obj, self )
12374
12475 def modulus (self ):
12576 """ Get the modulus of the context.
12677
127- >>> ctx = nmod_ctx(17)
78+ >>> ctx = nmod_ctx.get_ctx (17)
12879 >>> ctx.modulus()
12980 17
13081
@@ -134,7 +85,7 @@ cdef class nmod_ctx:
13485 def is_prime (self ):
13586 """ Check if the modulus is prime.
13687
137- >>> ctx = nmod_ctx(17)
88+ >>> ctx = nmod_ctx.get_ctx (17)
13889 >>> ctx.is_prime()
13990 True
14091
@@ -144,7 +95,7 @@ cdef class nmod_ctx:
14495 def zero (self ):
14596 """ Return the zero element of the context.
14697
147- >>> ctx = nmod_ctx(17)
98+ >>> ctx = nmod_ctx.get_ctx (17)
14899 >>> ctx.zero()
149100 0
150101
@@ -154,7 +105,7 @@ cdef class nmod_ctx:
154105 def one (self ):
155106 """ Return the one element of the context.
156107
157- >>> ctx = nmod_ctx(17)
108+ >>> ctx = nmod_ctx.get_ctx (17)
158109 >>> ctx.one()
159110 1
160111
@@ -185,16 +136,42 @@ cdef class nmod_ctx:
185136 def __call__ (self , val ):
186137 """ Create an nmod element from an integer.
187138
188- >>> ctx = nmod_ctx(17)
139+ >>> ctx = nmod_ctx.get_ctx (17)
189140 >>> ctx(10)
190141 10
191142
192143 """
193144 cdef mp_limb_t v
194- v = val
145+ v = val % self .mod.n
195146 return self ._new(& v)
196147
197148
149+ cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except - 1 :
150+ """ Convert an object to an nmod element."""
151+ cdef int success
152+ cdef fmpz_t t
153+ if typecheck(obj, nmod):
154+ if (< nmod> obj).ctx.mod.n != ctx.mod.n:
155+ raise ValueError (" cannot coerce integers mod n with different n" )
156+ val[0 ] = (< nmod> obj).val
157+ return 1
158+ z = any_as_fmpz(obj)
159+ if z is not NotImplemented :
160+ val[0 ] = fmpz_fdiv_ui((< fmpz> z).val, ctx.mod.n)
161+ return 1
162+ q = any_as_fmpq(obj)
163+ if q is not NotImplemented :
164+ fmpz_init(t)
165+ fmpz_set_ui(t, ctx.mod.n)
166+ success = fmpq_mod_fmpz(t, (< fmpq> q).val, t)
167+ val[0 ] = fmpz_get_ui(t)
168+ fmpz_clear(t)
169+ if not success:
170+ raise ZeroDivisionError (" %s does not exist mod %i !" % (q, ctx.mod.n))
171+ return 1
172+ return 0
173+
174+
198175@cython.no_gc
199176cdef class nmod(flint_scalar):
200177 """
@@ -205,10 +182,10 @@ cdef class nmod(flint_scalar):
205182
206183 """
207184 def __init__ (self , val , mod ):
208- ctx = any_as_nmod_ctx(mod)
185+ ctx = nmod_ctx. any_as_nmod_ctx(mod)
209186 if ctx is NotImplemented :
210187 raise TypeError (" Invalid context/modulus for nmod: %s " % mod)
211- if not ctx. any_as_nmod(& self .val, val):
188+ if not any_as_nmod(& self .val, val, ctx ):
212189 raise TypeError (" cannot create nmod from object of type %s " % type (val))
213190 self .ctx = ctx
214191
@@ -263,7 +240,7 @@ cdef class nmod(flint_scalar):
263240 cdef nmod r, s2
264241 cdef mp_limb_t val
265242 s2 = s
266- if s2.ctx. any_as_nmod(& val, t):
243+ if any_as_nmod(& val, t, s2.ctx ):
267244 r = nmod.__new__ (nmod)
268245 r.ctx = s2.ctx
269246 r.val = nmod_add(val, s2.val, s2.ctx.mod)
@@ -274,7 +251,7 @@ cdef class nmod(flint_scalar):
274251 cdef nmod r, s2
275252 cdef mp_limb_t val
276253 s2 = s
277- if s2.ctx. any_as_nmod(& val, t):
254+ if any_as_nmod(& val, t, s2.ctx ):
278255 r = nmod.__new__ (nmod)
279256 r.ctx = s2.ctx
280257 r.val = nmod_add(s2.val, val, s2.ctx.mod)
@@ -285,7 +262,7 @@ cdef class nmod(flint_scalar):
285262 cdef nmod r, s2
286263 cdef mp_limb_t val
287264 s2 = s
288- if s2.ctx. any_as_nmod(& val, t):
265+ if any_as_nmod(& val, t, s2.ctx ):
289266 r = nmod.__new__ (nmod)
290267 r.ctx = s2.ctx
291268 r.val = nmod_sub(s2.val, val, s2.ctx.mod)
@@ -296,7 +273,7 @@ cdef class nmod(flint_scalar):
296273 cdef nmod r
297274 cdef mp_limb_t val
298275 s2 = s
299- if s2.ctx. any_as_nmod(& val, t):
276+ if any_as_nmod(& val, t, s2.ctx ):
300277 r = nmod.__new__ (nmod)
301278 r.ctx = s2.ctx
302279 r.val = nmod_sub(val, s2.val, s2.ctx.mod)
@@ -307,7 +284,7 @@ cdef class nmod(flint_scalar):
307284 cdef nmod r, s2
308285 cdef mp_limb_t val
309286 s2 = s
310- if any_as_nmod(& val, t, s2.ctx.mod ):
287+ if any_as_nmod(& val, t, s2.ctx):
311288 r = nmod.__new__ (nmod)
312289 r.ctx = s2.ctx
313290 r.val = nmod_mul(val, s2.val, s2.ctx.mod)
@@ -318,7 +295,7 @@ cdef class nmod(flint_scalar):
318295 cdef nmod r, s2
319296 cdef mp_limb_t val
320297 s2 = s
321- if s2.ctx. any_as_nmod(& val, t):
298+ if any_as_nmod(& val, t, s2.ctx ):
322299 r = nmod.__new__ (nmod)
323300 r.ctx = s2.ctx
324301 r.val = nmod_mul(s2.val, val, s2.ctx.mod)
@@ -328,21 +305,21 @@ cdef class nmod(flint_scalar):
328305 @staticmethod
329306 def _div_ (s , t ):
330307 cdef nmod r, s2, t2
331- cdef mp_limb_t sval, tval, x
308+ cdef mp_limb_t sval, tval
332309 cdef nmod_ctx ctx
333310 cdef ulong tinvval
334311
335312 if typecheck(s, nmod):
336313 s2 = s
337314 ctx = s2.ctx
338315 sval = s2.val
339- if not ctx. any_as_nmod(& tval, t):
316+ if not any_as_nmod(& tval, t, ctx ):
340317 return NotImplemented
341318 else :
342319 t2 = t
343320 ctx = t2.ctx
344321 tval = t2.val
345- if not ctx. any_as_nmod(& sval, s):
322+ if not any_as_nmod(& sval, s, ctx ):
346323 return NotImplemented
347324
348325 if tval == 0 :
0 commit comments