Skip to content

Commit 723d10d

Browse files
committed
added alignment functions
1 parent 626e7e6 commit 723d10d

3 files changed

Lines changed: 292 additions & 15 deletions

File tree

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import spatialdata as sd
2+
from xenium_analysis_tools.utils.sd_utils import add_micron_coord_sys
3+
from spatialdata.models import Image3DModel, Labels3DModel
4+
from pathlib import Path
5+
import pandas as pd
6+
import numpy as np
7+
import xarray as xr
8+
import tifffile
9+
import json
10+
import re
11+
12+
def create_zstack_array(tif_path,
13+
add_chan=True,
14+
fov_z_um=450.0,
15+
fov_x_um=400.0,
16+
fov_y_um=400.0):
17+
18+
data = tifffile.imread(tif_path)
19+
pixels_z, pixels_y, pixels_x = data.shape
20+
21+
# Pixel sizes
22+
pixel_size_z = fov_z_um / pixels_z
23+
pixel_size_y = fov_y_um / pixels_y
24+
pixel_size_x = fov_x_um / pixels_x
25+
26+
# Create coordinate arrays with proper spacing
27+
z_coords = np.arange(pixels_z) * pixel_size_z
28+
y_coords = np.arange(pixels_y) * pixel_size_y
29+
x_coords = np.arange(pixels_x) * pixel_size_x
30+
coords = {"z": z_coords,
31+
"y": y_coords,
32+
"x": x_coords}
33+
34+
if add_chan:
35+
data = np.expand_dims(data, axis=0)
36+
coords["c"] = np.arange(data.shape[0])
37+
dims = ("c", "z", "y", "x")
38+
else:
39+
dims = ("z", "y", "x")
40+
41+
# Create xarray DataArray with improved metadata
42+
da = xr.DataArray(
43+
data,
44+
coords=coords,
45+
dims=dims,
46+
attrs={
47+
"pixel_size_um_z": pixel_size_z,
48+
"pixel_size_um_y": pixel_size_y,
49+
"pixel_size_um_x": pixel_size_x,
50+
"fov_um_z": fov_z_um,
51+
"fov_um_y": fov_y_um,
52+
"fov_um_x": fov_x_um,
53+
"units": "micrometers",
54+
}
55+
)
56+
return da
57+
58+
def get_zstacks_dict(zstacks_folder, channels=['gcamp', 'dextran']):
59+
zstacks_dict = {}
60+
61+
# Process only directories
62+
stack_dirs = [d for d in zstacks_folder.iterdir() if d.is_dir()]
63+
64+
for stack_ind, stack_folder in enumerate(stack_dirs):
65+
stack_info = {
66+
'stack_name': stack_folder.name,
67+
'stack_size': _extract_stack_size(stack_folder.name),
68+
'stack_channels': [ch for ch in channels if ch in stack_folder.name.lower()],
69+
'metadata_jsons': {'registration': None, 'roi_groups': None, 'scanimage': None},
70+
'channel_tifs': {}
71+
}
72+
73+
# Process files in stack folder
74+
chan_ind = 0
75+
for item in sorted(stack_folder.iterdir()):
76+
if item.is_file() and item.suffix.lower() == '.json':
77+
# Categorize JSON metadata files
78+
json_type = _categorize_json_file(item.name.lower())
79+
if json_type:
80+
stack_info['metadata_jsons'][json_type] = item
81+
82+
elif item.is_dir() and 'channel' in item.name.lower():
83+
# Process channel directories
84+
tif_files = [f for f in item.iterdir() if f.suffix.lower() == '.tif']
85+
stack_info['channel_tifs'][chan_ind] = {
86+
'chan_name': item.name,
87+
'chan_tif_path': tif_files[0] if len(tif_files) == 1 else tif_files
88+
}
89+
chan_ind += 1
90+
91+
zstacks_dict[stack_ind] = stack_info
92+
93+
return zstacks_dict
94+
95+
def _extract_stack_size(stack_name):
96+
"""Extract width x height x depth from stack name."""
97+
size_pattern = re.search(r'(\d+)x(\d+)x(\d+)', stack_name)
98+
if size_pattern:
99+
width, height, depth = map(int, size_pattern.groups())
100+
return {"width": width, "height": height, "depth": depth}
101+
return {"width": None, "height": None, "depth": None}
102+
103+
def _categorize_json_file(filename_lower):
104+
"""Categorize JSON file by its name."""
105+
if 'registration' in filename_lower:
106+
return 'registration'
107+
elif 'roi_groups' in filename_lower:
108+
return 'roi_groups'
109+
elif 'scanimage' in filename_lower:
110+
return 'scanimage'
111+
return None
112+
113+
def get_zstack(zstacks_dict, zstack_ind=None, zstack_name=None, stack_size=None, channels=None):
114+
if zstack_ind is not None:
115+
if zstack_ind not in zstacks_dict:
116+
raise ValueError(f"Z-stack index {zstack_ind} not found in zstacks_dict.")
117+
return zstacks_dict[zstack_ind]
118+
119+
# Helper function to find matches
120+
def _find_matches(criterion_func, criterion_name, criterion_value):
121+
matches = [i for i, stack in zstacks_dict.items() if criterion_func(stack)]
122+
123+
if not matches:
124+
raise ValueError(f"{criterion_name} {criterion_value} not found in zstacks_dict.")
125+
126+
if len(matches) == 1:
127+
return zstacks_dict[matches[0]]
128+
129+
# Handle multiple matches with optional channel filtering
130+
if channels is not None:
131+
channel_matches = [
132+
i for i in matches
133+
if set(zstacks_dict[i]['stack_channels']) == set(channels)
134+
]
135+
if len(channel_matches) == 1:
136+
return zstacks_dict[channel_matches[0]]
137+
elif len(channel_matches) > 1:
138+
raise ValueError(f"Multiple z-stacks found with {criterion_name} {criterion_value} and channels {channels}. Found {len(channel_matches)} matches.")
139+
else:
140+
raise ValueError(f"No z-stack found with {criterion_name} {criterion_value} and channels {channels}.")
141+
142+
raise ValueError(f"Multiple z-stacks found with {criterion_name} {criterion_value}. Found {len(matches)} matches. Consider specifying channels parameter.")
143+
144+
if zstack_name is not None:
145+
return _find_matches(
146+
lambda stack: stack['stack_name'] == zstack_name,
147+
"Z-stack name", zstack_name
148+
)
149+
150+
if stack_size is not None:
151+
return _find_matches(
152+
lambda stack: (
153+
stack['stack_size']['width'] == stack_size['width'] and
154+
stack['stack_size']['height'] == stack_size['height'] and
155+
stack['stack_size']['depth'] == stack_size['depth']
156+
),
157+
"Stack size", stack_size
158+
)
159+
160+
raise ValueError("Either zstack_ind, zstack_name, or stack_size must be provided.")
161+
162+
def get_alignment_data_paths(dataset_id,
163+
data_root=Path('/root/capsule/data'),
164+
scratch_root=Path('/root/capsule/scratch'),
165+
results_root=Path('/root/capsule/results'),
166+
code_root=Path('/root/capsule/code')):
167+
datasets_naming_dict_path = code_root / 'datasets_names_dict.json'
168+
with open(datasets_naming_dict_path) as f:
169+
datasets_naming_dict = json.load(f)
170+
dataset_id = str(dataset_id) # Ensure string format
171+
dataset_config = datasets_naming_dict[dataset_id]
172+
173+
paths = {
174+
"data_root": data_root,
175+
"scratch_root": scratch_root,
176+
"results_root": results_root,
177+
"sdata_path": data_root / f'{dataset_config["xenium_name"]}_processed',
178+
"confocal_path": data_root / dataset_config["confocal_name"],
179+
"zstack_path": data_root / dataset_config["zstack_name"],
180+
"zstack_masks": data_root / dataset_config["zstack_masks_name"]
181+
}
182+
183+
return paths
184+
185+
def get_zstack_sdata(stack, stack_masks=None, use_shared_coords=True):
186+
# Create the z-stack image array
187+
num_channels = len(stack['stack_channels'])
188+
chans = []
189+
if num_channels > 1:
190+
for ch_ind in range(num_channels):
191+
chan_array = create_zstack_array(tif_path=stack['channel_tifs'][ch_ind]['chan_tif_path'],
192+
fov_x_um=stack['stack_size']['width'],
193+
fov_y_um=stack['stack_size']['height'],
194+
fov_z_um=stack['stack_size']['depth'])
195+
chans.append(chan_array)
196+
zstack_img = xr.concat(chans, dim='c')
197+
zstack_img['c'] = stack['stack_channels']
198+
else:
199+
zstack_img = create_zstack_array(tif_path=stack['channel_tifs'][0]['chan_tif_path'],
200+
fov_x_um=stack['stack_size']['width'],
201+
fov_y_um=stack['stack_size']['height'],
202+
fov_z_um=stack['stack_size']['depth'])
203+
zstack_img['c'] = stack['stack_channels']
204+
205+
if use_shared_coords:
206+
reg_json_path = stack['metadata_jsons']['registration']
207+
with open(reg_json_path) as f:
208+
reg_json = json.load(f)
209+
if 'z_steps' in reg_json.keys() and len(reg_json['z_steps'])==zstack_img.sizes['z']:
210+
print("Using shared z coordinates for images")
211+
zstack_img.coords['z'] = reg_json['z_steps']
212+
213+
# Parse into Image3DModel
214+
zstack_img = Image3DModel.parse(
215+
zstack_img,
216+
dims=['c', 'z', 'y', 'x'],
217+
c_coords=stack['stack_channels'],
218+
chunks='auto',
219+
)
220+
221+
# Make the SpatialData object
222+
zstack_sdata = sd.SpatialData(
223+
images={'zstack': zstack_img},
224+
)
225+
226+
if stack_masks is not None:
227+
# Get labels for each channel
228+
for mask_ind, masks in zstack_masks['channel_tifs'].items():
229+
channel_name = zstack_masks['stack_channels'][mask_ind]
230+
zstack_label = create_zstack_array(tif_path=masks['chan_tif_path'],
231+
fov_x_um=stack_masks['stack_size']['width'],
232+
fov_y_um=stack_masks['stack_size']['height'],
233+
fov_z_um=stack_masks['stack_size']['depth'],
234+
add_chan=False)
235+
236+
if use_shared_coords:
237+
if 'z_steps' in reg_json.keys() and len(reg_json['z_steps'])==zstack_label.sizes['z']:
238+
print("Using shared z coordinates for labels")
239+
zstack_label.coords['z'] = reg_json['z_steps']
240+
241+
zstack_label = Labels3DModel.parse(
242+
zstack_label,
243+
dims=['z', 'y', 'x'],
244+
chunks='auto',
245+
)
246+
zstack_sdata.labels[f"{channel_name}_labels"] = zstack_label
247+
248+
# Determine pixel sizes
249+
if zstack_sdata['zstack'].attrs['pixel_size_um_x'] == zstack_sdata['zstack'].attrs['pixel_size_um_y']:
250+
pixel_size = zstack_sdata['zstack'].attrs['pixel_size_um_x']
251+
if zstack_sdata['zstack'].attrs['fov_um_z']==zstack_sdata['zstack'].shape[1]:
252+
z_step_size = 1
253+
254+
# Add micron coordinate system if not already present
255+
if 'microns' not in zstack_sdata.coordinate_systems:
256+
zstack_sdata = add_micron_coord_sys(zstack_sdata, pixel_size=pixel_size, z_step=z_step_size)
257+
else:
258+
print("Micron coordinate system already exists")
259+
return zstack_sdata
260+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
def get_vals_perc(img, chan, vmin_val=None, vmax_val=None, vmin_perc=None, vmax_perc=None, res_level=4):
2+
if res_level > len(sd.get_pyramid_levels(img))-1:
3+
res_level = len(sd.get_pyramid_levels(img))-1
4+
print(f'Using resolution level: {res_level}')
5+
if vmin_perc is not None or vmax_perc is not None:
6+
ch_vals = sd.get_pyramid_levels(img, n=res_level).sel(c=chan)
7+
vmin = None
8+
if vmin_val is not None:
9+
vmin = vmin_val
10+
elif vmin_perc is not None:
11+
vmin = np.percentile(ch_vals.values, vmin_perc)
12+
13+
vmax = None
14+
if vmax_val is not None:
15+
vmax = vmax_val
16+
elif vmax_perc is not None:
17+
vmax = np.percentile(ch_vals.values, vmax_perc)
18+
19+
return vmin, vmax

src/xenium_analysis_tools/utils/sd_utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,19 @@ def add_micron_coord_sys(sdata, pixel_size=None, z_step=None):
2828
identity = Identity()
2929

3030
# --- Images ---
31-
# dapi_zstack is (c, z, y, x)
32-
if 'dapi_zstack' in sdata.images:
33-
set_transformation(
34-
sdata.images['dapi_zstack'],
35-
scale_czyx,
36-
to_coordinate_system="microns"
37-
)
38-
39-
# morphology_focus is (c, y, x)
40-
if 'morphology_focus' in sdata.images:
41-
set_transformation(
42-
sdata.images['morphology_focus'],
43-
scale_cyx,
44-
to_coordinate_system="microns"
45-
)
31+
for image_name in sdata.images:
32+
if 'z' in sdata[image_name].dims:
33+
set_transformation(
34+
sdata.images[image_name],
35+
scale_czyx,
36+
to_coordinate_system="microns"
37+
)
38+
else:
39+
set_transformation(
40+
sdata.images[image_name],
41+
scale_cyx,
42+
to_coordinate_system="microns"
43+
)
4644

4745
# --- Labels ---
4846
# Both labels are (y, x)

0 commit comments

Comments
 (0)