@@ -28,6 +28,7 @@ def plot_probe(
2828 ylims : tuple | None = None ,
2929 zlims : tuple | None = None ,
3030 show_channel_on_click : bool = False ,
31+ add_to_axis : bool = True ,
3132):
3233 """Plot a Probe object.
3334 Generates a 2D or 3D axis, depending on Probe.ndim
@@ -64,6 +65,9 @@ def plot_probe(
6465 Limits for z dimension
6566 show_channel_on_click : bool, default: False
6667 If True, the channel information is shown upon click
68+ add_to_axis : bool, default: True
69+ If True, collections are added to the axis. If False, collections are
70+ only returned without being added to the axis.
6771
6872 Returns
6973 -------
@@ -79,14 +83,14 @@ def plot_probe(
7983 elif probe .ndim == 3 :
8084 from mpl_toolkits .mplot3d .art3d import Poly3DCollection
8185
82- if ax is None :
86+ if ax is None and add_to_axis :
8387 if probe .ndim == 2 :
8488 fig , ax = plt .subplots ()
8589 ax .set_aspect ("equal" )
8690 else :
8791 fig = plt .figure ()
8892 ax = fig .add_subplot (1 , 1 , 1 , projection = "3d" )
89- else :
93+ elif ax is not None :
9094 fig = ax .get_figure ()
9195
9296 _probe_shape_kwargs = dict (facecolor = "green" , edgecolor = "k" , lw = 0.5 , alpha = 0.3 )
@@ -107,16 +111,18 @@ def plot_probe(
107111 vertices = probe .get_contact_vertices ()
108112 if probe .ndim == 2 :
109113 poly = PolyCollection (vertices , color = contacts_colors , ** _contacts_kargs )
110- ax .add_collection (poly )
114+ if add_to_axis and ax is not None :
115+ ax .add_collection (poly )
111116 elif probe .ndim == 3 :
112117 poly = Poly3DCollection (vertices , color = contacts_colors , ** _contacts_kargs )
113- ax .add_collection3d (poly )
118+ if add_to_axis and ax is not None :
119+ ax .add_collection3d (poly )
114120
115121 if contacts_values is not None :
116122 poly .set_array (contacts_values )
117123 poly .set_cmap (cmap )
118124
119- if show_channel_on_click :
125+ if show_channel_on_click and add_to_axis :
120126 assert probe .ndim == 2 , "show_channel_on_click works only for ndim=2"
121127
122128 def on_press (event ):
@@ -126,61 +132,63 @@ def on_press(event):
126132 fig .canvas .mpl_connect ("button_release_event" , on_release )
127133
128134 # probe shape
135+ poly_contour = None
129136 planar_contour = probe .probe_planar_contour
130137 if planar_contour is not None :
131138 if probe .ndim == 2 :
132139 poly_contour = PolyCollection ([planar_contour ], ** _probe_shape_kwargs )
133- ax .add_collection (poly_contour )
140+ if add_to_axis and ax is not None :
141+ ax .add_collection (poly_contour )
134142 elif probe .ndim == 3 :
135143 poly_contour = Poly3DCollection ([planar_contour ], ** _probe_shape_kwargs )
136- ax .add_collection3d (poly_contour )
137- else :
138- poly_contour = None
139-
140- if text_on_contact is not None :
141- text_on_contact = np .asarray (text_on_contact )
142- assert text_on_contact .size == probe .get_contact_count ()
144+ if add_to_axis and ax is not None :
145+ ax .add_collection3d (poly_contour )
146+
147+ if add_to_axis and ax is not None :
148+ if text_on_contact is not None :
149+ text_on_contact = np .asarray (text_on_contact )
150+ assert text_on_contact .size == probe .get_contact_count ()
151+
152+ if with_contact_id or with_device_index or text_on_contact is not None :
153+ if probe .ndim == 3 :
154+ raise NotImplementedError ("Channel index is 2d only" )
155+ for i in range (n ):
156+ txt = []
157+ if with_contact_id and probe .contact_ids is not None :
158+ contact_id = probe .contact_ids [i ]
159+ txt .append (f"id{ contact_id } " )
160+ if with_device_index and probe .device_channel_indices is not None :
161+ chan_ind = probe .device_channel_indices [i ]
162+ txt .append (f"dev{ chan_ind } " )
163+ if text_on_contact is not None :
164+ txt .append (f"{ text_on_contact [i ]} " )
165+
166+ txt = "\n " .join (txt )
167+ x , y = probe .contact_positions [i ]
168+ ax .text (x , y , txt , ha = "center" , va = "center" , clip_on = True )
169+
170+ if xlims is None or ylims is None or (zlims is None and probe .ndim == 3 ):
171+ xlims , ylims , zlims = get_auto_lims (probe )
172+
173+ ax .set_xlim (* xlims )
174+ ax .set_ylim (* ylims )
175+
176+ if probe .si_units == "um" :
177+ unit_str = "($\\ mu m$)"
178+ else :
179+ unit_str = f"({ probe .si_units } )"
180+ ax .set_xlabel (f"x { unit_str } " , fontsize = 15 )
181+ ax .set_ylabel (f"y { unit_str } " , fontsize = 15 )
143182
144- if with_contact_id or with_device_index or text_on_contact is not None :
145183 if probe .ndim == 3 :
146- raise NotImplementedError ("Channel index is 2d only" )
147- for i in range (n ):
148- txt = []
149- if with_contact_id and probe .contact_ids is not None :
150- contact_id = probe .contact_ids [i ]
151- txt .append (f"id{ contact_id } " )
152- if with_device_index and probe .device_channel_indices is not None :
153- chan_ind = probe .device_channel_indices [i ]
154- txt .append (f"dev{ chan_ind } " )
155- if text_on_contact is not None :
156- txt .append (f"{ text_on_contact [i ]} " )
157-
158- txt = "\n " .join (txt )
159- x , y = probe .contact_positions [i ]
160- ax .text (x , y , txt , ha = "center" , va = "center" , clip_on = True )
161-
162- if xlims is None or ylims is None or (zlims is None and probe .ndim == 3 ):
163- xlims , ylims , zlims = get_auto_lims (probe )
164-
165- ax .set_xlim (* xlims )
166- ax .set_ylim (* ylims )
167-
168- if probe .si_units == "um" :
169- unit_str = "($\\ mu m$)"
170- else :
171- unit_str = f"({ probe .si_units } )"
172- ax .set_xlabel (f"x { unit_str } " , fontsize = 15 )
173- ax .set_ylabel (f"y { unit_str } " , fontsize = 15 )
184+ ax .set_zlim (zlims )
185+ ax .set_zlabel ("z" )
174186
175- if probe .ndim == 3 :
176- ax .set_zlim (zlims )
177- ax .set_zlabel ("z" )
178-
179- if probe .ndim == 2 :
180- ax .set_aspect ("equal" )
187+ if probe .ndim == 2 :
188+ ax .set_aspect ("equal" )
181189
182- if title :
183- ax .set_title (probe .get_title ())
190+ if title :
191+ ax .set_title (probe .get_title ())
184192
185193 return poly , poly_contour
186194
0 commit comments