Skip to content

Commit ad81966

Browse files
committed
fix: check for non-prime modulus in nmod_poly
1 parent 7d18b7a commit ad81966

2 files changed

Lines changed: 111 additions & 59 deletions

File tree

src/flint/test/test_all.py

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ def test_fmpz_mod_poly():
21012101
assert pow(f, 2**60, g) == pow(pow(f, 2**30, g), 2**30, g)
21022102
assert pow(R_gen, 2**60, g) == pow(pow(R_gen, 2**30, g), 2**30, g)
21032103

2104-
# Check other typechecks for pow_mod
2104+
# Check other typechecks for pow_mod
21052105
assert raises(lambda: pow(f, -2, g), ValueError)
21062106
assert raises(lambda: pow(f, 1, "A"), TypeError)
21072107
assert raises(lambda: pow(f, "A", g), TypeError)
@@ -2536,10 +2536,6 @@ def test_polys():
25362536
for P, S, is_field, characteristic in _all_polys():
25372537

25382538
composite_characteristic = characteristic != 0 and not characteristic.is_prime()
2539-
# nmod_poly crashes for many operations with non-prime modulus
2540-
# https://github.com/flintlib/python-flint/issues/124
2541-
# so we can't even test it...
2542-
nmod_poly_will_crash = type(P(1)) is flint.nmod_poly and composite_characteristic
25432539

25442540
assert P([S(1)]) == P([1]) == P(P([1])) == P(1)
25452541

@@ -2684,30 +2680,58 @@ def setbad(obj, i, val):
26842680
assert raises(lambda: P([1, 2, 3]) * None, TypeError)
26852681
assert raises(lambda: None * P([1, 2, 3]), TypeError)
26862682

2687-
assert P([1, 2, 1]) // P([1, 1]) == P([1, 1])
2688-
assert P([1, 2, 1]) % P([1, 1]) == P([0])
2689-
assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0]))
2683+
if composite_characteristic and type(P(1)) is flint.nmod_poly:
2684+
# Z/nZ for n not prime
2685+
#
2686+
# fmpz_mod_poly and nmod_poly can sometimes compute division with
2687+
# composite characteristic, but it is not guaranteed to work. For
2688+
# fmpz_mod_poly, we can detect the failure and raise an exception.
2689+
# For nmod_poly, we cannot detect the failure and calling e.g.
2690+
# nmod_poly_divrem would crash the process so for nmod_poly we
2691+
# raise an exception in all cases if the modulus is not prime.
2692+
assert raises(lambda: P([1, 2, 1]) // P([1, 1]), DomainError)
2693+
assert raises(lambda: P([1, 2, 1]) % P([1, 1]), DomainError)
2694+
assert raises(lambda: divmod(P([1, 2, 1]), P([1, 1])), DomainError)
2695+
2696+
assert raises(lambda: 1 // P([1, 1]), DomainError)
2697+
assert raises(lambda: 1 % P([1, 1]), DomainError)
2698+
assert raises(lambda: divmod(1, P([1, 1])), DomainError)
2699+
else:
2700+
assert P([1, 2, 1]) // P([1, 1]) == P([1, 1])
2701+
assert P([1, 2, 1]) % P([1, 1]) == P([0])
2702+
assert divmod(P([1, 2, 1]), P([1, 1])) == (P([1, 1]), P([0]))
2703+
2704+
assert 1 // P([1, 1]) == P([0])
2705+
assert 1 % P([1, 1]) == P([1])
2706+
assert divmod(1, P([1, 1])) == (P([0]), P([1]))
2707+
2708+
assert P([1, 2, 1]) / P([1, 1]) == P([1, 1])
2709+
2710+
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
2711+
2712+
assert raises(lambda: 1 / P([1, 1]), DomainError)
2713+
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError)
26902714

26912715
if is_field:
26922716
assert P([1, 1]) // 2 == P([S(1)/2, S(1)/2])
26932717
assert P([1, 1]) % 2 == P([0])
2718+
assert P([2, 2]) / 2 == P([1, 1])
2719+
assert P([1, 2]) / 2 == P([S(1)/2, 1])
26942720
elif characteristic == 0:
26952721
assert P([1, 1]) // 2 == P([0, 0])
26962722
assert P([1, 1]) % 2 == P([1, 1])
2697-
elif nmod_poly_will_crash:
2698-
pass
2723+
assert P([2, 2]) / 2 == P([1, 1])
2724+
assert raises(lambda: P([1, 2]) / 2, DomainError)
26992725
else:
27002726
# Z/nZ for n not prime
27012727
if characteristic % 2 == 0:
27022728
assert raises(lambda: P([1, 1]) // 2, DomainError)
27032729
assert raises(lambda: P([1, 1]) % 2, DomainError)
2730+
assert raises(lambda: P([2, 2]) / 2, DomainError)
2731+
assert raises(lambda: P([1, 2]) / 2, DomainError)
27042732
else:
27052733
1/0
27062734

2707-
assert 1 // P([1, 1]) == P([0])
2708-
assert 1 % P([1, 1]) == P([1])
2709-
assert divmod(1, P([1, 1])) == (P([0]), P([1]))
2710-
27112735
assert raises(lambda: P([1, 2, 1]) // None, TypeError)
27122736
assert raises(lambda: P([1, 2, 1]) % None, TypeError)
27132737
assert raises(lambda: divmod(P([1, 2, 1]), None), TypeError)
@@ -2724,50 +2748,43 @@ def setbad(obj, i, val):
27242748
assert raises(lambda: P([1, 2, 1]) % P([0]), ZeroDivisionError)
27252749
assert raises(lambda: divmod(P([1, 2, 1]), P([0])), ZeroDivisionError)
27262750

2727-
# Exact/field scalar division
2728-
if is_field:
2729-
assert P([2, 2]) / 2 == P([1, 1])
2730-
assert P([1, 2]) / 2 == P([S(1)/2, 1])
2731-
elif characteristic == 0:
2732-
assert P([2, 2]) / 2 == P([1, 1])
2733-
assert raises(lambda: P([1, 2]) / 2, DomainError)
2734-
elif nmod_poly_will_crash:
2735-
pass
2736-
else:
2737-
# Z/nZ for n not prime
2738-
assert raises(lambda: P([2, 2]) / 2, DomainError)
2739-
assert raises(lambda: P([1, 2]) / 2, DomainError)
2740-
2741-
assert raises(lambda: P([1, 2]) / 0, ZeroDivisionError)
2742-
2743-
if not nmod_poly_will_crash:
2744-
assert P([1, 2, 1]) / P([1, 1]) == P([1, 1])
2745-
assert raises(lambda: 1 / P([1, 1]), DomainError)
2746-
assert raises(lambda: P([1, 2, 1]) / P([1, 2]), DomainError)
2747-
27482751
assert P([1, 1]) ** 0 == P([1])
27492752
assert P([1, 1]) ** 1 == P([1, 1])
27502753
assert P([1, 1]) ** 2 == P([1, 2, 1])
27512754
assert raises(lambda: P([1, 1]) ** -1, ValueError)
27522755
assert raises(lambda: P([1, 1]) ** None, TypeError)
2753-
2754-
# XXX: Not sure what this should do in general:
2756+
2757+
# 3-arg pow: (x^2 + 1)**3 mod x-1
2758+
2759+
pow3_types = [
2760+
# flint.fmpq_poly, XXX
2761+
flint.nmod_poly,
2762+
flint.fmpz_mod_poly,
2763+
flint.fq_default_poly
2764+
]
2765+
27552766
p = P([1, 1])
27562767
mod = P([1, 1])
2757-
if type(p) not in [flint.fmpz_mod_poly, flint.nmod_poly, flint.fq_default_poly]:
2768+
2769+
if type(p) not in pow3_types:
27582770
assert raises(lambda: pow(p, 2, mod), NotImplementedError)
2771+
assert p * p % mod == 0
2772+
elif composite_characteristic and type(p) == flint.nmod_poly:
2773+
# nmod_poly does not support % with composite characteristic
2774+
assert pow(p, 2, mod) == 0
2775+
assert raises(lambda: p * p % mod, DomainError)
27592776
else:
2777+
# Should be for any is_field including fmpq_poly. Works also in
2778+
# some cases for fmpz_mod_poly with non-prime modulus.
27602779
assert p * p % mod == pow(p, 2, mod)
27612780

27622781
if not composite_characteristic:
27632782
assert P([1, 2, 1]).gcd(P([1, 1])) == P([1, 1])
2764-
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
2765-
elif nmod_poly_will_crash:
2766-
pass
27672783
else:
27682784
# Z/nZ for n not prime
27692785
assert raises(lambda: P([1, 2, 1]).gcd(P([1, 1])), DomainError)
2770-
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
2786+
2787+
assert raises(lambda: P([1, 2, 1]).gcd(None), TypeError)
27712788

27722789
if is_field:
27732790
p1 = P([1, 0, 1])
@@ -2778,19 +2795,16 @@ def setbad(obj, i, val):
27782795

27792796
if not composite_characteristic:
27802797
assert P([1, 2, 1]).factor() == (S(1), [(P([1, 1]), 2)])
2781-
elif nmod_poly_will_crash:
2782-
pass
27832798
else:
27842799
assert raises(lambda: P([1, 2, 1]).factor(), DomainError)
27852800

27862801
if not composite_characteristic:
27872802
assert P([1, 2, 1]).sqrt() == P([1, 1])
2788-
assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError)
2789-
elif nmod_poly_will_crash:
2790-
pass
27912803
else:
27922804
assert raises(lambda: P([1, 2, 1]).sqrt(), DomainError)
27932805

2806+
assert raises(lambda: P([1, 2, 2]).sqrt(), DomainError)
2807+
27942808
if P == flint.fmpq_poly:
27952809
assert raises(lambda: P([1, 2, 1], 3).sqrt(), ValueError)
27962810
assert P([1, 2, 1], 4).sqrt() == P([1, 1], 2)
@@ -3362,13 +3376,6 @@ def factor_sqf(p):
33623376
for P, S, [x, y], is_field, characteristic in _all_polys_mpolys():
33633377

33643378
if characteristic != 0 and not characteristic.is_prime():
3365-
# nmod_poly crashes for many operations with non-prime modulus
3366-
# https://github.com/flintlib/python-flint/issues/124
3367-
# so we can't even test it...
3368-
nmod_poly_will_crash = type(x) is flint.nmod_poly
3369-
if nmod_poly_will_crash:
3370-
continue
3371-
33723379
try:
33733380
S(4).sqrt() ** 2 == S(4)
33743381
except DomainError:
@@ -4128,7 +4135,7 @@ def test_fq_default():
41284135

41294136
# p must be prime
41304137
assert raises(lambda: flint.fq_default_ctx(10), ValueError)
4131-
4138+
41324139
# degree must be positive
41334140
assert raises(lambda: flint.fq_default_ctx(11, -1), ValueError)
41344141

src/flint/types/nmod_poly.pyx

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,14 @@ cdef class nmod_poly(flint_poly):
504504

505505
def _floordiv_(s, t):
506506
cdef nmod_poly r
507+
507508
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
508509
raise ValueError("cannot divide nmod_polys with different moduli")
509510
if nmod_poly_is_zero((<nmod_poly>t).val):
510511
raise ZeroDivisionError("polynomial division by zero")
512+
if not s.ctx._is_prime:
513+
raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime")
514+
511515
r = nmod_poly_new_init_preinv(s.ctx)
512516
nmod_poly_div(r.val, (<nmod_poly>s).val, (<nmod_poly>t).val)
513517
r.ctx = s.ctx
@@ -527,10 +531,14 @@ cdef class nmod_poly(flint_poly):
527531

528532
def _divmod_(s, t):
529533
cdef nmod_poly P, Q
534+
530535
if (<nmod_poly>s).val.mod.n != (<nmod_poly>t).val.mod.n:
531536
raise ValueError("cannot divide nmod_polys with different moduli")
532537
if nmod_poly_is_zero((<nmod_poly>t).val):
533538
raise ZeroDivisionError("polynomial division by zero")
539+
if not s.ctx._is_prime:
540+
raise DomainError("nmod_poly divmod: modulus {self.ctx.mod.n} is not prime")
541+
534542
P = nmod_poly_new_init_preinv(s.ctx)
535543
Q = nmod_poly_new_init_preinv(s.ctx)
536544
nmod_poly_divrem(P.val, Q.val, (<nmod_poly>s).val, (<nmod_poly>t).val)
@@ -640,27 +648,53 @@ cdef class nmod_poly(flint_poly):
640648
>>> (A * B).gcd(B) * 5
641649
5*x^2 + x + 4
642650
651+
The modulus must be prime.
643652
"""
644653
cdef nmod_poly res
654+
645655
other = self.ctx.any_as_nmod_poly(other)
646656
if other is NotImplemented:
647657
raise TypeError("cannot convert input to nmod_poly")
648658
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
649659
raise ValueError("moduli must be the same")
660+
if not self.ctx._is_prime:
661+
raise DomainError("nmod_poly gcd: modulus {self.ctx.mod.n} is not prime")
662+
650663
res = nmod_poly_new_init_preinv(self.ctx)
651664
nmod_poly_gcd(res.val, self.val, (<nmod_poly>other).val)
652665
res.ctx = self.ctx
653666
return res
654667

655668
def xgcd(self, other):
669+
r"""
670+
Computes the extended gcd of self and other: (`G`, `S`, `T`)
671+
where `G` is the ``gcd(self, other)`` and `S`, `T` are such that:
672+
673+
:math:`G = \textrm{self}*S + \textrm{other}*T`
674+
675+
>>> f = nmod_poly([143, 19, 37, 138, 102, 127, 95], 163)
676+
>>> g = nmod_poly([139, 9, 35, 154, 87, 120, 24], 163)
677+
>>> f.xgcd(g)
678+
(x^3 + 128*x^2 + 123*x + 91, 17*x^2 + 49*x + 104, 21*x^2 + 5*x + 25)
679+
680+
The modulus must be prime.
681+
"""
656682
cdef nmod_poly res1, res2, res3
683+
657684
other = self.ctx.any_as_nmod_poly(other)
658685
if other is NotImplemented:
659686
raise TypeError("cannot convert input to fmpq_poly")
687+
if self.val.mod.n != (<nmod_poly>other).val.mod.n:
688+
raise ValueError("moduli must be the same")
689+
if not self.ctx._is_prime:
690+
raise DomainError("nmod_poly xgcd: modulus {self.ctx.mod.n} is not prime")
691+
660692
res1 = nmod_poly_new_init(self.ctx)
661693
res2 = nmod_poly_new_init(self.ctx)
662694
res3 = nmod_poly_new_init(self.ctx)
695+
663696
nmod_poly_xgcd(res1.val, res2.val, res3.val, self.val, (<nmod_poly>other).val)
697+
664698
return (res1, res2, res3)
665699

666700
def factor(self, algorithm=None):
@@ -685,11 +719,14 @@ cdef class nmod_poly(flint_poly):
685719
>>> nmod_poly([3,2,1,2,3], 7).factor(algorithm='cantor-zassenhaus')
686720
(3, [(x + 4, 1), (x + 2, 1), (x^2 + 4*x + 1, 1)])
687721
722+
The modulus must be prime.
688723
"""
689724
if algorithm is None:
690725
algorithm = 'irreducible'
691726
elif algorithm not in ('berlekamp', 'cantor-zassenhaus'):
692727
raise ValueError(f"unknown factorization algorithm: {algorithm}")
728+
if not self.ctx._is_prime:
729+
raise DomainError(f"nmod_poly factor: modulus {self.ctx.mod.n} is not prime")
693730
return self._factor(algorithm)
694731

695732
def factor_squarefree(self):
@@ -708,6 +745,8 @@ cdef class nmod_poly(flint_poly):
708745
(2, [(x, 2), (x + 5, 2), (x + 1, 3)])
709746
710747
"""
748+
if not self.ctx._is_prime:
749+
raise DomainError(f"nmod_poly factor_squarefree: modulus {self.ctx.mod.n} is not prime")
711750
return self._factor('squarefree')
712751

713752
def _factor(self, factor_type):
@@ -748,12 +787,18 @@ cdef class nmod_poly(flint_poly):
748787

749788
def sqrt(nmod_poly self):
750789
"""Return exact square root or ``None``. """
751-
cdef nmod_poly res = nmod_poly_new_init_preinv(self.ctx)
752-
if nmod_poly_sqrt(res.val, self.val):
753-
return res
754-
else:
790+
cdef nmod_poly
791+
792+
if not self.ctx._is_prime:
793+
raise DomainError(f"nmod_poly sqrt: modulus {self.ctx.mod.n} is not prime")
794+
795+
res = nmod_poly_new_init_preinv(self.ctx)
796+
797+
if not nmod_poly_sqrt(res.val, self.val):
755798
raise DomainError(f"Cannot compute square root of {self}")
756799

800+
return res
801+
757802
def deflation(self):
758803
cdef nmod_poly v
759804
cdef ulong n

0 commit comments

Comments
 (0)