|
2 | 2 |
|
3 | 3 | import matplotlib.pyplot as plt |
4 | 4 | import numpy as np |
| 5 | +import pandas as pd |
5 | 6 |
|
6 | 7 | from .calibration import channel_to_keV, get_energy_cal_from_dataset, make_energy_axis_from_length |
7 | 8 | from .peaks import preprocess |
@@ -198,3 +199,201 @@ def plot_identified_elements_confident( |
198 | 199 | "assign_scored": assign if assign is not None else None, |
199 | 200 | "label_df": label_df, |
200 | 201 | } |
| 202 | + |
| 203 | + |
| 204 | +def _select_overlay_indices(n_cells: int, n_show: int) -> np.ndarray: |
| 205 | + if n_cells <= 0: |
| 206 | + return np.array([], dtype=int) |
| 207 | + n_show = min(max(int(n_show), 1), n_cells) |
| 208 | + if n_cells <= n_show: |
| 209 | + return np.arange(n_cells, dtype=int) |
| 210 | + return np.unique(np.linspace(0, n_cells - 1, n_show, dtype=int)) |
| 211 | + |
| 212 | + |
| 213 | +def plot_overlaid_cell_spectra( |
| 214 | + aggregate_results, |
| 215 | + peak_summary, |
| 216 | + band_start, |
| 217 | + band_end, |
| 218 | + band_mode="channels", |
| 219 | + cal=None, |
| 220 | + show_energy_top_axis=True, |
| 221 | + n_overlay_spectra=50, |
| 222 | + peak_top_n_labels=6, |
| 223 | + peak_fwhm_mn_ev=67.8, |
| 224 | + show_peak_crosses=False, |
| 225 | + show_element_lines=True, |
| 226 | + band_color="gray", |
| 227 | + band_alpha=0.15, |
| 228 | + line_color="black", |
| 229 | + overlay_color="C0", |
| 230 | +): |
| 231 | + """Overlay up to N cell spectra for each group with optional band shading and element-line annotations.""" |
| 232 | + from .spectra import band_label_text, resolve_band_to_channels |
| 233 | + |
| 234 | + band_start_ch, band_end_ch = resolve_band_to_channels( |
| 235 | + band_start, |
| 236 | + band_end, |
| 237 | + band_mode=band_mode, |
| 238 | + cal=cal, |
| 239 | + ) |
| 240 | + band_label = band_label_text(band_start, band_end, band_mode) |
| 241 | + |
| 242 | + for name, res in aggregate_results.items(): |
| 243 | + stack = np.asarray(res["stack"], dtype=float) |
| 244 | + if stack.ndim != 2 or stack.shape[0] == 0: |
| 245 | + print(f"Skipping {name}: no spectra available.") |
| 246 | + continue |
| 247 | + |
| 248 | + show_idx = _select_overlay_indices(stack.shape[0], n_overlay_spectra) |
| 249 | + x = np.arange(stack.shape[1], dtype=float) |
| 250 | + |
| 251 | + fig, ax = plt.subplots(figsize=(11, 5)) |
| 252 | + for idx in show_idx: |
| 253 | + ax.plot(x, stack[idx], lw=0.8, alpha=0.25, color=overlay_color) |
| 254 | + |
| 255 | + ax.axvspan(band_start_ch, band_end_ch, color=band_color, alpha=band_alpha) |
| 256 | + |
| 257 | + ymax = float(np.nanmax(stack[show_idx])) if show_idx.size else 1.0 |
| 258 | + if not np.isfinite(ymax) or ymax <= 0: |
| 259 | + ymax = 1.0 |
| 260 | + |
| 261 | + if show_peak_crosses: |
| 262 | + peaks_df_overlay = peak_summary.get(name, {}).get("peaks_df", pd.DataFrame()) |
| 263 | + if peaks_df_overlay is not None and not peaks_df_overlay.empty and "x" in peaks_df_overlay.columns: |
| 264 | + x_peaks = peaks_df_overlay["x"].to_numpy(dtype=float) |
| 265 | + x_peaks = x_peaks[(x_peaks >= 0) & (x_peaks <= stack.shape[1] - 1)] |
| 266 | + if x_peaks.size: |
| 267 | + ax.scatter( |
| 268 | + x_peaks, |
| 269 | + np.full_like(x_peaks, 0.98 * ymax), |
| 270 | + marker="x", |
| 271 | + s=28, |
| 272 | + color=line_color, |
| 273 | + zorder=5, |
| 274 | + ) |
| 275 | + |
| 276 | + assign_df_overlay = peak_summary.get(name, {}).get("assign_df", pd.DataFrame()) |
| 277 | + if show_element_lines and cal is not None and assign_df_overlay is not None and not assign_df_overlay.empty: |
| 278 | + assign = assign_df_overlay.copy() |
| 279 | + |
| 280 | + if "height" not in assign.columns or assign["height"].isna().all(): |
| 281 | + assign["height"] = assign.get("area", 0.0) |
| 282 | + |
| 283 | + if "prominence" not in assign.columns or assign["prominence"].isna().all(): |
| 284 | + assign["prominence"] = assign["height"] |
| 285 | + |
| 286 | + if "delta_keV" not in assign.columns and {"lib_energy_keV", "energy_keV"}.issubset(assign.columns): |
| 287 | + assign["delta_keV"] = assign["lib_energy_keV"] - assign["energy_keV"] |
| 288 | + |
| 289 | + if "score" not in assign.columns: |
| 290 | + if "delta_keV" in assign.columns: |
| 291 | + sigma_E = max((float(peak_fwhm_mn_ev) / 1000.0) / 2.355, 1e-6) |
| 292 | + assign["z_mismatch"] = assign["delta_keV"].abs() / sigma_E |
| 293 | + else: |
| 294 | + assign["z_mismatch"] = 0.0 |
| 295 | + assign["score"] = assign["prominence"] / (1.0 + assign["z_mismatch"]) |
| 296 | + |
| 297 | + assign = assign.sort_values("score", ascending=False).head(int(peak_top_n_labels)) |
| 298 | + assign = assign.sort_values("lib_energy_keV").reset_index(drop=True) |
| 299 | + |
| 300 | + levels = np.linspace(0.93, 0.68, num=max(len(assign), 1)) |
| 301 | + start_eV = float(cal["start_eV"]) |
| 302 | + eV_per_ch = float(cal["eV_per_ch"]) |
| 303 | + |
| 304 | + for i, (_, row) in enumerate(assign.iterrows()): |
| 305 | + e_keV = float(row["lib_energy_keV"]) |
| 306 | + ch = (e_keV * 1000.0 - start_eV) / eV_per_ch |
| 307 | + if 0 <= ch <= stack.shape[1] - 1: |
| 308 | + label = str(row.get("label", f"{row.get('element', '')} {row.get('line', '')}")).strip() |
| 309 | + ax.axvline(ch, ls="--", lw=1.0, alpha=0.55, color=line_color) |
| 310 | + ax.text( |
| 311 | + ch, |
| 312 | + levels[min(i, len(levels) - 1)] * ymax, |
| 313 | + label, |
| 314 | + rotation=90, |
| 315 | + va="top", |
| 316 | + ha="center", |
| 317 | + fontsize=8, |
| 318 | + color=line_color, |
| 319 | + bbox=dict(facecolor="white", alpha=0.75, edgecolor="none", pad=0.6), |
| 320 | + ) |
| 321 | + |
| 322 | + ax.set_xlabel("Channel") |
| 323 | + ax.set_ylabel("Counts") |
| 324 | + ax.set_title(f"{name}: {len(show_idx)} overlaid cell spectra ({band_label}; shaded gray)") |
| 325 | + |
| 326 | + if show_energy_top_axis and cal is not None: |
| 327 | + add_energy_top_axis(ax, cal=cal, n=stack.shape[1]) |
| 328 | + |
| 329 | + fig.tight_layout() |
| 330 | + plt.show() |
| 331 | + |
| 332 | + |
| 333 | +def plot_stacked_band_histograms( |
| 334 | + aggregate_results, |
| 335 | + shared_bins, |
| 336 | + *, |
| 337 | + band_label, |
| 338 | + global_min=None, |
| 339 | + global_max=None, |
| 340 | + figsize_per_row=3.2, |
| 341 | +): |
| 342 | + """Plot one band-sum histogram per group in a vertical stack with a shared x-axis.""" |
| 343 | + n = len(aggregate_results) |
| 344 | + fig, axes = plt.subplots(n, 1, figsize=(10, figsize_per_row * n), sharex=True, squeeze=False) |
| 345 | + axes = axes[:, 0] |
| 346 | + |
| 347 | + if global_min is None or global_max is None: |
| 348 | + all_band_vals = np.concatenate([ |
| 349 | + res["df"]["band_value"].dropna().to_numpy(dtype=float) |
| 350 | + for res in aggregate_results.values() |
| 351 | + if len(res["df"]) > 0 |
| 352 | + ]) |
| 353 | + global_min = float(np.min(all_band_vals)) |
| 354 | + global_max = float(np.max(all_band_vals)) |
| 355 | + |
| 356 | + for i, (ax, (name, res)) in enumerate(zip(axes, aggregate_results.items())): |
| 357 | + vals = res["df"]["band_value"].dropna().to_numpy(dtype=float) |
| 358 | + ax.hist(vals, bins=shared_bins) |
| 359 | + ax.set_xlim(global_min, global_max) |
| 360 | + ax.set_ylabel("Number of cells") |
| 361 | + ax.set_title(f"{name}\n{band_label}") |
| 362 | + if i < n - 1: |
| 363 | + ax.tick_params(axis="x", labelbottom=False) |
| 364 | + |
| 365 | + axes[-1].set_xlabel("Raw band sum") |
| 366 | + fig.tight_layout() |
| 367 | + plt.show() |
| 368 | + |
| 369 | + |
| 370 | +def plot_overlay_band_histograms( |
| 371 | + aggregate_results, |
| 372 | + shared_bins, |
| 373 | + *, |
| 374 | + band_label, |
| 375 | + vmin=None, |
| 376 | + vmax=None, |
| 377 | + log_y=True, |
| 378 | + shade_selected_range=True, |
| 379 | + legend_loc="upper left", |
| 380 | +): |
| 381 | + """Overlay band-sum histograms across groups, optionally shading a selected value range.""" |
| 382 | + plt.figure(figsize=(10, 5)) |
| 383 | + |
| 384 | + for name, res in aggregate_results.items(): |
| 385 | + vals = res["df"]["band_value"].dropna().to_numpy(dtype=float) |
| 386 | + plt.hist(vals, bins=shared_bins, alpha=0.40, label=name) |
| 387 | + |
| 388 | + if shade_selected_range and vmin is not None and vmax is not None: |
| 389 | + plt.axvspan(vmin, vmax, color="gray", alpha=0.15, label=f"selected range: {vmin}–{vmax}") |
| 390 | + |
| 391 | + if log_y: |
| 392 | + plt.yscale("log") |
| 393 | + |
| 394 | + plt.xlabel("Raw band sum") |
| 395 | + plt.ylabel("Number of cells") |
| 396 | + plt.title(f"Band-sum comparison, {band_label}") |
| 397 | + plt.legend(frameon=True, facecolor="white", framealpha=1.0, edgecolor="lightgray", loc=legend_loc) |
| 398 | + plt.tight_layout() |
| 399 | + plt.show() |
0 commit comments