@@ -40,13 +40,15 @@ from flint.flintlib.fmpz_mod_mpoly cimport (
4040 fmpz_mod_mpoly_get_str_pretty,
4141 fmpz_mod_mpoly_get_term_coeff_fmpz,
4242 fmpz_mod_mpoly_get_term_exp_fmpz,
43+ fmpz_mod_mpoly_inflate,
4344 fmpz_mod_mpoly_is_one,
4445 fmpz_mod_mpoly_is_zero,
4546 fmpz_mod_mpoly_length,
4647 fmpz_mod_mpoly_mul,
4748 fmpz_mod_mpoly_neg,
4849 fmpz_mod_mpoly_pow_fmpz,
4950 fmpz_mod_mpoly_push_term_fmpz_ffmpz,
51+ fmpz_mod_mpoly_push_term_ui_ffmpz,
5052 fmpz_mod_mpoly_resultant,
5153 fmpz_mod_mpoly_scalar_mul_fmpz,
5254 fmpz_mod_mpoly_set,
@@ -1074,27 +1076,139 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
10741076 fmpz_mod_mpoly_derivative(res.val, self .val, i, self .ctx.val)
10751077 return res
10761078
1077- def deflation (self ) :
1079+ def inflate (self , N: list[int]) -> fmpz_mod_mpoly :
10781080 """
1079- Compute the deflation of ``self``. See Flint documentation for
1080- details. Returns deflated polynomial and the stride vector.
1081+ Compute the inflation of ``self`` for a provided ``N``, that is return ``q``
1082+ such that ``q(X ) = p(X^N )``.
1083+
1084+ >>> from flint import Ordering
1085+ >>> ctx = fmpz_mod_mpoly_ctx.get_context(2 , Ordering.lex, 11 , nametup = (' x' , ' y' ))
1086+ >>> x , y = ctx.gens()
1087+ >>> f = x + y + 1
1088+ >>> f.inflate([2, 3])
1089+ x^2 + y^3 + 1
1090+ """
1091+
1092+ cdef nvars = self .ctx.nvars()
1093+
1094+ if nvars != len(N ):
1095+ raise ValueError (f" expected list of length {nvars}, got {len(N)}" )
1096+ elif any (n < 0 for n in N):
1097+ raise ValueError (" all inflate strides must be non-negative" )
1098+
1099+ cdef:
1100+ fmpz_vec shift = fmpz_vec(nvars)
1101+ fmpz_vec stride = fmpz_vec(N)
1102+ fmpz_mod_mpoly res = create_fmpz_mod_mpoly(self .ctx)
1103+
1104+ fmpz_mod_mpoly_inflate(res.val, self .val, shift.val, stride.val, self .ctx.val)
1105+ return res
1106+
1107+ def deflate (self , N: list[int]) -> fmpz_mod_mpoly:
1108+ """
1109+ Compute the deflation of ``self`` for a provided ``N``, that is return ``q``
1110+ such that ``q(X ) = p(X^(1/N ))``.
10811111
10821112 >>> from flint import Ordering
10831113 >>> ctx = fmpz_mod_mpoly_ctx.get_context(2 , Ordering.lex, 11 , nametup = (' x' , ' y' ))
10841114 >>> x , y = ctx.gens()
10851115 >>> f = x** 3 * y + x * y** 4 + x * y
1086- >>> f.deflation( )
1087- ( x + y + 1, fmpz_vec(['2', '3'], 2))
1116+ >>> f.deflate([2, 3] )
1117+ x + y + 1
10881118 """
1119+ cdef slong nvars = self .ctx.nvars()
1120+
1121+ if nvars != len(N ):
1122+ raise ValueError (f" expected list of length {nvars}, got {len(N)}" )
1123+
10891124 cdef:
1090- fmpz_vec shift = fmpz_vec(self .ctx. nvars() )
1091- fmpz_vec stride = fmpz_vec(self .ctx.nvars() )
1125+ fmpz_vec shift = fmpz_vec(nvars)
1126+ fmpz_vec stride = fmpz_vec(N )
10921127 fmpz_mod_mpoly res = create_fmpz_mod_mpoly(self .ctx)
10931128
1094- fmpz_mod_mpoly_deflation(shift.val, stride.val, self .val, self .ctx.val)
10951129 fmpz_mod_mpoly_deflate(res.val, self .val, shift.val, stride.val, self .ctx.val)
1130+ return res
1131+
1132+ def deflation (self ) -> tuple[fmpz_mod_mpoly , list[int]]:
1133+ """
1134+ Compute the deflation of ``self``, that is ``p(X^(1/N ))`` for maximal
1135+ N. Returns ``q , N`` such that ``self == q.inflate(N )``.
1136+
1137+ >>> from flint import Ordering
1138+ >>> ctx = fmpz_mod_mpoly_ctx.get_context(2 , Ordering.lex, 11 , nametup = (' x' , ' y' ))
1139+ >>> x , y = ctx.gens()
1140+ >>> f = x** 3 * y + x * y** 4 + x * y
1141+ >>> q , N = f.deflation()
1142+ >>> q , N
1143+ (x + y + 1, [2, 3])
1144+ """
1145+ cdef:
1146+ fmpz_vec _shift = fmpz_vec(self .ctx.nvars())
1147+ fmpz_vec stride = fmpz_vec(self .ctx.nvars())
1148+ fmpz_mod_mpoly res = create_fmpz_mod_mpoly(self .ctx)
1149+
1150+ fmpz_mod_mpoly_deflation(_shift.val , stride.val , self.val , self.ctx.val )
1151+
1152+ cdef fmpz_vec zero_shift = fmpz_vec(self .ctx.nvars())
1153+ fmpz_mod_mpoly_deflate(res.val , self.val , zero_shift.val , stride.val , self.ctx.val )
1154+
1155+ return res , list(stride )
1156+
1157+ def deflation_monom(self ) -> tuple[list[int], fmpz_mod_mpoly]:
1158+ """
1159+ Compute the exponent vector ``N`` and monomial ``m`` such that ``p(X^(1/N ))
1160+ = m * q(X^N )`` for maximal N. Importantly the deflation itself is not computed
1161+ here. The returned monomial allows the undo-ing of the deflation.
1162+
1163+ >>> from flint import Ordering
1164+ >>> ctx = fmpz_mod_mpoly_ctx.get_context(2 , Ordering.lex, 11 , nametup = (' x' , ' y' ))
1165+ >>> x , y = ctx.gens()
1166+ >>> f = x** 3 * y + x * y** 4 + x * y
1167+ >>> N , m = f.deflation_monom()
1168+ >>> N , m
1169+ ([2, 3], x*y )
1170+ >>> f_deflated = f.deflate(N)
1171+ >>> f_deflated
1172+ x + y + 1
1173+ >>> m * f_deflated.inflate(N )
1174+ x^3*y + x*y^4 + x*y
1175+ """
1176+ cdef fmpz_mod_mpoly monom = create_fmpz_mod_mpoly(self .ctx)
10961177
1097- return res, stride
1178+ stride , _shift = self .deflation_index()
1179+
1180+ fmpz_mod_mpoly_push_term_ui_ffmpz(monom.val , 1, fmpz_vec(_shift ).val , self.ctx.val )
1181+ return list(stride ), monom
1182+
1183+ def deflation_index(self ) -> tuple[list[int], list[int]]:
1184+ """
1185+ Compute the exponent vectors ``N`` and ``I`` such that ``p(X^(1/N )) = X^I *
1186+ q(X^N )`` for maximal N. Importantly the deflation itself is not computed
1187+ here. The returned exponent vector ``I`` is the shift that was applied to the
1188+ exponents. It is the exponent vector of the monomial returned by
1189+ ``deflation_monom``.
1190+
1191+ >>> from flint import Ordering
1192+ >>> ctx = fmpz_mod_mpoly_ctx.get_context(2 , Ordering.lex, 11 , nametup = (' x' , ' y' ))
1193+ >>> x , y = ctx.gens()
1194+ >>> f = x** 3 * y + x * y** 4 + x * y
1195+ >>> N , I = f.deflation_index()
1196+ >>> N , I
1197+ ([2, 3], [1, 1])
1198+ >>> f_deflated = f.deflate(N)
1199+ >>> f_deflated
1200+ x + y + 1
1201+ >>> m = ctx.term(exp_vec = I)
1202+ >>> m * f_deflated.inflate(N )
1203+ x^3*y + x*y^4 + x*y
1204+ """
1205+ cdef:
1206+ slong nvars = self .ctx.nvars()
1207+ fmpz_vec shift = fmpz_vec(nvars)
1208+ fmpz_vec stride = fmpz_vec(nvars)
1209+
1210+ fmpz_mod_mpoly_deflation(shift.val , stride.val , self.val , self.ctx.val )
1211+ return list(stride ), list(shift )
10981212
10991213
11001214cdef class fmpz_mod_mpoly_vec:
0 commit comments