1+ from pathlib import Path
2+ import tqdm
3+ import os
4+
5+ # MapMyCells mapping pipeline
6+ from cell_type_mapper .test_utils .cache_wrapper import AbcCacheWrapper
7+
8+ # Local imports
9+ from xenium_analysis_tools .map_xenium .map_sections import (
10+ get_v1_merfish_subclasses ,
11+ get_abc_paths ,
12+ get_sections_to_process ,
13+ map_single_section ,
14+ )
15+ from xenium_analysis_tools .utils .io_utils import (
16+ load_config ,
17+ setup_logging ,
18+ )
19+
20+ # Environment setup (limit threads for numpy operations)
21+ os .environ ['NUMEXPR_NUM_THREADS' ] = '1'
22+ os .environ ['MKL_NUM_THREADS' ] = '1'
23+ os .environ ['OMP_NUM_THREADS' ] = '1'
24+
25+ def map_sections (dataset_name : str , config_path : str , select_sections : list [int ]| None = None , sections_parent_folder = 'data' ):
26+ # Early validation
27+ config = load_config (config_path )
28+ processing_config = config ['processing_control' ]
29+ paths = config ['paths' ]
30+ sections_path = Path (paths [f'{ sections_parent_folder } _root' ]) / f"{ dataset_name } { processing_config ['save_processed_dataset_suffix' ]} "
31+
32+ try :
33+ section_zarrs = get_sections_to_process (sections_path , select_sections )
34+ print (f"Processing { len (section_zarrs )} sections: { [p .stem .split ('_' )[- 1 ] for p in section_zarrs ]} " )
35+ except (FileNotFoundError , ValueError ) as e :
36+ print (f"Error: { e } " )
37+ sys .exit (1 )
38+
39+ # ---- Set up ----
40+ save_results_path = Path (paths [f'{ processing_config ['save_mapped_data_parent_folder' ]} _root' ]) / f"{ dataset_name } { processing_config ['save_mapped_dataset_suffix' ]} "
41+ save_results_path .mkdir (parents = True , exist_ok = True )
42+ logger , log_file_path = setup_logging (save_results_path )
43+ logger .info (f"Running: { dataset_name } " )
44+
45+ # Paths
46+ abc_atlas_path = Path (paths ['abc_path' ])
47+ abc_cache = AbcCacheWrapper .from_local_cache (abc_atlas_path )
48+ precomputed_stats_path , mouse_markers_path , gene_mapper_db_path = get_abc_paths (abc_cache )
49+
50+ # Get configurations
51+ mapping_config = config ['mapping_config' ]
52+ mapping_params = mapping_config ['mapping_params' ]
53+
54+ # Get section zarrs
55+ section_zarrs = get_sections_to_process (sections_path , select_sections )
56+
57+ logger .info (f"Found { len (section_zarrs )} section zarrs to map." )
58+
59+ # Filters
60+ var_filters = mapping_config .get ('var_filters' , None )
61+ obs_filters = mapping_config .get ('obs_filters' , None )
62+
63+ # Determine nodes to drop based on V1 subclass cells only option
64+ if mapping_config .get ('v1_subclass_cells_only' ,False ):
65+ v1_cells_path = Path (paths ['scratch_root' ]) / mapping_config .get ('v1_merfish_cells_path' , None )
66+ v1_min_cells = mapping_config .get ('v1_visp_cluster_min_cells' , 0 )
67+ nodes_to_drop = get_v1_merfish_subclasses (abc_cache , output_path = v1_cells_path , visp_cluster_min_cells = v1_min_cells )
68+ logger .info (f"Number of nodes to drop: { len (nodes_to_drop ) if nodes_to_drop else 0 } " )
69+ else :
70+ nodes_to_drop = None
71+ mapping_params ['nodes_to_drop' ] = nodes_to_drop
72+
73+ # Number of workers
74+ num_workers = mapping_params .get ('num_workers' , None )
75+ if num_workers == 'max' :
76+ num_workers = os .cpu_count ()
77+ elif num_workers is None :
78+ num_workers = 4
79+ if num_workers > os .cpu_count ():
80+ num_workers = os .cpu_count ()
81+
82+ # Type assignment parameters for mapper
83+ type_assignment = {
84+ 'normalization' : mapping_params .get ('normalization' , 'raw' ),
85+ 'bootstrap_iteration' : int (mapping_params .get ('bootstrap_iteration' , 100 )),
86+ 'bootstrap_factor' : float (mapping_params .get ('bootstrap_factor' , 0.5 )),
87+ 'n_runners_up' : int (mapping_params .get ('n_runner_ups' , 0 )),
88+ 'chunk_size' : int (mapping_params .get ('chunk_size' , 5000 )),
89+ 'n_processors' : int (mapping_params .get ('num_workers' , 4 )),
90+ 'rng_seed' : int (mapping_params .get ('rng_seed' , 42 ))
91+ }
92+
93+ # ----- Process sections -----
94+ successful_sections = []
95+ failed_sections = []
96+
97+ logger .info (f"\n === Starting processing of { len (section_zarrs )} sections ===" )
98+
99+ # Create progress bar
100+ with tqdm .tqdm (total = len (section_zarrs ), desc = "Processing sections" ,
101+ unit = "section" , position = 0 , leave = True ) as pbar :
102+
103+ for idx , section_path in enumerate (section_zarrs , 1 ):
104+ section_name = section_path .stem
105+
106+ # Update progress bar description with current section
107+ pbar .set_description (f"Processing { section_name } " )
108+
109+ logger .info (f"\n [{ idx } /{ len (section_zarrs )} ] Processing { section_name } ..." )
110+
111+ try :
112+ success = map_single_section (
113+ section_path = section_path ,
114+ logger = logger ,
115+ save_results_path = save_results_path ,
116+ mapping_config = mapping_config ,
117+ mapping_params = mapping_params ,
118+ type_assignment = type_assignment ,
119+ precomputed_stats_path = precomputed_stats_path ,
120+ mouse_markers_path = mouse_markers_path ,
121+ gene_mapper_db_path = gene_mapper_db_path ,
122+ var_filters = var_filters ,
123+ obs_filters = obs_filters
124+ )
125+ if success :
126+ successful_sections .append (section_name )
127+ pbar .set_postfix ({"✓" : len (successful_sections ), "✗" : len (failed_sections )})
128+ else :
129+ failed_sections .append ((section_name , "Mapping validation failed" ))
130+ pbar .set_postfix ({"✓" : len (successful_sections ), "✗" : len (failed_sections )})
131+
132+ except Exception as e :
133+ failed_sections .append ((section_name , str (e )))
134+ logger .error (f"Failed to process { section_name } : { e } " )
135+ pbar .set_postfix ({"✓" : len (successful_sections ), "✗" : len (failed_sections )})
136+ continue
137+ finally :
138+ # Always update progress bar
139+ pbar .update (1 )
140+
141+ # Final update
142+ pbar .set_description ("Processing complete" )
143+
144+ # Final summary
145+ print (f"\n === FINAL RESULTS ===" )
146+ print (f"Successful: { len (successful_sections )} /{ len (section_zarrs )} " )
147+ if successful_sections :
148+ print (f"Successfully processed: { [s for s in successful_sections ]} " )
149+ if failed_sections :
150+ print ("Failed sections:" )
151+ for name , error in failed_sections :
152+ print (f" - { name } : { error } " )
153+
154+ logger .info (f"Pipeline completed. Success: { len (successful_sections )} , Failed: { len (failed_sections )} " )
0 commit comments