Skip to content

Commit accffa9

Browse files
committed
Support comparisons with more literals
1 parent 51a572e commit accffa9

5 files changed

Lines changed: 53 additions & 19 deletions

File tree

src/flint/test/test_all.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,13 +2685,13 @@ def _all_mpolys():
26852685
(
26862686
flint.nmod_mpoly,
26872687
lambda *args, **kwargs: flint.nmod_mpoly_ctx.get_context(*args, **kwargs, modulus=101),
2688-
int,
2688+
lambda x: flint.nmod(x, 101),
26892689
True,
26902690
),
26912691
(
26922692
flint.nmod_mpoly,
26932693
lambda *args, **kwargs: flint.nmod_mpoly_ctx.get_context(*args, **kwargs, modulus=100),
2694-
int,
2694+
lambda x: flint.nmod(x, 100),
26952695
False,
26962696
),
26972697
]
@@ -2767,10 +2767,15 @@ def quick_poly():
27672767
assert (P(1, ctx=ctx) == P(2, ctx=ctx)) is False
27682768
assert (P(1, ctx=ctx) != P(2, ctx=ctx)) is True
27692769

2770-
assert (P(1, ctx=ctx) == 1) is False
2771-
assert (P(1, ctx=ctx) != 1) is True
2772-
assert (1 == P(1, ctx=ctx)) is False
2773-
assert (1 != P(1, ctx=ctx)) is True
2770+
assert (P(1, ctx=ctx) == 1) is True
2771+
assert (P(1, ctx=ctx) != 1) is False
2772+
assert (1 == P(1, ctx=ctx)) is True
2773+
assert (1 != P(1, ctx=ctx)) is False
2774+
2775+
assert (P(1, ctx=ctx) == S(1)) is True
2776+
assert (P(1, ctx=ctx) != S(1)) is False
2777+
assert (S(1) == P(1, ctx=ctx)) is True
2778+
assert (S(1) != P(1, ctx=ctx)) is False
27742779

27752780
assert (P(1, ctx=ctx) == P(1, ctx=ctx1)) is False
27762781
assert (P(1, ctx=ctx) != P(1, ctx=ctx1)) is True
@@ -2926,7 +2931,7 @@ def quick_poly():
29262931
else {k: ctx.modulus() + v for k, v in {(0, 0): -4, (0, 1): -4, (1, 0): -4, (2, 2): -4}.items()}
29272932
)
29282933

2929-
for T in [int, S, lambda x: P(x, ctx=ctx)]:
2934+
for T in [int, S, int, lambda x: P(x, ctx=ctx)]:
29302935
p = quick_poly()
29312936
p -= T(1)
29322937
q = quick_poly()
@@ -2955,7 +2960,7 @@ def quick_poly():
29552960
(0, 1): 6
29562961
})
29572962

2958-
for T in [int, S, lambda x: P(x, ctx=ctx)]:
2963+
for T in [int, S, int, lambda x: P(x, ctx=ctx)]:
29592964
p = quick_poly()
29602965
p *= T(2)
29612966
q = quick_poly()

src/flint/types/fmpq_mpoly.pyx

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ from flint.flintlib.fmpq_mpoly cimport (
2828
fmpq_mpoly_divides,
2929
fmpq_mpoly_divrem,
3030
fmpq_mpoly_equal,
31+
fmpq_mpoly_equal_fmpq,
32+
fmpq_mpoly_equal_fmpz,
3133
fmpq_mpoly_evaluate_all_fmpq,
3234
fmpq_mpoly_evaluate_one_fmpq,
3335
fmpq_mpoly_gcd,
@@ -256,11 +258,18 @@ cdef class fmpq_mpoly(flint_mpoly):
256258
return op == Py_NE
257259
elif typecheck(self, fmpq_mpoly) and typecheck(other, fmpq_mpoly):
258260
if (<fmpq_mpoly>self).ctx is (<fmpq_mpoly>other).ctx:
259-
return (op == Py_NE) ^ bool(
260-
fmpq_mpoly_equal((<fmpq_mpoly>self).val, (<fmpq_mpoly>other).val, (<fmpq_mpoly>self).ctx.val)
261-
)
261+
return (op == Py_NE) ^ <bint>fmpq_mpoly_equal(self.val, (<fmpq_mpoly>other).val, self.ctx.val)
262262
else:
263263
return op == Py_NE
264+
elif typecheck(other, fmpq):
265+
return (op == Py_NE) ^ <bint>fmpq_mpoly_equal_fmpq(self.val, (<fmpq>other).val, self.ctx.val)
266+
elif typecheck(other, fmpz):
267+
return (op == Py_NE) ^ <bint>fmpq_mpoly_equal_fmpz(self.val, (<fmpz>other).val, self.ctx.val)
268+
elif isinstance(other, int):
269+
other = any_as_fmpz(other)
270+
if other is NotImplemented:
271+
return NotImplemented
272+
return (op == Py_NE) ^ <bint>fmpq_mpoly_equal_fmpz(self.val, (<fmpz>other).val, self.ctx.val)
264273
else:
265274
return NotImplemented
266275

src/flint/types/fmpz_mod_mpoly.pyx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ from flint.flintlib.fmpz_mod_mpoly cimport (
2626
fmpz_mod_mpoly_divides,
2727
fmpz_mod_mpoly_divrem,
2828
fmpz_mod_mpoly_equal,
29+
fmpz_mod_mpoly_equal_fmpz,
2930
fmpz_mod_mpoly_evaluate_all_fmpz,
3031
fmpz_mod_mpoly_evaluate_one_fmpz,
3132
fmpz_mod_mpoly_gcd,
@@ -308,6 +309,13 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
308309
)
309310
else:
310311
return op == Py_NE
312+
elif typecheck(other, fmpz):
313+
return (op == Py_NE) ^ <bint>fmpz_mod_mpoly_equal_fmpz(self.val, (<fmpz>other).val, self.ctx.val)
314+
elif isinstance(other, int):
315+
other = any_as_fmpz(other)
316+
if other is NotImplemented:
317+
return NotImplemented
318+
return (op == Py_NE) ^ <bint>fmpz_mod_mpoly_equal_fmpz(self.val, (<fmpz>other).val, self.ctx.val)
311319
else:
312320
return NotImplemented
313321

src/flint/types/fmpz_mpoly.pyx

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ from flint.flintlib.fmpz_mpoly cimport (
2424
fmpz_mpoly_divides,
2525
fmpz_mpoly_divrem,
2626
fmpz_mpoly_equal,
27+
fmpz_mpoly_equal_fmpz,
2728
fmpz_mpoly_evaluate_all_fmpz,
2829
fmpz_mpoly_evaluate_one_fmpz,
2930
fmpz_mpoly_gcd,
3031
fmpz_mpoly_gen,
3132
fmpz_mpoly_get_coeff_fmpz_fmpz,
3233
fmpz_mpoly_get_str_pretty,
33-
fmpz_mpoly_get_term,
3434
fmpz_mpoly_get_term_coeff_fmpz,
3535
fmpz_mpoly_get_term_exp_fmpz,
3636
fmpz_mpoly_integral,
@@ -234,13 +234,18 @@ cdef class fmpz_mpoly(flint_mpoly):
234234
return NotImplemented
235235
elif other is None:
236236
return op == Py_NE
237-
elif typecheck(self, fmpz_mpoly) and typecheck(other, fmpz_mpoly):
237+
elif typecheck(other, fmpz_mpoly):
238238
if (<fmpz_mpoly>self).ctx is (<fmpz_mpoly>other).ctx:
239-
return (op == Py_NE) ^ bool(
240-
fmpz_mpoly_equal((<fmpz_mpoly>self).val, (<fmpz_mpoly>other).val, (<fmpz_mpoly>self).ctx.val)
241-
)
239+
return (op == Py_NE) ^ <bint>fmpz_mpoly_equal(self.val, (<fmpz_mpoly>other).val, self.ctx.val)
242240
else:
243241
return op == Py_NE
242+
elif typecheck(other, fmpz):
243+
return (op == Py_NE) ^ <bint>fmpz_mpoly_equal_fmpz(self.val, (<fmpz>other).val, self.ctx.val)
244+
elif isinstance(other, int):
245+
other = any_as_fmpz(other)
246+
if other is NotImplemented:
247+
return NotImplemented
248+
return (op == Py_NE) ^ <bint>fmpz_mpoly_equal_fmpz(self.val, (<fmpz>other).val, self.ctx.val)
244249
else:
245250
return NotImplemented
246251

src/flint/types/nmod_mpoly.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ from flint.utils.flint_exceptions import DomainError, IncompatibleContextError
1212
from flint.types.fmpz cimport any_as_fmpz, fmpz
1313
from flint.types.fmpz_vec cimport fmpz_vec
1414

15+
from flint.types.nmod cimport nmod
16+
1517
from flint.flintlib.fmpz cimport fmpz_set
1618
from flint.flintlib.nmod_mpoly cimport (
1719
nmod_mpoly_add,
@@ -26,6 +28,7 @@ from flint.flintlib.nmod_mpoly cimport (
2628
nmod_mpoly_divides,
2729
nmod_mpoly_divrem,
2830
nmod_mpoly_equal,
31+
nmod_mpoly_equal_ui,
2932
nmod_mpoly_evaluate_all_ui,
3033
nmod_mpoly_evaluate_one_ui,
3134
nmod_mpoly_gcd,
@@ -283,11 +286,15 @@ cdef class nmod_mpoly(flint_mpoly):
283286
return op == Py_NE
284287
elif typecheck(self, nmod_mpoly) and typecheck(other, nmod_mpoly):
285288
if (<nmod_mpoly>self).ctx is (<nmod_mpoly>other).ctx:
286-
return (op == Py_NE) ^ bool(
287-
nmod_mpoly_equal((<nmod_mpoly>self).val, (<nmod_mpoly>other).val, (<nmod_mpoly>self).ctx.val)
288-
)
289+
return (op == Py_NE) ^ <bint>nmod_mpoly_equal(self.val, (<nmod_mpoly>other).val, self.ctx.val)
289290
else:
290291
return op == Py_NE
292+
elif typecheck(other, nmod):
293+
if other.modulus() != self.ctx.modulus():
294+
raise ValueError(f"cannot compare different modulus {other.modulus()} vs {self.ctx.modulus()}")
295+
return (op == Py_NE) ^ <bint>nmod_mpoly_equal_ui(self.val, int(other), self.ctx.val)
296+
elif isinstance(other, int):
297+
return (op == Py_NE) ^ <bint>nmod_mpoly_equal_ui(self.val, other, self.ctx.val)
291298
else:
292299
return NotImplemented
293300

0 commit comments

Comments
 (0)