Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions arc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
logger = logging.getLogger('arc')
logging.getLogger('matplotlib.font_manager').disabled = True

try:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
except ImportError:
Comment thread
kfir4444 marked this conversation as resolved.
Dismissed
pass

# Absolute path to the ARC folder.
ARC_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
ARC_TESTING_PATH = os.path.join(ARC_PATH, 'arc', 'testing')
Expand Down
166 changes: 128 additions & 38 deletions arc/family/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,37 @@
logger = get_logger()


REACTION_FAMILY_CACHE: dict[tuple[str, bool], 'ReactionFamily'] = {}

# Pre-compiled regex patterns
ENTRY_PATTERN = re.compile(r'entry\((.*?)\)', re.DOTALL)
LABEL_PATTERN = re.compile(r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)')
GROUP_PATTERN = re.compile(r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))', re.DOTALL)
REVERSIBLE_PATTERN = re.compile(r'reversible\s*=\s*(True|False)')
OWN_REVERSE_PATTERN = re.compile(r'ownReverse\s*=\s*(True|False)')
RECIPE_PATTERN = re.compile(r'recipe\((.*?)\)', re.DOTALL)
REACTANTS_PATTERN = re.compile(r'reactants\s*=\s*\[(.*?)\]', re.DOTALL)
PRODUCTS_PATTERN = re.compile(r'products\s*=\s*\[(.*?)\]', re.DOTALL)
ACTIONS_PATTERN = re.compile(r'actions\s*=\s*\[(.*?)\]', re.DOTALL)


def get_reaction_family(label: str, consider_arc_families: bool = True) -> 'ReactionFamily':
"""
A helper function for getting a cached ReactionFamily object.

Args:
label (str): The reaction family label.
consider_arc_families (bool, optional): Whether to consider ARC's custom families.

Returns:
ReactionFamily: The ReactionFamily object.
"""
key = (label, consider_arc_families)
if key not in REACTION_FAMILY_CACHE:
REACTION_FAMILY_CACHE[key] = ReactionFamily(label=label, consider_arc_families=consider_arc_families)
return REACTION_FAMILY_CACHE[key]


def get_rmg_db_subpath(*parts: str, must_exist: bool = False) -> str:
"""Return a path under the RMG database, handling both source and packaged layouts."""
if RMG_DB_PATH is None:
Expand Down Expand Up @@ -108,7 +139,17 @@ def __init__(self,
self.groups_as_lines = read_groups_file_lines(label, consider_arc_families)
self.reversible = is_reversible(self.groups_as_lines)
self.own_reverse = is_own_reverse(self.groups_as_lines)
self.reactants = get_reactant_groups_from_template(self.groups_as_lines)

reactant_labels = get_initial_reactant_labels_from_template(self.groups_as_lines)
all_necessary_entries = get_entries(self.groups_as_lines, entry_labels=reactant_labels, recursive=True)
self.reactants = get_reactant_groups_from_template(self.groups_as_lines, entries=all_necessary_entries)
self.entries = all_necessary_entries

self.groups = {}
for reactant_group in self.reactants:
for label in reactant_group:
if label not in self.groups and label in self.entries:
self.groups[label] = Group().from_adjacency_list(self.entries[label])
Comment thread
alongd marked this conversation as resolved.
self.reactant_num = self.get_reactant_num()
self.product_num = get_product_num(self.groups_as_lines)
entry_labels = list()
Expand Down Expand Up @@ -156,6 +197,9 @@ def generate_products(self,
for group_label in group_labels:
group = self.groups_by_label[group_label]
for mol in reactant.mol_list or [reactant.mol]:
if not any(a.atomtype for a in mol.atoms):
# Update atomtypes if they are missing (e.g., from SMILES)
mol.update_atomtypes(log_species=False, raise_exception=False)
splits = group.split()
if mol.is_subgraph_isomorphic(other=group, save_order=True) \
or len(splits) > 1 and any(mol.is_subgraph_isomorphic(other=g, save_order=True) for g in splits):
Expand Down Expand Up @@ -297,9 +341,15 @@ def generate_bimolecular_products(self,
group_2 = self.groups_by_label[reactant_to_group_map_2['subgroup']]
isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=group_1, save_order=True)
isomorphic_subgraphs_2 = mol_2.find_subgraph_isomorphisms(other=group_2, save_order=True)

if len(isomorphic_subgraphs_1) and len(isomorphic_subgraphs_2):
for isomorphic_subgraph_1 in isomorphic_subgraphs_1:
for isomorphic_subgraph_2 in isomorphic_subgraphs_2:
# Create the combined isomorphic subgraph.
# Note: get_isomorphic_subgraph needs to know which subgraph corresponds to which template index.
# It assumes mol_1 corresponds to the first group match and mol_2 to the second.
# The labels are already inside the group_atom.label.

isomorphic_subgraph_dicts.append(
{'mols': [mol_1, mol_2],
'subgroups': (reactant_to_group_map_1['subgroup'],
Expand Down Expand Up @@ -422,7 +472,7 @@ def get_reactant_num(self) -> int:
if match:
return int(match.group(1))
if len(self.reactants) == 1:
group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0]))
group = self.groups[self.reactants[0][0]]
groups = group.split()
return len(groups)
else:
Expand Down Expand Up @@ -523,7 +573,7 @@ def determine_possible_reaction_products_from_family(rxn: ARCReaction,
and whether the family's template also represents its own reverse.
"""
product_dicts = list()
family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families)
family = get_reaction_family(label=family_label, consider_arc_families=consider_arc_families)
products = family.generate_products(reactants=rxn.get_reactants_and_products(return_copies=True)[0])
if products:
for group_labels, product_lists in products.items():
Expand Down Expand Up @@ -765,11 +815,10 @@ def is_reversible(groups_as_lines: list[str]) -> bool:
Returns:
bool: Whether the reaction family is reversible.
"""
for line in groups_as_lines:
if 'reversible = True' in line:
return True
if 'reversible = False' in line:
return False
groups_str = ''.join(groups_as_lines)
match = REVERSIBLE_PATTERN.search(groups_str)
if match:
return match.group(1) == 'True'
return True


Expand All @@ -780,36 +829,41 @@ def is_own_reverse(groups_as_lines: list[str]) -> bool:
Returns:
bool: Whether the reaction family's template also represents its own reverse.
"""
for line in groups_as_lines:
if 'ownReverse=True' in line:
return True
if 'ownReverse=False' in line:
return False
groups_str = ''.join(groups_as_lines)
match = OWN_REVERSE_PATTERN.search(groups_str)
if match:
return match.group(1) == 'True'
return False


def get_reactant_groups_from_template(groups_as_lines: list[str]) -> list[list[str]]:
def get_reactant_groups_from_template(groups_as_lines: list[str],
entries: dict[str, str] | None = None,
) -> list[list[str]]:
"""
Get the reactant groups from a template content string.
Descends the entries if a group is defined as an OR complex,
e.g.: group = "OR{Xtrirad_H, Xbirad_H, Xrad_H, X_H}"

Args:
groups_as_lines (list[str]): The template content string.
entries (dict[str, str], optional): Pre-extracted entries.

Returns:
list[list[str]]: The non-complex reactant groups.
"""
reactant_labels = get_initial_reactant_labels_from_template(groups_as_lines)
if entries is None:
entries = get_entries(groups_as_lines, entry_labels=reactant_labels)
result = list()
for reactant_label in reactant_labels:
if 'OR{' not in get_group_adjlist(groups_as_lines, entry_label=reactant_label):
adj = get_group_adjlist(groups_as_lines, entry_label=reactant_label, entries=entries)
if 'OR{' not in adj:
result.append([reactant_label])
else:
stack = [reactant_label]
while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label) for label in stack):
while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) for label in stack):
label = stack.pop(0)
group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label)
group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label, entries=entries)
if 'OR{' not in group_adjlist:
stack.append(label)
else:
Expand Down Expand Up @@ -851,7 +905,7 @@ def descent_complex_group(group: str) -> list[str]:
list[str]: The non-complex reactant group labels, e.g.: ['Xtrirad_H', 'Xbirad_H', 'Xrad_H', 'X_H'].
"""
if group.startswith('OR{') and group.endswith('}'):
group = [g.strip() for g in group[3:-1].split(',')]
group = [c.strip() for c in group[3:-1].split(',')]
if isinstance(group, str):
group = [group]
return group
Expand All @@ -871,13 +925,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: list[str],
Returns:
list[str]: The reactant groups.
"""
labels = list()
for line in groups_as_lines:
match = re.search(r'products=\[(.*?)\]', line) if products else re.search(r'reactants=\[(.*?)\]', line)
if match:
labels = match.group(1).replace('"', '').split(', ')
break
return labels
groups_str = ''.join(groups_as_lines)
pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN
match = pattern.search(groups_str)
if match:
content = match.group(1)
# Use regex to find all quoted strings (with backreferences) or unquoted words
labels = re.findall(r'(["\'])(.*?)\1|(\w+)', content)
return [label[1] or label[2] for label in labels]
return list()


def get_recipe_actions(groups_as_lines: list[str]) -> list[list[str]]:
Expand Down Expand Up @@ -982,43 +1038,77 @@ def split_entries(groups_str: str) -> list[str]:

def get_entries(groups_as_lines: list[str],
entry_labels: list[str],
recursive: bool = False,
) -> dict[str, str]:
"""
Get the requested entries grom a template content string.
Get the requested entries from a template content string.

Args:
groups_as_lines (list[str]): The template content string.
entry_labels (list[str]): The entry labels to extract.
entry_labels (list[str], optional): The entry labels to extract. If None, all entries are extracted.
recursive (bool, optional): Whether to recursively extract child entries for OR complexes.

Returns:
dict[str, str]: The extracted entries, keys are the labels, values are the groups.
"""
groups_str = ''.join(groups_as_lines)
entries = split_entries(groups_str)
specific_entries = dict()
for i, entry in enumerate(entries):
label_match = re.search(r'label\s*=\s*"(.*?)"', entry)
group_match = re.search(r'group\s*=(.*?)(?=\w+\s*=)', entry, re.DOTALL)
if label_match is not None and group_match is not None and label_match.group(1) in entry_labels:
specific_entries[label_match.group(1)] = clean_text(group_match.group(1))
if i > 2000:
break
return specific_entries
groups_str = "\n" + "".join(groups_as_lines)
# Split by `entry(` but keep the delimiter-ish part
parts = re.split(r"\nentry\s*\(", groups_str)

temp_entries = {}
label_pat = re.compile(r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))")
group_pat = re.compile(r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))", re.DOTALL)

for part in parts[1:]: # Skip the header
label_match = label_pat.search(part)
group_match = group_pat.search(part)
if label_match and group_match:
label = label_match.group(2) or label_match.group(3)
# Extract the matched regex group (1 for triple quotes, 3 for single/double quotes, 4 for OR complex)
adj = group_match.group(1) or group_match.group(3) or group_match.group(4)
temp_entries[label] = clean_text(adj)

if entry_labels is None:
return temp_entries

all_entries = {}
to_process = list(entry_labels)
processed = set()
while to_process:
label = to_process.pop()
if label in processed or label not in temp_entries:
continue
processed.add(label)
adj = temp_entries[label]
if recursive and 'OR{' in adj:
# Match OR{label1, label2, ...}
or_match = re.search(r'OR\s*\{\s*(.*?)\s*\}', adj, re.DOTALL)
if or_match:
children_str = or_match.group(1)
children = [c.strip() for c in children_str.split(',')]
to_process.extend(children)
else:
all_entries[label] = adj
return all_entries


def get_group_adjlist(groups_as_lines: list[str],
entry_label: str,
entries: dict[str, str] | None = None,
) -> str:
"""
Get the corresponding group value for the given entry label.

Args:
groups_as_lines (list[str]): The template content string.
entry_label (str): The entry label to extract.
entries (dict[str, str], optional): Pre-extracted entries.

Returns:
str: The extracted group.
"""
if entries is not None and entry_label in entries:
return entries[entry_label]
specific_entries = get_entries(groups_as_lines, entry_labels=[entry_label])
return specific_entries[entry_label]

Expand Down
35 changes: 30 additions & 5 deletions arc/job/adapters/ts/heuristics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from arc.species.species import ARCSpecies
from arc.species.zmat import _compare_zmats, get_parameter_from_atom_indices

from arc.species.species import check_isomorphism
from arc.species.zmat import remove_zmat_atom_0
from arc.species.converter import relocate_zmat_dummy_atoms_to_the_end


class TestHeuristicsAdapter(unittest.TestCase):
"""
Expand Down Expand Up @@ -1409,11 +1413,32 @@ def test_get_new_zmat2_map(self):
# expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
# 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0,
# 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}
expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19,
11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4,
25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'}

self.assertEqual(new_map, expected_new_map)

# Test isomorphism of the mapped reactant_2 part
zmat_2_mod = remove_zmat_atom_0(self.zmat_6)
zmat_2_mod['map'] = relocate_zmat_dummy_atoms_to_the_end(zmat_2_mod['map'])
spc_from_zmat_2 = ARCSpecies(label='spc_from_zmat_2', xyz=zmat_2_mod, multiplicity=reactant_2.multiplicity,
number_of_radicals=reactant_2.number_of_radicals, charge=reactant_2.charge)

# Verify that all physical atom indices in new_map that came from zmat_2 correctly map to reactant_2
# Atom indices in new_map are for the combined species.
# Atoms 0-16 in self.zmat_5, atoms 1-12 in self.zmat_6 (13 atoms total, index 0 removed).
# In get_new_zmat_2_map, zmat_2 atoms are mapped to indices in new_map.

num_atoms_1 = len(self.zmat_5['symbols'])
atom_map = dict()
for i in range(1, len(self.zmat_6['symbols'])):
if not isinstance(self.zmat_6['symbols'][i], str) or self.zmat_6['symbols'][i] != 'X':
# This is a physical atom in zmat_2 (at index i)
# Its index in the combined Z-Matrix is num_atoms_1 + i - 1
combined_idx = num_atoms_1 + i - 1
if combined_idx in new_map:
# new_map[combined_idx] is the index in reactant_2
# i-1 is the index in spc_from_zmat_2
atom_map[i-1] = new_map[combined_idx]

# Verify the atom_map is a valid isomorphism
self.assertTrue(check_isomorphism(spc_from_zmat_2.mol, reactant_2.mol, atom_map))

def test_get_new_map_based_on_zmat_1(self):
"""Test the get_new_map_based_on_zmat_1() function."""
Expand Down
19 changes: 15 additions & 4 deletions arc/job/adapters/ts/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def _make_rxn_2() -> ARCReaction:
H 0.27058353 -0.73979548 1.43184405""")])


class TestHeuristicsAdapter(unittest.TestCase):
class TestLinearAdapter(unittest.TestCase):
"""
Contains unit tests for the HeuristicsAdapter class.
Contains unit tests for the LinearAdapter class.
"""

@classmethod
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def test_interpolate_1_plus_2_cycloaddition(self):
self.assertEqual(len(ts_xyz['symbols']), 10)
self.assertFalse(colliding_atoms(ts_xyz),
msg=f'Collision in 1+2_Cycloaddition TS:\n{xyz_to_str(ts_xyz)}')
expected_ts = """C 1.59999925 -0.11618654 -0.14166302
expected_ts_1 = """C 1.59999925 -0.11618654 -0.14166302
C 0.29517860 -0.02143486 -0.02613492
C -1.15821797 -1.12490772 0.14486040
C -0.81238032 0.84414025 0.04444949
Expand All @@ -1143,7 +1143,18 @@ def test_interpolate_1_plus_2_cycloaddition(self):
H -1.52801447 -1.64655150 -0.72678867
H -0.94547237 1.40230195 0.96062403
H -1.11212744 1.33826544 -0.86912905"""
self.assertTrue(any(almost_equal_coords(ts, str_to_xyz(expected_ts)) for ts in ts_xyzs))
expected_ts_2 = """C 1.59999925 -0.11618654 -0.14166302
C 0.29517860 -0.02143486 -0.02613492
C -0.92013120 -0.71833111 0.10894610
C -0.99229728 1.28107087 0.04554500
H 2.21797993 0.77036923 -0.22897655
H 2.09015362 -1.08321135 -0.15246324
H -1.12327237 -1.17593811 1.06705013
H -1.28992770 -1.23997489 -0.76270297
H -1.12538933 1.83923257 0.96171954
H -1.29204440 1.77519606 -0.86803354"""
self.assertTrue(any(almost_equal_coords(ts, str_to_xyz(expected_ts_1)) for ts in ts_xyzs) or
any(almost_equal_coords(ts, str_to_xyz(expected_ts_2)) for ts in ts_xyzs))
# The TS should have extended forming bonds compared to the product.
# Get atom map to find the carbene atom in the product.
atom_map = map_rxn(rxn=rxn)
Expand Down
Loading