@@ -5,47 +5,83 @@ from flint.types.fmpz cimport fmpz, any_as_fmpz
55from flint.types.fmpz_poly cimport any_as_fmpz_poly
66from flint.types.fmpz_poly cimport fmpz_poly
77from flint.types.nmod cimport any_as_nmod_ctx
8- from flint.types.nmod cimport nmod
8+ from flint.types.nmod cimport nmod, nmod_ctx
99
1010from flint.flintlib.nmod_vec cimport *
1111from flint.flintlib.nmod_poly cimport *
1212from flint.flintlib.nmod_poly_factor cimport *
1313from flint.flintlib.fmpz_poly cimport fmpz_poly_get_nmod_poly
14+ from flint.flintlib.ulong_extras cimport n_gcdinv, n_is_prime
1415
1516from flint.utils.flint_exceptions import DomainError
1617
1718
18- cdef any_as_nmod_poly(obj, nmod_t mod):
19- cdef nmod_poly r
20- cdef mp_limb_t v
21- # XXX: should check that modulus is the same here, and not all over the place
22- if typecheck(obj, nmod_poly):
19+ _nmod_poly_ctx_cache = {}
20+
21+
22+ cdef nmod_ctx any_as_nmod_poly_ctx(obj):
23+ """ Convert an int to an nmod_ctx."""
24+ if typecheck(obj, nmod_poly_ctx):
2325 return obj
24- if any_as_nmod(& v, obj, mod):
25- r = nmod_poly.__new__ (nmod_poly)
26- nmod_poly_init(r.val, mod.n)
27- nmod_poly_set_coeff_ui(r.val, 0 , v)
28- return r
29- x = any_as_fmpz_poly(obj)
30- if x is not NotImplemented :
31- r = nmod_poly.__new__ (nmod_poly)
32- nmod_poly_init(r.val, mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
33- fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
34- return r
26+ if typecheck(obj, int ):
27+ ctx = _nmod_poly_ctx_cache.get(obj)
28+ if ctx is None :
29+ ctx = nmod_poly_ctx(obj)
30+ _nmod_poly_ctx_cache[obj] = ctx
31+ return ctx
3532 return NotImplemented
3633
37- cdef nmod_poly_set_list(nmod_poly_t poly, list val):
38- cdef long i, n
39- cdef nmod_t mod
40- cdef mp_limb_t v
41- nmod_init(& mod, nmod_poly_modulus(poly)) # XXX
42- n = PyList_GET_SIZE(val)
43- nmod_poly_fit_length(poly, n)
44- for i from 0 <= i < n:
45- if any_as_nmod(& v, val[i], mod):
46- nmod_poly_set_coeff_ui(poly, i, v)
47- else :
48- raise TypeError (" unsupported coefficient in list" )
34+
35+ cdef class nmod_poly_ctx:
36+ """
37+ Context object for creating :class:`~.nmod_poly` initalised
38+ with modulus :math:`N`.
39+
40+ >>> nmod_ctx(17)
41+ nmod_ctx(17)
42+
43+ """
44+ def __init__ (self , mod ):
45+ cdef mp_limb_t m
46+ m = mod
47+ nmod_init(& self .mod, m)
48+ self .ctx = nmod_ctx(mod)
49+ self ._is_prime = n_is_prime(m)
50+
51+ cdef int any_as_nmod(self , mp_limb_t * val, obj) except - 1 :
52+ return self .ctx.any_as_nmod(val, obj)
53+
54+ cdef any_as_nmod_poly(self , obj):
55+ cdef nmod_poly r
56+ cdef mp_limb_t v
57+ # XXX: should check that modulus is the same here, and not all over the place
58+ if typecheck(obj, nmod_poly):
59+ return obj
60+ if self .ctx.any_as_nmod(& v, obj):
61+ r = nmod_poly.__new__ (nmod_poly)
62+ nmod_poly_init(r.val, self .mod.n)
63+ nmod_poly_set_coeff_ui(r.val, 0 , v)
64+ return r
65+ x = any_as_fmpz_poly(obj)
66+ if x is not NotImplemented :
67+ r = nmod_poly.__new__ (nmod_poly)
68+ nmod_poly_init(r.val, self .mod.n) # XXX: create flint _nmod_poly_set_modulus for this?
69+ fmpz_poly_get_nmod_poly(r.val, (< fmpz_poly> x).val)
70+ return r
71+ return NotImplemented
72+
73+ cdef nmod_poly_set_list(self , nmod_poly_t poly, list val):
74+ cdef long i, n
75+ cdef mp_limb_t v
76+ n = PyList_GET_SIZE(val)
77+ nmod_poly_fit_length(poly, n)
78+ for i from 0 <= i < n:
79+ c = val[i]
80+ if self .any_as_nmod(& v, val[i]):
81+ nmod_poly_set_coeff_ui(poly, i, v)
82+ else :
83+ raise TypeError (" unsupported coefficient in list" )
84+
4985
5086cdef class nmod_poly(flint_poly):
5187 """
@@ -77,24 +113,32 @@ cdef class nmod_poly(flint_poly):
77113 def __dealloc__ (self ):
78114 nmod_poly_clear(self .val)
79115
80- def __init__ (self , val = None , ulong mod = 0 ):
116+ def __init__ (self , val = None , mod = 0 ):
81117 cdef ulong m2
82118 cdef mp_limb_t v
119+ cdef nmod_poly_ctx ctx
120+
83121 if typecheck(val, nmod_poly):
84122 m2 = nmod_poly_modulus((< nmod_poly> val).val)
85123 if m2 != mod:
86124 raise ValueError (" different moduli!" )
87125 nmod_poly_init(self .val, m2)
88126 nmod_poly_set(self .val, (< nmod_poly> val).val)
127+ self .ctx = (< nmod_poly> val).ctx
89128 else :
90129 if mod == 0 :
91130 raise ValueError (" a nonzero modulus is required" )
92- nmod_poly_init(self .val, mod)
131+ ctx = any_as_nmod_poly_ctx(mod)
132+ if ctx is NotImplemented :
133+ raise TypeError (" cannot create nmod_poly_ctx from input of type %s " , type (mod))
134+
135+ self .ctx = ctx
136+ nmod_poly_init(self .val, ctx.mod.n)
93137 if typecheck(val, fmpz_poly):
94138 fmpz_poly_get_nmod_poly(self .val, (< fmpz_poly> val).val)
95139 elif typecheck(val, list ):
96- nmod_poly_set_list(self .val, val)
97- elif any_as_nmod(& v, val, self .val.mod ):
140+ ctx. nmod_poly_set_list(self .val, val)
141+ elif ctx. any_as_nmod(& v, val):
98142 nmod_poly_fit_length(self .val, 1 )
99143 nmod_poly_set_coeff_ui(self .val, 0 , v)
100144 else :
@@ -175,7 +219,7 @@ cdef class nmod_poly(flint_poly):
175219 cdef mp_limb_t v
176220 if i < 0 :
177221 raise ValueError (" cannot assign to index < 0 of polynomial" )
178- if any_as_nmod(& v, x, self .val.mod ):
222+ if self .ctx. any_as_nmod(& v, x):
179223 nmod_poly_set_coeff_ui(self .val, i, v)
180224 else :
181225 raise TypeError (" cannot set element of type %s " % type (x))
@@ -288,7 +332,7 @@ cdef class nmod_poly(flint_poly):
288332 9*x^4 + 12*x^3 + 10*x^2 + 4*x + 1
289333 """
290334 cdef nmod_poly res
291- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
335+ other = self .ctx.any_as_nmod_poly(other )
292336 if other is NotImplemented :
293337 raise TypeError (" cannot convert input to nmod_poly" )
294338 res = nmod_poly.__new__ (nmod_poly)
@@ -313,11 +357,11 @@ cdef class nmod_poly(flint_poly):
313357 147* x^ 3 + 159* x^ 2 + 4* x + 7
314358 """
315359 cdef nmod_poly res
316- g = any_as_nmod_poly(other, self .val.mod )
360+ g = self .ctx.any_as_nmod_poly(other )
317361 if g is NotImplemented :
318362 raise TypeError (f" cannot convert other = {other} to nmod_poly" )
319363
320- h = any_as_nmod_poly(modulus, self .val.mod)
364+ h = self . any_as_nmod_poly(modulus, self .val.mod)
321365 if h is NotImplemented :
322366 raise TypeError (f" cannot convert modulus = {modulus} to nmod_poly" )
323367
@@ -331,11 +375,11 @@ cdef class nmod_poly(flint_poly):
331375
332376 def __call__ (self , other ):
333377 cdef mp_limb_t c
334- if any_as_nmod(& c, other, self .val.mod ):
378+ if self .ctx. any_as_nmod(& c, other):
335379 v = nmod(0 , self .modulus())
336380 (< nmod> v).val = nmod_poly_evaluate_nmod(self .val, c)
337381 return v
338- t = any_as_nmod_poly(other, self .val.mod )
382+ t = self .ctx.any_as_nmod_poly(other )
339383 if t is not NotImplemented :
340384 r = nmod_poly.__new__ (nmod_poly)
341385 nmod_poly_init_preinv((< nmod_poly> r).val, self .val.mod.n, self .val.mod.ninv)
@@ -366,7 +410,7 @@ cdef class nmod_poly(flint_poly):
366410
367411 def _add_ (s , t ):
368412 cdef nmod_poly r
369- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
413+ t = s.ctx.any_as_nmod_poly(t )
370414 if t is NotImplemented :
371415 return t
372416 if (< nmod_poly> s).val.mod.n != (< nmod_poly> t).val.mod.n:
@@ -392,20 +436,20 @@ cdef class nmod_poly(flint_poly):
392436 return r
393437
394438 def __sub__ (s , t ):
395- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
439+ t = s.ctx.any_as_nmod_poly(t )
396440 if t is NotImplemented :
397441 return t
398442 return s._sub_(t)
399443
400444 def __rsub__ (s , t ):
401- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
445+ t = s. any_as_nmod_poly(t)
402446 if t is NotImplemented :
403447 return t
404448 return t._sub_(s)
405449
406450 def _mul_ (s , t ):
407451 cdef nmod_poly r
408- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
452+ t = s. any_as_nmod_poly(t)
409453 if t is NotImplemented :
410454 return t
411455 if (< nmod_poly> s).val.mod.n != (< nmod_poly> t).val.mod.n:
@@ -422,7 +466,7 @@ cdef class nmod_poly(flint_poly):
422466 return s._mul_(t)
423467
424468 def __truediv__ (s , t ):
425- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
469+ t = s. any_as_nmod_poly(t)
426470 if t is NotImplemented :
427471 return t
428472 res, r = s._divmod_(t)
@@ -431,7 +475,7 @@ cdef class nmod_poly(flint_poly):
431475 return res
432476
433477 def __rtruediv__ (s , t ):
434- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
478+ t = s. any_as_nmod_poly(t)
435479 if t is NotImplemented :
436480 return t
437481 res, r = t._divmod_(s)
@@ -451,13 +495,13 @@ cdef class nmod_poly(flint_poly):
451495 return r
452496
453497 def __floordiv__ (s , t ):
454- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
498+ t = s. any_as_nmod_poly(t)
455499 if t is NotImplemented :
456500 return t
457501 return s._floordiv_(t)
458502
459503 def __rfloordiv__ (s , t ):
460- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
504+ t = s. any_as_nmod_poly(t)
461505 if t is NotImplemented :
462506 return t
463507 return t._floordiv_(s)
@@ -476,13 +520,13 @@ cdef class nmod_poly(flint_poly):
476520 return P, Q
477521
478522 def __divmod__ (s , t ):
479- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
523+ t = s. any_as_nmod_poly(t)
480524 if t is NotImplemented :
481525 return t
482526 return s._divmod_(t)
483527
484528 def __rdivmod__ (s , t ):
485- t = any_as_nmod_poly(t, ( < nmod_poly > s).val.mod )
529+ t = s. any_as_nmod_poly(t)
486530 if t is NotImplemented :
487531 return t
488532 return t._divmod_(s)
@@ -531,7 +575,7 @@ cdef class nmod_poly(flint_poly):
531575 if e < 0 :
532576 raise ValueError (" Exponent must be non-negative" )
533577
534- modulus = any_as_nmod_poly(modulus, ( < nmod_poly > self ).val.mod )
578+ modulus = self .ctx.any_as_nmod_poly(modulus )
535579 if modulus is NotImplemented :
536580 raise TypeError (" cannot convert input to nmod_poly" )
537581
@@ -553,7 +597,7 @@ cdef class nmod_poly(flint_poly):
553597
554598 # To optimise powering, we precompute the inverse of the reverse of the modulus
555599 if mod_rev_inv is not None :
556- mod_rev_inv = any_as_nmod_poly(mod_rev_inv, ( < nmod_poly > self ).val.mod )
600+ mod_rev_inv = self . any_as_nmod_poly(mod_rev_inv)
557601 if mod_rev_inv is NotImplemented :
558602 raise TypeError (f" Cannot interpret {mod_rev_inv} as a polynomial" )
559603 else :
@@ -582,7 +626,7 @@ cdef class nmod_poly(flint_poly):
582626
583627 """
584628 cdef nmod_poly res
585- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
629+ other = self . any_as_nmod_poly(other)
586630 if other is NotImplemented :
587631 raise TypeError (" cannot convert input to nmod_poly" )
588632 if self .val.mod.n != (< nmod_poly> other).val.mod.n:
@@ -594,7 +638,7 @@ cdef class nmod_poly(flint_poly):
594638
595639 def xgcd (self , other ):
596640 cdef nmod_poly res1, res2, res3
597- other = any_as_nmod_poly(other, ( < nmod_poly > self ).val.mod )
641+ other = self . any_as_nmod_poly(other)
598642 if other is NotImplemented :
599643 raise TypeError (" cannot convert input to fmpq_poly" )
600644 res1 = nmod_poly.__new__ (nmod_poly)
0 commit comments