Skip to content

Commit 27e5e4c

Browse files
committed
Add nmod_poly_ctx
1 parent 293ec24 commit 27e5e4c

2 files changed

Lines changed: 115 additions & 55 deletions

File tree

src/flint/types/nmod_poly.pxd

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1-
from flint.flint_base.flint_base cimport flint_poly
2-
1+
from flint.flintlib.nmod cimport nmod_t
32
from flint.flintlib.nmod_poly cimport nmod_poly_t
43
from flint.flintlib.flint cimport mp_limb_t
54

5+
from flint.flint_base.flint_base cimport flint_poly
6+
7+
from flint.types.nmod cimport nmod_ctx
8+
9+
10+
cdef class nmod_poly_ctx:
11+
cdef nmod_ctx ctx
12+
cdef nmod_t mod
13+
cdef bint _is_prime
14+
15+
cdef nmod_poly_set_list(self, nmod_poly_t poly, list val)
16+
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1
17+
cdef any_as_nmod_poly(self, obj)
18+
19+
620
cdef class nmod_poly(flint_poly):
721
cdef nmod_poly_t val
22+
cdef nmod_poly_ctx ctx
23+
824
cpdef long length(self)
925
cpdef long degree(self)
1026
cpdef mp_limb_t modulus(self)

src/flint/types/nmod_poly.pyx

Lines changed: 97 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
55
from flint.types.fmpz_poly cimport any_as_fmpz_poly
66
from flint.types.fmpz_poly cimport fmpz_poly
77
from flint.types.nmod cimport any_as_nmod_ctx
8-
from flint.types.nmod cimport nmod
8+
from flint.types.nmod cimport nmod, nmod_ctx
99

1010
from flint.flintlib.nmod_vec cimport *
1111
from flint.flintlib.nmod_poly cimport *
1212
from flint.flintlib.nmod_poly_factor cimport *
1313
from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
14+
from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime
1415

1516
from flint.utils.flint_exceptions import DomainError
1617

1718

18-
cdef any_as_nmod_poly(obj, nmod_t mod):
19-
cdef nmod_poly r
20-
cdef mp_limb_t v
21-
# XXX: should check that modulus is the same here, and not all over the place
22-
if typecheck(obj, nmod_poly):
19+
_nmod_poly_ctx_cache = {}
20+
21+
22+
cdef nmod_ctx any_as_nmod_poly_ctx(obj):
23+
"""Convert an int to an nmod_ctx."""
24+
if typecheck(obj, nmod_poly_ctx):
2325
return obj
24-
if any_as_nmod(&v, obj, mod):
25-
r = nmod_poly.__new__(nmod_poly)
26-
nmod_poly_init(r.val, mod.n)
27-
nmod_poly_set_coeff_ui(r.val, 0, v)
28-
return r
29-
x = any_as_fmpz_poly(obj)
30-
if x is not NotImplemented:
31-
r = nmod_poly.__new__(nmod_poly)
32-
nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
33-
fmpz_poly_get_nmod_poly(r.val, (<fmpz_poly>x).val)
34-
return r
26+
if typecheck(obj, int):
27+
ctx = _nmod_poly_ctx_cache.get(obj)
28+
if ctx is None:
29+
ctx = nmod_poly_ctx(obj)
30+
_nmod_poly_ctx_cache[obj] = ctx
31+
return ctx
3532
return NotImplemented
3633

37-
cdef nmod_poly_set_list(nmod_poly_t poly, list val):
38-
cdef long i, n
39-
cdef nmod_t mod
40-
cdef mp_limb_t v
41-
nmod_init(&mod, nmod_poly_modulus(poly)) # XXX
42-
n = PyList_GET_SIZE(val)
43-
nmod_poly_fit_length(poly, n)
44-
for i from 0 <= i < n:
45-
if any_as_nmod(&v, val[i], mod):
46-
nmod_poly_set_coeff_ui(poly, i, v)
47-
else:
48-
raise TypeError("unsupported coefficient in list")
34+
35+
cdef class nmod_poly_ctx:
36+
"""
37+
Context object for creating :class:`~.nmod_poly` initalised
38+
with modulus :math:`N`.
39+
40+
>>> nmod_ctx(17)
41+
nmod_ctx(17)
42+
43+
"""
44+
def __init__(self, mod):
45+
cdef mp_limb_t m
46+
m = mod
47+
nmod_init(&self.mod, m)
48+
self.ctx = nmod_ctx(mod)
49+
self._is_prime = n_is_prime(m)
50+
51+
cdef int any_as_nmod(self, mp_limb_t * val, obj) except -1:
52+
return self.ctx.any_as_nmod(val, obj)
53+
54+
cdef any_as_nmod_poly(self, obj):
55+
cdef nmod_poly r
56+
cdef mp_limb_t v
57+
# XXX: should check that modulus is the same here, and not all over the place
58+
if typecheck(obj, nmod_poly):
59+
return obj
60+
if self.ctx.any_as_nmod(&v, obj):
61+
r = nmod_poly.__new__(nmod_poly)
62+
nmod_poly_init(r.val, self.mod.n)
63+
nmod_poly_set_coeff_ui(r.val, 0, v)
64+
return r
65+
x = any_as_fmpz_poly(obj)
66+
if x is not NotImplemented:
67+
r = nmod_poly.__new__(nmod_poly)
68+
nmod_poly_init(r.val, self.mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
69+
fmpz_poly_get_nmod_poly(r.val, (<fmpz_poly>x).val)
70+
return r
71+
return NotImplemented
72+
73+
cdef nmod_poly_set_list(self, nmod_poly_t poly, list val):
74+
cdef long i, n
75+
cdef mp_limb_t v
76+
n = PyList_GET_SIZE(val)
77+
nmod_poly_fit_length(poly, n)
78+
for i from 0 <= i < n:
79+
c = val[i]
80+
if self.any_as_nmod(&v, val[i]):
81+
nmod_poly_set_coeff_ui(poly, i, v)
82+
else:
83+
raise TypeError("unsupported coefficient in list")
84+
4985

5086
cdef class nmod_poly(flint_poly):
5187
"""
@@ -77,24 +113,32 @@ cdef class nmod_poly(flint_poly):
77113
def __dealloc__(self):
78114
nmod_poly_clear(self.val)
79115

80-
def __init__(self, val=None, ulong mod=0):
116+
def __init__(self, val=None, mod=0):
81117
cdef ulong m2
82118
cdef mp_limb_t v
119+
cdef nmod_poly_ctx ctx
120+
83121
if typecheck(val, nmod_poly):
84122
m2 = nmod_poly_modulus((<nmod_poly>val).val)
85123
if m2 != mod:
86124
raise ValueError("different moduli!")
87125
nmod_poly_init(self.val, m2)
88126
nmod_poly_set(self.val, (<nmod_poly>val).val)
127+
self.ctx = (<nmod_poly>val).ctx
89128
else:
90129
if mod == 0:
91130
raise ValueError("a nonzero modulus is required")
92-
nmod_poly_init(self.val, mod)
131+
ctx = any_as_nmod_poly_ctx(mod)
132+
if ctx is NotImplemented:
133+
raise TypeError("cannot create nmod_poly_ctx from input of type %s", type(mod))
134+
135+
self.ctx = ctx
136+
nmod_poly_init(self.val, ctx.mod.n)
93137
if typecheck(val, fmpz_poly):
94138
fmpz_poly_get_nmod_poly(self.val, (<fmpz_poly>val).val)
95139
elif typecheck(val, list):
96-
nmod_poly_set_list(self.val, val)
97-
elif any_as_nmod(&v, val, self.val.mod):
140+
ctx.nmod_poly_set_list(self.val, val)
141+
elif ctx.any_as_nmod(&v, val):
98142
nmod_poly_fit_length(self.val, 1)
99143
nmod_poly_set_coeff_ui(self.val, 0, v)
100144
else:
@@ -175,7 +219,7 @@ cdef class nmod_poly(flint_poly):
175219
cdef mp_limb_t v
176220
if i < 0:
177221
raise ValueError("cannot assign to index < 0 of polynomial")
178-
if any_as_nmod(&v, x, self.val.mod):
222+
if self.ctx.any_as_nmod(&v, x):
179223
nmod_poly_set_coeff_ui(self.val, i, v)
180224
else:
181225
raise TypeError("cannot set element of type %s" % type(x))
@@ -288,7 +332,7 @@ cdef class nmod_poly(flint_poly):
288332
9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
289333
"""
290334
cdef nmod_poly res
291-
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
335+
other = self.ctx.any_as_nmod_poly(other)
292336
if other is NotImplemented:
293337
raise TypeError("cannot convert input to nmod_poly")
294338
res = nmod_poly.__new__(nmod_poly)
@@ -313,11 +357,11 @@ cdef class nmod_poly(flint_poly):
313357
147*x^3 + 159*x^2 + 4*x + 7
314358
"""
315359
cdef nmod_poly res
316-
g = any_as_nmod_poly(other, self.val.mod)
360+
g = self.ctx.any_as_nmod_poly(other)
317361
if g is NotImplemented:
318362
raise TypeError(f"cannot convert other = {other} to nmod_poly")
319363

320-
h = any_as_nmod_poly(modulus, self.val.mod)
364+
h = self.any_as_nmod_poly(modulus, self.val.mod)
321365
if h is NotImplemented:
322366
raise TypeError(f"cannot convert modulus = {modulus} to nmod_poly")
323367

@@ -331,11 +375,11 @@ cdef class nmod_poly(flint_poly):
331375

332376
def __call__(self, other):
333377
cdef mp_limb_t c
334-
if any_as_nmod(&c, other, self.val.mod):
378+
if self.ctx.any_as_nmod(&c, other):
335379
v = nmod(0, self.modulus())
336380
(<nmod>v).val = nmod_poly_evaluate_nmod(self.val, c)
337381
return v
338-
t = any_as_nmod_poly(other, self.val.mod)
382+
t = self.ctx.any_as_nmod_poly(other)
339383
if t is not NotImplemented:
340384
r = nmod_poly.__new__(nmod_poly)
341385
nmod_poly_init_preinv((<nmod_poly>r).val, self.val.mod.n, self.val.mod.ninv)
@@ -366,7 +410,7 @@ cdef class nmod_poly(flint_poly):
366410

367411
def _add_(s, t):
368412
cdef nmod_poly r
369-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
413+
t = s.ctx.any_as_nmod_poly(t)
370414
if t is NotImplemented:
371415
return t
372416
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
@@ -392,20 +436,20 @@ cdef class nmod_poly(flint_poly):
392436
return r
393437

394438
def __sub__(s, t):
395-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
439+
t = s.ctx.any_as_nmod_poly(t)
396440
if t is NotImplemented:
397441
return t
398442
return s._sub_(t)
399443

400444
def __rsub__(s, t):
401-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
445+
t = s.any_as_nmod_poly(t)
402446
if t is NotImplemented:
403447
return t
404448
return t._sub_(s)
405449

406450
def _mul_(s, t):
407451
cdef nmod_poly r
408-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
452+
t = s.any_as_nmod_poly(t)
409453
if t is NotImplemented:
410454
return t
411455
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
@@ -422,7 +466,7 @@ cdef class nmod_poly(flint_poly):
422466
return s._mul_(t)
423467

424468
def __truediv__(s, t):
425-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
469+
t = s.any_as_nmod_poly(t)
426470
if t is NotImplemented:
427471
return t
428472
res, r = s._divmod_(t)
@@ -431,7 +475,7 @@ cdef class nmod_poly(flint_poly):
431475
return res
432476

433477
def __rtruediv__(s, t):
434-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
478+
t = s.any_as_nmod_poly(t)
435479
if t is NotImplemented:
436480
return t
437481
res, r = t._divmod_(s)
@@ -451,13 +495,13 @@ cdef class nmod_poly(flint_poly):
451495
return r
452496

453497
def __floordiv__(s, t):
454-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
498+
t = s.any_as_nmod_poly(t)
455499
if t is NotImplemented:
456500
return t
457501
return s._floordiv_(t)
458502

459503
def __rfloordiv__(s, t):
460-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
504+
t = s.any_as_nmod_poly(t)
461505
if t is NotImplemented:
462506
return t
463507
return t._floordiv_(s)
@@ -476,13 +520,13 @@ cdef class nmod_poly(flint_poly):
476520
return P, Q
477521

478522
def __divmod__(s, t):
479-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
523+
t = s.any_as_nmod_poly(t)
480524
if t is NotImplemented:
481525
return t
482526
return s._divmod_(t)
483527

484528
def __rdivmod__(s, t):
485-
t = any_as_nmod_poly(t, (<nmod_poly>s).val.mod)
529+
t = s.any_as_nmod_poly(t)
486530
if t is NotImplemented:
487531
return t
488532
return t._divmod_(s)
@@ -531,7 +575,7 @@ cdef class nmod_poly(flint_poly):
531575
if e < 0:
532576
raise ValueError("Exponent must be non-negative")
533577

534-
modulus = any_as_nmod_poly(modulus, (<nmod_poly>self).val.mod)
578+
modulus = self.ctx.any_as_nmod_poly(modulus)
535579
if modulus is NotImplemented:
536580
raise TypeError("cannot convert input to nmod_poly")
537581

@@ -553,7 +597,7 @@ cdef class nmod_poly(flint_poly):
553597

554598
# To optimise powering, we precompute the inverse of the reverse of the modulus
555599
if mod_rev_inv is not None:
556-
mod_rev_inv = any_as_nmod_poly(mod_rev_inv, (<nmod_poly>self).val.mod)
600+
mod_rev_inv = self.any_as_nmod_poly(mod_rev_inv)
557601
if mod_rev_inv is NotImplemented:
558602
raise TypeError(f"Cannot interpret {mod_rev_inv} as a polynomial")
559603
else:
@@ -582,7 +626,7 @@ cdef class nmod_poly(flint_poly):
582626
583627
"""
584628
cdef nmod_poly res
585-
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
629+
other = self.any_as_nmod_poly(other)
586630
if other is NotImplemented:
587631
raise TypeError("cannot convert input to nmod_poly")
588632
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
@@ -594,7 +638,7 @@ cdef class nmod_poly(flint_poly):
594638

595639
def xgcd(self, other):
596640
cdef nmod_poly res1, res2, res3
597-
other = any_as_nmod_poly(other, (<nmod_poly>self).val.mod)
641+
other = self.any_as_nmod_poly(other)
598642
if other is NotImplemented:
599643
raise TypeError("cannot convert input to fmpq_poly")
600644
res1 = nmod_poly.__new__(nmod_poly)

0 commit comments

Comments
 (0)