7777
7878from neo import logging_handler
7979
80+ from .utils import get_memmap_chunk_from_opened_file
81+
8082
8183possible_raw_modes = [
8284 "one-file" ,
@@ -182,6 +184,15 @@ def __init__(self, use_cache: bool = False, cache_path: str = "same_as_resource"
182184 self .header = None
183185 self .is_header_parsed = False
184186
187+ self ._has_buffer_description_api = False
188+
189+ def has_buffer_description_api (self ) -> bool :
190+ """
191+ Return if the reader handle the buffer API.
192+ If True then the reader support internally `get_analogsignal_buffer_description()`
193+ """
194+ return self ._has_buffer_description_api
195+
185196 def parse_header (self ):
186197 """
187198 Parses the header of the file(s) to allow for faster computations
@@ -191,6 +202,7 @@ def parse_header(self):
191202 # this must create
192203 # self.header['nb_block']
193204 # self.header['nb_segment']
205+ # self.header['signal_buffers']
194206 # self.header['signal_streams']
195207 # self.header['signal_channels']
196208 # self.header['spike_channels']
@@ -663,6 +675,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int |
663675
664676 """
665677 stream_index = self ._get_stream_index_from_arg (stream_index )
678+
666679 return self ._get_signal_size (block_index , seg_index , stream_index )
667680
668681 def get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int | None = None ):
@@ -1311,7 +1324,6 @@ def _get_analogsignal_chunk(
13111324 -------
13121325 array of samples, with each requested channel in a column
13131326 """
1314-
13151327 raise (NotImplementedError )
13161328
13171329 ###
@@ -1350,6 +1362,150 @@ def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype
13501362 def _rescale_epoch_duration (self , raw_duration : np .ndarray , dtype : np .dtype ):
13511363 raise (NotImplementedError )
13521364
1365+ ###
1366+ # buffer api zone
1367+ # must be implemented if has_buffer_description_api=True
1368+ def get_analogsignal_buffer_description (self , block_index : int = 0 , seg_index : int = 0 , buffer_id : str = None ):
1369+ if not self .has_buffer_description_api :
1370+ raise ValueError ("This reader do not support buffer_description API" )
1371+ descr = self ._get_analogsignal_buffer_description (block_index , seg_index , buffer_id )
1372+ return descr
1373+
1374+ def _get_analogsignal_buffer_description (self , block_index , seg_index , buffer_id ):
1375+ raise (NotImplementedError )
1376+
1377+
1378+
1379+ class BaseRawWithBufferApiIO (BaseRawIO ):
1380+ """
1381+ Generic class for reader that support "buffer api".
1382+
1383+ In short reader that are internally based on:
1384+
1385+ * np.memmap
1386+ * hdf5
1387+
1388+ In theses cases _get_signal_size and _get_analogsignal_chunk are totaly generic and do not need to be implemented in the class.
1389+
1390+ For this class sub classes must implements theses two dict:
1391+ * self._buffer_descriptions[block_index][seg_index] = buffer_description
1392+ * self._stream_buffer_slice[buffer_id] = None or slicer o indices
1393+
1394+ """
1395+
1396+ def __init__ (self , * arg , ** kwargs ):
1397+ super ().__init__ (* arg , ** kwargs )
1398+ self ._has_buffer_description_api = True
1399+
1400+ def _get_signal_size (self , block_index , seg_index , stream_index ):
1401+ buffer_id = self .header ["signal_streams" ][stream_index ]["buffer_id" ]
1402+ buffer_desc = self .get_analogsignal_buffer_description (block_index , seg_index , buffer_id )
1403+ # some hdf5 revert teh buffer
1404+ time_axis = buffer_desc .get ("time_axis" , 0 )
1405+ return buffer_desc ['shape' ][time_axis ]
1406+
1407+ def _get_analogsignal_chunk (
1408+ self ,
1409+ block_index : int ,
1410+ seg_index : int ,
1411+ i_start : int | None ,
1412+ i_stop : int | None ,
1413+ stream_index : int ,
1414+ channel_indexes : list [int ] | None ,
1415+ ):
1416+
1417+ stream_id = self .header ["signal_streams" ][stream_index ]["id" ]
1418+ buffer_id = self .header ["signal_streams" ][stream_index ]["buffer_id" ]
1419+
1420+ buffer_slice = self ._stream_buffer_slice [stream_id ]
1421+
1422+
1423+ buffer_desc = self .get_analogsignal_buffer_description (block_index , seg_index , buffer_id )
1424+
1425+ i_start = i_start or 0
1426+ i_stop = i_stop or buffer_desc ['shape' ][0 ]
1427+
1428+ if buffer_desc ['type' ] == "raw" :
1429+
1430+ # open files on demand and keep reference to opened file
1431+ if not hasattr (self , '_memmap_analogsignal_buffers' ):
1432+ self ._memmap_analogsignal_buffers = {}
1433+ if block_index not in self ._memmap_analogsignal_buffers :
1434+ self ._memmap_analogsignal_buffers [block_index ] = {}
1435+ if seg_index not in self ._memmap_analogsignal_buffers [block_index ]:
1436+ self ._memmap_analogsignal_buffers [block_index ][seg_index ] = {}
1437+ if buffer_id not in self ._memmap_analogsignal_buffers [block_index ][seg_index ]:
1438+ fid = open (buffer_desc ['file_path' ], mode = 'rb' )
1439+ self ._memmap_analogsignal_buffers [block_index ][seg_index ][buffer_id ] = fid
1440+ else :
1441+ fid = self ._memmap_analogsignal_buffers [block_index ][seg_index ][buffer_id ]
1442+
1443+ num_channels = buffer_desc ['shape' ][1 ]
1444+
1445+ raw_sigs = get_memmap_chunk_from_opened_file (fid , num_channels , i_start , i_stop , np .dtype (buffer_desc ['dtype' ]), file_offset = buffer_desc ['file_offset' ])
1446+
1447+
1448+ elif buffer_desc ['type' ] == 'hdf5' :
1449+
1450+ # open files on demand and keep reference to opened file
1451+ if not hasattr (self , '_hdf5_analogsignal_buffers' ):
1452+ self ._hdf5_analogsignal_buffers = {}
1453+ if block_index not in self ._hdf5_analogsignal_buffers :
1454+ self ._hdf5_analogsignal_buffers [block_index ] = {}
1455+ if seg_index not in self ._hdf5_analogsignal_buffers [block_index ]:
1456+ self ._hdf5_analogsignal_buffers [block_index ][seg_index ] = {}
1457+ if buffer_id not in self ._hdf5_analogsignal_buffers [block_index ][seg_index ]:
1458+ import h5py
1459+ h5file = h5py .File (buffer_desc ['file_path' ], mode = "r" )
1460+ self ._hdf5_analogsignal_buffers [block_index ][seg_index ][buffer_id ] = h5file
1461+ else :
1462+ h5file = self ._hdf5_analogsignal_buffers [block_index ][seg_index ][buffer_id ]
1463+
1464+ hdf5_path = buffer_desc ["hdf5_path" ]
1465+ full_raw_sigs = h5file [hdf5_path ]
1466+
1467+ time_axis = buffer_desc .get ("time_axis" , 0 )
1468+ if time_axis == 0 :
1469+ raw_sigs = full_raw_sigs [i_start :i_stop , :]
1470+ elif time_axis == 1 :
1471+ raw_sigs = full_raw_sigs [:, i_start :i_stop ].T
1472+ else :
1473+ raise RuntimeError ("Should never happen" )
1474+
1475+ if buffer_slice is not None :
1476+ raw_sigs = raw_sigs [:, buffer_slice ]
1477+
1478+
1479+
1480+ else :
1481+ raise NotImplementedError ()
1482+
1483+ # this is a pre slicing when the stream do not contain all channels (for instance spikeglx when load_sync_channel=False)
1484+ if buffer_slice is not None :
1485+ raw_sigs = raw_sigs [:, buffer_slice ]
1486+
1487+ # channel slice requested
1488+ if channel_indexes is not None :
1489+ raw_sigs = raw_sigs [:, channel_indexes ]
1490+
1491+
1492+ return raw_sigs
1493+
1494+ def __del__ (self ):
1495+ if hasattr (self , '_memmap_analogsignal_buffers' ):
1496+ for block_index in self ._memmap_analogsignal_buffers .keys ():
1497+ for seg_index in self ._memmap_analogsignal_buffers [block_index ].keys ():
1498+ for buffer_id , fid in self ._memmap_analogsignal_buffers [block_index ][seg_index ].items ():
1499+ fid .close ()
1500+ del self ._memmap_analogsignal_buffers
1501+
1502+ if hasattr (self , '_hdf5_analogsignal_buffers' ):
1503+ for block_index in self ._hdf5_analogsignal_buffers .keys ():
1504+ for seg_index in self ._hdf5_analogsignal_buffers [block_index ].keys ():
1505+ for buffer_id , h5_file in self ._hdf5_analogsignal_buffers [block_index ][seg_index ].items ():
1506+ h5_file .close ()
1507+ del self ._hdf5_analogsignal_buffers
1508+
13531509
13541510def pprint_vector (vector , lim : int = 8 ):
13551511 vector = np .asarray (vector )
0 commit comments