Skip to content

Commit 6f04cdd

Browse files
committed
fix: check for non-prime modulus in nmod_poly
1 parent affe462 commit 6f04cdd

2 files changed

Lines changed: 111 additions & 59 deletions

File tree

src/flint/test/test_all.py

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,7 +2104,7 @@ def test_fmpz_mod_poly():
21042104
assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g)
21052105
assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g)
21062106

2107-
# Check other typechecks for pow_mod
2107+
# Check other typechecks for pow_mod
21082108
assert raises(lambda: pow(f, -2, g), ValueError)
21092109
assert raises(lambda: pow(f, 1, "A"), TypeError)
21102110
assert raises(lambda: pow(f, "A", g), TypeError)
@@ -2542,10 +2542,6 @@ def test_polys():
25422542
for P, S, is_field, characteristic in _all_polys():
25432543

25442544
composite_characteristic = characteristic != 0 and not characteristic.is_prime()
2545-
# nmod_poly crashes for many operations with non-prime modulus
2546-
# https://github.com/flintlib/python-flint/issues/124
2547-
# so we can't even test it...
2548-
nmod_poly_will_crash = type(P(1)) is flint.nmod_poly and composite_characteristic
25492545

25502546
assert P([S(1)]) == P([1]) == P(P([1])) == P(1)
25512547

@@ -2690,30 +2686,58 @@ def setbad(obj, i, val):
26902686
assert raises(lambda: P([1, 2, 3]) * None, TypeError)
26912687
assert raises(lambda: None * P([1, 2, 3]), TypeError)
26922688

2693-
assert P([1, 2, 1]) // P([1, 1]) == P([1, 1])
2694-
assert P([1, 2, 1]) % P([1, 1]) == P([0])
2695-
assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0]))
2689+
if composite_characteristic and type(P(1)) is flint.nmod_poly:
2690+
# Z/nZ for n not prime
2691+
#
2692+
# fmpz_mod_poly and nmod_poly can sometimes compute division with
2693+
# composite characteristic, but it is not guaranteed to work. For
2694+
# fmpz_mod_poly, we can detect the failure and raise an exception.
2695+
# For nmod_poly, we cannot detect the failure and calling e.g.
2696+
# nmod_poly_divrem would crash the process so for nmod_poly we
2697+
# raise an exception in all cases if the modulus is not prime.
2698+
assert raises(lambda: P([1, 2, 1]) // P([1, 1]), DomainError)
2699+
assert raises(lambda: P([1, 2, 1]) % P([1, 1]), DomainError)
2700+
assert raises(lambda: divmod(P([1, 2, 1]), P([1, 1])), DomainError)
2701+
2702+
assert raises(lambda: 1 // P([1, 1]), DomainError)
2703+
assert raises(lambda: 1 % P([1, 1]), DomainError)
2704+
assert raises(lambda: divmod(1, P([1, 1])), DomainError)
2705+
else:
2706+
assert P([1, 2, 1]) // P([1, 1]) == P([1, 1])
2707+
assert P([1, 2, 1]) % P([1, 1]) == P([0])
2708+
assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0]))
2709+
2710+
assert 1 // P([1, 1]) == P([0])
2711+
assert 1 % P([1, 1]) == P([1])
2712+
assert divmod(1, P([1, 1])) == (P([0]), P([1]))
2713+
2714+
assert P([1, 2, 1]) / P([1, 1]) == P([1, 1])
2715+
2716+
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
2717+
2718+
assert raises(lambda: 1 / P([1, 1]), DomainError)
2719+
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError)
26962720

26972721
if is_field:
26982722
assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2])
26992723
assert P([1, 1]) % 2 == P([0])
2724+
assert P([2, 2]) / 2 == P([1, 1])
2725+
assert P([1, 2]) / 2 == P([S(1)/2, 1])
27002726
elif characteristic == 0:
27012727
assert P([1, 1]) // 2 == P([0, 0])
27022728
assert P([1, 1]) % 2 == P([1, 1])
2703-
elif nmod_poly_will_crash:
2704-
pass
2729+
assert P([2, 2]) / 2 == P([1, 1])
2730+
assert raises(lambda: P([1, 2]) / 2, DomainError)
27052731
else:
27062732
# Z/nZ for n not prime
27072733
if characteristic % 2 == 0:
27082734
assert raises(lambda: P([1, 1]) // 2, DomainError)
27092735
assert raises(lambda: P([1, 1]) % 2, DomainError)
2736+
assert raises(lambda: P([2, 2]) / 2, DomainError)
2737+
assert raises(lambda: P([1, 2]) / 2, DomainError)
27102738
else:
27112739
1/0
27122740

2713-
assert 1 // P([1, 1]) == P([0])
2714-
assert 1 % P([1, 1]) == P([1])
2715-
assert divmod(1, P([1, 1])) == (P([0]), P([1]))
2716-
27172741
assert raises(lambda: P([1, 2, 1]) // None, TypeError)
27182742
assert raises(lambda: P([1, 2, 1]) % None, TypeError)
27192743
assert raises(lambda: divmod(P([1, 2, 1]), None), TypeError)
@@ -2730,50 +2754,43 @@ def setbad(obj, i, val):
27302754
assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError)
27312755
assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError)
27322756

2733-
# Exact/field scalar division
2734-
if is_field:
2735-
assert P([2, 2]) / 2 == P([1, 1])
2736-
assert P([1, 2]) / 2 == P([S(1)/2, 1])
2737-
elif characteristic == 0:
2738-
assert P([2, 2]) / 2 == P([1, 1])
2739-
assert raises(lambda: P([1, 2]) / 2, DomainError)
2740-
elif nmod_poly_will_crash:
2741-
pass
2742-
else:
2743-
# Z/nZ for n not prime
2744-
assert raises(lambda: P([2, 2]) / 2, DomainError)
2745-
assert raises(lambda: P([1, 2]) / 2, DomainError)
2746-
2747-
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
2748-
2749-
if not nmod_poly_will_crash:
2750-
assert P([1, 2, 1]) / P([1, 1]) == P([1, 1])
2751-
assert raises(lambda: 1 / P([1, 1]), DomainError)
2752-
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError)
2753-
27542757
assert P([1, 1]) ** 0 == P([1])
27552758
assert P([1, 1]) ** 1 == P([1, 1])
27562759
assert P([1, 1]) ** 2 == P([1, 2, 1])
27572760
assert raises(lambda: P([1, 1]) ** -1, ValueError)
27582761
assert raises(lambda: P([1, 1]) ** None, TypeError)
2759-
2760-
# XXX: Not sure what this should do in general:
2762+
2763+
# 3-arg pow: (x^2 + 1)**3 mod x-1
2764+
2765+
pow3_types = [
2766+
# flint.fmpq_poly, XXX
2767+
flint.nmod_poly,
2768+
flint.fmpz_mod_poly,
2769+
flint.fq_default_poly
2770+
]
2771+
27612772
p = P([1, 1])
27622773
mod = P([1, 1])
2763-
if type(p) not in [flint.fmpz_mod_poly, flint.nmod_poly, flint.fq_default_poly]:
2774+
2775+
if type(p) not in pow3_types:
27642776
assert raises(lambda: pow(p, 2, mod), NotImplementedError)
2777+
assert p * p % mod == 0
2778+
elif composite_characteristic and type(p) == flint.nmod_poly:
2779+
# nmod_poly does not support % with composite characteristic
2780+
assert pow(p, 2, mod) == 0
2781+
assert raises(lambda: p * p % mod, DomainError)
27652782
else:
2783+
# Should be for any is_field including fmpq_poly. Works also in
2784+
# some cases for fmpz_mod_poly with non-prime modulus.
27662785
assert p * p % mod == pow(p, 2, mod)
27672786

27682787
if not composite_characteristic:
27692788
assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
2770-
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
2771-
elif nmod_poly_will_crash:
2772-
pass
27732789
else:
27742790
# Z/nZ for n not prime
27752791
assert raises(lambda: P([1, 2, 1]).gcd(P([1, 1])), DomainError)
2776-
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
2792+
2793+
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
27772794

27782795
if is_field:
27792796
p1 = P([1, 0, 1])
@@ -2784,19 +2801,16 @@ def setbad(obj, i, val):
27842801

27852802
if not composite_characteristic:
27862803
assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)])
2787-
elif nmod_poly_will_crash:
2788-
pass
27892804
else:
27902805
assert raises(lambda: P([1, 2, 1]).factor(), DomainError)
27912806

27922807
if not composite_characteristic:
27932808
assert P([1, 2, 1]).sqrt() == P([1, 1])
2794-
assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError)
2795-
elif nmod_poly_will_crash:
2796-
pass
27972809
else:
27982810
assert raises(lambda: P([1, 2, 1]).sqrt(), DomainError)
27992811

2812+
assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError)
2813+
28002814
if P == flint.fmpq_poly:
28012815
assert raises(lambda: P([1, 2, 1], 3).sqrt(), ValueError)
28022816
assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2)
@@ -3424,13 +3438,6 @@ def factor_sqf(p):
34243438
for P, S, [x, y], is_field, characteristic in _all_polys_mpolys():
34253439

34263440
if characteristic != 0 and not characteristic.is_prime():
3427-
# nmod_poly crashes for many operations with non-prime modulus
3428-
# https://github.com/flintlib/python-flint/issues/124
3429-
# so we can't even test it...
3430-
nmod_poly_will_crash = type(x) is flint.nmod_poly
3431-
if nmod_poly_will_crash:
3432-
continue
3433-
34343441
try:
34353442
S(4).sqrt() ** 2 == S(4)
34363443
except DomainError:
@@ -4190,7 +4197,7 @@ def test_fq_default():
41904197

41914198
# p must be prime
41924199
assert raises(lambda: flint.fq_default_ctx(10), ValueError)
4193-
4200+
41944201
# degree must be positive
41954202
assert raises(lambda: flint.fq_default_ctx(11, -1), ValueError)
41964203

src/flint/types/nmod_poly.pyx

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,10 +503,14 @@ cdef class nmod_poly(flint_poly):
503503

504504
def _floordiv_(s, t):
505505
cdef nmod_poly r
506+
506507
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
507508
raise ValueError("cannot divide nmod_polys with different moduli")
508509
if nmod_poly_is_zero((<nmod_poly>t).val):
509510
raise ZeroDivisionError("polynomial division by zero")
511+
if not s.ctx._is_prime:
512+
raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime")
513+
510514
r = nmod_poly_new_init_preinv(s.ctx)
511515
nmod_poly_div(r.val, (<nmod_poly>s).val, (<nmod_poly>t).val)
512516
r.ctx = s.ctx
@@ -526,10 +530,14 @@ cdef class nmod_poly(flint_poly):
526530

527531
def _divmod_(s, t):
528532
cdef nmod_poly P, Q
533+
529534
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
530535
raise ValueError("cannot divide nmod_polys with different moduli")
531536
if nmod_poly_is_zero((<nmod_poly>t).val):
532537
raise ZeroDivisionError("polynomial division by zero")
538+
if not s.ctx._is_prime:
539+
raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime")
540+
533541
P = nmod_poly_new_init_preinv(s.ctx)
534542
Q = nmod_poly_new_init_preinv(s.ctx)
535543
nmod_poly_divrem(P.val, Q.val, (<nmod_poly>s).val, (<nmod_poly>t).val)
@@ -639,27 +647,53 @@ cdef class nmod_poly(flint_poly):
639647
>>> (A * B).gcd(B) * 5
640648
5*x^2 + x + 4
641649
650+
The modulus must be prime.
642651
"""
643652
cdef nmod_poly res
653+
644654
other = self.ctx.any_as_nmod_poly(other)
645655
if other is NotImplemented:
646656
raise TypeError("cannot convert input to nmod_poly")
647657
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
648658
raise ValueError("moduli must be the same")
659+
if not self.ctx._is_prime:
660+
raise DomainError("nmod_poly gcd: modulus {self.ctx.mod.n} is not prime")
661+
649662
res = nmod_poly_new_init_preinv(self.ctx)
650663
nmod_poly_gcd(res.val, self.val, (<nmod_poly>other).val)
651664
res.ctx = self.ctx
652665
return res
653666

654667
def xgcd(self, other):
668+
r"""
669+
Computes the extended gcd of self and other: (`G`, `S`, `T`)
670+
where `G` is the ``gcd(self, other)`` and `S`, `T` are such that:
671+
672+
:math:`G = \textrm{self}*S + \textrm{other}*T`
673+
674+
>>> f = nmod_poly([143, 19, 37, 138, 102, 127, 95], 163)
675+
>>> g = nmod_poly([139, 9, 35, 154, 87, 120, 24], 163)
676+
>>> f.xgcd(g)
677+
(x^3 + 128*x^2 + 123*x + 91, 17*x^2 + 49*x + 104, 21*x^2 + 5*x + 25)
678+
679+
The modulus must be prime.
680+
"""
655681
cdef nmod_poly res1, res2, res3
682+
656683
other = self.ctx.any_as_nmod_poly(other)
657684
if other is NotImplemented:
658685
raise TypeError("cannot convert input to fmpq_poly")
686+
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
687+
raise ValueError("moduli must be the same")
688+
if not self.ctx._is_prime:
689+
raise DomainError("nmod_poly xgcd: modulus {self.ctx.mod.n} is not prime")
690+
659691
res1 = nmod_poly_new_init(self.ctx)
660692
res2 = nmod_poly_new_init(self.ctx)
661693
res3 = nmod_poly_new_init(self.ctx)
694+
662695
nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (<nmod_poly>other).val)
696+
663697
return (res1, res2, res3)
664698

665699
def factor(self, algorithm=None):
@@ -684,11 +718,14 @@ cdef class nmod_poly(flint_poly):
684718
>>> nmod_poly([3,2,1,2,3], 7).factor(algorithm='cantor-zassenhaus')
685719
(3, [(x + 4, 1), (x + 2, 1), (x^2 + 4*x + 1, 1)])
686720
721+
The modulus must be prime.
687722
"""
688723
if algorithm is None:
689724
algorithm = 'irreducible'
690725
elif algorithm not in ('berlekamp', 'cantor-zassenhaus'):
691726
raise ValueError(f"unknown factorization algorithm: {algorithm}")
727+
if not self.ctx._is_prime:
728+
raise DomainError(f"nmod_poly factor: modulus {self.ctx.mod.n} is not prime")
692729
return self._factor(algorithm)
693730

694731
def factor_squarefree(self):
@@ -707,6 +744,8 @@ cdef class nmod_poly(flint_poly):
707744
(2, [(x, 2), (x + 5, 2), (x + 1, 3)])
708745
709746
"""
747+
if not self.ctx._is_prime:
748+
raise DomainError(f"nmod_poly factor_squarefree: modulus {self.ctx.mod.n} is not prime")
710749
return self._factor('squarefree')
711750

712751
def _factor(self, factor_type):
@@ -747,12 +786,18 @@ cdef class nmod_poly(flint_poly):
747786

748787
def sqrt(nmod_poly self):
749788
"""Return exact square root or ``None``. """
750-
cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx)
751-
if nmod_poly_sqrt(res.val, self.val):
752-
return res
753-
else:
789+
cdef nmod_poly
790+
791+
if not self.ctx._is_prime:
792+
raise DomainError(f"nmod_poly sqrt: modulus {self.ctx.mod.n} is not prime")
793+
794+
res = nmod_poly_new_init_preinv(self.ctx)
795+
796+
if not nmod_poly_sqrt(res.val, self.val):
754797
raise DomainError(f"Cannot compute square root of {self}")
755798

799+
return res
800+
756801
def deflation(self):
757802
cdef nmod_poly v
758803
cdef ulong n

0 commit comments

Comments
 (0)