Skip to content

Commit 55cb9ef

Browse files
committed
added coregistration helpers
1 parent 96b20a6 commit 55cb9ef

2 files changed

Lines changed: 237 additions & 1 deletion

File tree

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import numpy as np
2+
import pandas as pd
3+
import spatialdata as sd
4+
import tifffile
5+
from IPython.display import display
6+
from tqdm.notebook import tqdm
7+
8+
from xenium_analysis_tools.utils.sd_utils import (
9+
add_micron_coord_sys,
10+
_is_multiscale,
11+
)
12+
13+
from spatialdata.transformations import (
14+
get_transformation,
15+
Scale,
16+
Identity,
17+
Sequence
18+
)
19+
20+
from xenium_analysis_tools.alignment.format_for_napari import (
21+
add_mapped_cells_cols,
22+
filter_cells,
23+
filter_labels,
24+
)
25+
26+
from spatialdata.models import TableModel
27+
28+
29+
def _rename_channel_coord(element_obj, channel_name_map=None):
30+
if channel_name_map is None:
31+
channel_name_map = {
32+
'DAPI': 'dapi',
33+
'ATP1A1/CD45/E-Cadherin': 'boundary',
34+
'18S': 'rna',
35+
'AlphaSMA/Vimentin': 'protein'
36+
}
37+
38+
if not hasattr(element_obj, 'coords'):
39+
return element_obj
40+
if 'c' not in element_obj.coords:
41+
return element_obj
42+
43+
old_names = [str(ch) for ch in element_obj.coords['c'].values]
44+
new_names = [channel_name_map.get(ch, ch) for ch in old_names]
45+
46+
if old_names == new_names:
47+
return element_obj
48+
49+
return element_obj.assign_coords(c=new_names)
50+
51+
52+
def add_mapping_results(sdata, mapping_output_path, table_el='table', section_n=None):
53+
import anndata as ad
54+
from xenium_analysis_tools.map_xenium.format_mapping import (
55+
add_broad_types,
56+
)
57+
mapped_data = ad.read_h5ad(mapping_output_path)
58+
if section_n is not None:
59+
mapped_data = mapped_data[mapped_data.obs['section'] == section_n]
60+
mapped_data.obs['cell_id'] = mapped_data.obs['original_cell_id']
61+
mapped_data.obs.rename_axis('cell_section_id', inplace=True)
62+
mapped_data.var.set_index('gene_symbol', inplace=True, drop=False)
63+
sdata[table_el] = add_mapped_cells_cols(sdata[table_el], mapped_data, verbose=True)
64+
sdata[table_el] = add_broad_types(sdata[table_el])
65+
return sdata
66+
67+
68+
def get_cell_labels(section_sdata, table_el='table', labels_el='cell_labels', multiscale_level=2, cell_filters=None):
69+
if 'microns' not in section_sdata.coordinate_systems:
70+
section_sdata = add_micron_coord_sys(section_sdata)
71+
72+
if _is_multiscale(section_sdata[labels_el]):
73+
label_da = sd.get_pyramid_levels(section_sdata[labels_el], n=multiscale_level)
74+
else:
75+
label_da = section_sdata[labels_el]
76+
77+
# Compute to numpy so filter_labels' in-place assignment works
78+
# (dask arrays silently ignore item assignment)
79+
import dask.array as da
80+
label_da = label_da.copy()
81+
label_da.data = da.from_array(label_da.data.compute())
82+
83+
table = section_sdata[table_el].copy()
84+
table.obs['region'] = pd.Categorical(['cell_labels'] * len(table), categories=['cell_labels'])
85+
table = TableModel.parse(table, region='cell_labels', region_key='region', instance_key='cell_labels', overwrite_metadata=True)
86+
87+
cell_labels_sd = sd.SpatialData(tables={table_el: table}, labels={labels_el: label_da})
88+
if cell_filters is not None:
89+
cell_labels_sd = filter_cells(cell_labels_sd, cell_filters=cell_filters)
90+
cell_labels_sd = filter_labels(cell_labels_sd)
91+
92+
return cell_labels_sd
93+
94+
95+
def extract_bigwarp_images_for_section(section_sdata,
96+
section_n,
97+
bigwarp_projects_folder,
98+
el_name = 'morphology_focus',
99+
subset_channels = 'all',
100+
multiscale_level = 2,
101+
dtype='uint16',
102+
normalize=False,
103+
z_step_um=None,
104+
resunit = 'cm',
105+
microns_coord_sys_name = 'microns',
106+
return_sdata=True):
107+
section_bigwarp_folder = bigwarp_projects_folder / f'section_{section_n}'
108+
section_bigwarp_folder.mkdir(exist_ok=True, parents=True)
109+
if not isinstance(section_sdata, sd.SpatialData):
110+
section_sdata = sd.read_zarr(section_sdata)
111+
section_sdata = add_micron_coord_sys(section_sdata)
112+
if _is_multiscale(section_sdata[el_name]):
113+
mf_element = sd.get_pyramid_levels(section_sdata[el_name], n=multiscale_level)
114+
else:
115+
mf_element = section_sdata[el_name]
116+
mf_element = _rename_channel_coord(mf_element)
117+
if subset_channels == 'all':
118+
subset_channels = mf_element.coords['c'].values
119+
120+
for ch in tqdm(subset_channels, desc=f'Extracting channels for section {section_n}'):
121+
out_path = section_bigwarp_folder / f'{ch}.tif'
122+
ch_el = mf_element.sel(c=ch)
123+
dims = ch_el.dims
124+
microns_tf = get_transformation(ch_el, to_coordinate_system=microns_coord_sys_name)
125+
126+
# Get scale factors for Y and X dimensions
127+
if isinstance(microns_tf, Scale):
128+
pixel_size_yx = [microns_tf.scale[microns_tf.axes.index(d)] for d in ['y', 'x']]
129+
elif isinstance(microns_tf, Identity):
130+
pixel_size_yx = [1.0, 1.0]
131+
elif isinstance(microns_tf, Sequence):
132+
py, px = 1.0, 1.0
133+
for tf in microns_tf.transformations:
134+
if isinstance(tf, Scale):
135+
py *= tf.scale[tf.axes.index('y')]
136+
px *= tf.scale[tf.axes.index('x')]
137+
pixel_size_yx = [py, px]
138+
else:
139+
pixel_size_yx = None
140+
print(f' Warning: unhandled transform type {type(microns_tf)} for channel {ch}, skipping calibration')
141+
142+
143+
arr = ch_el.data.compute()
144+
if np.issubdtype(np.dtype(dtype), np.integer) and normalize:
145+
finite = arr[np.isfinite(arr)]
146+
lo, hi = finite.min(), finite.max()
147+
if hi > lo:
148+
arr = (arr.astype(np.float64) - lo) / (hi - lo) * np.iinfo(dtype).max
149+
arr = np.nan_to_num(arr, nan=0.0)
150+
arr = np.clip(arr, 0, np.iinfo(dtype).max)
151+
arr = arr.astype(dtype, copy=False)
152+
153+
ij_meta = {'axes': ''.join(d.upper() for d in dims)}
154+
if 'z' in dims and z_step_um is not None:
155+
ij_meta['spacing'] = z_step_um
156+
if pixel_size_yx is not None:
157+
ij_meta['unit'] = 'um'
158+
159+
# Resolution tag: pixels per micron (for XY)
160+
kwargs = dict(imagej=True, metadata=ij_meta)
161+
if pixel_size_yx is not None:
162+
py, px = pixel_size_yx
163+
resolution_um = (1.0 / px, 1.0 / py)
164+
resolution_cm = (1e4 / px, 1e4 / py)
165+
kwargs['resolution'] = resolution_um if resunit == 'um' else resolution_cm
166+
kwargs['resolutionunit'] = tifffile.RESUNIT.MICROMETER if resunit == 'um' else tifffile.RESUNIT.CENTIMETER
167+
168+
tifffile.imwrite(str(out_path), arr, **kwargs)
169+
170+
if return_sdata:
171+
return section_sdata
172+
173+
def extract_bigwarp_label_for_section(
174+
labels_el,
175+
labels_name,
176+
section_n,
177+
bigwarp_projects_folder,
178+
microns_coord_sys_name='microns',
179+
dtype='uint8',
180+
binary=True,
181+
z_step_um=None,
182+
resunit='cm'
183+
):
184+
section_bigwarp_folder = bigwarp_projects_folder / f'section_{section_n}'
185+
out_path = section_bigwarp_folder / f'{labels_name}.tif'
186+
187+
# Get the labels DataArray — handle SpatialData objects and raw DataArrays
188+
if isinstance(labels_el, sd.SpatialData):
189+
el_name = list(labels_el.labels.keys())[0]
190+
el = labels_el.labels[el_name]
191+
else:
192+
el = labels_el
193+
if _is_multiscale(el):
194+
el = sd.get_pyramid_levels(el, n=2)
195+
196+
dims = el.dims
197+
microns_tf = get_transformation(el, to_coordinate_system=microns_coord_sys_name)
198+
199+
if isinstance(microns_tf, Scale):
200+
pixel_size_yx = [microns_tf.scale[microns_tf.axes.index(d)] for d in ['y', 'x']]
201+
elif isinstance(microns_tf, Identity):
202+
pixel_size_yx = [1.0, 1.0]
203+
elif isinstance(microns_tf, Sequence):
204+
py, px = 1.0, 1.0
205+
for tf in microns_tf.transformations:
206+
if isinstance(tf, Scale):
207+
py *= tf.scale[tf.axes.index('y')]
208+
px *= tf.scale[tf.axes.index('x')]
209+
pixel_size_yx = [py, px]
210+
else:
211+
pixel_size_yx = None
212+
print(f' Warning: unhandled transform type {type(microns_tf)}, skipping calibration')
213+
214+
arr = el.data.compute()
215+
if binary:
216+
arr = np.where(arr > 0, np.iinfo('uint16').max, 0).astype('uint16')
217+
else:
218+
arr = arr.astype(dtype, copy=False)
219+
220+
ij_meta = {'axes': ''.join(d.upper() for d in dims)}
221+
if 'z' in dims and z_step_um is not None:
222+
ij_meta['spacing'] = z_step_um
223+
if pixel_size_yx is not None:
224+
ij_meta['unit'] = 'um'
225+
226+
kwargs = dict(imagej=True, metadata=ij_meta)
227+
if pixel_size_yx is not None:
228+
py, px = pixel_size_yx
229+
resolution_cm = (1e4 / px, 1e4 / py)
230+
resolution_um = (1.0 / px, 1.0 / py)
231+
kwargs['resolution'] = resolution_um if resunit == 'um' else resolution_cm
232+
kwargs['resolutionunit'] = tifffile.RESUNIT.MICROMETER if resunit == 'um' else tifffile.RESUNIT.CENTIMETER
233+
234+
section_bigwarp_folder.mkdir(exist_ok=True, parents=True)
235+
tifffile.imwrite(str(out_path), arr, **kwargs)
236+
print(f' Wrote: {out_path.name} shape={arr.shape} binary={binary}')

src/xenium_analysis_tools/alignment/format_for_napari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import dask.dataframe as dd
1313
import matplotlib.pyplot as plt
1414
import matplotlib.colors as mcolors
15-
import napari
1615

1716
def filter_labels(sdata, label_elements='cell_labels', table='table', key_col='cell_labels'):
1817
# Get all label elements that match the specified prefix
@@ -904,6 +903,7 @@ def set_solid_label_color(sdata, table_key, color, col_name='label_color_group')
904903
table.uns[f'{col_name}_colors'] = np.array([color])
905904

906905
def apply_layer_style(layer, layer_styles):
906+
import napari
907907
def _find_style(layer_name):
908908
"""Return params for the longest matching key, or None."""
909909
matches = [k for k in layer_styles if k in layer_name]

0 commit comments

Comments
 (0)