Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 24 additions & 36 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,26 +531,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
"ranks match on the arguments");

//
// Illustration of steps by an example.
// Strategy. Consider A(ijpab;xy) * B(jiqba;yx) -> C(ipjq), inner xy fully
// contracted. We reduce the contracted-outer indices ab and the contracted-
// inner indices xy together with a single ToT x ToT -> ToT contraction
// whose inner product is annotated to leave a *phantom unit* inner mode
// (⊗₁) on the result, so the inner cell is a genuine (≥ order-1) unit
// tensor rather than the unsupported order-0:
//
// Consider the evaluation: A(ijpab;xy) * B(jiqba;yx) -> C(ipjq).
// C0(ipjq; ⊗₁) = A(ijpab; xy) * B(jiqba; xy,⊗₁)
//
// Note for the outer indices:
// - Hadamard: 'ij'
// - External A: 'p'
// - External B: 'q'
// - Contracted: 'ab'
//
// Now C is evaluated in the following steps.
// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
// Step II: C0(ijpqab;xy) -> C1(ijpqab)
// Step III: C1(ijpqab) -> C2(ijpq)
// Step IV: C2(ijpq) -> C(ipjq)

// Build a "denested" tile: one scalar per outer index, summed over the
// inner tile. The result tile's outer type is TA::Tensor (inner tile
// types like btas::Tensor are only valid as the innermost tile and don't
// expose the range+lambda ctor used here).
// ⊗₁ is appended to B's inner annotation only; B's inner *tensor* is
// unchanged (⊗₁ is phantom unit -- ContEngine recognizes it and realizes
// the inner product as a flat dot into a [1] cell, never requiring B to
// physically carry the extra mode). Each [1] inner cell is then unwrapped
// to a scalar. This never materializes the uncontracted product, and is
// correct when an inner extent depends on a contracted-outer index.
auto sum_tot_2_tos = [](auto const &tot) {
using tot_t = std::remove_reference_t<decltype(tot)>;
using numeric_type = typename tot_t::numeric_type;
Expand All @@ -566,30 +561,23 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
return result;
};

auto const oixs = TensorOpIndices(a, b, c);
// U+2297 CIRCLED TIMES + U+2081 SUBSCRIPT ONE: a reserved phantom-unit
// inner annotator (see is_phantom_unit_label).
const std::string phantom_unit = "⊗₁";

struct {
std::string C0, C1, C2;
} const Cn_annot{
std::string(oixs.ix_C_canon() + oixs.contracted()) + inner.a,
{oixs.ix_C_canon() + oixs.contracted()},
{oixs.ix_C_canon()}};
auto a_annot = std::string(a) + inner.a; // e.g. "ijpab;xy"
auto b_annot =
std::string(b) + inner.b + "," + phantom_unit; // e.g. "jiqba;yx,⊗₁"
auto c_annot = std::string(c) + ";" + phantom_unit; // e.g. "ipjq;⊗₁"

// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
auto C0 = einsum(A, B, Cn_annot.C0);
// C0(c; ⊗₁) = A(a; inner.A) * B(b; inner.B,⊗₁)
auto C0 = einsum(A.array()(a_annot), B.array()(b_annot), c_annot);

// Step II: C0(ijpqab;xy) -> C1(ijpqab)
auto C1 = TA::foreach<typename ArrayC::value_type>(
// unwrap unit-extent inner cells to scalars
ArrayC C = TA::foreach<typename ArrayC::value_type>(
C0, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) {
out_tile = sum_tot_2_tos(in_tile);
});

// Step III: C1(ijpqab) -> C2(ijpq)
auto C2 = reduce_modes(C1, oixs.contracted().size());

// Step IV: C2(ijpq) -> C(ipjq)
ArrayC C;
C(c) = C2(Cn_annot.C2);
return C;

} else {
Expand Down
Loading
Loading