Skip to content

Commit 8358f13

Browse files
committed
Use any_as_nmod rather than ctx.any_as_nmod
1 parent 3730225 commit 8358f13

6 files changed

Lines changed: 96 additions & 109 deletions

File tree

src/flint/test/test_all.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,10 @@ def test_nmod():
13571357
assert str(G(3,5)) == "3"
13581358
assert G(3,5).repr() == "nmod(3, 5)"
13591359

1360+
G = flint.nmod_ctx.get_ctx(7)
1361+
assert G(0) == G(7) == G(-7)
1362+
1363+
13601364
def test_nmod_poly():
13611365
N = flint.nmod
13621366
P = flint.nmod_poly

src/flint/types/arb.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2239,7 +2239,7 @@ cdef class arb(flint_scalar):
22392239
>>> from flint import showgood
22402240
>>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5, abc=True), dps=25)
22412241
1.447530478120770807945697
2242-
>>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25)
2242+
>>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25) # doctest: +SKIP
22432243
Traceback (most recent call last):
22442244
...
22452245
ValueError: no convergence (maxprec=960, try higher maxprec)

src/flint/types/nmod.pxd

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from flint.flint_base.flint_base cimport flint_scalar
2-
from flint.flintlib.flint cimport mp_limb_t
2+
from flint.flintlib.flint cimport mp_limb_t, ulong
33
from flint.flintlib.nmod cimport nmod_t
44

5-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1
6-
cdef nmod_ctx any_as_nmod_ctx(obj)
5+
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx mod) except -1
6+
#cdef nmod_ctx any_as_nmod_ctx(obj)
77

88

99
cdef class nmod_ctx:
1010
cdef nmod_t mod
1111
cdef bint _is_prime
1212

13+
@staticmethod
14+
cdef nmod_ctx any_as_nmod_ctx(obj)
15+
@staticmethod
16+
cdef _get_ctx(int mod)
17+
@staticmethod
18+
cdef _new_ctx(ulong mod)
19+
1320
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
1421
cdef nmod _new(self, mp_limb_t * val)
1522

src/flint/types/nmod.pyx

Lines changed: 78 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -19,112 +19,63 @@ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod
1919
from flint.utils.flint_exceptions import DomainError
2020

2121

22-
#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
23-
# return mod.ctx.any_as_nmod(val, obj)
24-
25-
_nmod_ctx_cache = {}
26-
27-
28-
cdef nmod_ctx any_as_nmod_ctx(obj):
29-
"""Convert an int to an nmod_ctx."""
30-
if typecheck(obj, nmod_ctx):
31-
return obj
32-
if typecheck(obj, int):
33-
ctx = _nmod_ctx_cache.get(obj)
34-
if ctx is None:
35-
ctx = nmod_ctx(obj)
36-
_nmod_ctx_cache[obj] = ctx
37-
return ctx
38-
return NotImplemented
39-
40-
41-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
42-
"""Convert an object to an nmod element."""
43-
cdef int success
44-
cdef fmpz_t t
45-
if typecheck(obj, nmod):
46-
if (<nmod>obj).ctx.mod.n != mod.n:
47-
raise ValueError("cannot coerce integers mod n with different n")
48-
val[0] = (<nmod>obj).val
49-
return 1
50-
z = any_as_fmpz(obj)
51-
if z is not NotImplemented:
52-
val[0] = fmpz_fdiv_ui((<fmpz>z).val, mod.n)
53-
return 1
54-
q = any_as_fmpq(obj)
55-
if q is not NotImplemented:
56-
fmpz_init(t)
57-
fmpz_set_ui(t, mod.n)
58-
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
59-
val[0] = fmpz_get_ui(t)
60-
fmpz_clear(t)
61-
if not success:
62-
raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n))
63-
return 1
64-
return 0
22+
cdef dict _nmod_ctx_cache = {}
6523

6624

6725
cdef class nmod_ctx:
6826
"""
6927
Context object for creating :class:`~.nmod` initalised
7028
with modulus :math:`N`.
7129
72-
>>> nmod_ctx(17)
30+
>>> nmod_ctx.get_ctx(17)
7331
nmod_ctx(17)
7432
7533
"""
76-
def __init__(self, mod):
77-
cdef mp_limb_t m
78-
m = mod
79-
nmod_init(&self.mod, m)
80-
self._is_prime = n_is_prime(m)
8134

82-
def __eq__(self, other):
83-
# XXX: If we could ensure uniqueness of nmod_ctx for given modulus then
84-
# we would need to implement __eq__ and __hash__ at all...
85-
#
86-
# It isn't possible to ensure uniqueness in __new__ like it is in
87-
# Python because we can't return an existing object from __new__. What
88-
# we could do though is make it so that __init__ raises an error and
89-
# use a static method .new() to create new objects.
90-
if self is other:
91-
return True
92-
if not typecheck(other, nmod_ctx):
93-
return NotImplemented
94-
return self.mod.n == (<nmod_ctx>other).mod.n
35+
def __init__(self, *args, **kwargs):
36+
raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.get_ctx()")
37+
38+
@staticmethod
39+
cdef nmod_ctx any_as_nmod_ctx(obj):
40+
"""Convert an int to an nmod_ctx."""
41+
if typecheck(obj, nmod_ctx):
42+
return obj
43+
if typecheck(obj, int):
44+
return nmod_ctx._get_ctx(obj)
45+
return NotImplemented
46+
47+
@staticmethod
48+
def get_ctx(mod):
49+
"""Create a new nmod context."""
50+
return nmod_ctx._get_ctx(mod)
51+
52+
@staticmethod
53+
cdef _get_ctx(int mod):
54+
"""Create a new nmod context."""
55+
ctx = _nmod_ctx_cache.get(mod)
56+
if ctx is None:
57+
_nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
58+
return ctx
59+
60+
@staticmethod
61+
cdef _new_ctx(ulong mod):
62+
"""Create a new nmod context."""
63+
cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx)
64+
nmod_init(&ctx.mod, mod)
65+
ctx._is_prime = n_is_prime(mod)
66+
return ctx
9567

9668
def __repr__(self):
9769
return f"nmod_ctx({self.modulus()})"
9870

9971
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
10072
"""Convert an object to an nmod element."""
101-
cdef int success
102-
cdef fmpz_t t
103-
if typecheck(obj, nmod):
104-
if (<nmod>obj).ctx != self:
105-
raise ValueError("cannot coerce integers mod n with different n")
106-
val[0] = (<nmod>obj).val
107-
return 1
108-
z = any_as_fmpz(obj)
109-
if z is not NotImplemented:
110-
val[0] = fmpz_fdiv_ui((<fmpz>z).val, self.mod.n)
111-
return 1
112-
q = any_as_fmpq(obj)
113-
if q is not NotImplemented:
114-
fmpz_init(t)
115-
fmpz_set_ui(t, self.mod.n)
116-
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
117-
val[0] = fmpz_get_ui(t)
118-
fmpz_clear(t)
119-
if not success:
120-
raise ZeroDivisionError("%s does not exist mod %i!" % (q, self.mod.n))
121-
return 1
122-
return 0
73+
return any_as_nmod(val, obj, self)
12374

12475
def modulus(self):
12576
"""Get the modulus of the context.
12677
127-
>>> ctx = nmod_ctx(17)
78+
>>> ctx = nmod_ctx.get_ctx(17)
12879
>>> ctx.modulus()
12980
17
13081
@@ -134,7 +85,7 @@ cdef class nmod_ctx:
13485
def is_prime(self):
13586
"""Check if the modulus is prime.
13687
137-
>>> ctx = nmod_ctx(17)
88+
>>> ctx = nmod_ctx.get_ctx(17)
13889
>>> ctx.is_prime()
13990
True
14091
@@ -144,7 +95,7 @@ cdef class nmod_ctx:
14495
def zero(self):
14596
"""Return the zero element of the context.
14697
147-
>>> ctx = nmod_ctx(17)
98+
>>> ctx = nmod_ctx.get_ctx(17)
14899
>>> ctx.zero()
149100
0
150101
@@ -154,7 +105,7 @@ cdef class nmod_ctx:
154105
def one(self):
155106
"""Return the one element of the context.
156107
157-
>>> ctx = nmod_ctx(17)
108+
>>> ctx = nmod_ctx.get_ctx(17)
158109
>>> ctx.one()
159110
1
160111
@@ -185,16 +136,42 @@ cdef class nmod_ctx:
185136
def __call__(self, val):
186137
"""Create an nmod element from an integer.
187138
188-
>>> ctx = nmod_ctx(17)
139+
>>> ctx = nmod_ctx.get_ctx(17)
189140
>>> ctx(10)
190141
10
191142
192143
"""
193144
cdef mp_limb_t v
194-
v = val
145+
v = val % self.mod.n
195146
return self._new(&v)
196147

197148

149+
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except -1:
150+
"""Convert an object to an nmod element."""
151+
cdef int success
152+
cdef fmpz_t t
153+
if typecheck(obj, nmod):
154+
if (<nmod>obj).ctx.mod.n != ctx.mod.n:
155+
raise ValueError("cannot coerce integers mod n with different n")
156+
val[0] = (<nmod>obj).val
157+
return 1
158+
z = any_as_fmpz(obj)
159+
if z is not NotImplemented:
160+
val[0] = fmpz_fdiv_ui((<fmpz>z).val, ctx.mod.n)
161+
return 1
162+
q = any_as_fmpq(obj)
163+
if q is not NotImplemented:
164+
fmpz_init(t)
165+
fmpz_set_ui(t, ctx.mod.n)
166+
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
167+
val[0] = fmpz_get_ui(t)
168+
fmpz_clear(t)
169+
if not success:
170+
raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n))
171+
return 1
172+
return 0
173+
174+
198175
@cython.no_gc
199176
cdef class nmod(flint_scalar):
200177
"""
@@ -205,10 +182,10 @@ cdef class nmod(flint_scalar):
205182
206183
"""
207184
def __init__(self, val, mod):
208-
ctx = any_as_nmod_ctx(mod)
185+
ctx = nmod_ctx.any_as_nmod_ctx(mod)
209186
if ctx is NotImplemented:
210187
raise TypeError("Invalid context/modulus for nmod: %s" % mod)
211-
if not ctx.any_as_nmod(&self.val, val):
188+
if not any_as_nmod(&self.val, val, ctx):
212189
raise TypeError("cannot create nmod from object of type %s" % type(val))
213190
self.ctx = ctx
214191

@@ -263,7 +240,7 @@ cdef class nmod(flint_scalar):
263240
cdef nmod r, s2
264241
cdef mp_limb_t val
265242
s2 = s
266-
if s2.ctx.any_as_nmod(&val, t):
243+
if any_as_nmod(&val, t, s2.ctx):
267244
r = nmod.__new__(nmod)
268245
r.ctx = s2.ctx
269246
r.val = nmod_add(val, s2.val, s2.ctx.mod)
@@ -274,7 +251,7 @@ cdef class nmod(flint_scalar):
274251
cdef nmod r, s2
275252
cdef mp_limb_t val
276253
s2 = s
277-
if s2.ctx.any_as_nmod(&val, t):
254+
if any_as_nmod(&val, t, s2.ctx):
278255
r = nmod.__new__(nmod)
279256
r.ctx = s2.ctx
280257
r.val = nmod_add(s2.val, val, s2.ctx.mod)
@@ -285,7 +262,7 @@ cdef class nmod(flint_scalar):
285262
cdef nmod r, s2
286263
cdef mp_limb_t val
287264
s2 = s
288-
if s2.ctx.any_as_nmod(&val, t):
265+
if any_as_nmod(&val, t, s2.ctx):
289266
r = nmod.__new__(nmod)
290267
r.ctx = s2.ctx
291268
r.val = nmod_sub(s2.val, val, s2.ctx.mod)
@@ -296,7 +273,7 @@ cdef class nmod(flint_scalar):
296273
cdef nmod r
297274
cdef mp_limb_t val
298275
s2 = s
299-
if s2.ctx.any_as_nmod(&val, t):
276+
if any_as_nmod(&val, t, s2.ctx):
300277
r = nmod.__new__(nmod)
301278
r.ctx = s2.ctx
302279
r.val = nmod_sub(val, s2.val, s2.ctx.mod)
@@ -307,7 +284,7 @@ cdef class nmod(flint_scalar):
307284
cdef nmod r, s2
308285
cdef mp_limb_t val
309286
s2 = s
310-
if any_as_nmod(&val, t, s2.ctx.mod):
287+
if any_as_nmod(&val, t, s2.ctx):
311288
r = nmod.__new__(nmod)
312289
r.ctx = s2.ctx
313290
r.val = nmod_mul(val, s2.val, s2.ctx.mod)
@@ -318,7 +295,7 @@ cdef class nmod(flint_scalar):
318295
cdef nmod r, s2
319296
cdef mp_limb_t val
320297
s2 = s
321-
if s2.ctx.any_as_nmod(&val, t):
298+
if any_as_nmod(&val, t, s2.ctx):
322299
r = nmod.__new__(nmod)
323300
r.ctx = s2.ctx
324301
r.val = nmod_mul(s2.val, val, s2.ctx.mod)
@@ -328,21 +305,21 @@ cdef class nmod(flint_scalar):
328305
@staticmethod
329306
def _div_(s, t):
330307
cdef nmod r, s2, t2
331-
cdef mp_limb_t sval, tval, x
308+
cdef mp_limb_t sval, tval
332309
cdef nmod_ctx ctx
333310
cdef ulong tinvval
334311

335312
if typecheck(s, nmod):
336313
s2 = s
337314
ctx = s2.ctx
338315
sval = s2.val
339-
if not ctx.any_as_nmod(&tval, t):
316+
if not any_as_nmod(&tval, t, ctx):
340317
return NotImplemented
341318
else:
342319
t2 = t
343320
ctx = t2.ctx
344321
tval = t2.val
345-
if not ctx.any_as_nmod(&sval, s):
322+
if not any_as_nmod(&sval, s, ctx):
346323
return NotImplemented
347324

348325
if tval == 0:

src/flint/types/nmod_mat.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ from flint.flintlib.nmod_mat cimport (
4343
from flint.utils.typecheck cimport typecheck
4444
from flint.types.fmpz_mat cimport any_as_fmpz_mat
4545
from flint.types.fmpz_mat cimport fmpz_mat
46-
from flint.types.nmod cimport nmod, any_as_nmod_ctx
46+
from flint.types.nmod cimport nmod, nmod_ctx
4747
from flint.types.nmod_poly cimport nmod_poly, nmod_poly_new_init, any_as_nmod_poly_ctx
4848
from flint.pyflint cimport global_random_state
4949
from flint.flint_base.flint_context cimport thectx
@@ -121,7 +121,7 @@ cdef class nmod_mat(flint_mat):
121121
if mod == 0:
122122
raise ValueError("modulus must be nonzero")
123123

124-
ctx = any_as_nmod_ctx(mod)
124+
ctx = nmod_ctx.any_as_nmod_ctx(mod)
125125
self.ctx = ctx
126126

127127
if len(args) == 1:

src/flint/types/nmod_poly.pyx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ from flint.utils.typecheck cimport typecheck
44
from flint.types.fmpz cimport fmpz, any_as_fmpz
55
from flint.types.fmpz_poly cimport any_as_fmpz_poly
66
from flint.types.fmpz_poly cimport fmpz_poly
7-
from flint.types.nmod cimport any_as_nmod_ctx
87
from flint.types.nmod cimport nmod, nmod_ctx
98

109
from flint.flintlib.nmod_vec cimport *
@@ -61,7 +60,7 @@ cdef class nmod_poly_ctx:
6160
cdef mp_limb_t m
6261
m = mod
6362
nmod_init(&self.mod, m)
64-
self.ctx = nmod_ctx(mod)
63+
self.ctx = nmod_ctx.get_ctx(mod)
6564
self._is_prime = n_is_prime(m)
6665

6766
def __repr__(self):

0 commit comments

Comments
 (0)