Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4fc7f23
Fix complex conjugation in ITF code
Krzmbrzl Jun 12, 2026
c36587a
Remove wrap_conj from generator interface
Krzmbrzl Jun 12, 2026
a267d65
Also handle Powers in release_used_terms
Krzmbrzl Jun 12, 2026
3cbcbeb
ITF export: Add indentation support
Krzmbrzl Jun 12, 2026
6619daa
ITF export: Implement index batching support
Krzmbrzl Jun 12, 2026
032e5ff
Use option struct
Krzmbrzl Jun 12, 2026
fc22fa4
Switch to using SeQuant Exception
Krzmbrzl Jun 12, 2026
e293348
external_interface: Implement first usage of index batching
Krzmbrzl Jun 12, 2026
387dd31
Use SEQUANT_ASSERT instead of assert
Krzmbrzl Jun 15, 2026
c58e772
Don't rely on defaults for tensor symmetries
Krzmbrzl Jun 15, 2026
1569489
Factor ExportTensorComparator into separate files
Krzmbrzl Jun 15, 2026
ecabe24
Support batching in text generator
Krzmbrzl Jun 15, 2026
f520884
Test batching support in generation optimizer
Krzmbrzl Jun 15, 2026
d16a709
Make GenerationOptimizer work in case of batched indices
Krzmbrzl Jun 15, 2026
828c796
Remove unneeded header
Krzmbrzl Jun 16, 2026
1398f71
Improve option handling
Krzmbrzl Jun 16, 2026
26b2784
Don't look for wchar in regular string
Krzmbrzl Jun 16, 2026
5de7d98
Remove unnecessary comment
Krzmbrzl Jun 16, 2026
95bd3eb
Add wchar overloads for string_to(…)
Krzmbrzl Jun 16, 2026
cf2b67a
Allow index construction from strings without underscore
Krzmbrzl Jun 16, 2026
fc39b34
Add additional test case
Krzmbrzl Jun 16, 2026
3edaddd
Implement batching control options
Krzmbrzl Jun 16, 2026
26c30d9
Introduce CSEOptions
Krzmbrzl Jun 16, 2026
226512f
Also use batching for CSE intermediates
Krzmbrzl Jun 16, 2026
25cb1a8
Make ExportTensorComparator independent of export module
Krzmbrzl Jun 17, 2026
4546c2b
Add missing header
Krzmbrzl Jun 17, 2026
86121b1
Make CSE aware of batched indices
Krzmbrzl Jun 17, 2026
614a948
Use batch-index aware CSE in external_interface
Krzmbrzl Jun 17, 2026
9966088
Sort after having assigned batching indices to CSEs
Krzmbrzl Jun 18, 2026
eafd6e9
Reverse order of sort importance for batched indices
Krzmbrzl Jun 18, 2026
abeb077
ITF: Implement batch loop fusion
Krzmbrzl Jun 18, 2026
7621d1a
ITF: Sort batched indices for max. parallelizability
Krzmbrzl Jun 18, 2026
ad10658
Move index sorting to context
Krzmbrzl Jun 18, 2026
37a681d
ITF: Also account for index loop fusability
Krzmbrzl Jun 18, 2026
2e56f0c
Fix using less batching indices than possible
Krzmbrzl Jun 18, 2026
ba7aac5
Add option to specify max number of batched indices
Krzmbrzl Jun 18, 2026
2168a5c
Break line that clang-format missed
Krzmbrzl Jun 19, 2026
7d9d858
Update examples to produce correct code again
Krzmbrzl Jun 19, 2026
34d7337
Add a test that verifies output of external_interface examples
Krzmbrzl Jun 19, 2026
8c0e4b1
Refactor duplicated code into function
Krzmbrzl Jun 19, 2026
98615f6
Avoid repeated re-construction of comparator object
Krzmbrzl Jun 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions SeQuant/core/eval/eval_node_compare.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <SeQuant/core/hash.hpp>
#include <SeQuant/core/utility/macros.hpp>
#include <SeQuant/core/utility/tensor.hpp>
#include <SeQuant/external/bliss/graph.hh>

#include <cstddef>
#include <unordered_map>
Expand Down Expand Up @@ -34,6 +35,10 @@ struct TreeNodeEqualityComparator {
/// Trait used by the C++ STL allowing heterogenous lookups
using is_transparent = void;

TreeNodeEqualityComparator() = default;
TreeNodeEqualityComparator(std::vector<Index> indices)
: block_comparator_(std::move(indices)) {}

bool operator()(const TreeNode *lhs, const TreeNode *rhs) const {
return (*this)(*lhs, *rhs);
}
Expand Down Expand Up @@ -84,8 +89,7 @@ struct TreeNodeEqualityComparator {
const Tensor &lhs_tensor = lhs->as_tensor();
const Tensor &rhs_tensor = rhs->as_tensor();

TensorBlockEqualComparator cmp;
if (!cmp(lhs_tensor, rhs_tensor)) {
if (!block_comparator_(lhs_tensor, rhs_tensor)) {
return false;
}
}
Expand Down Expand Up @@ -119,6 +123,9 @@ struct TreeNodeEqualityComparator {

return true;
}

private:
IndexSpecificTensorBlockEqualComparator block_comparator_;
};

/// A map between (sub)tree hashes and how often they have been found
Expand Down
27 changes: 16 additions & 11 deletions SeQuant/core/export/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ bool operator!=(Usage usage, UsageSet set) { return set != usage; }

ExportContext::ExportContext(TensorStrategyMap tensorMap,
VariableStrategyMap variableMap)
: m_tensorStrategies({{GLOBAL, std::move(tensorMap)}}),
m_variableStrategies({{GLOBAL, std::move(variableMap)}}) {}
: m_tensorStrategies({{ID_GLOBAL, std::move(tensorMap)}}),
m_variableStrategies({{ID_GLOBAL, std::move(variableMap)}}) {}

ExportContext::ExportContext(VariableStrategyMap map)
: m_variableStrategies({{GLOBAL, std::move(map)}}) {}
: m_variableStrategies({{ID_GLOBAL, std::move(map)}}) {}

ExportContext::~ExportContext() = default;

LoadStrategy ExportContext::loadStrategy(const Tensor &tensor) const {
if (auto map_iter = m_tensorStrategies.find(GLOBAL);
if (auto map_iter = m_tensorStrategies.find(ID_GLOBAL);
map_iter != m_tensorStrategies.end()) {
auto iter = map_iter->second.find(tensor);
if (iter != map_iter->second.end()) {
Expand All @@ -105,7 +105,7 @@ LoadStrategy ExportContext::loadStrategy(const Tensor &tensor) const {
}

LoadStrategy ExportContext::loadStrategy(const Variable &variable) const {
if (auto map_iter = m_variableStrategies.find(GLOBAL);
if (auto map_iter = m_variableStrategies.find(ID_GLOBAL);
map_iter != m_variableStrategies.end()) {
auto iter = map_iter->second.find(variable);
if (iter != map_iter->second.end()) {
Expand All @@ -130,7 +130,7 @@ void ExportContext::setLoadStrategy(
const Tensor &tensor, LoadStrategy strategy,
const std::optional<std::size_t> &expression_id) {
std::size_t id = expression_id.value_or(
has_current_expression_id() ? current_expression_id() : GLOBAL);
has_current_expression_id() ? current_expression_id() : ID_GLOBAL);

auto iter = m_tensorStrategies[id].find(tensor);

Expand All @@ -145,7 +145,7 @@ void ExportContext::setLoadStrategy(
const Variable &variable, LoadStrategy strategy,
const std::optional<std::size_t> &expression_id) {
std::size_t id = expression_id.value_or(
has_current_expression_id() ? current_expression_id() : GLOBAL);
has_current_expression_id() ? current_expression_id() : ID_GLOBAL);

auto iter = m_variableStrategies[id].find(variable);

Expand All @@ -157,7 +157,7 @@ void ExportContext::setLoadStrategy(
}

ZeroStrategy ExportContext::zeroStrategy(const Tensor &tensor) const {
if (auto map_iter = m_tensorStrategies.find(GLOBAL);
if (auto map_iter = m_tensorStrategies.find(ID_GLOBAL);
map_iter != m_tensorStrategies.end()) {
auto iter = map_iter->second.find(tensor);
if (iter != map_iter->second.end()) {
Expand All @@ -179,7 +179,7 @@ ZeroStrategy ExportContext::zeroStrategy(const Tensor &tensor) const {
}

ZeroStrategy ExportContext::zeroStrategy(const Variable &variable) const {
if (auto map_iter = m_variableStrategies.find(GLOBAL);
if (auto map_iter = m_variableStrategies.find(ID_GLOBAL);
map_iter != m_variableStrategies.end()) {
auto iter = map_iter->second.find(variable);
if (iter != map_iter->second.end()) {
Expand All @@ -204,7 +204,7 @@ void ExportContext::setZeroStrategy(
const Tensor &tensor, ZeroStrategy strategy,
const std::optional<std::size_t> &expression_id) {
std::size_t id = expression_id.value_or(
has_current_expression_id() ? current_expression_id() : GLOBAL);
has_current_expression_id() ? current_expression_id() : ID_GLOBAL);

auto iter = m_tensorStrategies[id].find(tensor);

Expand All @@ -219,7 +219,7 @@ void ExportContext::setZeroStrategy(
const Variable &variable, ZeroStrategy strategy,
const std::optional<std::size_t> &expression_id) {
std::size_t id = expression_id.value_or(
has_current_expression_id() ? current_expression_id() : GLOBAL);
has_current_expression_id() ? current_expression_id() : ID_GLOBAL);

auto iter = m_variableStrategies[id].find(variable);

Expand Down Expand Up @@ -264,4 +264,9 @@ void ExportContext::clear_current_expression_id() {
m_currentExpressionID.reset();
}

std::vector<Index> ExportContext::batch_indices(
std::optional<std::size_t>) const {
return {};
}

} // namespace sequant
14 changes: 12 additions & 2 deletions SeQuant/core/export/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
#define SEQUANT_CORE_EXPORT_CONTEXT_HPP

#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/utility/tensor.hpp>

#include <limits>
#include <map>
#include <optional>
#include <string>
#include <type_traits>
#include <vector>

namespace sequant {

Expand Down Expand Up @@ -199,9 +201,17 @@ class ExportContext {
/// Resets the ID of the current expression
virtual void clear_current_expression_id();

private:
static constexpr std::size_t GLOBAL = std::numeric_limits<std::size_t>::max();
/// @param id The ID of the relevant expression
/// @returns The list of indices that the given expression should be batched
/// over
virtual std::vector<Index> batch_indices(
std::optional<std::size_t> id = {}) const;

protected:
static constexpr std::size_t ID_GLOBAL =
std::numeric_limits<std::size_t>::max();

private:
std::map<std::size_t, TensorStrategyMap> m_tensorStrategies;
std::map<std::size_t, VariableStrategyMap> m_variableStrategies;
std::optional<std::string> m_currentSection;
Expand Down
109 changes: 71 additions & 38 deletions SeQuant/core/export/export.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ namespace sequant {

namespace detail {

/// A collection of various meta-data that the preprocessing stage will collect
struct PreprocessResult {
std::unordered_map<std::size_t, ExprPtr> scalarFactors;
std::set<Index> indices;
std::map<Tensor, UsageSet, TensorBlockLessThanComparator> tensors;
std::map<Variable, UsageSet> variables;

std::map<Tensor, std::size_t, TensorBlockLessThanComparator> tensorReferences;
std::map<Variable, std::size_t> variableReferences;
};

/// Visitor objects that will steer code generation while visiting a given
/// expression/evaluation tree by triggering the corresponding callbacks in the
/// provided Generator objects.
Expand All @@ -44,7 +55,19 @@ class GenerationVisitor {
/// need to be multiplied with the result before storing the node
GenerationVisitor(Generator<Context> &generator, Context &ctx,
const std::unordered_map<NodeID, ExprPtr> &scalarFactors)
: m_generator(generator), m_ctx(ctx), m_scalarFactors(scalarFactors) {}
: m_generator(generator), m_ctx(ctx), m_scalarFactors(scalarFactors) {
if (m_generator.supports_index_batching()) {
// Make tensor comparator aware of the list of batched indices
std::vector<Index> batchIndices =
m_ctx.batch_indices(m_ctx.current_expression_id());
if (!batchIndices.empty()) {
IndexSpecificTensorBlockLessThanComparator cmp =
m_tensorUses.key_comp();
cmp.set_indices(std::move(batchIndices));
m_tensorUses = decltype(m_tensorUses)(std::move(cmp));
}
}
}

void operator()(const ExportNode<NodeData> &node, TreeTraversal context) {
// Note the context for leaf nodes is always TreeTraversal::Any
Expand Down Expand Up @@ -304,23 +327,13 @@ class GenerationVisitor {
Generator<Context> &m_generator;
Context &m_ctx;
const std::unordered_map<NodeID, ExprPtr> &m_scalarFactors;
std::map<Tensor, std::size_t, TensorBlockLessThanComparator> m_tensorUses;
std::map<Tensor, std::size_t, IndexSpecificTensorBlockLessThanComparator>
m_tensorUses;
std::map<Variable, std::size_t> m_variableUses;

std::optional<NodeID> m_rootID;
};

/// A collection of various meta-data that the preprocessing stage will collect
struct PreprocessResult {
std::unordered_map<std::size_t, ExprPtr> scalarFactors;
std::set<Index> indices;
std::map<Tensor, UsageSet, TensorBlockLessThanComparator> tensors;
std::map<Variable, UsageSet> variables;

std::map<Tensor, std::size_t, TensorBlockLessThanComparator> tensorReferences;
std::map<Variable, std::size_t> variableReferences;
};

/// Removes explicitly represented scalar factors from the provided tree and
/// instead stores them separately. This yields a much more compact tree and
/// makes subsequent visiting easier as scalar factors should simply be
Expand Down Expand Up @@ -409,8 +422,9 @@ bool prune_scalar_factor(ExportNode<T> &node, PreprocessResult &result,
/// Renames the given Tensor to a name that doesn't collide with any currently
/// loaded object. This may reuse previously used/declares tensors.
bool rename(Tensor &tensor, PreprocessResult &result);
/// Renames the given Variable to a name that doesn't collide with any currently
/// loaded object. This may reuse previously used/declares variables.
/// Renames the given Variable to a name that doesn't collide with any
/// currently loaded object. This may reuse previously used/declares
/// variables.
bool rename(Variable &variable, PreprocessResult &result);

/// Preprocesses the given expression
Expand All @@ -423,8 +437,9 @@ void preprocess(ExprType expr, ExportContext &ctx, Node &node,

bool storeExpr = false;

// TODO: find a way to pass usage information to this call so that indices of
// tensors that are only used as an intermediate can be more easily reordered
// TODO: find a way to pass usage information to this call so that indices
// of tensors that are only used as an intermediate can be more easily
// reordered
storeExpr |= ctx.rewrite(expr);

if (node.leaf()) {
Expand Down Expand Up @@ -532,8 +547,8 @@ void track_usage(const EvalNode<T> &node, PreprocessResult &result) {
}
}

/// @returns Whether the given node may be pruned from its parent in order to be
/// represented implicitly rather than by explicit occurrence in the tree
/// @returns Whether the given node may be pruned from its parent in order to
/// be represented implicitly rather than by explicit occurrence in the tree
template <typename T>
bool may_prune(const EvalNode<T> &tree) {
// Tree must represent a product and must itself not be a leaf (pruning that
Expand Down Expand Up @@ -565,17 +580,18 @@ bool may_prune(const EvalNode<T> &tree) {
///
/// Preprocesses the provided binary tree by
/// - removing explicit appearances of scalar leafs. We don't want them to
/// be represented in the tree. Instead, we keep track of them in a different
/// way in order to be able to give scalar factors alongside the actual
/// tensor contraction they are supposed to scale (this is necessary
/// be represented in the tree. Instead, we keep track of them in a
/// different way in order to be able to give scalar factors alongside the
/// actual tensor contraction they are supposed to scale (this is necessary
/// as there are backends which only support scaling in this context)
/// - rebalance the tree such that for any given non-leaf node, its left
/// subtree is always larger (or equally large) than its right one.
/// This ensures that we have to have the least amount of tensors loaded
/// at the same time, when generating code for a backend which only supports
/// stack-like memory allocations (e.g. when A is allocated before B, B
/// must be deleted before A can be deleted).
/// - Rename intermediate tensors that have the same name and describe the same
/// - Rename intermediate tensors that have the same name and describe the
/// same
/// tensor block, which are required as two separate entities at the same
/// time when evaluating the tree (thus a single tensor object is
/// insufficient).
Expand Down Expand Up @@ -603,8 +619,9 @@ class PreprocessVisitor {
prune_redundant_intermediate(tree);
}

// It is important to track_usage AFTER prune_redundant_intermediate as
// the latter might change the result expression on the current node
// It is important to track_usage AFTER prune_redundant_intermediate
// as the latter might change the result expression on the current
// node
track_usage(tree, m_result);
break;
case TreeTraversal::PostOrder:
Expand Down Expand Up @@ -729,25 +746,40 @@ class PreprocessVisitor {
}

void release_used_terms(ExportNode<T> &node) {
// Mark tensors/variables as no longer in use
if (node.left()->is_tensor()) {
const Tensor &tensor = node.left()->as_tensor();
auto handle_tensor = [&](const Tensor &tensor) {
SEQUANT_ASSERT(m_result.tensorReferences[tensor] > 0);
m_result.tensorReferences[tensor]--;
} else if (node.left()->is_variable()) {
const Variable &variable = node.left()->as_variable();
};
auto handle_variable = [&](const Variable &variable) {
SEQUANT_ASSERT(m_result.variableReferences[variable] > 0);
m_result.variableReferences[variable]--;
};

// Mark tensors/variables as no longer in use
if (node.left()->is_tensor()) {
handle_tensor(node.left()->as_tensor());
} else if (node.left()->is_variable()) {
handle_variable(node.left()->as_variable());
} else if (node.left()->is_power()) {
const Power &power = node.left()->as_power();
if (power.base().is<Tensor>()) {
handle_tensor(power.base().as<Tensor>());
} else if (power.base().is<Variable>()) {
handle_variable(power.base().as<Variable>());
}
}

if (node.right()->is_tensor()) {
const Tensor &tensor = node.right()->as_tensor();
SEQUANT_ASSERT(m_result.tensorReferences[tensor] > 0);
m_result.tensorReferences[tensor]--;
handle_tensor(node.right()->as_tensor());
} else if (node.right()->is_variable()) {
const Variable &variable = node.right()->as_variable();
SEQUANT_ASSERT(m_result.variableReferences[variable] > 0);
m_result.variableReferences[variable]--;
handle_variable(node.right()->as_variable());
} else if (node.right()->is_power()) {
const Power &power = node.right()->as_power();
if (power.base().is<Tensor>()) {
handle_tensor(power.base().as<Tensor>());
} else if (power.base().is<Variable>()) {
handle_variable(power.base().as<Variable>());
}
}
}

Expand Down Expand Up @@ -797,6 +829,7 @@ void export_expression(ExportNode<T> &expression, Generator<Context> &generator,

detail::GenerationVisitor<T, Context> visitor(generator, ctx,
pp_result.scalarFactors);

expression.visit(
[&visitor](const FullBinaryNode<T> &node, TreeTraversal context) {
visitor(node, context);
Expand Down Expand Up @@ -842,8 +875,8 @@ void declare_all(const Range &range, Generator<Context> &generator,
}

/// Combines the known T from the range of
/// PreprocessResults and clears the respective fields of the individual result
/// objects.
/// PreprocessResults and clears the respective fields of the individual
/// result objects.
/// @returns The combined set of known objects
template <typename T, typename Compare = std::less<T>, typename Range>
requires std::ranges::range<Range> &&
Expand Down
Loading
Loading