From ba062059d3424d5aa83c2e8dbcebc29dd7bf7751 Mon Sep 17 00:00:00 2001 From: architd Date: Mon, 15 Jun 2026 14:39:46 -0700 Subject: [PATCH 1/3] fix: preserve system IDs during inflight refill Signed-off-by: architd --- nvalchemi/dynamics/base.py | 11 +++++++++++ test/dynamics/test_inflight.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/nvalchemi/dynamics/base.py b/nvalchemi/dynamics/base.py index 690ee2cc..1833bfa9 100644 --- a/nvalchemi/dynamics/base.py +++ b/nvalchemi/dynamics/base.py @@ -2047,6 +2047,17 @@ def refill_check(self, batch: Batch, exit_status: int) -> Batch | None: device = result.device for key, default_fn in self._bookkeeping_keys.items(): new_tensor = default_fn(n_total, device) + # Preserve values already carried by appended replacements before + # restoring the prefix for systems that stayed active. + result_vals = getattr(result, key, None) + if result_vals is not None: + result_vals = ( + result_vals.unsqueeze(-1) + if result_vals.dim() == 1 + else result_vals + ) + if result_vals.shape == new_tensor.shape: + new_tensor.copy_(result_vals) remaining_vals = getattr(batch, key, None) if remaining_vals is not None and n_remaining > 0: src = remaining_vals[remaining_indices] diff --git a/test/dynamics/test_inflight.py b/test/dynamics/test_inflight.py index 7a7da0ef..218f325f 100644 --- a/test/dynamics/test_inflight.py +++ b/test/dynamics/test_inflight.py @@ -635,6 +635,22 @@ def test_refill_writes_bookkeeping_to_storage(self) -> None: assert "status" in result + def test_refill_preserves_replacement_system_id(self) -> None: + """Replacement systems keep sampler-assigned system IDs after refill.""" + dataset = MockDataset([(1, 0)] * 3) + sampler = SizeAwareSampler(dataset, max_atoms=2) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1] + + batch["status"] = torch.tensor([[1], [1]]) + result = dynamics.refill_check(batch, exit_status=1) + + assert result is not None + assert result.system_id.view(-1).tolist() == [2] + assert result.status.view(-1).tolist() == [0] + def test_refill_partial_replacement(self) -> None: """When sampler has fewer replacements than graduated, batch shrinks. From dbcef64e40acfb82498aa297e2c6151d499a6f30 Mon Sep 17 00:00:00 2001 From: architd Date: Mon, 15 Jun 2026 15:56:37 -0700 Subject: [PATCH 2/3] fix: validate refill bookkeeping and cover mixed IDs Signed-off-by: architd --- nvalchemi/dynamics/base.py | 5 +++++ test/dynamics/test_inflight.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/nvalchemi/dynamics/base.py b/nvalchemi/dynamics/base.py index 1833bfa9..082c0d12 100644 --- a/nvalchemi/dynamics/base.py +++ b/nvalchemi/dynamics/base.py @@ -2058,6 +2058,11 @@ def refill_check(self, batch: Batch, exit_status: int) -> Batch | None: ) if result_vals.shape == new_tensor.shape: new_tensor.copy_(result_vals) + else: + raise RuntimeError( + f"Bookkeeping key '{key}' has shape {result_vals.shape} " + f"after refill, expected {new_tensor.shape}." + ) remaining_vals = getattr(batch, key, None) if remaining_vals is not None and n_remaining > 0: src = remaining_vals[remaining_indices] diff --git a/test/dynamics/test_inflight.py b/test/dynamics/test_inflight.py index 218f325f..df8d115d 100644 --- a/test/dynamics/test_inflight.py +++ b/test/dynamics/test_inflight.py @@ -651,6 +651,22 @@ def test_refill_preserves_replacement_system_id(self) -> None: assert result.system_id.view(-1).tolist() == [2] assert result.status.view(-1).tolist() == [0] + def test_refill_preserves_mixed_system_ids(self) -> None: + """Remaining and replacement systems both keep their system IDs.""" + dataset = MockDataset([(1, 0)] * 3) + sampler = SizeAwareSampler(dataset, max_atoms=2) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1] + + batch["status"] = torch.tensor([[1], [0]]) + result = dynamics.refill_check(batch, exit_status=1) + + assert result is not None + assert result.system_id.view(-1).tolist() == [1, 2] + assert result.status.view(-1).tolist() == [0, 0] + def test_refill_partial_replacement(self) -> None: """When sampler has fewer replacements than graduated, batch shrinks. From 4e7d30aa52614b9f6a6f80e119e0b976229b72c6 Mon Sep 17 00:00:00 2001 From: architd Date: Wed, 24 Jun 2026 18:17:56 -0700 Subject: [PATCH 3/3] test: cover refill with complex status bookkeeping Add refill coverage for a multistage-like status case where only the status-2 system is replaced while lower-status systems keep their IDs. Add coverage for simultaneous replacement of multiple graduated systems, preserving the remaining system ID and assigning new sampler IDs to both replacements. Add coverage for replacement-carried bookkeeping by verifying a dataset-provided nonzero status survives refill. Signed-off-by: architd --- test/dynamics/test_inflight.py | 60 ++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/dynamics/test_inflight.py b/test/dynamics/test_inflight.py index df8d115d..a40e35be 100644 --- a/test/dynamics/test_inflight.py +++ b/test/dynamics/test_inflight.py @@ -667,6 +667,66 @@ def test_refill_preserves_mixed_system_ids(self) -> None: assert result.system_id.view(-1).tolist() == [1, 2] assert result.status.view(-1).tolist() == [0, 0] + def test_refill_preserves_system_ids_with_multistage_statuses(self) -> None: + """A status-2 system is replaced while lower-status systems remain.""" + dataset = MockDataset([(1, 0)] * 5) + sampler = SizeAwareSampler(dataset, max_atoms=3) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1, 2] + + batch["status"] = torch.tensor([[0], [1], [2]]) + result = dynamics.refill_check(batch, exit_status=2) + + assert result is not None + assert result.system_id.view(-1).tolist() == [0, 1, 3] + assert result.status.view(-1).tolist() == [0, 1, 0] + + def test_refill_preserves_system_ids_with_multiple_replacements(self) -> None: + """Multiple graduated systems are replaced in a single refill.""" + dataset = MockDataset([(1, 0)] * 5) + sampler = SizeAwareSampler(dataset, max_atoms=3) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1, 2] + + batch["status"] = torch.tensor([[1], [0], [1]]) + result = dynamics.refill_check(batch, exit_status=1) + + assert result is not None + assert result.system_id.view(-1).tolist() == [1, 3, 4] + assert result.status.view(-1).tolist() == [0, 0, 0] + + def test_refill_preserves_replacement_status_from_dataset(self) -> None: + """Replacement systems keep nonzero dataset-provided entry status.""" + + class StatusDataset(MockDataset): + def __getitem__(self, index: int) -> tuple[AtomicData, dict]: + data, metadata = super().__getitem__(index) + data.add_system_property( + "status", torch.tensor([[1]], dtype=torch.long) + ) + return data, metadata + + dataset = StatusDataset([(1, 0)] * 5) + sampler = SizeAwareSampler(dataset, max_atoms=3) + dynamics = BaseDynamics(model=self.model, sampler=sampler, device_type="cpu") + + batch = sampler.build_initial_batch() + assert batch.system_id.view(-1).tolist() == [0, 1, 2] + # Initial batching resets status to zero, but refill preserves + # replacement status already carried by the appended batch. + assert batch.status.view(-1).tolist() == [0, 0, 0] + + batch["status"] = torch.tensor([[0], [2], [1]]) + result = dynamics.refill_check(batch, exit_status=2) + + assert result is not None + assert result.system_id.view(-1).tolist() == [0, 2, 3] + assert result.status.view(-1).tolist() == [0, 1, 1] + def test_refill_partial_replacement(self) -> None: """When sampler has fewer replacements than graduated, batch shrinks.