Skip to content

Commit 51b26d4

Browse files
committed
Add t_starts parameter to generate_recording and generate_sorting
1 parent 4c3a6f9 commit 51b26d4

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

src/spikeinterface/core/generate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def generate_recording(
3131
durations: list[float] = [5.0, 2.5],
3232
set_probe: bool | None = True,
3333
ndim: int | None = 2,
34+
t_starts: list[float] | None = None,
3435
seed: int | None = None,
3536
) -> BaseRecording:
3637
"""
@@ -50,6 +51,9 @@ def generate_recording(
5051
If true, attaches probe to the returned `Recording`
5152
ndim : int, default: 2
5253
The number of dimensions of the probe, default: 2. Set to 3 to make 3 dimensional probe.
54+
t_starts : list[float] | None, default: None
55+
The start time of each segment in seconds. If provided, must have the same
56+
length as `durations`.
5357
seed : int | None, default: None
5458
A seed for the np.ramdom.default_rng function
5559
@@ -71,6 +75,11 @@ def generate_recording(
7175
noise_block_size=int(sampling_frequency),
7276
)
7377

78+
if t_starts is not None:
79+
assert len(t_starts) == len(durations), "t_starts must have the same length as durations"
80+
for segment_index, t_start in enumerate(t_starts):
81+
recording.segments[segment_index].t_start = t_start
82+
7483
recording.annotate(is_filtered=True)
7584

7685
if set_probe:
@@ -95,6 +104,7 @@ def generate_sorting(
95104
add_spikes_on_borders=False,
96105
num_spikes_per_border=3,
97106
border_size_samples=20,
107+
t_starts=None,
98108
seed=None,
99109
):
100110
"""
@@ -122,6 +132,9 @@ def generate_sorting(
122132
The number of spikes to add close to the borders of the segments.
123133
border_size_samples : int, default: 20
124134
The size of the border in samples to add border spikes.
135+
t_starts : list[float] | None, default: None
136+
The start time of each segment in seconds. If provided, must have the same
137+
length as `durations`.
125138
seed : int, default: None
126139
The random seed.
127140
@@ -177,6 +190,11 @@ def generate_sorting(
177190

178191
sorting = NumpySorting(spikes, sampling_frequency, unit_ids)
179192

193+
if t_starts is not None:
194+
assert len(t_starts) == len(durations), "t_starts must have the same length as durations"
195+
for segment_index, t_start in enumerate(t_starts):
196+
sorting.segments[segment_index]._t_start = t_start
197+
180198
return sorting
181199

182200

0 commit comments

Comments
 (0)