Skip to content

Commit 7844a9e

Browse files
committed
updated alignment functions
1 parent 05af8e6 commit 7844a9e

4 files changed

Lines changed: 191 additions & 228 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ dynamic = ["version"]
1919
dependencies = [
2020
"pyarrow>=21.0.0",
2121
"scanpy>=1.11.5",
22-
"spatialdata[extra]>=0.6.1",
22+
"spatialdata==0.6.1",
23+
"spatialdata-plot",
24+
"spatialdata-io",
2325
"numpy",
2426
"pandas"
2527
]
Lines changed: 97 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
1+
from xenium_analysis_tools.utils.sd_utils import add_micron_coord_sys
22
from spatialdata.transformations import Scale, Identity, Sequence, set_transformation, get_transformation
33
from spatialdata.models import Image3DModel
44
import spatialdata as sd
55
import xarray as xr
66
import dask.array as da
77
import zarr
8-
from xenium_analysis_tools.utils.sd_utils import add_micron_coord_sys
98
import pandas as pd
109
import numpy as np
11-
# from bioio_sldy import Reader
12-
import pandas as pd
1310
import yaml
14-
import os
11+
from pathlib import Path
1512

1613
def get_confocal_image_sizes(img_name, cf_raw_path, overlap=0.1):
1714
confocal_notes = pd.read_csv(cf_raw_path / 'notes.csv')
@@ -33,15 +30,17 @@ def get_confocal_image_sizes(img_name, cf_raw_path, overlap=0.1):
3330
'sizeZ': shape[0],
3431
'sizeY': shape[1],
3532
'sizeX': shape[2],
36-
'sizeC': 1, # Confocal captures are usually single channel per dir
33+
'sizeC': 1,
3734
'physical_pixel_size_x': phys_x,
38-
'physical_pixel_size_y': phys_x, # Typically square
39-
'physical_pixel_size_z': 1.0 # Placeholder if not in YAML
35+
'physical_pixel_size_y': phys_x,
36+
'physical_pixel_size_z': 1.0
4037
}
4138

42-
def generate_confocal_sdata(zarr_path, raw_confocal_path=None, select_scales=None):
39+
def generate_confocal_sdata(zarr_path, raw_confocal_path=None, select_scales=['0','1','2','3'],
40+
chunk_size=(1, 64, 512, 512)):
4341
cf_name = zarr_path.stem
44-
cf_dt = create_datatree_from_zarr(zarr_path, chan_name=cf_name, select_scales=select_scales)
42+
cf_dt = create_datatree_from_zarr(zarr_path, chan_name=cf_name,
43+
select_scales=select_scales, chunk_size=chunk_size)
4544

4645
cf_sdata = sd.SpatialData(
4746
images={cf_name: cf_dt}
@@ -50,160 +49,122 @@ def generate_confocal_sdata(zarr_path, raw_confocal_path=None, select_scales=Non
5049
if raw_confocal_path:
5150
cf_sizes = get_confocal_image_sizes(cf_name, raw_confocal_path)
5251
cf_sdata[cf_name].attrs.update(cf_sizes)
53-
if not cf_sizes['physical_pixel_size_x']==cf_sizes['physical_pixel_size_y']:
54-
raise ValueError(f"Confocal pixel sizes in X and Y do not match for {cf_name}!")
55-
else:
56-
cf_sdata = add_micron_coord_sys(cf_sdata, pixel_size=cf_sizes['physical_pixel_size_x'])
52+
cf_sdata = add_micron_coord_sys(cf_sdata, pixel_size=[cf_sizes['physical_pixel_size_y'], cf_sizes['physical_pixel_size_x']], z_step=cf_sizes['physical_pixel_size_z'])
5753

5854
return cf_sdata
5955

60-
def create_datatree_from_zarr(zarr_path, chan_name='chan', select_scales=None):
56+
def create_datatree_from_zarr(zarr_path, chan_name='chan', select_scales=['0','1','2','3'],
57+
chunk_size=(1, 64, 512, 512)):
6158
root = zarr.open_group(zarr_path, mode='r')
6259
data_tree_obj = xr.DataTree()
63-
64-
for scale_level in sorted(list(root.keys())):
65-
if scale_level != '0': #Have to have scale0 to determine sizes, but can drop later
60+
61+
# Pre-calculate which scales to process
62+
available_scales = sorted(list(root.keys()))
63+
if select_scales is not None:
64+
# Always include scale0 for size calculations, filter later
65+
scales_to_process = ['0'] + [s for s in select_scales if s != '0']
66+
else:
67+
scales_to_process = available_scales
68+
69+
scale0_shape = None
70+
71+
for scale_level in scales_to_process:
72+
if scale_level not in available_scales:
73+
continue
74+
75+
if scale_level != '0':
6676
if select_scales is not None and scale_level not in select_scales:
6777
continue
68-
# Load the image data at this scale level
69-
level_array = da.from_zarr(str(zarr_path / scale_level))
70-
level_array = np.expand_dims(level_array, axis=0) # Add c dimension
71-
# Convert to xarray DataArray
78+
79+
print(f"Adding scale level: {scale_level}")
80+
81+
# Optimize zarr loading with explicit chunking
82+
level_array = da.from_zarr(str(zarr_path / scale_level), chunks=chunk_size[1:]) # Skip c dim
83+
level_array = da.expand_dims(level_array, axis=0) # Add c dimension
84+
85+
# Store scale0 shape for later calculations
86+
if scale_level == '0':
87+
scale0_shape = level_array.shape
88+
89+
# Convert to xarray DataArray with optimized chunking
7290
data_array = xr.DataArray(
73-
level_array,
74-
dims=['c', 'z', 'y', 'x']
75-
)
91+
level_array,
92+
dims=['c', 'z', 'y', 'x']
93+
)
94+
7695
parsed_array = Image3DModel.parse(
77-
data_array,
78-
dims=['c', 'z', 'y', 'x'],
79-
c_coords=chan_name,
80-
chunks='auto',
96+
data_array,
97+
dims=['c', 'z', 'y', 'x'],
98+
c_coords=[chan_name], # Use list for consistency
99+
chunks=chunk_size, # Explicit chunking
81100
)
101+
82102
scale_key = f'scale{scale_level}'
83103
data_tree_obj[scale_key] = xr.Dataset({'image': parsed_array})
84-
# Set up scale transformation for non-zero scales
85-
if scale_key != 'scale0':
86-
scale_factors = np.array(data_tree_obj[f'scale0'].image.shape) / np.array(data_tree_obj[scale_key].image.shape)
87-
scale_transform = Scale(scale_factors, axes=data_tree_obj[scale_key].image.dims)
104+
105+
# Set up transformations more efficiently
106+
if scale_level != '0' and scale0_shape is not None:
107+
current_shape = level_array.shape
108+
scale_factors = np.array(scale0_shape) / np.array(current_shape)
109+
scale_transform = Scale(scale_factors, axes=parsed_array.dims)
88110
sequence = Sequence([scale_transform, Identity()])
89-
set_transformation(data_tree_obj[scale_key].image, sequence, to_coordinate_system="global")
111+
set_transformation(parsed_array, sequence, to_coordinate_system="global")
90112
else:
91-
set_transformation(data_tree_obj[scale_key].image, Identity(), to_coordinate_system="global")
113+
set_transformation(parsed_array, Identity(), to_coordinate_system="global")
92114

115+
# Handle scale removal and renaming more efficiently
93116
if select_scales is not None and '0' not in select_scales:
117+
print("Removing scale0 from data tree as it's not in select_scales")
94118
del data_tree_obj['scale0']
95-
if select_scales is not None: # rename scales
96-
for n, scale_level in enumerate(list(data_tree_obj.keys())):
97-
data_tree_obj[f'scale{n}'] = data_tree_obj[scale_level]
98-
data_tree_obj[f'scale{n}'].image.attrs[f'original_scale_level'] = scale_level
99-
del data_tree_obj[scale_level]
100-
101-
return data_tree_obj
102-
103-
# Coped from capsule 4 to keep track of overlap blending code
104-
def generate_fused_confocal_images(data_asset, overlap=0.1, img_layers=6):
105-
notes = pd.read_csv(os.path.join(data_asset, 'notes.csv'))
106-
notes=notes[notes['note']!='qc'].reset_index(drop=True)
107-
today_str = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
108-
processed_dir = os.path.join(data_asset.replace('/data/','/scratch/')+'_processed_'+today_str)
109-
for idx, row in notes.iterrows():
110-
image_dir = os.path.join(data_asset,[d for d in os.listdir(data_asset) if d.endswith('.dir')][0] ,row['capture names']+'.imgdir')
111-
zarr_filename = os.path.join(processed_dir, row['note'] + '.zarr')
112-
tif_filename = zarr_filename.replace('.zarr','.tif')
113-
if os.path.exists(zarr_filename):
114-
print('skipping', data_asset, row['note'])
115-
continue
119+
120+
# Optimize renaming logic
121+
if select_scales is not None:
122+
orig_keys = list(data_tree_obj.keys())
123+
desired_keys = [f"scale{idx}" for idx in range(len(orig_keys))]
116124

117-
try:
118-
yaml_file = os.path.join(image_dir,'StagePositionData.yaml')
119-
StagePositionData = open_yaml(yaml_file)
120-
positions = (np.array(StagePositionData['StructArrayValues'])/100).astype(int)
121-
except:
122-
print('no position data for', data_asset, row['note'])
123-
continue
124-
for p in range(3):
125-
positions[p::3] = rankdata(-positions[p::3], method='dense')-1
125+
if orig_keys != desired_keys:
126+
print("Renaming scales to ensure sequential order")
127+
# Use dict comprehension for faster renaming
128+
renamed_datasets = {}
129+
for n, old_key in enumerate(orig_keys):
130+
new_key = f"scale{n}"
131+
print(f" Renaming {old_key} -> {new_key}")
132+
dataset = data_tree_obj[old_key].copy(deep=False) # Shallow copy
133+
dataset.image.attrs['original_scale_level'] = old_key
134+
renamed_datasets[new_key] = dataset
126135

127-
locs = positions.reshape(-1,3)[:,:-1][:,::-1].astype(int)
128-
n_rows = np.max(locs[:,0])+1
129-
n_cols = np.max(locs[:,1])+1
130-
131-
if n_rows*n_cols == 1:
132-
print('no fusion needed for', data_asset, row['note'])
133-
file = os.path.join(image_dir, [f for f in os.listdir(image_dir) if f.endswith('.npy') and f.startswith('ImageData')][0])
134-
image = np.load(file)
135-
136-
else:
137-
files = [f for f in os.listdir(image_dir) if f.endswith('.npy') and f.startswith('ImageData')]
138-
files = np.sort(files)
139-
image_ = np.load(os.path.join(image_dir, files[0]))
140-
n_tiles = len(locs)
141-
x_size = image_.shape[1]
142-
y_size = image_.shape[2]
143-
z_size = image_.shape[0]
144-
image = np.zeros((z_size,int(x_size*(n_rows-overlap*(n_rows-1))), int(y_size*(n_cols-overlap*(n_cols-1)))),dtype=np.uint16)
145-
print('fusing', data_asset, row['note'])
146-
for ind_tile in range(n_tiles):
147-
tile_ = np.load(os.path.join(image_dir, files[ind_tile]))
148-
for z in tqdm(range(z_size), desc=f'Fusing tile {ind_tile+1}/{n_tiles}'):
149-
150-
tile = tile_[z]
151-
if locs[ind_tile][0] == 0:
152-
x_start = 0
153-
x_end = x_size
154-
else:
155-
x_start = int(x_size*(locs[ind_tile][0]*(1-overlap)))
156-
x_end = x_start+x_size
157-
158-
159-
if locs[ind_tile][0] < np.max(locs,axis=0)[0]:
160-
tile[-int(overlap*x_size):, :] = tile[-int(overlap*x_size):, :]*(1-sigmoid_vector(int(overlap*x_size), y_size))
161-
162-
if locs[ind_tile][0] > 0:
163-
tile[:int(overlap*x_size), :] = tile[:int(overlap*x_size), :]*sigmoid_vector(int(overlap*x_size), y_size)
164-
165-
if locs[ind_tile][1] == 0:
166-
y_start = 0
167-
y_end = y_size
168-
else:
169-
y_start = int(y_size*(locs[ind_tile][1]*(1-overlap)))
170-
y_end = y_start+y_size
171-
172-
173-
if locs[ind_tile][1] < np.max(locs,axis=0)[1]:
174-
tile[:, -int(overlap*y_size):] = tile[:, -int(overlap*y_size):]*(1-sigmoid_vector(int(overlap*y_size),x_size).T)
175-
176-
if locs[ind_tile][1] > 0:
177-
tile[:, :int(overlap*y_size)] = tile[:, :int(overlap*y_size)]*sigmoid_vector(int(overlap*y_size),x_size).T
178-
179-
image[z,x_start:x_end, y_start:y_end] += tile
180-
os.makedirs(processed_dir, exist_ok=True)
181-
tiff.imwrite(tif_filename, image)
182-
store = zarr.storage.LocalStore(zarr_filename)
183-
root = zarr.group(store=store, zarr_format=2)
184-
185-
# Define a scaler for creating the image pyramid
186-
scaler = Scaler(method='nearest', max_layer=img_layers) # Create 4 levels in the pyramid
187-
# Write the image data with pyramid
188-
write_image(image, root, scaler=scaler, axes = 'zyx')
136+
# Rebuild tree from dict
137+
data_tree_obj = xr.DataTree.from_dict(renamed_datasets)
138+
139+
return data_tree_obj
189140

190-
def get_confocal_sdata(confocal_zarr_path, raw_confocal_path, select_scales=None):
141+
def get_confocal_sdata(confocal_zarr_path, raw_confocal_path, select_scales=['0','1','2','3'],
142+
chunk_size=(1, 64, 512, 512)):
191143
sdatas = []
192-
if 'deep' in [zarrs.stem for zarrs in list(confocal_zarr_path.iterdir()) if zarrs.suffix == '.zarr']:
144+
145+
# Use pathlib for more efficient path operations
146+
confocal_path = Path(confocal_zarr_path)
147+
zarr_files = [f.stem for f in confocal_path.glob('*.zarr')]
148+
149+
if 'deep' in zarr_files:
193150
print("Generating sdata for deep confocal...")
194151
deep_sdata = generate_confocal_sdata(
195-
zarr_path = confocal_zarr_path / 'deep.zarr',
196-
raw_confocal_path = raw_confocal_path,
197-
select_scales=select_scales
152+
zarr_path=confocal_path / 'deep.zarr',
153+
raw_confocal_path=raw_confocal_path,
154+
select_scales=select_scales,
155+
chunk_size=chunk_size
198156
)
199157
sdatas.append(deep_sdata)
200-
if 'surface' in [zarrs.stem for zarrs in list(confocal_zarr_path.iterdir()) if zarrs.suffix == '.zarr']:
158+
159+
if 'surface' in zarr_files:
201160
print("Generating sdata for surface confocal...")
202161
surface_sdata = generate_confocal_sdata(
203-
zarr_path = confocal_zarr_path / 'surface.zarr',
204-
raw_confocal_path = raw_confocal_path,
205-
select_scales=select_scales
162+
zarr_path=confocal_path / 'surface.zarr',
163+
raw_confocal_path=raw_confocal_path,
164+
select_scales=select_scales,
165+
chunk_size=chunk_size
206166
)
207167
sdatas.append(surface_sdata)
168+
208169
confocal_sdata = sd.concatenate(sdatas, merge_coord_systems=True)
209170
return confocal_sdata

0 commit comments

Comments
 (0)