Skip to content

Commit b08dd8b

Browse files
committed
Allow plot_probe not to plot on axes, but just return polycollections
1 parent 35d2b19 commit b08dd8b

1 file changed

Lines changed: 58 additions & 50 deletions

File tree

src/probeinterface/plotting.py

Lines changed: 58 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def plot_probe(
2828
ylims: tuple | None = None,
2929
zlims: tuple | None = None,
3030
show_channel_on_click: bool = False,
31+
add_to_axis: bool = True,
3132
):
3233
"""Plot a Probe object.
3334
Generates a 2D or 3D axis, depending on Probe.ndim
@@ -64,6 +65,9 @@ def plot_probe(
6465
Limits for z dimension
6566
show_channel_on_click : bool, default: False
6667
If True, the channel information is shown upon click
68+
add_to_axis : bool, default: True
69+
If True, collections are added to the axis. If False, collections are
70+
only returned without being added to the axis.
6771
6872
Returns
6973
-------
@@ -79,14 +83,14 @@ def plot_probe(
7983
elif probe.ndim == 3:
8084
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8185

82-
if ax is None:
86+
if ax is None and add_to_axis:
8387
if probe.ndim == 2:
8488
fig, ax = plt.subplots()
8589
ax.set_aspect("equal")
8690
else:
8791
fig = plt.figure()
8892
ax = fig.add_subplot(1, 1, 1, projection="3d")
89-
else:
93+
elif ax is not None:
9094
fig = ax.get_figure()
9195

9296
_probe_shape_kwargs = dict(facecolor="green", edgecolor="k", lw=0.5, alpha=0.3)
@@ -107,16 +111,18 @@ def plot_probe(
107111
vertices = probe.get_contact_vertices()
108112
if probe.ndim == 2:
109113
poly = PolyCollection(vertices, color=contacts_colors, **_contacts_kargs)
110-
ax.add_collection(poly)
114+
if add_to_axis and ax is not None:
115+
ax.add_collection(poly)
111116
elif probe.ndim == 3:
112117
poly = Poly3DCollection(vertices, color=contacts_colors, **_contacts_kargs)
113-
ax.add_collection3d(poly)
118+
if add_to_axis and ax is not None:
119+
ax.add_collection3d(poly)
114120

115121
if contacts_values is not None:
116122
poly.set_array(contacts_values)
117123
poly.set_cmap(cmap)
118124

119-
if show_channel_on_click:
125+
if show_channel_on_click and add_to_axis:
120126
assert probe.ndim == 2, "show_channel_on_click works only for ndim=2"
121127

122128
def on_press(event):
@@ -126,61 +132,63 @@ def on_press(event):
126132
fig.canvas.mpl_connect("button_release_event", on_release)
127133

128134
# probe shape
135+
poly_contour = None
129136
planar_contour = probe.probe_planar_contour
130137
if planar_contour is not None:
131138
if probe.ndim == 2:
132139
poly_contour = PolyCollection([planar_contour], **_probe_shape_kwargs)
133-
ax.add_collection(poly_contour)
140+
if add_to_axis and ax is not None:
141+
ax.add_collection(poly_contour)
134142
elif probe.ndim == 3:
135143
poly_contour = Poly3DCollection([planar_contour], **_probe_shape_kwargs)
136-
ax.add_collection3d(poly_contour)
137-
else:
138-
poly_contour = None
139-
140-
if text_on_contact is not None:
141-
text_on_contact = np.asarray(text_on_contact)
142-
assert text_on_contact.size == probe.get_contact_count()
144+
if add_to_axis and ax is not None:
145+
ax.add_collection3d(poly_contour)
146+
147+
if add_to_axis and ax is not None:
148+
if text_on_contact is not None:
149+
text_on_contact = np.asarray(text_on_contact)
150+
assert text_on_contact.size == probe.get_contact_count()
151+
152+
if with_contact_id or with_device_index or text_on_contact is not None:
153+
if probe.ndim == 3:
154+
raise NotImplementedError("Channel index is 2d only")
155+
for i in range(n):
156+
txt = []
157+
if with_contact_id and probe.contact_ids is not None:
158+
contact_id = probe.contact_ids[i]
159+
txt.append(f"id{contact_id}")
160+
if with_device_index and probe.device_channel_indices is not None:
161+
chan_ind = probe.device_channel_indices[i]
162+
txt.append(f"dev{chan_ind}")
163+
if text_on_contact is not None:
164+
txt.append(f"{text_on_contact[i]}")
165+
166+
txt = "\n".join(txt)
167+
x, y = probe.contact_positions[i]
168+
ax.text(x, y, txt, ha="center", va="center", clip_on=True)
169+
170+
if xlims is None or ylims is None or (zlims is None and probe.ndim == 3):
171+
xlims, ylims, zlims = get_auto_lims(probe)
172+
173+
ax.set_xlim(*xlims)
174+
ax.set_ylim(*ylims)
175+
176+
if probe.si_units == "um":
177+
unit_str = "($\\mu m$)"
178+
else:
179+
unit_str = f"({probe.si_units})"
180+
ax.set_xlabel(f"x {unit_str}", fontsize=15)
181+
ax.set_ylabel(f"y {unit_str}", fontsize=15)
143182

144-
if with_contact_id or with_device_index or text_on_contact is not None:
145183
if probe.ndim == 3:
146-
raise NotImplementedError("Channel index is 2d only")
147-
for i in range(n):
148-
txt = []
149-
if with_contact_id and probe.contact_ids is not None:
150-
contact_id = probe.contact_ids[i]
151-
txt.append(f"id{contact_id}")
152-
if with_device_index and probe.device_channel_indices is not None:
153-
chan_ind = probe.device_channel_indices[i]
154-
txt.append(f"dev{chan_ind}")
155-
if text_on_contact is not None:
156-
txt.append(f"{text_on_contact[i]}")
157-
158-
txt = "\n".join(txt)
159-
x, y = probe.contact_positions[i]
160-
ax.text(x, y, txt, ha="center", va="center", clip_on=True)
161-
162-
if xlims is None or ylims is None or (zlims is None and probe.ndim == 3):
163-
xlims, ylims, zlims = get_auto_lims(probe)
164-
165-
ax.set_xlim(*xlims)
166-
ax.set_ylim(*ylims)
167-
168-
if probe.si_units == "um":
169-
unit_str = "($\\mu m$)"
170-
else:
171-
unit_str = f"({probe.si_units})"
172-
ax.set_xlabel(f"x {unit_str}", fontsize=15)
173-
ax.set_ylabel(f"y {unit_str}", fontsize=15)
184+
ax.set_zlim(zlims)
185+
ax.set_zlabel("z")
174186

175-
if probe.ndim == 3:
176-
ax.set_zlim(zlims)
177-
ax.set_zlabel("z")
178-
179-
if probe.ndim == 2:
180-
ax.set_aspect("equal")
187+
if probe.ndim == 2:
188+
ax.set_aspect("equal")
181189

182-
if title:
183-
ax.set_title(probe.get_title())
190+
if title:
191+
ax.set_title(probe.get_title())
184192

185193
return poly, poly_contour
186194

0 commit comments

Comments
 (0)