|
8 | 8 | from . import utils |
9 | 9 |
|
10 | 10 |
|
11 | | -def _backtrack( |
12 | | - f_t, f_grad, x_t, d_t, g_t, L_t, |
13 | | - gamma_max=1, ratio_increase=2., ratio_decrease=0.999, |
14 | | - max_iter=100): |
15 | | - # could be included inside minimize_FW |
16 | | - d2_t = splinalg.norm(d_t) ** 2 |
17 | | - for i in range(max_iter): |
18 | | - step_size = min(g_t / (d2_t * L_t), gamma_max) |
19 | | - rhs = f_t - step_size * g_t + 0.5 * (step_size**2) * L_t * d2_t |
20 | | - f_next, grad_next = f_grad(x_t + step_size * d_t) |
21 | | - if f_next <= rhs: |
22 | | - if i == 0: |
23 | | - L_t *= ratio_decrease |
24 | | - break |
25 | | - else: |
26 | | - L_t *= ratio_increase |
27 | | - return step_size, L_t, f_next, grad_next |
28 | | - |
29 | | - |
30 | 11 | def minimize_FW( |
31 | 12 | f_grad, lmo, x0, L=None, max_iter=1000, tol=1e-12, |
32 | 13 | backtracking=True, callback=None, verbose=0): |
@@ -55,11 +36,21 @@ def minimize_FW( |
55 | 36 | g_t = g_t[0] |
56 | 37 | if g_t <= tol: |
57 | 38 | break |
| 39 | + d2_t = splinalg.norm(d_t) ** 2 |
58 | 40 | if backtracking: |
59 | | - step_size, L_t, f_next, grad_next = _backtrack( |
60 | | - f_t, f_grad, x_t, d_t, g_t, L_t) |
| 41 | + ratio_decrease = 0.999 |
| 42 | + ratio_increase = 2 |
| 43 | + for i in range(max_iter): |
| 44 | + step_size = min(g_t / (d2_t * L_t), 1) |
| 45 | + rhs = f_t - step_size * g_t + 0.5 * (step_size**2) * L_t * d2_t |
| 46 | + f_next, grad_next = f_grad(x_t + step_size * d_t) |
| 47 | + if f_next <= rhs + 1e-6: |
| 48 | + if i == 0: |
| 49 | + L_t *= ratio_decrease |
| 50 | + break |
| 51 | + else: |
| 52 | + L_t *= ratio_increase |
61 | 53 | else: |
62 | | - d2_t = splinalg.norm(d_t) ** 2 |
63 | 54 | step_size = min(g_t / (d2_t * L_t), 1) |
64 | 55 | f_next, grad_next = f_grad(x_t + step_size * d_t) |
65 | 56 | x_t += step_size * d_t |
|
0 commit comments