@@ -27,31 +27,38 @@ cdef class nmod_ctx:
2727 Context object for creating :class:`~.nmod` initalised
2828 with modulus :math:`N`.
2929
30- >>> nmod_ctx.get_ctx(17)
30+ >>> ctx = nmod_ctx.new(17)
31+ >>> ctx
3132 nmod_ctx(17)
33+ >>> ctx.modulus()
34+ 17
35+ >>> e = ctx(10)
36+ >>> e
37+ 10
38+ >>> e + 10
39+ 3
3240
3341 """
34-
3542 def __init__ (self , *args , **kwargs ):
36- raise TypeError (" cannot create nmod_ctx directly: use nmod_ctx.get_ctx()" )
43+ raise TypeError (" cannot create nmod_ctx directly: use nmod_ctx.new()" )
44+
45+ @staticmethod
46+ def new (mod ):
47+ """ Get an nmod context with modulus ``mod``."""
48+ return nmod_ctx._get_ctx(mod)
3749
3850 @staticmethod
3951 cdef nmod_ctx any_as_nmod_ctx(obj):
40- """ Convert an int to an nmod_ctx."""
52+ """ Convert an ``nmod_ctx`` or `` int`` to an `` nmod_ctx`` ."""
4153 if typecheck(obj, nmod_ctx):
4254 return obj
4355 if typecheck(obj, int ):
4456 return nmod_ctx._get_ctx(obj)
4557 return NotImplemented
4658
47- @staticmethod
48- def get_ctx (mod ):
49- """ Create a new nmod context."""
50- return nmod_ctx._get_ctx(mod)
51-
5259 @staticmethod
5360 cdef _get_ctx(int mod):
54- """ Create a new nmod context."""
61+ """ Retrieve an nmod context from the cache or create a new one ."""
5562 ctx = _nmod_ctx_cache.get(mod)
5663 if ctx is None :
5764 _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
@@ -91,13 +98,16 @@ cdef class nmod_ctx:
9198 return 1
9299 return 0
93100
94- def __repr__ (self ):
95- return f" nmod_ctx({self.modulus()})"
101+ @cython.final
102+ cdef nmod new_nmod(self ):
103+ cdef nmod r = nmod.__new__ (nmod)
104+ r.ctx = self
105+ return r
96106
97107 def modulus (self ):
98108 """ Get the modulus of the context.
99109
100- >>> ctx = nmod_ctx.get_ctx (17)
110+ >>> ctx = nmod_ctx.new (17)
101111 >>> ctx.modulus()
102112 17
103113
@@ -107,7 +117,7 @@ cdef class nmod_ctx:
107117 def is_prime (self ):
108118 """ Check if the modulus is prime.
109119
110- >>> ctx = nmod_ctx.get_ctx (17)
120+ >>> ctx = nmod_ctx.new (17)
111121 >>> ctx.is_prime()
112122 True
113123
@@ -117,7 +127,7 @@ cdef class nmod_ctx:
117127 def zero (self ):
118128 """ Return the zero element of the context.
119129
120- >>> ctx = nmod_ctx.get_ctx (17)
130+ >>> ctx = nmod_ctx.new (17)
121131 >>> ctx.zero()
122132 0
123133
@@ -127,7 +137,7 @@ cdef class nmod_ctx:
127137 def one (self ):
128138 """ Return the one element of the context.
129139
130- >>> ctx = nmod_ctx.get_ctx (17)
140+ >>> ctx = nmod_ctx.new (17)
131141 >>> ctx.one()
132142 1
133143
@@ -140,23 +150,17 @@ cdef class nmod_ctx:
140150 def __repr__ (self ):
141151 return f" nmod_ctx({self.modulus()})"
142152
143- cdef nmod _new(self , mp_limb_t * val):
144- cdef nmod r = nmod.__new__ (nmod)
145- r.val = val[0 ]
146- r.ctx = self
147- return r
148-
149153 def __call__ (self , val ):
150154 """ Create an nmod element from an integer.
151155
152- >>> ctx = nmod_ctx.get_ctx (17)
156+ >>> ctx = nmod_ctx.new (17)
153157 >>> ctx(10)
154158 10
155159
156160 """
157- cdef mp_limb_t v
158- v = val % self .mod.n
159- return self ._new( & v)
161+ r = self .new_nmod()
162+ self .any_as_nmod( & r.val, val)
163+ return r
160164
161165
162166@cython.no_gc
@@ -218,8 +222,7 @@ cdef class nmod(flint_scalar):
218222 return self
219223
220224 def __neg__ (self ):
221- cdef nmod r = nmod.__new__ (nmod)
222- r.ctx = self .ctx
225+ r = self .ctx.new_nmod()
223226 r.val = nmod_neg(self .val, self .ctx.mod)
224227 return r
225228
@@ -228,8 +231,7 @@ cdef class nmod(flint_scalar):
228231 cdef mp_limb_t val
229232 s2 = s
230233 if s2.ctx.any_as_nmod(& val, t):
231- r = nmod.__new__ (nmod)
232- r.ctx = s2.ctx
234+ r = s2.ctx.new_nmod()
233235 r.val = nmod_add(val, s2.val, s2.ctx.mod)
234236 return r
235237 return NotImplemented
@@ -239,8 +241,7 @@ cdef class nmod(flint_scalar):
239241 cdef mp_limb_t val
240242 s2 = s
241243 if s2.ctx.any_as_nmod(& val, t):
242- r = nmod.__new__ (nmod)
243- r.ctx = s2.ctx
244+ r = s2.ctx.new_nmod()
244245 r.val = nmod_add(s2.val, val, s2.ctx.mod)
245246 return r
246247 return NotImplemented
@@ -250,8 +251,7 @@ cdef class nmod(flint_scalar):
250251 cdef mp_limb_t val
251252 s2 = s
252253 if s2.ctx.any_as_nmod(& val, t):
253- r = nmod.__new__ (nmod)
254- r.ctx = s2.ctx
254+ r = s2.ctx.new_nmod()
255255 r.val = nmod_sub(s2.val, val, s2.ctx.mod)
256256 return r
257257 return NotImplemented
@@ -261,8 +261,7 @@ cdef class nmod(flint_scalar):
261261 cdef mp_limb_t val
262262 s2 = s
263263 if s2.ctx.any_as_nmod(& val, t):
264- r = nmod.__new__ (nmod)
265- r.ctx = s2.ctx
264+ r = s2.ctx.new_nmod()
266265 r.val = nmod_sub(val, s2.val, s2.ctx.mod)
267266 return r
268267 return NotImplemented
@@ -272,8 +271,7 @@ cdef class nmod(flint_scalar):
272271 cdef mp_limb_t val
273272 s2 = s
274273 if s2.ctx.any_as_nmod(& val, t):
275- r = nmod.__new__ (nmod)
276- r.ctx = s2.ctx
274+ r = s2.ctx.new_nmod()
277275 r.val = nmod_mul(val, s2.val, s2.ctx.mod)
278276 return r
279277 return NotImplemented
@@ -283,8 +281,7 @@ cdef class nmod(flint_scalar):
283281 cdef mp_limb_t val
284282 s2 = s
285283 if s2.ctx.any_as_nmod(& val, t):
286- r = nmod.__new__ (nmod)
287- r.ctx = s2.ctx
284+ r = s2.ctx.new_nmod()
288285 r.val = nmod_mul(s2.val, val, s2.ctx.mod)
289286 return r
290287 return NotImplemented
@@ -318,8 +315,7 @@ cdef class nmod(flint_scalar):
318315 if g != 1 :
319316 raise ZeroDivisionError (" %s is not invertible mod %s " % (tval, ctx.mod.n))
320317
321- r = nmod.__new__ (nmod)
322- r.ctx = ctx
318+ r = ctx.new_nmod()
323319 r.val = nmod_mul(sval, < mp_limb_t> tinvval, ctx.mod)
324320 return r
325321
@@ -339,8 +335,7 @@ cdef class nmod(flint_scalar):
339335 g = n_gcdinv(& inv, sval, ctx.mod.n)
340336 if g != 1 :
341337 raise ZeroDivisionError (" %s is not invertible mod %s " % (sval, ctx.mod.n))
342- r = nmod.__new__ (nmod)
343- r.ctx = ctx
338+ r = ctx.new_nmod()
344339 r.val = < mp_limb_t> inv
345340 return r
346341
@@ -371,8 +366,7 @@ cdef class nmod(flint_scalar):
371366 rval = < mp_limb_t> rinv
372367 e = - e
373368
374- r = nmod.__new__ (nmod)
375- r.ctx = ctx
369+ r = ctx.new_nmod()
376370 r.val = nmod_pow_fmpz(rval, (< fmpz> e).val, ctx.mod)
377371 return r
378372
@@ -395,8 +389,7 @@ cdef class nmod(flint_scalar):
395389 """
396390 cdef nmod r
397391 cdef mp_limb_t val
398- r = nmod.__new__ (nmod)
399- r.ctx = self .ctx
392+ r = self .ctx.new_nmod()
400393
401394 if self .val == 0 :
402395 return r
0 commit comments