1212import dask .dataframe as dd
1313import matplotlib .pyplot as plt
1414import matplotlib .colors as mcolors
15+ import napari
1516
1617def filter_labels (sdata , label_elements = 'cell_labels' , table = 'table' , key_col = 'cell_labels' ):
1718 # Get all label elements that match the specified prefix
@@ -138,22 +139,43 @@ def _is_2d_element(el):
138139 fov_data .points = filtered_points
139140 return fov_data
140141
141- def filter_transcripts (sdata ,
142- genes_to_show = 'all' ,
143- filter_is_gene = False ,
144- filter_assigned_to_cell = False ,
142+ def filter_transcripts (sdata ,
143+ genes_to_show = 'all' ,
144+ filter_is_gene = True ,
145+ filter_assigned_to_cell = True ,
145146 min_qv = 20 ,
146- filter_transcripts_to_cells = False ,
147+ filter_transcripts_to_cells = True ,
147148 sections = [],
148149 add_prefix = '' ,
150+ source_el = None ,
149151 return_only = False ):
152+ """
153+ Filter transcript point elements in sdata.
154+
155+ If add_prefix is empty, the original elements are updated in place.
156+ If add_prefix is set, the original elements are never modified; filtered
157+ copies are written under f"{add_prefix}_{source_el_name}".
158+
159+ Parameters
160+ ----------
161+ source_el : str or None
162+ When add_prefix is set, filter only this specific element instead of
163+ iterating all transcripts* elements. The source element is never
164+ modified. When None, all transcripts* elements are processed.
165+ """
166+ if add_prefix and source_el is not None :
167+ candidates = [source_el ]
168+ else :
169+ candidates = [k for k in sdata .points .keys () if k .startswith ('transcripts' )]
150170
151171 if return_only :
152172 tx_els = {}
153- for tx_el in list (sdata .points .keys ()):
154- print (f'Filtering { tx_el } ...' )
155- if not tx_el .startswith ('transcripts' ):
173+
174+ for tx_el in candidates :
175+ if tx_el not in sdata .points :
176+ print (f"Warning: { tx_el !r} not found in sdata.points, skipping" )
156177 continue
178+ print (f'Filtering { tx_el } ...' )
157179
158180 tx = sdata .points [tx_el ]
159181
@@ -194,13 +216,13 @@ def filter_transcripts(sdata,
194216 cell_table = cell_table [cell_table .obs ['section' ] == s_n ]
195217 tx = tx [tx ['cell_id' ].isin (cell_table .obs ['cell_id' ])]
196218
197- print ( f"{ tx_el } : filters applied (lazy — will compute on render)" )
198- if add_prefix :
199- tx_el = f" { add_prefix } _ { tx_el } "
219+ dest_el = f"{ add_prefix } _ { tx_el } " if add_prefix else tx_el
220+ print ( f" { tx_el } → { dest_el !r } : filters applied (lazy — will compute on render)" )
221+
200222 if return_only :
201- tx_els [tx_el ] = tx
223+ tx_els [dest_el ] = tx
202224 else :
203- sdata .points [tx_el ] = tx
225+ sdata .points [dest_el ] = tx
204226
205227 if return_only :
206228 return tx_els
@@ -215,7 +237,7 @@ def get_sample_val(series):
215237 return series .head (1 ).iloc [0 ]
216238 return series .iloc [0 ]
217239
218- def filter_cells (sdata , cell_filters = []):
240+ def filter_cells (sdata , el = 'table' , cell_filters = []):
219241 import operator as op_module
220242
221243 # Map operator strings to functions
@@ -232,7 +254,7 @@ def filter_cells(sdata, cell_filters=[]):
232254 kept_cell_ids = None # will be populated if cell filters are applied
233255
234256 if cell_filters :
235- tbl = sdata .tables ['table' ]
257+ tbl = sdata .tables [el ]
236258 obs = tbl .obs .copy ()
237259
238260 # Build combined boolean mask
@@ -262,8 +284,7 @@ def filter_cells(sdata, cell_filters=[]):
262284 filtered_tbl .obs [region_key ] = filtered_tbl .obs [region_key ].cat .remove_unused_categories ()
263285 attrs ['region' ] = filtered_tbl .obs [region_key ].unique ().tolist ()
264286
265- sdata .tables ['table' ] = filtered_tbl
266- kept_cell_ids = set (filtered_tbl .obs [instance_key ].astype (str ).tolist ())
287+ sdata .tables [el ] = filtered_tbl
267288 print (f"\n { mask .sum ()} / { len (obs )} cells kept after all filters." )
268289 return sdata
269290
@@ -877,6 +898,54 @@ def expand_for_napari(
877898
878899 return plot_sdata
879900
901+ def set_solid_label_color (sdata , table_key , color , col_name = 'label_color_group' ):
902+ table = sdata [table_key ]
903+ table .obs [col_name ] = pd .Categorical (['all' ] * table .n_obs )
904+ table .uns [f'{ col_name } _colors' ] = np .array ([color ])
905+
906+ def apply_layer_style (layer , layer_styles ):
907+ def _find_style (layer_name ):
908+ """Return params for the longest matching key, or None."""
909+ matches = [k for k in layer_styles if k in layer_name ]
910+ return layer_styles [max (matches , key = len )] if matches else None
911+
912+ params = _find_style (layer .name )
913+ if params is None :
914+ return
915+
916+ THUMBNAIL_PROPS = {'colormap' , 'contrast_limits' , 'blending' , 'opacity' }
917+ TYPE_GATES = {
918+ 'colormap' : napari .layers .Image ,
919+ 'contour' : napari .layers .Labels ,
920+ 'face_color' : napari .layers .Points ,
921+ }
922+ # Keys handled elsewhere — skip silently
923+ SKIP = {'label_color' , 'face_color_column' }
924+
925+ thumbnail_updates = {k : v for k , v in params .items ()
926+ if k in THUMBNAIL_PROPS and k not in SKIP }
927+ other_updates = {k : v for k , v in params .items ()
928+ if k not in THUMBNAIL_PROPS and k not in SKIP }
929+
930+ try :
931+ for prop , val in thumbnail_updates .items ():
932+ gate = TYPE_GATES .get (prop )
933+ if gate and not isinstance (layer , gate ):
934+ continue
935+ if hasattr (layer , prop ):
936+ setattr (layer , prop , val )
937+ except RuntimeError :
938+ pass # 3D layers fail _update_thumbnail; settings apply on render
939+
940+ for prop , val in other_updates .items ():
941+ gate = TYPE_GATES .get (prop )
942+ if gate and not isinstance (layer , gate ):
943+ continue
944+ if hasattr (layer , prop ):
945+ try :
946+ setattr (layer , prop , val )
947+ except Exception :
948+ pass
880949
881950def add_napari_colormaps (
882951 sdata ,
0 commit comments