@@ -27,31 +27,38 @@ cdef class nmod_ctx:
2727 Context object for creating :class:`~.nmod` initalised
2828 with modulus :math:`N`.
2929
30- >>> nmod_ctx.get_ctx(17)
30+ >>> ctx = nmod_ctx.new(17)
31+ >>> ctx
3132 nmod_ctx(17)
33+ >>> ctx.modulus()
34+ 17
35+ >>> e = ctx(10)
36+ >>> e
37+ 10
38+ >>> e + 10
39+ 3
3240
3341 """
34-
3542 def __init__ (self , *args , **kwargs ):
36- raise TypeError (" cannot create nmod_ctx directly: use nmod_ctx.get_ctx()" )
43+ raise TypeError (" cannot create nmod_ctx directly: use nmod_ctx.new()" )
44+
45+ @staticmethod
46+ def new (mod ):
47+ """ Get an nmod context with modulus ``mod``."""
48+ return nmod_ctx._get_ctx(mod)
3749
3850 @staticmethod
3951 cdef nmod_ctx any_as_nmod_ctx(obj):
40- """ Convert an int to an nmod_ctx."""
52+ """ Convert an ``nmod_ctx`` or `` int`` to an `` nmod_ctx`` ."""
4153 if typecheck(obj, nmod_ctx):
4254 return obj
4355 if typecheck(obj, int ):
4456 return nmod_ctx._get_ctx(obj)
4557 return NotImplemented
4658
47- @staticmethod
48- def get_ctx (mod ):
49- """ Create a new nmod context."""
50- return nmod_ctx._get_ctx(mod)
51-
5259 @staticmethod
5360 cdef _get_ctx(int mod):
54- """ Create a new nmod context."""
61+ """ Retrieve an nmod context from the cache or create a new one ."""
5562 ctx = _nmod_ctx_cache.get(mod)
5663 if ctx is None :
5764 _nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
@@ -91,13 +98,16 @@ cdef class nmod_ctx:
9198 return 1
9299 return 0
93100
94- def __repr__ (self ):
95- return f" nmod_ctx({self.modulus()})"
101+ @cython.final
102+ cdef nmod new_nmod(self ):
103+ cdef nmod r = nmod.__new__ (nmod)
104+ r.ctx = self
105+ return r
96106
97107 def modulus (self ):
98108 """ Get the modulus of the context.
99109
100- >>> ctx = nmod_ctx.get_ctx (17)
110+ >>> ctx = nmod_ctx.new (17)
101111 >>> ctx.modulus()
102112 17
103113
@@ -107,7 +117,7 @@ cdef class nmod_ctx:
107117 def is_prime (self ):
108118 """ Check if the modulus is prime.
109119
110- >>> ctx = nmod_ctx.get_ctx (17)
120+ >>> ctx = nmod_ctx.new (17)
111121 >>> ctx.is_prime()
112122 True
113123
@@ -117,7 +127,7 @@ cdef class nmod_ctx:
117127 def zero (self ):
118128 """ Return the zero element of the context.
119129
120- >>> ctx = nmod_ctx.get_ctx (17)
130+ >>> ctx = nmod_ctx.new (17)
121131 >>> ctx.zero()
122132 0
123133
@@ -127,7 +137,7 @@ cdef class nmod_ctx:
127137 def one (self ):
128138 """ Return the one element of the context.
129139
130- >>> ctx = nmod_ctx.get_ctx (17)
140+ >>> ctx = nmod_ctx.new (17)
131141 >>> ctx.one()
132142 1
133143
@@ -140,23 +150,17 @@ cdef class nmod_ctx:
140150 def __repr__ (self ):
141151 return f" nmod_ctx({self.modulus()})"
142152
143- cdef nmod _new(self , mp_limb_t * val):
144- cdef nmod r = nmod.__new__ (nmod)
145- r.val = val[0 ]
146- r.ctx = self
147- return r
148-
149153 def __call__ (self , val ):
150154 """ Create an nmod element from an integer.
151155
152- >>> ctx = nmod_ctx.get_ctx (17)
156+ >>> ctx = nmod_ctx.new (17)
153157 >>> ctx(10)
154158 10
155159
156160 """
157- cdef mp_limb_t v
158- v = val % self .mod.n
159- return self ._new( & v)
161+ r = self .new_nmod()
162+ self .any_as_nmod( & r.val, val)
163+ return r
160164
161165
162166@cython.no_gc
@@ -217,8 +221,7 @@ cdef class nmod(flint_scalar):
217221 return self
218222
219223 def __neg__ (self ):
220- cdef nmod r = nmod.__new__ (nmod)
221- r.ctx = self .ctx
224+ r = self .ctx.new_nmod()
222225 r.val = nmod_neg(self .val, self .ctx.mod)
223226 return r
224227
@@ -227,8 +230,7 @@ cdef class nmod(flint_scalar):
227230 cdef mp_limb_t val
228231 s2 = s
229232 if s2.ctx.any_as_nmod(& val, t):
230- r = nmod.__new__ (nmod)
231- r.ctx = s2.ctx
233+ r = s2.ctx.new_nmod()
232234 r.val = nmod_add(val, s2.val, s2.ctx.mod)
233235 return r
234236 return NotImplemented
@@ -238,8 +240,7 @@ cdef class nmod(flint_scalar):
238240 cdef mp_limb_t val
239241 s2 = s
240242 if s2.ctx.any_as_nmod(& val, t):
241- r = nmod.__new__ (nmod)
242- r.ctx = s2.ctx
243+ r = s2.ctx.new_nmod()
243244 r.val = nmod_add(s2.val, val, s2.ctx.mod)
244245 return r
245246 return NotImplemented
@@ -249,8 +250,7 @@ cdef class nmod(flint_scalar):
249250 cdef mp_limb_t val
250251 s2 = s
251252 if s2.ctx.any_as_nmod(& val, t):
252- r = nmod.__new__ (nmod)
253- r.ctx = s2.ctx
253+ r = s2.ctx.new_nmod()
254254 r.val = nmod_sub(s2.val, val, s2.ctx.mod)
255255 return r
256256 return NotImplemented
@@ -260,8 +260,7 @@ cdef class nmod(flint_scalar):
260260 cdef mp_limb_t val
261261 s2 = s
262262 if s2.ctx.any_as_nmod(& val, t):
263- r = nmod.__new__ (nmod)
264- r.ctx = s2.ctx
263+ r = s2.ctx.new_nmod()
265264 r.val = nmod_sub(val, s2.val, s2.ctx.mod)
266265 return r
267266 return NotImplemented
@@ -271,8 +270,7 @@ cdef class nmod(flint_scalar):
271270 cdef mp_limb_t val
272271 s2 = s
273272 if s2.ctx.any_as_nmod(& val, t):
274- r = nmod.__new__ (nmod)
275- r.ctx = s2.ctx
273+ r = s2.ctx.new_nmod()
276274 r.val = nmod_mul(val, s2.val, s2.ctx.mod)
277275 return r
278276 return NotImplemented
@@ -282,8 +280,7 @@ cdef class nmod(flint_scalar):
282280 cdef mp_limb_t val
283281 s2 = s
284282 if s2.ctx.any_as_nmod(& val, t):
285- r = nmod.__new__ (nmod)
286- r.ctx = s2.ctx
283+ r = s2.ctx.new_nmod()
287284 r.val = nmod_mul(s2.val, val, s2.ctx.mod)
288285 return r
289286 return NotImplemented
@@ -317,8 +314,7 @@ cdef class nmod(flint_scalar):
317314 if g != 1 :
318315 raise ZeroDivisionError (" %s is not invertible mod %s " % (tval, ctx.mod.n))
319316
320- r = nmod.__new__ (nmod)
321- r.ctx = ctx
317+ r = ctx.new_nmod()
322318 r.val = nmod_mul(sval, < mp_limb_t> tinvval, ctx.mod)
323319 return r
324320
@@ -338,8 +334,7 @@ cdef class nmod(flint_scalar):
338334 g = n_gcdinv(& inv, sval, ctx.mod.n)
339335 if g != 1 :
340336 raise ZeroDivisionError (" %s is not invertible mod %s " % (sval, ctx.mod.n))
341- r = nmod.__new__ (nmod)
342- r.ctx = ctx
337+ r = ctx.new_nmod()
343338 r.val = < mp_limb_t> inv
344339 return r
345340
@@ -370,8 +365,7 @@ cdef class nmod(flint_scalar):
370365 rval = < mp_limb_t> rinv
371366 e = - e
372367
373- r = nmod.__new__ (nmod)
374- r.ctx = ctx
368+ r = ctx.new_nmod()
375369 r.val = nmod_pow_fmpz(rval, (< fmpz> e).val, ctx.mod)
376370 return r
377371
@@ -394,8 +388,7 @@ cdef class nmod(flint_scalar):
394388 """
395389 cdef nmod r
396390 cdef mp_limb_t val
397- r = nmod.__new__ (nmod)
398- r.ctx = self .ctx
391+ r = self .ctx.new_nmod()
399392
400393 if self .val == 0 :
401394 return r
0 commit comments