Skip to content

Commit 02e0e97

Browse files
Compute missing spikes on the right tail of amplitude distributions (#4353)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f732780 commit 02e0e97

2 files changed

Lines changed: 39 additions & 38 deletions

File tree

src/spikeinterface/metrics/quality/misc_metrics.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -873,21 +873,16 @@ def compute_amplitude_cutoffs(
873873

874874
all_fraction_missing = {}
875875

876-
if sorting_analyzer.has_extension("spike_amplitudes"):
877-
extension = sorting_analyzer.get_extension("spike_amplitudes")
878-
all_amplitudes = extension.get_data()
879-
invert_amplitudes = np.median(all_amplitudes) > 0
880-
elif sorting_analyzer.has_extension("amplitude_scalings"):
881-
# amplitude scalings are positive, we need to invert them
882-
invert_amplitudes = True
883-
extension = sorting_analyzer.get_extension("amplitude_scalings")
884-
876+
available_extension = (
877+
"spike_amplitudes" if sorting_analyzer.has_extension("spike_amplitudes") else "amplitude_scalings"
878+
)
879+
extension = sorting_analyzer.get_extension(available_extension)
885880
amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods)
886881

887882
for unit_id in unit_ids:
888883
amplitudes = amplitudes_by_units[unit_id]
889884

890-
if invert_amplitudes:
885+
if np.median(amplitudes) < 0: # amplitude_cutoff expects positive amplitudes
891886
amplitudes = -amplitudes
892887
all_fraction_missing[unit_id] = amplitude_cutoff(
893888
amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio
@@ -903,7 +898,7 @@ class AmplitudeCutoff(BaseMetric):
903898
metric_name = "amplitude_cutoff"
904899
metric_function = compute_amplitude_cutoffs
905900
metric_params = {
906-
"num_histogram_bins": 100,
901+
"num_histogram_bins": 200,
907902
"histogram_smoothing_value": 3,
908903
"amplitudes_bins_min_ratio": 5,
909904
}
@@ -1012,15 +1007,10 @@ def compute_noise_cutoffs(
10121007
noise_cutoff_dict = {}
10131008
noise_ratio_dict = {}
10141009

1015-
if sorting_analyzer.has_extension("spike_amplitudes"):
1016-
extension = sorting_analyzer.get_extension("spike_amplitudes")
1017-
all_amplitudes = extension.get_data()
1018-
invert_amplitudes = np.median(all_amplitudes) > 0
1019-
elif sorting_analyzer.has_extension("amplitude_scalings"):
1020-
# amplitude scalings are positive, we need to invert them
1021-
invert_amplitudes = True
1022-
extension = sorting_analyzer.get_extension("amplitude_scalings")
1023-
1010+
available_extension = (
1011+
"spike_amplitudes" if sorting_analyzer.has_extension("spike_amplitudes") else "amplitude_scalings"
1012+
)
1013+
extension = sorting_analyzer.get_extension(available_extension)
10241014
amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods)
10251015

10261016
for unit_id in unit_ids:
@@ -1029,7 +1019,7 @@ def compute_noise_cutoffs(
10291019
cutoff, ratio = np.nan, np.nan
10301020
continue
10311021

1032-
if invert_amplitudes:
1022+
if np.median(amplitudes) < 0: # _noise_cutoff expects positive amplitudes
10331023
amplitudes = -amplitudes
10341024

10351025
cutoff, ratio = _noise_cutoff(amplitudes, high_quantile=high_quantile, low_quantile=low_quantile, n_bins=n_bins)
@@ -1520,6 +1510,9 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val
15201510
"""
15211511
Calculate approximate fraction of spikes missing from a distribution of amplitudes.
15221512
1513+
Find the missing spikes from the left tail of the distribution. Assumes cutoff happens at spikes
1514+
with lower amplitudes.
1515+
15231516
See compute_amplitude_cutoffs for additional documentation
15241517
15251518
Parameters
@@ -1544,27 +1537,20 @@ def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_val
15441537
if len(amplitudes) / num_histogram_bins < amplitudes_bins_min_ratio:
15451538
return np.nan
15461539
else:
1547-
h, b = np.histogram(amplitudes, num_histogram_bins, density=True)
1548-
1549-
# TODO : use something better than scipy.ndimage.gaussian_filter1d
15501540
from scipy.ndimage import gaussian_filter1d
15511541

1552-
pdf = gaussian_filter1d(h, histogram_smoothing_value)
1553-
support = b[:-1]
1554-
bin_size = np.mean(np.diff(support))
1555-
peak_index = np.argmax(pdf)
1542+
# Approximate amplitude pdf with np.histogram
1543+
h = np.histogram(amplitudes, num_histogram_bins)[0]
1544+
pdf = gaussian_filter1d(h, histogram_smoothing_value, mode="nearest")
15561545

1557-
pdf_above = np.abs(pdf[peak_index:] - pdf[0])
1558-
1559-
if len(np.where(pdf_above == pdf_above.min())[0]) > 1:
1560-
warnings.warn(
1561-
"Amplitude PDF does not have a unique minimum! More spikes might be required for a correct "
1562-
"amplitude_cutoff computation!"
1563-
)
1546+
# Find number of missed spikes
1547+
cutoff_point = pdf[0] # >> pdf[-1] if spikes were cutoff (at lower amplitudes)
1548+
G = np.where(pdf >= cutoff_point)[0][-1] # last occurence where pdf was greater than cutoff
1549+
num_missed_spikes = np.sum(pdf[G + 1 :]) # theoretically missing spikes on the left side
15641550

1565-
G = np.argmin(pdf_above) + peak_index
1566-
fraction_missing = np.sum(pdf[G:]) * bin_size
1567-
fraction_missing = np.min([fraction_missing, 0.5])
1551+
# Compute fraction of missed spikes
1552+
fraction_missing = num_missed_spikes / (len(amplitudes) + num_missed_spikes)
1553+
fraction_missing = min(fraction_missing, 0.5)
15681554

15691555
return fraction_missing
15701556

src/spikeinterface/metrics/quality/tests/test_metrics_functions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
compute_sd_ratio,
3131
_noise_cutoff,
3232
_get_synchrony_counts,
33+
amplitude_cutoff,
3334
)
3435

3536
from spikeinterface.metrics.quality.pca_metrics import (
@@ -130,6 +131,20 @@ def test_noise_cutoff():
130131
assert ratio1 <= ratio2
131132

132133

134+
def test_amplitude_cutoff():
135+
"""
136+
Generate two artificial gaussians, one truncated and one not. Check the metrics are higher for the truncated one.
137+
"""
138+
np.random.seed(1)
139+
amps = np.random.normal(0, 1, 1000)
140+
amps_trunc = amps[amps > -1]
141+
142+
fraction_missing_1 = amplitude_cutoff(amplitudes=amps, num_histogram_bins=20)
143+
fraction_missing_2 = amplitude_cutoff(amplitudes=amps_trunc, num_histogram_bins=20)
144+
145+
assert fraction_missing_1 < fraction_missing_2
146+
147+
133148
def test_synchrony_counts_no_sync():
134149

135150
spike_times, spike_units = synthesize_random_firings(num_units=1, duration=1, firing_rates=1.0)

0 commit comments

Comments
 (0)