Skip to content

Commit ba9c1fd

Browse files
committed
Inline ctx.any_as_nmod for faster nmod.__mul__
1 parent a88f881 commit ba9c1fd

2 files changed

Lines changed: 28 additions & 3 deletions

File tree

src/flint/types/nmod.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ from flint.flint_base.flint_base cimport flint_scalar
22
from flint.flintlib.flint cimport mp_limb_t
33
from flint.flintlib.nmod cimport nmod_t
44

5-
#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1
5+
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1
66
cdef nmod_ctx any_as_nmod_ctx(obj)
77

88
cdef class nmod_ctx:

src/flint/types/nmod.pyx

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ from flint.utils.flint_exceptions import DomainError
2020
#cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
2121
# return mod.ctx.any_as_nmod(val, obj)
2222

23-
2423
_nmod_ctx_cache = {}
2524

2625

@@ -37,6 +36,32 @@ cdef nmod_ctx any_as_nmod_ctx(obj):
3736
return NotImplemented
3837

3938

39+
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
40+
"""Convert an object to an nmod element."""
41+
cdef int success
42+
cdef fmpz_t t
43+
if typecheck(obj, nmod):
44+
if (<nmod>obj).ctx.mod.n != mod.n:
45+
raise ValueError("cannot coerce integers mod n with different n")
46+
val[0] = (<nmod>obj).val
47+
return 1
48+
z = any_as_fmpz(obj)
49+
if z is not NotImplemented:
50+
val[0] = fmpz_fdiv_ui((<fmpz>z).val, mod.n)
51+
return 1
52+
q = any_as_fmpq(obj)
53+
if q is not NotImplemented:
54+
fmpz_init(t)
55+
fmpz_set_ui(t, mod.n)
56+
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
57+
val[0] = fmpz_get_ui(t)
58+
fmpz_clear(t)
59+
if not success:
60+
raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n))
61+
return 1
62+
return 0
63+
64+
4065
cdef class nmod_ctx:
4166
"""
4267
Context object for creating :class:`~.nmod` initalised
@@ -278,7 +303,7 @@ cdef class nmod(flint_scalar):
278303
cdef nmod r, s2
279304
cdef mp_limb_t val
280305
s2 = s
281-
if s2.ctx.any_as_nmod(&val, t):
306+
if any_as_nmod(&val, t, s2.ctx.mod):
282307
r = nmod.__new__(nmod)
283308
r.ctx = s2.ctx
284309
r.val = nmod_mul(val, s2.val, s2.ctx.mod)

0 commit comments

Comments
 (0)