First-class general tensor products (fused + contracted + free indices) in the expression layer#562
Open
evaleev wants to merge 11 commits into
Open
First-class general tensor products (fused + contracted + free indices) in the expression layer#562evaleev wants to merge 11 commits into
evaleev wants to merge 11 commits into
Conversation
Buckets per einsum call: entry_fence / setup / commsplit+world / retile/make_array / contract+fence / harvest / local_kernel / teardown, keyed by branch (hadamard-reduction-local, generalized-subworld, generalized-inner-perm-recurse) and contraction annotation; dumped to stderr at exit. Zero overhead when disabled. Establishes the baseline attribution for replacing the per-Hadamard-tile sub-World decomposition with first-class general-product (h+e+c) support. PNO-CCSD c6h14/cc-pVDZ baseline (np=1, 3 CC iters): 17.9 s einsum-region, 1412 Hadamard slices = 1412 sub-Worlds; retile/make_array 25.5% + harvest 3.0% + teardown 4.2% non-numeric, contract+fence 30.5%.
…ral) Phase A of first-class general-product (fused + contracted + free indices) support in the expression layer (target: PNO-CC batched contractions, replacing einsum's per-Hadamard-tile sub-World decomposition): - compute_product_type(left, right, target) now returns TensorProduct::General when a shared index survives into the target (fused) alongside contracted and/or free indices, incl. the Hadamard-reduction case (args related by permutation, target drops indices). The 2-arg overload is unchanged (bottom-up convention: shared => contracted). - new GeneralPermutationOptimizer: canonical layouts left (h, e_A, c), right (h, c, e_B), result (h, e_A, e_B) -- the GEMM-canonical layout with fused indices prepended so a consuming batched-GEMM op can fold them into the tile batch dimension by reshape; exposes the h/c/e_A/e_B partition for engine consumption. Requires target indices (fused-vs-contracted is undecidable bottom-up); validates against implicit reductions. - BinaryEngine::init_indices_ and MultEngine/ScalMultEngine route General through the new optimizer; ContEngine::product_type() accessor admits General. - evaluation is gated with an informative exception (use TiledArray::einsum() meanwhile) until the batched-Summa DistEval lands (Phase B); previously such expressions misclassified as pure contractions and died in target-permutation resolution. - unit tests: classification, optimizer layouts/partitions/errors, and the end-to-end expression gate (tests/general_product.cpp).
Phase B step 1 of general-product support: Summa gains an optional slab
count nh (default 1 = exactly the prior, unbatched behavior). For nh > 1
the operands and result carry the fused (Hadamard) modes as leading
dimensions (left = (h,i,k), right = (h,k,j), result = (h,i,j)); the
contraction runs as nh independent SUMMA slabs over ONE shared 2-d
process grid and ONE task graph:
- iteration space becomes steps s = h*k_ + k; the step-task chain,
depth control, and sparse step iteration (iterate_{row,col,sparse},
skipped-range broadcasts) operate in step space
- every argument/result tile ordinal is offset by its slab base; the
owner of a tile is independent of h (block-cyclic phase restarts per
slab), so broadcast roots and the 2-d grid logic are unchanged
- per-step sparse broadcast groups are keyed by step (col: s, row:
s + nsteps); the static dense groups use keys 2*nsteps, 2*nsteps+1;
tile broadcast keys (global ordinals) are unique across slabs as-is
- reduce tasks: one per local result tile per slab
(reduce_tasks_[h*local_size + i*local_cols + j]); initialize/finalize
loop slab-by-slab
- sparse row/col masks take the slab index; get_tile owner computation
mods out the slab
No caller passes nh yet (that lands with the General-product ContEngine
wiring); all existing suites pass unchanged.
Phase B step 2: dense (DensePolicy) general products now evaluate
natively in the expression layer, end-to-end:
C("b,i,k") = A("b,i,j") * B("b,j,k") runs as ONE distributed batched
Summa in one World -- no per-Hadamard-tile sub-Worlds.
- SlabbedPmap: replicates a base pmap over a leading slab dimension
(owner of a tile is independent of its fused-index slab), used for the
SUMMA phase maps of the arguments and the result pmap
- BatchedContractReduce: adapts a folded (fused-mode-free)
ContractReduce to tiles carrying leading fused modes; folds them into
the tile batch dimension by zero-copy reshape (modes lead => layout
preserved), allocates the result with its full range up front, and
lets Tensor::gemm's per-batch loop do the work; TA::Tensor tiles only
- ContEngine: init_struct_general / make_trange_general /
make_shape_general / init_distribution_general / make_dist_eval_general
-- fused-mode-prefixed result structure, per-slab 2-d process grid,
batched Summa construction (nh = product of fused-mode tile extents,
K = per-slab contracted tile count)
- MultEngine routes General to these; ScalMultEngine still gates
- not yet supported (clear errors): block-sparse shapes (per-slab shape
gemm TODO), tensors-of-tensors, targets interleaving fused and free
modes
Differential-tested against the einsum free function (norm(diff) <=
1e-10): multi-tile uneven dims, permuted argument layouts, batched outer
product; np = 1, 2, 3, 4. All existing suites unchanged.
Phase B3: SparsePolicy general products (fused + contracted + free indices) now evaluate natively in the expression layer. - SparseShape::gemm_batched(other, factor, gemm_helper, nfused): the batched analogue of gemm. The leading nfused modes of both shapes and the result are fused; each fused-index slab is contracted exactly as in gemm with the *folded* (fused-mode-free) GEMM helper. The contracted-mode size vector is slab-invariant (contracted modes follow the fused modes), the norm scaling loops extend over slabs naturally, and the per-slab norm GEMMs run as one batched Tensor::gemm via the zero-copy fused-modes-into-nbatch reshape; same hard-zero threshold pass and outer-product (k_rank == 0) handling as gemm. - ContEngine::make_shape_general routes SparseShape to it (the dense branch is unchanged); this removes the last block-sparse gate, so the batched Summa's sparse path ((h,k)-keyed masks, groups, and step iteration, landed earlier) is now reachable. - tests: block-sparse differential tests vs einsum (batched contraction and batched outer product, deterministic block-sparsity patterns); pass at np = 1, 2, 3, 4.
Phase C: ToT general products (fused + contracted + free outer indices, nested inner product) now evaluate natively in the expression layer via the batched Summa. - the inner-tile-op builders (init_inner_tile_op and the owning-cell variant) now classify the outer regime via outer_product_uses_summa() (pure contraction OR general product): for both, the tile op is a ContractReduce consumed by a (batched) SUMMA, so the per-cell ops accumulate in place, no per-cell result permutation is applied, and the arena plans / strided-DGEMM ops install as for a pure contraction - the strided-DGEMM install gates derive the outer-contracted rank from the fused-mode-free outer sizes (n_fused_outer_modes() helper) - init_struct_general gains the ToT arm, mirroring init_struct: builds the folded-rank ContractReduce with the inner element op and the arena plan, installs the strided ce+e / ce+ce ops; a non-identity inner result permutation is gated (the batched op must be perm-free) - BatchedContractReduce now admits ToT tiles: the folded result is allocated by the wrapped op itself (engaging its tile-type-specific construction, e.g. the arena reserve) and unfolded by a zero-copy reshape; this also gives plain tiles the beta=0 first-accumulation fast path - MultEngine initializes the inner tile op before init_struct_general Differential-tested against einsum (inner Hadamard and inner contraction, owning cells) at np = 1, 2, 3; all existing suites unchanged. Arena (view-cell) general products compile via the same paths; their end-to-end validation comes with the mpqc/einsum cutover.
Phase D (partial): einsum can evaluate its generalized-contraction branch
through the expression layer's native general-product support (one batched
Summa in one World) instead of the legacy per-Hadamard-slab sub-World
decomposition. The engine receives the canonical (fused..., left-free...,
right-free...) result layout; arbitrary einsum targets are reached by a
final permutation assignment.
Three-way runtime control (detail::einsum_legacy_subworld /
detail::einsum_differential, env TA_EINSUM_LEGACY_SUBWORLD /
TA_EINSUM_DIFFERENTIAL):
- legacy (DEFAULT for now, see below)
- expression route (TA_EINSUM_LEGACY_SUBWORLD=0)
- differential: evaluates BOTH routes per general product, compares
squared norms, reports mismatching contractions (with annotations) to
stderr, returns the legacy result. The legacy path is retained
indefinitely as the reference implementation for such testing.
Status: with the expression route, PNO-CCSD (c6h14/cc-pVDZ) runs with ZERO
sub-Worlds (legacy: 1412) and the einsum-region attribution collapses from
17.9 s to 12.4 s of pure evaluation -- but the energy is WRONG (-238.09 vs
-236.35). TA_EINSUM_DIFFERENTIAL isolates the mismatching shapes:
(1) ToT x T (inner Scale) with a non-leading fused index and interleaved
target, e.g. (i4,i1,mu;a) * (mu,i4,K) -> (i1,i4,K;a)
(2) phantom-unit (denest-internal) general products, e.g.
(mu,i1,i4;a) * (i1,i4,K;a,phantom) -> (mu,i1,K;phantom)
Synthetic unit reproductions of (1) with fixed inner extents PASS, so the
trigger involves CSV specifics (variable per-block inner extents and/or
arena cell layout); under investigation. Until resolved the legacy path is
the default and the expression route is opt-in.
All unit suites green in both modes; einsum suites were also validated
green against their reference data with the expression route as default
before the flip.
The GEMM-based ToT x scalar scale paths of Tensor::gemm (and the T x ToT mirror) probe each row (column) for cleanliness; the presence probe stops at the first ABSENT cell, leaving the probed inner size A == -1 when the leading cell is absent. The subsequent 'A <= 0 => empty row, nothing to do' shortcut then dropped the ENTIRE row's contributions even when later cells were present. Rows whose leading contracted cell is absent (common for screened tensor-of-tensors, e.g. PNO-CC CSV intermediates) silently lost their contraction. Fix: when the probe ends with A <= 0, scan the full row (column) for any present cell; only a fully absent row is skipped, anything else takes the per-cell AXPY fallback. Also guard the engine's scale fallback element op against absent cells. This bug predates the general-product work but was masked on the legacy einsum route, whose canonical operand layout fails the NoTranspose gate of the strided path; the expression route's GEMM-canonical layout exposed it. Found with TA_EINSUM_DIFFERENTIAL on c6h14 PNO-CCSD: the opt-in expression-route energy error drops from 1.7 Eh to 6.9e-7 (the small residual is a screening-semantics difference -- the legacy route hard-zeroes sub-threshold result tiles that the engine route keeps -- plus a small systematic difference in phantom-unit denest products, both under review). Adds the CSV-like reproduction test (arena view cells, SparsePolicy, variable inner extents, screened cells, non-leading fused index, interleaved target): expression route and einsum routes agree to 1e-10, deterministically.
With the strided-scale-path fix in place, the TA_EINSUM_DIFFERENTIAL audit of c6h14 PNO-CCSD shows the two routes agree except for: - sub-threshold result tiles that the legacy path implicitly hard-zeroes (its result shape derives from the harvested tile norms) while the expression route keeps them (standard estimate-derived contraction shape). Per the TA screening philosophy, norms are trusted as genuine and no implicit truncation is performed; users wanting the tighter shape call truncate() explicitly. - floating-point summation-order noise in tiny, heavily-cancelling tensors (absolute tile-norm^2 differences <= 1e-9, no structural pattern). Neither is a defect, so general products in einsum now default to the expression-layer evaluation (TensorProduct::General -> batched Summa: one task graph in one World, ZERO per-slab sub-Worlds). The legacy path remains available via TA_EINSUM_LEGACY_SUBWORLD (or detail::einsum_legacy_subworld()) as the reference implementation for differential testing. All suites pass with the new default (einsum suites against their reference data; the 2 pre-existing assign_subblock_block_base1 failures are unrelated); np = 1, 2, 3.
There was a problem hiding this comment.
Pull request overview
This PR adds first-class support for general binary tensor products (fused + contracted + free indices coexisting) in the expression layer by introducing a batched-SUMMA evaluation path, and updates einsum() to route generalized contractions through this new path by default (retaining the legacy sub-World implementation for reference and differential testing).
Changes:
- Added
TensorProduct::Generalclassification/layout support (newGeneralPermutationOptimizer) and integrated it into expression engines (MultEngine/ContEngine). - Generalized SUMMA to support batched (slabbed) contractions and introduced supporting utilities (
SlabbedPmap,BatchedContractReduce,SparseShape::gemm_batched). - Updated
einsum()routing + added runtime-gated instrumentation/differential modes, plus a new comprehensivegeneral_product_suitetest suite.
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
tests/general_product.cpp |
New test suite covering classification, optimizer layouts, expression-vs-legacy routing, and sparse/ToT/arena scenarios. |
tests/CMakeLists.txt |
Adds the new general_product.cpp test target source. |
src/TiledArray/tile_op/batched_contract_reduce.h |
New tile-op adapter to fold fused leading modes into a batch dimension for GEMM-based contraction/reduction. |
src/TiledArray/tensor/tensor.h |
Fixes a ToT×scalar strided-scale “empty row/col” probe bug by correctly scanning for later non-empty cells. |
src/TiledArray/sparse_shape.h |
Adds SparseShape::gemm_batched to compute slab-batched contraction shapes for general products. |
src/TiledArray/pmap/slabbed_pmap.h |
New pmap that replicates a base mapping across a slab dimension (slab-independent ownership). |
src/TiledArray/expressions/product.h |
Enables 3-arg classification to return TensorProduct::General when target keeps shared indices. |
src/TiledArray/expressions/permopt.h |
Adds GeneralPermutationOptimizer and routes TensorProduct::General through it. |
src/TiledArray/expressions/mult_engine.h |
Integrates general-product routing, inner-node gating checks, and general distribution/eval hooks. |
src/TiledArray/expressions/cont_engine.h |
Implements general-product structure/distribution/evaluator (batched SUMMA + batched tile op). |
src/TiledArray/expressions/binary_engine.h |
Extends index initialization template to allow TensorProduct::General optimizer selection. |
src/TiledArray/einsum/tiledarray.h |
Adds routing toggles, differential mode, instrumentation hooks, and routes generalized contraction via expression layer by default. |
src/TiledArray/einsum/einsum_instrument.h |
New lightweight, runtime-gated einsum attribution profiler. |
src/TiledArray/dist_eval/contraction_eval.h |
Generalizes SUMMA implementation to batched slabs (step space expanded to nh * k). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…gruence checks - einsum/tiledarray.h, tile_op/batched_contract_reduce.h, pmap/slabbed_pmap.h: include what is used (<cstdlib>, <string_view>, util/vector.h, <memory>, <utility>) instead of relying on transitive includes - cont_engine.h: re-initialize K_ in init_distribution_general() (defensive; engines are single-use, but mirrors the n_slabs_ reset) - sparse_shape.h: gemm_batched() now TA_ASSERTs that the argument ranks match the folded gemm ranks plus the fused modes and that the fused and contracted mode extents of the two shapes are congruent (the batched analogue of the checks GemmHelper::compute_matrix_sizes performs for plain gemm)
ScopedEinsumRoute restores the previous einsum_legacy_subworld() value on scope exit (ForceLegacyEinsum is now the legacy=true special case), so a throwing TA::einsum can no longer leak the toggle into later test cases. Also restores einsum_expression_route_matches_legacy to its intent: it was written when the legacy sub-World path was einsum's default, so after the default flip (ef0066c) its "legacy" reference silently took the expression route (vacuous comparison) and the trailing manual toggle left the legacy path enabled for the rest of the test module.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds native support for general binary tensor products — fused (Hadamard), contracted, and free indices coexisting, e.g.
C("b,i,k") = A("b,i,j") * B("b,j,k")— to the expression layer (MultEngine/ContEngine) and evaluates them with a batched Summa: one distributed task graph in one World.einsum()now routes its generalized-contraction branch through this path by default, eliminating its per-Hadamard-slab decomposition (oneMPI_Comm_split+ sub-World +make_array+ fence per slab).On the motivating workload (n-hexane PNO-CCSD/cc-pVDZ in MPQC, block-sparse arena tensor-of-tensors), the per-run sub-World count drops 1412 → 0 and the einsum-region attribution becomes pure evaluation (was ~50% machinery: retile/make_array 25%, per-slab fences 30%, harvest+teardown 8%).
Design
compute_product_type(left, right, target)now returns the (previously unreachable)TensorProduct::Generalwhen a shared index survives into the target alongside contracted/free indices; the bottom-up 2-arg overload is unchanged (shared ⇒ contracted).GeneralPermutationOptimizer: canonical layoutsA(h,e_A,c),B(h,c,e_B),C(h,e_A,e_B)— the GEMM-canonical layout with the fused modes leading, so the tile op folds them into the tile batch dimension by zero-copy reshape.Summageneralized in place to batched contractions: optional slab countnh(default 1 = exactly the prior behavior); iteration in stepss = h*k + k; slab-offset tile ordinals;(h,k)-keyed sparse masks/groups; per-slab reduce tasks. The owner of a tile is independent of its slab (SlabbedPmap), so one slab's contraction is fully distributed over the same 2-d grid and slabs overlap in the task pipeline with no inter-slab barriers.BatchedContractReduceadapts a folded-rankContractReduceto fused-mode-carrying tiles (Tensor::gemm's nbatch loop and the arena ToT kernels already speak this convention).SparseShape::gemm_batched: slab-batched norm contraction for the result shape.outer_product_uses_summa()(Contraction or General ⇒ ContractReduce semantics), so arena plans and the strided-DGEMM kernels install as for a pure contraction.einsum cutover & differential harness
Three-way runtime control:
TA_EINSUM_LEGACY_SUBWORLD(ordetail::einsum_legacy_subworld()): forces the legacy per-slab sub-World path, retained indefinitely as the reference implementation;TA_EINSUM_DIFFERENTIAL: evaluates every general product by both routes, compares norms, and reports mismatching contractions with per-tile forensics — this is how the bug below was found.Also adds
TA_EINSUM_INSTRUMENT, a runtime-gated attribution profiler for the einsum region (per-call time buckets: retile, comm-split, contract+fence, harvest, …).Drive-by bug fix (pre-existing)
The differential harness exposed a latent bug in
Tensor::gemm's ToT×scalar strided scale paths: the per-row cleanliness probe stops at the first absent cell, and the subsequentA <= 0 ⇒ empty rowshortcut dropped the entire row's contributions even when later cells were present. The legacy einsum route dodged it by accident (its canonical layout fails the path's NoTranspose gate). Fixed for both orientations + regression test.Known, intentional route differences
truncate()explicitly if desired. Downstream consumers may see ~1e-7-scale shifts vs legacy-einsum-derived baselines.X("p,r1") * X("q,r1") * Z("r1,r2") * …) cannot be classified bottom-up; they now produce an informative error suggesting explicit intermediates (top-down index-set deduction is future work). Targets that interleave fused and free modes are supported througheinsum()(canonicalize + permute); native engine support is future work.Testing
New
general_product_suite(23 cases): classification, optimizer layouts, and differential tests against the legacy einsum oracle — dense/block-sparse, plain/ToT/mixed ToT×T, owning and arena (view) inner cells, variable inner extents, screened (absent) cells, non-leading fused indices, interleaved targets, batched outer products, THC gating + workaround; np = 1–4. All existing suites pass unchanged with the new default (einsum suites validated against their reference data). End-to-end validated in MPQC PNO-CCSD via the differential mode.