Skip to content

Commit 8612f6f

Browse files
committed
added plot utils, updated alignment images, xoa server functions
1 parent 9c2cfdf commit 8612f6f

4 files changed

Lines changed: 246 additions & 44 deletions

File tree

src/xenium_analysis_tools/alignment/generate_images.py

Lines changed: 117 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import spatialdata as sd
22
from xenium_analysis_tools.utils.sd_utils import add_micron_coord_sys
3-
from spatialdata.models import Image3DModel, Labels3DModel, PointsModel
3+
from spatialdata.models import Image2DModel, Image3DModel, Labels3DModel, Labels2DModel, PointsModel, ShapesModel, TableModel
4+
from spatialdata.transformations import get_transformation, set_transformation
5+
import anndata as ad
6+
from spatialdata import get_element_instances
47
from pathlib import Path
58
import pandas as pd
69
import numpy as np
@@ -173,6 +176,7 @@ def get_alignment_data_paths(dataset_id,
173176
"data_root": data_root,
174177
"scratch_root": scratch_root,
175178
"results_root": results_root,
179+
"xenium_dataset_name": dataset_config["xenium_name"],
176180
"sdata_path": data_root / f'{dataset_config["xenium_name"]}_processed',
177181
"confocal_path": data_root / dataset_config["confocal_name"],
178182
"zstack_path": data_root / dataset_config["zstack_name"],
@@ -187,43 +191,31 @@ def get_label_params(label_obj, id_name='cell'):
187191
props = regionprops(labels)
188192
data = [
189193
{f'{id_name}_id': p.label,
190-
'z': p.centroid[0],
191-
'y': p.centroid[1],
192-
'x': p.centroid[2],
194+
'z': p.centroid[0] if len(p.centroid)==3 else None,
195+
'y': p.centroid[1] if len(p.centroid)==3 else p.centroid[0],
196+
'x': p.centroid[2] if len(p.centroid)==3 else p.centroid[1],
193197
'area': p.area,
194198
'bbox': p.bbox}
195199
for p in props
196200
]
197201
df = pd.DataFrame(data)
198202
return df
199203

200-
def get_zstack_sdata(stack, zstack_masks=None, get_centroids_as_points=True):
204+
def get_zstack_sdata(stack, zstack_masks=None):
201205
# Create the z-stack image array
202-
num_channels = len(stack['zstack_channels'])
203-
chans = []
204-
if num_channels > 1:
205-
for ch_ind in range(num_channels):
206-
chan_array = create_zstack_array(tif_path=stack['channel_tifs'][ch_ind]['chan_tif_path'],
206+
chan_arrays = {}
207+
for ch_ind, chan_name in enumerate(stack['zstack_channels']):
208+
chan_img = create_zstack_array(tif_path=stack['channel_tifs'][ch_ind]['chan_tif_path'],
207209
fov_x_um=stack['zstack_size']['width'],
208210
fov_y_um=stack['zstack_size']['height'],
209211
fov_z_um=stack['zstack_size']['depth'])
210-
chans.append(chan_array)
211-
zstack_img = xr.concat(chans, dim='c')
212-
zstack_img['c'] = stack['zstack_channels']
213-
else:
214-
zstack_img = create_zstack_array(tif_path=stack['channel_tifs'][0]['chan_tif_path'],
215-
fov_x_um=stack['zstack_size']['width'],
216-
fov_y_um=stack['zstack_size']['height'],
217-
fov_z_um=stack['zstack_size']['depth'])
218-
zstack_img['c'] = stack['zstack_channels']
219-
220-
# Parse into Image3DModel
221-
zstack_img = Image3DModel.parse(
222-
zstack_img,
212+
chan_img = Image3DModel.parse(
213+
chan_img,
223214
dims=['c', 'z', 'y', 'x'],
224-
c_coords=stack['zstack_channels'],
215+
c_coords=chan_name,
225216
chunks='auto',
226217
)
218+
chan_arrays[chan_name] = chan_img
227219

228220
if zstack_masks is not None:
229221
zstack_labels = {}
@@ -242,28 +234,32 @@ def get_zstack_sdata(stack, zstack_masks=None, get_centroids_as_points=True):
242234
chunks='auto',
243235
)
244236
zstack_labels[f"{channel_name}_labels"] = zstack_label
245-
246-
if get_centroids_as_points:
247-
zstack_points = {}
248-
# Get label parameters add as points
249-
for label_name, labels_obj in zstack_labels.items():
250-
chan_name = label_name.replace('_labels','')
251-
cells_df = get_label_params(labels_obj, id_name=chan_name)
252-
print(f"# {chan_name} segmented cells: {len(cells_df)}")
253-
cells_df = PointsModel.parse(cells_df)
254-
zstack_points[f"{chan_name}_cells"] = cells_df
237+
238+
tables = {}
239+
for label_name, labels_obj in zstack_labels.items():
240+
chan_name = label_name.replace('_labels','')
241+
label_type_id = f'{chan_name}_id'
242+
chan_label_ids = get_element_instances(labels_obj).values
243+
obs = pd.DataFrame(chan_label_ids, columns=[label_type_id])
244+
cells_df = get_label_params(labels_obj, id_name=chan_name)
245+
cells_df['region'] = label_name
246+
obs = obs.merge(cells_df, left_on=label_type_id, right_on=chan_name+'_id', how='left')
247+
table = ad.AnnData(obs=obs, obsm={'spatial': obs[['z','y','x']].values})
248+
table = TableModel.parse(table, region=label_name, region_key='region', instance_key=label_type_id)
249+
tables[f'{chan_name}_cells'] = table
255250

256251
# Assemble SpatialData
257252
zstack_sdata = sd.SpatialData(
258-
images={'zstack': zstack_img},
253+
images={**chan_arrays},
259254
labels={**zstack_labels} if zstack_masks is not None else {},
260-
points={**zstack_points} if (zstack_masks is not None and get_centroids_as_points) else {}
255+
tables={**tables} if zstack_masks is not None else {}
261256
)
262257

263258
# Determine pixel sizes
264-
if zstack_sdata['zstack'].attrs['pixel_size_um_x'] == zstack_sdata['zstack'].attrs['pixel_size_um_y']:
265-
pixel_size = zstack_sdata['zstack'].attrs['pixel_size_um_x']
266-
if zstack_sdata['zstack'].attrs['fov_um_z']==zstack_sdata['zstack'].shape[1]:
259+
zstack_chan = zstack_sdata[stack['zstack_channels'][0]] # Use first channel for pixel size reference
260+
if zstack_chan.attrs['pixel_size_um_x'] == zstack_chan.attrs['pixel_size_um_y']:
261+
pixel_size = zstack_chan.attrs['pixel_size_um_x']
262+
if zstack_chan.attrs['fov_um_z']==zstack_chan.shape[1]:
267263
z_step_size = 1
268264

269265
# Add micron coordinate system if not already present
@@ -273,3 +269,85 @@ def get_zstack_sdata(stack, zstack_masks=None, get_centroids_as_points=True):
273269
print("Micron coordinate system already exists")
274270
return zstack_sdata
275271

272+
def get_alignment_spatial_elements(sdata, scale_from_level=2, channel_names=['dapi', 'boundary', 'rna', 'protein']):
273+
# Technically should only need to replace morphology focus transforms, but doing for all elements just in case
274+
# For elements, get at a specific scale level (if multi-scale) and set transform to global coordinate system
275+
# Images
276+
# Dapi z-stack
277+
dapi_zstack_level = sdata['dapi_zstack'][f'scale{scale_from_level}'].image
278+
dapi_zstack_global_tf = get_transformation(dapi_zstack_level, to_coordinate_system='global')
279+
dapi_zstack = Image3DModel.parse(sdata['dapi_zstack'][f'scale{scale_from_level}'].image,
280+
dims=['c', 'z', 'y', 'x'],
281+
c_coords=['DAPI'],
282+
chunks='auto',
283+
)
284+
set_transformation(dapi_zstack, dapi_zstack_global_tf, to_coordinate_system='global')
285+
286+
# Morphology focus channels
287+
mf_chans_level = sdata['morphology_focus'][f'scale{scale_from_level}'].image
288+
mf_img_global_tf = get_transformation(mf_chans_level, to_coordinate_system='global')
289+
chans_arrays = {}
290+
for chan_ind, chan in enumerate(channel_names):
291+
chan_img = sdata['morphology_focus'][f'scale{scale_from_level}'].image[chan_ind]
292+
chan_img = np.expand_dims(chan_img.data, axis=0)
293+
chans_arrays[chan] = Image2DModel.parse(chan_img,
294+
dims=['c', 'y', 'x'],
295+
c_coords=chan,
296+
chunks='auto',
297+
)
298+
set_transformation(chans_arrays[chan], mf_img_global_tf, to_coordinate_system='global')
299+
300+
images = {'dapi_zstack': dapi_zstack, **chans_arrays}
301+
302+
# Labels
303+
cell_labels_level = sdata['cell_labels'][f'scale{scale_from_level}'].image
304+
cell_labels_tf = get_transformation(cell_labels_level, to_coordinate_system='global')
305+
cell_labels = Labels2DModel.parse(cell_labels_level, dims=['y', 'x'], chunks='auto')
306+
set_transformation(cell_labels, cell_labels_tf, to_coordinate_system='global')
307+
nucleus_labels_level = sdata['nucleus_labels'][f'scale{scale_from_level}'].image
308+
nucleus_labels = Labels2DModel.parse(nucleus_labels_level, dims=['y', 'x'], chunks='auto')
309+
set_transformation(nucleus_labels, cell_labels_tf, to_coordinate_system='global')
310+
311+
labels = {
312+
'cell_labels': cell_labels,
313+
'nucleus_labels': nucleus_labels,
314+
}
315+
316+
return images, labels
317+
318+
def get_alignment_shapes_tables(sdata,
319+
transcripts_qv_thresh=20,
320+
annotate_spatial_elements='cell_boundaries',
321+
cell_id_name='cell_id',
322+
mask_id_name='cell_labels'):
323+
# Make cell_id to cell_label mapping dictionary
324+
cell_id_label_dict = dict(zip(sdata['table'].obs[cell_id_name].values, sdata['table'].obs[mask_id_name].values))
325+
transcripts = sdata['transcripts'].compute()
326+
# Drop transcripts not included in counts
327+
transcripts = transcripts[transcripts['qv']>=transcripts_qv_thresh]
328+
# Add cell_labels to transcripts based on cell_id
329+
transcripts[mask_id_name] = transcripts[cell_id_name].map(cell_id_label_dict).fillna(0).astype('int64')
330+
# Annotate spatial elements (e.g., cell_boundaries) with cell_labels
331+
sdata[annotate_spatial_elements][cell_id_name] = sdata[annotate_spatial_elements].index.values
332+
sdata[annotate_spatial_elements][mask_id_name] = sdata[annotate_spatial_elements][cell_id_name].map(cell_id_label_dict).values
333+
sdata[annotate_spatial_elements].set_index(mask_id_name, inplace=True, drop=False)
334+
# Update annotation regions
335+
table = sdata['table'].copy()
336+
table.obs['region'] = annotate_spatial_elements
337+
table.obs['region'] = pd.Categorical(table.obs['region'])
338+
table.uns['spatialdata_attrs'].update({
339+
'region_key': 'region',
340+
'region': [annotate_spatial_elements],
341+
'instance_key': mask_id_name
342+
})
343+
344+
# Parse shapes
345+
annotated_shape = ShapesModel.parse(sdata[annotate_spatial_elements])
346+
shapes = sdata.shapes
347+
shapes[annotate_spatial_elements] = annotated_shape
348+
# Parse table
349+
table = TableModel.parse(table)
350+
# Parse transcripts
351+
transcripts = PointsModel.parse(transcripts)
352+
353+
return table, transcripts, shapes

src/xenium_analysis_tools/utils/plot_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,22 @@ def get_vals_perc(img, chan, vmin_val=None, vmax_val=None, vmin_perc=None, vmax_
1616
elif vmax_perc is not None:
1717
vmax = np.percentile(ch_vals.values, vmax_perc)
1818

19-
return vmin, vmax
19+
return vmin, vmax
20+
21+
def get_channel_name(chan, print_chan_names_only=False):
22+
channel_aliases = {'DAPI': ['dapi','nuclear'],
23+
'ATP1A1/CD45/E-Cadherin': ['boundary'],
24+
'18S': ['rna, RNA'],
25+
'AlphaSMA/Vimentin': ['protein']
26+
}
27+
if print_chan_names_only:
28+
chan_names = sd.models.get_channel_names(section_sdata[image_name])
29+
print('Available channel names:')
30+
for name in chan_names:
31+
print(f' - {name}')
32+
return None
33+
for chan_label, aliases in channel_aliases.items():
34+
for alias in aliases:
35+
if alias.lower() in chan.lower():
36+
return chan_label
37+
return chan

src/xenium_analysis_tools/utils/sd_utils.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
import spatialdata as sd
2-
from spatialdata.transformations import Scale, Identity, set_transformation
1+
2+
from spatialdata.transformations import Scale, Translation, Sequence, Identity, set_transformation
3+
import xarray as xr
4+
import numpy as np
5+
import pandas as pd
36

47
def add_micron_coord_sys(sdata, pixel_size=None, z_step=None):
58
# Define the pixel scaling factor
@@ -29,7 +32,8 @@ def add_micron_coord_sys(sdata, pixel_size=None, z_step=None):
2932

3033
# --- Images ---
3134
for image_name in sdata.images:
32-
if 'z' in sdata[image_name].dims:
35+
dims = sdata[image_name].dims if not isinstance(sdata[image_name], xr.core.datatree.DataTree) else sdata[image_name]['scale0'].dims
36+
if 'z' in dims:
3337
set_transformation(
3438
sdata.images[image_name],
3539
scale_czyx,
@@ -66,4 +70,55 @@ def add_micron_coord_sys(sdata, pixel_size=None, z_step=None):
6670
identity,
6771
to_coordinate_system="microns"
6872
)
69-
return sdata
73+
return sdata
74+
75+
def add_mapped_cells_cols(sdata, mapped_h5ad_path):
76+
import scanpy as sc
77+
mapped_h5ad = sc.read_h5ad(mapped_h5ad_path)
78+
mapping_obs_cols = np.setdiff1d(mapped_h5ad.obs.columns, sdata['table'].obs.columns)
79+
if len(mapping_obs_cols) == 0:
80+
print("No new columns to add from mapped data")
81+
else:
82+
print(f"Adding {len(mapping_obs_cols)} columns from mapped data: {mapping_obs_cols}")
83+
sdata['table'].obs = sdata['table'].obs.merge(
84+
mapped_h5ad.obs[mapping_obs_cols],
85+
left_index=True,
86+
right_index=True,
87+
how='outer'
88+
)
89+
mapping_vars_cols = np.setdiff1d(mapped_h5ad.var.columns, sdata['table'].var.columns)
90+
if len(mapping_vars_cols) == 0:
91+
print("No new columns to add from mapped data")
92+
else:
93+
print(f"Adding {len(mapping_vars_cols)} columns from mapped data: {mapping_vars_cols}")
94+
sdata['table'].var = sdata['table'].var.merge(
95+
mapped_h5ad.var[mapping_vars_cols],
96+
left_index=True,
97+
right_index=True,
98+
how='outer'
99+
)
100+
return sdata
101+
102+
def get_transcripts_bboxes(transcripts, id_col='cell_labels'):
103+
transcripts = transcripts.compute() if hasattr(transcripts, 'compute') else transcripts
104+
# If no transcripts, return empty dict quickly
105+
cell_label_bboxes = {}
106+
if transcripts.shape[0] == 0:
107+
cell_label_bboxes = {}
108+
else:
109+
# Aggregate min/max per cell label for z, y, x
110+
grouped = transcripts.groupby(id_col)[['z', 'y', 'x']].agg(['min', 'max'])
111+
112+
import numpy as np
113+
for cell_label, row in grouped.iterrows():
114+
# Skip background / unmapped label if present
115+
if cell_label == 0:
116+
continue
117+
z_min = int(np.floor(row[('z', 'min')]))
118+
y_min = int(np.floor(row[('y', 'min')]))
119+
x_min = int(np.floor(row[('x', 'min')]))
120+
z_max = int(np.ceil(row[('z', 'max')]))
121+
y_max = int(np.ceil(row[('y', 'max')]))
122+
x_max = int(np.ceil(row[('x', 'max')]))
123+
cell_label_bboxes[cell_label] = (z_min, y_min, x_min, z_max, y_max, x_max)
124+
return cell_label_bboxes
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import http.server
2+
import socketserver
3+
import threading
4+
import socket
5+
from pathlib import Path
6+
7+
def find_free_port(start_port=8000, max_port=8100):
8+
"""Find a free port starting from start_port"""
9+
for port in range(start_port, max_port):
10+
try:
11+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
12+
s.bind(('', port))
13+
return port
14+
except OSError:
15+
continue
16+
raise RuntimeError(f"No free port found between {start_port} and {max_port}")
17+
18+
def start_server(directory, port=None):
19+
"""Start a simple HTTP server in the background"""
20+
import os
21+
original_dir = os.getcwd()
22+
23+
try:
24+
# Find a free port if none specified
25+
if port is None:
26+
port = find_free_port()
27+
28+
os.chdir(directory)
29+
handler = http.server.SimpleHTTPRequestHandler
30+
httpd = socketserver.TCPServer(("", port), handler)
31+
32+
# Start server in background thread
33+
server_thread = threading.Thread(target=httpd.serve_forever)
34+
server_thread.daemon = True
35+
server_thread.start()
36+
37+
print(f"Server started at http://localhost:{port}")
38+
# print(f"Access your file at: http://localhost:{port}/{Path(html_name).name}")
39+
# print(f"Browse all files at: http://localhost:{port}")
40+
41+
return httpd
42+
except Exception as e:
43+
print(f"Server start failed: {e}")
44+
os.chdir(original_dir)
45+
return None
46+
47+
def stop_server(httpd):
48+
"""Stop the HTTP server"""
49+
if httpd:
50+
httpd.shutdown()
51+
print("Server stopped.")

0 commit comments

Comments
 (0)