@@ -110,16 +110,28 @@ def get_fov_sdata(
110110
111111def filter_transcripts (sdata ,
112112 genes_to_show = 'all' ,
113- filter_is_gene = True ,
114- filter_assigned_to_cell = True ,
115- min_qv = 0 ,
116- filter_transcripts_to_cells = True ):
113+ filter_is_gene = False ,
114+ filter_assigned_to_cell = False ,
115+ min_qv = 20 ,
116+ filter_transcripts_to_cells = False ,
117+ sections = [],
118+ add_prefix = '' ,
119+ return_only = False ):
120+
121+ if return_only :
122+ tx_els = {}
117123 for tx_el in list (sdata .points .keys ()):
118124 if not tx_el .startswith ('transcripts' ):
119125 continue
120126
121127 tx = sdata .points [tx_el ]
122128
129+ if sections :
130+ tx_sec = np .unique (plot_sdata_fov ['transcripts-3' ]['section' ].compute ())[0 ]
131+ if tx_sec not in sections :
132+ print (f"Skipping { tx_el } (section { tx_sec } not in filter list)" )
133+ continue
134+
123135 # 1. Filter by gene name
124136 if genes_to_show != 'all' and 'feature_name' in tx .columns :
125137 tx = tx [tx ['feature_name' ].isin (genes_to_show )]
@@ -151,10 +163,17 @@ def filter_transcripts(sdata,
151163 cell_table = cell_table [cell_table .obs ['section' ] == s_n ]
152164 tx = tx [tx ['cell_id' ].isin (cell_table .obs ['cell_id' ])]
153165
154- sdata .points [tx_el ] = tx
155166 print (f"{ tx_el } : filters applied (lazy — will compute on render)" )
167+ if add_prefix :
168+ tx_el = f"{ add_prefix } _{ tx_el } "
169+ if return_only :
170+ tx_els [tx_el ] = tx
156171
157- return sdata
172+ if return_only :
173+ return tx_els
174+ else :
175+ sdata .points [tx_el ] = tx
176+ return sdata
158177
159178def is_dask (df ):
160179 return isinstance (df , dd .DataFrame )
@@ -987,4 +1006,75 @@ def _resolve_color_map(color_spec, categories):
9871006 f"{ len (color_map )} gene colours"
9881007 )
9891008
990- return sdata
1009+ return sdata
1010+
1011+ def make_column_colormap (
1012+ source ,
1013+ column_name ,
1014+ colors = None ,
1015+ colormap_name = 'tab20' ,
1016+ default_color = '#808080' ,
1017+ add_to_uns = False ,
1018+ ):
1019+ import anndata as ad
1020+
1021+ # --- Extract the series ---
1022+ if isinstance (source , ad .AnnData ):
1023+ series = source .obs [column_name ]
1024+ elif isinstance (source , pd .Series ):
1025+ series = source
1026+ else :
1027+ # dask or pandas DataFrame
1028+ series = source [column_name ]
1029+
1030+ # --- Get unique categories ---
1031+ if hasattr (series , 'compute' ):
1032+ # dask: use cat.as_known() if categorical, else compute unique
1033+ if hasattr (series , 'cat' ):
1034+ categories = series .cat .as_known ().cat .categories .tolist ()
1035+ else :
1036+ categories = sorted (series .unique ().compute ().dropna ().tolist ())
1037+ else :
1038+ if hasattr (series , 'cat' ):
1039+ categories = series .cat .categories .tolist ()
1040+ else :
1041+ categories = sorted (series .dropna ().unique ().tolist ())
1042+
1043+ # --- Build color mapping ---
1044+ cmap = plt .get_cmap (colormap_name , len (categories ))
1045+ auto_colors = {cat : mcolors .to_hex (cmap (i )) for i , cat in enumerate (categories )}
1046+
1047+ # Override with any explicitly provided colors
1048+ color_map = {cat : colors .get (cat , auto_colors .get (cat , default_color ))
1049+ for cat in categories }
1050+ if colors :
1051+ color_map .update ({k : v for k , v in colors .items () if k in color_map })
1052+
1053+ # --- Optionally write back to AnnData uns ---
1054+ if add_to_uns :
1055+ if not isinstance (source , ad .AnnData ):
1056+ raise ValueError ("add_to_uns requires an AnnData source" )
1057+ if not hasattr (series , 'cat' ):
1058+ source .obs [column_name ] = pd .Categorical (source .obs [column_name ])
1059+ source .uns [f'{ column_name } _colors' ] = [
1060+ color_map [cat ] for cat in source .obs [column_name ].cat .categories
1061+ ]
1062+
1063+ return color_map
1064+
1065+ def recolor_tx_layer (viewer , el_name , sdata , color_col , colors_dict = None , cmap = 'tab20' ):
1066+ tx_layer = viewer .layers [el_name ]
1067+ tx_colors = make_column_colormap (
1068+ sdata [el_name ],
1069+ color_col ,
1070+ colors = colors_dict ,
1071+ colormap_name = cmap ,
1072+ )
1073+ tx_data = sdata [el_name ][color_col ].compute ()
1074+ tx_fns = tx_data .values
1075+ tx_layer .properties = {
1076+ ** tx_layer .properties ,
1077+ color_col : tx_fns ,
1078+ }
1079+ tx_layer .face_color_cycle = tx_colors
1080+ tx_layer .face_color = color_col
0 commit comments