Skip to content

Commit 96b20a6

Browse files
committed
refactored generating cortical zstack zarrs
1 parent 4721364 commit 96b20a6

6 files changed

Lines changed: 1569 additions & 793 deletions

File tree

src/xenium_analysis_tools/alignment/align_sections.py

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,6 @@ def get_alignment_transforms(landmarks):
287287
xenium_lm = landmarks[['x', 'y', 'z']].values.astype(float)
288288

289289
# ── 4b. Normalize section landmarks to full-res (global) pixel space ──
290-
# Landmarks are stored in matched-pyramid-level pixel space, with the
291-
# 'global' transform (e.g. Scale([4, 4])) recording the level→full-res
292-
# scale factor. add_affine_to_element prepends that same scale as the
293-
# first step of every stored transform so that level-N element pixels are
294-
# upscaled to full-res before the Affine is applied. The Affine must
295-
# therefore be fit in *full-res* pixel space so the two match.
296-
#
297-
# Without this step the Affine is fit on level-2 coords (0..~2500) but
298-
# receives level-2 × 4 = full-res (0..~10000) at apply time → 4× wrong
299-
# output coordinates → section appears 4× too large in czstack_microns.
300290
try:
301291
lm_global_tf = get_transformation(landmarks, to_coordinate_system='global')
302292
if lm_global_tf is not None and not isinstance(lm_global_tf, Identity):
@@ -328,102 +318,127 @@ def get_alignment_transforms(landmarks):
328318
}
329319
return section_affines
330320

331-
def adjust_3d_images_z_scaling(sdata, sections_depth, elements_3d=['dapi_zstack'], center_z=True):
321+
def _rescale_z(tf, mps, z_offset=0.0):
322+
"""Return a new transform with the z-scale replaced by *mps* and z shifted by *z_offset*.
323+
324+
Works for Scale, Sequence, Identity, or Affine inputs. Does not mutate the
325+
input transform. The z-scale replacement and optional centering translation
326+
are applied in a single pass — no double-nesting of Sequences.
327+
"""
328+
def _replace_z_in_scale(t):
329+
axes = list(t.axes)
330+
vals = list(t.scale)
331+
vals[axes.index('z')] = mps
332+
return Scale(vals, axes=t.axes)
333+
334+
if isinstance(tf, Identity):
335+
result = Scale([mps, 1.0, 1.0], axes=('z', 'y', 'x'))
336+
elif isinstance(tf, Scale):
337+
if 'z' in tf.axes:
338+
result = _replace_z_in_scale(tf)
339+
else:
340+
result = Sequence([Scale([mps], axes=('z',)), tf])
341+
elif isinstance(tf, Sequence):
342+
tfs = list(tf.transformations)
343+
for i, t in enumerate(tfs):
344+
if isinstance(t, Scale) and 'z' in t.axes:
345+
tfs[i] = _replace_z_in_scale(t)
346+
break
347+
else:
348+
# No z-bearing Scale found — prepend one
349+
tfs = [Scale([mps], axes=('z',))] + tfs
350+
result = Sequence(tfs)
351+
else:
352+
# Affine or other: wrap with z-scale prepended
353+
result = Sequence([Scale([mps], axes=('z',)), tf])
354+
355+
# Append z centering as a single translation Affine (one extra step, not a second pass)
356+
if z_offset != 0.0:
357+
mat = np.eye(4)
358+
mat[2, 3] = z_offset # row 2 = z output, col 3 = translation
359+
z_center_tf = Affine(mat, input_axes=('x', 'y', 'z'), output_axes=('x', 'y', 'z'))
360+
if isinstance(result, Sequence):
361+
result = Sequence(list(result.transformations) + [z_center_tf])
362+
else:
363+
result = Sequence([result, z_center_tf])
364+
365+
return result
366+
367+
368+
def adjust_3d_images_z_scaling(sdata, sections_depth, elements_3d=None, center_z=True):
332369
"""
333370
Rescale the z-axis of 3D section images to match the known section thickness.
334371
335-
If center_z=True (default), the slab is further shifted so that z=0 in global/microns
336-
space corresponds to the section midplane. This is the physically correct convention
337-
because the 2D landmark affine was fit to a projection of the whole section volume,
338-
so z=0 should represent the center of that volume rather than its bottom edge.
372+
If center_z=True (default), the slab is further shifted so that z=0 corresponds
373+
to the section midplane (range: -sections_depth/2 .. +sections_depth/2 µm),
374+
matching the convention used by adjust_transcripts_z_scaling.
339375
340-
The z-step uses ``sections_depth / (z_planes - 1)`` so that the first and last planes
341-
sit exactly at ±sections_depth/2, matching the transcript z-range produced by
342-
``adjust_transcripts_z_scaling``. This prevents empty planes from appearing in napari
343-
after the transcripts have ended.
376+
The z-step is sections_depth / (n_z - 1) so the first and last planes sit exactly
377+
at ±sections_depth/2 — no empty planes appear in Napari after the transcripts end.
378+
379+
All coordinate systems already registered on each pyramid level are updated;
380+
there is no need to pass a coordinate-system list.
344381
"""
345-
for el in elements_3d:
346-
if el not in sdata:
382+
if elements_3d is None:
383+
elements_3d = ['dapi_zstack']
384+
385+
for el_name in elements_3d:
386+
if el_name not in sdata:
347387
continue
348-
349-
for scale in sdata[el].keys():
350-
img = sdata[el][scale].image if hasattr(sdata[el][scale], 'image') else sdata[el][scale]
351-
z_planes = img.sizes.get('z', img.shape[0])
352-
if z_planes < 2:
353-
continue # single-plane: nothing meaningful to rescale
354-
microns_per_slice = sections_depth / (z_planes - 1)
355-
356-
def update_z_scale(tf):
357-
if isinstance(tf, Identity):
358-
return Scale([microns_per_slice, 1.0, 1.0], axes=('z', 'y', 'x'))
359-
elif isinstance(tf, Scale):
360-
if 'z' in tf.axes:
361-
new_scale = list(tf.scale)
362-
new_scale[list(tf.axes).index('z')] = microns_per_slice
363-
return Scale(new_scale, axes=tf.axes)
364-
return Sequence([Scale([microns_per_slice], axes=('z',)), tf])
365-
elif isinstance(tf, Sequence):
366-
new_tfs = [
367-
Scale(
368-
[microns_per_slice if ax == 'z' else s for ax, s in zip(t.axes, t.scale)],
369-
axes=t.axes
370-
) if isinstance(t, Scale) and 'z' in t.axes else t
371-
for t in tf.transformations
372-
]
373-
if not any(isinstance(t, Scale) and 'z' in t.axes for t in new_tfs):
374-
new_tfs = [Scale([microns_per_slice, 1.0, 1.0], axes=('z', 'y', 'x'))] + new_tfs
375-
return Sequence(new_tfs)
376-
return Sequence([Scale([microns_per_slice], axes=('z',)), tf])
377-
378-
for cs in ['global', 'microns']:
379-
existing_tf = get_transformation(img, to_coordinate_system=cs)
380-
if existing_tf is not None:
381-
new_tf = update_z_scale(existing_tf)
382-
set_transformation(img, new_tf, to_coordinate_system=cs)
383-
384-
# Center the slab: shift z so the midplane (z_image=(z_planes-1)/2) maps to z=0.
385-
# With mps = sections_depth/(z_planes-1), plane 0 → -sections_depth/2 and
386-
# plane (z_planes-1) → +sections_depth/2, exactly matching transcript z range.
387-
if center_z:
388-
z_center_offset = -(z_planes - 1) / 2.0 * microns_per_slice
389-
center_mat = np.eye(4)
390-
center_mat[2, 3] = z_center_offset
391-
center_tf_3d = Affine(center_mat, input_axes=('x', 'y', 'z'), output_axes=('x', 'y', 'z'))
392-
for cs in ['global', 'microns']:
393-
existing_tf = get_transformation(img, to_coordinate_system=cs)
394-
if existing_tf is not None:
395-
set_transformation(img, Sequence([existing_tf, center_tf_3d]), to_coordinate_system=cs)
388+
el = sdata[el_name]
389+
390+
# Iterate over every pyramid level robustly
391+
if _is_multiscale(el):
392+
level_imgs = [sd.get_pyramid_levels(el, n=i) for i in range(len(el.keys()))]
393+
else:
394+
level_imgs = [el]
395+
396+
for img in level_imgs:
397+
# Robust z-dimension check — never falls back to shape[0] (= c dim)
398+
if not (hasattr(img, 'dims') and 'z' in img.dims):
399+
continue
400+
n_z = img.sizes['z']
401+
if n_z < 2:
402+
continue
403+
404+
mps = sections_depth / (n_z - 1) # microns per z-step at this level
405+
z_offset = -(n_z - 1) / 2.0 * mps if center_z else 0.0
406+
407+
# Update every coord system present — not just a hardcoded subset
408+
for cs, existing_tf in get_transformation(img, get_all=True).items():
409+
set_transformation(img, _rescale_z(existing_tf, mps, z_offset),
410+
to_coordinate_system=cs)
396411
return sdata
397412

413+
398414
def adjust_transcripts_z_scaling(sdata, sections_depth, center_z=True):
399415
"""
400416
Rescale transcript z-coordinates so they span the known section thickness.
401417
402-
If center_z=True (default), z-coordinates are further shifted so that z=0 corresponds
403-
to the section midplane (range: -sections_depth/2 .. +sections_depth/2 µm). This
404-
matches the centered convention used by adjust_3d_images_z_scaling so that transcripts
405-
and the DAPI z-stack share the same z=0 reference point.
418+
If center_z=True (default), z=0 is the section midplane
419+
(range: -sections_depth/2 .. +sections_depth/2 µm), matching
420+
adjust_3d_images_z_scaling so transcripts and the DAPI z-stack share the
421+
same z=0 reference point.
406422
"""
407-
if 'original_z_coords' not in sdata['transcripts'].columns:
408-
z_coords = sdata['transcripts']['z']
409-
sdata['transcripts']['original_z_coords'] = z_coords
410-
else:
411-
z_coords = sdata['transcripts']['original_z_coords']
412-
413-
z_min, z_max = z_coords.min().compute(), z_coords.max().compute()
414-
tx_z_span = z_max - z_min
415-
416-
if tx_z_span == 0:
417-
print(f" Warning: transcript z-span is zero, skipping z scaling.")
423+
tx = sdata['transcripts']
424+
if 'original_z_coords' not in tx.columns:
425+
tx['original_z_coords'] = tx['z']
426+
z = tx['original_z_coords']
427+
428+
# Single compute call — avoids triggering the dask graph twice
429+
z_stats = z.agg(['min', 'max']).compute()
430+
z_min, z_max = float(z_stats['min']), float(z_stats['max'])
431+
z_span = z_max - z_min
432+
433+
if z_span == 0:
434+
import warnings
435+
warnings.warn("Transcript z-span is zero; skipping z scaling.")
418436
return sdata
419-
420-
microns_thickness_scale = sections_depth / tx_z_span
421-
scaled_z = (z_coords - z_min) * microns_thickness_scale
422-
# Center around z=0 so the middle of the transcript volume coincides with
423-
# the 2D section plane (z=0 in section space = the landmark-fitted position).
437+
438+
scaled_z = (z - z_min) * (sections_depth / z_span)
424439
if center_z:
425440
scaled_z = scaled_z - sections_depth / 2.0
426-
sdata['transcripts']['z'] = scaled_z
441+
tx['z'] = scaled_z
427442
return sdata
428443

429444
def _shift_transform_origin_along_z(tf, z_offset):

0 commit comments

Comments
 (0)