Skip to content

First-class general tensor products (fused + contracted + free indices) in the expression layer#562

Open
evaleev wants to merge 11 commits into
masterfrom
evaleev/feature/general-product-expr
Open

First-class general tensor products (fused + contracted + free indices) in the expression layer#562
evaleev wants to merge 11 commits into
masterfrom
evaleev/feature/general-product-expr

Conversation

@evaleev

@evaleev evaleev commented Jun 11, 2026

Copy link
Copy Markdown
Member

Summary

Adds native support for general binary tensor products — fused (Hadamard), contracted, and free indices coexisting, e.g. C("b,i,k") = A("b,i,j") * B("b,j,k") — to the expression layer (MultEngine/ContEngine) and evaluates them with a batched Summa: one distributed task graph in one World. einsum() now routes its generalized-contraction branch through this path by default, eliminating its per-Hadamard-slab decomposition (one MPI_Comm_split + sub-World + make_array + fence per slab).

On the motivating workload (n-hexane PNO-CCSD/cc-pVDZ in MPQC, block-sparse arena tensor-of-tensors), the per-run sub-World count drops 1412 → 0 and the einsum-region attribution becomes pure evaluation (was ~50% machinery: retile/make_array 25%, per-slab fences 30%, harvest+teardown 8%).

Design

  • compute_product_type(left, right, target) now returns the (previously unreachable) TensorProduct::General when a shared index survives into the target alongside contracted/free indices; the bottom-up 2-arg overload is unchanged (shared ⇒ contracted).
  • New GeneralPermutationOptimizer: canonical layouts A(h,e_A,c), B(h,c,e_B), C(h,e_A,e_B) — the GEMM-canonical layout with the fused modes leading, so the tile op folds them into the tile batch dimension by zero-copy reshape.
  • Summa generalized in place to batched contractions: optional slab count nh (default 1 = exactly the prior behavior); iteration in steps s = h*k + k; slab-offset tile ordinals; (h,k)-keyed sparse masks/groups; per-slab reduce tasks. The owner of a tile is independent of its slab (SlabbedPmap), so one slab's contraction is fully distributed over the same 2-d grid and slabs overlap in the task pipeline with no inter-slab barriers.
  • BatchedContractReduce adapts a folded-rank ContractReduce to fused-mode-carrying tiles (Tensor::gemm's nbatch loop and the arena ToT kernels already speak this convention).
  • SparseShape::gemm_batched: slab-batched norm contraction for the result shape.
  • ToT composition: the inner-tile-op builders classify the outer regime via outer_product_uses_summa() (Contraction or General ⇒ ContractReduce semantics), so arena plans and the strided-DGEMM kernels install as for a pure contraction.

einsum cutover & differential harness

Three-way runtime control:

  • default: expression route;
  • TA_EINSUM_LEGACY_SUBWORLD (or detail::einsum_legacy_subworld()): forces the legacy per-slab sub-World path, retained indefinitely as the reference implementation;
  • TA_EINSUM_DIFFERENTIAL: evaluates every general product by both routes, compares norms, and reports mismatching contractions with per-tile forensics — this is how the bug below was found.

Also adds TA_EINSUM_INSTRUMENT, a runtime-gated attribution profiler for the einsum region (per-call time buckets: retile, comm-split, contract+fence, harvest, …).

Drive-by bug fix (pre-existing)

The differential harness exposed a latent bug in Tensor::gemm's ToT×scalar strided scale paths: the per-row cleanliness probe stops at the first absent cell, and the subsequent A <= 0 ⇒ empty row shortcut dropped the entire row's contributions even when later cells were present. The legacy einsum route dodged it by accident (its canonical layout fails the path's NoTranspose gate). Fixed for both orientations + regression test.

Known, intentional route differences

  • The legacy path derives the result shape from harvested tile norms and thus implicitly hard-zeroes sub-threshold result tiles; the expression route keeps them (standard estimate-derived contraction shape). Per the TA screening philosophy norms are trusted as genuine and no implicit truncation is performed — call truncate() explicitly if desired. Downstream consumers may see ~1e-7-scale shifts vs legacy-einsum-derived baselines.
  • General products at inner nodes of an expression tree (e.g. THC-style X("p,r1") * X("q,r1") * Z("r1,r2") * …) cannot be classified bottom-up; they now produce an informative error suggesting explicit intermediates (top-down index-set deduction is future work). Targets that interleave fused and free modes are supported through einsum() (canonicalize + permute); native engine support is future work.

Testing

New general_product_suite (23 cases): classification, optimizer layouts, and differential tests against the legacy einsum oracle — dense/block-sparse, plain/ToT/mixed ToT×T, owning and arena (view) inner cells, variable inner extents, screened (absent) cells, non-leading fused indices, interleaved targets, batched outer products, THC gating + workaround; np = 1–4. All existing suites pass unchanged with the new default (einsum suites validated against their reference data). End-to-end validated in MPQC PNO-CCSD via the differential mode.

evaleev added 9 commits June 11, 2026 01:47
Buckets per einsum call: entry_fence / setup / commsplit+world /
retile/make_array / contract+fence / harvest / local_kernel / teardown,
keyed by branch (hadamard-reduction-local, generalized-subworld,
generalized-inner-perm-recurse) and contraction annotation; dumped to
stderr at exit. Zero overhead when disabled. Establishes the baseline
attribution for replacing the per-Hadamard-tile sub-World decomposition
with first-class general-product (h+e+c) support.

PNO-CCSD c6h14/cc-pVDZ baseline (np=1, 3 CC iters): 17.9 s einsum-region,
1412 Hadamard slices = 1412 sub-Worlds; retile/make_array 25.5% +
harvest 3.0% + teardown 4.2% non-numeric, contract+fence 30.5%.
…ral)

Phase A of first-class general-product (fused + contracted + free indices)
support in the expression layer (target: PNO-CC batched contractions,
replacing einsum's per-Hadamard-tile sub-World decomposition):

- compute_product_type(left, right, target) now returns
  TensorProduct::General when a shared index survives into the target
  (fused) alongside contracted and/or free indices, incl. the
  Hadamard-reduction case (args related by permutation, target drops
  indices). The 2-arg overload is unchanged (bottom-up convention:
  shared => contracted).
- new GeneralPermutationOptimizer: canonical layouts
  left (h, e_A, c), right (h, c, e_B), result (h, e_A, e_B) -- the
  GEMM-canonical layout with fused indices prepended so a consuming
  batched-GEMM op can fold them into the tile batch dimension by
  reshape; exposes the h/c/e_A/e_B partition for engine consumption.
  Requires target indices (fused-vs-contracted is undecidable
  bottom-up); validates against implicit reductions.
- BinaryEngine::init_indices_ and MultEngine/ScalMultEngine route
  General through the new optimizer; ContEngine::product_type()
  accessor admits General.
- evaluation is gated with an informative exception (use
  TiledArray::einsum() meanwhile) until the batched-Summa DistEval
  lands (Phase B); previously such expressions misclassified as pure
  contractions and died in target-permutation resolution.
- unit tests: classification, optimizer layouts/partitions/errors, and
  the end-to-end expression gate (tests/general_product.cpp).
Phase B step 1 of general-product support: Summa gains an optional slab
count nh (default 1 = exactly the prior, unbatched behavior). For nh > 1
the operands and result carry the fused (Hadamard) modes as leading
dimensions (left = (h,i,k), right = (h,k,j), result = (h,i,j)); the
contraction runs as nh independent SUMMA slabs over ONE shared 2-d
process grid and ONE task graph:

- iteration space becomes steps s = h*k_ + k; the step-task chain,
  depth control, and sparse step iteration (iterate_{row,col,sparse},
  skipped-range broadcasts) operate in step space
- every argument/result tile ordinal is offset by its slab base; the
  owner of a tile is independent of h (block-cyclic phase restarts per
  slab), so broadcast roots and the 2-d grid logic are unchanged
- per-step sparse broadcast groups are keyed by step (col: s, row:
  s + nsteps); the static dense groups use keys 2*nsteps, 2*nsteps+1;
  tile broadcast keys (global ordinals) are unique across slabs as-is
- reduce tasks: one per local result tile per slab
  (reduce_tasks_[h*local_size + i*local_cols + j]); initialize/finalize
  loop slab-by-slab
- sparse row/col masks take the slab index; get_tile owner computation
  mods out the slab

No caller passes nh yet (that lands with the General-product ContEngine
wiring); all existing suites pass unchanged.
Phase B step 2: dense (DensePolicy) general products now evaluate
natively in the expression layer, end-to-end:
C("b,i,k") = A("b,i,j") * B("b,j,k") runs as ONE distributed batched
Summa in one World -- no per-Hadamard-tile sub-Worlds.

- SlabbedPmap: replicates a base pmap over a leading slab dimension
  (owner of a tile is independent of its fused-index slab), used for the
  SUMMA phase maps of the arguments and the result pmap
- BatchedContractReduce: adapts a folded (fused-mode-free)
  ContractReduce to tiles carrying leading fused modes; folds them into
  the tile batch dimension by zero-copy reshape (modes lead => layout
  preserved), allocates the result with its full range up front, and
  lets Tensor::gemm's per-batch loop do the work; TA::Tensor tiles only
- ContEngine: init_struct_general / make_trange_general /
  make_shape_general / init_distribution_general / make_dist_eval_general
  -- fused-mode-prefixed result structure, per-slab 2-d process grid,
  batched Summa construction (nh = product of fused-mode tile extents,
  K = per-slab contracted tile count)
- MultEngine routes General to these; ScalMultEngine still gates
- not yet supported (clear errors): block-sparse shapes (per-slab shape
  gemm TODO), tensors-of-tensors, targets interleaving fused and free
  modes

Differential-tested against the einsum free function (norm(diff) <=
1e-10): multi-tile uneven dims, permuted argument layouts, batched outer
product; np = 1, 2, 3, 4. All existing suites unchanged.
Phase B3: SparsePolicy general products (fused + contracted + free
indices) now evaluate natively in the expression layer.

- SparseShape::gemm_batched(other, factor, gemm_helper, nfused): the
  batched analogue of gemm. The leading nfused modes of both shapes and
  the result are fused; each fused-index slab is contracted exactly as
  in gemm with the *folded* (fused-mode-free) GEMM helper. The
  contracted-mode size vector is slab-invariant (contracted modes follow
  the fused modes), the norm scaling loops extend over slabs naturally,
  and the per-slab norm GEMMs run as one batched Tensor::gemm via the
  zero-copy fused-modes-into-nbatch reshape; same hard-zero threshold
  pass and outer-product (k_rank == 0) handling as gemm.
- ContEngine::make_shape_general routes SparseShape to it (the dense
  branch is unchanged); this removes the last block-sparse gate, so the
  batched Summa's sparse path ((h,k)-keyed masks, groups, and step
  iteration, landed earlier) is now reachable.
- tests: block-sparse differential tests vs einsum (batched contraction
  and batched outer product, deterministic block-sparsity patterns);
  pass at np = 1, 2, 3, 4.
Phase C: ToT general products (fused + contracted + free outer indices,
nested inner product) now evaluate natively in the expression layer via
the batched Summa.

- the inner-tile-op builders (init_inner_tile_op and the owning-cell
  variant) now classify the outer regime via outer_product_uses_summa()
  (pure contraction OR general product): for both, the tile op is a
  ContractReduce consumed by a (batched) SUMMA, so the per-cell ops
  accumulate in place, no per-cell result permutation is applied, and
  the arena plans / strided-DGEMM ops install as for a pure contraction
- the strided-DGEMM install gates derive the outer-contracted rank from
  the fused-mode-free outer sizes (n_fused_outer_modes() helper)
- init_struct_general gains the ToT arm, mirroring init_struct: builds
  the folded-rank ContractReduce with the inner element op and the
  arena plan, installs the strided ce+e / ce+ce ops; a non-identity
  inner result permutation is gated (the batched op must be perm-free)
- BatchedContractReduce now admits ToT tiles: the folded result is
  allocated by the wrapped op itself (engaging its tile-type-specific
  construction, e.g. the arena reserve) and unfolded by a zero-copy
  reshape; this also gives plain tiles the beta=0 first-accumulation
  fast path
- MultEngine initializes the inner tile op before init_struct_general

Differential-tested against einsum (inner Hadamard and inner
contraction, owning cells) at np = 1, 2, 3; all existing suites
unchanged. Arena (view-cell) general products compile via the same
paths; their end-to-end validation comes with the mpqc/einsum cutover.
Phase D (partial): einsum can evaluate its generalized-contraction branch
through the expression layer's native general-product support (one batched
Summa in one World) instead of the legacy per-Hadamard-slab sub-World
decomposition. The engine receives the canonical (fused..., left-free...,
right-free...) result layout; arbitrary einsum targets are reached by a
final permutation assignment.

Three-way runtime control (detail::einsum_legacy_subworld /
detail::einsum_differential, env TA_EINSUM_LEGACY_SUBWORLD /
TA_EINSUM_DIFFERENTIAL):
- legacy (DEFAULT for now, see below)
- expression route (TA_EINSUM_LEGACY_SUBWORLD=0)
- differential: evaluates BOTH routes per general product, compares
  squared norms, reports mismatching contractions (with annotations) to
  stderr, returns the legacy result. The legacy path is retained
  indefinitely as the reference implementation for such testing.

Status: with the expression route, PNO-CCSD (c6h14/cc-pVDZ) runs with ZERO
sub-Worlds (legacy: 1412) and the einsum-region attribution collapses from
17.9 s to 12.4 s of pure evaluation -- but the energy is WRONG (-238.09 vs
-236.35). TA_EINSUM_DIFFERENTIAL isolates the mismatching shapes:
 (1) ToT x T (inner Scale) with a non-leading fused index and interleaved
     target, e.g. (i4,i1,mu;a) * (mu,i4,K) -> (i1,i4,K;a)
 (2) phantom-unit (denest-internal) general products, e.g.
     (mu,i1,i4;a) * (i1,i4,K;a,phantom) -> (mu,i1,K;phantom)
Synthetic unit reproductions of (1) with fixed inner extents PASS, so the
trigger involves CSV specifics (variable per-block inner extents and/or
arena cell layout); under investigation. Until resolved the legacy path is
the default and the expression route is opt-in.

All unit suites green in both modes; einsum suites were also validated
green against their reference data with the expression route as default
before the flip.
The GEMM-based ToT x scalar scale paths of Tensor::gemm (and the
T x ToT mirror) probe each row (column) for cleanliness; the presence
probe stops at the first ABSENT cell, leaving the probed inner size
A == -1 when the leading cell is absent. The subsequent 'A <= 0 =>
empty row, nothing to do' shortcut then dropped the ENTIRE row's
contributions even when later cells were present. Rows whose leading
contracted cell is absent (common for screened tensor-of-tensors, e.g.
PNO-CC CSV intermediates) silently lost their contraction.

Fix: when the probe ends with A <= 0, scan the full row (column) for
any present cell; only a fully absent row is skipped, anything else
takes the per-cell AXPY fallback. Also guard the engine's scale
fallback element op against absent cells.

This bug predates the general-product work but was masked on the
legacy einsum route, whose canonical operand layout fails the
NoTranspose gate of the strided path; the expression route's
GEMM-canonical layout exposed it. Found with TA_EINSUM_DIFFERENTIAL on
c6h14 PNO-CCSD: the opt-in expression-route energy error drops from
1.7 Eh to 6.9e-7 (the small residual is a screening-semantics
difference -- the legacy route hard-zeroes sub-threshold result tiles
that the engine route keeps -- plus a small systematic difference in
phantom-unit denest products, both under review).

Adds the CSV-like reproduction test (arena view cells, SparsePolicy,
variable inner extents, screened cells, non-leading fused index,
interleaved target): expression route and einsum routes agree to
1e-10, deterministically.
With the strided-scale-path fix in place, the TA_EINSUM_DIFFERENTIAL
audit of c6h14 PNO-CCSD shows the two routes agree except for:
- sub-threshold result tiles that the legacy path implicitly hard-zeroes
  (its result shape derives from the harvested tile norms) while the
  expression route keeps them (standard estimate-derived contraction
  shape). Per the TA screening philosophy, norms are trusted as genuine
  and no implicit truncation is performed; users wanting the tighter
  shape call truncate() explicitly.
- floating-point summation-order noise in tiny, heavily-cancelling
  tensors (absolute tile-norm^2 differences <= 1e-9, no structural
  pattern).

Neither is a defect, so general products in einsum now default to the
expression-layer evaluation (TensorProduct::General -> batched Summa:
one task graph in one World, ZERO per-slab sub-Worlds). The legacy path
remains available via TA_EINSUM_LEGACY_SUBWORLD (or
detail::einsum_legacy_subworld()) as the reference implementation for
differential testing.

All suites pass with the new default (einsum suites against their
reference data; the 2 pre-existing assign_subblock_block_base1 failures
are unrelated); np = 1, 2, 3.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds first-class support for general binary tensor products (fused + contracted + free indices coexisting) in the expression layer by introducing a batched-SUMMA evaluation path, and updates einsum() to route generalized contractions through this new path by default (retaining the legacy sub-World implementation for reference and differential testing).

Changes:

  • Added TensorProduct::General classification/layout support (new GeneralPermutationOptimizer) and integrated it into expression engines (MultEngine/ContEngine).
  • Generalized SUMMA to support batched (slabbed) contractions and introduced supporting utilities (SlabbedPmap, BatchedContractReduce, SparseShape::gemm_batched).
  • Updated einsum() routing + added runtime-gated instrumentation/differential modes, plus a new comprehensive general_product_suite test suite.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
tests/general_product.cpp New test suite covering classification, optimizer layouts, expression-vs-legacy routing, and sparse/ToT/arena scenarios.
tests/CMakeLists.txt Adds the new general_product.cpp test target source.
src/TiledArray/tile_op/batched_contract_reduce.h New tile-op adapter to fold fused leading modes into a batch dimension for GEMM-based contraction/reduction.
src/TiledArray/tensor/tensor.h Fixes a ToT×scalar strided-scale “empty row/col” probe bug by correctly scanning for later non-empty cells.
src/TiledArray/sparse_shape.h Adds SparseShape::gemm_batched to compute slab-batched contraction shapes for general products.
src/TiledArray/pmap/slabbed_pmap.h New pmap that replicates a base mapping across a slab dimension (slab-independent ownership).
src/TiledArray/expressions/product.h Enables 3-arg classification to return TensorProduct::General when target keeps shared indices.
src/TiledArray/expressions/permopt.h Adds GeneralPermutationOptimizer and routes TensorProduct::General through it.
src/TiledArray/expressions/mult_engine.h Integrates general-product routing, inner-node gating checks, and general distribution/eval hooks.
src/TiledArray/expressions/cont_engine.h Implements general-product structure/distribution/evaluator (batched SUMMA + batched tile op).
src/TiledArray/expressions/binary_engine.h Extends index initialization template to allow TensorProduct::General optimizer selection.
src/TiledArray/einsum/tiledarray.h Adds routing toggles, differential mode, instrumentation hooks, and routes generalized contraction via expression layer by default.
src/TiledArray/einsum/einsum_instrument.h New lightweight, runtime-gated einsum attribution profiler.
src/TiledArray/dist_eval/contraction_eval.h Generalizes SUMMA implementation to batched slabs (step space expanded to nh * k).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/TiledArray/einsum/tiledarray.h
Comment thread src/TiledArray/tile_op/batched_contract_reduce.h
Comment thread src/TiledArray/pmap/slabbed_pmap.h
Comment thread src/TiledArray/expressions/cont_engine.h
Comment thread src/TiledArray/sparse_shape.h
Comment thread tests/general_product.cpp Outdated
Comment thread tests/general_product.cpp Outdated
Comment thread tests/general_product.cpp Outdated
evaleev added 2 commits June 11, 2026 18:43
…gruence checks

- einsum/tiledarray.h, tile_op/batched_contract_reduce.h, pmap/slabbed_pmap.h:
  include what is used (<cstdlib>, <string_view>, util/vector.h, <memory>,
  <utility>) instead of relying on transitive includes
- cont_engine.h: re-initialize K_ in init_distribution_general() (defensive;
  engines are single-use, but mirrors the n_slabs_ reset)
- sparse_shape.h: gemm_batched() now TA_ASSERTs that the argument ranks match
  the folded gemm ranks plus the fused modes and that the fused and contracted
  mode extents of the two shapes are congruent (the batched analogue of the
  checks GemmHelper::compute_matrix_sizes performs for plain gemm)
ScopedEinsumRoute restores the previous einsum_legacy_subworld() value on
scope exit (ForceLegacyEinsum is now the legacy=true special case), so a
throwing TA::einsum can no longer leak the toggle into later test cases.

Also restores einsum_expression_route_matches_legacy to its intent: it was
written when the legacy sub-World path was einsum's default, so after the
default flip (ef0066c) its "legacy" reference silently took the expression
route (vacuous comparison) and the trailing manual toggle left the legacy
path enabled for the rest of the test module.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants