diff --git a/arc/common.py b/arc/common.py index 6e8ef4f56b..ef54b430e8 100644 --- a/arc/common.py +++ b/arc/common.py @@ -36,6 +36,12 @@ logger = logging.getLogger('arc') logging.getLogger('matplotlib.font_manager').disabled = True +try: + from rdkit import RDLogger + RDLogger.DisableLog('rdApp.*') +except ImportError: + pass + # Absolute path to the ARC folder. ARC_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) ARC_TESTING_PATH = os.path.join(ARC_PATH, 'arc', 'testing') diff --git a/arc/family/family.py b/arc/family/family.py index 0eb2295009..13fb418bd7 100644 --- a/arc/family/family.py +++ b/arc/family/family.py @@ -24,6 +24,37 @@ logger = get_logger() +REACTION_FAMILY_CACHE: dict[tuple[str, bool], 'ReactionFamily'] = {} + +# Pre-compiled regex patterns +ENTRY_PATTERN = re.compile(r'entry\((.*?)\)', re.DOTALL) +LABEL_PATTERN = re.compile(r'label\s*=\s*(["\'])(.*?)\1|label\s*=\s*(\w+)') +GROUP_PATTERN = re.compile(r'group\s*=\s*(?:("""(.*?)"""|"(.*?)"|\'(.*?)\')|(OR\{.*?\}))', re.DOTALL) +REVERSIBLE_PATTERN = re.compile(r'reversible\s*=\s*(True|False)') +OWN_REVERSE_PATTERN = re.compile(r'ownReverse\s*=\s*(True|False)') +RECIPE_PATTERN = re.compile(r'recipe\((.*?)\)', re.DOTALL) +REACTANTS_PATTERN = re.compile(r'reactants\s*=\s*\[(.*?)\]', re.DOTALL) +PRODUCTS_PATTERN = re.compile(r'products\s*=\s*\[(.*?)\]', re.DOTALL) +ACTIONS_PATTERN = re.compile(r'actions\s*=\s*\[(.*?)\]', re.DOTALL) + + +def get_reaction_family(label: str, consider_arc_families: bool = True) -> 'ReactionFamily': + """ + A helper function for getting a cached ReactionFamily object. + + Args: + label (str): The reaction family label. + consider_arc_families (bool, optional): Whether to consider ARC's custom families. + + Returns: + ReactionFamily: The ReactionFamily object. + """ + key = (label, consider_arc_families) + if key not in REACTION_FAMILY_CACHE: + REACTION_FAMILY_CACHE[key] = ReactionFamily(label=label, consider_arc_families=consider_arc_families) + return REACTION_FAMILY_CACHE[key] + + def get_rmg_db_subpath(*parts: str, must_exist: bool = False) -> str: """Return a path under the RMG database, handling both source and packaged layouts.""" if RMG_DB_PATH is None: @@ -108,7 +139,17 @@ def __init__(self, self.groups_as_lines = read_groups_file_lines(label, consider_arc_families) self.reversible = is_reversible(self.groups_as_lines) self.own_reverse = is_own_reverse(self.groups_as_lines) - self.reactants = get_reactant_groups_from_template(self.groups_as_lines) + + reactant_labels = get_initial_reactant_labels_from_template(self.groups_as_lines) + all_necessary_entries = get_entries(self.groups_as_lines, entry_labels=reactant_labels, recursive=True) + self.reactants = get_reactant_groups_from_template(self.groups_as_lines, entries=all_necessary_entries) + self.entries = all_necessary_entries + + self.groups = {} + for reactant_group in self.reactants: + for label in reactant_group: + if label not in self.groups and label in self.entries: + self.groups[label] = Group().from_adjacency_list(self.entries[label]) self.reactant_num = self.get_reactant_num() self.product_num = get_product_num(self.groups_as_lines) entry_labels = list() @@ -156,6 +197,9 @@ def generate_products(self, for group_label in group_labels: group = self.groups_by_label[group_label] for mol in reactant.mol_list or [reactant.mol]: + if not any(a.atomtype for a in mol.atoms): + # Update atomtypes if they are missing (e.g., from SMILES) + mol.update_atomtypes(log_species=False, raise_exception=False) splits = group.split() if mol.is_subgraph_isomorphic(other=group, save_order=True) \ or len(splits) > 1 and any(mol.is_subgraph_isomorphic(other=g, save_order=True) for g in splits): @@ -297,9 +341,15 @@ def generate_bimolecular_products(self, group_2 = self.groups_by_label[reactant_to_group_map_2['subgroup']] isomorphic_subgraphs_1 = mol_1.find_subgraph_isomorphisms(other=group_1, save_order=True) isomorphic_subgraphs_2 = mol_2.find_subgraph_isomorphisms(other=group_2, save_order=True) + if len(isomorphic_subgraphs_1) and len(isomorphic_subgraphs_2): for isomorphic_subgraph_1 in isomorphic_subgraphs_1: for isomorphic_subgraph_2 in isomorphic_subgraphs_2: + # Create the combined isomorphic subgraph. + # Note: get_isomorphic_subgraph needs to know which subgraph corresponds to which template index. + # It assumes mol_1 corresponds to the first group match and mol_2 to the second. + # The labels are already inside the group_atom.label. + isomorphic_subgraph_dicts.append( {'mols': [mol_1, mol_2], 'subgroups': (reactant_to_group_map_1['subgroup'], @@ -422,7 +472,7 @@ def get_reactant_num(self) -> int: if match: return int(match.group(1)) if len(self.reactants) == 1: - group = Group().from_adjacency_list(get_group_adjlist(self.groups_as_lines, entry_label=self.reactants[0][0])) + group = self.groups[self.reactants[0][0]] groups = group.split() return len(groups) else: @@ -523,7 +573,7 @@ def determine_possible_reaction_products_from_family(rxn: ARCReaction, and whether the family's template also represents its own reverse. """ product_dicts = list() - family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families) + family = get_reaction_family(label=family_label, consider_arc_families=consider_arc_families) products = family.generate_products(reactants=rxn.get_reactants_and_products(return_copies=True)[0]) if products: for group_labels, product_lists in products.items(): @@ -765,11 +815,10 @@ def is_reversible(groups_as_lines: list[str]) -> bool: Returns: bool: Whether the reaction family is reversible. """ - for line in groups_as_lines: - if 'reversible = True' in line: - return True - if 'reversible = False' in line: - return False + groups_str = ''.join(groups_as_lines) + match = REVERSIBLE_PATTERN.search(groups_str) + if match: + return match.group(1) == 'True' return True @@ -780,15 +829,16 @@ def is_own_reverse(groups_as_lines: list[str]) -> bool: Returns: bool: Whether the reaction family's template also represents its own reverse. """ - for line in groups_as_lines: - if 'ownReverse=True' in line: - return True - if 'ownReverse=False' in line: - return False + groups_str = ''.join(groups_as_lines) + match = OWN_REVERSE_PATTERN.search(groups_str) + if match: + return match.group(1) == 'True' return False -def get_reactant_groups_from_template(groups_as_lines: list[str]) -> list[list[str]]: +def get_reactant_groups_from_template(groups_as_lines: list[str], + entries: dict[str, str] | None = None, + ) -> list[list[str]]: """ Get the reactant groups from a template content string. Descends the entries if a group is defined as an OR complex, @@ -796,20 +846,24 @@ def get_reactant_groups_from_template(groups_as_lines: list[str]) -> list[list[s Args: groups_as_lines (list[str]): The template content string. + entries (dict[str, str], optional): Pre-extracted entries. Returns: list[list[str]]: The non-complex reactant groups. """ reactant_labels = get_initial_reactant_labels_from_template(groups_as_lines) + if entries is None: + entries = get_entries(groups_as_lines, entry_labels=reactant_labels) result = list() for reactant_label in reactant_labels: - if 'OR{' not in get_group_adjlist(groups_as_lines, entry_label=reactant_label): + adj = get_group_adjlist(groups_as_lines, entry_label=reactant_label, entries=entries) + if 'OR{' not in adj: result.append([reactant_label]) else: stack = [reactant_label] - while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label) for label in stack): + while any('OR{' in get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) for label in stack): label = stack.pop(0) - group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label) + group_adjlist = get_group_adjlist(groups_as_lines, entry_label=label, entries=entries) if 'OR{' not in group_adjlist: stack.append(label) else: @@ -851,7 +905,7 @@ def descent_complex_group(group: str) -> list[str]: list[str]: The non-complex reactant group labels, e.g.: ['Xtrirad_H', 'Xbirad_H', 'Xrad_H', 'X_H']. """ if group.startswith('OR{') and group.endswith('}'): - group = [g.strip() for g in group[3:-1].split(',')] + group = [c.strip() for c in group[3:-1].split(',')] if isinstance(group, str): group = [group] return group @@ -871,13 +925,15 @@ def get_initial_reactant_labels_from_template(groups_as_lines: list[str], Returns: list[str]: The reactant groups. """ - labels = list() - for line in groups_as_lines: - match = re.search(r'products=\[(.*?)\]', line) if products else re.search(r'reactants=\[(.*?)\]', line) - if match: - labels = match.group(1).replace('"', '').split(', ') - break - return labels + groups_str = ''.join(groups_as_lines) + pattern = PRODUCTS_PATTERN if products else REACTANTS_PATTERN + match = pattern.search(groups_str) + if match: + content = match.group(1) + # Use regex to find all quoted strings (with backreferences) or unquoted words + labels = re.findall(r'(["\'])(.*?)\1|(\w+)', content) + return [label[1] or label[2] for label in labels] + return list() def get_recipe_actions(groups_as_lines: list[str]) -> list[list[str]]: @@ -982,32 +1038,63 @@ def split_entries(groups_str: str) -> list[str]: def get_entries(groups_as_lines: list[str], entry_labels: list[str], + recursive: bool = False, ) -> dict[str, str]: """ - Get the requested entries grom a template content string. + Get the requested entries from a template content string. Args: groups_as_lines (list[str]): The template content string. - entry_labels (list[str]): The entry labels to extract. + entry_labels (list[str], optional): The entry labels to extract. If None, all entries are extracted. + recursive (bool, optional): Whether to recursively extract child entries for OR complexes. Returns: dict[str, str]: The extracted entries, keys are the labels, values are the groups. """ - groups_str = ''.join(groups_as_lines) - entries = split_entries(groups_str) - specific_entries = dict() - for i, entry in enumerate(entries): - label_match = re.search(r'label\s*=\s*"(.*?)"', entry) - group_match = re.search(r'group\s*=(.*?)(?=\w+\s*=)', entry, re.DOTALL) - if label_match is not None and group_match is not None and label_match.group(1) in entry_labels: - specific_entries[label_match.group(1)] = clean_text(group_match.group(1)) - if i > 2000: - break - return specific_entries + groups_str = "\n" + "".join(groups_as_lines) + # Split by `entry(` but keep the delimiter-ish part + parts = re.split(r"\nentry\s*\(", groups_str) + + temp_entries = {} + label_pat = re.compile(r"label\s*=\s*(?:([\"'])(.*?)\1|(\w+))") + group_pat = re.compile(r"group\s*=\s*(?:\"\"\"(.*?)\"\"\"|([\"'])(.*?)\2|(OR\{.*?\}))", re.DOTALL) + + for part in parts[1:]: # Skip the header + label_match = label_pat.search(part) + group_match = group_pat.search(part) + if label_match and group_match: + label = label_match.group(2) or label_match.group(3) + # Extract the matched regex group (1 for triple quotes, 3 for single/double quotes, 4 for OR complex) + adj = group_match.group(1) or group_match.group(3) or group_match.group(4) + temp_entries[label] = clean_text(adj) + + if entry_labels is None: + return temp_entries + + all_entries = {} + to_process = list(entry_labels) + processed = set() + while to_process: + label = to_process.pop() + if label in processed or label not in temp_entries: + continue + processed.add(label) + adj = temp_entries[label] + if recursive and 'OR{' in adj: + # Match OR{label1, label2, ...} + or_match = re.search(r'OR\s*\{\s*(.*?)\s*\}', adj, re.DOTALL) + if or_match: + children_str = or_match.group(1) + children = [c.strip() for c in children_str.split(',')] + to_process.extend(children) + else: + all_entries[label] = adj + return all_entries def get_group_adjlist(groups_as_lines: list[str], entry_label: str, + entries: dict[str, str] | None = None, ) -> str: """ Get the corresponding group value for the given entry label. @@ -1015,10 +1102,13 @@ def get_group_adjlist(groups_as_lines: list[str], Args: groups_as_lines (list[str]): The template content string. entry_label (str): The entry label to extract. + entries (dict[str, str], optional): Pre-extracted entries. Returns: str: The extracted group. """ + if entries is not None and entry_label in entries: + return entries[entry_label] specific_entries = get_entries(groups_as_lines, entry_labels=[entry_label]) return specific_entries[entry_label] diff --git a/arc/job/adapters/ts/heuristics_test.py b/arc/job/adapters/ts/heuristics_test.py index 5b8071df76..d49566a29a 100644 --- a/arc/job/adapters/ts/heuristics_test.py +++ b/arc/job/adapters/ts/heuristics_test.py @@ -36,6 +36,10 @@ from arc.species.species import ARCSpecies from arc.species.zmat import _compare_zmats, get_parameter_from_atom_indices +from arc.species.species import check_isomorphism +from arc.species.zmat import remove_zmat_atom_0 +from arc.species.converter import relocate_zmat_dummy_atoms_to_the_end + class TestHeuristicsAdapter(unittest.TestCase): """ @@ -1409,11 +1413,32 @@ def test_get_new_zmat2_map(self): # expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19, # 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 3, 19: 1, 21: 4, 23: 0, # 25: 7, 26: 6, 28: 5, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'} - expected_new_map = {0: 12, 1: 13, 2: 'X24', 3: 14, 4: 15, 5: 16, 6: 'X25', 7: 17, 8: 'X26', 9: 18, 10: 19, - 11: 20, 12: 21, 13: 22, 14: 'X27', 15: 23, 16: 'X28', 17: 2, 18: 1, 19: 3, 21: 0, 23: 4, - 25: 5, 26: 6, 28: 7, 20: 'X8', 22: 'X9', 24: 'X10', 27: 'X11'} - - self.assertEqual(new_map, expected_new_map) + + # Test isomorphism of the mapped reactant_2 part + zmat_2_mod = remove_zmat_atom_0(self.zmat_6) + zmat_2_mod['map'] = relocate_zmat_dummy_atoms_to_the_end(zmat_2_mod['map']) + spc_from_zmat_2 = ARCSpecies(label='spc_from_zmat_2', xyz=zmat_2_mod, multiplicity=reactant_2.multiplicity, + number_of_radicals=reactant_2.number_of_radicals, charge=reactant_2.charge) + + # Verify that all physical atom indices in new_map that came from zmat_2 correctly map to reactant_2 + # Atom indices in new_map are for the combined species. + # Atoms 0-16 in self.zmat_5, atoms 1-12 in self.zmat_6 (13 atoms total, index 0 removed). + # In get_new_zmat_2_map, zmat_2 atoms are mapped to indices in new_map. + + num_atoms_1 = len(self.zmat_5['symbols']) + atom_map = dict() + for i in range(1, len(self.zmat_6['symbols'])): + if not isinstance(self.zmat_6['symbols'][i], str) or self.zmat_6['symbols'][i] != 'X': + # This is a physical atom in zmat_2 (at index i) + # Its index in the combined Z-Matrix is num_atoms_1 + i - 1 + combined_idx = num_atoms_1 + i - 1 + if combined_idx in new_map: + # new_map[combined_idx] is the index in reactant_2 + # i-1 is the index in spc_from_zmat_2 + atom_map[i-1] = new_map[combined_idx] + + # Verify the atom_map is a valid isomorphism + self.assertTrue(check_isomorphism(spc_from_zmat_2.mol, reactant_2.mol, atom_map)) def test_get_new_map_based_on_zmat_1(self): """Test the get_new_map_based_on_zmat_1() function.""" diff --git a/arc/job/adapters/ts/linear_test.py b/arc/job/adapters/ts/linear_test.py index 9096cd7e00..1494c03f41 100644 --- a/arc/job/adapters/ts/linear_test.py +++ b/arc/job/adapters/ts/linear_test.py @@ -247,9 +247,9 @@ def _make_rxn_2() -> ARCReaction: H 0.27058353 -0.73979548 1.43184405""")]) -class TestHeuristicsAdapter(unittest.TestCase): +class TestLinearAdapter(unittest.TestCase): """ - Contains unit tests for the HeuristicsAdapter class. + Contains unit tests for the LinearAdapter class. """ @classmethod @@ -1133,7 +1133,7 @@ def test_interpolate_1_plus_2_cycloaddition(self): self.assertEqual(len(ts_xyz['symbols']), 10) self.assertFalse(colliding_atoms(ts_xyz), msg=f'Collision in 1+2_Cycloaddition TS:\n{xyz_to_str(ts_xyz)}') - expected_ts = """C 1.59999925 -0.11618654 -0.14166302 + expected_ts_1 = """C 1.59999925 -0.11618654 -0.14166302 C 0.29517860 -0.02143486 -0.02613492 C -1.15821797 -1.12490772 0.14486040 C -0.81238032 0.84414025 0.04444949 @@ -1143,7 +1143,18 @@ def test_interpolate_1_plus_2_cycloaddition(self): H -1.52801447 -1.64655150 -0.72678867 H -0.94547237 1.40230195 0.96062403 H -1.11212744 1.33826544 -0.86912905""" - self.assertTrue(any(almost_equal_coords(ts, str_to_xyz(expected_ts)) for ts in ts_xyzs)) + expected_ts_2 = """C 1.59999925 -0.11618654 -0.14166302 +C 0.29517860 -0.02143486 -0.02613492 +C -0.92013120 -0.71833111 0.10894610 +C -0.99229728 1.28107087 0.04554500 +H 2.21797993 0.77036923 -0.22897655 +H 2.09015362 -1.08321135 -0.15246324 +H -1.12327237 -1.17593811 1.06705013 +H -1.28992770 -1.23997489 -0.76270297 +H -1.12538933 1.83923257 0.96171954 +H -1.29204440 1.77519606 -0.86803354""" + self.assertTrue(any(almost_equal_coords(ts, str_to_xyz(expected_ts_1)) for ts in ts_xyzs) or + any(almost_equal_coords(ts, str_to_xyz(expected_ts_2)) for ts in ts_xyzs)) # The TS should have extended forming bonds compared to the product. # Get atom map to find the carbene atom in the product. atom_map = map_rxn(rxn=rxn) diff --git a/arc/mapping/driver.py b/arc/mapping/driver.py index 4ef47a53b8..e18fbb756f 100644 --- a/arc/mapping/driver.py +++ b/arc/mapping/driver.py @@ -272,6 +272,13 @@ def map_rxn(rxn: ARCReaction, r_bdes, p_bdes = find_all_breaking_bonds(rxn, r_direction=True, pdi=pdi), find_all_breaking_bonds(rxn, r_direction=False, pdi=pdi) r_cuts, p_cuts = cut_species_based_on_atom_indices(reactants, r_bdes), cut_species_based_on_atom_indices(products, p_bdes) + if r_cuts is None or p_cuts is None: + if rxn.product_dicts is not None and len(rxn.product_dicts) - 1 > pdi < MAX_PDI: + return map_rxn(rxn, backend=backend, product_dict_index_to_try=pdi + 1) + else: + logger.error(f'Could not cut species for reaction {rxn}') + return None + try: r_label_map = rxn.product_dicts[pdi]['r_label_map'] p_label_map = rxn.product_dicts[pdi]['p_label_map'] diff --git a/arc/mapping/driver_test.py b/arc/mapping/driver_test.py index 10339634b1..3f4c63833f 100644 --- a/arc/mapping/driver_test.py +++ b/arc/mapping/driver_test.py @@ -1134,7 +1134,7 @@ def test_get_atom_map_2(self): 11 H u0 p0 c0 {4,S} 12 H u0 p0 c0 {5,S}""") rxn = ARCReaction(reactants=['C6H6_a'], products=['C6H6_b'], r_species=[r_1], p_species=[p_1]) - self.assertEqual(rxn.atom_map, [3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11]) + self.assertIn(rxn.atom_map, [[3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11], [3, 2, 1, 0, 5, 4, 10, 9, 8, 6, 7, 11]]) self.assertTrue(check_atom_map(rxn)) # Disproportionation: HO2 + NHOH <=> NH2OH + O2 @@ -1387,7 +1387,7 @@ def test_get_atom_map_6(self): 11 H u0 p0 c0 {4,S} 12 H u0 p0 c0 {5,S}""") rxn = ARCReaction(reactants=['C6H6_1'], products=['C6H6_b'], r_species=[r_1], p_species=[p_1]) - self.assertEqual(rxn.atom_map, [3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11]) + self.assertIn(rxn.atom_map, [[3, 2, 1, 0, 5, 4, 10, 9, 8, 7, 6, 11], [3, 2, 1, 0, 5, 4, 10, 9, 8, 6, 7, 11]]) self.assertTrue(check_atom_map(rxn)) def test_get_atom_map_7(self): @@ -1434,10 +1434,12 @@ def test_get_atom_map_8(self): rxn = ARCReaction(r_species=[ARCSpecies(label="r1", smiles="F[C]F", xyz=r1_xyz), ARCSpecies(label="r2", smiles="[CH3]", xyz=r2_xyz)], p_species=[ARCSpecies(label="p1", smiles="F[C](F)C", xyz=p1_xyz)]) - self.assertIn(rxn.atom_map[:2], [[0, 1], [1, 0]]) - self.assertEqual(rxn.atom_map[2], 2) - self.assertEqual(rxn.atom_map[3], 3) - self.assertIn(tuple(rxn.atom_map[4:]), list(permutations([4, 5, 6]))) + if rxn.atom_map[0] == 0: + self.assertEqual(rxn.atom_map[:4], [0, 1, 2, 3]) + else: # Only other F can be in position 0. + self.assertEqual(rxn.atom_map[:4], [2, 1, 0, 3]) + self.assertIn(tuple(rxn.atom_map[4:]), tuple(permutations([4, 5, 6]))) + self.assertTrue(check_atom_map(rxn)) def test_get_atom_map_9(self): @@ -1556,22 +1558,42 @@ def test_get_atom_map_11(self): rxn = ARCReaction(reactants=['C4H10', 'CO'], products=['C5H10O'], r_species=[r_1, r_2], p_species=[p_1]) atom_map = rxn.atom_map - self.assertEqual(atom_map[:4], [0, 1, 2, 3]) - self.assertIn(tuple(rxn.atom_map[4:7]), permutations([6, 7, 8])) - self.assertEqual(atom_map[7], 15) - self.assertIn(tuple(rxn.atom_map[8:11]), permutations([9, 10, 11])) - self.assertIn(tuple(rxn.atom_map[11:14]), permutations([12, 13, 14])) - self.assertEqual(atom_map[14:], [4, 5]) self.assertTrue(check_atom_map(rxn)) + # Set all anchor- atoms uneffected by symmetry. + self.assertEqual(atom_map[1], 1) # Middle Carbon + self.assertEqual(atom_map[7], 15) # Middel Hydrogen + self.assertEqual(atom_map[-2:], [4, 5]) # CO (In that order!) + # Check the symmetric carbons: + symm_carbon_hydrogens_r = { + 0: [4, 5, 6], + 2: [8, 9, 10], + 3: [11, 12, 13] + } + symm_carbon_hydrogens_p = { + 0: [6, 7, 8], + 2: [9, 10, 11], + 3: [12, 13, 14] + } + for r_atom, p_atom in enumerate(atom_map[:4]): + if r_atom == 1: + continue # anchor carbon. + self.assertIn(p_atom, [0, 2, 3]) + for r_h in symm_carbon_hydrogens_r[r_atom]: + self.assertIn(atom_map[r_h], symm_carbon_hydrogens_p[p_atom]) + + # same reaction in reverse: rxn_rev = ARCReaction(r_species=[p_1], p_species=[r_1, r_2]) atom_map = rxn_rev.atom_map - for index in [0, 2, 3]: - self.assertIn(atom_map[index], [0, 2, 3]) - self.assertEqual(atom_map[1], 1) - self.assertEqual(atom_map[4], 14) - self.assertEqual(atom_map[5], 15) - self.assertEqual(atom_map[15], 7) + self.assertEqual(atom_map[1], 1) # Middle Carbon + self.assertEqual(atom_map[15], 7) # Middel Hydrogen + self.assertEqual(atom_map[4:6], [14, 15]) # CO (In that order!) + for r_atom, p_atom in enumerate(atom_map[:4]): + if r_atom == 1: + continue # anchor carbon. + self.assertIn(p_atom, [0, 2, 3]) + for p_h in symm_carbon_hydrogens_p[r_atom]: + self.assertIn(atom_map[p_h], symm_carbon_hydrogens_r[p_atom]) self.assertTrue(check_atom_map(rxn_rev)) def test_get_atom_map_12(self): diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index f4ba85e5b0..98fa590f03 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -285,12 +285,14 @@ def identify_superimposable_candidates(fingerprint_1: dict[int, dict[str, str | of species 1, values are potentially mapped atom indices of species 2. """ candidates = list() - for key_1 in fingerprint_1.keys(): - for key_2 in fingerprint_2.keys(): - # Try all combinations of heavy atoms. - result = iterative_dfs(fingerprint_1, fingerprint_2, key_1, key_2) - if result is not None: - candidates.append(result) + if not fingerprint_1: + return [] + key_1 = list(fingerprint_1.keys())[0] + for key_2 in fingerprint_2.keys(): + # Try all combinations of heavy atoms. + result = iterative_dfs(fingerprint_1, fingerprint_2, key_1, key_2) + if result is not None: + candidates.append(result) return prune_identical_dicts(candidates) @@ -328,8 +330,7 @@ def iterative_dfs(fingerprint_1: dict[int, dict[str, list[int]]], ) -> dict[int, int] | None: """ A depth first search (DFS) graph traversal algorithm to determine possible superimposable ordering of heavy atoms. - This is an iterative and not a recursive algorithm since Python doesn't have a great support for recursion - since it lacks Tail Recursion Elimination and because there is a limit of recursion stack depth (by default is 1000). + Implemented as a backtracking search to guarantee correctness. Args: fingerprint_1 (dict[int, dict[str, list[int]]]): Adjacent elements dictionary 1 (graph 1). @@ -343,31 +344,67 @@ def iterative_dfs(fingerprint_1: dict[int, dict[str, list[int]]], dict[int, int] | None: ``None`` if this is an invalid superimposable candidate. Keys are atom indices of heavy atoms of species 1, values are potentially mapped atom indices of species 2. """ - visited_1, visited_2 = list(), list() - stack_1, stack_2 = deque(), deque() - stack_1.append(key_1) - stack_2.append(key_2) - result: dict[int, int] = dict() - while stack_1 and stack_2: - current_key_1 = stack_1.pop() - current_key_2 = stack_2.pop() - if current_key_1 in visited_1 or current_key_2 in visited_2: - continue - if not are_adj_elements_in_agreement(fingerprint_1[current_key_1], fingerprint_2[current_key_2]) \ - and not (allow_first_key_pair_to_disagree and len(result) == 0): - continue - visited_1.append(current_key_1) - visited_2.append(current_key_2) - result[current_key_1] = current_key_2 - for symbol in fingerprint_1[current_key_1].keys(): - if symbol not in RESERVED_FINGERPRINT_KEYS + ['H']: - for combination_tuple in product(fingerprint_1[current_key_1][symbol], fingerprint_2[current_key_2][symbol]): - if combination_tuple[0] not in visited_1 and combination_tuple[1] not in visited_2: - stack_1.append(combination_tuple[0]) - stack_2.append(combination_tuple[1]) - if len(result) != len(fingerprint_1): + keys_1 = list(fingerprint_1.keys()) + keys_2 = list(fingerprint_2.keys()) + if len(keys_1) != len(keys_2): return None - return result + + if not allow_first_key_pair_to_disagree: + if not are_adj_elements_in_agreement(fingerprint_1[key_1], fingerprint_2[key_2]): + return None + + mapping = {key_1: key_2} + mapped_2 = {key_2} + + traversal_order = [] + visited = set() + + def dfs_order(k): + visited.add(k) + traversal_order.append(k) + for symbol in fingerprint_1[k].keys(): + if symbol not in RESERVED_FINGERPRINT_KEYS + ['H']: + for nbr in fingerprint_1[k][symbol]: + if nbr not in visited: + dfs_order(nbr) + + dfs_order(key_1) + for k in keys_1: + if k not in visited: + dfs_order(k) + + def backtrack(idx_1): + if idx_1 == len(traversal_order): + return True + k1 = traversal_order[idx_1] + for k2 in keys_2: + if k2 in mapped_2: + continue + if not are_adj_elements_in_agreement(fingerprint_1[k1], fingerprint_2[k2]): + continue + consistent = True + for symbol in fingerprint_1[k1].keys(): + if symbol not in RESERVED_FINGERPRINT_KEYS + ['H']: + for nbr1 in fingerprint_1[k1][symbol]: + if nbr1 in mapping: + if mapping[nbr1] not in fingerprint_2[k2].get(symbol, []): + consistent = False + break + if not consistent: + break + if not consistent: + continue + mapping[k1] = k2 + mapped_2.add(k2) + if backtrack(idx_1 + 1): + return True + del mapping[k1] + mapped_2.remove(k2) + return False + + if backtrack(1): + return mapping + return None def prune_identical_dicts(dicts_list: list[dict]) -> list[dict]: @@ -382,14 +419,7 @@ def prune_identical_dicts(dicts_list: list[dict]) -> list[dict]: """ new_dicts_list = list() for new_dict in dicts_list: - unique_ = True - for existing_dict in new_dicts_list: - if unique_: - for new_key, new_val in new_dict.items(): - if new_key not in existing_dict.keys() or new_val == existing_dict[new_key]: - unique_ = False - break - if unique_: + if new_dict not in new_dicts_list: new_dicts_list.append(new_dict) return new_dicts_list @@ -1195,11 +1225,18 @@ def pairing_reactants_and_products_for_mapping(r_cuts: list[ARCSpecies], list[tuple[ARCSpecies,ARCSpecies]]: A list of paired reactant and products, to be sent to map_two_species. """ pairs: list[tuple[ARCSpecies, ARCSpecies]] = list() - for react in r_cuts: + r_res = [generate_resonance_structures_safely(react.mol, save_order=True) or [react.mol] for react in r_cuts] + for i, react in enumerate(r_cuts): + res1 = r_res[i] for idx, prod in enumerate(p_cuts): - if r_cut_p_cut_isomorphic(react, prod): - pairs.append((react, prod)) - p_cuts.pop(idx) + found = False + for res in res1: + if res.fingerprint == prod.mol.fingerprint or prod.mol.is_isomorphic(res, save_order=True): + pairs.append((react, prod)) + p_cuts.pop(idx) + found = True + break + if found: break return pairs @@ -1460,7 +1497,10 @@ def copy_species_list_for_mapping(species: list["ARCSpecies"]) -> list["ARCSpeci Returns: list[ARCSpecies]: The copied species list. """ - copies = [spc.copy() for spc in species] + copies = list() + for spc in species: + new_spc = ARCSpecies(label=spc.label, mol=spc.mol.copy(deep=True), xyz=spc.get_xyz(), keep_mol=True) + copies.append(new_spc) for copy, spc in zip(copies, species): for atom1, atom2 in zip(copy.mol.atoms, spc.mol.atoms): atom1.label = atom2.label diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index a4810b9636..6b8beff140 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -826,7 +826,7 @@ def test_identify_superimposable_candidates(self): candidates = engine.identify_superimposable_candidates(fingerprint_1=self.butenylnebzene_fingerprint, fingerprint_2=self.butenylnebzene_fingerprint) - self.assertEqual(candidates, [{0: 0, 5: 5, 4: 4, 3: 3, 2: 2, 1: 1, 6: 6, 7: 7, 8: 8, 9: 9}]) + self.assertEqual(candidates[0], {0: 0, 5: 5, 4: 4, 3: 3, 2: 2, 1: 1, 6: 6, 7: 7, 8: 8, 9: 9}) fingerprint_1 = {0: {'self': 'C', 'C': [1, 2, 4], 'H': [11]}, 1: {'self': 'C', 'C': [0, 3, 9], 'H': [12]}, diff --git a/arc/reaction/reaction.py b/arc/reaction/reaction.py index c7a5d639dd..0dde981364 100644 --- a/arc/reaction/reaction.py +++ b/arc/reaction/reaction.py @@ -109,10 +109,13 @@ def __init__(self, self.kinetics = kinetics self.rmg_kinetics = None self.long_kinetic_description = '' - if check_family_name(family): - self.family = family - else: - raise ValueError(f"Invalid family name: {family}") + self._family = None + self._family_determined = False + if family is not None: + if check_family_name(family): + self.family = family + else: + raise ValueError(f"Invalid family name: {family}") self._family_own_reverse = False self.ts_label = ts_label self.dh_rxn298 = None @@ -201,7 +204,7 @@ def multiplicity(self): if self._multiplicity is not None: logger.info(f'Setting multiplicity of reaction {self.label} to {self._multiplicity}') else: - logger.Error(f'Could not determine multiplicity for the reaction: {self.label}') + logger.error(f'Could not determine multiplicity for the reaction: {self.label}') return self._multiplicity @multiplicity.setter @@ -214,14 +217,16 @@ def multiplicity(self, value): @property def family(self): """The RMG reaction family""" - if self._family is None: + if not self._family_determined: self._family, self._family_own_reverse = self.determine_family() + self._family_determined = True return self._family @family.setter def family(self, value): """Allow setting family""" self._family = value + self._family_determined = True if value is not None and not isinstance(value, str): raise InputError(f'Reaction family must be a string, got {value} which is a {type(value)}.') diff --git a/arc/species/species.py b/arc/species/species.py index 420f84cd40..4ed34eb98d 100644 --- a/arc/species/species.py +++ b/arc/species/species.py @@ -463,7 +463,7 @@ def __init__(self, self.multiplicity = self.mol.multiplicity if self.charge is None: self.charge = self.mol.get_net_charge() - if regen_mol: + if regen_mol and not (self.mol is not None and self.keep_mol): # Perceive molecule from xyz coordinates. This also populates the .mol attribute of the Species. # It overrides self.mol generated from adjlist or smiles so xyz and mol will have the same atom order. if self.final_xyz or self.initial_xyz or self.most_stable_conformer or self.conformers or self.ts_guesses: @@ -2007,12 +2007,17 @@ def _scissors(self, sort_atoms_in_descending_label_order(split) if len(mol_splits) == 1: # If cutting leads to only one split, then the split is cyclic. + mol1 = mol_splits[0] + self._assign_radicals_after_scission(mol=mol1) + mol1.update_multiplicity() spc1 = ARCSpecies(label=self.label + '_BDE_' + str(indices[0] + 1) + '_' + str(indices[1] + 1) + '_cyclic', - mol=mol_splits[0], - multiplicity=mol_splits[0].multiplicity, - charge=mol_splits[0].get_net_charge(), + mol=mol1, + xyz=self.final_xyz, + multiplicity=mol1.multiplicity, + charge=mol1.get_net_charge(), compute_thermo=False, - e0_only=True) + e0_only=True, + keep_mol=True) spc1.generate_conformers(economic_generation=True) return [spc1] elif len(mol_splits) == 2: @@ -2035,19 +2040,7 @@ def _scissors(self, added_radical = list() for mol, label in zip([mol1, mol2], [label1, label2]): - for atom in mol.atoms: - theoretical_charge = elements.PeriodicSystem.valence_electrons[atom.symbol] \ - - atom.get_total_bond_order() \ - - atom.radical_electrons - \ - 2 * atom.lone_pairs - if theoretical_charge == atom.charge + 1: - # we're missing a radical electron on this atom - if label not in added_radical or label == 'H': - atom.radical_electrons += 1 - added_radical.append(label) - else: - raise SpeciesError(f'Could not figure out which atom should gain a radical ' - f'due to scission in {self.label}') + self._assign_radicals_after_scission(mol=mol, label=label, added_radical=added_radical) mol1.update(log_species=False, raise_atomtype_exception=False, sort_atoms=False) mol2.update(log_species=False, raise_atomtype_exception=False, sort_atoms=False) @@ -2083,6 +2076,34 @@ def _scissors(self, return [spc1, spc2] + def _assign_radicals_after_scission(self, + mol: Molecule, + label: str = None, + added_radical: list = None): + """ + A helper function to assign radical electrons to atoms after scission. + + Args: + mol (Molecule): The molecule to update. + label (str, optional): The label of the species. + added_radical (list, optional): A list of labels for which a radical was already added. + """ + for atom in mol.atoms: + theoretical_charge = elements.PeriodicSystem.valence_electrons[atom.symbol] \ + - atom.get_total_bond_order() \ + - atom.radical_electrons - \ + 2 * atom.lone_pairs + if theoretical_charge == atom.charge + 1: + # we're missing a radical electron on this atom + if added_radical is None: + atom.radical_electrons += 1 + elif label not in added_radical or label == 'H': + atom.radical_electrons += 1 + added_radical.append(label) + else: + raise SpeciesError(f'Could not figure out which atom should gain a radical ' + f'due to scission in {self.label}') + def populate_ts_checks(self): """Populate (or restart) the .ts_checks attribute with default (``None``) values.""" if self.is_ts: diff --git a/arc/species/species_test.py b/arc/species/species_test.py index 9cd24fd4dc..6afd3fd0ff 100644 --- a/arc/species/species_test.py +++ b/arc/species/species_test.py @@ -2133,7 +2133,7 @@ def test_scissors(self): cycle.final_xyz = cycle.get_xyz() cycle_scissors = cycle.scissors() cycle_scissors[0].mol.update(sort_atoms=False) - self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2+]C[CH2+]").mol)) + self.assertTrue(cycle_scissors[0].mol.is_isomorphic(ARCSpecies(label="check",smiles ="[CH2]C[CH2]").mol)) self.assertEqual(len(cycle_scissors), 1) benzyl_alcohol = ARCSpecies(label='benzyl_alcohol', smiles='c1ccccc1CO', @@ -2973,6 +2973,49 @@ def test_kabsch(self): self.spc1.kabsch(self.spc1, [0, 1, 2]) + def test_assign_radicals_after_scission_cyclic(self): + """ + Test radical assignment for a cyclic scission (single molecule result). + Using Cyclopropane to represent a true ring opening. + """ + mol = Molecule().from_smiles('C1CC1') + + # Find a C-C bond to remove to simulate a ring opening + for bond in mol.get_all_edges(): + if bond.atom1.is_carbon() and bond.atom2.is_carbon(): + c1, c2 = bond.atom1, bond.atom2 + mol.remove_bond(bond) + break + + self.assertEqual(c1.radical_electrons, 0) + self.assertEqual(c2.radical_electrons, 0) + + spc = ARCSpecies(label='cyclopropane', mol=Molecule().from_smiles('C1CC1')) + spc._assign_radicals_after_scission(mol=mol) + + self.assertEqual(c1.radical_electrons, 1) + self.assertEqual(c2.radical_electrons, 1) + + def test_assign_radicals_after_scission_with_added_radical_list(self): + """ + Test radical assignment using the added_radical tracking list (non-cyclic scission). + """ + mol1 = Molecule().from_smiles('[CH3]') + mol1.atoms[0].radical_electrons = 0 + + spc = ARCSpecies(label='parent', mol=Molecule().from_smiles('CC')) + added_radical = [] + + spc._assign_radicals_after_scission(mol=mol1, label='fragment_A', added_radical=added_radical) + self.assertEqual(mol1.atoms[0].radical_electrons, 1) + self.assertEqual(added_radical, ['fragment_A']) + + # Reset the radical electron to simulate another atom in the same fragment needing one + mol1.atoms[0].radical_electrons = 0 + with self.assertRaises(SpeciesError): + spc._assign_radicals_after_scission(mol=mol1, label='fragment_A', added_radical=added_radical) + + class TestTSGuess(unittest.TestCase): """ Contains unit tests for the TSGuess class diff --git a/arc/species/zmat.py b/arc/species/zmat.py index 591fbfc3b0..4a2df672d1 100644 --- a/arc/species/zmat.py +++ b/arc/species/zmat.py @@ -667,10 +667,13 @@ def determine_d_atoms_from_connectivity(zmat: dict, if num_of_neighbors == 1: # Atom A is only connected to B, use the dummy atom on atom B as atom A. b_neighbors = connectivity[atom_b] - x_neighbor = [neighbor for neighbor in b_neighbors - if xyz['symbols'][neighbor] == 'X'][0] - if key_by_val(zmat['map'], f'X{x_neighbor}') not in d_atoms: - zmat_index = key_by_val(zmat['map'], f'X{x_neighbor}') + try: + x_neighbor = [neighbor for neighbor in b_neighbors + if xyz['symbols'][neighbor] == 'X'][0] + if key_by_val(zmat['map'], f'X{x_neighbor}') not in d_atoms: + zmat_index = key_by_val(zmat['map'], f'X{x_neighbor}') + break + except (IndexError, ValueError): break elif num_of_neighbors == 2: # atom A is only connected to B and E, check the E -- B -- C angle. diff --git a/devtools/sella_environment.yml b/devtools/sella_environment.yml index 2bb9f530b6..34a9fb8538 100644 --- a/devtools/sella_environment.yml +++ b/devtools/sella_environment.yml @@ -8,6 +8,7 @@ dependencies: - pandas - ncurses - numpy >=1.26,<2.0 + - scipy >=1.14.1 - typing-extensions - pip - pip: