Skip to content

Commit e86c35d

Browse files
committed
arb_mat: add typing and tests
1 parent ee382ab commit e86c35d

6 files changed

Lines changed: 425 additions & 46 deletions

File tree

src/flint/test/helpers.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Callable, Sequence
44

5-
from flint import acb, acb_poly, acb_series, arb, arb_poly, arb_series
5+
from flint import acb, acb_poly, acb_series, arb, arb_mat, arb_poly, arb_series
66

77

88
def raises(f: Callable[[], object], exception: type[Exception]) -> bool:
@@ -107,6 +107,26 @@ def is_close_arb_series(
107107
return True
108108

109109

110+
def is_close_arb_mat(
111+
x: arb_mat,
112+
y: arb_mat | Sequence[Sequence[int | float | str | arb]],
113+
*,
114+
tol: int | float | str | arb = 1e-10,
115+
rel_tol: int | float | str | arb = 1e-10,
116+
max_width: int | float | str | arb = 1e-10,
117+
) -> bool:
118+
if not isinstance(x, arb_mat):
119+
return False
120+
y = arb_mat(y)
121+
if x.nrows() != y.nrows() or x.ncols() != y.ncols():
122+
return False
123+
for i in range(x.nrows()):
124+
for j in range(x.ncols()):
125+
if not is_close_arb(x[i, j], y[i, j], tol=tol, rel_tol=rel_tol, max_width=max_width):
126+
return False
127+
return True
128+
129+
110130
def is_close_acb_series(
111131
x: acb_series,
112132
y: acb_series | Sequence[int | float | complex | acb],

src/flint/test/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pyfiles = [
99
'test_acb_poly.py',
1010
'test_acb_series.py',
1111
'test_arb.py',
12+
'test_arb_mat.py',
1213
'test_arb_series.py',
1314
'test_arb_poly.py',
1415
'test_fmpq_vec.py',

src/flint/test/test_arb_mat.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
from __future__ import annotations
2+
3+
from flint import acb_mat, arb, arb_mat, ctx, fmpq_mat, fmpz_mat
4+
from flint.test.helpers import is_close_arb, is_close_arb_mat as is_close, raises
5+
6+
7+
class _DummyMatrix:
8+
rows = 2
9+
cols = 2
10+
11+
def __getitem__(self, index: tuple[int, int]) -> float:
12+
i, j = index
13+
return float(i + 2 * j + 1)
14+
15+
16+
def test_arb_mat_constructor() -> None:
17+
z = fmpz_mat([[1, 2], [3, 4]])
18+
q = fmpq_mat([[1, 2], [3, 4]])
19+
a = arb_mat([[1, 2], [3, 4]])
20+
b = arb_mat(a)
21+
assert is_close(b, [[1, 2], [3, 4]])
22+
assert is_close(arb_mat(z), [[1, 2], [3, 4]])
23+
assert is_close(arb_mat(q), [[1, 2], [3, 4]])
24+
assert is_close(arb_mat(_DummyMatrix()), [[1, 3], [2, 4]])
25+
assert isinstance(arb_mat.convert(z), arb_mat)
26+
assert isinstance(arb_mat.convert(q), arb_mat)
27+
assert raises(lambda: arb_mat.convert(object()), TypeError) # type: ignore[arg-type]
28+
29+
assert is_close(arb_mat(2, 2), [[0, 0], [0, 0]])
30+
assert is_close(arb_mat(2, 2, [1, 2, 3, 4]), [[1, 2], [3, 4]])
31+
assert is_close(arb_mat(3, 3, 5), [[5, 0, 0], [0, 5, 0], [0, 0, 5]])
32+
33+
assert raises(lambda: arb_mat([1, 2]), TypeError) # type: ignore[arg-type,list-item]
34+
assert raises(lambda: arb_mat([[1], [2, 3]]), ValueError)
35+
assert raises(lambda: arb_mat(object()), TypeError) # type: ignore[call-overload]
36+
assert raises(lambda: arb_mat(2, 2, [1, 2, 3]), ValueError)
37+
assert raises(lambda: arb_mat(1, 2, 3, 4), ValueError) # type: ignore[call-overload]
38+
39+
40+
def test_arb_mat_basics_and_indexing() -> None:
41+
a = arb_mat([[1, 2], [3, 4]])
42+
assert a.nrows() == 2
43+
assert a.ncols() == 2
44+
assert is_close_arb(a[0, 1], 2)
45+
a[0, 1] = 1.5
46+
assert is_close_arb(a[0, 1], 1.5)
47+
assert raises(lambda: a[2, 0], ValueError) # type: ignore[index]
48+
assert raises(lambda: a[0, 2], ValueError) # type: ignore[index]
49+
def set_oob_1() -> None:
50+
a[2, 0] = 1
51+
def set_oob_2() -> None:
52+
a[0, 2] = 1
53+
def set_bad() -> None:
54+
a[0, 0] = object() # type: ignore[assignment]
55+
assert raises(set_oob_1, ValueError)
56+
assert raises(set_oob_2, ValueError)
57+
assert raises(set_bad, TypeError)
58+
59+
assert is_close(a.transpose(), [[1, 3], [1.5, 4]])
60+
assert is_close(+a, a)
61+
assert is_close(-a, [[-1, -1.5], [-3, -4]])
62+
assert raises(lambda: bool(a), NotImplementedError)
63+
64+
oldpretty = ctx.pretty
65+
try:
66+
ctx.pretty = False
67+
assert "arb_mat(" in repr(a)
68+
finally:
69+
ctx.pretty = oldpretty
70+
assert "[" in str(a)
71+
72+
73+
def test_arb_mat_add_sub() -> None:
74+
a = arb_mat([[1, 2], [3, 4]])
75+
b = arb_mat([[4, 5], [6, 7]])
76+
z = fmpz_mat([[4, 5], [6, 7]])
77+
q = fmpq_mat([[4, 5], [6, 7]])
78+
79+
assert is_close(a + b, [[5, 7], [9, 11]])
80+
assert is_close(a + z, [[5, 7], [9, 11]])
81+
assert is_close(a + q, [[5, 7], [9, 11]])
82+
assert is_close(z + a, [[5, 7], [9, 11]])
83+
assert is_close(q + a, [[5, 7], [9, 11]])
84+
85+
assert is_close(a - b, [[-3, -3], [-3, -3]])
86+
assert is_close(a - z, [[-3, -3], [-3, -3]])
87+
assert is_close(z - a, [[3, 3], [3, 3]])
88+
89+
assert is_close(a + 2, [[3, 2], [3, 6]])
90+
assert is_close(2 + a, [[3, 2], [3, 6]])
91+
assert is_close(a - 2, [[-1, 2], [3, 2]])
92+
assert is_close(2 - a, [[1, -2], [-3, -2]])
93+
94+
c = a + 1j
95+
d = 1j + a
96+
e = a - 1j
97+
f = 1j - a
98+
assert isinstance(c, acb_mat)
99+
assert isinstance(d, acb_mat)
100+
assert isinstance(e, acb_mat)
101+
assert isinstance(f, acb_mat)
102+
103+
assert raises(lambda: a + arb_mat([[1, 2, 3]]), ValueError)
104+
assert raises(lambda: a - arb_mat([[1, 2, 3]]), ValueError)
105+
assert raises(lambda: a + object(), TypeError) # type: ignore[operator]
106+
assert raises(lambda: object() + a, TypeError) # type: ignore[operator]
107+
assert raises(lambda: a - object(), TypeError) # type: ignore[operator]
108+
assert raises(lambda: object() - a, TypeError) # type: ignore[operator]
109+
110+
111+
def test_arb_mat_mul_div() -> None:
112+
a = arb_mat([[1, 2], [3, 4]])
113+
b = arb_mat([[4, 5], [6, 7]])
114+
z = fmpz_mat([[4, 5], [6, 7]])
115+
q = fmpq_mat([[4, 5], [6, 7]])
116+
117+
assert is_close(a * b, [[16, 19], [36, 43]])
118+
assert is_close(a * z, [[16, 19], [36, 43]])
119+
assert is_close(a * q, [[16, 19], [36, 43]])
120+
assert is_close(z * a, [[19, 28], [27, 40]])
121+
assert is_close(q * a, [[19, 28], [27, 40]])
122+
123+
assert is_close(a * 2, [[2, 4], [6, 8]])
124+
assert is_close(2 * a, [[2, 4], [6, 8]])
125+
assert is_close(a * 0.5, [[0.5, 1], [1.5, 2]])
126+
assert is_close(a / 2, [[0.5, 1], [1.5, 2]])
127+
128+
c = a * (1 + 2j)
129+
d = (1 + 2j) * a
130+
assert isinstance(c, acb_mat)
131+
assert isinstance(d, acb_mat)
132+
133+
assert raises(lambda: a * arb_mat([[1, 2, 3]]), ValueError)
134+
assert raises(lambda: a * object(), TypeError) # type: ignore[operator]
135+
assert raises(lambda: object() * a, TypeError) # type: ignore[operator]
136+
nan_mat = a / 0
137+
assert nan_mat[0, 0].is_nan() is True
138+
assert nan_mat[1, 1].is_nan() is True
139+
assert raises(lambda: a / object(), TypeError) # type: ignore[operator]
140+
141+
142+
def test_arb_mat_pow_inv_solve() -> None:
143+
a = arb_mat([[1, 2], [3, 4]])
144+
assert is_close(a**2, [[7, 10], [15, 22]])
145+
assert raises(lambda: pow(a, 2, 3), NotImplementedError) # type: ignore[misc]
146+
assert raises(lambda: arb_mat([[1, 2, 3]])**2, ValueError)
147+
148+
ai = a.inv()
149+
eye = a * ai
150+
assert is_close(eye, [[1, 0], [0, 1]], tol=1e-8, rel_tol=1e-8, max_width=1e-8)
151+
assert raises(lambda: arb_mat([[1, 2], [2, 4]]).inv(), ZeroDivisionError)
152+
assert raises(lambda: arb_mat([[1, 2, 3]]).inv(), ValueError)
153+
inv_ns = arb_mat([[1, 2], [2, 4]]).inv(nonstop=True)
154+
assert inv_ns[0, 0].is_nan() is True
155+
assert inv_ns[1, 1].is_nan() is True
156+
157+
x = arb_mat([[1], [2]])
158+
b = a * x
159+
assert is_close(a.solve(b), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
160+
assert is_close(a.solve(b, algorithm="lu"), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
161+
assert is_close(a.solve(b, algorithm="precond"), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
162+
assert is_close(a.solve(b, algorithm="approx"), x, tol=1e-8, rel_tol=1e-8, max_width=1e-8)
163+
assert raises(lambda: a.solve(b, algorithm="bad"), ValueError) # type: ignore[arg-type]
164+
assert raises(lambda: arb_mat([[1, 2], [2, 4]]).solve(b), ZeroDivisionError)
165+
solve_ns = arb_mat([[1, 2], [2, 4]]).solve(b, nonstop=True)
166+
assert solve_ns[0, 0].is_nan() is True
167+
assert solve_ns[1, 0].is_nan() is True
168+
assert raises(lambda: arb_mat([[1, 2, 3]]).solve(arb_mat([[1], [2], [3]])), ValueError)
169+
assert raises(lambda: a.solve([[1], [2]]), TypeError) # type: ignore[arg-type]
170+
171+
172+
def test_arb_mat_special_methods() -> None:
173+
a = arb_mat([[1, 2], [3, 4]])
174+
assert is_close_arb(a.det(), -2)
175+
assert is_close_arb(a.trace(), 5)
176+
assert is_close(a.mid(), a)
177+
assert is_close(
178+
a.exp(),
179+
[[51.9689561987050, 74.7365645670032], [112.104846850505, 164.073803049210]],
180+
tol=1e-7,
181+
rel_tol=1e-7,
182+
max_width=1e-7,
183+
)
184+
assert raises(lambda: arb_mat([[1, 2, 3]]).det(), ValueError)
185+
assert raises(lambda: arb_mat([[1, 2, 3]]).trace(), ValueError)
186+
assert raises(lambda: arb_mat([[1, 2, 3]]).exp(), ValueError)
187+
188+
p = a.charpoly()
189+
assert is_close_arb(p[0], -2)
190+
assert is_close_arb(p[1], -5)
191+
assert is_close_arb(p[2], 1)
192+
assert raises(lambda: arb_mat([[1, 2, 3]]).charpoly(), ValueError)
193+
194+
h = arb_mat.hilbert(2, 2)
195+
assert is_close(h, [[1, 1/2], [1/2, 1/3]], tol=1e-12, rel_tol=1e-12, max_width=1e-12)
196+
197+
ps = arb_mat.pascal(3, 4)
198+
assert is_close(ps, [[1, 1, 1, 1], [1, 2, 3, 4], [1, 3, 6, 10]])
199+
pu = arb_mat.pascal(3, 4, 1)
200+
assert is_close(pu, [[1, 1, 1, 1], [0, 1, 2, 3], [0, 0, 1, 3]])
201+
pl = arb_mat.pascal(3, 4, -1)
202+
assert is_close(pl, [[1, 0, 0, 0], [1, 1, 0, 0], [1, 2, 1, 0]])
203+
204+
st0 = arb_mat.stirling(4, 3, 0)
205+
assert is_close(st0, [[1, 0, 0], [0, 1, 0], [0, 1, 1], [0, 2, 3]])
206+
st1 = arb_mat.stirling(4, 3, 1)
207+
assert is_close(st1, [[1, 0, 0], [0, 1, 0], [0, -1, 1], [0, 2, -3]])
208+
st2 = arb_mat.stirling(4, 3, 2)
209+
assert is_close(st2, [[1, 0, 0], [0, 1, 0], [0, 1, 1], [0, 1, 3]])
210+
assert raises(lambda: arb_mat.stirling(2, 2, 5), ValueError)
211+
212+
d = arb_mat.dct(2)
213+
assert d.nrows() == 2
214+
assert d.ncols() == 2
215+
d2 = arb_mat.dct(2, 3)
216+
assert d2.nrows() == 2
217+
assert d2.ncols() == 3
218+
219+
220+
def test_arb_mat_contains_overlap_chop_cmp_eig() -> None:
221+
a = arb_mat([[1, 2], [3, 4]])
222+
b = (a / 3) * 3
223+
assert b.contains(a) is True
224+
assert a.contains(b) is False
225+
assert b.contains(fmpz_mat([[1, 2], [3, 4]])) is True
226+
assert (a / 3).contains(fmpq_mat([[1, 2], [3, 4]]) / 3) is True
227+
assert raises(lambda: a.contains(object()), TypeError) # type: ignore[arg-type]
228+
229+
assert b.overlaps(a) is True
230+
assert (a + 100).overlaps(a) is False
231+
232+
c = arb_mat([[1e-20, 2], [3, -1e-20]])
233+
chopped = c.chop(1e-10)
234+
assert is_close(chopped, [[0, 2], [3, 0]])
235+
236+
assert (a == arb_mat([[1, 2], [3, 4]])) is True
237+
assert (a != arb_mat([[1, 2], [3, 4]])) is False
238+
assert (a == fmpz_mat([[1, 2], [3, 4]])) is True
239+
assert (a != fmpz_mat([[1, 2], [3, 5]])) is True
240+
assert raises(lambda: a < arb_mat([[1, 2], [3, 4]]), ValueError) # type: ignore[operator]
241+
assert (a == object()) is False
242+
assert (a != object()) is True
243+
244+
eigvals = arb_mat([[1, 0], [0, 2]]).eig()
245+
assert len(eigvals) == 2
246+
assert any(v.real.contains(1) for v in eigvals)
247+
assert any(v.real.contains(2) for v in eigvals)
248+
249+
250+
def test_is_close_arb_mat() -> None:
251+
x = arb_mat([[1, 2], [3, 4]])
252+
assert is_close(x, [[1, 2], [3, 4]]) is True
253+
assert is_close(x, arb_mat([[1, 2], [3, 4]])) is True
254+
assert is_close(x, [[1, 2, 3]]) is False
255+
assert is_close(x, [[1, 2], [3, 5]]) is False
256+
assert is_close(object(), [[1]]) is False # type: ignore[arg-type]

0 commit comments

Comments
 (0)