Skip to content

Commit 2314465

Browse files
committed
fig2 notebook cleanup
1 parent 2c8e65a commit 2314465

9 files changed

Lines changed: 945 additions & 862 deletions

File tree

dca/analysis.py

Lines changed: 150 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import scipy
12
import numpy as np
23
from sklearn.linear_model import LinearRegression as LR
34
from sklearn.decomposition import PCA
45
from scipy.stats import special_ortho_group as sog
56

6-
from .data_util import CrossValidate, form_lag_matrix
7+
from .base import init_coef
8+
from .cov_util import calc_pi_from_cross_cov_mats, form_lag_matrix
9+
from .data_util import CrossValidate
710
from .methods_comparison import SlowFeatureAnalysis as SFA
8-
from .dca import DynamicalComponentsAnalysis
11+
from .dca import DynamicalComponentsAnalysis as DCA, DynamicalComponentsAnalysisFFT as DCAFFT
12+
913

1014
import warnings
1115
warnings.simplefilter(action='ignore', category=FutureWarning)
@@ -113,7 +117,7 @@ def run_analysis(X, Y, T_pi_vals, dim_vals, offset_vals, num_cv_folds, decoding_
113117
Y_test_ctd = Y_test - Y_mean
114118

115119
# compute cross-cov mats for DCA
116-
dca_model = DynamicalComponentsAnalysis(T=np.max(T_pi_vals))
120+
dca_model = DCA(T=np.max(T_pi_vals))
117121
dca_model.estimate_data_statistics(X_train_ctd)
118122

119123
# do PCA/SFA
@@ -220,7 +224,7 @@ def run_dim_analysis_dca(X, Y, T_pi, dim_vals, offset, num_cv_folds, decoding_wi
220224

221225
# make DCA object
222226
# compute cross-cov mats for DCA
223-
dca_model = DynamicalComponentsAnalysis(T=T_pi)
227+
dca_model = DCA(T=T_pi)
224228
dca_model.estimate_data_statistics(X_train_ctd)
225229

226230
# loop over dimensionalities
@@ -250,3 +254,145 @@ def run_dim_analysis_dca(X, Y, T_pi, dim_vals, offset, num_cv_folds, decoding_wi
250254
null_results[fold_idx, dim_idx, ii] = r2_dca
251255

252256
return results, null_results
257+
258+
259+
def gen_pi_heatmap(calc_pi_fn, N=100):
260+
theta_vals = np.linspace(0, np.pi, N)
261+
phi_vals = np.linspace(0, np.pi, N)
262+
heatmap = np.zeros((N, N))
263+
for theta_idx in range(N):
264+
if theta_idx % 10 == 0:
265+
print("theta_idx =", theta_idx)
266+
for phi_idx in range(N):
267+
theta, phi = theta_vals[theta_idx], phi_vals[phi_idx]
268+
x = np.cos(phi) * np.sin(theta)
269+
y = np.sin(phi) * np.sin(theta)
270+
z = np.cos(theta)
271+
V = np.array([x, y, z]).reshape((3, 1))
272+
heatmap[theta_idx, phi_idx] = calc_pi_fn(V)
273+
return heatmap
274+
275+
276+
def make_pi_fn_gp(cross_cov_mats):
277+
def calc_pi_fn_gp(V):
278+
pi = calc_pi_from_cross_cov_mats(cross_cov_mats, proj=V)
279+
return pi
280+
return calc_pi_fn_gp
281+
282+
283+
def make_pi_fn_knn(X, T_pi, n_jobs=-1):
284+
from info_measures.continuous import kraskov_stoegbauer_grassberger as ksg
285+
286+
def calc_pi_fn_knn(V):
287+
X_proj = np.dot(X, V)
288+
X_proj_lags = form_lag_matrix(X_proj, 2 * T_pi)
289+
mi = ksg.MutualInformation(X_proj_lags[:, :T_pi], X_proj_lags[:, T_pi:], add_noise=True)
290+
pi = mi.mutual_information(n_jobs=n_jobs)
291+
return pi
292+
return calc_pi_fn_knn
293+
294+
295+
def random_proj_pi_comparison(calc_pi_fn_1, cal_pi_fn_2, N, d=1,
296+
n_samples=10000, seed=20210412):
297+
rng = np.random.RandomState(seed)
298+
pi_1, pi_2 = np.zeros(n_samples), np.zeros(n_samples)
299+
for i in range(n_samples):
300+
if i % 100 == 0:
301+
print("sample {} of {}".format(i, n_samples))
302+
V = init_coef(N, d, rng=rng, init='random_ortho')
303+
pi_1[i] = calc_pi_fn_1(V)
304+
pi_2[i] = cal_pi_fn_2(V)
305+
pi_12 = np.vstack((pi_1, pi_2)).T # (n_samples, 2)
306+
return pi_12
307+
308+
309+
def gp_knn_trajectories(num_traj, cross_cov_mats, X, T_pi, d):
310+
f_gp = make_pi_fn_gp(cross_cov_mats)
311+
f_knn = make_pi_fn_knn(X, T_pi=T_pi)
312+
trajectories = []
313+
for traj_idx in range(num_traj):
314+
print("traj_idx =", traj_idx)
315+
opt = DCA(d=d, T=T_pi)
316+
opt.cross_covs = cross_cov_mats
317+
opt.fit_projection(record_V=True)
318+
V_seq = opt.V_seq
319+
num_dca_iter = len(V_seq)
320+
pi_gp_knn_traj = np.zeros((num_dca_iter, 2))
321+
for i in range(num_dca_iter):
322+
if i % 50 == 0:
323+
print("{} of {}".format(i, num_dca_iter))
324+
V = V_seq[i]
325+
pi_gp_knn_traj[i, 0] = f_gp(V)
326+
pi_gp_knn_traj[i, 1] = f_knn(V)
327+
trajectories.append(pi_gp_knn_traj)
328+
return trajectories
329+
330+
331+
def dca_deflation(cross_cov_mats, n_proj, n_init=1):
332+
N = cross_cov_mats.shape[1]
333+
T = cross_cov_mats.shape[0] // 2
334+
F = np.eye(N)
335+
cov_proj = np.copy(cross_cov_mats)
336+
basis = np.zeros((N, n_proj))
337+
opt = DCA(T=T)
338+
for i in range(n_proj):
339+
if i % 10 == 0:
340+
print(i)
341+
# run DCA
342+
opt.cross_covs = cov_proj
343+
opt.fit_projection(d=1, n_init=n_init)
344+
v = opt.coef_.flatten()
345+
# get full-dim v
346+
v_full = np.dot(F, v)
347+
basis[:, i] = v_full
348+
# update U, F, cov_proj
349+
U = scipy.linalg.orth(np.eye(N - i) - np.outer(v, v))
350+
F = np.dot(F, U)
351+
cov_proj = np.array([U.T.dot(C).dot(U) for C in cov_proj])
352+
return basis
353+
354+
355+
def dca_fft_deflation(X, T, n_proj, n_init=1):
356+
N = X.shape[1]
357+
F = np.eye(N)
358+
X_proj = np.copy(X)
359+
basis = np.zeros((N, n_proj))
360+
opt = DCAFFT(T=T, d=1)
361+
for i in range(n_proj):
362+
if i % 10 == 0:
363+
print(i)
364+
# run DCA
365+
opt.fit(X_proj, n_init=n_init)
366+
v = opt.coef_.flatten()
367+
# get full-dim v
368+
v_full = np.dot(F, v)
369+
basis[:, i] = v_full
370+
# update U, F, X
371+
U = scipy.linalg.orth(np.eye(N - i) - np.outer(v, v))
372+
F = np.dot(F, U)
373+
X_proj = np.dot(X_proj, U)
374+
return basis
375+
376+
377+
def dca_full(cross_cov_mats, n_proj, n_init=1):
378+
T = cross_cov_mats.shape[0] // 2
379+
opt = DCA(T=T)
380+
opt.cross_covs = cross_cov_mats
381+
V_seq = []
382+
for i in range(n_proj):
383+
if i % 10 == 0:
384+
print(i)
385+
opt.fit_projection(d=i + 1, n_init=n_init)
386+
V = opt.coef_
387+
V_seq.append(V)
388+
return V_seq
389+
390+
391+
def calc_pi_vs_dim(cross_cov_mats, V=None, V_seq=None):
392+
if V_seq is None:
393+
V_seq = [V[:, :i + 1] for i in range(V.shape[1])]
394+
pi_vals = np.zeros(len(V_seq))
395+
for i in range(len(V_seq)):
396+
V = V_seq[i]
397+
pi_vals[i] = calc_pi_from_cross_cov_mats(cross_cov_mats, proj=V)
398+
return pi_vals

dca/dca.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from scipy.signal.windows import hann
66

77
import torch
8+
import torch.fft
89
import torch.nn.functional as F
910

1011
from .base import SingleProjectionComponentsAnalysis, ortho_reg_fn, init_coef, ObjectiveWrapper
@@ -292,7 +293,7 @@ def make_cepts2(X, T_pi):
292293

293294
# Compute the power spectral density
294295
window = torch.Tensor(hann(Y.shape[-1])[np.newaxis, np.newaxis]).type(Y.dtype)
295-
Yf = torch.fft.rfft(Y * window, dim=1)
296+
Yf = torch.fft.rfft(Y * window, dim=-1)
296297
spect = abs(Yf)**2
297298
spect = spect.mean(dim=1)
298299
spect = torch.cat([torch.flip(spect[:, 1:], dims=(1,)), spect], dim=1)

dca/plotting/__init__.py

Whitespace-only changes.
File renamed without changes.

dca/plotting/fig2.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
4+
5+
from .. import style
6+
7+
8+
def make_axes(fig_width, wpad_edge=0, wpad_mid=0.05, hpad_top=0.05, hpad_bottom=0.05,
9+
small_sq_width=0.07):
10+
sq_width = (1. - 2 * wpad_edge - small_sq_width - 4 * wpad_mid) / 4.
11+
sq_height = 1. - hpad_top - hpad_bottom
12+
fig_height = sq_width * fig_width / sq_height
13+
small_sq_height = small_sq_width * fig_width / fig_height
14+
fig = plt.figure(figsize=(fig_width, fig_height))
15+
# 2 small squares
16+
ax2 = fig.add_axes((wpad_edge, hpad_bottom, small_sq_width, small_sq_height))
17+
ax1 = fig.add_axes((wpad_edge, 1. - hpad_top - small_sq_height,
18+
small_sq_width, small_sq_height))
19+
# 3 big squares
20+
ax3 = fig.add_axes((wpad_edge + small_sq_width + wpad_mid, hpad_bottom, sq_width, sq_height))
21+
ax4 = fig.add_axes((wpad_edge + small_sq_width + 2 * wpad_mid + sq_width, hpad_bottom,
22+
sq_width, sq_height))
23+
ax5 = fig.add_axes((wpad_edge + small_sq_width + 3 * wpad_mid + 2 * sq_width, hpad_bottom,
24+
sq_width, sq_height))
25+
ax6 = fig.add_axes((wpad_edge + small_sq_width + 4 * wpad_mid + 3 * sq_width, hpad_bottom,
26+
sq_width, sq_height))
27+
axes = [ax1, ax2, ax3, ax4, ax5, ax6]
28+
for ax in axes:
29+
ax.set_xticks([])
30+
ax.set_yticks([])
31+
32+
label_dx = -0.02
33+
label_dy = 0.05
34+
label_y = hpad_bottom + sq_height + label_dy
35+
fig.text(wpad_edge + label_dx, label_y,
36+
"A", va="bottom", ha="right", color="black",
37+
**style.panel_letter_fontstyle)
38+
fig.text(wpad_edge + small_sq_width + wpad_mid, label_y,
39+
"B", va="bottom", ha="center", color="black",
40+
**style.panel_letter_fontstyle)
41+
fig.text(wpad_edge + small_sq_width + 2 * wpad_mid + sq_width, label_y,
42+
"C", va="bottom", ha="center", color="black",
43+
**style.panel_letter_fontstyle)
44+
fig.text(wpad_edge + small_sq_width + 3 * wpad_mid + 2 * sq_width, label_y,
45+
"D", va="bottom", ha="center", color="black",
46+
**style.panel_letter_fontstyle)
47+
fig.text(wpad_edge + small_sq_width + 4 * wpad_mid + 3 * sq_width, label_y,
48+
"E", va="bottom", ha="center", color="black",
49+
**style.panel_letter_fontstyle)
50+
return axes
51+
52+
53+
def disp_heatmap(ax, heatmap, show_xlabels=True, show_ylabels=True, title=None):
54+
N_theta, N_phi = heatmap.shape
55+
ax.imshow(heatmap, origin="lower left", cmap="gray", aspect="equal")
56+
if show_xlabels:
57+
ax.set_xlabel("$\phi$", fontsize=style.axis_label_fontsize, labelpad=-8.5)
58+
ax.set_xticks([0, N_phi - 1])
59+
ax.set_xticklabels(["0", "$\pi$"], fontsize=style.ticklabel_fontsize)
60+
else:
61+
ax.set_xticks([])
62+
if show_ylabels:
63+
ax.set_ylabel("$\\theta$", fontsize=style.axis_label_fontsize, labelpad=-8.5)
64+
ax.set_yticks([0, N_theta - 1])
65+
ax.set_yticklabels(["0", "$\pi$"], fontsize=style.ticklabel_fontsize)
66+
else:
67+
ax.set_yticks([])
68+
ax.set_xlim([0, heatmap.shape[1] - 1])
69+
ax.set_ylim([0, heatmap.shape[0] - 1])
70+
if title is not None:
71+
ax.set_title(title, fontsize=style.axis_label_fontsize * 0.8, pad=1)
72+
73+
74+
def disp_scatter(ax, pi_gp, pi_knn, trajectories=None, diag_text=False,
75+
arrow=True, xlabel="full PI", ylabel="Gaussian PI"):
76+
# Note that gp=y and knn=x, but 0 index is gp and 1 is knn in data arrays!
77+
traj_color = "#C63F3A"
78+
all_gp_vals = [pi_gp]
79+
all_knn_vals = [pi_knn]
80+
if trajectories is not None:
81+
all_gp_vals += [traj[:, 0] for traj in trajectories]
82+
all_knn_vals += [traj[:, 1] for traj in trajectories]
83+
all_gp_vals = np.concatenate(all_gp_vals)
84+
all_knn_vals = np.concatenate(all_knn_vals)
85+
min_gp, max_gp = all_gp_vals.min(), all_gp_vals.max()
86+
range_gp = max_gp - min_gp
87+
min_knn, max_knn = all_knn_vals.min(), all_knn_vals.max()
88+
range_knn = max_knn - min_knn
89+
pi_gp_norm = (pi_gp - min_gp) / range_gp
90+
pi_knn_norm = (pi_knn - min_knn) / range_knn
91+
ax.hexbin(pi_knn_norm, pi_gp_norm, gridsize=50, extent=(0, 1, 0, 1),
92+
cmap="gray_r", bins="log", linewidth=0.05)
93+
if trajectories is not None:
94+
for traj_idx in range(len(trajectories)):
95+
traj = np.copy(trajectories[traj_idx])
96+
traj[:, 0] = (traj[:, 0] - min_gp) / range_gp
97+
traj[:, 1] = (traj[:, 1] - min_knn) / range_knn
98+
ax.plot(traj[:, 1], traj[:, 0], linewidth=0.5, color=traj_color)
99+
ax.set_xlim([0, 1.025])
100+
ax.set_ylim([0, 1.025])
101+
ax.set_xticks([0, 1])
102+
ax.set_xticklabels([0, 1], fontsize=style.ticklabel_fontsize)
103+
ax.set_yticks([0, 1])
104+
ax.set_yticklabels([0, 1], fontsize=style.ticklabel_fontsize)
105+
ax.spines['left'].set_bounds(0, 1)
106+
ax.spines['bottom'].set_bounds(0, 1)
107+
ax.spines['right'].set_visible(False)
108+
ax.spines['top'].set_visible(False)
109+
ax.set_xlabel(xlabel, fontsize=style.axis_label_fontsize, labelpad=-9.5)
110+
ax.set_ylabel(ylabel, fontsize=style.axis_label_fontsize, labelpad=-8)
111+
112+
theta_deg = 47
113+
if diag_text:
114+
ax.text(0.5, 0.65, "DCA trajectories", fontsize=style.ticklabel_fontsize * 0.8,
115+
rotation=theta_deg, rotation_mode="anchor", ha="center", va="center",
116+
color=traj_color)
117+
if arrow:
118+
len_x = np.cos(np.deg2rad(theta_deg))
119+
len_y = np.sin(np.deg2rad(theta_deg))
120+
mag = 0.425
121+
ax.quiver(0.475, 0.40, mag * len_x, mag * len_y,
122+
angles='xy', scale_units='xy', scale=1, width=0.015,
123+
color=traj_color)
124+
125+
126+
def plot_deflation_results(ax, pi_regular, pi_def, pi_fft):
127+
dim_vals = np.arange(len(pi_def) + 1)
128+
pi_vals = (pi_regular, pi_def, pi_fft)
129+
labels = ["DCA", "deflation", "FFT deflation"]
130+
markersize = 1.5
131+
colors = ["#C63F3A", "gray", "black"]
132+
for i in range(len(pi_vals)):
133+
if i < 2:
134+
ax.plot(dim_vals, [0] + list(pi_vals[i]), label=labels[i],
135+
linewidth=0.85, color=colors[i], linestyle="-")
136+
else:
137+
ax.plot(dim_vals, [0] + list(pi_vals[i]), label=labels[i],
138+
linewidth=0, marker=".", markersize=markersize,
139+
color=colors[i])
140+
141+
ax.legend(fontsize=style.ticklabel_fontsize * 0.8, frameon=False,
142+
labelspacing=0.1, bbox_to_anchor=(0.2, 0, 1, 1))
143+
ax.set_xlabel("dimension", fontsize=style.axis_label_fontsize, labelpad=-9.5)
144+
ax.set_ylabel("PI (nats)", fontsize=style.axis_label_fontsize, labelpad=-13)
145+
ax.spines["right"].set_visible(False)
146+
ax.spines["top"].set_visible(False)
147+
148+
max_dim = len(pi_regular)
149+
max_dim_padded = max_dim * 1.025
150+
ax.set_xticks([0, max_dim])
151+
ax.set_xticklabels([0, max_dim], fontsize=style.ticklabel_fontsize)
152+
ax.set_xlim([0, max_dim_padded])
153+
ax.spines["bottom"].set_bounds(0, max_dim)
154+
155+
max_pi = np.max(np.concatenate(pi_vals))
156+
max_pi_padded = max_pi * 1.025
157+
ax.set_yticks([0, max_pi])
158+
ax.set_yticklabels([0, np.round(max_pi, 1)], fontsize=style.ticklabel_fontsize)
159+
ax.set_ylim([0, max_pi_padded])
160+
ax.spines["left"].set_bounds(0, max_pi)

notebooks/Fig1.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"from dca import DynamicalComponentsAnalysis as DCA, style\n",
1414
"from dca.cov_util import calc_cross_cov_mats_from_data\n",
1515
"from dca.synth_data import embedded_lorenz_cross_cov_mats, gen_lorenz_data, random_basis, median_subspace\n",
16-
"from dca.plotting import lorenz_fig_axes, plot_3d, plot_lorenz_3d, plot_traces, plot_dca_demo, plot_r2, plot_cov\n",
16+
"from dca.plotting.fig1 import lorenz_fig_axes, plot_3d, plot_lorenz_3d, plot_traces, plot_dca_demo, plot_r2, plot_cov\n",
1717
"\n",
1818
"RESULTS_FILENAME = \"lorenz_results.hdf5\""
1919
]
@@ -283,5 +283,5 @@
283283
}
284284
},
285285
"nbformat": 4,
286-
"nbformat_minor": 2
286+
"nbformat_minor": 4
287287
}

0 commit comments

Comments
 (0)