Skip to content

Commit 82af09e

Browse files
committed
Add nmod_ctx.new_nmod function
1 parent 154b035 commit 82af09e

4 files changed

Lines changed: 44 additions & 51 deletions

File tree

src/flint/test/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,7 @@ def test_nmod():
13571357
assert str(G(3,5)) == "3"
13581358
assert G(3,5).repr() == "nmod(3, 5)"
13591359

1360-
G = flint.nmod_ctx.get_ctx(7)
1360+
G = flint.nmod_ctx.new(7)
13611361
assert G(0) == G(7) == G(-7)
13621362

13631363

src/flint/types/nmod.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ cdef class nmod_ctx:
1515
cdef _new_ctx(ulong mod)
1616

1717
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
18-
cdef nmod _new(self, mp_limb_t * val)
18+
cdef nmod new_nmod(self)
1919

2020

2121
cdef class nmod(flint_scalar):

src/flint/types/nmod.pyx

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,38 @@ cdef class nmod_ctx:
2727
Context object for creating :class:`~.nmod` initalised
2828
with modulus :math:`N`.
2929
30-
>>> nmod_ctx.get_ctx(17)
30+
>>> ctx = nmod_ctx.new(17)
31+
>>> ctx
3132
nmod_ctx(17)
33+
>>> ctx.modulus()
34+
17
35+
>>> e = ctx(10)
36+
>>> e
37+
10
38+
>>> e + 10
39+
3
3240
3341
"""
34-
3542
def __init__(self, *args, **kwargs):
36-
raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.get_ctx()")
43+
raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.new()")
44+
45+
@staticmethod
46+
def new(mod):
47+
"""Get an nmod context with modulus ``mod``."""
48+
return nmod_ctx._get_ctx(mod)
3749

3850
@staticmethod
3951
cdef nmod_ctx any_as_nmod_ctx(obj):
40-
"""Convert an int to an nmod_ctx."""
52+
"""Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``."""
4153
if typecheck(obj, nmod_ctx):
4254
return obj
4355
if typecheck(obj, int):
4456
return nmod_ctx._get_ctx(obj)
4557
return NotImplemented
4658

47-
@staticmethod
48-
def get_ctx(mod):
49-
"""Create a new nmod context."""
50-
return nmod_ctx._get_ctx(mod)
51-
5259
@staticmethod
5360
cdef _get_ctx(int mod):
54-
"""Create a new nmod context."""
61+
"""Retrieve an nmod context from the cache or create a new one."""
5562
ctx = _nmod_ctx_cache.get(mod)
5663
if ctx is None:
5764
_nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
@@ -91,13 +98,16 @@ cdef class nmod_ctx:
9198
return 1
9299
return 0
93100

94-
def __repr__(self):
95-
return f"nmod_ctx({self.modulus()})"
101+
@cython.final
102+
cdef nmod new_nmod(self):
103+
cdef nmod r = nmod.__new__(nmod)
104+
r.ctx = self
105+
return r
96106

97107
def modulus(self):
98108
"""Get the modulus of the context.
99109
100-
>>> ctx = nmod_ctx.get_ctx(17)
110+
>>> ctx = nmod_ctx.new(17)
101111
>>> ctx.modulus()
102112
17
103113
@@ -107,7 +117,7 @@ cdef class nmod_ctx:
107117
def is_prime(self):
108118
"""Check if the modulus is prime.
109119
110-
>>> ctx = nmod_ctx.get_ctx(17)
120+
>>> ctx = nmod_ctx.new(17)
111121
>>> ctx.is_prime()
112122
True
113123
@@ -117,7 +127,7 @@ cdef class nmod_ctx:
117127
def zero(self):
118128
"""Return the zero element of the context.
119129
120-
>>> ctx = nmod_ctx.get_ctx(17)
130+
>>> ctx = nmod_ctx.new(17)
121131
>>> ctx.zero()
122132
0
123133
@@ -127,7 +137,7 @@ cdef class nmod_ctx:
127137
def one(self):
128138
"""Return the one element of the context.
129139
130-
>>> ctx = nmod_ctx.get_ctx(17)
140+
>>> ctx = nmod_ctx.new(17)
131141
>>> ctx.one()
132142
1
133143
@@ -140,23 +150,17 @@ cdef class nmod_ctx:
140150
def __repr__(self):
141151
return f"nmod_ctx({self.modulus()})"
142152

143-
cdef nmod _new(self, mp_limb_t * val):
144-
cdef nmod r = nmod.__new__(nmod)
145-
r.val = val[0]
146-
r.ctx = self
147-
return r
148-
149153
def __call__(self, val):
150154
"""Create an nmod element from an integer.
151155
152-
>>> ctx = nmod_ctx.get_ctx(17)
156+
>>> ctx = nmod_ctx.new(17)
153157
>>> ctx(10)
154158
10
155159
156160
"""
157-
cdef mp_limb_t v
158-
v = val % self.mod.n
159-
return self._new(&v)
161+
r = self.new_nmod()
162+
self.any_as_nmod(&r.val, val)
163+
return r
160164

161165

162166
@cython.no_gc
@@ -217,8 +221,7 @@ cdef class nmod(flint_scalar):
217221
return self
218222

219223
def __neg__(self):
220-
cdef nmod r = nmod.__new__(nmod)
221-
r.ctx = self.ctx
224+
r = self.ctx.new_nmod()
222225
r.val = nmod_neg(self.val, self.ctx.mod)
223226
return r
224227

@@ -227,8 +230,7 @@ cdef class nmod(flint_scalar):
227230
cdef mp_limb_t val
228231
s2 = s
229232
if s2.ctx.any_as_nmod(&val, t):
230-
r = nmod.__new__(nmod)
231-
r.ctx = s2.ctx
233+
r = s2.ctx.new_nmod()
232234
r.val = nmod_add(val, s2.val, s2.ctx.mod)
233235
return r
234236
return NotImplemented
@@ -238,8 +240,7 @@ cdef class nmod(flint_scalar):
238240
cdef mp_limb_t val
239241
s2 = s
240242
if s2.ctx.any_as_nmod(&val, t):
241-
r = nmod.__new__(nmod)
242-
r.ctx = s2.ctx
243+
r = s2.ctx.new_nmod()
243244
r.val = nmod_add(s2.val, val, s2.ctx.mod)
244245
return r
245246
return NotImplemented
@@ -249,8 +250,7 @@ cdef class nmod(flint_scalar):
249250
cdef mp_limb_t val
250251
s2 = s
251252
if s2.ctx.any_as_nmod(&val, t):
252-
r = nmod.__new__(nmod)
253-
r.ctx = s2.ctx
253+
r = s2.ctx.new_nmod()
254254
r.val = nmod_sub(s2.val, val, s2.ctx.mod)
255255
return r
256256
return NotImplemented
@@ -260,8 +260,7 @@ cdef class nmod(flint_scalar):
260260
cdef mp_limb_t val
261261
s2 = s
262262
if s2.ctx.any_as_nmod(&val, t):
263-
r = nmod.__new__(nmod)
264-
r.ctx = s2.ctx
263+
r = s2.ctx.new_nmod()
265264
r.val = nmod_sub(val, s2.val, s2.ctx.mod)
266265
return r
267266
return NotImplemented
@@ -271,8 +270,7 @@ cdef class nmod(flint_scalar):
271270
cdef mp_limb_t val
272271
s2 = s
273272
if s2.ctx.any_as_nmod(&val, t):
274-
r = nmod.__new__(nmod)
275-
r.ctx = s2.ctx
273+
r = s2.ctx.new_nmod()
276274
r.val = nmod_mul(val, s2.val, s2.ctx.mod)
277275
return r
278276
return NotImplemented
@@ -282,8 +280,7 @@ cdef class nmod(flint_scalar):
282280
cdef mp_limb_t val
283281
s2 = s
284282
if s2.ctx.any_as_nmod(&val, t):
285-
r = nmod.__new__(nmod)
286-
r.ctx = s2.ctx
283+
r = s2.ctx.new_nmod()
287284
r.val = nmod_mul(s2.val, val, s2.ctx.mod)
288285
return r
289286
return NotImplemented
@@ -317,8 +314,7 @@ cdef class nmod(flint_scalar):
317314
if g != 1:
318315
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n))
319316

320-
r = nmod.__new__(nmod)
321-
r.ctx = ctx
317+
r = ctx.new_nmod()
322318
r.val = nmod_mul(sval, <mp_limb_t>tinvval, ctx.mod)
323319
return r
324320

@@ -338,8 +334,7 @@ cdef class nmod(flint_scalar):
338334
g = n_gcdinv(&inv, sval, ctx.mod.n)
339335
if g != 1:
340336
raise ZeroDivisionError("%s is not invertible mod %s" % (sval, ctx.mod.n))
341-
r = nmod.__new__(nmod)
342-
r.ctx = ctx
337+
r = ctx.new_nmod()
343338
r.val = <mp_limb_t>inv
344339
return r
345340

@@ -370,8 +365,7 @@ cdef class nmod(flint_scalar):
370365
rval = <mp_limb_t>rinv
371366
e = -e
372367

373-
r = nmod.__new__(nmod)
374-
r.ctx = ctx
368+
r = ctx.new_nmod()
375369
r.val = nmod_pow_fmpz(rval, (<fmpz>e).val, ctx.mod)
376370
return r
377371

@@ -394,8 +388,7 @@ cdef class nmod(flint_scalar):
394388
"""
395389
cdef nmod r
396390
cdef mp_limb_t val
397-
r = nmod.__new__(nmod)
398-
r.ctx = self.ctx
391+
r = self.ctx.new_nmod()
399392

400393
if self.val == 0:
401394
return r

src/flint/types/nmod_poly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ cdef class nmod_poly_ctx:
6060
cdef mp_limb_t m
6161
m = mod
6262
nmod_init(&self.mod, m)
63-
self.ctx = nmod_ctx.get_ctx(mod)
63+
self.ctx = nmod_ctx.new(mod)
6464
self._is_prime = n_is_prime(m)
6565

6666
def __repr__(self):

0 commit comments

Comments
 (0)