Skip to content

Commit bf8c48b

Browse files
committed
perf: use @cython.final for nmod_ctx.any_as_nmod
1 parent 8358f13 commit bf8c48b

2 files changed

Lines changed: 35 additions & 51 deletions

File tree

src/flint/types/nmod.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ from flint.flint_base.flint_base cimport flint_scalar
22
from flint.flintlib.flint cimport mp_limb_t, ulong
33
from flint.flintlib.nmod cimport nmod_t
44

5-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx mod) except -1
6-
#cdef nmod_ctx any_as_nmod_ctx(obj)
7-
85

96
cdef class nmod_ctx:
107
cdef nmod_t mod

src/flint/types/nmod.pyx

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,35 @@ cdef class nmod_ctx:
6565
ctx._is_prime = n_is_prime(mod)
6666
return ctx
6767

68+
@cython.final
69+
cdef int any_as_nmod(nmod_ctx ctx, mp_limb_t * val, obj) except -1:
70+
"""Convert an object to an nmod element."""
71+
cdef int success
72+
cdef fmpz_t t
73+
if typecheck(obj, nmod):
74+
if (<nmod>obj).ctx.mod.n != ctx.mod.n:
75+
raise ValueError("cannot coerce integers mod n with different n")
76+
val[0] = (<nmod>obj).val
77+
return 1
78+
z = any_as_fmpz(obj)
79+
if z is not NotImplemented:
80+
val[0] = fmpz_fdiv_ui((<fmpz>z).val, ctx.mod.n)
81+
return 1
82+
q = any_as_fmpq(obj)
83+
if q is not NotImplemented:
84+
fmpz_init(t)
85+
fmpz_set_ui(t, ctx.mod.n)
86+
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
87+
val[0] = fmpz_get_ui(t)
88+
fmpz_clear(t)
89+
if not success:
90+
raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n))
91+
return 1
92+
return 0
93+
6894
def __repr__(self):
6995
return f"nmod_ctx({self.modulus()})"
7096

71-
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
72-
"""Convert an object to an nmod element."""
73-
return any_as_nmod(val, obj, self)
74-
7597
def modulus(self):
7698
"""Get the modulus of the context.
7799
@@ -112,15 +134,6 @@ cdef class nmod_ctx:
112134
"""
113135
return self(1)
114136

115-
def __hash__(self):
116-
return hash(self.mod)
117-
118-
def __eq__(self, other):
119-
if typecheck(other, nmod_ctx):
120-
return self.mod.n == (<nmod_ctx>other).mod.n
121-
else:
122-
return NotImplemented
123-
124137
def __str__(self):
125138
return f"Context for nmod with modulus: {self.modulus()}"
126139

@@ -146,32 +159,6 @@ cdef class nmod_ctx:
146159
return self._new(&v)
147160

148161

149-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except -1:
150-
"""Convert an object to an nmod element."""
151-
cdef int success
152-
cdef fmpz_t t
153-
if typecheck(obj, nmod):
154-
if (<nmod>obj).ctx.mod.n != ctx.mod.n:
155-
raise ValueError("cannot coerce integers mod n with different n")
156-
val[0] = (<nmod>obj).val
157-
return 1
158-
z = any_as_fmpz(obj)
159-
if z is not NotImplemented:
160-
val[0] = fmpz_fdiv_ui((<fmpz>z).val, ctx.mod.n)
161-
return 1
162-
q = any_as_fmpq(obj)
163-
if q is not NotImplemented:
164-
fmpz_init(t)
165-
fmpz_set_ui(t, ctx.mod.n)
166-
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
167-
val[0] = fmpz_get_ui(t)
168-
fmpz_clear(t)
169-
if not success:
170-
raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n))
171-
return 1
172-
return 0
173-
174-
175162
@cython.no_gc
176163
cdef class nmod(flint_scalar):
177164
"""
@@ -185,7 +172,7 @@ cdef class nmod(flint_scalar):
185172
ctx = nmod_ctx.any_as_nmod_ctx(mod)
186173
if ctx is NotImplemented:
187174
raise TypeError("Invalid context/modulus for nmod: %s" % mod)
188-
if not any_as_nmod(&self.val, val, ctx):
175+
if not ctx.any_as_nmod(&self.val, val):
189176
raise TypeError("cannot create nmod from object of type %s" % type(val))
190177
self.ctx = ctx
191178

@@ -240,7 +227,7 @@ cdef class nmod(flint_scalar):
240227
cdef nmod r, s2
241228
cdef mp_limb_t val
242229
s2 = s
243-
if any_as_nmod(&val, t, s2.ctx):
230+
if s2.ctx.any_as_nmod(&val, t):
244231
r = nmod.__new__(nmod)
245232
r.ctx = s2.ctx
246233
r.val = nmod_add(val, s2.val, s2.ctx.mod)
@@ -251,7 +238,7 @@ cdef class nmod(flint_scalar):
251238
cdef nmod r, s2
252239
cdef mp_limb_t val
253240
s2 = s
254-
if any_as_nmod(&val, t, s2.ctx):
241+
if s2.ctx.any_as_nmod(&val, t):
255242
r = nmod.__new__(nmod)
256243
r.ctx = s2.ctx
257244
r.val = nmod_add(s2.val, val, s2.ctx.mod)
@@ -262,7 +249,7 @@ cdef class nmod(flint_scalar):
262249
cdef nmod r, s2
263250
cdef mp_limb_t val
264251
s2 = s
265-
if any_as_nmod(&val, t, s2.ctx):
252+
if s2.ctx.any_as_nmod(&val, t):
266253
r = nmod.__new__(nmod)
267254
r.ctx = s2.ctx
268255
r.val = nmod_sub(s2.val, val, s2.ctx.mod)
@@ -273,7 +260,7 @@ cdef class nmod(flint_scalar):
273260
cdef nmod r
274261
cdef mp_limb_t val
275262
s2 = s
276-
if any_as_nmod(&val, t, s2.ctx):
263+
if s2.ctx.any_as_nmod(&val, t):
277264
r = nmod.__new__(nmod)
278265
r.ctx = s2.ctx
279266
r.val = nmod_sub(val, s2.val, s2.ctx.mod)
@@ -284,7 +271,7 @@ cdef class nmod(flint_scalar):
284271
cdef nmod r, s2
285272
cdef mp_limb_t val
286273
s2 = s
287-
if any_as_nmod(&val, t, s2.ctx):
274+
if s2.ctx.any_as_nmod(&val, t):
288275
r = nmod.__new__(nmod)
289276
r.ctx = s2.ctx
290277
r.val = nmod_mul(val, s2.val, s2.ctx.mod)
@@ -295,7 +282,7 @@ cdef class nmod(flint_scalar):
295282
cdef nmod r, s2
296283
cdef mp_limb_t val
297284
s2 = s
298-
if any_as_nmod(&val, t, s2.ctx):
285+
if s2.ctx.any_as_nmod(&val, t):
299286
r = nmod.__new__(nmod)
300287
r.ctx = s2.ctx
301288
r.val = nmod_mul(s2.val, val, s2.ctx.mod)
@@ -313,13 +300,13 @@ cdef class nmod(flint_scalar):
313300
s2 = s
314301
ctx = s2.ctx
315302
sval = s2.val
316-
if not any_as_nmod(&tval, t, ctx):
303+
if not ctx.any_as_nmod(&tval, t):
317304
return NotImplemented
318305
else:
319306
t2 = t
320307
ctx = t2.ctx
321308
tval = t2.val
322-
if not any_as_nmod(&sval, s, ctx):
309+
if not ctx.any_as_nmod(&sval, s):
323310
return NotImplemented
324311

325312
if tval == 0:

0 commit comments

Comments
 (0)