Skip to content

Commit 154b035

Browse files
committed
perf: use @cython.final for nmod_ctx.any_as_nmod
1 parent f47188c commit 154b035

2 files changed

Lines changed: 35 additions & 51 deletions

File tree

src/flint/types/nmod.pxd

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ from flint.flint_base.flint_base cimport flint_scalar
22
from flint.flintlib.flint cimport mp_limb_t, ulong
33
from flint.flintlib.nmod cimport nmod_t
44

5-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx mod) except -1
6-
#cdef nmod_ctx any_as_nmod_ctx(obj)
7-
85

96
cdef class nmod_ctx:
107
cdef nmod_t mod

src/flint/types/nmod.pyx

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,35 @@ cdef class nmod_ctx:
6565
ctx._is_prime = n_is_prime(mod)
6666
return ctx
6767

68+
@cython.final
69+
cdef int any_as_nmod(nmod_ctx ctx, mp_limb_t * val, obj) except -1:
70+
"""Convert an object to an nmod element."""
71+
cdef int success
72+
cdef fmpz_t t
73+
if typecheck(obj, nmod):
74+
if (<nmod>obj).ctx.mod.n != ctx.mod.n:
75+
raise ValueError("cannot coerce integers mod n with different n")
76+
val[0] = (<nmod>obj).val
77+
return 1
78+
z = any_as_fmpz(obj)
79+
if z is not NotImplemented:
80+
val[0] = fmpz_fdiv_ui((<fmpz>z).val, ctx.mod.n)
81+
return 1
82+
q = any_as_fmpq(obj)
83+
if q is not NotImplemented:
84+
fmpz_init(t)
85+
fmpz_set_ui(t, ctx.mod.n)
86+
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
87+
val[0] = fmpz_get_ui(t)
88+
fmpz_clear(t)
89+
if not success:
90+
raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n))
91+
return 1
92+
return 0
93+
6894
def __repr__(self):
6995
return f"nmod_ctx({self.modulus()})"
7096

71-
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
72-
"""Convert an object to an nmod element."""
73-
return any_as_nmod(val, obj, self)
74-
7597
def modulus(self):
7698
"""Get the modulus of the context.
7799
@@ -112,15 +134,6 @@ cdef class nmod_ctx:
112134
"""
113135
return self(1)
114136

115-
def __hash__(self):
116-
return hash(self.mod)
117-
118-
def __eq__(self, other):
119-
if typecheck(other, nmod_ctx):
120-
return self.mod.n == (<nmod_ctx>other).mod.n
121-
else:
122-
return NotImplemented
123-
124137
def __str__(self):
125138
return f"Context for nmod with modulus: {self.modulus()}"
126139

@@ -146,32 +159,6 @@ cdef class nmod_ctx:
146159
return self._new(&v)
147160

148161

149-
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_ctx ctx) except -1:
150-
"""Convert an object to an nmod element."""
151-
cdef int success
152-
cdef fmpz_t t
153-
if typecheck(obj, nmod):
154-
if (<nmod>obj).ctx.mod.n != ctx.mod.n:
155-
raise ValueError("cannot coerce integers mod n with different n")
156-
val[0] = (<nmod>obj).val
157-
return 1
158-
z = any_as_fmpz(obj)
159-
if z is not NotImplemented:
160-
val[0] = fmpz_fdiv_ui((<fmpz>z).val, ctx.mod.n)
161-
return 1
162-
q = any_as_fmpq(obj)
163-
if q is not NotImplemented:
164-
fmpz_init(t)
165-
fmpz_set_ui(t, ctx.mod.n)
166-
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
167-
val[0] = fmpz_get_ui(t)
168-
fmpz_clear(t)
169-
if not success:
170-
raise ZeroDivisionError("%s does not exist mod %i!" % (q, ctx.mod.n))
171-
return 1
172-
return 0
173-
174-
175162
@cython.no_gc
176163
cdef class nmod(flint_scalar):
177164
"""
@@ -185,7 +172,7 @@ cdef class nmod(flint_scalar):
185172
ctx = nmod_ctx.any_as_nmod_ctx(mod)
186173
if ctx is NotImplemented:
187174
raise TypeError("Invalid context/modulus for nmod: %s" % mod)
188-
if not any_as_nmod(&self.val, val, ctx):
175+
if not ctx.any_as_nmod(&self.val, val):
189176
raise TypeError("cannot create nmod from object of type %s" % type(val))
190177
self.ctx = ctx
191178

@@ -239,7 +226,7 @@ cdef class nmod(flint_scalar):
239226
cdef nmod r, s2
240227
cdef mp_limb_t val
241228
s2 = s
242-
if any_as_nmod(&val, t, s2.ctx):
229+
if s2.ctx.any_as_nmod(&val, t):
243230
r = nmod.__new__(nmod)
244231
r.ctx = s2.ctx
245232
r.val = nmod_add(val, s2.val, s2.ctx.mod)
@@ -250,7 +237,7 @@ cdef class nmod(flint_scalar):
250237
cdef nmod r, s2
251238
cdef mp_limb_t val
252239
s2 = s
253-
if any_as_nmod(&val, t, s2.ctx):
240+
if s2.ctx.any_as_nmod(&val, t):
254241
r = nmod.__new__(nmod)
255242
r.ctx = s2.ctx
256243
r.val = nmod_add(s2.val, val, s2.ctx.mod)
@@ -261,7 +248,7 @@ cdef class nmod(flint_scalar):
261248
cdef nmod r, s2
262249
cdef mp_limb_t val
263250
s2 = s
264-
if any_as_nmod(&val, t, s2.ctx):
251+
if s2.ctx.any_as_nmod(&val, t):
265252
r = nmod.__new__(nmod)
266253
r.ctx = s2.ctx
267254
r.val = nmod_sub(s2.val, val, s2.ctx.mod)
@@ -272,7 +259,7 @@ cdef class nmod(flint_scalar):
272259
cdef nmod r
273260
cdef mp_limb_t val
274261
s2 = s
275-
if any_as_nmod(&val, t, s2.ctx):
262+
if s2.ctx.any_as_nmod(&val, t):
276263
r = nmod.__new__(nmod)
277264
r.ctx = s2.ctx
278265
r.val = nmod_sub(val, s2.val, s2.ctx.mod)
@@ -283,7 +270,7 @@ cdef class nmod(flint_scalar):
283270
cdef nmod r, s2
284271
cdef mp_limb_t val
285272
s2 = s
286-
if any_as_nmod(&val, t, s2.ctx):
273+
if s2.ctx.any_as_nmod(&val, t):
287274
r = nmod.__new__(nmod)
288275
r.ctx = s2.ctx
289276
r.val = nmod_mul(val, s2.val, s2.ctx.mod)
@@ -294,7 +281,7 @@ cdef class nmod(flint_scalar):
294281
cdef nmod r, s2
295282
cdef mp_limb_t val
296283
s2 = s
297-
if any_as_nmod(&val, t, s2.ctx):
284+
if s2.ctx.any_as_nmod(&val, t):
298285
r = nmod.__new__(nmod)
299286
r.ctx = s2.ctx
300287
r.val = nmod_mul(s2.val, val, s2.ctx.mod)
@@ -312,13 +299,13 @@ cdef class nmod(flint_scalar):
312299
s2 = s
313300
ctx = s2.ctx
314301
sval = s2.val
315-
if not any_as_nmod(&tval, t, ctx):
302+
if not ctx.any_as_nmod(&tval, t):
316303
return NotImplemented
317304
else:
318305
t2 = t
319306
ctx = t2.ctx
320307
tval = t2.val
321-
if not any_as_nmod(&sval, s, ctx):
308+
if not ctx.any_as_nmod(&sval, s):
322309
return NotImplemented
323310

324311
if tval == 0:

0 commit comments

Comments
 (0)