Skip to content

Commit 48136d1

Browse files
committed
More nmod interop
1 parent accffa9 commit 48136d1

2 files changed

Lines changed: 125 additions & 3 deletions

File tree

src/flint/test/test_all.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,6 +2994,12 @@ def quick_poly():
29942994
assert quick_poly() % 1 == P(ctx=ctx)
29952995
assert divmod(quick_poly(), 1) == (quick_poly(), P(ctx=ctx))
29962996

2997+
assert S(1) / P(1, ctx=ctx) == P(1, ctx=ctx)
2998+
assert quick_poly() / S(1) == quick_poly()
2999+
assert quick_poly() // S(1) == quick_poly()
3000+
assert quick_poly() % S(1) == P(ctx=ctx)
3001+
assert divmod(quick_poly(), S(1)) == (quick_poly(), P(ctx=ctx))
3002+
29973003
if is_field:
29983004
if (P is flint.fmpz_mod_mpoly or P is flint.nmod_mpoly):
29993005
assert quick_poly() / 3 == mpoly({(0, 0): S(34), (0, 1): S(68), (1, 0): S(1), (2, 2): S(35)})
@@ -3011,6 +3017,10 @@ def quick_poly():
30113017
assert 1 % quick_poly() == P(1, ctx=ctx)
30123018
assert divmod(1, quick_poly()) == (P(ctx=ctx), P(1, ctx=ctx))
30133019

3020+
assert S(1) // quick_poly() == P(ctx=ctx)
3021+
assert S(1) % quick_poly() == P(1, ctx=ctx)
3022+
assert divmod(S(1), quick_poly()) == (P(ctx=ctx), P(1, ctx=ctx))
3023+
30143024
assert raises(lambda: quick_poly() / None, TypeError)
30153025
assert raises(lambda: quick_poly() // None, TypeError)
30163026
assert raises(lambda: quick_poly() % None, TypeError)

src/flint/types/nmod_mpoly.pyx

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ cdef class nmod_mpoly(flint_mpoly):
291291
return op == Py_NE
292292
elif typecheck(other, nmod):
293293
if other.modulus() != self.ctx.modulus():
294-
raise ValueError(f"cannot compare different modulus {other.modulus()} vs {self.ctx.modulus()}")
294+
raise ValueError(f"cannot compare with different modulus {other.modulus()} vs {self.ctx.modulus()}")
295295
return (op == Py_NE) ^ <bint>nmod_mpoly_equal_ui(self.val, int(other), self.ctx.val)
296296
elif isinstance(other, int):
297297
return (op == Py_NE) ^ <bint>nmod_mpoly_equal_ui(self.val, other, self.ctx.val)
@@ -364,6 +364,12 @@ cdef class nmod_mpoly(flint_mpoly):
364364
res = create_nmod_mpoly(self.ctx)
365365
nmod_mpoly_add(res.val, (<nmod_mpoly>self).val, (<nmod_mpoly>other).val, res.ctx.val)
366366
return res
367+
elif typecheck(other, nmod):
368+
if other.modulus() != self.ctx.modulus():
369+
raise ValueError(f"cannot add with different modulus {other.modulus()} vs {self.ctx.modulus()}")
370+
res = create_nmod_mpoly(self.ctx)
371+
nmod_mpoly_add_ui(res.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
372+
return res
367373
elif isinstance(other, int):
368374
res = create_nmod_mpoly(self.ctx)
369375
nmod_mpoly_add_ui(res.val, (<nmod_mpoly>self).val, other, self.ctx.val)
@@ -377,6 +383,12 @@ cdef class nmod_mpoly(flint_mpoly):
377383
res = create_nmod_mpoly(self.ctx)
378384
nmod_mpoly_add_ui(res.val, (<nmod_mpoly>self).val, other, self.ctx.val)
379385
return res
386+
elif typecheck(other, nmod):
387+
if other.modulus() != self.ctx.modulus():
388+
raise ValueError(f"cannot add with different modulus {other.modulus()} vs {self.ctx.modulus()}")
389+
res = create_nmod_mpoly(self.ctx)
390+
nmod_mpoly_add_ui(res.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
391+
return res
380392
else:
381393
return NotImplemented
382394

@@ -388,6 +400,12 @@ cdef class nmod_mpoly(flint_mpoly):
388400
res = create_nmod_mpoly(self.ctx)
389401
nmod_mpoly_sub(res.val, (<nmod_mpoly>self).val, (<nmod_mpoly>other).val, res.ctx.val)
390402
return res
403+
elif typecheck(other, nmod):
404+
if other.modulus() != self.ctx.modulus():
405+
raise ValueError(f"cannot subtract with different modulus {other.modulus()} vs {self.ctx.modulus()}")
406+
res = create_nmod_mpoly(self.ctx)
407+
nmod_mpoly_sub_ui(res.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
408+
return res
391409
elif isinstance(other, int):
392410
res = create_nmod_mpoly(self.ctx)
393411
nmod_mpoly_sub_ui(res.val, (<nmod_mpoly>self).val, other, self.ctx.val)
@@ -401,6 +419,12 @@ cdef class nmod_mpoly(flint_mpoly):
401419
res = create_nmod_mpoly(self.ctx)
402420
nmod_mpoly_sub_ui(res.val, (<nmod_mpoly>self).val, other, res.ctx.val)
403421
return -res
422+
elif typecheck(other, nmod):
423+
if other.modulus() != self.ctx.modulus():
424+
raise ValueError(f"cannot subtract with different modulus {other.modulus()} vs {self.ctx.modulus()}")
425+
res = create_nmod_mpoly(self.ctx)
426+
nmod_mpoly_sub_ui(res.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
427+
return -res
404428
else:
405429
return NotImplemented
406430

@@ -412,6 +436,12 @@ cdef class nmod_mpoly(flint_mpoly):
412436
res = create_nmod_mpoly(self.ctx)
413437
nmod_mpoly_mul(res.val, (<nmod_mpoly>self).val, (<nmod_mpoly>other).val, res.ctx.val)
414438
return res
439+
elif typecheck(other, nmod):
440+
if other.modulus() != self.ctx.modulus():
441+
raise ValueError(f"cannot multiply with different modulus {other.modulus()} vs {self.ctx.modulus()}")
442+
res = create_nmod_mpoly(self.ctx)
443+
nmod_mpoly_scalar_mul_ui(res.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
444+
return res
415445
elif isinstance(other, int):
416446
res = create_nmod_mpoly(self.ctx)
417447
nmod_mpoly_scalar_mul_ui(res.val, (<nmod_mpoly>self).val, other, res.ctx.val)
@@ -425,6 +455,12 @@ cdef class nmod_mpoly(flint_mpoly):
425455
res = create_nmod_mpoly(self.ctx)
426456
nmod_mpoly_scalar_mul_ui(res.val, (<nmod_mpoly>self).val, other, res.ctx.val)
427457
return res
458+
elif typecheck(other, nmod):
459+
if other.modulus() != self.ctx.modulus():
460+
raise ValueError(f"cannot multiply with different modulus {other.modulus()} vs {self.ctx.modulus()}")
461+
res = create_nmod_mpoly(self.ctx)
462+
nmod_mpoly_scalar_mul_ui(res.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
463+
return res
428464
else:
429465
return NotImplemented
430466

@@ -464,6 +500,18 @@ cdef class nmod_mpoly(flint_mpoly):
464500
o = create_nmod_mpoly(self.ctx)
465501
nmod_mpoly_set_ui(o.val, other, o.ctx.val)
466502

503+
res = create_nmod_mpoly(self.ctx)
504+
res2 = create_nmod_mpoly(self.ctx)
505+
nmod_mpoly_divrem(res.val, res2.val, (<nmod_mpoly>self).val, o.val, res.ctx.val)
506+
return (res, res2)
507+
elif typecheck(other, nmod):
508+
if other.modulus() != self.ctx.modulus():
509+
raise ValueError(f"cannot divide with different modulus {other.modulus()} vs {self.ctx.modulus()}")
510+
elif not other:
511+
raise ZeroDivisionError("nmod_mpoly division by zero")
512+
o = create_nmod_mpoly(self.ctx)
513+
nmod_mpoly_set_ui(o.val, int(other), o.ctx.val)
514+
467515
res = create_nmod_mpoly(self.ctx)
468516
res2 = create_nmod_mpoly(self.ctx)
469517
nmod_mpoly_divrem(res.val, res2.val, (<nmod_mpoly>self).val, o.val, res.ctx.val)
@@ -481,6 +529,16 @@ cdef class nmod_mpoly(flint_mpoly):
481529
o = create_nmod_mpoly(self.ctx)
482530
nmod_mpoly_set_ui(o.val, other, o.ctx.val)
483531

532+
res = create_nmod_mpoly(self.ctx)
533+
res2 = create_nmod_mpoly(self.ctx)
534+
nmod_mpoly_divrem(res.val, res2.val, o.val, (<nmod_mpoly>self).val, res.ctx.val)
535+
return (res, res2)
536+
elif typecheck(other, nmod):
537+
if other.modulus() != self.ctx.modulus():
538+
raise ValueError(f"cannot divide with different modulus {other.modulus()} vs {self.ctx.modulus()}")
539+
o = create_nmod_mpoly(self.ctx)
540+
nmod_mpoly_set_ui(o.val, int(other), o.ctx.val)
541+
484542
res = create_nmod_mpoly(self.ctx)
485543
res2 = create_nmod_mpoly(self.ctx)
486544
nmod_mpoly_divrem(res.val, res2.val, o.val, (<nmod_mpoly>self).val, res.ctx.val)
@@ -506,6 +564,17 @@ cdef class nmod_mpoly(flint_mpoly):
506564
o = create_nmod_mpoly(self.ctx)
507565
nmod_mpoly_set_ui(o.val, other, o.ctx.val)
508566

567+
res = create_nmod_mpoly(self.ctx)
568+
nmod_mpoly_div(res.val, (<nmod_mpoly>self).val, o.val, res.ctx.val)
569+
return res
570+
elif typecheck(other, nmod):
571+
if other.modulus() != self.ctx.modulus():
572+
raise ValueError(f"cannot divide with different modulus {other.modulus()} vs {self.ctx.modulus()}")
573+
elif not other:
574+
raise ZeroDivisionError("nmod_mpoly division by zero")
575+
o = create_nmod_mpoly(self.ctx)
576+
nmod_mpoly_set_ui(o.val, int(other), o.ctx.val)
577+
509578
res = create_nmod_mpoly(self.ctx)
510579
nmod_mpoly_div(res.val, (<nmod_mpoly>self).val, o.val, res.ctx.val)
511580
return res
@@ -525,6 +594,15 @@ cdef class nmod_mpoly(flint_mpoly):
525594
res = create_nmod_mpoly(self.ctx)
526595
nmod_mpoly_div(res.val, o.val, self.val, res.ctx.val)
527596
return res
597+
elif typecheck(other, nmod):
598+
if other.modulus() != self.ctx.modulus():
599+
raise ValueError(f"cannot divide with different modulus {other.modulus()} vs {self.ctx.modulus()}")
600+
o = create_nmod_mpoly(self.ctx)
601+
nmod_mpoly_set_ui(o.val, int(other), o.ctx.val)
602+
603+
res = create_nmod_mpoly(self.ctx)
604+
nmod_mpoly_div(res.val, o.val, self.val, res.ctx.val)
605+
return res
528606
else:
529607
return NotImplemented
530608

@@ -555,6 +633,18 @@ cdef class nmod_mpoly(flint_mpoly):
555633
return res
556634
else:
557635
raise DomainError("nmod_mpoly division is not exact")
636+
elif typecheck(other, nmod):
637+
if other.modulus() != self.ctx.modulus():
638+
raise ValueError(f"cannot divide with different modulus {other.modulus()} vs {self.ctx.modulus()}")
639+
elif not other:
640+
raise ZeroDivisionError("nmod_mpoly division by zero")
641+
res = create_nmod_mpoly(self.ctx)
642+
div = create_nmod_mpoly(self.ctx)
643+
nmod_mpoly_set_ui(div.val, int(other), self.ctx.val)
644+
if nmod_mpoly_divides(res.val, self.val, div.val, self.ctx.val):
645+
return res
646+
else:
647+
raise DomainError("nmod_mpoly division is not exact")
558648
else:
559649
return NotImplemented
560650

@@ -572,6 +662,16 @@ cdef class nmod_mpoly(flint_mpoly):
572662
return res
573663
else:
574664
raise DomainError("nmod_mpoly division is not exact")
665+
elif typecheck(other, nmod):
666+
if other.modulus() != self.ctx.modulus():
667+
raise ValueError(f"cannot divide with different modulus {other.modulus()} vs {self.ctx.modulus()}")
668+
res = create_nmod_mpoly(self.ctx)
669+
div = create_nmod_mpoly(self.ctx)
670+
nmod_mpoly_set_ui(div.val, int(other), self.ctx.val)
671+
if nmod_mpoly_divides(res.val, div.val, self.val, self.ctx.val):
672+
return res
673+
else:
674+
raise DomainError("nmod_mpoly division is not exact")
575675
else:
576676
return NotImplemented
577677

@@ -625,6 +725,10 @@ cdef class nmod_mpoly(flint_mpoly):
625725
nmod_mpoly_add((<nmod_mpoly>self).val, (<nmod_mpoly>self).val, (<nmod_mpoly>other).val, self.ctx.val)
626726
elif isinstance(other, int):
627727
nmod_mpoly_add_ui((<nmod_mpoly>self).val, (<nmod_mpoly>self).val, other, self.ctx.val)
728+
elif typecheck(other, nmod):
729+
if other.modulus() != self.ctx.modulus():
730+
raise ValueError(f"cannot add with different modulus {other.modulus()} vs {self.ctx.modulus()}")
731+
nmod_mpoly_add_ui(self.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
628732
else:
629733
raise NotImplementedError(f"cannot add {type(other)} to nmod_mpoly")
630734

@@ -648,6 +752,10 @@ cdef class nmod_mpoly(flint_mpoly):
648752
nmod_mpoly_sub((<nmod_mpoly>self).val, (<nmod_mpoly>self).val, (<nmod_mpoly>other).val, self.ctx.val)
649753
elif isinstance(other, int):
650754
nmod_mpoly_sub_ui((<nmod_mpoly>self).val, (<nmod_mpoly>self).val, other, self.ctx.val)
755+
elif typecheck(other, nmod):
756+
if other.modulus() != self.ctx.modulus():
757+
raise ValueError(f"cannot subtract with different modulus {other.modulus()} vs {self.ctx.modulus()}")
758+
nmod_mpoly_sub_ui(self.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
651759
else:
652760
raise NotImplementedError(f"cannot subtract {type(other)} from nmod_mpoly")
653761

@@ -671,6 +779,10 @@ cdef class nmod_mpoly(flint_mpoly):
671779
nmod_mpoly_mul((<nmod_mpoly>self).val, (<nmod_mpoly>self).val, (<nmod_mpoly>other).val, self.ctx.val)
672780
elif isinstance(other, int):
673781
nmod_mpoly_scalar_mul_ui(self.val, (<nmod_mpoly>self).val, other, self.ctx.val)
782+
elif typecheck(other, nmod):
783+
if other.modulus() != self.ctx.modulus():
784+
raise ValueError(f"cannot multiply with different modulus {other.modulus()} vs {self.ctx.modulus()}")
785+
nmod_mpoly_scalar_mul_ui(self.val, (<nmod_mpoly>self).val, int(other), self.ctx.val)
674786
else:
675787
raise NotImplementedError(f"cannot multiple nmod_mpoly by {type(other)}")
676788

@@ -748,9 +860,9 @@ cdef class nmod_mpoly(flint_mpoly):
748860
nmod_mpoly res
749861
slong i
750862

751-
args = tuple((self.ctx.variable_to_index(k), v) for k, v in dict_args.items())
863+
args = tuple((self.ctx.variable_to_index(k), int(v)) for k, v in dict_args.items())
752864
for _, v in args:
753-
if not (isinstance(v, int) and v >= 0):
865+
if not v >= 0:
754866
raise TypeError("constants must be non-negative integers")
755867

756868
# Partial application with args in Z. We evaluate the polynomial one variable at a time

0 commit comments

Comments
 (0)