Skip to content

Commit e196c23

Browse files
terapyonterapyon
authored andcommitted
fix ruff format
1 parent 8a6eb8d commit e196c23

4 files changed

Lines changed: 79 additions & 86 deletions

File tree

net_vis/adapters/networkx_adapter.py

Lines changed: 49 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def _detect_graph_type(graph: Any) -> str:
3030
# Check class name to determine type
3131
class_name = type(graph).__name__.lower()
3232

33-
if 'multidigraph' in class_name:
34-
return 'multidigraph'
35-
elif 'multigraph' in class_name:
36-
return 'multigraph'
37-
elif 'digraph' in class_name:
38-
return 'digraph'
33+
if "multidigraph" in class_name:
34+
return "multidigraph"
35+
elif "multigraph" in class_name:
36+
return "multigraph"
37+
elif "digraph" in class_name:
38+
return "digraph"
3939
else:
40-
return 'graph'
40+
return "graph"
4141

4242
@staticmethod
4343
def _extract_nodes(
@@ -82,7 +82,7 @@ def _extract_nodes(
8282
y=float(y),
8383
color=color,
8484
label=label,
85-
metadata=node_attrs
85+
metadata=node_attrs,
8686
)
8787

8888
nodes.append(node)
@@ -106,9 +106,9 @@ def _extract_edges(
106106
# Detect graph type and dispatch to appropriate extractor
107107
graph_type = NetworkXAdapter._detect_graph_type(graph)
108108

109-
if graph_type in ('multigraph', 'multidigraph'):
109+
if graph_type in ("multigraph", "multidigraph"):
110110
return NetworkXAdapter._expand_multigraph_edges(graph, edge_label)
111-
elif graph_type == 'digraph':
111+
elif graph_type == "digraph":
112112
return NetworkXAdapter._extract_edges_digraph(graph, edge_label)
113113
else:
114114
# Basic Graph type
@@ -142,12 +142,7 @@ def _extract_edges_simple(
142142
label = NetworkXAdapter._map_edge_label(edge_attrs, edge_label)
143143

144144
# Create Edge object
145-
edge = Edge(
146-
source=source_str,
147-
target=target_str,
148-
label=label,
149-
metadata=edge_attrs
150-
)
145+
edge = Edge(source=source_str, target=target_str, label=label, metadata=edge_attrs)
151146

152147
edges.append(edge)
153148

@@ -178,18 +173,13 @@ def _extract_edges_digraph(
178173
edge_attrs = dict(graph[source][target]) if graph[source][target] else {}
179174

180175
# Add direction indicator to metadata for DiGraph
181-
edge_attrs['directed'] = True
176+
edge_attrs["directed"] = True
182177

183178
# Apply label mapping
184179
label = NetworkXAdapter._map_edge_label(edge_attrs, edge_label)
185180

186181
# Create Edge object
187-
edge = Edge(
188-
source=source_str,
189-
target=target_str,
190-
label=label,
191-
metadata=edge_attrs
192-
)
182+
edge = Edge(source=source_str, target=target_str, label=label, metadata=edge_attrs)
193183

194184
edges.append(edge)
195185

@@ -216,7 +206,7 @@ def _expand_multigraph_edges(
216206

217207
# Check if this is a directed multigraph
218208
graph_type = NetworkXAdapter._detect_graph_type(graph)
219-
is_directed = graph_type == 'multidigraph'
209+
is_directed = graph_type == "multidigraph"
220210

221211
# MultiGraph.edges() returns (source, target, key) tuples
222212
for source, target, key in graph.edges(keys=True):
@@ -228,22 +218,17 @@ def _expand_multigraph_edges(
228218
edge_attrs = dict(graph[source][target][key]) if graph[source][target][key] else {}
229219

230220
# Preserve edge key in metadata
231-
edge_attrs['edge_key'] = key
221+
edge_attrs["edge_key"] = key
232222

233223
# Add direction indicator for MultiDiGraph
234224
if is_directed:
235-
edge_attrs['directed'] = True
225+
edge_attrs["directed"] = True
236226

237227
# Apply label mapping
238228
label = NetworkXAdapter._map_edge_label(edge_attrs, edge_label)
239229

240230
# Create Edge object
241-
edge = Edge(
242-
source=source_str,
243-
target=target_str,
244-
label=label,
245-
metadata=edge_attrs
246-
)
231+
edge = Edge(source=source_str, target=target_str, label=label, metadata=edge_attrs)
247232

248233
edges.append(edge)
249234

@@ -264,8 +249,8 @@ def _get_existing_positions(graph: Any) -> dict[Any, Any] | None:
264249

265250
for node_id in graph.nodes():
266251
node_data = graph.nodes[node_id]
267-
if 'pos' in node_data:
268-
positions[node_id] = node_data['pos']
252+
if "pos" in node_data:
253+
positions[node_id] = node_data["pos"]
269254
has_positions = True
270255

271256
return positions if has_positions else None
@@ -299,8 +284,7 @@ def _apply_kamada_kawai_layout(graph: Any) -> dict[Any, Any]:
299284
import scipy # type: ignore[import-not-found] # noqa: F401
300285
except ImportError:
301286
raise ImportError(
302-
"Layout 'kamada_kawai' requires scipy. "
303-
"Install with: pip install net_vis[full]"
287+
"Layout 'kamada_kawai' requires scipy. Install with: pip install net_vis[full]"
304288
)
305289

306290
return nx.kamada_kawai_layout(graph)
@@ -322,8 +306,7 @@ def _apply_spectral_layout(graph: Any) -> dict[Any, Any]:
322306
import scipy # type: ignore[import-not-found] # noqa: F401
323307
except ImportError:
324308
raise ImportError(
325-
"Layout 'spectral' requires scipy. "
326-
"Install with: pip install net_vis[full]"
309+
"Layout 'spectral' requires scipy. Install with: pip install net_vis[full]"
327310
)
328311

329312
return nx.spectral_layout(graph)
@@ -383,10 +366,7 @@ def _validate_positions(positions: dict[Any, Any]) -> bool:
383366
return True
384367

385368
@staticmethod
386-
def _compute_layout(
387-
graph: Any,
388-
layout: str | Callable | None = None
389-
) -> dict[Any, Any]:
369+
def _compute_layout(graph: Any, layout: str | Callable | None = None) -> dict[Any, Any]:
390370
"""Compute node positions using specified layout algorithm.
391371
392372
Args:
@@ -423,15 +403,15 @@ def _compute_layout(
423403
# Named layout algorithm
424404
layout_str = str(layout).lower()
425405
try:
426-
if layout_str == 'spring':
406+
if layout_str == "spring":
427407
positions = NetworkXAdapter._apply_spring_layout(graph)
428-
elif layout_str == 'kamada_kawai':
408+
elif layout_str == "kamada_kawai":
429409
positions = NetworkXAdapter._apply_kamada_kawai_layout(graph)
430-
elif layout_str == 'spectral':
410+
elif layout_str == "spectral":
431411
positions = NetworkXAdapter._apply_spectral_layout(graph)
432-
elif layout_str == 'circular':
412+
elif layout_str == "circular":
433413
positions = NetworkXAdapter._apply_circular_layout(graph)
434-
elif layout_str == 'random':
414+
elif layout_str == "random":
435415
positions = NetworkXAdapter._apply_random_layout(graph)
436416
else:
437417
warnings.warn(f"Unknown layout '{layout}', using spring layout")
@@ -442,7 +422,9 @@ def _compute_layout(
442422

443423
# Validate positions
444424
if not NetworkXAdapter._validate_positions(positions):
445-
warnings.warn("Layout produced invalid positions (NaN/inf), falling back to random layout")
425+
warnings.warn(
426+
"Layout produced invalid positions (NaN/inf), falling back to random layout"
427+
)
446428
positions = NetworkXAdapter._apply_random_layout(graph)
447429

448430
return positions
@@ -495,13 +477,15 @@ def convert_graph(
495477
layer_id="", # Will be set by Plotter
496478
nodes=nodes,
497479
edges=edges,
498-
metadata={"graph_type": graph_type}
480+
metadata={"graph_type": graph_type},
499481
)
500482

501483
return layer
502484

503485
@staticmethod
504-
def _map_node_color(node_id: Any, node_data: dict, mapping: str | Callable | None) -> str | None:
486+
def _map_node_color(
487+
node_id: Any, node_data: dict, mapping: str | Callable | None
488+
) -> str | None:
505489
"""Map node attribute to color value.
506490
507491
Args:
@@ -528,7 +512,9 @@ def _map_node_color(node_id: Any, node_data: dict, mapping: str | Callable | Non
528512
return str(value) if value is not None else None
529513

530514
@staticmethod
531-
def _map_node_label(node_id: Any, node_data: dict, mapping: str | Callable | None) -> str | None:
515+
def _map_node_label(
516+
node_id: Any, node_data: dict, mapping: str | Callable | None
517+
) -> str | None:
532518
"""Map node attribute to label value.
533519
534520
Args:
@@ -602,8 +588,8 @@ def _detect_color_type(values: list) -> str:
602588

603589
# If majority are numeric, treat as numeric
604590
if total_count > 0 and numeric_count / total_count > 0.5:
605-
return 'numeric'
606-
return 'categorical'
591+
return "numeric"
592+
return "categorical"
607593

608594
@staticmethod
609595
def _apply_continuous_color_scale(value: float, min_val: float, max_val: float) -> str:
@@ -645,8 +631,16 @@ def _apply_categorical_color_palette(category: str) -> str:
645631
"""
646632
# D3.js Category10 palette
647633
palette = [
648-
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
649-
"#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"
634+
"#1f77b4",
635+
"#ff7f0e",
636+
"#2ca02c",
637+
"#d62728",
638+
"#9467bd",
639+
"#8c564b",
640+
"#e377c2",
641+
"#7f7f7f",
642+
"#bcbd22",
643+
"#17becf",
650644
]
651645

652646
# Use hash of category string to select color

net_vis/plotter.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,8 @@ def add_networkx(
141141
- NaN/inf positions trigger automatic fallback to random layout
142142
"""
143143
# Validate input is a NetworkX graph
144-
if not hasattr(graph, 'nodes') or not hasattr(graph, 'edges'):
145-
raise TypeError(
146-
f"Expected NetworkX graph object, got {type(graph).__name__}"
147-
)
144+
if not hasattr(graph, "nodes") or not hasattr(graph, "edges"):
145+
raise TypeError(f"Expected NetworkX graph object, got {type(graph).__name__}")
148146

149147
# Generate layer ID if not provided
150148
if layer_id is None:
@@ -187,8 +185,6 @@ def _repr_mimebundle_(self, include=None, exclude=None) -> dict:
187185
scene_dict = self._scene.to_dict()
188186

189187
return {
190-
"application/vnd.netvis+json": {
191-
"data": json.dumps(scene_dict)
192-
},
193-
"text/plain": f"<Plotter with {len(self._scene.layers)} layer(s)>"
188+
"application/vnd.netvis+json": {"data": json.dumps(scene_dict)},
189+
"text/plain": f"<Plotter with {len(self._scene.layers)} layer(s)>",
194190
}

0 commit comments

Comments
 (0)