Skip to content

Commit cc8c736

Browse files
committed
first pass on github actions
1 parent 05a9dac commit cc8c736

5 files changed

Lines changed: 48 additions & 33 deletions

File tree

.github/workflows/dca_tests.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: DCA tests
2+
on:
3+
push:
4+
branches:
5+
- main
6+
pull_request:
7+
branches:
8+
- main
9+
jobs:
10+
run-tests:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.8", "3.8", "3.9"]
15+
steps:
16+
- name: Test DCA
17+
- uses actions/checkout@v3
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v3
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
26+
python -m pip install flake8 codecov pytest-cov sphinx_rtd_theme
27+
- name: Lint with flake8
28+
run: |
29+
flake8 dca tests
30+
- name: Test with pytest
31+
run: |
32+
pytest -sv --cov=./ tests
33+
- name: Build docs
34+
run: |
35+
sphinx-build -W -b html docs/source docs/build
36+
- name: Codecov
37+
run: |
38+
codecov

.travis.yml

Lines changed: 0 additions & 26 deletions
This file was deleted.

dca/data_util.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def load_weather_data(filename):
7979
"""
8080

8181

82-
def load_sabes_data(filename, bin_width_s=.05, preprocess=True):
82+
def load_sabes_data(filename, bin_width_s=.05, high_pass=True, sqrt=True, thresh=5000,
83+
zscore_pos=True):
8384
# Load MATLAB file
8485
with h5py.File(filename, "r") as f:
8586
# Get channel names (e.g. M1 001 or S1 001)
@@ -119,10 +120,10 @@ def load_sabes_data(filename, bin_width_s=.05, preprocess=True):
119120
unique_idxs, counts = np.unique(bin_idx, return_counts=True)
120121
# make sure to ignore the hash here...
121122
binned_spikes[unique_idxs, chan_idx * n_sorted_units + unit_idx - 1] += counts
122-
binned_spikes = binned_spikes[:, binned_spikes.sum(axis=0) > 0]
123-
if preprocess:
124-
binned_spikes = binned_spikes[:, binned_spikes.sum(axis=0) > 5000]
123+
binned_spikes = binned_spikes[:, binned_spikes.sum(axis=0) > thresh]
124+
if sqrt:
125125
binned_spikes = np.sqrt(binned_spikes)
126+
if high_pass:
126127
binned_spikes = moving_center(binned_spikes, n=600)
127128
result[region] = binned_spikes
128129
# Get cursor position
@@ -131,7 +132,7 @@ def load_sabes_data(filename, bin_width_s=.05, preprocess=True):
131132
t_mid_bin = np.arange(len(binned_spikes)) * bin_width_s + bin_width_s / 2
132133
cursor_pos_interp = interp1d(t - t[0], cursor_pos, axis=0)
133134
cursor_interp = cursor_pos_interp(t_mid_bin)
134-
if preprocess:
135+
if zscore_pos:
135136
cursor_interp -= cursor_interp.mean(axis=0, keepdims=True)
136137
cursor_interp /= cursor_interp.std(axis=0, keepdims=True)
137138
result["cursor"] = cursor_interp

dca/dca.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ def score(self, X=None):
284284
If X is given, calcuate the PI of X for the learned projections.
285285
"""
286286
T = self.T_fit
287+
if T is None:
288+
T = self.T
287289
if X is None:
288290
cross_covs = self.cross_covs.cpu().numpy()
289291
else:

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ h5py
44
pandas
55
scikit-learn
66
matplotlib
7-
-f https://download.pytorch.org/whl/torch_stable.html
8-
torch>=1.7.0+cpu
7+
--extra-index-url https://download.pytorch.org/whl/cpu
8+
torch

0 commit comments

Comments
 (0)