1212 vesin = None
1313
1414
15+ def _create_graph_from_connectivity (
16+ atoms : ase .Atoms , connectivity , charges
17+ ) -> nx .Graph :
18+ """Create NetworkX graph from explicit connectivity information."""
19+ graph = nx .Graph ()
20+ graph .graph ["pbc" ] = atoms .pbc
21+ graph .graph ["cell" ] = atoms .cell
22+
23+ for i , atom in enumerate (atoms ):
24+ graph .add_node (
25+ i ,
26+ position = atom .position ,
27+ atomic_number = atom .number ,
28+ original_index = atom .index ,
29+ charge = charges [i ],
30+ )
31+
32+ for i , j , bond_order in connectivity :
33+ graph .add_edge (i , j , bond_order = bond_order )
34+ return graph
35+
36+
37+ def _compute_connectivity_matrix (atoms : ase .Atoms , scale : float , pbc : bool ):
38+ """Compute connectivity matrix from distance-based cutoffs."""
39+ # non-bonding positive charged atoms / ions.
40+ non_bonding_atomic_numbers = {3 , 11 , 19 , 37 , 55 , 87 }
41+
42+ atomic_numbers = atoms .get_atomic_numbers ()
43+ excluded_mask = np .isin (atomic_numbers , list (non_bonding_atomic_numbers ))
44+
45+ atom_radii = np .array (natural_cutoffs (atoms , mult = scale ))
46+ pairwise_cutoffs = atom_radii [:, None ] + atom_radii [None , :]
47+ max_cutoff = np .max (pairwise_cutoffs )
48+
49+ if vesin is not None :
50+ i , j , d , s = vesin .ase_neighbor_list (
51+ "ijdS" , atoms , cutoff = max_cutoff , self_interaction = False
52+ )
53+ else :
54+ i , j , d , s = neighbor_list (
55+ "ijdS" , atoms , cutoff = max_cutoff , self_interaction = False
56+ )
57+
58+ # If pbc=False, filter out bonds that cross periodic boundaries
59+ if not pbc :
60+ non_periodic_mask = np .all (s == 0 , axis = 1 )
61+ i = i [non_periodic_mask ]
62+ j = j [non_periodic_mask ]
63+ d = d [non_periodic_mask ]
64+
65+ d_ij = np .full ((len (atoms ), len (atoms )), np .inf )
66+ d_ij [i , j ] = d
67+ np .fill_diagonal (d_ij , 0.0 )
68+
69+ # mask out non-bonding atoms
70+ d_ij [excluded_mask , :] = np .inf
71+ d_ij [:, excluded_mask ] = np .inf
72+
73+ connectivity_matrix = np .zeros ((len (atoms ), len (atoms )), dtype = int )
74+ np .fill_diagonal (d_ij , np .inf )
75+ connectivity_matrix [d_ij <= pairwise_cutoffs ] = 1
76+
77+ return connectivity_matrix , non_bonding_atomic_numbers
78+
79+
80+ def _add_node_properties (
81+ graph : nx .Graph , atoms : ase .Atoms , charges , non_bonding_atomic_numbers
82+ ):
83+ """Add node properties to the graph."""
84+ for i , atom in enumerate (atoms ):
85+ graph .nodes [i ]["position" ] = atom .position
86+ graph .nodes [i ]["atomic_number" ] = atom .number
87+ graph .nodes [i ]["original_index" ] = atom .index
88+ graph .nodes [i ]["charge" ] = float (charges [i ])
89+ if atom .number in non_bonding_atomic_numbers :
90+ graph .nodes [i ]["charge" ] = 1.0
91+
92+
1593def ase2networkx (
16- atoms : ase .Atoms , suggestions : list [str ] | None = None , pbc : bool = True
94+ atoms : ase .Atoms ,
95+ suggestions : list [str ] | None = None ,
96+ pbc : bool = True ,
97+ scale : float = 1.2 ,
1798) -> nx .Graph :
1899 """Convert an ASE Atoms object to a NetworkX graph with bonding information.
19100
@@ -33,6 +114,9 @@ def ase2networkx(
33114 Whether to consider periodic boundary conditions when calculating
34115 distances (default is True). If False, only connections within
35116 the unit cell are considered.
117+ scale : float, optional
118+ Scaling factor for the covalent radii when determining bond cutoffs
119+ (default is 1.2).
36120
37121 Returns
38122 -------
@@ -70,85 +154,25 @@ def ase2networkx(
70154 >>> len(graph.edges)
71155 2
72156 """
157+ if len (atoms ) == 0 :
158+ return nx .Graph ()
159+
73160 charges = atoms .get_initial_charges ()
74161
75162 if "connectivity" in atoms .info :
76- connectivity = atoms .info ["connectivity" ]
77- graph = nx .Graph ()
78-
79- graph .graph ["pbc" ] = atoms .pbc
80- graph .graph ["cell" ] = atoms .cell
81-
82- for i , atom in enumerate (atoms ):
83- graph .add_node (
84- i ,
85- position = atom .position ,
86- atomic_number = atom .number ,
87- original_index = atom .index ,
88- charge = charges [i ],
89- )
90-
91- for i , j , bond_order in connectivity :
92- graph .add_edge (
93- i ,
94- j ,
95- bond_order = bond_order ,
96- )
97- return graph
98-
99- # non-bonding positive charged atoms / ions.
100- non_bonding_atomic_numbers = {3 , 11 , 19 , 37 , 55 , 87 }
101-
102- atomic_numbers = atoms .get_atomic_numbers ()
103- excluded_mask = np .isin (atomic_numbers , list (non_bonding_atomic_numbers ))
104-
105- atom_radii = np .array (natural_cutoffs (atoms , mult = 1.2 ))
106- pairwise_cutoffs = atom_radii [:, None ] + atom_radii [None , :]
107-
108- max_cutoff = np .max (pairwise_cutoffs )
109-
110- if vesin is not None :
111- i , j , d , s = vesin .ase_neighbor_list (
112- "ijdS" , atoms , cutoff = max_cutoff , self_interaction = False
163+ return _create_graph_from_connectivity (
164+ atoms , atoms .info ["connectivity" ], charges
113165 )
114- else :
115- i , j , d , s = neighbor_list (
116- "ijdS" , atoms , cutoff = max_cutoff , self_interaction = False
117- )
118-
119- # If pbc=False, filter out bonds that cross periodic boundaries
120- if not pbc :
121- # Keep only bonds where all shift vectors are zero (no periodic wrapping)
122- non_periodic_mask = np .all (s == 0 , axis = 1 )
123- i = i [non_periodic_mask ]
124- j = j [non_periodic_mask ]
125- d = d [non_periodic_mask ]
126-
127- d_ij = np .full ((len (atoms ), len (atoms )), np .inf )
128- d_ij [i , j ] = d
129- np .fill_diagonal (d_ij , 0.0 )
130-
131- # mask out non-bonding atoms
132- d_ij [excluded_mask , :] = np .inf
133- d_ij [:, excluded_mask ] = np .inf
134166
135- connectivity_matrix = np .zeros ((len (atoms ), len (atoms )), dtype = int )
136-
137- np .fill_diagonal (d_ij , np .inf )
138-
139- connectivity_matrix [d_ij <= pairwise_cutoffs ] = 1
167+ connectivity_matrix , non_bonding_atomic_numbers = _compute_connectivity_matrix (
168+ atoms , scale , pbc
169+ )
140170
141171 graph = nx .from_numpy_array (connectivity_matrix , edge_attr = None )
142172 for u , v in graph .edges ():
143173 graph .edges [u , v ]["bond_order" ] = None
144174
145- for i , atom in enumerate (atoms ):
146- graph .nodes [i ]["position" ] = atom .position
147- graph .nodes [i ]["atomic_number" ] = atom .number
148- graph .nodes [i ]["original_index" ] = atom .index
149- graph .nodes [i ]["charge" ] = float (charges [i ])
150- if atom .number in non_bonding_atomic_numbers :
151- graph .nodes [i ]["charge" ] = 1.0
175+ _add_node_properties (graph , atoms , charges , non_bonding_atomic_numbers )
152176
153177 graph .graph ["pbc" ] = atoms .pbc
154178 graph .graph ["cell" ] = atoms .cell
@@ -188,6 +212,9 @@ def ase2rdkit(atoms: ase.Atoms, suggestions: list[str] | None = None) -> Chem.Mo
188212 >>> mol.GetNumAtoms()
189213 4
190214 """
215+ if len (atoms ) == 0 :
216+ return Chem .Mol ()
217+
191218 from rdkit2ase import ase2networkx , networkx2rdkit
192219
193220 graph = ase2networkx (atoms , suggestions = suggestions )
0 commit comments