11import spatialdata as sd
22from xenium_analysis_tools .utils .sd_utils import add_micron_coord_sys
3- from spatialdata .models import Image3DModel , Labels3DModel , PointsModel
3+ from spatialdata .models import Image2DModel , Image3DModel , Labels3DModel , Labels2DModel , PointsModel , ShapesModel , TableModel
4+ from spatialdata .transformations import get_transformation , set_transformation
5+ import anndata as ad
6+ from spatialdata import get_element_instances
47from pathlib import Path
58import pandas as pd
69import numpy as np
@@ -173,6 +176,7 @@ def get_alignment_data_paths(dataset_id,
173176 "data_root" : data_root ,
174177 "scratch_root" : scratch_root ,
175178 "results_root" : results_root ,
179+ "xenium_dataset_name" : dataset_config ["xenium_name" ],
176180 "sdata_path" : data_root / f'{ dataset_config ["xenium_name" ]} _processed' ,
177181 "confocal_path" : data_root / dataset_config ["confocal_name" ],
178182 "zstack_path" : data_root / dataset_config ["zstack_name" ],
@@ -187,43 +191,31 @@ def get_label_params(label_obj, id_name='cell'):
187191 props = regionprops (labels )
188192 data = [
189193 {f'{ id_name } _id' : p .label ,
190- 'z' : p .centroid [0 ],
191- 'y' : p .centroid [1 ],
192- 'x' : p .centroid [2 ],
194+ 'z' : p .centroid [0 ] if len ( p . centroid ) == 3 else None ,
195+ 'y' : p .centroid [1 ] if len ( p . centroid ) == 3 else p . centroid [ 0 ],
196+ 'x' : p .centroid [2 ] if len ( p . centroid ) == 3 else p . centroid [ 1 ] ,
193197 'area' : p .area ,
194198 'bbox' : p .bbox }
195199 for p in props
196200 ]
197201 df = pd .DataFrame (data )
198202 return df
199203
200- def get_zstack_sdata (stack , zstack_masks = None , get_centroids_as_points = True ):
204+ def get_zstack_sdata (stack , zstack_masks = None ):
201205 # Create the z-stack image array
202- num_channels = len (stack ['zstack_channels' ])
203- chans = []
204- if num_channels > 1 :
205- for ch_ind in range (num_channels ):
206- chan_array = create_zstack_array (tif_path = stack ['channel_tifs' ][ch_ind ]['chan_tif_path' ],
206+ chan_arrays = {}
207+ for ch_ind , chan_name in enumerate (stack ['zstack_channels' ]):
208+ chan_img = create_zstack_array (tif_path = stack ['channel_tifs' ][ch_ind ]['chan_tif_path' ],
207209 fov_x_um = stack ['zstack_size' ]['width' ],
208210 fov_y_um = stack ['zstack_size' ]['height' ],
209211 fov_z_um = stack ['zstack_size' ]['depth' ])
210- chans .append (chan_array )
211- zstack_img = xr .concat (chans , dim = 'c' )
212- zstack_img ['c' ] = stack ['zstack_channels' ]
213- else :
214- zstack_img = create_zstack_array (tif_path = stack ['channel_tifs' ][0 ]['chan_tif_path' ],
215- fov_x_um = stack ['zstack_size' ]['width' ],
216- fov_y_um = stack ['zstack_size' ]['height' ],
217- fov_z_um = stack ['zstack_size' ]['depth' ])
218- zstack_img ['c' ] = stack ['zstack_channels' ]
219-
220- # Parse into Image3DModel
221- zstack_img = Image3DModel .parse (
222- zstack_img ,
212+ chan_img = Image3DModel .parse (
213+ chan_img ,
223214 dims = ['c' , 'z' , 'y' , 'x' ],
224- c_coords = stack [ 'zstack_channels' ] ,
215+ c_coords = chan_name ,
225216 chunks = 'auto' ,
226217 )
218+ chan_arrays [chan_name ] = chan_img
227219
228220 if zstack_masks is not None :
229221 zstack_labels = {}
@@ -242,28 +234,32 @@ def get_zstack_sdata(stack, zstack_masks=None, get_centroids_as_points=True):
242234 chunks = 'auto' ,
243235 )
244236 zstack_labels [f"{ channel_name } _labels" ] = zstack_label
245-
246- if get_centroids_as_points :
247- zstack_points = {}
248- # Get label parameters add as points
249- for label_name , labels_obj in zstack_labels .items ():
250- chan_name = label_name .replace ('_labels' ,'' )
251- cells_df = get_label_params (labels_obj , id_name = chan_name )
252- print (f"# { chan_name } segmented cells: { len (cells_df )} " )
253- cells_df = PointsModel .parse (cells_df )
254- zstack_points [f"{ chan_name } _cells" ] = cells_df
237+
238+ tables = {}
239+ for label_name , labels_obj in zstack_labels .items ():
240+ chan_name = label_name .replace ('_labels' ,'' )
241+ label_type_id = f'{ chan_name } _id'
242+ chan_label_ids = get_element_instances (labels_obj ).values
243+ obs = pd .DataFrame (chan_label_ids , columns = [label_type_id ])
244+ cells_df = get_label_params (labels_obj , id_name = chan_name )
245+ cells_df ['region' ] = label_name
246+ obs = obs .merge (cells_df , left_on = label_type_id , right_on = chan_name + '_id' , how = 'left' )
247+ table = ad .AnnData (obs = obs , obsm = {'spatial' : obs [['z' ,'y' ,'x' ]].values })
248+ table = TableModel .parse (table , region = label_name , region_key = 'region' , instance_key = label_type_id )
249+ tables [f'{ chan_name } _cells' ] = table
255250
256251 # Assemble SpatialData
257252 zstack_sdata = sd .SpatialData (
258- images = {'zstack' : zstack_img },
253+ images = {** chan_arrays },
259254 labels = {** zstack_labels } if zstack_masks is not None else {},
260- points = {** zstack_points } if ( zstack_masks is not None and get_centroids_as_points ) else {}
255+ tables = {** tables } if zstack_masks is not None else {}
261256 )
262257
263258 # Determine pixel sizes
264- if zstack_sdata ['zstack' ].attrs ['pixel_size_um_x' ] == zstack_sdata ['zstack' ].attrs ['pixel_size_um_y' ]:
265- pixel_size = zstack_sdata ['zstack' ].attrs ['pixel_size_um_x' ]
266- if zstack_sdata ['zstack' ].attrs ['fov_um_z' ]== zstack_sdata ['zstack' ].shape [1 ]:
259+ zstack_chan = zstack_sdata [stack ['zstack_channels' ][0 ]] # Use first channel for pixel size reference
260+ if zstack_chan .attrs ['pixel_size_um_x' ] == zstack_chan .attrs ['pixel_size_um_y' ]:
261+ pixel_size = zstack_chan .attrs ['pixel_size_um_x' ]
262+ if zstack_chan .attrs ['fov_um_z' ]== zstack_chan .shape [1 ]:
267263 z_step_size = 1
268264
269265 # Add micron coordinate system if not already present
@@ -273,3 +269,85 @@ def get_zstack_sdata(stack, zstack_masks=None, get_centroids_as_points=True):
273269 print ("Micron coordinate system already exists" )
274270 return zstack_sdata
275271
272+ def get_alignment_spatial_elements (sdata , scale_from_level = 2 , channel_names = ['dapi' , 'boundary' , 'rna' , 'protein' ]):
273+ # Technically should only need to replace morphology focus transforms, but doing for all elements just in case
274+ # For elements, get at a specific scale level (if multi-scale) and set transform to global coordinate system
275+ # Images
276+ # Dapi z-stack
277+ dapi_zstack_level = sdata ['dapi_zstack' ][f'scale{ scale_from_level } ' ].image
278+ dapi_zstack_global_tf = get_transformation (dapi_zstack_level , to_coordinate_system = 'global' )
279+ dapi_zstack = Image3DModel .parse (sdata ['dapi_zstack' ][f'scale{ scale_from_level } ' ].image ,
280+ dims = ['c' , 'z' , 'y' , 'x' ],
281+ c_coords = ['DAPI' ],
282+ chunks = 'auto' ,
283+ )
284+ set_transformation (dapi_zstack , dapi_zstack_global_tf , to_coordinate_system = 'global' )
285+
286+ # Morphology focus channels
287+ mf_chans_level = sdata ['morphology_focus' ][f'scale{ scale_from_level } ' ].image
288+ mf_img_global_tf = get_transformation (mf_chans_level , to_coordinate_system = 'global' )
289+ chans_arrays = {}
290+ for chan_ind , chan in enumerate (channel_names ):
291+ chan_img = sdata ['morphology_focus' ][f'scale{ scale_from_level } ' ].image [chan_ind ]
292+ chan_img = np .expand_dims (chan_img .data , axis = 0 )
293+ chans_arrays [chan ] = Image2DModel .parse (chan_img ,
294+ dims = ['c' , 'y' , 'x' ],
295+ c_coords = chan ,
296+ chunks = 'auto' ,
297+ )
298+ set_transformation (chans_arrays [chan ], mf_img_global_tf , to_coordinate_system = 'global' )
299+
300+ images = {'dapi_zstack' : dapi_zstack , ** chans_arrays }
301+
302+ # Labels
303+ cell_labels_level = sdata ['cell_labels' ][f'scale{ scale_from_level } ' ].image
304+ cell_labels_tf = get_transformation (cell_labels_level , to_coordinate_system = 'global' )
305+ cell_labels = Labels2DModel .parse (cell_labels_level , dims = ['y' , 'x' ], chunks = 'auto' )
306+ set_transformation (cell_labels , cell_labels_tf , to_coordinate_system = 'global' )
307+ nucleus_labels_level = sdata ['nucleus_labels' ][f'scale{ scale_from_level } ' ].image
308+ nucleus_labels = Labels2DModel .parse (nucleus_labels_level , dims = ['y' , 'x' ], chunks = 'auto' )
309+ set_transformation (nucleus_labels , cell_labels_tf , to_coordinate_system = 'global' )
310+
311+ labels = {
312+ 'cell_labels' : cell_labels ,
313+ 'nucleus_labels' : nucleus_labels ,
314+ }
315+
316+ return images , labels
317+
318+ def get_alignment_shapes_tables (sdata ,
319+ transcripts_qv_thresh = 20 ,
320+ annotate_spatial_elements = 'cell_boundaries' ,
321+ cell_id_name = 'cell_id' ,
322+ mask_id_name = 'cell_labels' ):
323+ # Make cell_id to cell_label mapping dictionary
324+ cell_id_label_dict = dict (zip (sdata ['table' ].obs [cell_id_name ].values , sdata ['table' ].obs [mask_id_name ].values ))
325+ transcripts = sdata ['transcripts' ].compute ()
326+ # Drop transcripts not included in counts
327+ transcripts = transcripts [transcripts ['qv' ]>= transcripts_qv_thresh ]
328+ # Add cell_labels to transcripts based on cell_id
329+ transcripts [mask_id_name ] = transcripts [cell_id_name ].map (cell_id_label_dict ).fillna (0 ).astype ('int64' )
330+ # Annotate spatial elements (e.g., cell_boundaries) with cell_labels
331+ sdata [annotate_spatial_elements ][cell_id_name ] = sdata [annotate_spatial_elements ].index .values
332+ sdata [annotate_spatial_elements ][mask_id_name ] = sdata [annotate_spatial_elements ][cell_id_name ].map (cell_id_label_dict ).values
333+ sdata [annotate_spatial_elements ].set_index (mask_id_name , inplace = True , drop = False )
334+ # Update annotation regions
335+ table = sdata ['table' ].copy ()
336+ table .obs ['region' ] = annotate_spatial_elements
337+ table .obs ['region' ] = pd .Categorical (table .obs ['region' ])
338+ table .uns ['spatialdata_attrs' ].update ({
339+ 'region_key' : 'region' ,
340+ 'region' : [annotate_spatial_elements ],
341+ 'instance_key' : mask_id_name
342+ })
343+
344+ # Parse shapes
345+ annotated_shape = ShapesModel .parse (sdata [annotate_spatial_elements ])
346+ shapes = sdata .shapes
347+ shapes [annotate_spatial_elements ] = annotated_shape
348+ # Parse table
349+ table = TableModel .parse (table )
350+ # Parse transcripts
351+ transcripts = PointsModel .parse (transcripts )
352+
353+ return table , transcripts , shapes
0 commit comments