Skip to content

Commit 05af8e6

Browse files
committed
updated alignment module
1 parent 4356a5a commit 05af8e6

3 files changed

Lines changed: 140 additions & 69 deletions

File tree

src/xenium_analysis_tools/alignment/confocal_alignment.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,15 @@ def get_confocal_image_sizes(img_name, cf_raw_path, overlap=0.1):
3939
'physical_pixel_size_z': 1.0 # Placeholder if not in YAML
4040
}
4141

42-
def generate_confocal_sdata(zarr_path, raw_confocal_path=None):
42+
def generate_confocal_sdata(zarr_path, raw_confocal_path=None, select_scales=None):
4343
cf_name = zarr_path.stem
44-
cf_dt = create_datatree_from_zarr(zarr_path, chan_name=cf_name)
44+
cf_dt = create_datatree_from_zarr(zarr_path, chan_name=cf_name, select_scales=select_scales)
4545

4646
cf_sdata = sd.SpatialData(
4747
images={cf_name: cf_dt}
4848
)
4949

5050
if raw_confocal_path:
51-
# cf_sizes = get_confocal_image_sizes(cf_name, raw_confocal_path)
5251
cf_sizes = get_confocal_image_sizes(cf_name, raw_confocal_path)
5352
cf_sdata[cf_name].attrs.update(cf_sizes)
5453
if not cf_sizes['physical_pixel_size_x']==cf_sizes['physical_pixel_size_y']:
@@ -58,11 +57,14 @@ def generate_confocal_sdata(zarr_path, raw_confocal_path=None):
5857

5958
return cf_sdata
6059

61-
def create_datatree_from_zarr(zarr_path, chan_name='chan'):
60+
def create_datatree_from_zarr(zarr_path, chan_name='chan', select_scales=None):
6261
root = zarr.open_group(zarr_path, mode='r')
6362
data_tree_obj = xr.DataTree()
6463

6564
for scale_level in sorted(list(root.keys())):
65+
if scale_level != '0': #Have to have scale0 to determine sizes, but can drop later
66+
if select_scales is not None and scale_level not in select_scales:
67+
continue
6668
# Load the image data at this scale level
6769
level_array = da.from_zarr(str(zarr_path / scale_level))
6870
level_array = np.expand_dims(level_array, axis=0) # Add c dimension
@@ -87,9 +89,16 @@ def create_datatree_from_zarr(zarr_path, chan_name='chan'):
8789
set_transformation(data_tree_obj[scale_key].image, sequence, to_coordinate_system="global")
8890
else:
8991
set_transformation(data_tree_obj[scale_key].image, Identity(), to_coordinate_system="global")
90-
return data_tree_obj
9192

93+
if select_scales is not None and '0' not in select_scales:
94+
del data_tree_obj['scale0']
95+
if select_scales is not None: # rename scales
96+
for n, scale_level in enumerate(list(data_tree_obj.keys())):
97+
data_tree_obj[f'scale{n}'] = data_tree_obj[scale_level]
98+
data_tree_obj[f'scale{n}'].image.attrs[f'original_scale_level'] = scale_level
99+
del data_tree_obj[scale_level]
92100

101+
return data_tree_obj
93102

94103
# Coped from capsule 4 to keep track of overlap blending code
95104
def generate_fused_confocal_images(data_asset, overlap=0.1, img_layers=6):
@@ -176,4 +185,25 @@ def generate_fused_confocal_images(data_asset, overlap=0.1, img_layers=6):
176185
# Define a scaler for creating the image pyramid
177186
scaler = Scaler(method='nearest', max_layer=img_layers) # Create 4 levels in the pyramid
178187
# Write the image data with pyramid
179-
write_image(image, root, scaler=scaler, axes = 'zyx')
188+
write_image(image, root, scaler=scaler, axes = 'zyx')
189+
190+
def get_confocal_sdata(confocal_zarr_path, raw_confocal_path, select_scales=None):
191+
sdatas = []
192+
if 'deep' in [zarrs.stem for zarrs in list(confocal_zarr_path.iterdir()) if zarrs.suffix == '.zarr']:
193+
print("Generating sdata for deep confocal...")
194+
deep_sdata = generate_confocal_sdata(
195+
zarr_path = confocal_zarr_path / 'deep.zarr',
196+
raw_confocal_path = raw_confocal_path,
197+
select_scales=select_scales
198+
)
199+
sdatas.append(deep_sdata)
200+
if 'surface' in [zarrs.stem for zarrs in list(confocal_zarr_path.iterdir()) if zarrs.suffix == '.zarr']:
201+
print("Generating sdata for surface confocal...")
202+
surface_sdata = generate_confocal_sdata(
203+
zarr_path = confocal_zarr_path / 'surface.zarr',
204+
raw_confocal_path = raw_confocal_path,
205+
select_scales=select_scales
206+
)
207+
sdatas.append(surface_sdata)
208+
confocal_sdata = sd.concatenate(sdatas, merge_coord_systems=True)
209+
return confocal_sdata

src/xenium_analysis_tools/alignment/zstack_alignment.py

Lines changed: 99 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,77 +26,118 @@ def create_zstack_da(tif_path, name, add_chan=True, dims=("z", "y", "x"), fov_um
2626
da.attrs |= {f"scale_{d}": s for d, s in zip(dims, pixel_sizes)}
2727
return da
2828

29-
def get_zstack_sdata(zstack_path, zstack_masks_path, target_size, channel_names=['gcamp', 'dextran']):
30-
def get_matching_metas(root):
31-
matches = []
32-
for d in Path(root).iterdir():
33-
if not d.is_dir(): continue
34-
meta = _parse_stack_metadata(d) # Extracts size and tif paths
35-
if meta['size'] == target_size:
36-
matches.append(meta)
37-
return sorted(matches, key=lambda x: x['name']) # Sort ensures Channel 0 -> gcamp
38-
39-
img_metas = get_matching_metas(zstack_path)
40-
mask_metas = get_matching_metas(zstack_masks_path)
29+
def parse_stack_metadata(folder, lookup_chans=['gcamp', 'dextran']):
30+
"""Extracts size and specific channel name from naming convention."""
31+
# Pattern: Matches '400x400x450' and captures the following word (e.g., GCaMP)
32+
pattern = r'(\d+)x(\d+)x(\d+)'
33+
match = re.search(pattern, folder.name)
34+
35+
if match:
36+
width, height, depth = match.groups()
37+
size = {"width": int(width), "height": int(height), "depth": int(depth)}
38+
# Normalize channel name (e.g., GCaMP -> gcamp)
39+
detected_channels = [ch for ch in lookup_chans if ch in folder.name.lower()]
40+
else:
41+
size = {"width": None, "height": None, "depth": None}
42+
detected_channels = folder.name # Fallback
4143

42-
sz = img_metas[0]['size']
44+
tifs = {d.name.lower(): list(d.glob("*.tif"))[0]
45+
for d in folder.iterdir() if d.is_dir() and "channel" in d.name.lower()}
46+
jsons = {re.sub(r'.*_(registration|roi_groups|scanimage).*', r'\1', f.stem): f
47+
for f in folder.glob("*.json")}
48+
49+
return {"size": size, "detected_channels": detected_channels, "tifs": tifs, "jsons": jsons, "name": folder.name}
50+
51+
def get_zstack_elements(stack_folder, masks_folder, return_tables=False, chan_mapping=None, add_size_suffix=True):
52+
if chan_mapping is None:
53+
chan_mapping = {
54+
'channel_0_ref_0': 'gcamp',
55+
'channel_1_ref_1': 'dextran'
56+
}
57+
58+
# Get metadata for stacks
59+
stack_meta = parse_stack_metadata(stack_folder, lookup_chans=list(chan_mapping.values()))
60+
sz = stack_meta['size']
4361
fov = (sz['depth'], sz['height'], sz['width'])
44-
images, labels, tables = {}, {}, {}
4562

46-
for i, img_meta in enumerate(img_metas):
47-
chan_name = channel_names[i] if i < len(channel_names) else f"channel_{i}"
48-
49-
# Process Image
50-
img_tif = next(iter(img_meta['tifs'].values()))
51-
img_da = create_zstack_da(img_tif, chan_name, add_chan=True, fov_um=fov)
52-
images[chan_name] = Image3DModel.parse(img_da, chunks='auto')
63+
# Get stack channel images
64+
images = {}
65+
for chan, tif_path in stack_meta['tifs'].items():
66+
chan_name = chan_mapping.get(chan, chan)
67+
img_da = create_zstack_da(tif_path, chan_name, add_chan=True, fov_um=fov)
68+
img_da.attrs.update()
69+
images[chan_name] = Image3DModel.parse(img_da,
70+
c_coords=[chan_name],
71+
chunks='auto')
72+
73+
# If specified to keep names unique, add size suffix
74+
if add_size_suffix:
75+
size_suffix = f"{fov[0]}x{fov[1]}x{fov[2]}"
76+
images = {f"{name}_{size_suffix}": img for name, img in images.items()}
77+
78+
# Use name of stack to get corresponding masks
79+
img_name = stack_meta['name'].split('_registered')[0]
80+
all_masks = list(masks_folder.iterdir())
81+
matched_masks_path = [m for m in all_masks if img_name in m.name]
82+
if matched_masks_path:
83+
matched_masks_path = matched_masks_path[0] if len(matched_masks_path) == 1 else None
84+
if matched_masks_path is None:
85+
print(f"No matching mask found for {img_name} in {masks_folder}")
5386

54-
# Process Labels & Tables
55-
if i < len(mask_metas):
56-
mask_meta = mask_metas[i]
57-
mask_tif = next(iter(mask_meta['tifs'].values()))
58-
label_key = f"{chan_name}_labels"
59-
60-
mask_da = create_zstack_da(mask_tif, label_key, add_chan=False, fov_um=fov)
61-
labels[label_key] = Labels3DModel.parse(mask_da, chunks='auto')
62-
63-
# Table logic
87+
# Get mask metadata
88+
masks_meta = parse_stack_metadata(matched_masks_path, lookup_chans=list(chan_mapping.values()))
89+
90+
# Get labels and tables for each mask channel
91+
labels = {}
92+
tables = {}
93+
for chan, tif_path in masks_meta['tifs'].items():
94+
chan_name = chan_mapping.get(chan, chan)
95+
labels_name = f"{chan_name}_labels"
96+
if add_size_suffix:
97+
labels_name = f"{labels_name}_{size_suffix}"
98+
mask_da = create_zstack_da(tif_path, labels_name, add_chan=False, fov_um=fov)
99+
labels[labels_name] = Labels3DModel.parse(mask_da, chunks='auto')
100+
if return_tables:
101+
# Corresponding table
64102
unique_ids = np.unique(mask_da.values)
65103
unique_ids = unique_ids[unique_ids > 0]
66104
obs = pd.DataFrame(unique_ids, columns=[f"{chan_name}_id"], index=unique_ids.astype(str))
67-
obs['region'] = label_key
105+
obs['region'] = labels_name
68106
ann = ad.AnnData(obs=obs)
69-
tables[f"{chan_name}_cells"] = TableModel.parse(ann, region=label_key, region_key='region', instance_key=f"{chan_name}_id")
107+
table_name = f"{chan_name}_table"
108+
if add_size_suffix:
109+
table_name = f"{table_name}_{size_suffix}"
110+
tables[table_name] = TableModel.parse(ann, region=labels_name, region_key='region', instance_key=f"{chan_name}_id")
111+
112+
return images, labels, tables
70113

71-
sdata = sd.SpatialData(images=images, labels=labels, tables=tables)
114+
def get_zstacks_sdata(stacks_folder, masks_folder, return_tables=False, chan_mapping=None):
115+
if chan_mapping is None:
116+
chan_mapping = {
117+
'channel_0_ref_0': 'gcamp',
118+
'channel_1_ref_1': 'dextran'
119+
}
120+
all_stacks = list(stacks_folder.iterdir())
121+
combined_images = {}
122+
combined_labels = {}
123+
combined_tables = {}
124+
if len(all_stacks) > 1:
125+
add_size_suffix = True
126+
else:
127+
add_size_suffix = False
128+
for zstack_folder in all_stacks:
129+
images, labels, tables = get_zstack_elements(zstack_folder, masks_folder, return_tables=return_tables, chan_mapping=chan_mapping, add_size_suffix=add_size_suffix)
130+
combined_images.update(images)
131+
combined_labels.update(labels)
132+
if tables:
133+
combined_tables.update(tables)
72134

135+
# Combine into SpatialData
136+
sdata = sd.SpatialData(images=combined_images, labels=combined_labels, tables=combined_tables)
73137
# Apply Transformations
74138
for el_type in ['images', 'labels']:
75139
for name, el in getattr(sdata, el_type).items():
76140
set_transformation(el, Identity(), "global")
77141
scale = Scale([el.attrs[f"scale_{d}"] for d in ['z', 'y', 'x']], axes=('z', 'y', 'x'))
78142
set_transformation(el, scale, "microns")
79-
80-
return sdata
81-
82-
def _parse_stack_metadata(folder):
83-
"""Extracts size and specific channel name from the Allen Institute folder naming convention."""
84-
# Pattern: Matches '400x400x450' and captures the following word (e.g., GCaMP)
85-
pattern = r'(\d+)x(\d+)x(\d+)-([^_]+)'
86-
match = re.search(pattern, folder.name)
87-
88-
if match:
89-
width, height, depth, channel = match.groups()
90-
size = {"width": int(width), "height": int(height), "depth": int(depth)}
91-
# Normalize channel name (e.g., GCaMP -> gcamp)
92-
detected_channel = channel.lower()
93-
else:
94-
size = {"width": None, "height": None, "depth": None}
95-
detected_channel = folder.name # Fallback
96-
97-
tifs = {d.name.lower(): list(d.glob("*.tif"))[0]
98-
for d in folder.iterdir() if d.is_dir() and "channel" in d.name.lower()}
99-
jsons = {re.sub(r'.*_(registration|roi_groups|scanimage).*', r'\1', f.stem): f
100-
for f in folder.glob("*.json")}
101-
102-
return {"size": size, "detected_channel": detected_channel, "tifs": tifs, "jsons": jsons, "name": folder.name}
143+
return sdata

src/xenium_analysis_tools/utils/sd_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def get_dataset_paths(dataset_id,
4040
"scratch_root": scratch_root,
4141
"results_root": results_root,
4242
"xenium_dataset_name": dataset_config.get("xenium_name", None),
43-
"sdata_path": data_root / f'{dataset_config.get("xenium_name", None)}_processed',
44-
"confocal_path": data_root / dataset_config.get("confocal_name", None),
45-
"raw_confocal_path": data_root / dataset_config.get("raw_confocal_name", None),
46-
"zstack_path": data_root / dataset_config.get("zstack_name", None),
47-
"zstack_masks": data_root / dataset_config.get("zstack_masks_name", None),
43+
"sdata_path": data_root / f'{dataset_config["xenium_name"]}_processed' if dataset_config.get("xenium_name") else None,
44+
"confocal_path": data_root / dataset_config["confocal_name"] if dataset_config.get("confocal_name") else None,
45+
"raw_confocal_path": data_root / dataset_config["raw_confocal_name"] if dataset_config.get("raw_confocal_name") else None,
46+
"zstack_path": data_root / dataset_config["zstack_name"] if dataset_config.get("zstack_name") else None,
47+
"zstack_masks": data_root / dataset_config["zstack_masks_name"] if dataset_config.get("zstack_masks_name") else None,
4848
}
4949

5050
return paths

0 commit comments

Comments
 (0)