Skip to content

Commit 2ccc8a7

Browse files
v0.2.11 (#55)
* bump version * empty molecule edge cases * rename `unwrap_structures` * update docstring * add `scale` parameter * add TODO * simplify unwrap code * remove tests that did nothing * test cyclic molecules * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address comments * include node mapping for non-sequential data * add test * add one more test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 89779c4 commit 2ccc8a7

11 files changed

Lines changed: 386 additions & 177 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "rdkit2ase"
3-
version = "0.1.10"
3+
version = "0.1.11"
44
description = "Interface between rdkit and ASE"
55
readme = "README.md"
66
license = "Apache-2.0"

rdkit2ase/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from rdkit2ase.rdkit2x import rdkit2ase, rdkit2networkx
66
from rdkit2ase.smiles2x import smiles2atoms, smiles2conformers
77
from rdkit2ase.substructure import get_substructures, iter_fragments, match_substructure
8+
from rdkit2ase.utils import unwrap_structures
89

910
__all__ = [
1011
"ase2rdkit",
@@ -15,6 +16,7 @@
1516
"match_substructure",
1617
"get_substructures",
1718
"iter_fragments",
19+
"unwrap_structures",
1820
#
1921
"ase2networkx",
2022
#

rdkit2ase/ase2x.py

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,89 @@
1212
vesin = None
1313

1414

15+
def _create_graph_from_connectivity(
16+
atoms: ase.Atoms, connectivity, charges
17+
) -> nx.Graph:
18+
"""Create NetworkX graph from explicit connectivity information."""
19+
graph = nx.Graph()
20+
graph.graph["pbc"] = atoms.pbc
21+
graph.graph["cell"] = atoms.cell
22+
23+
for i, atom in enumerate(atoms):
24+
graph.add_node(
25+
i,
26+
position=atom.position,
27+
atomic_number=atom.number,
28+
original_index=atom.index,
29+
charge=charges[i],
30+
)
31+
32+
for i, j, bond_order in connectivity:
33+
graph.add_edge(i, j, bond_order=bond_order)
34+
return graph
35+
36+
37+
def _compute_connectivity_matrix(atoms: ase.Atoms, scale: float, pbc: bool):
38+
"""Compute connectivity matrix from distance-based cutoffs."""
39+
# non-bonding positive charged atoms / ions.
40+
non_bonding_atomic_numbers = {3, 11, 19, 37, 55, 87}
41+
42+
atomic_numbers = atoms.get_atomic_numbers()
43+
excluded_mask = np.isin(atomic_numbers, list(non_bonding_atomic_numbers))
44+
45+
atom_radii = np.array(natural_cutoffs(atoms, mult=scale))
46+
pairwise_cutoffs = atom_radii[:, None] + atom_radii[None, :]
47+
max_cutoff = np.max(pairwise_cutoffs)
48+
49+
if vesin is not None:
50+
i, j, d, s = vesin.ase_neighbor_list(
51+
"ijdS", atoms, cutoff=max_cutoff, self_interaction=False
52+
)
53+
else:
54+
i, j, d, s = neighbor_list(
55+
"ijdS", atoms, cutoff=max_cutoff, self_interaction=False
56+
)
57+
58+
# If pbc=False, filter out bonds that cross periodic boundaries
59+
if not pbc:
60+
non_periodic_mask = np.all(s == 0, axis=1)
61+
i = i[non_periodic_mask]
62+
j = j[non_periodic_mask]
63+
d = d[non_periodic_mask]
64+
65+
d_ij = np.full((len(atoms), len(atoms)), np.inf)
66+
d_ij[i, j] = d
67+
np.fill_diagonal(d_ij, 0.0)
68+
69+
# mask out non-bonding atoms
70+
d_ij[excluded_mask, :] = np.inf
71+
d_ij[:, excluded_mask] = np.inf
72+
73+
connectivity_matrix = np.zeros((len(atoms), len(atoms)), dtype=int)
74+
np.fill_diagonal(d_ij, np.inf)
75+
connectivity_matrix[d_ij <= pairwise_cutoffs] = 1
76+
77+
return connectivity_matrix, non_bonding_atomic_numbers
78+
79+
80+
def _add_node_properties(
81+
graph: nx.Graph, atoms: ase.Atoms, charges, non_bonding_atomic_numbers
82+
):
83+
"""Add node properties to the graph."""
84+
for i, atom in enumerate(atoms):
85+
graph.nodes[i]["position"] = atom.position
86+
graph.nodes[i]["atomic_number"] = atom.number
87+
graph.nodes[i]["original_index"] = atom.index
88+
graph.nodes[i]["charge"] = float(charges[i])
89+
if atom.number in non_bonding_atomic_numbers:
90+
graph.nodes[i]["charge"] = 1.0
91+
92+
1593
def ase2networkx(
16-
atoms: ase.Atoms, suggestions: list[str] | None = None, pbc: bool = True
94+
atoms: ase.Atoms,
95+
suggestions: list[str] | None = None,
96+
pbc: bool = True,
97+
scale: float = 1.2,
1798
) -> nx.Graph:
1899
"""Convert an ASE Atoms object to a NetworkX graph with bonding information.
19100
@@ -33,6 +114,9 @@ def ase2networkx(
33114
Whether to consider periodic boundary conditions when calculating
34115
distances (default is True). If False, only connections within
35116
the unit cell are considered.
117+
scale : float, optional
118+
Scaling factor for the covalent radii when determining bond cutoffs
119+
(default is 1.2).
36120
37121
Returns
38122
-------
@@ -70,85 +154,25 @@ def ase2networkx(
70154
>>> len(graph.edges)
71155
2
72156
"""
157+
if len(atoms) == 0:
158+
return nx.Graph()
159+
73160
charges = atoms.get_initial_charges()
74161

75162
if "connectivity" in atoms.info:
76-
connectivity = atoms.info["connectivity"]
77-
graph = nx.Graph()
78-
79-
graph.graph["pbc"] = atoms.pbc
80-
graph.graph["cell"] = atoms.cell
81-
82-
for i, atom in enumerate(atoms):
83-
graph.add_node(
84-
i,
85-
position=atom.position,
86-
atomic_number=atom.number,
87-
original_index=atom.index,
88-
charge=charges[i],
89-
)
90-
91-
for i, j, bond_order in connectivity:
92-
graph.add_edge(
93-
i,
94-
j,
95-
bond_order=bond_order,
96-
)
97-
return graph
98-
99-
# non-bonding positive charged atoms / ions.
100-
non_bonding_atomic_numbers = {3, 11, 19, 37, 55, 87}
101-
102-
atomic_numbers = atoms.get_atomic_numbers()
103-
excluded_mask = np.isin(atomic_numbers, list(non_bonding_atomic_numbers))
104-
105-
atom_radii = np.array(natural_cutoffs(atoms, mult=1.2))
106-
pairwise_cutoffs = atom_radii[:, None] + atom_radii[None, :]
107-
108-
max_cutoff = np.max(pairwise_cutoffs)
109-
110-
if vesin is not None:
111-
i, j, d, s = vesin.ase_neighbor_list(
112-
"ijdS", atoms, cutoff=max_cutoff, self_interaction=False
163+
return _create_graph_from_connectivity(
164+
atoms, atoms.info["connectivity"], charges
113165
)
114-
else:
115-
i, j, d, s = neighbor_list(
116-
"ijdS", atoms, cutoff=max_cutoff, self_interaction=False
117-
)
118-
119-
# If pbc=False, filter out bonds that cross periodic boundaries
120-
if not pbc:
121-
# Keep only bonds where all shift vectors are zero (no periodic wrapping)
122-
non_periodic_mask = np.all(s == 0, axis=1)
123-
i = i[non_periodic_mask]
124-
j = j[non_periodic_mask]
125-
d = d[non_periodic_mask]
126-
127-
d_ij = np.full((len(atoms), len(atoms)), np.inf)
128-
d_ij[i, j] = d
129-
np.fill_diagonal(d_ij, 0.0)
130-
131-
# mask out non-bonding atoms
132-
d_ij[excluded_mask, :] = np.inf
133-
d_ij[:, excluded_mask] = np.inf
134166

135-
connectivity_matrix = np.zeros((len(atoms), len(atoms)), dtype=int)
136-
137-
np.fill_diagonal(d_ij, np.inf)
138-
139-
connectivity_matrix[d_ij <= pairwise_cutoffs] = 1
167+
connectivity_matrix, non_bonding_atomic_numbers = _compute_connectivity_matrix(
168+
atoms, scale, pbc
169+
)
140170

141171
graph = nx.from_numpy_array(connectivity_matrix, edge_attr=None)
142172
for u, v in graph.edges():
143173
graph.edges[u, v]["bond_order"] = None
144174

145-
for i, atom in enumerate(atoms):
146-
graph.nodes[i]["position"] = atom.position
147-
graph.nodes[i]["atomic_number"] = atom.number
148-
graph.nodes[i]["original_index"] = atom.index
149-
graph.nodes[i]["charge"] = float(charges[i])
150-
if atom.number in non_bonding_atomic_numbers:
151-
graph.nodes[i]["charge"] = 1.0
175+
_add_node_properties(graph, atoms, charges, non_bonding_atomic_numbers)
152176

153177
graph.graph["pbc"] = atoms.pbc
154178
graph.graph["cell"] = atoms.cell
@@ -188,6 +212,9 @@ def ase2rdkit(atoms: ase.Atoms, suggestions: list[str] | None = None) -> Chem.Mo
188212
>>> mol.GetNumAtoms()
189213
4
190214
"""
215+
if len(atoms) == 0:
216+
return Chem.Mol()
217+
191218
from rdkit2ase import ase2networkx, networkx2rdkit
192219

193220
graph = ase2networkx(atoms, suggestions=suggestions)

rdkit2ase/bond_order.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import networkx as nx
22
from networkx.algorithms import isomorphism
33

4-
from rdkit2ase.utils import rdkit_determine_bonds, suggestions2networkx, unwrap_molecule
4+
from rdkit2ase.utils import (
5+
rdkit_determine_bonds,
6+
suggestions2networkx,
7+
unwrap_structures,
8+
)
59

610

711
def sort_templates(graphs: list[nx.Graph]) -> list[nx.Graph]:
@@ -99,7 +103,7 @@ def update_bond_order_determine(graph: nx.Graph) -> None:
99103
if missing > 0:
100104
# Unwrapping could be made nicer, by utilizing the connectivity
101105
atoms = networkx2ase(subgraph)
102-
atoms = unwrap_molecule(atoms)
106+
atoms = unwrap_structures(atoms)
103107
rdkit_mol = rdkit_determine_bonds(atoms)
104108
rdkit_graph = rdkit2networkx(rdkit_mol)
105109
# update the bond order in the original graph

rdkit2ase/networkx2x.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def networkx2ase(graph: nx.Graph) -> ase.Atoms:
4343
>>> len(atoms)
4444
2
4545
"""
46+
# Create mapping from original node indices to new sequential indices
47+
node_mapping = {node: i for i, node in enumerate(graph.nodes)}
48+
4649
positions = np.array([graph.nodes[n]["position"] for n in graph.nodes])
4750
numbers = np.array([graph.nodes[n]["atomic_number"] for n in graph.nodes])
4851
charges = np.array([graph.nodes[n]["charge"] for n in graph.nodes])
@@ -58,7 +61,10 @@ def networkx2ase(graph: nx.Graph) -> ase.Atoms:
5861
connectivity = []
5962
for u, v, data in graph.edges(data=True):
6063
bond_order = data["bond_order"]
61-
connectivity.append((u, v, bond_order))
64+
# Map original node indices to new sequential indices
65+
new_u = node_mapping[u]
66+
new_v = node_mapping[v]
67+
connectivity.append((new_u, new_v, bond_order))
6268

6369
atoms.info["connectivity"] = connectivity
6470

rdkit2ase/rdkit2x.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def rdkit2ase(mol: Chem.Mol, seed: int = 42) -> ase.Atoms:
3333
>>> len(atoms)
3434
9
3535
"""
36+
if mol.GetNumAtoms() == 0:
37+
return ase.Atoms()
38+
3639
smiles = Chem.MolToSmiles(mol)
3740
mol = Chem.AddHs(mol)
3841
charges = [atom.GetFormalCharge() for atom in mol.GetAtoms()]

0 commit comments

Comments
 (0)