@@ -30,14 +30,14 @@ def _detect_graph_type(graph: Any) -> str:
3030 # Check class name to determine type
3131 class_name = type (graph ).__name__ .lower ()
3232
33- if ' multidigraph' in class_name :
34- return ' multidigraph'
35- elif ' multigraph' in class_name :
36- return ' multigraph'
37- elif ' digraph' in class_name :
38- return ' digraph'
33+ if " multidigraph" in class_name :
34+ return " multidigraph"
35+ elif " multigraph" in class_name :
36+ return " multigraph"
37+ elif " digraph" in class_name :
38+ return " digraph"
3939 else :
40- return ' graph'
40+ return " graph"
4141
4242 @staticmethod
4343 def _extract_nodes (
@@ -82,7 +82,7 @@ def _extract_nodes(
8282 y = float (y ),
8383 color = color ,
8484 label = label ,
85- metadata = node_attrs
85+ metadata = node_attrs ,
8686 )
8787
8888 nodes .append (node )
@@ -106,9 +106,9 @@ def _extract_edges(
106106 # Detect graph type and dispatch to appropriate extractor
107107 graph_type = NetworkXAdapter ._detect_graph_type (graph )
108108
109- if graph_type in (' multigraph' , ' multidigraph' ):
109+ if graph_type in (" multigraph" , " multidigraph" ):
110110 return NetworkXAdapter ._expand_multigraph_edges (graph , edge_label )
111- elif graph_type == ' digraph' :
111+ elif graph_type == " digraph" :
112112 return NetworkXAdapter ._extract_edges_digraph (graph , edge_label )
113113 else :
114114 # Basic Graph type
@@ -142,12 +142,7 @@ def _extract_edges_simple(
142142 label = NetworkXAdapter ._map_edge_label (edge_attrs , edge_label )
143143
144144 # Create Edge object
145- edge = Edge (
146- source = source_str ,
147- target = target_str ,
148- label = label ,
149- metadata = edge_attrs
150- )
145+ edge = Edge (source = source_str , target = target_str , label = label , metadata = edge_attrs )
151146
152147 edges .append (edge )
153148
@@ -178,18 +173,13 @@ def _extract_edges_digraph(
178173 edge_attrs = dict (graph [source ][target ]) if graph [source ][target ] else {}
179174
180175 # Add direction indicator to metadata for DiGraph
181- edge_attrs [' directed' ] = True
176+ edge_attrs [" directed" ] = True
182177
183178 # Apply label mapping
184179 label = NetworkXAdapter ._map_edge_label (edge_attrs , edge_label )
185180
186181 # Create Edge object
187- edge = Edge (
188- source = source_str ,
189- target = target_str ,
190- label = label ,
191- metadata = edge_attrs
192- )
182+ edge = Edge (source = source_str , target = target_str , label = label , metadata = edge_attrs )
193183
194184 edges .append (edge )
195185
@@ -216,7 +206,7 @@ def _expand_multigraph_edges(
216206
217207 # Check if this is a directed multigraph
218208 graph_type = NetworkXAdapter ._detect_graph_type (graph )
219- is_directed = graph_type == ' multidigraph'
209+ is_directed = graph_type == " multidigraph"
220210
221211 # MultiGraph.edges() returns (source, target, key) tuples
222212 for source , target , key in graph .edges (keys = True ):
@@ -228,22 +218,17 @@ def _expand_multigraph_edges(
228218 edge_attrs = dict (graph [source ][target ][key ]) if graph [source ][target ][key ] else {}
229219
230220 # Preserve edge key in metadata
231- edge_attrs [' edge_key' ] = key
221+ edge_attrs [" edge_key" ] = key
232222
233223 # Add direction indicator for MultiDiGraph
234224 if is_directed :
235- edge_attrs [' directed' ] = True
225+ edge_attrs [" directed" ] = True
236226
237227 # Apply label mapping
238228 label = NetworkXAdapter ._map_edge_label (edge_attrs , edge_label )
239229
240230 # Create Edge object
241- edge = Edge (
242- source = source_str ,
243- target = target_str ,
244- label = label ,
245- metadata = edge_attrs
246- )
231+ edge = Edge (source = source_str , target = target_str , label = label , metadata = edge_attrs )
247232
248233 edges .append (edge )
249234
@@ -264,8 +249,8 @@ def _get_existing_positions(graph: Any) -> dict[Any, Any] | None:
264249
265250 for node_id in graph .nodes ():
266251 node_data = graph .nodes [node_id ]
267- if ' pos' in node_data :
268- positions [node_id ] = node_data [' pos' ]
252+ if " pos" in node_data :
253+ positions [node_id ] = node_data [" pos" ]
269254 has_positions = True
270255
271256 return positions if has_positions else None
@@ -299,8 +284,7 @@ def _apply_kamada_kawai_layout(graph: Any) -> dict[Any, Any]:
299284 import scipy # type: ignore[import-not-found] # noqa: F401
300285 except ImportError :
301286 raise ImportError (
302- "Layout 'kamada_kawai' requires scipy. "
303- "Install with: pip install net_vis[full]"
287+ "Layout 'kamada_kawai' requires scipy. Install with: pip install net_vis[full]"
304288 )
305289
306290 return nx .kamada_kawai_layout (graph )
@@ -322,8 +306,7 @@ def _apply_spectral_layout(graph: Any) -> dict[Any, Any]:
322306 import scipy # type: ignore[import-not-found] # noqa: F401
323307 except ImportError :
324308 raise ImportError (
325- "Layout 'spectral' requires scipy. "
326- "Install with: pip install net_vis[full]"
309+ "Layout 'spectral' requires scipy. Install with: pip install net_vis[full]"
327310 )
328311
329312 return nx .spectral_layout (graph )
@@ -383,10 +366,7 @@ def _validate_positions(positions: dict[Any, Any]) -> bool:
383366 return True
384367
385368 @staticmethod
386- def _compute_layout (
387- graph : Any ,
388- layout : str | Callable | None = None
389- ) -> dict [Any , Any ]:
369+ def _compute_layout (graph : Any , layout : str | Callable | None = None ) -> dict [Any , Any ]:
390370 """Compute node positions using specified layout algorithm.
391371
392372 Args:
@@ -423,15 +403,15 @@ def _compute_layout(
423403 # Named layout algorithm
424404 layout_str = str (layout ).lower ()
425405 try :
426- if layout_str == ' spring' :
406+ if layout_str == " spring" :
427407 positions = NetworkXAdapter ._apply_spring_layout (graph )
428- elif layout_str == ' kamada_kawai' :
408+ elif layout_str == " kamada_kawai" :
429409 positions = NetworkXAdapter ._apply_kamada_kawai_layout (graph )
430- elif layout_str == ' spectral' :
410+ elif layout_str == " spectral" :
431411 positions = NetworkXAdapter ._apply_spectral_layout (graph )
432- elif layout_str == ' circular' :
412+ elif layout_str == " circular" :
433413 positions = NetworkXAdapter ._apply_circular_layout (graph )
434- elif layout_str == ' random' :
414+ elif layout_str == " random" :
435415 positions = NetworkXAdapter ._apply_random_layout (graph )
436416 else :
437417 warnings .warn (f"Unknown layout '{ layout } ', using spring layout" )
@@ -442,7 +422,9 @@ def _compute_layout(
442422
443423 # Validate positions
444424 if not NetworkXAdapter ._validate_positions (positions ):
445- warnings .warn ("Layout produced invalid positions (NaN/inf), falling back to random layout" )
425+ warnings .warn (
426+ "Layout produced invalid positions (NaN/inf), falling back to random layout"
427+ )
446428 positions = NetworkXAdapter ._apply_random_layout (graph )
447429
448430 return positions
@@ -495,13 +477,15 @@ def convert_graph(
495477 layer_id = "" , # Will be set by Plotter
496478 nodes = nodes ,
497479 edges = edges ,
498- metadata = {"graph_type" : graph_type }
480+ metadata = {"graph_type" : graph_type },
499481 )
500482
501483 return layer
502484
503485 @staticmethod
504- def _map_node_color (node_id : Any , node_data : dict , mapping : str | Callable | None ) -> str | None :
486+ def _map_node_color (
487+ node_id : Any , node_data : dict , mapping : str | Callable | None
488+ ) -> str | None :
505489 """Map node attribute to color value.
506490
507491 Args:
@@ -528,7 +512,9 @@ def _map_node_color(node_id: Any, node_data: dict, mapping: str | Callable | Non
528512 return str (value ) if value is not None else None
529513
530514 @staticmethod
531- def _map_node_label (node_id : Any , node_data : dict , mapping : str | Callable | None ) -> str | None :
515+ def _map_node_label (
516+ node_id : Any , node_data : dict , mapping : str | Callable | None
517+ ) -> str | None :
532518 """Map node attribute to label value.
533519
534520 Args:
@@ -602,8 +588,8 @@ def _detect_color_type(values: list) -> str:
602588
603589 # If majority are numeric, treat as numeric
604590 if total_count > 0 and numeric_count / total_count > 0.5 :
605- return ' numeric'
606- return ' categorical'
591+ return " numeric"
592+ return " categorical"
607593
608594 @staticmethod
609595 def _apply_continuous_color_scale (value : float , min_val : float , max_val : float ) -> str :
@@ -645,8 +631,16 @@ def _apply_categorical_color_palette(category: str) -> str:
645631 """
646632 # D3.js Category10 palette
647633 palette = [
648- "#1f77b4" , "#ff7f0e" , "#2ca02c" , "#d62728" , "#9467bd" ,
649- "#8c564b" , "#e377c2" , "#7f7f7f" , "#bcbd22" , "#17becf"
634+ "#1f77b4" ,
635+ "#ff7f0e" ,
636+ "#2ca02c" ,
637+ "#d62728" ,
638+ "#9467bd" ,
639+ "#8c564b" ,
640+ "#e377c2" ,
641+ "#7f7f7f" ,
642+ "#bcbd22" ,
643+ "#17becf" ,
650644 ]
651645
652646 # Use hash of category string to select color
0 commit comments