|
| 1 | +import spatialdata as sd |
| 2 | +from xenium_analysis_tools.utils.sd_utils import add_micron_coord_sys |
| 3 | +from spatialdata.models import Image3DModel, Labels3DModel |
| 4 | +from pathlib import Path |
| 5 | +import pandas as pd |
| 6 | +import numpy as np |
| 7 | +import xarray as xr |
| 8 | +import tifffile |
| 9 | +import json |
| 10 | +import re |
| 11 | + |
| 12 | +def create_zstack_array(tif_path, |
| 13 | + add_chan=True, |
| 14 | + fov_z_um=450.0, |
| 15 | + fov_x_um=400.0, |
| 16 | + fov_y_um=400.0): |
| 17 | + |
| 18 | + data = tifffile.imread(tif_path) |
| 19 | + pixels_z, pixels_y, pixels_x = data.shape |
| 20 | + |
| 21 | + # Pixel sizes |
| 22 | + pixel_size_z = fov_z_um / pixels_z |
| 23 | + pixel_size_y = fov_y_um / pixels_y |
| 24 | + pixel_size_x = fov_x_um / pixels_x |
| 25 | + |
| 26 | + # Create coordinate arrays with proper spacing |
| 27 | + z_coords = np.arange(pixels_z) * pixel_size_z |
| 28 | + y_coords = np.arange(pixels_y) * pixel_size_y |
| 29 | + x_coords = np.arange(pixels_x) * pixel_size_x |
| 30 | + coords = {"z": z_coords, |
| 31 | + "y": y_coords, |
| 32 | + "x": x_coords} |
| 33 | + |
| 34 | + if add_chan: |
| 35 | + data = np.expand_dims(data, axis=0) |
| 36 | + coords["c"] = np.arange(data.shape[0]) |
| 37 | + dims = ("c", "z", "y", "x") |
| 38 | + else: |
| 39 | + dims = ("z", "y", "x") |
| 40 | + |
| 41 | + # Create xarray DataArray with improved metadata |
| 42 | + da = xr.DataArray( |
| 43 | + data, |
| 44 | + coords=coords, |
| 45 | + dims=dims, |
| 46 | + attrs={ |
| 47 | + "pixel_size_um_z": pixel_size_z, |
| 48 | + "pixel_size_um_y": pixel_size_y, |
| 49 | + "pixel_size_um_x": pixel_size_x, |
| 50 | + "fov_um_z": fov_z_um, |
| 51 | + "fov_um_y": fov_y_um, |
| 52 | + "fov_um_x": fov_x_um, |
| 53 | + "units": "micrometers", |
| 54 | + } |
| 55 | + ) |
| 56 | + return da |
| 57 | + |
| 58 | +def get_zstacks_dict(zstacks_folder, channels=['gcamp', 'dextran']): |
| 59 | + zstacks_dict = {} |
| 60 | + |
| 61 | + # Process only directories |
| 62 | + stack_dirs = [d for d in zstacks_folder.iterdir() if d.is_dir()] |
| 63 | + |
| 64 | + for stack_ind, stack_folder in enumerate(stack_dirs): |
| 65 | + stack_info = { |
| 66 | + 'stack_name': stack_folder.name, |
| 67 | + 'stack_size': _extract_stack_size(stack_folder.name), |
| 68 | + 'stack_channels': [ch for ch in channels if ch in stack_folder.name.lower()], |
| 69 | + 'metadata_jsons': {'registration': None, 'roi_groups': None, 'scanimage': None}, |
| 70 | + 'channel_tifs': {} |
| 71 | + } |
| 72 | + |
| 73 | + # Process files in stack folder |
| 74 | + chan_ind = 0 |
| 75 | + for item in sorted(stack_folder.iterdir()): |
| 76 | + if item.is_file() and item.suffix.lower() == '.json': |
| 77 | + # Categorize JSON metadata files |
| 78 | + json_type = _categorize_json_file(item.name.lower()) |
| 79 | + if json_type: |
| 80 | + stack_info['metadata_jsons'][json_type] = item |
| 81 | + |
| 82 | + elif item.is_dir() and 'channel' in item.name.lower(): |
| 83 | + # Process channel directories |
| 84 | + tif_files = [f for f in item.iterdir() if f.suffix.lower() == '.tif'] |
| 85 | + stack_info['channel_tifs'][chan_ind] = { |
| 86 | + 'chan_name': item.name, |
| 87 | + 'chan_tif_path': tif_files[0] if len(tif_files) == 1 else tif_files |
| 88 | + } |
| 89 | + chan_ind += 1 |
| 90 | + |
| 91 | + zstacks_dict[stack_ind] = stack_info |
| 92 | + |
| 93 | + return zstacks_dict |
| 94 | + |
| 95 | +def _extract_stack_size(stack_name): |
| 96 | + """Extract width x height x depth from stack name.""" |
| 97 | + size_pattern = re.search(r'(\d+)x(\d+)x(\d+)', stack_name) |
| 98 | + if size_pattern: |
| 99 | + width, height, depth = map(int, size_pattern.groups()) |
| 100 | + return {"width": width, "height": height, "depth": depth} |
| 101 | + return {"width": None, "height": None, "depth": None} |
| 102 | + |
| 103 | +def _categorize_json_file(filename_lower): |
| 104 | + """Categorize JSON file by its name.""" |
| 105 | + if 'registration' in filename_lower: |
| 106 | + return 'registration' |
| 107 | + elif 'roi_groups' in filename_lower: |
| 108 | + return 'roi_groups' |
| 109 | + elif 'scanimage' in filename_lower: |
| 110 | + return 'scanimage' |
| 111 | + return None |
| 112 | + |
| 113 | +def get_zstack(zstacks_dict, zstack_ind=None, zstack_name=None, stack_size=None, channels=None): |
| 114 | + if zstack_ind is not None: |
| 115 | + if zstack_ind not in zstacks_dict: |
| 116 | + raise ValueError(f"Z-stack index {zstack_ind} not found in zstacks_dict.") |
| 117 | + return zstacks_dict[zstack_ind] |
| 118 | + |
| 119 | + # Helper function to find matches |
| 120 | + def _find_matches(criterion_func, criterion_name, criterion_value): |
| 121 | + matches = [i for i, stack in zstacks_dict.items() if criterion_func(stack)] |
| 122 | + |
| 123 | + if not matches: |
| 124 | + raise ValueError(f"{criterion_name} {criterion_value} not found in zstacks_dict.") |
| 125 | + |
| 126 | + if len(matches) == 1: |
| 127 | + return zstacks_dict[matches[0]] |
| 128 | + |
| 129 | + # Handle multiple matches with optional channel filtering |
| 130 | + if channels is not None: |
| 131 | + channel_matches = [ |
| 132 | + i for i in matches |
| 133 | + if set(zstacks_dict[i]['stack_channels']) == set(channels) |
| 134 | + ] |
| 135 | + if len(channel_matches) == 1: |
| 136 | + return zstacks_dict[channel_matches[0]] |
| 137 | + elif len(channel_matches) > 1: |
| 138 | + raise ValueError(f"Multiple z-stacks found with {criterion_name} {criterion_value} and channels {channels}. Found {len(channel_matches)} matches.") |
| 139 | + else: |
| 140 | + raise ValueError(f"No z-stack found with {criterion_name} {criterion_value} and channels {channels}.") |
| 141 | + |
| 142 | + raise ValueError(f"Multiple z-stacks found with {criterion_name} {criterion_value}. Found {len(matches)} matches. Consider specifying channels parameter.") |
| 143 | + |
| 144 | + if zstack_name is not None: |
| 145 | + return _find_matches( |
| 146 | + lambda stack: stack['stack_name'] == zstack_name, |
| 147 | + "Z-stack name", zstack_name |
| 148 | + ) |
| 149 | + |
| 150 | + if stack_size is not None: |
| 151 | + return _find_matches( |
| 152 | + lambda stack: ( |
| 153 | + stack['stack_size']['width'] == stack_size['width'] and |
| 154 | + stack['stack_size']['height'] == stack_size['height'] and |
| 155 | + stack['stack_size']['depth'] == stack_size['depth'] |
| 156 | + ), |
| 157 | + "Stack size", stack_size |
| 158 | + ) |
| 159 | + |
| 160 | + raise ValueError("Either zstack_ind, zstack_name, or stack_size must be provided.") |
| 161 | + |
| 162 | +def get_alignment_data_paths(dataset_id, |
| 163 | + data_root=Path('/root/capsule/data'), |
| 164 | + scratch_root=Path('/root/capsule/scratch'), |
| 165 | + results_root=Path('/root/capsule/results'), |
| 166 | + code_root=Path('/root/capsule/code')): |
| 167 | + datasets_naming_dict_path = code_root / 'datasets_names_dict.json' |
| 168 | + with open(datasets_naming_dict_path) as f: |
| 169 | + datasets_naming_dict = json.load(f) |
| 170 | + dataset_id = str(dataset_id) # Ensure string format |
| 171 | + dataset_config = datasets_naming_dict[dataset_id] |
| 172 | + |
| 173 | + paths = { |
| 174 | + "data_root": data_root, |
| 175 | + "scratch_root": scratch_root, |
| 176 | + "results_root": results_root, |
| 177 | + "sdata_path": data_root / f'{dataset_config["xenium_name"]}_processed', |
| 178 | + "confocal_path": data_root / dataset_config["confocal_name"], |
| 179 | + "zstack_path": data_root / dataset_config["zstack_name"], |
| 180 | + "zstack_masks": data_root / dataset_config["zstack_masks_name"] |
| 181 | + } |
| 182 | + |
| 183 | + return paths |
| 184 | + |
| 185 | +def get_zstack_sdata(stack, stack_masks=None, use_shared_coords=True): |
| 186 | + # Create the z-stack image array |
| 187 | + num_channels = len(stack['stack_channels']) |
| 188 | + chans = [] |
| 189 | + if num_channels > 1: |
| 190 | + for ch_ind in range(num_channels): |
| 191 | + chan_array = create_zstack_array(tif_path=stack['channel_tifs'][ch_ind]['chan_tif_path'], |
| 192 | + fov_x_um=stack['stack_size']['width'], |
| 193 | + fov_y_um=stack['stack_size']['height'], |
| 194 | + fov_z_um=stack['stack_size']['depth']) |
| 195 | + chans.append(chan_array) |
| 196 | + zstack_img = xr.concat(chans, dim='c') |
| 197 | + zstack_img['c'] = stack['stack_channels'] |
| 198 | + else: |
| 199 | + zstack_img = create_zstack_array(tif_path=stack['channel_tifs'][0]['chan_tif_path'], |
| 200 | + fov_x_um=stack['stack_size']['width'], |
| 201 | + fov_y_um=stack['stack_size']['height'], |
| 202 | + fov_z_um=stack['stack_size']['depth']) |
| 203 | + zstack_img['c'] = stack['stack_channels'] |
| 204 | + |
| 205 | + if use_shared_coords: |
| 206 | + reg_json_path = stack['metadata_jsons']['registration'] |
| 207 | + with open(reg_json_path) as f: |
| 208 | + reg_json = json.load(f) |
| 209 | + if 'z_steps' in reg_json.keys() and len(reg_json['z_steps'])==zstack_img.sizes['z']: |
| 210 | + print("Using shared z coordinates for images") |
| 211 | + zstack_img.coords['z'] = reg_json['z_steps'] |
| 212 | + |
| 213 | + # Parse into Image3DModel |
| 214 | + zstack_img = Image3DModel.parse( |
| 215 | + zstack_img, |
| 216 | + dims=['c', 'z', 'y', 'x'], |
| 217 | + c_coords=stack['stack_channels'], |
| 218 | + chunks='auto', |
| 219 | + ) |
| 220 | + |
| 221 | + # Make the SpatialData object |
| 222 | + zstack_sdata = sd.SpatialData( |
| 223 | + images={'zstack': zstack_img}, |
| 224 | + ) |
| 225 | + |
| 226 | + if stack_masks is not None: |
| 227 | + # Get labels for each channel |
| 228 | + for mask_ind, masks in zstack_masks['channel_tifs'].items(): |
| 229 | + channel_name = zstack_masks['stack_channels'][mask_ind] |
| 230 | + zstack_label = create_zstack_array(tif_path=masks['chan_tif_path'], |
| 231 | + fov_x_um=stack_masks['stack_size']['width'], |
| 232 | + fov_y_um=stack_masks['stack_size']['height'], |
| 233 | + fov_z_um=stack_masks['stack_size']['depth'], |
| 234 | + add_chan=False) |
| 235 | + |
| 236 | + if use_shared_coords: |
| 237 | + if 'z_steps' in reg_json.keys() and len(reg_json['z_steps'])==zstack_label.sizes['z']: |
| 238 | + print("Using shared z coordinates for labels") |
| 239 | + zstack_label.coords['z'] = reg_json['z_steps'] |
| 240 | + |
| 241 | + zstack_label = Labels3DModel.parse( |
| 242 | + zstack_label, |
| 243 | + dims=['z', 'y', 'x'], |
| 244 | + chunks='auto', |
| 245 | + ) |
| 246 | + zstack_sdata.labels[f"{channel_name}_labels"] = zstack_label |
| 247 | + |
| 248 | + # Determine pixel sizes |
| 249 | + if zstack_sdata['zstack'].attrs['pixel_size_um_x'] == zstack_sdata['zstack'].attrs['pixel_size_um_y']: |
| 250 | + pixel_size = zstack_sdata['zstack'].attrs['pixel_size_um_x'] |
| 251 | + if zstack_sdata['zstack'].attrs['fov_um_z']==zstack_sdata['zstack'].shape[1]: |
| 252 | + z_step_size = 1 |
| 253 | + |
| 254 | + # Add micron coordinate system if not already present |
| 255 | + if 'microns' not in zstack_sdata.coordinate_systems: |
| 256 | + zstack_sdata = add_micron_coord_sys(zstack_sdata, pixel_size=pixel_size, z_step=z_step_size) |
| 257 | + else: |
| 258 | + print("Micron coordinate system already exists") |
| 259 | + return zstack_sdata |
| 260 | + |
0 commit comments