Skip to content

Commit ae4de01

Browse files
committed
updated mapping sections
1 parent a2469aa commit ae4de01

3 files changed

Lines changed: 345 additions & 230 deletions

File tree

src/xenium_analysis_tools/map_xenium/map_dataset_sections.py

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,113 +6,157 @@
66
from cell_type_mapper.test_utils.cache_wrapper import AbcCacheWrapper
77

88
# Local imports
9-
from xenium_analysis_tools.map_xenium.map_sections import (
10-
get_v1_merfish_subclasses,
9+
from xenium_analysis_tools.map_xenium.map_sections import (
1110
get_abc_paths,
1211
get_sections_to_process,
12+
get_v1_merfish_cells,
13+
get_nodes_to_drop,
1314
map_single_section,
1415
)
1516
from xenium_analysis_tools.utils.io_utils import (
1617
load_config,
1718
setup_logging,
19+
get_partial_dataset,
20+
is_complete_mapping_results,
1821
)
1922

2023
# Environment setup (limit threads for numpy operations)
2124
os.environ['NUMEXPR_NUM_THREADS'] = '1'
2225
os.environ['MKL_NUM_THREADS'] = '1'
2326
os.environ['OMP_NUM_THREADS'] = '1'
2427

25-
def map_sections(dataset_name: str, config_path: str, select_sections: list[int]|None = None, sections_parent_folder='data'):
26-
# Early validation
28+
def map_sections(dataset_name: str, config_path: str=None, select_sections: list[int]|None = None, sections_parent_folder='data'):
29+
# ---- Set up ----
2730
config = load_config(config_path)
28-
processing_config = config['processing_control']
31+
32+
# Paths/directories
2933
paths = config['paths']
30-
sections_path = Path(paths[f'{sections_parent_folder}_root']) / f"{dataset_name}{processing_config['save_processed_dataset_suffix']}"
34+
processing_config = config['processing_control']
35+
mapping_config = config['mapping_config']
36+
sections_sd_path = Path(paths[f'{sections_parent_folder}_root']) / f"{dataset_name}{processing_config['save_processed_dataset_suffix']}"
37+
save_mapped_sections_parent_folder = processing_config['save_mapped_data_parent_folder']
38+
save_mapped_sections_path = Path(paths[f'{save_mapped_sections_parent_folder}_root']) / f"{dataset_name}{processing_config['save_mapped_dataset_suffix']}"
39+
save_mapped_sections_path.mkdir(parents=True, exist_ok=True)
40+
41+
# Logger
42+
logger, log_file_path = setup_logging(save_mapped_sections_path)
3143

44+
# Print out where sections are being saved
45+
logger.info(f"Dataset Name: {dataset_name}")
46+
logger.info(f"Configuration loaded from {config_path}")
47+
logger.info(f"Sections are being loaded from: {sections_sd_path}")
48+
logger.info(f"Mapped sections will be saved to: {save_mapped_sections_path}")
49+
50+
# If specified, copy sections from data folder instead of re-generating
51+
if processing_config['check_data_folder_mapped']:
52+
logger.info("Checking and copying sections from data folder if exists...")
53+
data_folder_sections_path = Path(paths['data_root']) / f'{dataset_name}{processing_config["save_mapped_dataset_suffix"]}'
54+
get_partial_dataset(source_path=data_folder_sections_path,
55+
dest_path=save_mapped_sections_path,
56+
pattern='section_*',
57+
subset_ids=select_sections,
58+
is_complete_func=is_complete_mapping_results,
59+
func_args={'input_folder_name': mapping_config.get('input_data_folder_name', 'input_data'),
60+
'mapped_folder_name': mapping_config.get('mapped_data_folder_name', 'mapped_data'),
61+
'input_data_files': [mapping_config.get('input_h5ad_name', 'input_cellxgene.h5ad')],
62+
'mapped_data_files': [mapping_config.get('basic_results_name', 'basic_results.csv'),
63+
mapping_config.get('extended_results_name', 'extended_results.json'),
64+
mapping_config.get('mapped_data_h5ad_name', 'mapped_cellxgene.h5ad')]
65+
}
66+
)
67+
68+
# Get sections to map
3269
try:
33-
section_zarrs = get_sections_to_process(sections_path, select_sections)
34-
print(f"Processing {len(section_zarrs)} sections: {[p.stem.split('_')[-1] for p in section_zarrs]}")
70+
section_zarrs = get_sections_to_process(sections_sd_path, select_sections)
71+
print(f"\nProcessing {len(section_zarrs)} sections: {[p.stem.split('_')[-1] for p in section_zarrs]}")
3572
except (FileNotFoundError, ValueError) as e:
3673
print(f"Error: {e}")
3774
sys.exit(1)
38-
39-
# ---- Set up ----
40-
save_results_path = Path(paths[f'{processing_config['save_mapped_data_parent_folder']}_root']) / f"{dataset_name}{processing_config['save_mapped_dataset_suffix']}"
41-
save_results_path.mkdir(parents=True, exist_ok=True)
42-
logger, log_file_path = setup_logging(save_results_path)
43-
logger.info(f"Running: {dataset_name}")
44-
45-
# Paths
75+
76+
# ABC Atlas paths
4677
abc_atlas_path = Path(paths['abc_path'])
4778
abc_cache = AbcCacheWrapper.from_local_cache(abc_atlas_path)
4879
precomputed_stats_path, mouse_markers_path, gene_mapper_db_path = get_abc_paths(abc_cache)
4980

50-
# Get configurations
51-
mapping_config = config['mapping_config']
52-
mapping_params = mapping_config['mapping_params']
53-
54-
# Get section zarrs
55-
section_zarrs = get_sections_to_process(sections_path, select_sections)
56-
57-
logger.info(f"Found {len(section_zarrs)} section zarrs to map.")
58-
59-
# Filters
81+
# ----- Filtering -----
82+
# Cell and gene filters
6083
var_filters = mapping_config.get('var_filters', None)
6184
obs_filters = mapping_config.get('obs_filters', None)
85+
if obs_filters:
86+
logger.info("Applying filters to cells:")
87+
for col,filt in obs_filters.items():
88+
logger.info(f"{col}: {filt}")
89+
if var_filters:
90+
logger.info("Applying filters to genes:")
91+
for col,filt in var_filters.items():
92+
logger.info(f"{col}: {filt}")
6293

63-
# Determine nodes to drop based on V1 subclass cells only option
64-
if mapping_config.get('v1_subclass_cells_only',False):
65-
v1_cells_path = Path(paths['scratch_root']) / mapping_config.get('v1_merfish_cells_path', None)
66-
v1_min_cells = mapping_config.get('v1_visp_cluster_min_cells', 0)
67-
nodes_to_drop = get_v1_merfish_subclasses(abc_cache, output_path=v1_cells_path, visp_cluster_min_cells=v1_min_cells)
68-
logger.info(f"Number of nodes to drop: {len(nodes_to_drop) if nodes_to_drop else 0}")
69-
else:
70-
nodes_to_drop = None
94+
# Taxonomy filters - nodes to drop for mapping
95+
nodes_to_drop=[]
96+
# If specified any specific nodes to drop in config
97+
drop_nodes_dict = mapping_config.get('drop_nodes_dict', None)
98+
if drop_nodes_dict:
99+
for h_level in drop_nodes_dict:
100+
nodes_to_drop.extend([(h_level, cl) for cl in drop_nodes_dict[h_level]])
101+
logger.info(f"Dropping {len(nodes_to_drop)} nodes based on drop_nodes_dict.")
102+
# Filter to only V1 cells
103+
filter_v1_types_config = mapping_config.get('filter_mapping_v1_types', None)
104+
if filter_v1_types_config and filter_v1_types_config.get('enabled', False):
105+
h_level = filter_v1_types_config.get('h_level', 'subclass')
106+
min_cells = filter_v1_types_config.get('min_cells', 0)
107+
v1_types_df_name = filter_v1_types_config.get('saved_df_name', None)
108+
if v1_types_df_name:
109+
v1_types_path = Path(paths['scratch_root']) / v1_types_df_name
110+
else:
111+
v1_types_path = None
112+
v1_merfish_cells = get_v1_merfish_cells(abc_cache, output_path=v1_types_path)
113+
v1_nodes_to_drop = get_nodes_to_drop(v1_merfish_cells, abc_cache, h_level=h_level, min_cells=min_cells)
114+
logger.info(f"Dropping {len(v1_nodes_to_drop)} {h_level} nodes not present in V1 MERFISH data with at least {min_cells} cells.")
115+
nodes_to_drop.extend(v1_nodes_to_drop)
116+
117+
# ----- Mapper parameters -----
118+
mapping_params = mapping_config.get('mapping_params', {})
71119
mapping_params['nodes_to_drop'] = nodes_to_drop
72120

73-
# Number of workers
121+
# n_processors for mapper
74122
num_workers = mapping_params.get('num_workers', None)
75123
if num_workers == 'max':
76124
num_workers = os.cpu_count()
77125
elif num_workers is None:
78126
num_workers = 4
79127
if num_workers > os.cpu_count():
80128
num_workers = os.cpu_count()
81-
129+
82130
# Type assignment parameters for mapper
83131
type_assignment = {
84132
'normalization': mapping_params.get('normalization', 'raw'),
85133
'bootstrap_iteration': int(mapping_params.get('bootstrap_iteration', 100)),
86134
'bootstrap_factor': float(mapping_params.get('bootstrap_factor', 0.5)),
87135
'n_runners_up': int(mapping_params.get('n_runner_ups', 0)),
88136
'chunk_size': int(mapping_params.get('chunk_size', 5000)),
89-
'n_processors': int(mapping_params.get('num_workers', 4)),
137+
'n_processors': num_workers,
90138
'rng_seed': int(mapping_params.get('rng_seed', 42))
91139
}
92140

93141
# ----- Process sections -----
94142
successful_sections = []
95143
failed_sections = []
96-
97144
logger.info(f"\n=== Starting processing of {len(section_zarrs)} sections ===")
98-
99-
# Create progress bar
100145
with tqdm.tqdm(total=len(section_zarrs), desc="Processing sections",
101-
unit="section", position=0, leave=True) as pbar:
102-
146+
unit="section", position=0, leave=True) as pbar:
147+
103148
for idx, section_path in enumerate(section_zarrs, 1):
104149
section_name = section_path.stem
105150

106151
# Update progress bar description with current section
107152
pbar.set_description(f"Processing {section_name}")
108-
109153
logger.info(f"\n[{idx}/{len(section_zarrs)}] Processing {section_name}...")
110154

111155
try:
112156
success = map_single_section(
113157
section_path=section_path,
114158
logger=logger,
115-
save_results_path=save_results_path,
159+
save_results_path=save_mapped_sections_path,
116160
mapping_config=mapping_config,
117161
mapping_params=mapping_params,
118162
type_assignment=type_assignment,
@@ -126,7 +170,7 @@ def map_sections(dataset_name: str, config_path: str, select_sections: list[int]
126170
successful_sections.append(section_name)
127171
pbar.set_postfix({"✓": len(successful_sections), "✗": len(failed_sections)})
128172
else:
129-
failed_sections.append((section_name, "Mapping validation failed"))
173+
failed_sections.append((section_name, "Mapping section failed"))
130174
pbar.set_postfix({"✓": len(successful_sections), "✗": len(failed_sections)})
131175

132176
except Exception as e:
@@ -140,7 +184,7 @@ def map_sections(dataset_name: str, config_path: str, select_sections: list[int]
140184

141185
# Final update
142186
pbar.set_description("Processing complete")
143-
187+
144188
# Final summary
145189
print(f"\n=== FINAL RESULTS ===")
146190
print(f"Successful: {len(successful_sections)}/{len(section_zarrs)}")
@@ -150,5 +194,5 @@ def map_sections(dataset_name: str, config_path: str, select_sections: list[int]
150194
print("Failed sections:")
151195
for name, error in failed_sections:
152196
print(f" - {name}: {error}")
153-
197+
154198
logger.info(f"Pipeline completed. Success: {len(successful_sections)}, Failed: {len(failed_sections)}")

0 commit comments

Comments
 (0)