Skip to content

Commit f1dfe9d

Browse files
authored
[PBC Resources Estimates 1/4] Add k-point THC factorization (#821)
* Add k-point THC code. * Add k-thc notebook. * Add utilities. * Add reference data. * Add missing __init__ * Add missing init / skipifs. * Fix formatting. * Resolve review comments. * Address comments. * Remove utils. * No more utils. * More review comments. * Fix formatting. * Formatting + add map to gvec_logic. * Fix import issues. * Mark slow tests and catch imports. * Fix checks. * Fix test failures. * Add ase to resources requirements.
1 parent e693270 commit f1dfe9d

20 files changed

Lines changed: 4734 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pyscf
22
jax
33
jaxlib
4+
ase

dev_tools/requirements/resource_estimates.env.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#
55
# pip-compile --output-file=resource_estimates.env.txt deps/resource_estimates.txt pytest.env.txt
66
#
7+
ase==3.22.1
8+
# via -r deps/resource_estimates.txt
79
attrs==23.1.0
810
# via
911
# -r pytest.env.txt
@@ -83,6 +85,7 @@ kiwisolver==1.4.4
8385
matplotlib==3.7.1
8486
# via
8587
# -r pytest.env.txt
88+
# ase
8689
# cirq-core
8790
ml-dtypes==0.2.0
8891
# via
@@ -101,6 +104,7 @@ networkx==2.8.8
101104
numpy==1.23.5
102105
# via
103106
# -r pytest.env.txt
107+
# ase
104108
# cirq-core
105109
# contourpy
106110
# h5py
@@ -174,6 +178,7 @@ requests==2.31.0
174178
scipy==1.9.3
175179
# via
176180
# -r pytest.env.txt
181+
# ase
177182
# cirq-core
178183
# jax
179184
# jaxlib
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# coverage: ignore
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# coverage: ignore
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from openfermion.resource_estimates import HAVE_DEPS_FOR_RESOURCE_ESTIMATES
15+
16+
if HAVE_DEPS_FOR_RESOURCE_ESTIMATES:
17+
from .hamiltonian import (build_hamiltonian,
18+
build_momentum_transfer_mapping,
19+
cholesky_from_df_ints)
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# coverage: ignore
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
from dataclasses import dataclass, asdict
14+
from typing import Tuple
15+
import h5py
16+
import numpy as np
17+
import numpy.typing as npt
18+
19+
from pyscf import lib
20+
from pyscf.ao2mo import _ao2mo
21+
from pyscf.lib import logger
22+
from pyscf.pbc.df import df
23+
from pyscf.pbc.lib.kpts_helper import gamma_point
24+
from pyscf.pbc.mp.kmp2 import _add_padding
25+
from pyscf.pbc import mp, scf, gto
26+
27+
28+
@dataclass
29+
class HamiltonianProperties:
30+
"""Lighweight descriptive data class to hold return values from
31+
compute_lambda functions.
32+
33+
Attributes:
34+
lambda_total: Total lambda value (norm) of Hamiltonian.
35+
lambda_one_body: One-body lambda value (norm) of Hamiltonian.
36+
lambda_two_body: Two-body lambda value (norm) of Hamiltonian.
37+
"""
38+
39+
lambda_total: float
40+
lambda_one_body: float
41+
lambda_two_body: float
42+
43+
dict = asdict
44+
45+
46+
def build_hamiltonian(mf: "scf.KRHF") -> Tuple[npt.NDArray, npt.NDArray]:
47+
"""Utility function to build one- and two-electron matrix elements from mean
48+
field object.
49+
50+
Arguments:
51+
mf: pyscf KRHF object.
52+
53+
Returns:
54+
hcore_mo: one-body Hamiltonian in MO basis.
55+
chol: 3-index RSGDF density fitted integrals.
56+
"""
57+
# Build temporary mp2 object so MO coeffs can potentially be padded if mean
58+
# field solution yields different number of MOs per k-point.
59+
tmp_mp2 = mp.KMP2(mf)
60+
mo_coeff_padded = _add_padding(tmp_mp2, tmp_mp2.mo_coeff,
61+
tmp_mp2.mo_energy)[0]
62+
hcore_mo = np.asarray([
63+
C.conj().T @ hk @ C for (C, hk) in zip(mo_coeff_padded, mf.get_hcore())
64+
])
65+
chol = cholesky_from_df_ints(tmp_mp2)
66+
return hcore_mo, chol
67+
68+
69+
def cholesky_from_df_ints(mp2_inst, pad_mos_with_zeros=True) -> npt.NDArray:
70+
"""Compute 3-center electron repulsion integrals, i.e. (L|ov),
71+
where `L` denotes DF auxiliary basis functions and `o` and `v` occupied and
72+
virtual canonical crystalline orbitals. Note that `o` and `v` contain kpt
73+
indices `ko` and `kv`, and the third kpt index `kL` is determined by
74+
the conservation of momentum.
75+
76+
Note that if the number of mos differs at each k-point then this function
77+
will pad MOs with zeros to ensure contiguity.
78+
79+
Args:
80+
mp2_inst: pyscf KMP2 instance.
81+
82+
Returns:
83+
Lchol: 3-center DF ints, with shape (nkpts, nkpts, naux, nmo, nmo)
84+
"""
85+
86+
log = logger.Logger(mp2_inst.stdout, mp2_inst.verbose)
87+
88+
if mp2_inst._scf.with_df._cderi is None:
89+
mp2_inst._scf.with_df.build()
90+
91+
cell = mp2_inst._scf.cell
92+
if cell.dimension == 2:
93+
# 2D ERIs are not positive definite. The 3-index tensors are stored in
94+
# two part. One corresponds to the positive part and one corresponds
95+
# to the negative part. The negative part is not considered in the
96+
# DF-driven CCSD implementation.
97+
raise NotImplementedError
98+
99+
# nvir = nmo - nocc
100+
nao = cell.nao_nr()
101+
102+
mo_coeff = mp2_inst._scf.mo_coeff
103+
kpts = mp2_inst.kpts
104+
if pad_mos_with_zeros:
105+
mo_coeff = _add_padding(mp2_inst, mp2_inst.mo_coeff,
106+
mp2_inst.mo_energy)[0]
107+
nmo = mp2_inst.nmo
108+
else:
109+
nmo = nao
110+
num_mo_per_kpt = np.array([C.shape[-1] for C in mo_coeff])
111+
if not (num_mo_per_kpt == nmo).all():
112+
log.info("Number of MOs differs at each k-point or is not the same "
113+
"as the number of AOs.")
114+
nkpts = len(kpts)
115+
if gamma_point(kpts):
116+
dtype = np.double
117+
else:
118+
dtype = np.complex128
119+
dtype = np.result_type(dtype, *mo_coeff)
120+
Lchol = np.empty((nkpts, nkpts), dtype=object)
121+
122+
cput0 = (logger.process_clock(), logger.perf_counter())
123+
124+
bra_start = 0
125+
bra_end = nmo
126+
ket_start = nmo
127+
ket_end = 2 * nmo
128+
with h5py.File(mp2_inst._scf.with_df._cderi, "r") as f:
129+
kptij_lst = f["j3c-kptij"][:]
130+
tao = []
131+
ao_loc = None
132+
for ki, kpti in enumerate(kpts):
133+
for kj, kptj in enumerate(kpts):
134+
kpti_kptj = np.array((kpti, kptj))
135+
Lpq_ao = np.asarray(df._getitem(f, "j3c", kpti_kptj, kptij_lst))
136+
137+
mo = np.hstack((mo_coeff[ki], mo_coeff[kj]))
138+
mo = np.asarray(mo, dtype=dtype, order="F")
139+
if dtype == np.double:
140+
out = _ao2mo.nr_e2(
141+
Lpq_ao,
142+
mo,
143+
(bra_start, bra_end, ket_start, ket_end),
144+
aosym="s2",
145+
)
146+
else:
147+
# Note: Lpq.shape[0] != naux if linear dependency is found
148+
# in auxbasis
149+
if Lpq_ao[0].size != nao**2: # aosym = 's2'
150+
Lpq_ao = lib.unpack_tril(Lpq_ao).astype(np.complex128)
151+
out = _ao2mo.r_e2(
152+
Lpq_ao,
153+
mo,
154+
(bra_start, bra_end, ket_start, ket_end),
155+
tao,
156+
ao_loc,
157+
)
158+
Lchol[ki, kj] = out.reshape(-1, nmo, nmo)
159+
160+
log.timer_debug1("transforming DF-AO integrals to MO", *cput0)
161+
162+
return Lchol
163+
164+
165+
def build_momentum_transfer_mapping(cell: gto.Cell,
166+
kpoints: np.ndarray) -> np.ndarray:
167+
# Define mapping momentum_transfer_map[Q][k1] = k2 that satisfies
168+
# k1 - k2 + G = Q.
169+
a = cell.lattice_vectors() / (2 * np.pi)
170+
delta_k1_k2_Q = (kpoints[:, None, None, :] - kpoints[None, :, None, :] -
171+
kpoints[None, None, :, :])
172+
delta_k1_k2_Q += kpoints[0][None, None, None, :] # shift to center
173+
delta_dot_a = np.einsum("wx,kpQx->kpQw", a, delta_k1_k2_Q)
174+
int_delta_dot_a = np.rint(delta_dot_a)
175+
# Should be zero if transfer is statisfied (2*pi*n)
176+
mapping = np.where(
177+
np.sum(np.abs(delta_dot_a - int_delta_dot_a), axis=3) < 1e-10)
178+
num_kpoints = len(kpoints)
179+
momentum_transfer_map = np.zeros((num_kpoints,) * 2, dtype=np.int32)
180+
# Note index flip due to Q being first index in map but broadcasted last..
181+
momentum_transfer_map[mapping[1], mapping[0]] = mapping[2]
182+
183+
return momentum_transfer_map
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# coverage: ignore
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
import itertools
14+
15+
import numpy as np
16+
import pytest
17+
18+
from openfermion.resource_estimates import HAVE_DEPS_FOR_RESOURCE_ESTIMATES
19+
20+
if HAVE_DEPS_FOR_RESOURCE_ESTIMATES:
21+
from pyscf.pbc import cc, mp
22+
23+
from openfermion.resource_estimates.pbc.hamiltonian import (
24+
build_hamiltonian, cholesky_from_df_ints)
25+
from openfermion.resource_estimates.pbc.testing import make_diamond_113_szv
26+
27+
28+
@pytest.mark.skipif(not HAVE_DEPS_FOR_RESOURCE_ESTIMATES,
29+
reason='pyscf and/or jax not installed.')
30+
def test_build_hamiltonian():
31+
mf = make_diamond_113_szv()
32+
nmo = mf.mo_coeff[0].shape[-1]
33+
naux = 108
34+
hcore, chol = build_hamiltonian(mf)
35+
nkpts = len(mf.mo_coeff)
36+
assert hcore.shape == (nkpts, nmo, nmo)
37+
assert chol.shape == (nkpts, nkpts)
38+
assert chol[0, 0].shape == (naux, nmo, nmo)
39+
40+
41+
@pytest.mark.skipif(not HAVE_DEPS_FOR_RESOURCE_ESTIMATES,
42+
reason='pyscf and/or jax not installed.')
43+
def test_pyscf_chol_from_df():
44+
mf = make_diamond_113_szv()
45+
mymp = mp.KMP2(mf)
46+
nmo = mymp.nmo
47+
nocc = mymp.nocc
48+
nvir = nmo - nocc
49+
Luv = cholesky_from_df_ints(mymp)
50+
51+
# 1. Test that the DF integrals give the correct SCF energy (oo block)
52+
mf.exxdiv = None # exclude ewald exchange correction
53+
Eref = mf.energy_elec()[1]
54+
Eout = 0.0j
55+
nkpts = len(mf.mo_coeff)
56+
for ik, jk in itertools.product(range(nkpts), repeat=2):
57+
Lii = Luv[ik, ik][:, :nocc, :nocc]
58+
Ljj = Luv[jk, jk][:, :nocc, :nocc]
59+
Lij = Luv[ik, jk][:, :nocc, :nocc]
60+
Lji = Luv[jk, ik][:, :nocc, :nocc]
61+
oooo_d = np.einsum("Lij,Lkl->ijkl", Lii, Ljj) / nkpts
62+
oooo_x = np.einsum("Lij,Lkl->ijkl", Lij, Lji) / nkpts
63+
Eout += 2.0 * np.einsum("iijj->", oooo_d)
64+
Eout -= np.einsum("ijji->", oooo_x)
65+
assert abs(Eout / nkpts - Eref) < 1e-12
66+
67+
# 2. Test that the DF integrals agree with those from MP2 (ov block)
68+
from pyscf.pbc.mp.kmp2 import _init_mp_df_eris
69+
70+
Ltest = _init_mp_df_eris(mymp)
71+
for ik, jk in itertools.product(range(nkpts), repeat=2):
72+
assert np.allclose(Luv[ik, jk][:, :nocc, nocc:],
73+
Ltest[ik, jk],
74+
atol=1e-12)
75+
76+
# 3. Test that the DF integrals have correct vvvv block (vv)
77+
Ivvvv = np.zeros((nkpts, nkpts, nkpts, nvir, nvir, nvir, nvir),
78+
dtype=np.complex128)
79+
for ik, jk, kk in itertools.product(range(nkpts), repeat=3):
80+
lk = mymp.khelper.kconserv[ik, jk, kk]
81+
Lij = Luv[ik, jk][:, nocc:, nocc:]
82+
Lkl = Luv[kk, lk][:, nocc:, nocc:]
83+
Imo = np.einsum("Lij,Lkl->ijkl", Lij, Lkl)
84+
Ivvvv[ik, jk, kk] = Imo / nkpts
85+
86+
mycc = cc.KRCCSD(mf)
87+
eris = mycc.ao2mo()
88+
assert np.allclose(eris.vvvv,
89+
Ivvvv.transpose(0, 2, 1, 3, 5, 4, 6),
90+
atol=1e-12)

0 commit comments

Comments
 (0)