Skip to content

Commit e8a296a

Browse files
committed
acb_mat: add type stubs
1 parent e86c35d commit e8a296a

6 files changed

Lines changed: 489 additions & 53 deletions

File tree

src/flint/test/helpers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Callable, Sequence
44

5-
from flint import acb, acb_poly, acb_series, arb, arb_mat, arb_poly, arb_series
5+
from flint import acb, acb_mat, acb_poly, acb_series, arb, arb_mat, arb_poly, arb_series
66

77

88
def raises(f: Callable[[], object], exception: type[Exception]) -> bool:
@@ -127,6 +127,26 @@ def is_close_arb_mat(
127127
return True
128128

129129

130+
def is_close_acb_mat(
131+
x: acb_mat,
132+
y: acb_mat | Sequence[Sequence[int | float | complex | str | acb]],
133+
*,
134+
tol: int | float | str | arb = 1e-10,
135+
rel_tol: int | float | str | arb = 1e-10,
136+
max_width: int | float | str | arb = 1e-10,
137+
) -> bool:
138+
if not isinstance(x, acb_mat):
139+
return False
140+
y = acb_mat(y)
141+
if x.nrows() != y.nrows() or x.ncols() != y.ncols():
142+
return False
143+
for i in range(x.nrows()):
144+
for j in range(x.ncols()):
145+
if not is_close_acb(x[i, j], y[i, j], tol=tol, rel_tol=rel_tol, max_width=max_width):
146+
return False
147+
return True
148+
149+
130150
def is_close_acb_series(
131151
x: acb_series,
132152
y: acb_series | Sequence[int | float | complex | acb],

src/flint/test/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pyfiles = [
66
'helpers.py',
77
'test_all.py',
88
'test_acb.py',
9+
'test_acb_mat.py',
910
'test_acb_poly.py',
1011
'test_acb_series.py',
1112
'test_arb.py',

src/flint/test/test_acb_mat.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from unittest.mock import patch
5+
6+
from flint import acb, acb_mat, arb_mat, ctx, fmpq_mat, fmpz_mat
7+
from flint.test.helpers import is_close_acb, is_close_acb_mat as is_close, is_close_arb_mat, raises
8+
9+
10+
class _DummyMatrix:
11+
rows = 2
12+
cols = 2
13+
14+
def __getitem__(self, index: tuple[int, int]) -> complex:
15+
i, j = index
16+
return complex(i + 2 * j + 1, i - j)
17+
18+
19+
def test_acb_mat_constructor() -> None:
20+
z = fmpz_mat([[1, 2], [3, 4]])
21+
q = fmpq_mat([[1, 2], [3, 4]])
22+
a = arb_mat([[1, 2], [3, 4]])
23+
b = acb_mat([[1, 2], [3, 4]])
24+
25+
assert is_close(acb_mat(b), [[1, 2], [3, 4]])
26+
assert is_close(acb_mat(a), [[1, 2], [3, 4]])
27+
assert is_close(acb_mat(z), [[1, 2], [3, 4]])
28+
assert is_close(acb_mat(q), [[1, 2], [3, 4]])
29+
assert is_close(acb_mat(_DummyMatrix()), [[1, 3 - 1j], [2 + 1j, 4]])
30+
31+
assert isinstance(acb_mat.convert(z), acb_mat)
32+
assert isinstance(acb_mat.convert(q), acb_mat)
33+
assert isinstance(acb_mat.convert(a), acb_mat)
34+
assert raises(lambda: acb_mat.convert(object()), TypeError) # type: ignore[arg-type]
35+
36+
assert is_close(acb_mat(2, 2), [[0, 0], [0, 0]])
37+
assert is_close(acb_mat(2, 2, [1, 2, 3, 4]), [[1, 2], [3, 4]])
38+
assert is_close(acb_mat(3, 3, 5), [[5, 0, 0], [0, 5, 0], [0, 0, 5]])
39+
40+
assert raises(lambda: acb_mat([1, 2]), TypeError) # type: ignore[arg-type,list-item]
41+
assert raises(lambda: acb_mat([[1], [2, 3]]), ValueError)
42+
assert raises(lambda: acb_mat(object()), TypeError) # type: ignore[call-overload]
43+
assert raises(lambda: acb_mat(2, 2, [1, 2, 3]), ValueError)
44+
assert raises(lambda: acb_mat(1, 2, 3, 4), ValueError) # type: ignore[call-overload]
45+
46+
47+
def test_acb_mat_basics_and_indexing() -> None:
48+
a = acb_mat([[1, 2], [3, 4]])
49+
assert a.nrows() == 2
50+
assert a.ncols() == 2
51+
assert is_close_acb(a[0, 1], 2)
52+
a[0, 1] = 1 + 2j
53+
assert is_close_acb(a[0, 1], 1 + 2j)
54+
55+
assert raises(lambda: a[2, 0], ValueError) # type: ignore[index]
56+
assert raises(lambda: a[0, 2], ValueError) # type: ignore[index]
57+
58+
def set_oob_1() -> None:
59+
a[2, 0] = 1
60+
61+
def set_oob_2() -> None:
62+
a[0, 2] = 1
63+
64+
def set_bad() -> None:
65+
a[0, 0] = object() # type: ignore[assignment]
66+
67+
assert raises(set_oob_1, ValueError)
68+
assert raises(set_oob_2, ValueError)
69+
assert raises(set_bad, TypeError)
70+
71+
assert is_close(a.transpose(), [[1, 3], [1 + 2j, 4]])
72+
assert is_close(a.conjugate(), [[1, 1 - 2j], [3, 4]])
73+
assert is_close(+a, a)
74+
assert is_close(-a, [[-1, -1 - 2j], [-3, -4]])
75+
assert raises(lambda: bool(a), NotImplementedError)
76+
77+
oldpretty = ctx.pretty
78+
try:
79+
ctx.pretty = False
80+
assert "acb_mat(" in repr(a)
81+
finally:
82+
ctx.pretty = oldpretty
83+
assert "[" in str(a)
84+
85+
86+
def test_acb_mat_add_sub() -> None:
87+
a = acb_mat([[1, 2], [3, 4]])
88+
b = acb_mat([[4, 5], [6, 7]])
89+
z = fmpz_mat([[4, 5], [6, 7]])
90+
q = fmpq_mat([[4, 5], [6, 7]])
91+
r = arb_mat([[4, 5], [6, 7]])
92+
93+
assert is_close(a + b, [[5, 7], [9, 11]])
94+
assert is_close(a + z, [[5, 7], [9, 11]])
95+
assert is_close(a + q, [[5, 7], [9, 11]])
96+
assert is_close(a + r, [[5, 7], [9, 11]])
97+
assert is_close(z + a, [[5, 7], [9, 11]])
98+
assert is_close(q + a, [[5, 7], [9, 11]])
99+
assert is_close(r + a, [[5, 7], [9, 11]])
100+
101+
assert is_close(a - b, [[-3, -3], [-3, -3]])
102+
assert is_close(a - z, [[-3, -3], [-3, -3]])
103+
assert is_close(a - q, [[-3, -3], [-3, -3]])
104+
assert is_close(a - r, [[-3, -3], [-3, -3]])
105+
assert is_close(z - a, [[3, 3], [3, 3]])
106+
assert is_close(q - a, [[3, 3], [3, 3]])
107+
assert is_close(r - a, [[3, 3], [3, 3]])
108+
109+
assert is_close(a + 2, [[3, 2], [3, 6]])
110+
assert is_close(2 + a, [[3, 2], [3, 6]])
111+
assert is_close(a - 2, [[-1, 2], [3, 2]])
112+
assert is_close(2 - a, [[1, -2], [-3, -2]])
113+
114+
assert is_close(a + (1 + 2j), [[2 + 2j, 2], [3, 5 + 2j]])
115+
assert is_close((1 + 2j) + a, [[2 + 2j, 2], [3, 5 + 2j]])
116+
assert is_close(a - (1 + 2j), [[-2j, 2], [3, 3 - 2j]])
117+
assert is_close((1 + 2j) - a, [[2j, -2], [-3, -3 + 2j]])
118+
119+
assert raises(lambda: a + acb_mat([[1, 2, 3]]), ValueError)
120+
assert raises(lambda: a - acb_mat([[1, 2, 3]]), ValueError)
121+
assert raises(lambda: a + object(), TypeError) # type: ignore[operator]
122+
assert raises(lambda: object() + a, TypeError) # type: ignore[operator]
123+
assert raises(lambda: a - object(), TypeError) # type: ignore[operator]
124+
assert raises(lambda: object() - a, TypeError) # type: ignore[operator]
125+
126+
127+
def test_acb_mat_mul_div() -> None:
128+
a = acb_mat([[1, 2], [3, 4]])
129+
b = acb_mat([[4, 5], [6, 7]])
130+
z = fmpz_mat([[4, 5], [6, 7]])
131+
q = fmpq_mat([[4, 5], [6, 7]])
132+
r = arb_mat([[4, 5], [6, 7]])
133+
134+
assert is_close(a * b, [[16, 19], [36, 43]])
135+
assert is_close(a * z, [[16, 19], [36, 43]])
136+
assert is_close(a * q, [[16, 19], [36, 43]])
137+
assert is_close(a * r, [[16, 19], [36, 43]])
138+
assert is_close(z * a, [[19, 28], [27, 40]])
139+
assert is_close(q * a, [[19, 28], [27, 40]])
140+
assert is_close(r * a, [[19, 28], [27, 40]])
141+
142+
assert is_close(a * 2, [[2, 4], [6, 8]])
143+
assert is_close(2 * a, [[2, 4], [6, 8]])
144+
assert is_close(a * 0.5, [[0.5, 1], [1.5, 2]])
145+
assert is_close(a / 2, [[0.5, 1], [1.5, 2]])
146+
assert is_close(a * (1 + 2j), [[1 + 2j, 2 + 4j], [3 + 6j, 4 + 8j]])
147+
assert is_close((1 + 2j) * a, [[1 + 2j, 2 + 4j], [3 + 6j, 4 + 8j]])
148+
149+
assert raises(lambda: a * acb_mat([[1, 2, 3]]), ValueError)
150+
assert raises(lambda: a * object(), TypeError) # type: ignore[operator]
151+
assert raises(lambda: object() * a, TypeError) # type: ignore[operator]
152+
assert raises(lambda: a / object(), TypeError) # type: ignore[operator]
153+
154+
155+
def test_acb_mat_pow_inv_solve() -> None:
156+
a = acb_mat([[1, 2], [3, 4]])
157+
assert is_close(a**2, [[7, 10], [15, 22]])
158+
assert raises(lambda: pow(a, 2, 3), TypeError) # type: ignore[misc]
159+
assert raises(lambda: acb_mat([[1, 2, 3]]) ** 2, ValueError)
160+
161+
ai = a.inv()
162+
assert is_close(a * ai, [[1, 0], [0, 1]], tol=1e-8, rel_tol=1e-8, max_width=1e-8)
163+
assert raises(lambda: acb_mat([[1, 2], [2, 4]]).inv(), ZeroDivisionError)
164+
assert raises(lambda: acb_mat([[1, 2, 3]]).inv(), ValueError)
165+
166+
inv_ns = acb_mat([[1, 2], [2, 4]]).inv(nonstop=True)
167+
assert inv_ns[0, 0].is_finite() is False
168+
assert inv_ns[1, 1].is_finite() is False
169+
170+
x = acb_mat([[1], [2]])
171+
b = a * x
172+
assert is_close(a.solve(b), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
173+
assert is_close(a.solve(b, algorithm="lu"), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
174+
assert is_close(a.solve(b, algorithm="precond"), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
175+
assert is_close(a.solve(b, algorithm="approx"), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
176+
177+
assert raises(lambda: a.solve(b, algorithm="bad"), ValueError) # type: ignore[arg-type]
178+
assert raises(lambda: acb_mat([[1, 2], [2, 4]]).solve(b), ZeroDivisionError)
179+
180+
solve_ns = acb_mat([[1, 2], [2, 4]]).solve(b, nonstop=True)
181+
assert solve_ns[0, 0].is_finite() is False
182+
assert solve_ns[1, 0].is_finite() is False
183+
184+
assert raises(lambda: acb_mat([[1, 2, 3]]).solve(acb_mat([[1], [2], [3]])), ValueError)
185+
assert raises(lambda: a.solve([[1], [2]]), TypeError) # type: ignore[arg-type]
186+
187+
188+
def test_acb_mat_special_methods() -> None:
189+
a = acb_mat([[1, 2], [3, 4]])
190+
assert is_close_acb(a.det(), -2)
191+
assert is_close_acb(a.trace(), 5)
192+
assert is_close(a.mid(), a)
193+
194+
e = acb_mat(2, 2, [1, 4, -2, 1]).exp()
195+
assert is_close(
196+
e,
197+
[
198+
[-2.58607310345045, 1.18429895089106],
199+
[-0.592149475445530, -2.58607310345045],
200+
],
201+
tol=1e-12,
202+
rel_tol=1e-12,
203+
max_width=1e-12,
204+
)
205+
206+
assert raises(lambda: acb_mat([[1, 2, 3]]).det(), ValueError)
207+
assert raises(lambda: acb_mat([[1, 2, 3]]).trace(), ValueError)
208+
assert raises(lambda: acb_mat([[1, 2, 3]]).exp(), ValueError)
209+
210+
p = acb_mat([[1, 1], [1, 0]]).charpoly()
211+
assert is_close_acb(p[0], -1)
212+
assert is_close_acb(p[1], -1)
213+
assert is_close_acb(p[2], 1)
214+
assert raises(lambda: acb_mat([[1, 2, 3]]).charpoly(), ValueError)
215+
216+
d = acb_mat.dft(3)
217+
assert d.nrows() == 3
218+
assert d.ncols() == 3
219+
assert is_close_acb(d[0, 0], 0.5773502691896257)
220+
assert is_close_acb(d[1, 1], acb(-0.28867513459481287, -0.5), tol=1e-12, rel_tol=1e-12)
221+
222+
d2 = acb_mat.dft(2, 3)
223+
assert d2.nrows() == 2
224+
assert d2.ncols() == 3
225+
226+
227+
def test_acb_mat_contains_overlap_chop_cmp_real_imag() -> None:
228+
a = acb_mat([[1, 2], [3, 4]])
229+
b = (a / 3) * 3
230+
231+
assert b.contains(a) is True
232+
assert a.contains(b) is False
233+
assert b.contains(fmpz_mat([[1, 2], [3, 4]])) is True
234+
assert (a / 3).contains(fmpq_mat([[1, 2], [3, 4]]) / 3) is True
235+
assert ((a / 3) * 3).contains(arb_mat([[1, 2], [3, 4]])) is True
236+
assert raises(lambda: a.contains(object()), TypeError) # type: ignore[arg-type]
237+
238+
assert b.overlaps(a) is True
239+
assert (a + 100).overlaps(a) is False
240+
241+
c = acb_mat([[1e-20 + 1e-20j, 2], [3j, -1e-20 + 1e-20j]])
242+
chopped = c.chop(1e-10)
243+
assert is_close(chopped, [[0, 2], [3j, 0]])
244+
245+
assert (a == acb_mat([[1, 2], [3, 4]])) is True
246+
assert (a != acb_mat([[1, 2], [3, 4]])) is False
247+
assert (a == fmpz_mat([[1, 2], [3, 4]])) is True
248+
assert (a != fmpz_mat([[1, 2], [3, 5]])) is True
249+
assert raises(lambda: a < acb_mat([[1, 2], [3, 4]]), ValueError) # type: ignore[operator]
250+
assert raises(lambda: a <= acb_mat([[1, 2], [3, 4]]), ValueError) # type: ignore[operator]
251+
assert raises(lambda: a > acb_mat([[1, 2], [3, 4]]), ValueError) # type: ignore[operator]
252+
assert raises(lambda: a >= acb_mat([[1, 2], [3, 4]]), ValueError) # type: ignore[operator]
253+
assert (a == object()) is False
254+
assert (a != object()) is True
255+
256+
d = acb_mat.dft(3)
257+
assert is_close_arb_mat(d.real, [[0.5773502691896257, 0.5773502691896257, 0.5773502691896257], [0.5773502691896257, -0.28867513459481287, -0.28867513459481287], [0.5773502691896257, -0.28867513459481287, -0.28867513459481287]])
258+
assert is_close_arb_mat(d.imag, [[0, 0, 0], [0, -0.5, 0.5], [0, 0.5, -0.5]])
259+
260+
261+
def test_acb_mat_eig_theta_and_helper() -> None:
262+
a = acb_mat([[1, 0], [0, 2]])
263+
264+
vals = a.eig()
265+
assert len(vals) == 2
266+
assert any(v.real.contains(1) for v in vals)
267+
assert any(v.real.contains(2) for v in vals)
268+
269+
vals_l, left = a.eig(left=True)
270+
assert len(vals_l) == 2
271+
assert isinstance(left, acb_mat)
272+
273+
vals_r, right = a.eig(right=True)
274+
assert len(vals_r) == 2
275+
assert isinstance(right, acb_mat)
276+
277+
vals_lr, left2, right2 = a.eig(left=True, right=True)
278+
assert len(vals_lr) == 2
279+
assert isinstance(left2, acb_mat)
280+
assert isinstance(right2, acb_mat)
281+
282+
vals_approx = acb_mat.dft(4).eig(algorithm="approx")
283+
assert len(vals_approx) == 4
284+
285+
vals_rump = a.eig(algorithm="rump")
286+
assert len(vals_rump) == 2
287+
288+
vals_vm = a.eig(algorithm="vdhoeven_mourrain")
289+
assert len(vals_vm) == 2
290+
291+
vals_tol = a.eig(tol=1e-12)
292+
assert len(vals_tol) == 2
293+
294+
vals_multi = acb_mat.dft(4).eig(multiple=True)
295+
assert len(vals_multi) == 4
296+
297+
vals_multi_rump = acb_mat.dft(4).eig(multiple=True, algorithm="rump")
298+
assert len(vals_multi_rump) == 4
299+
300+
assert raises(lambda: acb_mat.dft(4).eig(multiple=True, right=True), NotImplementedError)
301+
assert raises(lambda: acb_mat.dft(4).eig(), ValueError)
302+
assert raises(lambda: acb_mat([[1, 2, 3]]).eig(), ValueError)
303+
304+
assert acb_mat(0, 0).eig() == []
305+
306+
tau = acb_mat([[1j]])
307+
theta_vals = acb_mat.theta(tau, acb_mat([[0]]))
308+
assert isinstance(theta_vals, acb_mat)
309+
assert theta_vals.nrows() == 1
310+
assert theta_vals.ncols() == 4
311+
312+
with patch.dict(sys.modules, {"flint.types.acb_theta": None}):
313+
assert raises(lambda: acb_mat.theta(tau, acb_mat([[0]])), NotImplementedError)
314+
315+
assert is_close(a, [[1, 0], [0, 2]]) is True
316+
assert is_close(a, acb_mat([[1, 0], [0, 2]])) is True
317+
assert is_close(a, [[1, 0, 0], [0, 2, 0]]) is False
318+
assert is_close(a, [[1, 0], [0, 3]]) is False
319+
assert is_close(object(), [[1]]) is False # type: ignore[arg-type]

0 commit comments

Comments
 (0)