diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index d87cf3ff04..89fd0db4cf 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -531,26 +531,21 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, "ranks match on the arguments"); // - // Illustration of steps by an example. + // Strategy. Consider A(ijpab;xy) * B(jiqba;yx) -> C(ipjq), inner xy fully + // contracted. We reduce the contracted-outer indices ab and the contracted- + // inner indices xy together with a single ToT x ToT -> ToT contraction + // whose inner product is annotated to leave a *phantom unit* inner mode + // (⊗₁) on the result, so the inner cell is a genuine (≥ order-1) unit + // tensor rather than the unsupported order-0: // - // Consider the evaluation: A(ijpab;xy) * B(jiqba;yx) -> C(ipjq). + // C0(ipjq; ⊗₁) = A(ijpab; xy) * B(jiqba; xy,⊗₁) // - // Note for the outer indices: - // - Hadamard: 'ij' - // - External A: 'p' - // - External B: 'q' - // - Contracted: 'ab' - // - // Now C is evaluated in the following steps. - // Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy) - // Step II: C0(ijpqab;xy) -> C1(ijpqab) - // Step III: C1(ijpqab) -> C2(ijpq) - // Step IV: C2(ijpq) -> C(ipjq) - - // Build a "denested" tile: one scalar per outer index, summed over the - // inner tile. The result tile's outer type is TA::Tensor (inner tile - // types like btas::Tensor are only valid as the innermost tile and don't - // expose the range+lambda ctor used here). + // ⊗₁ is appended to B's inner annotation only; B's inner *tensor* is + // unchanged (⊗₁ is phantom unit -- ContEngine recognizes it and realizes + // the inner product as a flat dot into a [1] cell, never requiring B to + // physically carry the extra mode). Each [1] inner cell is then unwrapped + // to a scalar. This never materializes the uncontracted product, and is + // correct when an inner extent depends on a contracted-outer index. auto sum_tot_2_tos = [](auto const &tot) { using tot_t = std::remove_reference_t; using numeric_type = typename tot_t::numeric_type; @@ -566,30 +561,23 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, return result; }; - auto const oixs = TensorOpIndices(a, b, c); + // U+2297 CIRCLED TIMES + U+2081 SUBSCRIPT ONE: a reserved phantom-unit + // inner annotator (see is_phantom_unit_label). + const std::string phantom_unit = "⊗₁"; - struct { - std::string C0, C1, C2; - } const Cn_annot{ - std::string(oixs.ix_C_canon() + oixs.contracted()) + inner.a, - {oixs.ix_C_canon() + oixs.contracted()}, - {oixs.ix_C_canon()}}; + auto a_annot = std::string(a) + inner.a; // e.g. "ijpab;xy" + auto b_annot = + std::string(b) + inner.b + "," + phantom_unit; // e.g. "jiqba;yx,⊗₁" + auto c_annot = std::string(c) + ";" + phantom_unit; // e.g. "ipjq;⊗₁" - // Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy) - auto C0 = einsum(A, B, Cn_annot.C0); + // C0(c; ⊗₁) = A(a; inner.A) * B(b; inner.B,⊗₁) + auto C0 = einsum(A.array()(a_annot), B.array()(b_annot), c_annot); - // Step II: C0(ijpqab;xy) -> C1(ijpqab) - auto C1 = TA::foreach( + // unwrap unit-extent inner cells to scalars + ArrayC C = TA::foreach( C0, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) { out_tile = sum_tot_2_tos(in_tile); }); - - // Step III: C1(ijpqab) -> C2(ijpq) - auto C2 = reduce_modes(C1, oixs.contracted().size()); - - // Step IV: C2(ijpq) -> C(ipjq) - ArrayC C; - C(c) = C2(Cn_annot.C2); return C; } else { diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 9091445f4e..7597479d74 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -629,22 +629,6 @@ class ContEngine : public BinaryEngine { // element_return_op_ left null: a view cell cannot be // value-returned (see the init_struct precondition check). } else if (inner_prod == TensorProduct::Contraction) { - using op_type = TiledArray::detail::ContractReduce< - result_tile_element_type, left_tile_element_type, - right_tile_element_type, scalar_type>; - // The inner op is built *perm-free* on purpose. factor_ is absorbed - // into element_nonreturn_op_; operand inner transposes are folded - // into the inner GEMM via left_/right_inner_permtype_. A non-identity - // inner *result* permutation is NOT placed on this op - // (make_fused_contraction_lambda asserts a perm-free op); it is - // applied downstream instead -- by op_'s post-processing permute for - // a contraction outer product, or by arena_hadamard_inner_contract's - // slab-level post-pass for a Hadamard outer product. - auto contrreduce_op = op_type( - to_cblas_op(this->left_inner_permtype_), - to_cblas_op(this->right_inner_permtype_), this->factor_, - inner_size(this->indices_), inner_size(this->left_indices_), - inner_size(this->right_indices_)); constexpr bool arena_eligible = TiledArray::detail::is_contraction_arena_tot_v< result_tile_type, left_tile_type, right_tile_type>; @@ -653,42 +637,115 @@ class ContEngine : public BinaryEngine { "nested contraction on view inner tiles is supported only " "for arena-backed tensors-of-tensors"); } else { - // perm-free per-cell in-place contraction; used by both outer - // regimes below - this->element_nonreturn_op_ = - TiledArray::detail::make_fused_contraction_lambda< - result_tile_element_type, left_tile_element_type, - right_tile_element_type>(contrreduce_op); - if (this->product_type() == TensorProduct::Contraction) { - // outer contraction: the SUMMA result is shaped from operand - // inner cells by arena_plan_; op_'s post-processing permute - // applies the (outer + inner) result permutation. - this->arena_plan_ = - TiledArray::detail::make_contraction_arena_plan< - result_tile_type, left_tile_type, right_tile_type>( - TiledArray::detail::ArenaInnerShapeKind:: - gemm_result_range, - std::make_optional(contrreduce_op.gemm_helper()), - Permutation{}); - if (!bool(this->arena_plan_)) - TA_EXCEPTION( - "nested contraction on view inner tiles: the arena fast " - "path was inactive (arena disabled)"); + // Phantom-unit denest: result inner indices all phantom (⊗ₙ) -- the + // real inner modes are fully contracted, so the inner product is a + // flat dot into a unit-extent [1]^phantom_rank cell. The arena plan + // shapes the [1] result cells (unit_range); the per-cell op fills + // the lone element via the dot. Operands are read flat, so no view + // cell carries the phantom mode and no GEMM rank match is required. + const auto result_inner = inner(this->indices_); + bool result_inner_all_phantom = result_inner.size() > 0; + for (std::size_t m = 0; m < result_inner.size(); ++m) + if (!TiledArray::detail::is_phantom_unit_label(result_inner[m])) { + result_inner_all_phantom = false; + break; + } + if (result_inner_all_phantom) { + const scalar_type factor = this->factor_; + this->element_nonreturn_op_ = + [factor](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (left.empty() || right.empty()) return; + using Numeric = + typename result_tile_element_type::numeric_type; + const std::size_t n = left.range().volume(); + TA_ASSERT(n == right.range().volume()); + const auto* lp = left.data(); + const auto* rp = right.data(); + Numeric acc{0}; + for (std::size_t j = 0; j < n; ++j) acc += lp[j] * rp[j]; + // result cell is pre-shaped [1] by the unit_range plan. + result.data()[0] += static_cast(factor) * acc; + }; + if (this->product_type() == TensorProduct::Contraction) { + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + TiledArray::detail::ArenaInnerShapeKind::unit_range, + std::nullopt, Permutation{}, result_inner.size()); + if (!bool(this->arena_plan_)) + TA_EXCEPTION( + "phantom-unit denest on view inner tiles: the arena fast " + "path was inactive (arena disabled)"); + } else { + // outer Hadamard: a whole-tile arena op that shapes each + // result outer cell as a unit-extent [1] cell and fills it via + // the phantom-dot per-cell op. + this->arena_hadamard_tile_op_ = + [cell_op = this->element_nonreturn_op_, + phantom_rank = result_inner.size()]( + const left_tile_type& l, + const right_tile_type& r) -> result_tile_type { + return TiledArray::detail::arena_hadamard_phantom_dot< + result_tile_type>(l, r, phantom_rank, cell_op); + }; + } } else { - // outer Hadamard: MultEngine builds a binary tile op, which - // cannot use a value-returning per-cell op. Supply a whole-tile - // arena op that shapes the result from per-cell inner GEMMs and - // fills it in place; the inner result permutation is a - // slab-level post-pass inside the kernel. - this->arena_hadamard_tile_op_ = - [cell_op = this->element_nonreturn_op_, - inner_gh = contrreduce_op.gemm_helper(), - inner_perm = inner(this->perm_)]( - const left_tile_type& l, - const right_tile_type& r) -> result_tile_type { - return TiledArray::detail::arena_hadamard_inner_contract< - result_tile_type>(l, r, inner_gh, cell_op, inner_perm); - }; + using op_type = TiledArray::detail::ContractReduce< + result_tile_element_type, left_tile_element_type, + right_tile_element_type, scalar_type>; + // The inner op is built *perm-free* on purpose. factor_ is + // absorbed into element_nonreturn_op_; operand inner transposes + // are folded into the inner GEMM via left_/right_inner_permtype_. + // A non-identity inner *result* permutation is NOT placed on this + // op (make_fused_contraction_lambda asserts a perm-free op); it + // is applied downstream instead -- by op_'s post-processing + // permute for a contraction outer product, or by + // arena_hadamard_inner_contract's slab-level post-pass for a + // Hadamard outer product. + auto contrreduce_op = op_type( + to_cblas_op(this->left_inner_permtype_), + to_cblas_op(this->right_inner_permtype_), this->factor_, + inner_size(this->indices_), inner_size(this->left_indices_), + inner_size(this->right_indices_)); + // perm-free per-cell in-place contraction; used by both outer + // regimes below + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_contraction_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(contrreduce_op); + if (this->product_type() == TensorProduct::Contraction) { + // outer contraction: the SUMMA result is shaped from operand + // inner cells by arena_plan_; op_'s post-processing permute + // applies the (outer + inner) result permutation. + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + TiledArray::detail::ArenaInnerShapeKind:: + gemm_result_range, + std::make_optional(contrreduce_op.gemm_helper()), + Permutation{}); + if (!bool(this->arena_plan_)) + TA_EXCEPTION( + "nested contraction on view inner tiles: the arena fast " + "path was inactive (arena disabled)"); + } else { + // outer Hadamard: MultEngine builds a binary tile op, which + // cannot use a value-returning per-cell op. Supply a whole-tile + // arena op that shapes the result from per-cell inner GEMMs and + // fills it in place; the inner result permutation is a + // slab-level post-pass inside the kernel. + this->arena_hadamard_tile_op_ = + [cell_op = this->element_nonreturn_op_, + inner_gh = contrreduce_op.gemm_helper(), + inner_perm = inner(this->perm_)]( + const left_tile_type& l, + const right_tile_type& r) -> result_tile_type { + return TiledArray::detail::arena_hadamard_inner_contract< + result_tile_type>(l, r, inner_gh, cell_op, inner_perm); + }; + } } } // element_return_op_ left null: a view cell cannot be @@ -721,44 +778,99 @@ class ContEngine : public BinaryEngine { if (inner_prod == TensorProduct::Contraction) { TA_ASSERT(tot_x_tot); if constexpr (tot_x_tot) { - using op_type = TiledArray::detail::ContractReduce< - result_tile_element_type, left_tile_element_type, - right_tile_element_type, scalar_type>; - // factor_ is absorbed into inner_tile_nonreturn_op_ - auto contrreduce_op = - (inner_target_indices != inner(this->indices_)) - ? op_type(to_cblas_op(this->left_inner_permtype_), - to_cblas_op(this->right_inner_permtype_), - this->factor_, inner_size(this->indices_), - inner_size(this->left_indices_), - inner_size(this->right_indices_), - (!this->implicit_permute_inner_ ? inner(this->perm_) - : Permutation{})) - : op_type(to_cblas_op(this->left_inner_permtype_), - to_cblas_op(this->right_inner_permtype_), - this->factor_, inner_size(this->indices_), - inner_size(this->left_indices_), - inner_size(this->right_indices_)); - constexpr bool arena_eligible = - TiledArray::detail::is_contraction_arena_tot_v< - result_tile_type, left_tile_type, right_tile_type>; - if constexpr (arena_eligible) { - if (this->product_type() == TensorProduct::Contraction) { - this->arena_plan_ = - TiledArray::detail::make_contraction_arena_plan< - result_tile_type, left_tile_type, right_tile_type>( - TiledArray::detail::ArenaInnerShapeKind:: - gemm_result_range, - std::make_optional(contrreduce_op.gemm_helper()), - inner(this->perm_)); + // Phantom-unit denest: every result inner index is a phantom unit + // (⊗ₙ), i.e. the real inner modes are fully contracted. The inner + // product is a flat (non-conjugating) dot of the operand cells + // accumulated into the lone element of a unit-extent [1]^phantom_rank + // result cell. Operands are read flat, so neither carries the phantom + // mode -- no GEMM, no ContractReduce rank match. element_return_op_ + // (built below) wraps this for the outer-Hadamard regime. + const auto result_inner = inner(this->indices_); + bool result_inner_all_phantom = result_inner.size() > 0; + for (std::size_t m = 0; m < result_inner.size(); ++m) + if (!TiledArray::detail::is_phantom_unit_label(result_inner[m])) { + result_inner_all_phantom = false; + break; } - } - if constexpr (arena_eligible) { - if (this->arena_plan_) { - this->element_nonreturn_op_ = - TiledArray::detail::make_fused_contraction_lambda< - result_tile_element_type, left_tile_element_type, - right_tile_element_type>(contrreduce_op); + if (result_inner_all_phantom) { + const std::size_t phantom_rank = result_inner.size(); + const scalar_type factor = this->factor_; + this->element_nonreturn_op_ = + [phantom_rank, factor](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (left.empty() || right.empty()) return; + using Numeric = + typename result_tile_element_type::numeric_type; + const std::size_t n = left.range().volume(); + TA_ASSERT(n == right.range().volume()); + const auto* lp = left.data(); + const auto* rp = right.data(); + Numeric acc{0}; + for (std::size_t j = 0; j < n; ++j) acc += lp[j] * rp[j]; + acc *= static_cast(factor); + if (TA::empty(result)) { + using R = typename result_tile_element_type::range_type; + TiledArray::container::svector ext( + phantom_rank, 1); + result = result_tile_element_type(R(ext), Numeric{0}); + } + result.data()[0] += acc; + }; + } else { + using op_type = TiledArray::detail::ContractReduce< + result_tile_element_type, left_tile_element_type, + right_tile_element_type, scalar_type>; + // factor_ is absorbed into inner_tile_nonreturn_op_ + auto contrreduce_op = + (inner_target_indices != inner(this->indices_)) + ? op_type( + to_cblas_op(this->left_inner_permtype_), + to_cblas_op(this->right_inner_permtype_), + this->factor_, inner_size(this->indices_), + inner_size(this->left_indices_), + inner_size(this->right_indices_), + (!this->implicit_permute_inner_ ? inner(this->perm_) + : Permutation{})) + : op_type(to_cblas_op(this->left_inner_permtype_), + to_cblas_op(this->right_inner_permtype_), + this->factor_, inner_size(this->indices_), + inner_size(this->left_indices_), + inner_size(this->right_indices_)); + constexpr bool arena_eligible = + TiledArray::detail::is_contraction_arena_tot_v< + result_tile_type, left_tile_type, right_tile_type>; + if constexpr (arena_eligible) { + if (this->product_type() == TensorProduct::Contraction) { + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + TiledArray::detail::ArenaInnerShapeKind:: + gemm_result_range, + std::make_optional(contrreduce_op.gemm_helper()), + inner(this->perm_)); + } + } + if constexpr (arena_eligible) { + if (this->arena_plan_) { + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_contraction_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(contrreduce_op); + } else { + this->element_nonreturn_op_ = + [contrreduce_op, + permute_inner = + this->product_type() != TensorProduct::Contraction]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + contrreduce_op(result, left, right); + // permutations of result are applied as "postprocessing" + if (permute_inner && !TA::empty(result)) + result = contrreduce_op(result); + }; + } } else { this->element_nonreturn_op_ = [contrreduce_op, permute_inner = this->product_type() != @@ -772,18 +884,6 @@ class ContEngine : public BinaryEngine { result = contrreduce_op(result); }; } - } else { - this->element_nonreturn_op_ = - [contrreduce_op, permute_inner = this->product_type() != - TensorProduct::Contraction]( - result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - contrreduce_op(result, left, right); - // permutations of result are applied as "postprocessing" - if (permute_inner && !TA::empty(result)) - result = contrreduce_op(result); - }; } } // ToT x ToT } else if (inner_prod == TensorProduct::Hadamard) { diff --git a/src/TiledArray/tensor/arena_einsum.h b/src/TiledArray/tensor/arena_einsum.h index 7e917e2bf8..8d7b3a578a 100644 --- a/src/TiledArray/tensor/arena_einsum.h +++ b/src/TiledArray/tensor/arena_einsum.h @@ -10,6 +10,7 @@ #include "TiledArray/tensor/arena_kernels.h" #include "TiledArray/tensor/kernels.h" #include "TiledArray/tensor/type_traits.h" +#include "TiledArray/util/annotation.h" #include #include @@ -26,15 +27,19 @@ namespace TiledArray::detail { /// Specifies how an inner-cell range is derived from operand inner cells. enum class ArenaInnerShapeKind { - left_range, // Hadamard inner; Scale tot_x_t - right_range, // Scale t_x_tot - gemm_result_range // inner Contraction (uses inner_gh) + left_range, // Hadamard inner; Scale tot_x_t + right_range, // Scale t_x_tot + gemm_result_range, // inner Contraction (uses inner_gh) + unit_range // phantom-unit denest: a unit-extent [1]^phantom_rank cell + // independent of operand inner ranges (the inner product is a + // flat dot; see RegimeAInnerKind::phantom_dot) }; /// Inner-shape derivation plan: kind + (optional) inner GemmHelper. struct ArenaInnerShapePlan { ArenaInnerShapeKind kind; std::optional inner_gh; // only for gemm_result_range + std::size_t phantom_rank = 0; // only for unit_range /// Derives one result inner range from operand inner cells. template @@ -48,6 +53,10 @@ struct ArenaInnerShapePlan { TA_ASSERT(inner_gh.has_value()); return inner_gh->template make_result_range( l.range(), r.range()); + case ArenaInnerShapeKind::unit_range: { + container::vector ext(phantom_rank, std::size_t{1}); + return ResultInnerRange(ext); + } } TA_ASSERT(false); return ResultInnerRange{}; @@ -125,7 +134,8 @@ using arena_plan_storage_t = template auto make_contraction_arena_plan(ArenaInnerShapeKind inner_kind, std::optional inner_gh, - const Permutation& inner_perm) + const Permutation& inner_perm, + std::size_t phantom_rank = 0) -> std::optional> { if (arena_disabled()) return std::nullopt; if constexpr (!is_contraction_arena_tot_v) { @@ -137,7 +147,8 @@ auto make_contraction_arena_plan(ArenaInnerShapeKind inner_kind, else if (!inner_gh.has_value()) return std::nullopt; return std::optional>( - std::in_place, ArenaInnerShapePlan{inner_kind, std::move(inner_gh)}); + std::in_place, + ArenaInnerShapePlan{inner_kind, std::move(inner_gh), phantom_rank}); } } @@ -366,6 +377,36 @@ Result arena_hadamard_inner_contract(const Left& left, const Right& right, return result; } +/// Hadamard-outer, phantom-unit-denest-inner ToT x ToT product into a fresh +/// arena tile. Like arena_hadamard_inner_contract, but each result outer cell +/// is a unit-extent [1]^phantom_rank cell (the inner product is a full +/// contraction = a flat dot; there are no real result inner modes). `cell_op` +/// (the phantom-dot per-cell op) fills the pre-shaped unit cell. No inner +/// permutation: phantom modes are all unit-extent. +template +Result arena_hadamard_phantom_dot(const Left& left, const Right& right, + std::size_t phantom_rank, + const CellOp& cell_op) { + using inner_range_t = typename Result::value_type::range_type; + TA_ASSERT(left.range().volume() == right.range().volume()); + TA_ASSERT(left.nbatch() == right.nbatch()); + const std::size_t N_cells = left.range().volume() * left.nbatch(); + const container::vector unit_ext(phantom_rank, std::size_t{1}); + auto range_fn = [&left, &right, &unit_ext](std::size_t ord) -> inner_range_t { + const auto& lc = left.data()[ord]; + const auto& rc = right.data()[ord]; + if (lc.empty() || rc.empty()) return inner_range_t{}; + return inner_range_t(unit_ext); + }; + Result result = + arena_outer_init(left.range(), left.nbatch(), range_fn); + for (std::size_t ord = 0; ord < N_cells; ++ord) { + if (result.data()[ord].empty()) continue; + cell_op(result.data()[ord], left.data()[ord], right.data()[ord]); + } + return result; +} + /// Creates a fused Hadamard callback. template auto make_fused_hadamard_lambda() { @@ -402,8 +443,13 @@ auto make_fused_scale_t_x_tot_lambda() { enum class RegimeAInnerKind { hadamard, contraction, - scale_left, // ToT × plain T → ToT (right operand contributes scalars) - scale_right // plain T × ToT → ToT (left operand contributes scalars) + scale_left, // ToT × plain T → ToT (right operand contributes scalars) + scale_right, // plain T × ToT → ToT (left operand contributes scalars) + phantom_dot // full inner contraction (dot) into a unit-extent result cell; + // the result keeps only phantom-unit inner modes (see + // is_phantom_unit_label). Operand cells are read flat, so no + // operand carries the phantom mode and no GEMM rank match is + // required. Realizes the ToT×ToT→plain-T (DeNest) inner product. }; /// Permute the extents of `src` by `perm` and materialize a range of type @@ -439,6 +485,10 @@ struct RegimeAArenaPlan { std::optional> h_plan{}; std::optional> c_plan{}; + // For kind == phantom_dot: the number of phantom-unit result modes (the rank + // of the unit-extent result inner cell, e.g. 1 for `⊗₁`). + std::size_t phantom_rank = 0; + /// Derives the result inner range from a non-empty input-cell pair. template InnerRange derive_inner_range(const LRange& l_range, @@ -468,6 +518,12 @@ struct RegimeAArenaPlan { return InnerRange(l_range); case RegimeAInnerKind::scale_right: return InnerRange(r_range); + case RegimeAInnerKind::phantom_dot: { + // The result keeps only phantom-unit modes: a rank-`phantom_rank`, + // all-unit-extent cell (e.g. [1] for `⊗₁`). + container::vector ext(phantom_rank, std::size_t{1}); + return InnerRange(ext); + } } TA_ASSERT(false && "RegimeAInnerKind: unhandled kind"); return InnerRange{}; @@ -520,6 +576,27 @@ struct RegimeAArenaPlan { } return; } + case RegimeAInnerKind::phantom_dot: { + if constexpr (is_arena_inner_cell_v && + is_arena_inner_cell_v) { + if (l.empty() || rr.empty()) return; + // Full inner contraction with only phantom-unit modes surviving: a + // flat (non-conjugating) dot of the operand cells -- the same value a + // GEMM with M=N=1,K=vol would compute -- accumulated into the lone + // element of the unit-extent result cell. Reads operands flat, so no + // operand need carry the phantom mode and no rank match is required; + // uniform for TA::Tensor and ArenaTensor cells. + using Numeric = typename std::remove_cv_t::numeric_type; + const std::size_t n = l.range().volume(); + TA_ASSERT(n == rr.range().volume()); + const auto* MADNESS_RESTRICT lp = l.data(); + const auto* MADNESS_RESTRICT rp = rr.data(); + Numeric acc{0}; + for (std::size_t j = 0; j < n; ++j) acc += lp[j] * rp[j]; + r.data()[0] += acc; + } + return; + } } } }; @@ -564,6 +641,19 @@ auto make_regime_a_arena_plan(const A& a, const B& b, const Inner& inner, // run_regime_a_arena hoists each operand inner permutation to a // slab-level rewrite (arena_inner_permute) so both operands reach // C-layout before the per-cell flat r += l * rr. No need to bail. + } else if (bool(inner.C) && [&] { + for (const auto& lbl : inner.C) + if (!::TiledArray::detail::is_phantom_unit_label(lbl)) + return false; + return true; + }()) { + // Phantom-unit denest: every surviving result inner mode is a phantom + // unit (⊗ₙ), i.e. the real inner modes are fully contracted. Realize + // the inner product as a flat dot into a unit-extent result cell -- + // operands are read flat, so neither needs to carry the phantom mode + // (no GEMM, no TensorContractionPlan rank match). See accumulate(). + plan.kind = RegimeAInnerKind::phantom_dot; + plan.phantom_rank = inner.C.size(); } else { plan.kind = RegimeAInnerKind::contraction; plan.c_plan.emplace(inner.A, inner.B, inner.C); diff --git a/src/TiledArray/util/annotation.h b/src/TiledArray/util/annotation.h index 8fd30569cc..8d43deae02 100644 --- a/src/TiledArray/util/annotation.h +++ b/src/TiledArray/util/annotation.h @@ -142,6 +142,28 @@ inline bool is_tot_index(const std::string& idx) { return idx.find(";") != std::string::npos; } +/// The reserved prefix marking a *phantom unit* inner-mode annotator. A phantom +/// unit mode (spelled `⊗₁`, `⊗₂`, … = U+2297 CIRCLED TIMES followed by a +/// subscript) denotes a unit-extent inner mode that is *not* physically present +/// in the annotated tensor. It is used by the ToT×ToT→plain-T (DeNest) path to +/// express a full inner contraction (a dot) as a contraction whose result keeps +/// a unit inner mode, so the result inner cell is a genuine (≥ order-1) tensor +/// rather than the unsupported order-0. The einsum machinery recognizes such +/// labels and realizes the inner product as a flat dot into a unit-extent cell, +/// without requiring the operand to carry the extra mode. +inline const char* phantom_unit_label_prefix() { + return "\xE2\x8A\x97"; // UTF-8 for U+2297 (⊗) +} + +/// \param[in] label a single (already split) index label +/// \return true if \p label is a phantom-unit annotator (see +/// phantom_unit_label_prefix) +inline bool is_phantom_unit_label(const std::string& label) { + const std::string prefix = phantom_unit_label_prefix(); + return label.size() >= prefix.size() && + label.compare(0, prefix.size(), prefix) == 0; +} + /// Splits and sanitizes a string labeling a tensor's modes. /// /// This function encapsulates TiledArray's string index parsing. It is a free diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 8bae61cf1f..7ef67f176d 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -414,6 +414,16 @@ BOOST_AUTO_TEST_CASE(nested_rank_reduction) { {{0, 2, 4}, {0, 4}}, // {3, 2}, // {3, 2}))); + // external indices present (branch-2 generalized contraction): Hadamard i, + // external p (from A) and q (from B), contracted-outer k, inner ab fully + // contracted. This is the shape that motivated the rewrite (cf. the CSV-CC + // term g*C*t*...). Exercises the ContEngine inner phantom-dot path. + BOOST_REQUIRE((check_manual_eval( + "ipk;ab,iqk;ab->ipq", // + {{0, 2}, {0, 2, 3}, {0, 2, 4}}, // A outer: i, p, k + {{0, 2}, {0, 3}, {0, 2, 4}}, // B outer: i, q, k + {3, 2}, // A inner: ab + {3, 2}))); // B inner: ab } BOOST_AUTO_TEST_CASE(corner_cases) {