1-
1+ from xenium_analysis_tools . utils . sd_utils import add_micron_coord_sys
22from spatialdata .transformations import Scale , Identity , Sequence , set_transformation , get_transformation
33from spatialdata .models import Image3DModel
44import spatialdata as sd
55import xarray as xr
66import dask .array as da
77import zarr
8- from xenium_analysis_tools .utils .sd_utils import add_micron_coord_sys
98import pandas as pd
109import numpy as np
11- # from bioio_sldy import Reader
12- import pandas as pd
1310import yaml
14- import os
11+ from pathlib import Path
1512
1613def get_confocal_image_sizes (img_name , cf_raw_path , overlap = 0.1 ):
1714 confocal_notes = pd .read_csv (cf_raw_path / 'notes.csv' )
@@ -33,15 +30,17 @@ def get_confocal_image_sizes(img_name, cf_raw_path, overlap=0.1):
3330 'sizeZ' : shape [0 ],
3431 'sizeY' : shape [1 ],
3532 'sizeX' : shape [2 ],
36- 'sizeC' : 1 , # Confocal captures are usually single channel per dir
33+ 'sizeC' : 1 ,
3734 'physical_pixel_size_x' : phys_x ,
38- 'physical_pixel_size_y' : phys_x , # Typically square
39- 'physical_pixel_size_z' : 1.0 # Placeholder if not in YAML
35+ 'physical_pixel_size_y' : phys_x ,
36+ 'physical_pixel_size_z' : 1.0
4037 }
4138
42- def generate_confocal_sdata (zarr_path , raw_confocal_path = None , select_scales = None ):
39+ def generate_confocal_sdata (zarr_path , raw_confocal_path = None , select_scales = ['0' ,'1' ,'2' ,'3' ],
40+ chunk_size = (1 , 64 , 512 , 512 )):
4341 cf_name = zarr_path .stem
44- cf_dt = create_datatree_from_zarr (zarr_path , chan_name = cf_name , select_scales = select_scales )
42+ cf_dt = create_datatree_from_zarr (zarr_path , chan_name = cf_name ,
43+ select_scales = select_scales , chunk_size = chunk_size )
4544
4645 cf_sdata = sd .SpatialData (
4746 images = {cf_name : cf_dt }
@@ -50,160 +49,122 @@ def generate_confocal_sdata(zarr_path, raw_confocal_path=None, select_scales=Non
5049 if raw_confocal_path :
5150 cf_sizes = get_confocal_image_sizes (cf_name , raw_confocal_path )
5251 cf_sdata [cf_name ].attrs .update (cf_sizes )
53- if not cf_sizes ['physical_pixel_size_x' ]== cf_sizes ['physical_pixel_size_y' ]:
54- raise ValueError (f"Confocal pixel sizes in X and Y do not match for { cf_name } !" )
55- else :
56- cf_sdata = add_micron_coord_sys (cf_sdata , pixel_size = cf_sizes ['physical_pixel_size_x' ])
52+ cf_sdata = add_micron_coord_sys (cf_sdata , pixel_size = [cf_sizes ['physical_pixel_size_y' ], cf_sizes ['physical_pixel_size_x' ]], z_step = cf_sizes ['physical_pixel_size_z' ])
5753
5854 return cf_sdata
5955
60- def create_datatree_from_zarr (zarr_path , chan_name = 'chan' , select_scales = None ):
56+ def create_datatree_from_zarr (zarr_path , chan_name = 'chan' , select_scales = ['0' ,'1' ,'2' ,'3' ],
57+ chunk_size = (1 , 64 , 512 , 512 )):
6158 root = zarr .open_group (zarr_path , mode = 'r' )
6259 data_tree_obj = xr .DataTree ()
63-
64- for scale_level in sorted (list (root .keys ())):
65- if scale_level != '0' : #Have to have scale0 to determine sizes, but can drop later
60+
61+ # Pre-calculate which scales to process
62+ available_scales = sorted (list (root .keys ()))
63+ if select_scales is not None :
64+ # Always include scale0 for size calculations, filter later
65+ scales_to_process = ['0' ] + [s for s in select_scales if s != '0' ]
66+ else :
67+ scales_to_process = available_scales
68+
69+ scale0_shape = None
70+
71+ for scale_level in scales_to_process :
72+ if scale_level not in available_scales :
73+ continue
74+
75+ if scale_level != '0' :
6676 if select_scales is not None and scale_level not in select_scales :
6777 continue
68- # Load the image data at this scale level
69- level_array = da .from_zarr (str (zarr_path / scale_level ))
70- level_array = np .expand_dims (level_array , axis = 0 ) # Add c dimension
71- # Convert to xarray DataArray
78+
79+ print (f"Adding scale level: { scale_level } " )
80+
81+ # Optimize zarr loading with explicit chunking
82+ level_array = da .from_zarr (str (zarr_path / scale_level ), chunks = chunk_size [1 :]) # Skip c dim
83+ level_array = da .expand_dims (level_array , axis = 0 ) # Add c dimension
84+
85+ # Store scale0 shape for later calculations
86+ if scale_level == '0' :
87+ scale0_shape = level_array .shape
88+
89+ # Convert to xarray DataArray with optimized chunking
7290 data_array = xr .DataArray (
73- level_array ,
74- dims = ['c' , 'z' , 'y' , 'x' ]
75- )
91+ level_array ,
92+ dims = ['c' , 'z' , 'y' , 'x' ]
93+ )
94+
7695 parsed_array = Image3DModel .parse (
77- data_array ,
78- dims = ['c' , 'z' , 'y' , 'x' ],
79- c_coords = chan_name ,
80- chunks = 'auto' ,
96+ data_array ,
97+ dims = ['c' , 'z' , 'y' , 'x' ],
98+ c_coords = [ chan_name ], # Use list for consistency
99+ chunks = chunk_size , # Explicit chunking
81100 )
101+
82102 scale_key = f'scale{ scale_level } '
83103 data_tree_obj [scale_key ] = xr .Dataset ({'image' : parsed_array })
84- # Set up scale transformation for non-zero scales
85- if scale_key != 'scale0' :
86- scale_factors = np .array (data_tree_obj [f'scale0' ].image .shape ) / np .array (data_tree_obj [scale_key ].image .shape )
87- scale_transform = Scale (scale_factors , axes = data_tree_obj [scale_key ].image .dims )
104+
105+ # Set up transformations more efficiently
106+ if scale_level != '0' and scale0_shape is not None :
107+ current_shape = level_array .shape
108+ scale_factors = np .array (scale0_shape ) / np .array (current_shape )
109+ scale_transform = Scale (scale_factors , axes = parsed_array .dims )
88110 sequence = Sequence ([scale_transform , Identity ()])
89- set_transformation (data_tree_obj [ scale_key ]. image , sequence , to_coordinate_system = "global" )
111+ set_transformation (parsed_array , sequence , to_coordinate_system = "global" )
90112 else :
91- set_transformation (data_tree_obj [ scale_key ]. image , Identity (), to_coordinate_system = "global" )
113+ set_transformation (parsed_array , Identity (), to_coordinate_system = "global" )
92114
115+ # Handle scale removal and renaming more efficiently
93116 if select_scales is not None and '0' not in select_scales :
117+ print ("Removing scale0 from data tree as it's not in select_scales" )
94118 del data_tree_obj ['scale0' ]
95- if select_scales is not None : # rename scales
96- for n , scale_level in enumerate (list (data_tree_obj .keys ())):
97- data_tree_obj [f'scale{ n } ' ] = data_tree_obj [scale_level ]
98- data_tree_obj [f'scale{ n } ' ].image .attrs [f'original_scale_level' ] = scale_level
99- del data_tree_obj [scale_level ]
100-
101- return data_tree_obj
102-
103- # Coped from capsule 4 to keep track of overlap blending code
104- def generate_fused_confocal_images (data_asset , overlap = 0.1 , img_layers = 6 ):
105- notes = pd .read_csv (os .path .join (data_asset , 'notes.csv' ))
106- notes = notes [notes ['note' ]!= 'qc' ].reset_index (drop = True )
107- today_str = datetime .now ().strftime ('%Y-%m-%d_%H-%M-%S' )
108- processed_dir = os .path .join (data_asset .replace ('/data/' ,'/scratch/' )+ '_processed_' + today_str )
109- for idx , row in notes .iterrows ():
110- image_dir = os .path .join (data_asset ,[d for d in os .listdir (data_asset ) if d .endswith ('.dir' )][0 ] ,row ['capture names' ]+ '.imgdir' )
111- zarr_filename = os .path .join (processed_dir , row ['note' ] + '.zarr' )
112- tif_filename = zarr_filename .replace ('.zarr' ,'.tif' )
113- if os .path .exists (zarr_filename ):
114- print ('skipping' , data_asset , row ['note' ])
115- continue
119+
120+ # Optimize renaming logic
121+ if select_scales is not None :
122+ orig_keys = list (data_tree_obj .keys ())
123+ desired_keys = [f"scale{ idx } " for idx in range (len (orig_keys ))]
116124
117- try :
118- yaml_file = os .path .join (image_dir ,'StagePositionData.yaml' )
119- StagePositionData = open_yaml (yaml_file )
120- positions = (np .array (StagePositionData ['StructArrayValues' ])/ 100 ).astype (int )
121- except :
122- print ('no position data for' , data_asset , row ['note' ])
123- continue
124- for p in range (3 ):
125- positions [p ::3 ] = rankdata (- positions [p ::3 ], method = 'dense' )- 1
125+ if orig_keys != desired_keys :
126+ print ("Renaming scales to ensure sequential order" )
127+ # Use dict comprehension for faster renaming
128+ renamed_datasets = {}
129+ for n , old_key in enumerate (orig_keys ):
130+ new_key = f"scale{ n } "
131+ print (f" Renaming { old_key } -> { new_key } " )
132+ dataset = data_tree_obj [old_key ].copy (deep = False ) # Shallow copy
133+ dataset .image .attrs ['original_scale_level' ] = old_key
134+ renamed_datasets [new_key ] = dataset
126135
127- locs = positions .reshape (- 1 ,3 )[:,:- 1 ][:,::- 1 ].astype (int )
128- n_rows = np .max (locs [:,0 ])+ 1
129- n_cols = np .max (locs [:,1 ])+ 1
130-
131- if n_rows * n_cols == 1 :
132- print ('no fusion needed for' , data_asset , row ['note' ])
133- file = os .path .join (image_dir , [f for f in os .listdir (image_dir ) if f .endswith ('.npy' ) and f .startswith ('ImageData' )][0 ])
134- image = np .load (file )
135-
136- else :
137- files = [f for f in os .listdir (image_dir ) if f .endswith ('.npy' ) and f .startswith ('ImageData' )]
138- files = np .sort (files )
139- image_ = np .load (os .path .join (image_dir , files [0 ]))
140- n_tiles = len (locs )
141- x_size = image_ .shape [1 ]
142- y_size = image_ .shape [2 ]
143- z_size = image_ .shape [0 ]
144- image = np .zeros ((z_size ,int (x_size * (n_rows - overlap * (n_rows - 1 ))), int (y_size * (n_cols - overlap * (n_cols - 1 )))),dtype = np .uint16 )
145- print ('fusing' , data_asset , row ['note' ])
146- for ind_tile in range (n_tiles ):
147- tile_ = np .load (os .path .join (image_dir , files [ind_tile ]))
148- for z in tqdm (range (z_size ), desc = f'Fusing tile { ind_tile + 1 } /{ n_tiles } ' ):
149-
150- tile = tile_ [z ]
151- if locs [ind_tile ][0 ] == 0 :
152- x_start = 0
153- x_end = x_size
154- else :
155- x_start = int (x_size * (locs [ind_tile ][0 ]* (1 - overlap )))
156- x_end = x_start + x_size
157-
158-
159- if locs [ind_tile ][0 ] < np .max (locs ,axis = 0 )[0 ]:
160- tile [- int (overlap * x_size ):, :] = tile [- int (overlap * x_size ):, :]* (1 - sigmoid_vector (int (overlap * x_size ), y_size ))
161-
162- if locs [ind_tile ][0 ] > 0 :
163- tile [:int (overlap * x_size ), :] = tile [:int (overlap * x_size ), :]* sigmoid_vector (int (overlap * x_size ), y_size )
164-
165- if locs [ind_tile ][1 ] == 0 :
166- y_start = 0
167- y_end = y_size
168- else :
169- y_start = int (y_size * (locs [ind_tile ][1 ]* (1 - overlap )))
170- y_end = y_start + y_size
171-
172-
173- if locs [ind_tile ][1 ] < np .max (locs ,axis = 0 )[1 ]:
174- tile [:, - int (overlap * y_size ):] = tile [:, - int (overlap * y_size ):]* (1 - sigmoid_vector (int (overlap * y_size ),x_size ).T )
175-
176- if locs [ind_tile ][1 ] > 0 :
177- tile [:, :int (overlap * y_size )] = tile [:, :int (overlap * y_size )]* sigmoid_vector (int (overlap * y_size ),x_size ).T
178-
179- image [z ,x_start :x_end , y_start :y_end ] += tile
180- os .makedirs (processed_dir , exist_ok = True )
181- tiff .imwrite (tif_filename , image )
182- store = zarr .storage .LocalStore (zarr_filename )
183- root = zarr .group (store = store , zarr_format = 2 )
184-
185- # Define a scaler for creating the image pyramid
186- scaler = Scaler (method = 'nearest' , max_layer = img_layers ) # Create 4 levels in the pyramid
187- # Write the image data with pyramid
188- write_image (image , root , scaler = scaler , axes = 'zyx' )
136+ # Rebuild tree from dict
137+ data_tree_obj = xr .DataTree .from_dict (renamed_datasets )
138+
139+ return data_tree_obj
189140
190- def get_confocal_sdata (confocal_zarr_path , raw_confocal_path , select_scales = None ):
141+ def get_confocal_sdata (confocal_zarr_path , raw_confocal_path , select_scales = ['0' ,'1' ,'2' ,'3' ],
142+ chunk_size = (1 , 64 , 512 , 512 )):
191143 sdatas = []
192- if 'deep' in [zarrs .stem for zarrs in list (confocal_zarr_path .iterdir ()) if zarrs .suffix == '.zarr' ]:
144+
145+ # Use pathlib for more efficient path operations
146+ confocal_path = Path (confocal_zarr_path )
147+ zarr_files = [f .stem for f in confocal_path .glob ('*.zarr' )]
148+
149+ if 'deep' in zarr_files :
193150 print ("Generating sdata for deep confocal..." )
194151 deep_sdata = generate_confocal_sdata (
195- zarr_path = confocal_zarr_path / 'deep.zarr' ,
196- raw_confocal_path = raw_confocal_path ,
197- select_scales = select_scales
152+ zarr_path = confocal_path / 'deep.zarr' ,
153+ raw_confocal_path = raw_confocal_path ,
154+ select_scales = select_scales ,
155+ chunk_size = chunk_size
198156 )
199157 sdatas .append (deep_sdata )
200- if 'surface' in [zarrs .stem for zarrs in list (confocal_zarr_path .iterdir ()) if zarrs .suffix == '.zarr' ]:
158+
159+ if 'surface' in zarr_files :
201160 print ("Generating sdata for surface confocal..." )
202161 surface_sdata = generate_confocal_sdata (
203- zarr_path = confocal_zarr_path / 'surface.zarr' ,
204- raw_confocal_path = raw_confocal_path ,
205- select_scales = select_scales
162+ zarr_path = confocal_path / 'surface.zarr' ,
163+ raw_confocal_path = raw_confocal_path ,
164+ select_scales = select_scales ,
165+ chunk_size = chunk_size
206166 )
207167 sdatas .append (surface_sdata )
168+
208169 confocal_sdata = sd .concatenate (sdatas , merge_coord_systems = True )
209170 return confocal_sdata
0 commit comments