Skip to content

Commit 4721364

Browse files
committed
fixed transcript filtering logic
1 parent a6081bd commit 4721364

1 file changed

Lines changed: 86 additions & 17 deletions

File tree

src/xenium_analysis_tools/alignment/format_for_napari.py

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import dask.dataframe as dd
1313
import matplotlib.pyplot as plt
1414
import matplotlib.colors as mcolors
15+
import napari
1516

1617
def filter_labels(sdata, label_elements='cell_labels', table='table', key_col='cell_labels'):
1718
# Get all label elements that match the specified prefix
@@ -138,22 +139,43 @@ def _is_2d_element(el):
138139
fov_data.points = filtered_points
139140
return fov_data
140141

141-
def filter_transcripts(sdata,
142-
genes_to_show='all',
143-
filter_is_gene=False,
144-
filter_assigned_to_cell=False,
142+
def filter_transcripts(sdata,
143+
genes_to_show='all',
144+
filter_is_gene=True,
145+
filter_assigned_to_cell=True,
145146
min_qv=20,
146-
filter_transcripts_to_cells=False,
147+
filter_transcripts_to_cells=True,
147148
sections=[],
148149
add_prefix='',
150+
source_el=None,
149151
return_only=False):
152+
"""
153+
Filter transcript point elements in sdata.
154+
155+
If add_prefix is empty, the original elements are updated in place.
156+
If add_prefix is set, the original elements are never modified; filtered
157+
copies are written under f"{add_prefix}_{source_el_name}".
158+
159+
Parameters
160+
----------
161+
source_el : str or None
162+
When add_prefix is set, filter only this specific element instead of
163+
iterating all transcripts* elements. The source element is never
164+
modified. When None, all transcripts* elements are processed.
165+
"""
166+
if add_prefix and source_el is not None:
167+
candidates = [source_el]
168+
else:
169+
candidates = [k for k in sdata.points.keys() if k.startswith('transcripts')]
150170

151171
if return_only:
152172
tx_els = {}
153-
for tx_el in list(sdata.points.keys()):
154-
print(f'Filtering {tx_el}...')
155-
if not tx_el.startswith('transcripts'):
173+
174+
for tx_el in candidates:
175+
if tx_el not in sdata.points:
176+
print(f"Warning: {tx_el!r} not found in sdata.points, skipping")
156177
continue
178+
print(f'Filtering {tx_el}...')
157179

158180
tx = sdata.points[tx_el]
159181

@@ -194,13 +216,13 @@ def filter_transcripts(sdata,
194216
cell_table = cell_table[cell_table.obs['section'] == s_n]
195217
tx = tx[tx['cell_id'].isin(cell_table.obs['cell_id'])]
196218

197-
print(f"{tx_el}: filters applied (lazy — will compute on render)")
198-
if add_prefix:
199-
tx_el = f"{add_prefix}_{tx_el}"
219+
dest_el = f"{add_prefix}_{tx_el}" if add_prefix else tx_el
220+
print(f"{tx_el}{dest_el!r}: filters applied (lazy — will compute on render)")
221+
200222
if return_only:
201-
tx_els[tx_el] = tx
223+
tx_els[dest_el] = tx
202224
else:
203-
sdata.points[tx_el] = tx
225+
sdata.points[dest_el] = tx
204226

205227
if return_only:
206228
return tx_els
@@ -215,7 +237,7 @@ def get_sample_val(series):
215237
return series.head(1).iloc[0]
216238
return series.iloc[0]
217239

218-
def filter_cells(sdata, cell_filters=[]):
240+
def filter_cells(sdata, el='table', cell_filters=[]):
219241
import operator as op_module
220242

221243
# Map operator strings to functions
@@ -232,7 +254,7 @@ def filter_cells(sdata, cell_filters=[]):
232254
kept_cell_ids = None # will be populated if cell filters are applied
233255

234256
if cell_filters:
235-
tbl = sdata.tables['table']
257+
tbl = sdata.tables[el]
236258
obs = tbl.obs.copy()
237259

238260
# Build combined boolean mask
@@ -262,8 +284,7 @@ def filter_cells(sdata, cell_filters=[]):
262284
filtered_tbl.obs[region_key] = filtered_tbl.obs[region_key].cat.remove_unused_categories()
263285
attrs['region'] = filtered_tbl.obs[region_key].unique().tolist()
264286

265-
sdata.tables['table'] = filtered_tbl
266-
kept_cell_ids = set(filtered_tbl.obs[instance_key].astype(str).tolist())
287+
sdata.tables[el] = filtered_tbl
267288
print(f"\n{mask.sum()} / {len(obs)} cells kept after all filters.")
268289
return sdata
269290

@@ -877,6 +898,54 @@ def expand_for_napari(
877898

878899
return plot_sdata
879900

901+
def set_solid_label_color(sdata, table_key, color, col_name='label_color_group'):
902+
table = sdata[table_key]
903+
table.obs[col_name] = pd.Categorical(['all'] * table.n_obs)
904+
table.uns[f'{col_name}_colors'] = np.array([color])
905+
906+
def apply_layer_style(layer, layer_styles):
907+
def _find_style(layer_name):
908+
"""Return params for the longest matching key, or None."""
909+
matches = [k for k in layer_styles if k in layer_name]
910+
return layer_styles[max(matches, key=len)] if matches else None
911+
912+
params = _find_style(layer.name)
913+
if params is None:
914+
return
915+
916+
THUMBNAIL_PROPS = {'colormap', 'contrast_limits', 'blending', 'opacity'}
917+
TYPE_GATES = {
918+
'colormap': napari.layers.Image,
919+
'contour': napari.layers.Labels,
920+
'face_color': napari.layers.Points,
921+
}
922+
# Keys handled elsewhere — skip silently
923+
SKIP = {'label_color', 'face_color_column'}
924+
925+
thumbnail_updates = {k: v for k, v in params.items()
926+
if k in THUMBNAIL_PROPS and k not in SKIP}
927+
other_updates = {k: v for k, v in params.items()
928+
if k not in THUMBNAIL_PROPS and k not in SKIP}
929+
930+
try:
931+
for prop, val in thumbnail_updates.items():
932+
gate = TYPE_GATES.get(prop)
933+
if gate and not isinstance(layer, gate):
934+
continue
935+
if hasattr(layer, prop):
936+
setattr(layer, prop, val)
937+
except RuntimeError:
938+
pass # 3D layers fail _update_thumbnail; settings apply on render
939+
940+
for prop, val in other_updates.items():
941+
gate = TYPE_GATES.get(prop)
942+
if gate and not isinstance(layer, gate):
943+
continue
944+
if hasattr(layer, prop):
945+
try:
946+
setattr(layer, prop, val)
947+
except Exception:
948+
pass
880949

881950
def add_napari_colormaps(
882951
sdata,

0 commit comments

Comments
 (0)