Skip to content

Commit bcf7593

Browse files
committed
Add first implementation of scaling functions
1 parent 1161d04 commit bcf7593

2 files changed

Lines changed: 419 additions & 0 deletions

File tree

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import numpy as np
2+
import pytest
3+
from foqus_lib.framework.surrogate.scaling import (
4+
scale_linear,
5+
unscale_linear,
6+
scale_log,
7+
unscale_log,
8+
scale_log2,
9+
unscale_log2,
10+
scale_power,
11+
unscale_power,
12+
scale_power2,
13+
unscale_power2,
14+
validate_for_scaling,
15+
)
16+
17+
from hypothesis.extra.numpy import arrays as arrays_strat, array_shapes
18+
from hypothesis import given, example, assume
19+
from contextlib import contextmanager
20+
21+
POSITIVE_VALS_ONLY = {scale_log}
22+
23+
24+
@contextmanager
25+
def does_not_raise():
26+
yield
27+
28+
29+
def test_scale_linear():
30+
# Test case 1: Basic scaling
31+
input_array = np.array([1, 2, 3, 4, 5])
32+
scaled_array = scale_linear(input_array)
33+
assert np.all(scaled_array >= 0)
34+
assert np.all(scaled_array <= 1)
35+
assert np.allclose(scaled_array, [0.0, 0.25, 0.5, 0.75, 1.0])
36+
37+
# Test case 2: Custom range scaling
38+
input_array = np.array([10, 20, 30, 40, 50])
39+
scaled_array = scale_linear(input_array, lo=10, hi=50)
40+
assert np.all(scaled_array >= 0)
41+
assert np.all(scaled_array <= 1)
42+
assert np.allclose(scaled_array, [0.0, 0.25, 0.5, 0.75, 1.0])
43+
44+
# Test case 3: Scaling with negative values
45+
input_array = np.array([-5, 0, 5])
46+
scaled_array = scale_linear(input_array)
47+
assert np.all(scaled_array >= 0)
48+
assert np.all(scaled_array <= 1)
49+
assert np.allclose(scaled_array, [0.0, 0.5, 1.0])
50+
51+
# Test case 4: Scaling with repeated values
52+
input_array = np.array([2, 2, 2, 2])
53+
scaled_array = scale_linear(input_array)
54+
assert np.all(scaled_array >= 0)
55+
assert np.all(scaled_array <= 1)
56+
assert np.allclose(scaled_array, [0.0, 0.0, 0.0, 0.0])
57+
58+
59+
def test_unscale_linear():
60+
# Test case 1: Basic unscaling
61+
input_array = np.array([0.0, 0.25, 0.5, 0.75, 1.0])
62+
unscaled_array = unscale_linear(input_array, lo=1, hi=5)
63+
assert np.allclose(unscaled_array, [1, 2, 3, 4, 5])
64+
65+
# Test case 2: Custom range unscaling
66+
input_array = np.array([0.0, 0.25, 0.5, 0.75, 1.0])
67+
unscaled_array = unscale_linear(input_array, lo=10, hi=50)
68+
assert np.allclose(unscaled_array, [10, 20, 30, 40, 50])
69+
70+
# Test case 3: Unscaling with negative values
71+
input_array = np.array([0.0, 0.5, 1.0])
72+
unscaled_array = unscale_linear(input_array, lo=-5, hi=5)
73+
assert np.allclose(unscaled_array, [-5, 0, 5])
74+
75+
# Test case 4: Unscaling with repeated values
76+
input_array = np.array([0.0, 0.0, 0.0, 0.0])
77+
unscaled_array = unscale_linear(input_array, lo=0, hi=5)
78+
assert np.allclose(unscaled_array, [0, 0, 0, 0])
79+
80+
81+
def test_scale_log():
82+
# Test case 1: Basic log scaling
83+
input_array = np.array([1, 2, 3, 4, 5])
84+
scaled_array = scale_log(input_array)
85+
assert np.all(scaled_array >= 0)
86+
assert np.all(scaled_array <= 1)
87+
assert np.allclose(scaled_array, [0.0, 0.43067656, 0.68260619, 0.86135312, 1.0])
88+
89+
# Test case 2: Custom range log scaling
90+
input_array = np.array([10, 20, 30, 40, 50])
91+
scaled_array = scale_log(input_array, lo=10, hi=50)
92+
assert np.all(scaled_array >= 0)
93+
assert np.all(scaled_array <= 1)
94+
assert np.allclose(scaled_array, [0.0, 0.43067656, 0.68260619, 0.86135312, 1.0])
95+
96+
97+
def test_scale_log2():
98+
# Test case 1: Basic log2 scaling
99+
input_array = np.array([1, 2, 3, 4, 5])
100+
scaled_array = scale_log2(input_array)
101+
assert np.all(scaled_array >= 0)
102+
assert np.all(scaled_array <= 1)
103+
assert np.allclose(scaled_array, [0.0, 0.51188336, 0.74036269, 0.8893017, 1.0])
104+
105+
# Test case 2: Custom range log2 scaling
106+
input_array = np.array([10, 20, 30, 40, 50])
107+
scaled_array = scale_log2(input_array, lo=10, hi=50)
108+
assert np.all(scaled_array >= 0)
109+
assert np.all(scaled_array <= 1)
110+
assert np.allclose(scaled_array, [0.0, 0.51188336, 0.74036269, 0.8893017, 1.0])
111+
112+
113+
def test_scale_power():
114+
# Test case 1: Basic power scaling
115+
input_array = np.array([1, 2, 3, 4, 5])
116+
scaled_array = scale_power(input_array)
117+
assert np.all(scaled_array >= 0)
118+
assert np.all(scaled_array <= 1)
119+
assert np.allclose(
120+
scaled_array,
121+
[0.00000000e00, 9.00090009e-04, 9.90099010e-03, 9.99099910e-02, 1.00000000e00],
122+
)
123+
124+
# Test case 2: Custom range power scaling
125+
input_array = np.array([1.0, 4.7, 4.8, 4.999, 5.0])
126+
scaled_array = scale_power(input_array)
127+
print(scaled_array)
128+
assert np.all(scaled_array >= 0)
129+
assert np.all(scaled_array <= 1)
130+
assert np.allclose(scaled_array, [0.0, 0.50113735, 0.63092044, 0.99769983, 1.0])
131+
132+
133+
def test_scale_power2():
134+
# Test case 1: Basic power scaling
135+
input_array = np.array([1, 2, 3, 4, 5])
136+
scaled_array = scale_power2(input_array)
137+
assert np.all(scaled_array >= 0)
138+
assert np.all(scaled_array <= 1)
139+
assert np.allclose(scaled_array, [0.0, 0.08647549, 0.24025307, 0.51371258, 1.0])
140+
141+
# Test case 2: Custom range power scaling
142+
input_array = np.array([1.0, 4.7, 4.8, 4.999, 5.0])
143+
scaled_array = scale_power2(input_array)
144+
assert np.all(scaled_array >= 0)
145+
assert np.all(scaled_array <= 1)
146+
assert np.allclose(scaled_array, [0.0, 0.82377238, 0.87916771, 0.99936058, 1.0])
147+
148+
149+
# @pytest.mark.xfail(reason="function formula is wrong", strict=True)
150+
def test_unscale_log():
151+
input_array = np.array([0.0, 0.43067656, 0.68260619, 0.86135312, 1.0])
152+
unscaled_array = unscale_log(input_array, lo=1, hi=5)
153+
assert np.allclose(unscaled_array, [1, 2, 3, 4, 5])
154+
155+
input_array = np.array([0.0, 0.43067656, 0.68260619, 0.86135312, 1.0])
156+
unscaled_array = unscale_log(input_array, lo=10, hi=50)
157+
assert np.allclose(unscaled_array, [10, 20, 30, 40, 50])
158+
159+
160+
def test_unscale_log2():
161+
input_array = np.array([0.0, 0.51188336, 0.74036269, 0.8893017, 1.0])
162+
unscaled_array = unscale_log2(input_array, lo=1, hi=5)
163+
assert np.allclose(unscaled_array, [1, 2, 3, 4, 5])
164+
165+
input_array = np.array([0.0, 0.51188336, 0.74036269, 0.8893017, 1.0])
166+
unscaled_array = unscale_log2(input_array, lo=10, hi=50)
167+
assert np.allclose(unscaled_array, [10, 20, 30, 40, 50])
168+
169+
170+
def test_unscale_power():
171+
input_array = np.array(
172+
[0.00000000e00, 9.00090009e-04, 9.90099010e-03, 9.99099910e-02, 1.00000000e00]
173+
)
174+
unscaled_array = unscale_power(input_array, lo=1, hi=5)
175+
assert np.allclose(unscaled_array, [1, 2, 3, 4, 5])
176+
177+
input_array = np.array([0.0, 0.50113735, 0.63092044, 0.99769983, 1.0])
178+
unscaled_array = unscale_power(input_array, lo=1.0, hi=5.0)
179+
assert np.allclose(unscaled_array, [1.0, 4.7, 4.8, 4.999, 5.0])
180+
181+
182+
def test_unscale_power2():
183+
input_array = np.array([0.0, 0.08647549, 0.24025307, 0.51371258, 1.0])
184+
unscaled_array = unscale_power2(input_array, lo=1, hi=5)
185+
assert np.allclose(unscaled_array, [1, 2, 3, 4, 5])
186+
187+
input_array = np.array([0.0, 0.82377238, 0.87916771, 0.99936058, 1.0])
188+
unscaled_array = unscale_power2(input_array, lo=1.0, hi=5.0)
189+
assert np.allclose(unscaled_array, [1.0, 4.7, 4.8, 4.999, 5.0])
190+
191+
192+
# fill in with more cases, parameters, functions
193+
@pytest.mark.parametrize("x", [np.array([1, 2, 3, 4, 5]), np.array([0, 7, 9, 10, 12])])
194+
# @given(x=arrays_strat(np.float32, array_shapes()))
195+
@pytest.mark.parametrize(
196+
"scale,unscale",
197+
[
198+
(scale_linear, unscale_linear),
199+
(scale_log, unscale_log),
200+
(scale_log2, unscale_log2),
201+
(scale_power, unscale_power),
202+
(scale_power2, unscale_power2),
203+
],
204+
)
205+
def test_roundtrip(x, scale, unscale):
206+
207+
lo = np.min(x)
208+
hi = np.max(x)
209+
if not passes_validation(x, lo, hi):
210+
expected_failure = pytest.raises(ValueError)
211+
elif lo <= 0 and scale in POSITIVE_VALS_ONLY:
212+
expected_failure = pytest.raises(ValueError, match="All values must be > 0.*")
213+
else:
214+
expected_failure = does_not_raise()
215+
with expected_failure:
216+
scaled = scale(x, lo=lo, hi=hi)
217+
unscaled = unscale(scaled, lo=lo, hi=hi)
218+
assert np.allclose(x, unscaled)
219+
220+
221+
def passes_validation(array_in, lo, hi):
222+
try:
223+
validate_for_scaling(array_in, lo, hi)
224+
except Exception:
225+
return False
226+
else:
227+
return True
228+
229+
230+
# Run the tests
231+
if __name__ == "__main__":
232+
pytest.main()

0 commit comments

Comments
 (0)