Skip to content

Commit 4f3c04e

Browse files
authored
Replace int() truncation with shared ms_to_samples() (#4484)
1 parent 6b47603 commit 4f3c04e

25 files changed

Lines changed: 89 additions & 65 deletions

src/spikeinterface/benchmark/tests/common_benchmark_testing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
estimate_templates,
1717
Templates,
1818
create_sorting_analyzer,
19+
ms_to_samples,
1920
)
2021
from spikeinterface.generation import generate_drifting_recording
2122

@@ -54,8 +55,8 @@ def make_dataset(job_kwargs={}):
5455
def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_in_uV=False, **job_kwargs):
5556
spikes = gt_sorting.to_spike_vector() # [spike_indices]
5657
fs = recording.sampling_frequency
57-
nbefore = int(ms_before * fs / 1000)
58-
nafter = int(ms_after * fs / 1000)
58+
nbefore = ms_to_samples(ms_before, fs)
59+
nafter = ms_to_samples(ms_after, fs)
5960
templates_array = estimate_templates(
6061
recording,
6162
spikes,

src/spikeinterface/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
read_python,
9191
write_python,
9292
normal_pdf,
93+
ms_to_samples,
9394
)
9495
from .job_tools import (
9596
get_best_job_kwargs,

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .template import Templates
2222
from .sorting_tools import random_spikes_selection, select_sorting_periods_mask, spike_vector_to_indices
2323
from .job_tools import fix_job_kwargs, split_job_kwargs
24+
from .core_tools import ms_to_samples
2425

2526

2627
class ComputeRandomSpikes(AnalyzerExtension):
@@ -170,11 +171,11 @@ class ComputeWaveforms(AnalyzerExtension):
170171

171172
@property
172173
def nbefore(self):
173-
return int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
174+
return ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency)
174175

175176
@property
176177
def nafter(self):
177-
return int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
178+
return ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency)
178179

179180
def _run(self, verbose=False, **job_kwargs):
180181
self.data.clear()
@@ -540,12 +541,12 @@ def _compute_and_append_from_waveforms(self, operators):
540541

541542
@property
542543
def nbefore(self):
543-
nbefore = int(self.params["ms_before"] * self.sorting_analyzer.sampling_frequency / 1000.0)
544+
nbefore = ms_to_samples(self.params["ms_before"], self.sorting_analyzer.sampling_frequency)
544545
return nbefore
545546

546547
@property
547548
def nafter(self):
548-
nafter = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
549+
nafter = ms_to_samples(self.params["ms_after"], self.sorting_analyzer.sampling_frequency)
549550
return nafter
550551

551552
def _select_extension_data(self, unit_ids):

src/spikeinterface/core/core_tools.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,3 +757,8 @@ def is_path_remote(path: str | Path) -> bool:
757757
Whether the path is a remote path.
758758
"""
759759
return "s3://" in str(path) or "gcs://" in str(path)
760+
761+
762+
def ms_to_samples(ms: float, sampling_frequency: float) -> int:
763+
"""Convert a duration in milliseconds to the nearest number of samples."""
764+
return round(ms * sampling_frequency / 1000.0)

src/spikeinterface/core/generate.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting
1414
from .snippets_tools import snippets_from_sorting
15-
from .core_tools import define_function_from_class
15+
from .core_tools import define_function_from_class, ms_to_samples
1616

1717

1818
def _ensure_seed(seed):
@@ -1598,8 +1598,8 @@ def generate_single_fake_waveform(
15981598
assert ms_after > depolarization_ms + repolarization_ms
15991599
assert ms_before > depolarization_ms
16001600

1601-
nbefore = int(sampling_frequency * ms_before / 1000.0)
1602-
nafter = int(sampling_frequency * ms_after / 1000.0)
1601+
nbefore = ms_to_samples(ms_before, sampling_frequency)
1602+
nafter = ms_to_samples(ms_after, sampling_frequency)
16031603
width = nbefore + nafter
16041604
wf = np.zeros(width, dtype=dtype)
16051605

@@ -1776,8 +1776,8 @@ def generate_templates(
17761776

17771777
num_units = units_locations.shape[0]
17781778
num_channels = channel_locations.shape[0]
1779-
nbefore = int(sampling_frequency * ms_before / 1000.0)
1780-
nafter = int(sampling_frequency * ms_after / 1000.0)
1779+
nbefore = ms_to_samples(ms_before, sampling_frequency)
1780+
nafter = ms_to_samples(ms_after, sampling_frequency)
17811781
width = nbefore + nafter
17821782

17831783
if upsample_factor is not None:
@@ -2451,8 +2451,8 @@ def generate_ground_truth_recording(
24512451
upsample_factor = templates.shape[3]
24522452
upsample_vector = rng.integers(0, upsample_factor, size=num_spikes)
24532453

2454-
nbefore = int(ms_before * sampling_frequency / 1000.0)
2455-
nafter = int(ms_after * sampling_frequency / 1000.0)
2454+
nbefore = ms_to_samples(ms_before, sampling_frequency)
2455+
nafter = ms_to_samples(ms_after, sampling_frequency)
24562456
assert (nbefore + nafter) == templates.shape[1]
24572457

24582458
# construct recording

src/spikeinterface/core/node_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from spikeinterface.core import BaseRecording, get_chunk_with_margin
1515
from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc
1616
from spikeinterface.core import get_channel_distances
17+
from spikeinterface.core.core_tools import ms_to_samples
1718

1819

1920
class PipelineNode:
@@ -314,8 +315,8 @@ def __init__(
314315
PipelineNode.__init__(self, recording=recording, parents=parents, return_output=return_output)
315316
self.ms_before = ms_before
316317
self.ms_after = ms_after
317-
self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0)
318-
self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0)
318+
self.nbefore = ms_to_samples(ms_before, recording.get_sampling_frequency())
319+
self.nafter = ms_to_samples(ms_after, recording.get_sampling_frequency())
319320
self.neighbours_mask = None
320321

321322

src/spikeinterface/core/sparsity.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .sorting_tools import random_spikes_selection
88
from .job_tools import _shared_job_kwargs_doc
99
from .waveform_tools import estimate_templates_with_accumulator
10+
from .core_tools import ms_to_samples
1011

1112
_sparsity_doc = """
1213
method : str
@@ -784,8 +785,8 @@ def estimate_sparsity(
784785
probe = recording.create_dummy_probe_from_locations(chan_locs)
785786

786787
if method != "by_property":
787-
nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
788-
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
788+
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
789+
nafter = ms_to_samples(ms_after, recording.sampling_frequency)
789790

790791
num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())]
791792
random_spikes_indices = random_spikes_selection(

src/spikeinterface/core/tests/test_loading.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
generate_ground_truth_recording,
66
create_sorting_analyzer,
77
load,
8+
ms_to_samples,
89
SortingAnalyzer,
910
Templates,
1011
aggregate_channels,
@@ -71,7 +72,7 @@ def generate_templates_object():
7172
templates = Templates(
7273
templates_array=templates_arr,
7374
sampling_frequency=sampling_frequency,
74-
nbefore=int(ms_before * sampling_frequency / 1000),
75+
nbefore=ms_to_samples(ms_before, sampling_frequency),
7576
probe=probe,
7677
)
7778
return templates

src/spikeinterface/core/tests/test_waveform_tools.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording
8+
from spikeinterface.core import generate_recording, generate_sorting, generate_ground_truth_recording, ms_to_samples
99
from spikeinterface.core.waveform_tools import (
1010
extract_waveforms_to_buffers,
1111
extract_waveforms_to_single_buffer,
@@ -56,8 +56,8 @@ def test_waveform_tools(create_cache_folder):
5656
recording, sorting = get_dataset()
5757
sampling_frequency = recording.sampling_frequency
5858

59-
nbefore = int(3.0 * sampling_frequency / 1000.0)
60-
nafter = int(4.0 * sampling_frequency / 1000.0)
59+
nbefore = ms_to_samples(3.0, sampling_frequency)
60+
nafter = ms_to_samples(4.0, sampling_frequency)
6161

6262
dtype = recording.get_dtype()
6363
# return_in_uV = False
@@ -164,8 +164,8 @@ def test_estimate_templates_with_accumulator():
164164
ms_before = 1.0
165165
ms_after = 1.5
166166

167-
nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
168-
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
167+
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
168+
nafter = ms_to_samples(ms_after, recording.sampling_frequency)
169169

170170
spikes = sorting.to_spike_vector()
171171
# take one spikes every 10
@@ -218,8 +218,8 @@ def test_estimate_templates():
218218
ms_before = 1.0
219219
ms_after = 1.5
220220

221-
nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
222-
nafter = int(ms_after * recording.sampling_frequency / 1000.0)
221+
nbefore = ms_to_samples(ms_before, recording.sampling_frequency)
222+
nafter = ms_to_samples(ms_after, recording.sampling_frequency)
223223

224224
spikes = sorting.to_spike_vector()
225225
# take one spikes every 10

src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .sparsity import ChannelSparsity
2222
from .sortinganalyzer import SortingAnalyzer, load_sorting_analyzer
2323
from .loading import load
24+
from .core_tools import ms_to_samples
2425
from .analyzer_extension_core import ComputeRandomSpikes, ComputeWaveforms, ComputeTemplates
2526

2627
_backwards_compatibility_msg = """####
@@ -162,12 +163,12 @@ def unit_ids(self) -> np.ndarray:
162163
@property
163164
def nbefore(self) -> int:
164165
ms_before = self.sorting_analyzer.get_extension("waveforms").params["ms_before"]
165-
return int(ms_before * self.sampling_frequency / 1000.0)
166+
return ms_to_samples(ms_before, self.sampling_frequency)
166167

167168
@property
168169
def nafter(self) -> int:
169170
ms_after = self.sorting_analyzer.get_extension("waveforms").params["ms_after"]
170-
return int(ms_after * self.sampling_frequency / 1000.0)
171+
return ms_to_samples(ms_after, self.sampling_frequency)
171172

172173
@property
173174
def nsamples(self) -> int:
@@ -522,8 +523,8 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
522523
else:
523524
max_num_channel = np.max(np.sum(sparsity.mask, axis=1))
524525

525-
nbefore = int(params["ms_before"] * sorting.sampling_frequency / 1000.0)
526-
nafter = int(params["ms_after"] * sorting.sampling_frequency / 1000.0)
526+
nbefore = ms_to_samples(params["ms_before"], sorting.sampling_frequency)
527+
nafter = ms_to_samples(params["ms_after"], sorting.sampling_frequency)
527528

528529
waveforms = np.zeros((num_spikes, nbefore + nafter, max_num_channel), dtype=params["dtype"])
529530
# then read waveforms per units

0 commit comments

Comments
 (0)