-
Notifications
You must be signed in to change notification settings - Fork 259
Expand file tree
/
Copy pathtest_benchmark_peak_detection.py
More file actions
72 lines (55 loc) · 2.38 KB
/
test_benchmark_peak_detection.py
File metadata and controls
72 lines (55 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pytest
import shutil
from pathlib import Path
from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
@pytest.mark.skip()
def test_benchmark_peak_detection(create_cache_folder):
cache_folder = create_cache_folder
job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms")
# recording, gt_sorting = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs)
# create study
study_folder = cache_folder / "study_peak_detection"
datasets = {"toy": (recording, gt_sorting)}
cases = {}
peaks = {}
for dataset in datasets.keys():
recording, gt_sorting = datasets[dataset]
sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("templates", **job_kwargs)
extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True)
spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes
for method in ["locally_exclusive", "matched_filtering"]:
cases[method] = {
"label": f"{method} on toy",
"dataset": "toy",
"init_kwargs": {"gt_peaks": peaks["toy"]},
"params": {"method": method, "method_kwargs": {}},
}
if study_folder.exists():
shutil.rmtree(study_folder)
study = PeakDetectionStudy.create(study_folder, datasets=datasets, cases=cases)
print(study)
# this study needs analyzer
study.compute_metrics(**job_kwargs)
# run and result
study.run(**job_kwargs)
study.compute_results(**job_kwargs)
# load study to check persistency
study = PeakDetectionStudy(study_folder)
study.plot_agreements_by_channels()
study.plot_agreements_by_units()
study.plot_deltas_per_cells()
study.plot_detected_amplitudes()
study.plot_performances_vs_snr()
study.plot_template_similarities()
study.plot_run_times()
import matplotlib.pyplot as plt
plt.show()
if __name__ == "__main__":
cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks"
test_benchmark_peak_detection(cache_folder)