Skip to content

Commit c400d90

Browse files
authored
Refactor filter_transcripts function parameters
1 parent d4b3424 commit c400d90

1 file changed

Lines changed: 97 additions & 7 deletions

File tree

src/xenium_analysis_tools/alignment/format_for_napari.py

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,28 @@ def get_fov_sdata(
110110

111111
def filter_transcripts(sdata,
112112
genes_to_show='all',
113-
filter_is_gene=True,
114-
filter_assigned_to_cell=True,
115-
min_qv=0,
116-
filter_transcripts_to_cells=True):
113+
filter_is_gene=False,
114+
filter_assigned_to_cell=False,
115+
min_qv=20,
116+
filter_transcripts_to_cells=False,
117+
sections=[],
118+
add_prefix='',
119+
return_only=False):
120+
121+
if return_only:
122+
tx_els = {}
117123
for tx_el in list(sdata.points.keys()):
118124
if not tx_el.startswith('transcripts'):
119125
continue
120126

121127
tx = sdata.points[tx_el]
122128

129+
if sections:
130+
tx_sec = np.unique(plot_sdata_fov['transcripts-3']['section'].compute())[0]
131+
if tx_sec not in sections:
132+
print(f"Skipping {tx_el} (section {tx_sec} not in filter list)")
133+
continue
134+
123135
# 1. Filter by gene name
124136
if genes_to_show != 'all' and 'feature_name' in tx.columns:
125137
tx = tx[tx['feature_name'].isin(genes_to_show)]
@@ -151,10 +163,17 @@ def filter_transcripts(sdata,
151163
cell_table = cell_table[cell_table.obs['section'] == s_n]
152164
tx = tx[tx['cell_id'].isin(cell_table.obs['cell_id'])]
153165

154-
sdata.points[tx_el] = tx
155166
print(f"{tx_el}: filters applied (lazy — will compute on render)")
167+
if add_prefix:
168+
tx_el = f"{add_prefix}_{tx_el}"
169+
if return_only:
170+
tx_els[tx_el] = tx
156171

157-
return sdata
172+
if return_only:
173+
return tx_els
174+
else:
175+
sdata.points[tx_el] = tx
176+
return sdata
158177

159178
def is_dask(df):
160179
return isinstance(df, dd.DataFrame)
@@ -987,4 +1006,75 @@ def _resolve_color_map(color_spec, categories):
9871006
f"{len(color_map)} gene colours"
9881007
)
9891008

990-
return sdata
1009+
return sdata
1010+
1011+
def make_column_colormap(
1012+
source,
1013+
column_name,
1014+
colors=None,
1015+
colormap_name='tab20',
1016+
default_color='#808080',
1017+
add_to_uns=False,
1018+
):
1019+
import anndata as ad
1020+
1021+
# --- Extract the series ---
1022+
if isinstance(source, ad.AnnData):
1023+
series = source.obs[column_name]
1024+
elif isinstance(source, pd.Series):
1025+
series = source
1026+
else:
1027+
# dask or pandas DataFrame
1028+
series = source[column_name]
1029+
1030+
# --- Get unique categories ---
1031+
if hasattr(series, 'compute'):
1032+
# dask: use cat.as_known() if categorical, else compute unique
1033+
if hasattr(series, 'cat'):
1034+
categories = series.cat.as_known().cat.categories.tolist()
1035+
else:
1036+
categories = sorted(series.unique().compute().dropna().tolist())
1037+
else:
1038+
if hasattr(series, 'cat'):
1039+
categories = series.cat.categories.tolist()
1040+
else:
1041+
categories = sorted(series.dropna().unique().tolist())
1042+
1043+
# --- Build color mapping ---
1044+
cmap = plt.get_cmap(colormap_name, len(categories))
1045+
auto_colors = {cat: mcolors.to_hex(cmap(i)) for i, cat in enumerate(categories)}
1046+
1047+
# Override with any explicitly provided colors
1048+
color_map = {cat: colors.get(cat, auto_colors.get(cat, default_color))
1049+
for cat in categories}
1050+
if colors:
1051+
color_map.update({k: v for k, v in colors.items() if k in color_map})
1052+
1053+
# --- Optionally write back to AnnData uns ---
1054+
if add_to_uns:
1055+
if not isinstance(source, ad.AnnData):
1056+
raise ValueError("add_to_uns requires an AnnData source")
1057+
if not hasattr(series, 'cat'):
1058+
source.obs[column_name] = pd.Categorical(source.obs[column_name])
1059+
source.uns[f'{column_name}_colors'] = [
1060+
color_map[cat] for cat in source.obs[column_name].cat.categories
1061+
]
1062+
1063+
return color_map
1064+
1065+
def recolor_tx_layer(viewer, el_name, sdata, color_col, colors_dict=None, cmap='tab20'):
1066+
tx_layer = viewer.layers[el_name]
1067+
tx_colors = make_column_colormap(
1068+
sdata[el_name],
1069+
color_col,
1070+
colors=colors_dict,
1071+
colormap_name=cmap,
1072+
)
1073+
tx_data = sdata[el_name][color_col].compute()
1074+
tx_fns = tx_data.values
1075+
tx_layer.properties = {
1076+
**tx_layer.properties,
1077+
color_col: tx_fns,
1078+
}
1079+
tx_layer.face_color_cycle = tx_colors
1080+
tx_layer.face_color = color_col

0 commit comments

Comments
 (0)