Graph support in AMS with simple GNN algebraic test.#194
Merged
Conversation
Contributor
There was a problem hiding this comment.
Cpp-linter Review
Used clang-format v18.1.8
Click here for the full clang-format patch
diff --git a/src/AMSlib/AMSGraph.cpp b/src/AMSlib/AMSGraph.cpp
index 58e73d3..b5a6769 100644
--- a/src/AMSlib/AMSGraph.cpp
+++ b/src/AMSlib/AMSGraph.cpp
@@ -48,2 +48,2 @@ static void requireRank(const AMSTensor& tensor,
- throw std::runtime_error("AMSHomogeneousGraph " + name +
- " must be rank " + std::to_string(rank) + ".");
+ throw std::runtime_error("AMSHomogeneousGraph " + name + " must be rank " +
+ std::to_string(rank) + ".");
@@ -78,2 +78 @@ AMSTensor* AMSTensorFieldMap::find(const std::string& name) noexcept
-const AMSTensor* AMSTensorFieldMap::find(
- const std::string& name) const noexcept
+const AMSTensor* AMSTensorFieldMap::find(const std::string& name) const noexcept
diff --git a/src/AMSlib/include/AMSGraph.hpp b/src/AMSlib/include/AMSGraph.hpp
index c8ba873..bda6d40 100644
--- a/src/AMSlib/include/AMSGraph.hpp
+++ b/src/AMSlib/include/AMSGraph.hpp
@@ -16 +16,2 @@ using AMSTensorMap = std::unordered_map<std::string, AMSTensor>;
-class AMSTensorFieldMap {
+class AMSTensorFieldMap
+{
@@ -174,4 +175,3 @@ struct AMSHeterogeneousGraphFields {
- AMSHeterogeneousGraphFields(AMSHeterogeneousGraphFields&&) noexcept =
- default;
- AMSHeterogeneousGraphFields& operator=(AMSHeterogeneousGraphFields&&)
- noexcept = default;
+ AMSHeterogeneousGraphFields(AMSHeterogeneousGraphFields&&) noexcept = default;
+ AMSHeterogeneousGraphFields& operator=(
+ AMSHeterogeneousGraphFields&&) noexcept = default;
diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp
index 162a69b..11ee4d2 100644
--- a/src/AMSlib/wf/interface.cpp
+++ b/src/AMSlib/wf/interface.cpp
@@ -141,2 +141,2 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
- rm.copy(src.data_ptr<float>(), rType, out.data<float>(), rType,
- src.numel());
+ rm.copy(
+ src.data_ptr<float>(), rType, out.data<float>(), rType, src.numel());
@@ -147 +147,4 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
- rm.copy(src.data_ptr<double>(), rType, out.data<double>(), rType,
+ rm.copy(src.data_ptr<double>(),
+ rType,
+ out.data<double>(),
+ rType,
@@ -153 +156,4 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
- rm.copy(src.data_ptr<int32_t>(), rType, out.data<int32_t>(), rType,
+ rm.copy(src.data_ptr<int32_t>(),
+ rType,
+ out.data<int32_t>(),
+ rType,
@@ -159 +165,4 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
- rm.copy(src.data_ptr<int64_t>(), rType, out.data<int64_t>(), rType,
+ rm.copy(src.data_ptr<int64_t>(),
+ rType,
+ out.data<int64_t>(),
+ rType,
@@ -239,2 +248 @@ static torch::Tensor amsTensorToTorchModelInput(const ams::AMSTensor& tensor,
- if (out.device().type() != model_device ||
- out.scalar_type() != dtype) {
+ if (out.device().type() != model_device || out.scalar_type() != dtype) {
@@ -257,4 +265,3 @@ static void requireOutputFirstDim(const torch::Tensor& tensor,
- entity +
- " fields has first dimension " +
- std::to_string(tensor.sizes()[0]) +
- ", expected " + std::to_string(expected) + ".");
+ entity + " fields has first dimension " +
+ std::to_string(tensor.sizes()[0]) + ", expected " +
+ std::to_string(expected) + ".");
@@ -524,2 +531,4 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- auto torch_graph = amsToTorchHomogeneousGraph(
- graph, executor->MLModel->torch_device, executor->MLModel->torch_dtype);
+ auto torch_graph =
+ amsToTorchHomogeneousGraph(graph,
+ executor->MLModel->torch_device,
+ executor->MLModel->torch_dtype);
@@ -542,4 +551,5 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Malformed homogeneous graph output key '" + key +
- "'. Expected 'node:<field>', 'edge:<field>', or "
- "'global:<field>'.");
+ throw std::runtime_error("Malformed homogeneous graph output key '" +
+ key +
+ "'. Expected 'node:<field>', 'edge:<field>', "
+ "or "
+ "'global:<field>'.");
@@ -559,3 +569,4 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Malformed homogeneous graph output key '" + key +
- "'. Expected entity prefix 'node', 'edge', or 'global'.");
+ throw std::runtime_error("Malformed homogeneous graph output key '" +
+ key +
+ "'. Expected entity prefix 'node', 'edge', or "
+ "'global'.");
@@ -567,2 +578,2 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(std::string("Homogeneous graph surrogate failed: ") +
- e.what());
+ throw std::runtime_error(
+ std::string("Homogeneous graph surrogate failed: ") + e.what());
@@ -602,3 +613,3 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Heterogeneous graph output key '" + key +
- "' references an unknown or empty node store.");
+ throw std::runtime_error("Heterogeneous graph output key '" + key +
+ "' references an unknown or empty node "
+ "store.");
@@ -608,3 +619,3 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Heterogeneous graph output key '" + key +
- "' cannot infer node count from a scalar input field.");
+ throw std::runtime_error("Heterogeneous graph output key '" + key +
+ "' cannot infer node count from a scalar "
+ "input field.");
@@ -614,4 +625,5 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- outputs.getOrCreateNodeStore(parts[1]).insert(
- parts[2], torchToAMSTensorCopy(tensor));
- } else if (parts.size() == 3 && parts[0] == "edge" &&
- !parts[1].empty() && !parts[2].empty()) {
+ outputs.getOrCreateNodeStore(parts[1]).insert(parts[2],
+ torchToAMSTensorCopy(
+ tensor));
+ } else if (parts.size() == 3 && parts[0] == "edge" && !parts[1].empty() &&
+ !parts[2].empty()) {
@@ -621,3 +633,2 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Heterogeneous graph output key '" + key +
- "' references an unknown edge store.");
+ throw std::runtime_error("Heterogeneous graph output key '" + key +
+ "' references an unknown edge store.");
@@ -627,3 +638,4 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Heterogeneous graph edge output key '" + key +
- "' requires an input edge_index tensor with shape [2, E].");
+ throw std::runtime_error("Heterogeneous graph edge output key '" +
+ key +
+ "' requires an input edge_index tensor with "
+ "shape [2, E].");
@@ -632,2 +644,3 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- outputs.getOrCreateEdgeStore(edge_type).insert(
- parts[2], torchToAMSTensorCopy(tensor));
+ outputs.getOrCreateEdgeStore(edge_type).insert(parts[2],
+ torchToAMSTensorCopy(
+ tensor));
@@ -639,4 +652,5 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
- throw std::runtime_error(
- "Malformed heterogeneous graph output key '" + key +
- "'. Expected 'node:<node_type>:<field>', "
- "'edge:<src>__<rel>__<dst>:<field>', or 'global:<field>'.");
+ throw std::runtime_error("Malformed heterogeneous graph output key '" +
+ key +
+ "'. Expected 'node:<node_type>:<field>', "
+ "'edge:<src>__<rel>__<dst>:<field>', or "
+ "'global:<field>'.");
diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
index 3078b77..578035e 100644
--- a/tests/AMSlib/ams_interface/test_graph_fallback.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
@@ -2 +1,0 @@
-
@@ -128,2 +127,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- CATCH_REQUIRE_NOTHROW(AMSHomogeneousGraph(
- makeNodeFeatures(), makeEdgeIndex32(), makeEdgeFeatures()));
+ CATCH_REQUIRE_NOTHROW(AMSHomogeneousGraph(makeNodeFeatures(),
+ makeEdgeIndex32(),
+ makeEdgeFeatures()));
@@ -131,4 +131,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- CATCH_REQUIRE_THROWS_AS(AMSHomogeneousGraph(
- makeTensor<int64_t>({3, 2}),
- makeEdgeIndex64(),
- makeEdgeFeatures()),
+ CATCH_REQUIRE_THROWS_AS(AMSHomogeneousGraph(makeTensor<int64_t>({3, 2}),
+ makeEdgeIndex64(),
+ makeEdgeFeatures()),
@@ -137,2 +136,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeEdgeFeatures()),
+ makeEdgeIndex64(),
+ makeEdgeFeatures()),
@@ -141,2 +140,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeTensor<float>({2, 2}),
- makeEdgeFeatures()),
+ makeTensor<float>({2, 2}),
+ makeEdgeFeatures()),
@@ -145,2 +144,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeTensor<int64_t>({2}),
- makeEdgeFeatures()),
+ makeTensor<int64_t>({2}),
+ makeEdgeFeatures()),
@@ -149,2 +148,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeTensor<int64_t>({3, 2}),
- makeEdgeFeatures()),
+ makeTensor<int64_t>({3, 2}),
+ makeEdgeFeatures()),
@@ -153,2 +152,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeTensor<float>({2})),
+ makeEdgeIndex64(),
+ makeTensor<float>({2})),
@@ -157,2 +156,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeEdgeFeatures(3)),
+ makeEdgeIndex64(),
+ makeEdgeFeatures(3)),
@@ -161,2 +160,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeTensor<int64_t>({2, 1})),
+ makeEdgeIndex64(),
+ makeTensor<int64_t>({2, 1})),
@@ -165,3 +164,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeEdgeFeatures(),
- makeTensor<float>({2})),
+ makeEdgeIndex64(),
+ makeEdgeFeatures(),
+ makeTensor<float>({2})),
@@ -170,3 +169,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeEdgeFeatures(),
- makeTensor<float>({2, 1})),
+ makeEdgeIndex64(),
+ makeEdgeFeatures(),
+ makeTensor<float>({2, 1})),
@@ -175,3 +174,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
- makeEdgeIndex64(),
- makeEdgeFeatures(),
- makeTensor<int64_t>({1, 1})),
+ makeEdgeIndex64(),
+ makeEdgeFeatures(),
+ makeTensor<int64_t>({1, 1})),
@@ -191,13 +190,13 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]")
- HomogeneousGraphDomainFn callback =
- [&](const AMSHomogeneousGraph& g, AMSHomogeneousGraphFields& outputs) {
- callback_invoked = true;
- CATCH_REQUIRE(g.node_features.shape()[0] == 3);
- CATCH_REQUIRE(g.edge_index.shape()[0] == 2);
-
- auto out = makeTensor<float>({3, 1});
- float* out_data = out.data<float>();
- out_data[0] = 2.0f;
- out_data[1] = 4.0f;
- out_data[2] = 6.0f;
- outputs.node_fields.set("prediction", std::move(out));
- };
+ HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g,
+ AMSHomogeneousGraphFields& outputs) {
+ callback_invoked = true;
+ CATCH_REQUIRE(g.node_features.shape()[0] == 3);
+ CATCH_REQUIRE(g.edge_index.shape()[0] == 2);
+
+ auto out = makeTensor<float>({3, 1});
+ float* out_data = out.data<float>();
+ out_data[0] = 2.0f;
+ out_data[1] = 4.0f;
+ out_data[2] = 6.0f;
+ outputs.node_fields.set("prediction", std::move(out));
+ };
diff --git a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
index f95e20e..0f11c61 100644
--- a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
@@ -3 +2,0 @@
-
@@ -120,2 +119,4 @@ static void runHomogeneousSurrogate(const char* domain_name)
- auto model = AMSRegisterAbstractModel(
- domain_name, 0.5, HOMOGENEOUS_GRAPH_MODEL_PATH, false);
+ auto model = AMSRegisterAbstractModel(domain_name,
+ 0.5,
+ HOMOGENEOUS_GRAPH_MODEL_PATH,
+ false);
@@ -126,5 +127,5 @@ static void runHomogeneousSurrogate(const char* domain_name)
- HomogeneousGraphDomainFn callback =
- [&](const AMSHomogeneousGraph&, AMSHomogeneousGraphFields& outputs) {
- callback_invoked = true;
- outputs.node_fields.set("prediction", makeTensor<float>({4, 1}));
- };
+ HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph&,
+ AMSHomogeneousGraphFields& outputs) {
+ callback_invoked = true;
+ outputs.node_fields.set("prediction", makeTensor<float>({4, 1}));
+ };
@@ -204,7 +205,7 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
- HomogeneousGraphDomainFn callback =
- [&](const AMSHomogeneousGraph&, AMSHomogeneousGraphFields& outputs) {
- callback_invoked = true;
- auto out = makeTensor<float>({4, 1});
- out.data<float>()[0] = 42.0f;
- outputs.node_fields.set("prediction", std::move(out));
- };
+ HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph&,
+ AMSHomogeneousGraphFields& outputs) {
+ callback_invoked = true;
+ auto out = makeTensor<float>({4, 1});
+ out.data<float>()[0] = 42.0f;
+ outputs.node_fields.set("prediction", std::move(out));
+ };
@@ -216,2 +217 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
- CATCH_REQUIRE(outputs.node_fields.at("prediction").data<float>()[0] ==
- 42.0f);
+ CATCH_REQUIRE(outputs.node_fields.at("prediction").data<float>()[0] == 42.0f);
@@ -225,2 +225,2 @@ CATCH_TEST_CASE("Malformed homogeneous graph surrogate outputs fail loudly",
- HomogeneousGraphDomainFn callback =
- [](const AMSHomogeneousGraph&, AMSHomogeneousGraphFields&) {};
+ HomogeneousGraphDomainFn callback = [](const AMSHomogeneousGraph&,
+ AMSHomogeneousGraphFields&) {};
Have any feedback or feature suggestions? Share it here.
This was referenced Jun 11, 2026
240d9f6 to
3327f50
Compare
3327f50 to
9e6355a
Compare
- Conversion operation from/to AMS from/to Torch graph representation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Add tests for homogeneous and heterogenous graph surrogate models.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
9e6355a to
1ebed27
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.