Skip to content

Commit 210fbd0

Browse files
committed
Add nmod_ctx.new_nmod function
1 parent bf8c48b commit 210fbd0

4 files changed

Lines changed: 44 additions & 51 deletions

File tree

src/flint/test/test_all.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1357,7 +1357,7 @@ def test_nmod():
13571357
assert str(G(3,5)) == "3"
13581358
assert G(3,5).repr() == "nmod(3, 5)"
13591359

1360-
G = flint.nmod_ctx.get_ctx(7)
1360+
G = flint.nmod_ctx.new(7)
13611361
assert G(0) == G(7) == G(-7)
13621362

13631363

src/flint/types/nmod.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ cdef class nmod_ctx:
1515
cdef _new_ctx(ulong mod)
1616

1717
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
18-
cdef nmod _new(self, mp_limb_t * val)
18+
cdef nmod new_nmod(self)
1919

2020

2121
cdef class nmod(flint_scalar):

src/flint/types/nmod.pyx

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,38 @@ cdef class nmod_ctx:
2727
Context object for creating :class:`~.nmod` initalised
2828
with modulus :math:`N`.
2929
30-
>>> nmod_ctx.get_ctx(17)
30+
>>> ctx = nmod_ctx.new(17)
31+
>>> ctx
3132
nmod_ctx(17)
33+
>>> ctx.modulus()
34+
17
35+
>>> e = ctx(10)
36+
>>> e
37+
10
38+
>>> e + 10
39+
3
3240
3341
"""
34-
3542
def __init__(self, *args, **kwargs):
36-
raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.get_ctx()")
43+
raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.new()")
44+
45+
@staticmethod
46+
def new(mod):
47+
"""Get an nmod context with modulus ``mod``."""
48+
return nmod_ctx._get_ctx(mod)
3749

3850
@staticmethod
3951
cdef nmod_ctx any_as_nmod_ctx(obj):
40-
"""Convert an int to an nmod_ctx."""
52+
"""Convert an ``nmod_ctx`` or ``int`` to an ``nmod_ctx``."""
4153
if typecheck(obj, nmod_ctx):
4254
return obj
4355
if typecheck(obj, int):
4456
return nmod_ctx._get_ctx(obj)
4557
return NotImplemented
4658

47-
@staticmethod
48-
def get_ctx(mod):
49-
"""Create a new nmod context."""
50-
return nmod_ctx._get_ctx(mod)
51-
5259
@staticmethod
5360
cdef _get_ctx(int mod):
54-
"""Create a new nmod context."""
61+
"""Retrieve an nmod context from the cache or create a new one."""
5562
ctx = _nmod_ctx_cache.get(mod)
5663
if ctx is None:
5764
_nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
@@ -91,13 +98,16 @@ cdef class nmod_ctx:
9198
return 1
9299
return 0
93100

94-
def __repr__(self):
95-
return f"nmod_ctx({self.modulus()})"
101+
@cython.final
102+
cdef nmod new_nmod(self):
103+
cdef nmod r = nmod.__new__(nmod)
104+
r.ctx = self
105+
return r
96106

97107
def modulus(self):
98108
"""Get the modulus of the context.
99109
100-
>>> ctx = nmod_ctx.get_ctx(17)
110+
>>> ctx = nmod_ctx.new(17)
101111
>>> ctx.modulus()
102112
17
103113
@@ -107,7 +117,7 @@ cdef class nmod_ctx:
107117
def is_prime(self):
108118
"""Check if the modulus is prime.
109119
110-
>>> ctx = nmod_ctx.get_ctx(17)
120+
>>> ctx = nmod_ctx.new(17)
111121
>>> ctx.is_prime()
112122
True
113123
@@ -117,7 +127,7 @@ cdef class nmod_ctx:
117127
def zero(self):
118128
"""Return the zero element of the context.
119129
120-
>>> ctx = nmod_ctx.get_ctx(17)
130+
>>> ctx = nmod_ctx.new(17)
121131
>>> ctx.zero()
122132
0
123133
@@ -127,7 +137,7 @@ cdef class nmod_ctx:
127137
def one(self):
128138
"""Return the one element of the context.
129139
130-
>>> ctx = nmod_ctx.get_ctx(17)
140+
>>> ctx = nmod_ctx.new(17)
131141
>>> ctx.one()
132142
1
133143
@@ -140,23 +150,17 @@ cdef class nmod_ctx:
140150
def __repr__(self):
141151
return f"nmod_ctx({self.modulus()})"
142152

143-
cdef nmod _new(self, mp_limb_t * val):
144-
cdef nmod r = nmod.__new__(nmod)
145-
r.val = val[0]
146-
r.ctx = self
147-
return r
148-
149153
def __call__(self, val):
150154
"""Create an nmod element from an integer.
151155
152-
>>> ctx = nmod_ctx.get_ctx(17)
156+
>>> ctx = nmod_ctx.new(17)
153157
>>> ctx(10)
154158
10
155159
156160
"""
157-
cdef mp_limb_t v
158-
v = val % self.mod.n
159-
return self._new(&v)
161+
r = self.new_nmod()
162+
self.any_as_nmod(&r.val, val)
163+
return r
160164

161165

162166
@cython.no_gc
@@ -218,8 +222,7 @@ cdef class nmod(flint_scalar):
218222
return self
219223

220224
def __neg__(self):
221-
cdef nmod r = nmod.__new__(nmod)
222-
r.ctx = self.ctx
225+
r = self.ctx.new_nmod()
223226
r.val = nmod_neg(self.val, self.ctx.mod)
224227
return r
225228

@@ -228,8 +231,7 @@ cdef class nmod(flint_scalar):
228231
cdef mp_limb_t val
229232
s2 = s
230233
if s2.ctx.any_as_nmod(&val, t):
231-
r = nmod.__new__(nmod)
232-
r.ctx = s2.ctx
234+
r = s2.ctx.new_nmod()
233235
r.val = nmod_add(val, s2.val, s2.ctx.mod)
234236
return r
235237
return NotImplemented
@@ -239,8 +241,7 @@ cdef class nmod(flint_scalar):
239241
cdef mp_limb_t val
240242
s2 = s
241243
if s2.ctx.any_as_nmod(&val, t):
242-
r = nmod.__new__(nmod)
243-
r.ctx = s2.ctx
244+
r = s2.ctx.new_nmod()
244245
r.val = nmod_add(s2.val, val, s2.ctx.mod)
245246
return r
246247
return NotImplemented
@@ -250,8 +251,7 @@ cdef class nmod(flint_scalar):
250251
cdef mp_limb_t val
251252
s2 = s
252253
if s2.ctx.any_as_nmod(&val, t):
253-
r = nmod.__new__(nmod)
254-
r.ctx = s2.ctx
254+
r = s2.ctx.new_nmod()
255255
r.val = nmod_sub(s2.val, val, s2.ctx.mod)
256256
return r
257257
return NotImplemented
@@ -261,8 +261,7 @@ cdef class nmod(flint_scalar):
261261
cdef mp_limb_t val
262262
s2 = s
263263
if s2.ctx.any_as_nmod(&val, t):
264-
r = nmod.__new__(nmod)
265-
r.ctx = s2.ctx
264+
r = s2.ctx.new_nmod()
266265
r.val = nmod_sub(val, s2.val, s2.ctx.mod)
267266
return r
268267
return NotImplemented
@@ -272,8 +271,7 @@ cdef class nmod(flint_scalar):
272271
cdef mp_limb_t val
273272
s2 = s
274273
if s2.ctx.any_as_nmod(&val, t):
275-
r = nmod.__new__(nmod)
276-
r.ctx = s2.ctx
274+
r = s2.ctx.new_nmod()
277275
r.val = nmod_mul(val, s2.val, s2.ctx.mod)
278276
return r
279277
return NotImplemented
@@ -283,8 +281,7 @@ cdef class nmod(flint_scalar):
283281
cdef mp_limb_t val
284282
s2 = s
285283
if s2.ctx.any_as_nmod(&val, t):
286-
r = nmod.__new__(nmod)
287-
r.ctx = s2.ctx
284+
r = s2.ctx.new_nmod()
288285
r.val = nmod_mul(s2.val, val, s2.ctx.mod)
289286
return r
290287
return NotImplemented
@@ -318,8 +315,7 @@ cdef class nmod(flint_scalar):
318315
if g != 1:
319316
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, ctx.mod.n))
320317

321-
r = nmod.__new__(nmod)
322-
r.ctx = ctx
318+
r = ctx.new_nmod()
323319
r.val = nmod_mul(sval, <mp_limb_t>tinvval, ctx.mod)
324320
return r
325321

@@ -339,8 +335,7 @@ cdef class nmod(flint_scalar):
339335
g = n_gcdinv(&inv, sval, ctx.mod.n)
340336
if g != 1:
341337
raise ZeroDivisionError("%s is not invertible mod %s" % (sval, ctx.mod.n))
342-
r = nmod.__new__(nmod)
343-
r.ctx = ctx
338+
r = ctx.new_nmod()
344339
r.val = <mp_limb_t>inv
345340
return r
346341

@@ -371,8 +366,7 @@ cdef class nmod(flint_scalar):
371366
rval = <mp_limb_t>rinv
372367
e = -e
373368

374-
r = nmod.__new__(nmod)
375-
r.ctx = ctx
369+
r = ctx.new_nmod()
376370
r.val = nmod_pow_fmpz(rval, (<fmpz>e).val, ctx.mod)
377371
return r
378372

@@ -395,8 +389,7 @@ cdef class nmod(flint_scalar):
395389
"""
396390
cdef nmod r
397391
cdef mp_limb_t val
398-
r = nmod.__new__(nmod)
399-
r.ctx = self.ctx
392+
r = self.ctx.new_nmod()
400393

401394
if self.val == 0:
402395
return r

src/flint/types/nmod_poly.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ cdef class nmod_poly_ctx:
6060
cdef mp_limb_t m
6161
m = mod
6262
nmod_init(&self.mod, m)
63-
self.ctx = nmod_ctx.get_ctx(mod)
63+
self.ctx = nmod_ctx.new(mod)
6464
self._is_prime = n_is_prime(m)
6565

6666
def __repr__(self):

0 commit comments

Comments
 (0)