Skip to content

Commit 52e400e

Browse files
authored
Update plotting.py
1 parent abbd4a6 commit 52e400e

1 file changed

Lines changed: 199 additions & 0 deletions

File tree

libs/surface_viewer/plotting.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import matplotlib.pyplot as plt
44
import numpy as np
5+
import pandas as pd
56

67
from .calibration import channel_to_keV, get_energy_cal_from_dataset, make_energy_axis_from_length
78
from .peaks import preprocess
@@ -198,3 +199,201 @@ def plot_identified_elements_confident(
198199
"assign_scored": assign if assign is not None else None,
199200
"label_df": label_df,
200201
}
202+
203+
204+
def _select_overlay_indices(n_cells: int, n_show: int) -> np.ndarray:
205+
if n_cells <= 0:
206+
return np.array([], dtype=int)
207+
n_show = min(max(int(n_show), 1), n_cells)
208+
if n_cells <= n_show:
209+
return np.arange(n_cells, dtype=int)
210+
return np.unique(np.linspace(0, n_cells - 1, n_show, dtype=int))
211+
212+
213+
def plot_overlaid_cell_spectra(
214+
aggregate_results,
215+
peak_summary,
216+
band_start,
217+
band_end,
218+
band_mode="channels",
219+
cal=None,
220+
show_energy_top_axis=True,
221+
n_overlay_spectra=50,
222+
peak_top_n_labels=6,
223+
peak_fwhm_mn_ev=67.8,
224+
show_peak_crosses=False,
225+
show_element_lines=True,
226+
band_color="gray",
227+
band_alpha=0.15,
228+
line_color="black",
229+
overlay_color="C0",
230+
):
231+
"""Overlay up to N cell spectra for each group with optional band shading and element-line annotations."""
232+
from .spectra import band_label_text, resolve_band_to_channels
233+
234+
band_start_ch, band_end_ch = resolve_band_to_channels(
235+
band_start,
236+
band_end,
237+
band_mode=band_mode,
238+
cal=cal,
239+
)
240+
band_label = band_label_text(band_start, band_end, band_mode)
241+
242+
for name, res in aggregate_results.items():
243+
stack = np.asarray(res["stack"], dtype=float)
244+
if stack.ndim != 2 or stack.shape[0] == 0:
245+
print(f"Skipping {name}: no spectra available.")
246+
continue
247+
248+
show_idx = _select_overlay_indices(stack.shape[0], n_overlay_spectra)
249+
x = np.arange(stack.shape[1], dtype=float)
250+
251+
fig, ax = plt.subplots(figsize=(11, 5))
252+
for idx in show_idx:
253+
ax.plot(x, stack[idx], lw=0.8, alpha=0.25, color=overlay_color)
254+
255+
ax.axvspan(band_start_ch, band_end_ch, color=band_color, alpha=band_alpha)
256+
257+
ymax = float(np.nanmax(stack[show_idx])) if show_idx.size else 1.0
258+
if not np.isfinite(ymax) or ymax <= 0:
259+
ymax = 1.0
260+
261+
if show_peak_crosses:
262+
peaks_df_overlay = peak_summary.get(name, {}).get("peaks_df", pd.DataFrame())
263+
if peaks_df_overlay is not None and not peaks_df_overlay.empty and "x" in peaks_df_overlay.columns:
264+
x_peaks = peaks_df_overlay["x"].to_numpy(dtype=float)
265+
x_peaks = x_peaks[(x_peaks >= 0) & (x_peaks <= stack.shape[1] - 1)]
266+
if x_peaks.size:
267+
ax.scatter(
268+
x_peaks,
269+
np.full_like(x_peaks, 0.98 * ymax),
270+
marker="x",
271+
s=28,
272+
color=line_color,
273+
zorder=5,
274+
)
275+
276+
assign_df_overlay = peak_summary.get(name, {}).get("assign_df", pd.DataFrame())
277+
if show_element_lines and cal is not None and assign_df_overlay is not None and not assign_df_overlay.empty:
278+
assign = assign_df_overlay.copy()
279+
280+
if "height" not in assign.columns or assign["height"].isna().all():
281+
assign["height"] = assign.get("area", 0.0)
282+
283+
if "prominence" not in assign.columns or assign["prominence"].isna().all():
284+
assign["prominence"] = assign["height"]
285+
286+
if "delta_keV" not in assign.columns and {"lib_energy_keV", "energy_keV"}.issubset(assign.columns):
287+
assign["delta_keV"] = assign["lib_energy_keV"] - assign["energy_keV"]
288+
289+
if "score" not in assign.columns:
290+
if "delta_keV" in assign.columns:
291+
sigma_E = max((float(peak_fwhm_mn_ev) / 1000.0) / 2.355, 1e-6)
292+
assign["z_mismatch"] = assign["delta_keV"].abs() / sigma_E
293+
else:
294+
assign["z_mismatch"] = 0.0
295+
assign["score"] = assign["prominence"] / (1.0 + assign["z_mismatch"])
296+
297+
assign = assign.sort_values("score", ascending=False).head(int(peak_top_n_labels))
298+
assign = assign.sort_values("lib_energy_keV").reset_index(drop=True)
299+
300+
levels = np.linspace(0.93, 0.68, num=max(len(assign), 1))
301+
start_eV = float(cal["start_eV"])
302+
eV_per_ch = float(cal["eV_per_ch"])
303+
304+
for i, (_, row) in enumerate(assign.iterrows()):
305+
e_keV = float(row["lib_energy_keV"])
306+
ch = (e_keV * 1000.0 - start_eV) / eV_per_ch
307+
if 0 <= ch <= stack.shape[1] - 1:
308+
label = str(row.get("label", f"{row.get('element', '')} {row.get('line', '')}")).strip()
309+
ax.axvline(ch, ls="--", lw=1.0, alpha=0.55, color=line_color)
310+
ax.text(
311+
ch,
312+
levels[min(i, len(levels) - 1)] * ymax,
313+
label,
314+
rotation=90,
315+
va="top",
316+
ha="center",
317+
fontsize=8,
318+
color=line_color,
319+
bbox=dict(facecolor="white", alpha=0.75, edgecolor="none", pad=0.6),
320+
)
321+
322+
ax.set_xlabel("Channel")
323+
ax.set_ylabel("Counts")
324+
ax.set_title(f"{name}: {len(show_idx)} overlaid cell spectra ({band_label}; shaded gray)")
325+
326+
if show_energy_top_axis and cal is not None:
327+
add_energy_top_axis(ax, cal=cal, n=stack.shape[1])
328+
329+
fig.tight_layout()
330+
plt.show()
331+
332+
333+
def plot_stacked_band_histograms(
334+
aggregate_results,
335+
shared_bins,
336+
*,
337+
band_label,
338+
global_min=None,
339+
global_max=None,
340+
figsize_per_row=3.2,
341+
):
342+
"""Plot one band-sum histogram per group in a vertical stack with a shared x-axis."""
343+
n = len(aggregate_results)
344+
fig, axes = plt.subplots(n, 1, figsize=(10, figsize_per_row * n), sharex=True, squeeze=False)
345+
axes = axes[:, 0]
346+
347+
if global_min is None or global_max is None:
348+
all_band_vals = np.concatenate([
349+
res["df"]["band_value"].dropna().to_numpy(dtype=float)
350+
for res in aggregate_results.values()
351+
if len(res["df"]) > 0
352+
])
353+
global_min = float(np.min(all_band_vals))
354+
global_max = float(np.max(all_band_vals))
355+
356+
for i, (ax, (name, res)) in enumerate(zip(axes, aggregate_results.items())):
357+
vals = res["df"]["band_value"].dropna().to_numpy(dtype=float)
358+
ax.hist(vals, bins=shared_bins)
359+
ax.set_xlim(global_min, global_max)
360+
ax.set_ylabel("Number of cells")
361+
ax.set_title(f"{name}\n{band_label}")
362+
if i < n - 1:
363+
ax.tick_params(axis="x", labelbottom=False)
364+
365+
axes[-1].set_xlabel("Raw band sum")
366+
fig.tight_layout()
367+
plt.show()
368+
369+
370+
def plot_overlay_band_histograms(
371+
aggregate_results,
372+
shared_bins,
373+
*,
374+
band_label,
375+
vmin=None,
376+
vmax=None,
377+
log_y=True,
378+
shade_selected_range=True,
379+
legend_loc="upper left",
380+
):
381+
"""Overlay band-sum histograms across groups, optionally shading a selected value range."""
382+
plt.figure(figsize=(10, 5))
383+
384+
for name, res in aggregate_results.items():
385+
vals = res["df"]["band_value"].dropna().to_numpy(dtype=float)
386+
plt.hist(vals, bins=shared_bins, alpha=0.40, label=name)
387+
388+
if shade_selected_range and vmin is not None and vmax is not None:
389+
plt.axvspan(vmin, vmax, color="gray", alpha=0.15, label=f"selected range: {vmin}{vmax}")
390+
391+
if log_y:
392+
plt.yscale("log")
393+
394+
plt.xlabel("Raw band sum")
395+
plt.ylabel("Number of cells")
396+
plt.title(f"Band-sum comparison, {band_label}")
397+
plt.legend(frameon=True, facecolor="white", framealpha=1.0, edgecolor="lightgray", loc=legend_loc)
398+
plt.tight_layout()
399+
plt.show()

0 commit comments

Comments
 (0)