|
| 1 | +import numpy as np |
| 2 | +import pandas as pd |
| 3 | +import spatialdata as sd |
| 4 | +import tifffile |
| 5 | +from IPython.display import display |
| 6 | +from tqdm.notebook import tqdm |
| 7 | + |
| 8 | +from xenium_analysis_tools.utils.sd_utils import ( |
| 9 | + add_micron_coord_sys, |
| 10 | + _is_multiscale, |
| 11 | +) |
| 12 | + |
| 13 | +from spatialdata.transformations import ( |
| 14 | + get_transformation, |
| 15 | + Scale, |
| 16 | + Identity, |
| 17 | + Sequence |
| 18 | +) |
| 19 | + |
| 20 | +from xenium_analysis_tools.alignment.format_for_napari import ( |
| 21 | + add_mapped_cells_cols, |
| 22 | + filter_cells, |
| 23 | + filter_labels, |
| 24 | +) |
| 25 | + |
| 26 | +from spatialdata.models import TableModel |
| 27 | + |
| 28 | + |
| 29 | +def _rename_channel_coord(element_obj, channel_name_map=None): |
| 30 | + if channel_name_map is None: |
| 31 | + channel_name_map = { |
| 32 | + 'DAPI': 'dapi', |
| 33 | + 'ATP1A1/CD45/E-Cadherin': 'boundary', |
| 34 | + '18S': 'rna', |
| 35 | + 'AlphaSMA/Vimentin': 'protein' |
| 36 | + } |
| 37 | + |
| 38 | + if not hasattr(element_obj, 'coords'): |
| 39 | + return element_obj |
| 40 | + if 'c' not in element_obj.coords: |
| 41 | + return element_obj |
| 42 | + |
| 43 | + old_names = [str(ch) for ch in element_obj.coords['c'].values] |
| 44 | + new_names = [channel_name_map.get(ch, ch) for ch in old_names] |
| 45 | + |
| 46 | + if old_names == new_names: |
| 47 | + return element_obj |
| 48 | + |
| 49 | + return element_obj.assign_coords(c=new_names) |
| 50 | + |
| 51 | + |
| 52 | +def add_mapping_results(sdata, mapping_output_path, table_el='table', section_n=None): |
| 53 | + import anndata as ad |
| 54 | + from xenium_analysis_tools.map_xenium.format_mapping import ( |
| 55 | + add_broad_types, |
| 56 | + ) |
| 57 | + mapped_data = ad.read_h5ad(mapping_output_path) |
| 58 | + if section_n is not None: |
| 59 | + mapped_data = mapped_data[mapped_data.obs['section'] == section_n] |
| 60 | + mapped_data.obs['cell_id'] = mapped_data.obs['original_cell_id'] |
| 61 | + mapped_data.obs.rename_axis('cell_section_id', inplace=True) |
| 62 | + mapped_data.var.set_index('gene_symbol', inplace=True, drop=False) |
| 63 | + sdata[table_el] = add_mapped_cells_cols(sdata[table_el], mapped_data, verbose=True) |
| 64 | + sdata[table_el] = add_broad_types(sdata[table_el]) |
| 65 | + return sdata |
| 66 | + |
| 67 | + |
| 68 | +def get_cell_labels(section_sdata, table_el='table', labels_el='cell_labels', multiscale_level=2, cell_filters=None): |
| 69 | + if 'microns' not in section_sdata.coordinate_systems: |
| 70 | + section_sdata = add_micron_coord_sys(section_sdata) |
| 71 | + |
| 72 | + if _is_multiscale(section_sdata[labels_el]): |
| 73 | + label_da = sd.get_pyramid_levels(section_sdata[labels_el], n=multiscale_level) |
| 74 | + else: |
| 75 | + label_da = section_sdata[labels_el] |
| 76 | + |
| 77 | + # Compute to numpy so filter_labels' in-place assignment works |
| 78 | + # (dask arrays silently ignore item assignment) |
| 79 | + import dask.array as da |
| 80 | + label_da = label_da.copy() |
| 81 | + label_da.data = da.from_array(label_da.data.compute()) |
| 82 | + |
| 83 | + table = section_sdata[table_el].copy() |
| 84 | + table.obs['region'] = pd.Categorical(['cell_labels'] * len(table), categories=['cell_labels']) |
| 85 | + table = TableModel.parse(table, region='cell_labels', region_key='region', instance_key='cell_labels', overwrite_metadata=True) |
| 86 | + |
| 87 | + cell_labels_sd = sd.SpatialData(tables={table_el: table}, labels={labels_el: label_da}) |
| 88 | + if cell_filters is not None: |
| 89 | + cell_labels_sd = filter_cells(cell_labels_sd, cell_filters=cell_filters) |
| 90 | + cell_labels_sd = filter_labels(cell_labels_sd) |
| 91 | + |
| 92 | + return cell_labels_sd |
| 93 | + |
| 94 | + |
| 95 | +def extract_bigwarp_images_for_section(section_sdata, |
| 96 | + section_n, |
| 97 | + bigwarp_projects_folder, |
| 98 | + el_name = 'morphology_focus', |
| 99 | + subset_channels = 'all', |
| 100 | + multiscale_level = 2, |
| 101 | + dtype='uint16', |
| 102 | + normalize=False, |
| 103 | + z_step_um=None, |
| 104 | + resunit = 'cm', |
| 105 | + microns_coord_sys_name = 'microns', |
| 106 | + return_sdata=True): |
| 107 | + section_bigwarp_folder = bigwarp_projects_folder / f'section_{section_n}' |
| 108 | + section_bigwarp_folder.mkdir(exist_ok=True, parents=True) |
| 109 | + if not isinstance(section_sdata, sd.SpatialData): |
| 110 | + section_sdata = sd.read_zarr(section_sdata) |
| 111 | + section_sdata = add_micron_coord_sys(section_sdata) |
| 112 | + if _is_multiscale(section_sdata[el_name]): |
| 113 | + mf_element = sd.get_pyramid_levels(section_sdata[el_name], n=multiscale_level) |
| 114 | + else: |
| 115 | + mf_element = section_sdata[el_name] |
| 116 | + mf_element = _rename_channel_coord(mf_element) |
| 117 | + if subset_channels == 'all': |
| 118 | + subset_channels = mf_element.coords['c'].values |
| 119 | + |
| 120 | + for ch in tqdm(subset_channels, desc=f'Extracting channels for section {section_n}'): |
| 121 | + out_path = section_bigwarp_folder / f'{ch}.tif' |
| 122 | + ch_el = mf_element.sel(c=ch) |
| 123 | + dims = ch_el.dims |
| 124 | + microns_tf = get_transformation(ch_el, to_coordinate_system=microns_coord_sys_name) |
| 125 | + |
| 126 | + # Get scale factors for Y and X dimensions |
| 127 | + if isinstance(microns_tf, Scale): |
| 128 | + pixel_size_yx = [microns_tf.scale[microns_tf.axes.index(d)] for d in ['y', 'x']] |
| 129 | + elif isinstance(microns_tf, Identity): |
| 130 | + pixel_size_yx = [1.0, 1.0] |
| 131 | + elif isinstance(microns_tf, Sequence): |
| 132 | + py, px = 1.0, 1.0 |
| 133 | + for tf in microns_tf.transformations: |
| 134 | + if isinstance(tf, Scale): |
| 135 | + py *= tf.scale[tf.axes.index('y')] |
| 136 | + px *= tf.scale[tf.axes.index('x')] |
| 137 | + pixel_size_yx = [py, px] |
| 138 | + else: |
| 139 | + pixel_size_yx = None |
| 140 | + print(f' Warning: unhandled transform type {type(microns_tf)} for channel {ch}, skipping calibration') |
| 141 | + |
| 142 | + |
| 143 | + arr = ch_el.data.compute() |
| 144 | + if np.issubdtype(np.dtype(dtype), np.integer) and normalize: |
| 145 | + finite = arr[np.isfinite(arr)] |
| 146 | + lo, hi = finite.min(), finite.max() |
| 147 | + if hi > lo: |
| 148 | + arr = (arr.astype(np.float64) - lo) / (hi - lo) * np.iinfo(dtype).max |
| 149 | + arr = np.nan_to_num(arr, nan=0.0) |
| 150 | + arr = np.clip(arr, 0, np.iinfo(dtype).max) |
| 151 | + arr = arr.astype(dtype, copy=False) |
| 152 | + |
| 153 | + ij_meta = {'axes': ''.join(d.upper() for d in dims)} |
| 154 | + if 'z' in dims and z_step_um is not None: |
| 155 | + ij_meta['spacing'] = z_step_um |
| 156 | + if pixel_size_yx is not None: |
| 157 | + ij_meta['unit'] = 'um' |
| 158 | + |
| 159 | + # Resolution tag: pixels per micron (for XY) |
| 160 | + kwargs = dict(imagej=True, metadata=ij_meta) |
| 161 | + if pixel_size_yx is not None: |
| 162 | + py, px = pixel_size_yx |
| 163 | + resolution_um = (1.0 / px, 1.0 / py) |
| 164 | + resolution_cm = (1e4 / px, 1e4 / py) |
| 165 | + kwargs['resolution'] = resolution_um if resunit == 'um' else resolution_cm |
| 166 | + kwargs['resolutionunit'] = tifffile.RESUNIT.MICROMETER if resunit == 'um' else tifffile.RESUNIT.CENTIMETER |
| 167 | + |
| 168 | + tifffile.imwrite(str(out_path), arr, **kwargs) |
| 169 | + |
| 170 | + if return_sdata: |
| 171 | + return section_sdata |
| 172 | + |
| 173 | +def extract_bigwarp_label_for_section( |
| 174 | + labels_el, |
| 175 | + labels_name, |
| 176 | + section_n, |
| 177 | + bigwarp_projects_folder, |
| 178 | + microns_coord_sys_name='microns', |
| 179 | + dtype='uint8', |
| 180 | + binary=True, |
| 181 | + z_step_um=None, |
| 182 | + resunit='cm' |
| 183 | +): |
| 184 | + section_bigwarp_folder = bigwarp_projects_folder / f'section_{section_n}' |
| 185 | + out_path = section_bigwarp_folder / f'{labels_name}.tif' |
| 186 | + |
| 187 | + # Get the labels DataArray — handle SpatialData objects and raw DataArrays |
| 188 | + if isinstance(labels_el, sd.SpatialData): |
| 189 | + el_name = list(labels_el.labels.keys())[0] |
| 190 | + el = labels_el.labels[el_name] |
| 191 | + else: |
| 192 | + el = labels_el |
| 193 | + if _is_multiscale(el): |
| 194 | + el = sd.get_pyramid_levels(el, n=2) |
| 195 | + |
| 196 | + dims = el.dims |
| 197 | + microns_tf = get_transformation(el, to_coordinate_system=microns_coord_sys_name) |
| 198 | + |
| 199 | + if isinstance(microns_tf, Scale): |
| 200 | + pixel_size_yx = [microns_tf.scale[microns_tf.axes.index(d)] for d in ['y', 'x']] |
| 201 | + elif isinstance(microns_tf, Identity): |
| 202 | + pixel_size_yx = [1.0, 1.0] |
| 203 | + elif isinstance(microns_tf, Sequence): |
| 204 | + py, px = 1.0, 1.0 |
| 205 | + for tf in microns_tf.transformations: |
| 206 | + if isinstance(tf, Scale): |
| 207 | + py *= tf.scale[tf.axes.index('y')] |
| 208 | + px *= tf.scale[tf.axes.index('x')] |
| 209 | + pixel_size_yx = [py, px] |
| 210 | + else: |
| 211 | + pixel_size_yx = None |
| 212 | + print(f' Warning: unhandled transform type {type(microns_tf)}, skipping calibration') |
| 213 | + |
| 214 | + arr = el.data.compute() |
| 215 | + if binary: |
| 216 | + arr = np.where(arr > 0, np.iinfo('uint16').max, 0).astype('uint16') |
| 217 | + else: |
| 218 | + arr = arr.astype(dtype, copy=False) |
| 219 | + |
| 220 | + ij_meta = {'axes': ''.join(d.upper() for d in dims)} |
| 221 | + if 'z' in dims and z_step_um is not None: |
| 222 | + ij_meta['spacing'] = z_step_um |
| 223 | + if pixel_size_yx is not None: |
| 224 | + ij_meta['unit'] = 'um' |
| 225 | + |
| 226 | + kwargs = dict(imagej=True, metadata=ij_meta) |
| 227 | + if pixel_size_yx is not None: |
| 228 | + py, px = pixel_size_yx |
| 229 | + resolution_cm = (1e4 / px, 1e4 / py) |
| 230 | + resolution_um = (1.0 / px, 1.0 / py) |
| 231 | + kwargs['resolution'] = resolution_um if resunit == 'um' else resolution_cm |
| 232 | + kwargs['resolutionunit'] = tifffile.RESUNIT.MICROMETER if resunit == 'um' else tifffile.RESUNIT.CENTIMETER |
| 233 | + |
| 234 | + section_bigwarp_folder.mkdir(exist_ok=True, parents=True) |
| 235 | + tifffile.imwrite(str(out_path), arr, **kwargs) |
| 236 | + print(f' Wrote: {out_path.name} shape={arr.shape} binary={binary}') |
0 commit comments