forked from arrayfire/arrayfire-py
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_linear_algebra.py
More file actions
123 lines (81 loc) · 3.92 KB
/
test_linear_algebra.py
File metadata and controls
123 lines (81 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pytest
import arrayfire as af
from tests._helpers import create_from_2d_nested
# Test dot
@pytest.fixture
def real_vector_1() -> af.Array:
return af.Array([1.0, 2.0, 3.0])
@pytest.fixture
def real_vector_2() -> af.Array:
return af.Array([4.0, 5.0, 6.0])
@pytest.fixture
def float_vector_1() -> af.Array:
return af.Array([1.5, 2.5, 3.5])
@pytest.fixture
def float_vector_2() -> af.Array:
return af.Array([4.5, 5.5, 6.5])
def test_dot_real_vectors(real_vector_1: af.Array, real_vector_2: af.Array) -> None:
expected = 32 # Calculated manually or using a trusted library
result = af.dot(real_vector_1, real_vector_2)
assert result == expected, f"Expected {expected}, got {result}"
def test_dot_float_vectors(float_vector_1: af.Array, float_vector_2: af.Array) -> None:
expected = 61.5 # Calculated manually or using a trusted library
result = af.dot(float_vector_1, float_vector_2)
assert result == expected, f"Expected {expected}, got {result}"
def test_dot_return_scalar(real_vector_1: af.Array, real_vector_2: af.Array) -> None:
result = af.dot(real_vector_1, real_vector_2, return_scalar=True)
assert isinstance(result, (int, float)), "Result is not a scalar"
# Test gemm
@pytest.fixture
def matrix_a() -> af.Array:
return create_from_2d_nested(1, 2, 3, 4)
@pytest.fixture
def matrix_b() -> af.Array:
return create_from_2d_nested(5, 6, 7, 8)
def test_gemm_basic(matrix_a: af.Array, matrix_b: af.Array) -> None:
result = af.gemm(matrix_a, matrix_b)
expected = create_from_2d_nested(19.0, 22.0, 43.0, 50.0)
assert result == expected, f"Expected {expected}, got {result}"
def test_gemm_alpha_beta(matrix_a: af.Array, matrix_b: af.Array) -> None:
alpha = 0.5
beta = 2.0
result = af.gemm(matrix_a, matrix_b, alpha=alpha, beta=beta, accum=matrix_a)
expected = create_from_2d_nested(11.5, 15.0, 27.5, 33.0)
assert result == expected, f"Expected {expected}, got {result}"
def test_gemm_transpose_options(matrix_a: af.Array, matrix_b: af.Array) -> None:
result = af.gemm(matrix_a, matrix_b, lhs_opts=af.MatProp.TRANS, rhs_opts=af.MatProp.TRANS)
expected = create_from_2d_nested(23.0, 31.0, 34.0, 46.0)
assert result == expected, f"Expected {expected}, got {result}"
# Test matmul
def test_basic_matrix_multiplication() -> None:
A = af.randu((3, 2), dtype=af.float32)
B = af.randu((2, 4), dtype=af.float32)
C = af.matmul(A, B)
assert C.shape == (3, 4), "Output dimensions should be 3x4."
def test_matrix_multiplication_with_lhs_transposed() -> None:
A = af.randu((2, 3), dtype=af.float32) # Transposing makes it 3x2
B = af.randu((2, 4), dtype=af.float32)
C = af.matmul(A, B, lhs_opts=af.MatProp.TRANS)
assert C.shape == (3, 4), "Output dimensions should be 3x4 when lhs is transposed."
def test_matrix_multiplication_with_both_transposed() -> None:
A = af.randu((4, 3), dtype=af.float32) # Transposing makes it 3x4
B = af.randu((6, 4), dtype=af.float32) # Transposing makes it 4x6
C = af.matmul(A, B, lhs_opts=af.MatProp.TRANS, rhs_opts=af.MatProp.TRANS)
assert C.shape == (3, 6), "Output dimensions should be 3x6 with both matrices transposed."
# BUG
# def test_incompatible_dimensions() -> None:
# A = af.randu((3, 5), dtype=af.float32)
# B = af.randu((4, 6), dtype=af.float32)
# with pytest.raises(ValueError):
# C = af.matmul(A, B)
# def test_unsupported_data_type() -> None:
# A = af.Array([1, 2, 3], dtype=af.uint32) # Assuming unsupported data type like unsigned int
# B = af.Array([4, 5, 6], dtype=af.uint32)
# with pytest.raises(TypeError):
# C = af.matmul(A, B)
# def test_multiplication_result_verification() -> None:
# A = create_from_2d_nested(1, 2, 3, 4)
# B = create_from_2d_nested(5, 6, 7, 8)
# C = af.matmul(A, B)
# expected = create_from_2d_nested(19, 22, 43, 50)
# assert af.all_true(C == expected), "The multiplication result is incorrect."