Skip to content

Commit f47188c

Browse files
committed
Use any_as_nmod rather than ctx.any_as_nmod
1 parent 23389a5 commit f47188c

6 files changed

Lines changed: 95 additions & 108 deletions

File tree

src/flint/test/test_all.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,10 @@ def test_nmod():
13571357
assert str(G(3,5)) == "3"
13581358
assert G(3,5).repr() == "nmod(3, 5)"
13591359

1360+
G = flint.nmod_ctx.get_ctx(7)
1361+
assert G(0) == G(7) == G(-7)
1362+
1363+
13601364
def test_nmod_poly():
13611365
N = flint.nmod
13621366
P = flint.nmod_poly

src/flint/types/arb.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,7 @@ cdef class arb(flint_scalar):
22592259
>>> from flint import showgood
22602260
>>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5, abc=True), dps=25)
22612261
1.447530478120770807945697
2262-
>>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25)
2262+
>>> showgood(lambda: arb("9/10").hypgeom_2f1(arb(2).sqrt(), 0.5, arb(2).sqrt()+1.5), dps=25) # doctest: +SKIP
22632263
Traceback (most recent call last):
22642264
...
22652265
ValueError: no convergence (maxprec=960, try higher maxprec)

src/flint/types/nmod.pxd

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from flint.flint_base.flint_base cimport flint_scalar
2-
from flint.flintlib.flint cimport mp_limb_t
2+
from flint.flintlib.flint cimport mp_limb_t, ulong
33
from flint.flintlib.nmod cimport nmod_t
44

5-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1
6-
cdef nmod_ctx any_as_nmod_ctx(obj)
5+
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx mod) except -1
6+
#cdef nmod_ctx any_as_nmod_ctx(obj)
77

88

99
cdef class nmod_ctx:
1010
cdef nmod_t mod
1111
cdef bint _is_prime
1212

13+
@staticmethod
14+
cdef nmod_ctx any_as_nmod_ctx(obj)
15+
@staticmethod
16+
cdef _get_ctx(int mod)
17+
@staticmethod
18+
cdef _new_ctx(ulong mod)
19+
1320
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
1421
cdef nmod _new(self, mp_limb_t * val)
1522

src/flint/types/nmod.pyx

Lines changed: 77 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -19,112 +19,63 @@ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime, n_sqrtmod
1919
from flint.utils.flint_exceptions import DomainError
2020

2121

22-
#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
23-
# return mod.ctx.any_as_nmod(val, obj)
24-
25-
_nmod_ctx_cache = {}
26-
27-
28-
cdef nmod_ctx any_as_nmod_ctx(obj):
29-
"""Convert an int to an nmod_ctx."""
30-
if typecheck(obj, nmod_ctx):
31-
return obj
32-
if typecheck(obj, int):
33-
ctx = _nmod_ctx_cache.get(obj)
34-
if ctx is None:
35-
ctx = nmod_ctx(obj)
36-
_nmod_ctx_cache[obj] = ctx
37-
return ctx
38-
return NotImplemented
39-
40-
41-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
42-
"""Convert an object to an nmod element."""
43-
cdef int success
44-
cdef fmpz_t t
45-
if typecheck(obj, nmod):
46-
if (<nmod>obj).ctx.mod.n != mod.n:
47-
raise ValueError("cannot coerce integers mod n with different n")
48-
val[0] = (<nmod>obj).val
49-
return 1
50-
z = any_as_fmpz(obj)
51-
if z is not NotImplemented:
52-
val[0] = fmpz_fdiv_ui((<fmpz>z).val, mod.n)
53-
return 1
54-
q = any_as_fmpq(obj)
55-
if q is not NotImplemented:
56-
fmpz_init(t)
57-
fmpz_set_ui(t, mod.n)
58-
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
59-
val[0] = fmpz_get_ui(t)
60-
fmpz_clear(t)
61-
if not success:
62-
raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n))
63-
return 1
64-
return 0
22+
cdef dict _nmod_ctx_cache = {}
6523

6624

6725
cdef class nmod_ctx:
6826
"""
6927
Context object for creating :class:`~.nmod` initalised
7028
with modulus :math:`N`.
7129
72-
>>> nmod_ctx(17)
30+
>>> nmod_ctx.get_ctx(17)
7331
nmod_ctx(17)
7432
7533
"""
76-
def __init__(self, mod):
77-
cdef mp_limb_t m
78-
m = mod
79-
nmod_init(&self.mod, m)
80-
self._is_prime = n_is_prime(m)
8134

82-
def __eq__(self, other):
83-
# XXX: If we could ensure uniqueness of nmod_ctx for given modulus then
84-
# we would need to implement __eq__ and __hash__ at all...
85-
#
86-
# It isn't possible to ensure uniqueness in __new__ like it is in
87-
# Python because we can't return an existing object from __new__. What
88-
# we could do though is make it so that __init__ raises an error and
89-
# use a static method .new() to create new objects.
90-
if self is other:
91-
return True
92-
if not typecheck(other, nmod_ctx):
93-
return NotImplemented
94-
return self.mod.n == (<nmod_ctx>other).mod.n
35+
def __init__(self, *args, **kwargs):
36+
raise TypeError("cannot create nmod_ctx directly: use nmod_ctx.get_ctx()")
37+
38+
@staticmethod
39+
cdef nmod_ctx any_as_nmod_ctx(obj):
40+
"""Convert an int to an nmod_ctx."""
41+
if typecheck(obj, nmod_ctx):
42+
return obj
43+
if typecheck(obj, int):
44+
return nmod_ctx._get_ctx(obj)
45+
return NotImplemented
46+
47+
@staticmethod
48+
def get_ctx(mod):
49+
"""Create a new nmod context."""
50+
return nmod_ctx._get_ctx(mod)
51+
52+
@staticmethod
53+
cdef _get_ctx(int mod):
54+
"""Create a new nmod context."""
55+
ctx = _nmod_ctx_cache.get(mod)
56+
if ctx is None:
57+
_nmod_ctx_cache[mod] = ctx = nmod_ctx._new_ctx(mod)
58+
return ctx
59+
60+
@staticmethod
61+
cdef _new_ctx(ulong mod):
62+
"""Create a new nmod context."""
63+
cdef nmod_ctx ctx = nmod_ctx.__new__(nmod_ctx)
64+
nmod_init(&ctx.mod, mod)
65+
ctx._is_prime = n_is_prime(mod)
66+
return ctx
9567

9668
def __repr__(self):
9769
return f"nmod_ctx({self.modulus()})"
9870

9971
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
10072
"""Convert an object to an nmod element."""
101-
cdef int success
102-
cdef fmpz_t t
103-
if typecheck(obj, nmod):
104-
if (<nmod>obj).ctx != self:
105-
raise ValueError("cannot coerce integers mod n with different n")
106-
val[0] = (<nmod>obj).val
107-
return 1
108-
z = any_as_fmpz(obj)
109-
if z is not NotImplemented:
110-
val[0] = fmpz_fdiv_ui((<fmpz>z).val, self.mod.n)
111-
return 1
112-
q = any_as_fmpq(obj)
113-
if q is not NotImplemented:
114-
fmpz_init(t)
115-
fmpz_set_ui(t, self.mod.n)
116-
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
117-
val[0] = fmpz_get_ui(t)
118-
fmpz_clear(t)
119-
if not success:
120-
raise ZeroDivisionError("%s does not exist mod %i!" % (q, self.mod.n))
121-
return 1
122-
return 0
73+
return any_as_nmod(val, obj, self)
12374

12475
def modulus(self):
12576
"""Get the modulus of the context.
12677
127-
>>> ctx = nmod_ctx(17)
78+
>>> ctx = nmod_ctx.get_ctx(17)
12879
>>> ctx.modulus()
12980
17
13081
@@ -134,7 +85,7 @@ cdef class nmod_ctx:
13485
def is_prime(self):
13586
"""Check if the modulus is prime.
13687
137-
>>> ctx = nmod_ctx(17)
88+
>>> ctx = nmod_ctx.get_ctx(17)
13889
>>> ctx.is_prime()
13990
True
14091
@@ -144,7 +95,7 @@ cdef class nmod_ctx:
14495
def zero(self):
14596
"""Return the zero element of the context.
14697
147-
>>> ctx = nmod_ctx(17)
98+
>>> ctx = nmod_ctx.get_ctx(17)
14899
>>> ctx.zero()
149100
0
150101
@@ -154,7 +105,7 @@ cdef class nmod_ctx:
154105
def one(self):
155106
"""Return the one element of the context.
156107
157-
>>> ctx = nmod_ctx(17)
108+
>>> ctx = nmod_ctx.get_ctx(17)
158109
>>> ctx.one()
159110
1
160111
@@ -185,16 +136,42 @@ cdef class nmod_ctx:
185136
def __call__(self, val):
186137
"""Create an nmod element from an integer.
187138
188-
>>> ctx = nmod_ctx(17)
139+
>>> ctx = nmod_ctx.get_ctx(17)
189140
>>> ctx(10)
190141
10
191142
192143
"""
193144
cdef mp_limb_t v
194-
v = val
145+
v = val % self.mod.n
195146
return self._new(&v)
196147

197148

149+
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except -1:
150+
"""Convert an object to an nmod element."""
151+
cdef int success
152+
cdef fmpz_t t
153+
if typecheck(obj, nmod):
154+
if (<nmod>obj).ctx.mod.n != ctx.mod.n:
155+
raise ValueError("cannot coerce integers mod n with different n")
156+
val[0] = (<nmod>obj).val
157+
return 1
158+
z = any_as_fmpz(obj)
159+
if z is not NotImplemented:
160+
val[0] = fmpz_fdiv_ui((<fmpz>z).val, ctx.mod.n)
161+
return 1
162+
q = any_as_fmpq(obj)
163+
if q is not NotImplemented:
164+
fmpz_init(t)
165+
fmpz_set_ui(t, ctx.mod.n)
166+
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
167+
val[0] = fmpz_get_ui(t)
168+
fmpz_clear(t)
169+
if not success:
170+
raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n))
171+
return 1
172+
return 0
173+
174+
198175
@cython.no_gc
199176
cdef class nmod(flint_scalar):
200177
"""
@@ -205,10 +182,10 @@ cdef class nmod(flint_scalar):
205182
206183
"""
207184
def __init__(self, val, mod):
208-
ctx = any_as_nmod_ctx(mod)
185+
ctx = nmod_ctx.any_as_nmod_ctx(mod)
209186
if ctx is NotImplemented:
210187
raise TypeError("Invalid context/modulus for nmod: %s" % mod)
211-
if not ctx.any_as_nmod(&self.val, val):
188+
if not any_as_nmod(&self.val, val, ctx):
212189
raise TypeError("cannot create nmod from object of type %s" % type(val))
213190
self.ctx = ctx
214191

@@ -262,7 +239,7 @@ cdef class nmod(flint_scalar):
262239
cdef nmod r, s2
263240
cdef mp_limb_t val
264241
s2 = s
265-
if s2.ctx.any_as_nmod(&val, t):
242+
if any_as_nmod(&val, t, s2.ctx):
266243
r = nmod.__new__(nmod)
267244
r.ctx = s2.ctx
268245
r.val = nmod_add(val, s2.val, s2.ctx.mod)
@@ -273,7 +250,7 @@ cdef class nmod(flint_scalar):
273250
cdef nmod r, s2
274251
cdef mp_limb_t val
275252
s2 = s
276-
if s2.ctx.any_as_nmod(&val, t):
253+
if any_as_nmod(&val, t, s2.ctx):
277254
r = nmod.__new__(nmod)
278255
r.ctx = s2.ctx
279256
r.val = nmod_add(s2.val, val, s2.ctx.mod)
@@ -284,7 +261,7 @@ cdef class nmod(flint_scalar):
284261
cdef nmod r, s2
285262
cdef mp_limb_t val
286263
s2 = s
287-
if s2.ctx.any_as_nmod(&val, t):
264+
if any_as_nmod(&val, t, s2.ctx):
288265
r = nmod.__new__(nmod)
289266
r.ctx = s2.ctx
290267
r.val = nmod_sub(s2.val, val, s2.ctx.mod)
@@ -295,7 +272,7 @@ cdef class nmod(flint_scalar):
295272
cdef nmod r
296273
cdef mp_limb_t val
297274
s2 = s
298-
if s2.ctx.any_as_nmod(&val, t):
275+
if any_as_nmod(&val, t, s2.ctx):
299276
r = nmod.__new__(nmod)
300277
r.ctx = s2.ctx
301278
r.val = nmod_sub(val, s2.val, s2.ctx.mod)
@@ -306,7 +283,7 @@ cdef class nmod(flint_scalar):
306283
cdef nmod r, s2
307284
cdef mp_limb_t val
308285
s2 = s
309-
if any_as_nmod(&val, t, s2.ctx.mod):
286+
if any_as_nmod(&val, t, s2.ctx):
310287
r = nmod.__new__(nmod)
311288
r.ctx = s2.ctx
312289
r.val = nmod_mul(val, s2.val, s2.ctx.mod)
@@ -317,7 +294,7 @@ cdef class nmod(flint_scalar):
317294
cdef nmod r, s2
318295
cdef mp_limb_t val
319296
s2 = s
320-
if s2.ctx.any_as_nmod(&val, t):
297+
if any_as_nmod(&val, t, s2.ctx):
321298
r = nmod.__new__(nmod)
322299
r.ctx = s2.ctx
323300
r.val = nmod_mul(s2.val, val, s2.ctx.mod)
@@ -335,13 +312,13 @@ cdef class nmod(flint_scalar):
335312
s2 = s
336313
ctx = s2.ctx
337314
sval = s2.val
338-
if not ctx.any_as_nmod(&tval, t):
315+
if not any_as_nmod(&tval, t, ctx):
339316
return NotImplemented
340317
else:
341318
t2 = t
342319
ctx = t2.ctx
343320
tval = t2.val
344-
if not ctx.any_as_nmod(&sval, s):
321+
if not any_as_nmod(&sval, s, ctx):
345322
return NotImplemented
346323

347324
if tval == 0:

src/flint/types/nmod_mat.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ from flint.flintlib.nmod_mat cimport (
4343
from flint.utils.typecheck cimport typecheck
4444
from flint.types.fmpz_mat cimport any_as_fmpz_mat
4545
from flint.types.fmpz_mat cimport fmpz_mat
46-
from flint.types.nmod cimport nmod, any_as_nmod_ctx
46+
from flint.types.nmod cimport nmod, nmod_ctx
4747
from flint.types.nmod_poly cimport nmod_poly, nmod_poly_new_init, any_as_nmod_poly_ctx
4848
from flint.pyflint cimport global_random_state
4949
from flint.flint_base.flint_context cimport thectx
@@ -120,7 +120,7 @@ cdef class nmod_mat(flint_mat):
120120
if mod == 0:
121121
raise ValueError("modulus must be nonzero")
122122

123-
ctx = any_as_nmod_ctx(mod)
123+
ctx = nmod_ctx.any_as_nmod_ctx(mod)
124124
self.ctx = ctx
125125

126126
if len(args) == 1:

src/flint/types/nmod_poly.pyx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ from flint.utils.typecheck cimport typecheck
44
from flint.types.fmpz cimport fmpz, any_as_fmpz
55
from flint.types.fmpz_poly cimport any_as_fmpz_poly
66
from flint.types.fmpz_poly cimport fmpz_poly
7-
from flint.types.nmod cimport any_as_nmod_ctx
87
from flint.types.nmod cimport nmod, nmod_ctx
98

109
from flint.flintlib.nmod_vec cimport *
@@ -61,7 +60,7 @@ cdef class nmod_poly_ctx:
6160
cdef mp_limb_t m
6261
m = mod
6362
nmod_init(&self.mod, m)
64-
self.ctx = nmod_ctx(mod)
63+
self.ctx = nmod_ctx.get_ctx(mod)
6564
self._is_prime = n_is_prime(m)
6665

6766
def __repr__(self):

0 commit comments

Comments
 (0)