Skip to content

Graph support in AMS with simple GNN algebraic test.#194

Merged
lpottier merged 11 commits into
developfrom
yohann/graph-API-refactor
Jun 23, 2026
Merged

Graph support in AMS with simple GNN algebraic test.#194
lpottier merged 11 commits into
developfrom
yohann/graph-API-refactor

Conversation

@YohannDudouit

Copy link
Copy Markdown
Collaborator

No description provided.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cpp-linter Review

Used clang-format v18.1.8

Click here for the full clang-format patch
diff --git a/src/AMSlib/AMSGraph.cpp b/src/AMSlib/AMSGraph.cpp
index 58e73d3..b5a6769 100644
--- a/src/AMSlib/AMSGraph.cpp
+++ b/src/AMSlib/AMSGraph.cpp
@@ -48,2 +48,2 @@ static void requireRank(const AMSTensor& tensor,
-    throw std::runtime_error("AMSHomogeneousGraph " + name +
-                             " must be rank " + std::to_string(rank) + ".");
+    throw std::runtime_error("AMSHomogeneousGraph " + name + " must be rank " +
+                             std::to_string(rank) + ".");
@@ -78,2 +78 @@ AMSTensor* AMSTensorFieldMap::find(const std::string& name) noexcept
-const AMSTensor* AMSTensorFieldMap::find(
-    const std::string& name) const noexcept
+const AMSTensor* AMSTensorFieldMap::find(const std::string& name) const noexcept
diff --git a/src/AMSlib/include/AMSGraph.hpp b/src/AMSlib/include/AMSGraph.hpp
index c8ba873..bda6d40 100644
--- a/src/AMSlib/include/AMSGraph.hpp
+++ b/src/AMSlib/include/AMSGraph.hpp
@@ -16 +16,2 @@ using AMSTensorMap = std::unordered_map<std::string, AMSTensor>;
-class AMSTensorFieldMap {
+class AMSTensorFieldMap
+{
@@ -174,4 +175,3 @@ struct AMSHeterogeneousGraphFields {
-  AMSHeterogeneousGraphFields(AMSHeterogeneousGraphFields&&) noexcept =
-      default;
-  AMSHeterogeneousGraphFields& operator=(AMSHeterogeneousGraphFields&&)
-      noexcept = default;
+  AMSHeterogeneousGraphFields(AMSHeterogeneousGraphFields&&) noexcept = default;
+  AMSHeterogeneousGraphFields& operator=(
+      AMSHeterogeneousGraphFields&&) noexcept = default;
diff --git a/src/AMSlib/wf/interface.cpp b/src/AMSlib/wf/interface.cpp
index 162a69b..11ee4d2 100644
--- a/src/AMSlib/wf/interface.cpp
+++ b/src/AMSlib/wf/interface.cpp
@@ -141,2 +141,2 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
-      rm.copy(src.data_ptr<float>(), rType, out.data<float>(), rType,
-              src.numel());
+      rm.copy(
+          src.data_ptr<float>(), rType, out.data<float>(), rType, src.numel());
@@ -147 +147,4 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
-      rm.copy(src.data_ptr<double>(), rType, out.data<double>(), rType,
+      rm.copy(src.data_ptr<double>(),
+              rType,
+              out.data<double>(),
+              rType,
@@ -153 +156,4 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
-      rm.copy(src.data_ptr<int32_t>(), rType, out.data<int32_t>(), rType,
+      rm.copy(src.data_ptr<int32_t>(),
+              rType,
+              out.data<int32_t>(),
+              rType,
@@ -159 +165,4 @@ static ams::AMSTensor torchToAMSTensorCopy(const torch::Tensor& tensor)
-      rm.copy(src.data_ptr<int64_t>(), rType, out.data<int64_t>(), rType,
+      rm.copy(src.data_ptr<int64_t>(),
+              rType,
+              out.data<int64_t>(),
+              rType,
@@ -239,2 +248 @@ static torch::Tensor amsTensorToTorchModelInput(const ams::AMSTensor& tensor,
-  if (out.device().type() != model_device ||
-      out.scalar_type() != dtype) {
+  if (out.device().type() != model_device || out.scalar_type() != dtype) {
@@ -257,4 +265,3 @@ static void requireOutputFirstDim(const torch::Tensor& tensor,
-                             entity +
-                             " fields has first dimension " +
-                             std::to_string(tensor.sizes()[0]) +
-                             ", expected " + std::to_string(expected) + ".");
+                             entity + " fields has first dimension " +
+                             std::to_string(tensor.sizes()[0]) + ", expected " +
+                             std::to_string(expected) + ".");
@@ -524,2 +531,4 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-    auto torch_graph = amsToTorchHomogeneousGraph(
-        graph, executor->MLModel->torch_device, executor->MLModel->torch_dtype);
+    auto torch_graph =
+        amsToTorchHomogeneousGraph(graph,
+                                   executor->MLModel->torch_device,
+                                   executor->MLModel->torch_dtype);
@@ -542,4 +551,5 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-        throw std::runtime_error(
-            "Malformed homogeneous graph output key '" + key +
-            "'. Expected 'node:<field>', 'edge:<field>', or "
-            "'global:<field>'.");
+        throw std::runtime_error("Malformed homogeneous graph output key '" +
+                                 key +
+                                 "'. Expected 'node:<field>', 'edge:<field>', "
+                                 "or "
+                                 "'global:<field>'.");
@@ -559,3 +569,4 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-        throw std::runtime_error(
-            "Malformed homogeneous graph output key '" + key +
-            "'. Expected entity prefix 'node', 'edge', or 'global'.");
+        throw std::runtime_error("Malformed homogeneous graph output key '" +
+                                 key +
+                                 "'. Expected entity prefix 'node', 'edge', or "
+                                 "'global'.");
@@ -567,2 +578,2 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-    throw std::runtime_error(std::string("Homogeneous graph surrogate failed: ") +
-                             e.what());
+    throw std::runtime_error(
+        std::string("Homogeneous graph surrogate failed: ") + e.what());
@@ -602,3 +613,3 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-          throw std::runtime_error(
-              "Heterogeneous graph output key '" + key +
-              "' references an unknown or empty node store.");
+          throw std::runtime_error("Heterogeneous graph output key '" + key +
+                                   "' references an unknown or empty node "
+                                   "store.");
@@ -608,3 +619,3 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-          throw std::runtime_error(
-              "Heterogeneous graph output key '" + key +
-              "' cannot infer node count from a scalar input field.");
+          throw std::runtime_error("Heterogeneous graph output key '" + key +
+                                   "' cannot infer node count from a scalar "
+                                   "input field.");
@@ -614,4 +625,5 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-        outputs.getOrCreateNodeStore(parts[1]).insert(
-            parts[2], torchToAMSTensorCopy(tensor));
-      } else if (parts.size() == 3 && parts[0] == "edge" &&
-                 !parts[1].empty() && !parts[2].empty()) {
+        outputs.getOrCreateNodeStore(parts[1]).insert(parts[2],
+                                                      torchToAMSTensorCopy(
+                                                          tensor));
+      } else if (parts.size() == 3 && parts[0] == "edge" && !parts[1].empty() &&
+                 !parts[2].empty()) {
@@ -621,3 +633,2 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-          throw std::runtime_error(
-              "Heterogeneous graph output key '" + key +
-              "' references an unknown edge store.");
+          throw std::runtime_error("Heterogeneous graph output key '" + key +
+                                   "' references an unknown edge store.");
@@ -627,3 +638,4 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-          throw std::runtime_error(
-              "Heterogeneous graph edge output key '" + key +
-              "' requires an input edge_index tensor with shape [2, E].");
+          throw std::runtime_error("Heterogeneous graph edge output key '" +
+                                   key +
+                                   "' requires an input edge_index tensor with "
+                                   "shape [2, E].");
@@ -632,2 +644,3 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-        outputs.getOrCreateEdgeStore(edge_type).insert(
-            parts[2], torchToAMSTensorCopy(tensor));
+        outputs.getOrCreateEdgeStore(edge_type).insert(parts[2],
+                                                       torchToAMSTensorCopy(
+                                                           tensor));
@@ -639,4 +652,5 @@ bool tryGraphSurrogate(AMSWorkflow* executor,
-        throw std::runtime_error(
-            "Malformed heterogeneous graph output key '" + key +
-            "'. Expected 'node:<node_type>:<field>', "
-            "'edge:<src>__<rel>__<dst>:<field>', or 'global:<field>'.");
+        throw std::runtime_error("Malformed heterogeneous graph output key '" +
+                                 key +
+                                 "'. Expected 'node:<node_type>:<field>', "
+                                 "'edge:<src>__<rel>__<dst>:<field>', or "
+                                 "'global:<field>'.");
diff --git a/tests/AMSlib/ams_interface/test_graph_fallback.cpp b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
index 3078b77..578035e 100644
--- a/tests/AMSlib/ams_interface/test_graph_fallback.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_fallback.cpp
@@ -2 +1,0 @@
-
@@ -128,2 +127,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-  CATCH_REQUIRE_NOTHROW(AMSHomogeneousGraph(
-      makeNodeFeatures(), makeEdgeIndex32(), makeEdgeFeatures()));
+  CATCH_REQUIRE_NOTHROW(AMSHomogeneousGraph(makeNodeFeatures(),
+                                            makeEdgeIndex32(),
+                                            makeEdgeFeatures()));
@@ -131,4 +131,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-  CATCH_REQUIRE_THROWS_AS(AMSHomogeneousGraph(
-                              makeTensor<int64_t>({3, 2}),
-                              makeEdgeIndex64(),
-                              makeEdgeFeatures()),
+  CATCH_REQUIRE_THROWS_AS(AMSHomogeneousGraph(makeTensor<int64_t>({3, 2}),
+                                              makeEdgeIndex64(),
+                                              makeEdgeFeatures()),
@@ -137,2 +136,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeEdgeFeatures()),
+                                              makeEdgeIndex64(),
+                                              makeEdgeFeatures()),
@@ -141,2 +140,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeTensor<float>({2, 2}),
-                                             makeEdgeFeatures()),
+                                              makeTensor<float>({2, 2}),
+                                              makeEdgeFeatures()),
@@ -145,2 +144,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeTensor<int64_t>({2}),
-                                             makeEdgeFeatures()),
+                                              makeTensor<int64_t>({2}),
+                                              makeEdgeFeatures()),
@@ -149,2 +148,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeTensor<int64_t>({3, 2}),
-                                             makeEdgeFeatures()),
+                                              makeTensor<int64_t>({3, 2}),
+                                              makeEdgeFeatures()),
@@ -153,2 +152,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeTensor<float>({2})),
+                                              makeEdgeIndex64(),
+                                              makeTensor<float>({2})),
@@ -157,2 +156,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeEdgeFeatures(3)),
+                                              makeEdgeIndex64(),
+                                              makeEdgeFeatures(3)),
@@ -161,2 +160,2 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeTensor<int64_t>({2, 1})),
+                                              makeEdgeIndex64(),
+                                              makeTensor<int64_t>({2, 1})),
@@ -165,3 +164,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeEdgeFeatures(),
-                                             makeTensor<float>({2})),
+                                              makeEdgeIndex64(),
+                                              makeEdgeFeatures(),
+                                              makeTensor<float>({2})),
@@ -170,3 +169,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeEdgeFeatures(),
-                                             makeTensor<float>({2, 1})),
+                                              makeEdgeIndex64(),
+                                              makeEdgeFeatures(),
+                                              makeTensor<float>({2, 1})),
@@ -175,3 +174,3 @@ CATCH_TEST_CASE("AMSHomogeneousGraph validates construction",
-                                             makeEdgeIndex64(),
-                                             makeEdgeFeatures(),
-                                             makeTensor<int64_t>({1, 1})),
+                                              makeEdgeIndex64(),
+                                              makeEdgeFeatures(),
+                                              makeTensor<int64_t>({1, 1})),
@@ -191,13 +190,13 @@ CATCH_TEST_CASE("AMSExecute homogeneous graph fallback path", "[wf][graph]")
-  HomogeneousGraphDomainFn callback =
-      [&](const AMSHomogeneousGraph& g, AMSHomogeneousGraphFields& outputs) {
-        callback_invoked = true;
-        CATCH_REQUIRE(g.node_features.shape()[0] == 3);
-        CATCH_REQUIRE(g.edge_index.shape()[0] == 2);
-
-        auto out = makeTensor<float>({3, 1});
-        float* out_data = out.data<float>();
-        out_data[0] = 2.0f;
-        out_data[1] = 4.0f;
-        out_data[2] = 6.0f;
-        outputs.node_fields.set("prediction", std::move(out));
-      };
+  HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph& g,
+                                          AMSHomogeneousGraphFields& outputs) {
+    callback_invoked = true;
+    CATCH_REQUIRE(g.node_features.shape()[0] == 3);
+    CATCH_REQUIRE(g.edge_index.shape()[0] == 2);
+
+    auto out = makeTensor<float>({3, 1});
+    float* out_data = out.data<float>();
+    out_data[0] = 2.0f;
+    out_data[1] = 4.0f;
+    out_data[2] = 6.0f;
+    outputs.node_fields.set("prediction", std::move(out));
+  };
diff --git a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
index f95e20e..0f11c61 100644
--- a/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
+++ b/tests/AMSlib/ams_interface/test_graph_surrogate.cpp
@@ -3 +2,0 @@
-
@@ -120,2 +119,4 @@ static void runHomogeneousSurrogate(const char* domain_name)
-  auto model = AMSRegisterAbstractModel(
-      domain_name, 0.5, HOMOGENEOUS_GRAPH_MODEL_PATH, false);
+  auto model = AMSRegisterAbstractModel(domain_name,
+                                        0.5,
+                                        HOMOGENEOUS_GRAPH_MODEL_PATH,
+                                        false);
@@ -126,5 +127,5 @@ static void runHomogeneousSurrogate(const char* domain_name)
-  HomogeneousGraphDomainFn callback =
-      [&](const AMSHomogeneousGraph&, AMSHomogeneousGraphFields& outputs) {
-        callback_invoked = true;
-        outputs.node_fields.set("prediction", makeTensor<float>({4, 1}));
-      };
+  HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph&,
+                                          AMSHomogeneousGraphFields& outputs) {
+    callback_invoked = true;
+    outputs.node_fields.set("prediction", makeTensor<float>({4, 1}));
+  };
@@ -204,7 +205,7 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
-  HomogeneousGraphDomainFn callback =
-      [&](const AMSHomogeneousGraph&, AMSHomogeneousGraphFields& outputs) {
-        callback_invoked = true;
-        auto out = makeTensor<float>({4, 1});
-        out.data<float>()[0] = 42.0f;
-        outputs.node_fields.set("prediction", std::move(out));
-      };
+  HomogeneousGraphDomainFn callback = [&](const AMSHomogeneousGraph&,
+                                          AMSHomogeneousGraphFields& outputs) {
+    callback_invoked = true;
+    auto out = makeTensor<float>({4, 1});
+    out.data<float>()[0] = 42.0f;
+    outputs.node_fields.set("prediction", std::move(out));
+  };
@@ -216,2 +217 @@ CATCH_TEST_CASE("Graph surrogate with no model triggers fallback",
-  CATCH_REQUIRE(outputs.node_fields.at("prediction").data<float>()[0] ==
-                42.0f);
+  CATCH_REQUIRE(outputs.node_fields.at("prediction").data<float>()[0] == 42.0f);
@@ -225,2 +225,2 @@ CATCH_TEST_CASE("Malformed homogeneous graph surrogate outputs fail loudly",
-  HomogeneousGraphDomainFn callback =
-      [](const AMSHomogeneousGraph&, AMSHomogeneousGraphFields&) {};
+  HomogeneousGraphDomainFn callback = [](const AMSHomogeneousGraph&,
+                                         AMSHomogeneousGraphFields&) {};

Have any feedback or feature suggestions? Share it here.

Comment thread src/AMSlib/AMSGraph.cpp Outdated
Comment thread src/AMSlib/AMSGraph.cpp Outdated
Comment thread src/AMSlib/include/AMSGraph.hpp Outdated
Comment thread src/AMSlib/include/AMSGraph.hpp Outdated
Comment thread src/AMSlib/wf/interface.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated
Comment thread tests/AMSlib/ams_interface/test_graph_surrogate.cpp Outdated

@lpottier lpottier left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

YohannDudouit and others added 11 commits June 18, 2026 11:12
- Conversion operation from/to AMS from/to Torch graph representation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
- Add tests for homogeneous and heterogenous graph surrogate models.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@YohannDudouit YohannDudouit force-pushed the yohann/graph-API-refactor branch from 9e6355a to 1ebed27 Compare June 18, 2026 18:12
@lpottier lpottier self-assigned this Jun 23, 2026
@lpottier lpottier merged commit 17a6bd2 into develop Jun 23, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants