diff --git a/src/dlstbx/services/archiver.py b/src/dlstbx/services/archiver.py index 2f93e9c2a..cc989cfbc 100644 --- a/src/dlstbx/services/archiver.py +++ b/src/dlstbx/services/archiver.py @@ -6,10 +6,13 @@ import os.path import xml.etree.ElementTree as ET from datetime import datetime +from pathlib import Path import workflows.recipe from workflows.services.common_service import CommonService +from dlstbx.util import INDUSTRIAL_CODES + class Dropfile: """A class encapsulating the XML dropfile tree as it is built up.""" @@ -132,6 +135,16 @@ def rangifier(numbers): group = list(group) yield group[0][1], group[-1][1] + def visit_is_archivable( + self, visit: str, allowed_industrial_visit_codes: tuple[str, ...] + ) -> bool: + """Return false if visit has an industrial code that is not in the list of allowed industrial visit codes.""" + if visit.startswith(tuple(INDUSTRIAL_CODES)) and not visit.startswith( + allowed_industrial_visit_codes + ): + return False + return True + def archive_dcid(self, rw, header, message): """Archive collected datafiles connected to a data collection.""" @@ -141,6 +154,21 @@ def archive_dcid(self, rw, header, message): # Extract parameters from the recipe params = rw.recipe_step["parameters"] + + allowed_industrial_visit_codes = tuple( + self.config.storage.get( + "zocalo.archiver.allowed_industrial_visit_codes", [] + ) + ) + if not self.visit_is_archivable( + params["visit"], allowed_industrial_visit_codes + ): + self.log.info( + f"Skipping archiving of {params['pattern']} because it is from a forbidden visit" + ) + self._transport.transaction_commit(txn) + return + self.log.info("Attempting to archive %s", params["pattern"]) settings = params.copy() @@ -151,10 +179,12 @@ def archive_dcid(self, rw, header, message): file_range_limit = int(settings.get("limit-files", 0)) - filepaths = params["pattern"].split("/") - _, _, beamline, _, _, visit_id = filepaths[0:6] + filepath = Path(params["pattern"]) + dataset_name = Path(*filepath.parts[6:-1]).as_posix() or "topdir" + beamline = params["beamline"] + visit_id = params["visit"] - df = Dropfile(visit_id.upper(), beamline, "/".join(filepaths[6:-1]) or "topdir") + df = Dropfile(visit_id.upper(), beamline, dataset_name) message_out = {"success": 0, "failed": 0} files_not_found = [] @@ -287,26 +317,28 @@ def archive_filelist(self, rw, header, message): ) file_range_limit = int(params.get("limit-files", 0)) - filepaths = filelist[0].split("/") - beamline = "unknown" - visit_id = "unknown" - try: - if filepaths[1:3] == ["dls", "mx"]: - beamline = "i02-2" # VMXi currently only beamline with new visit path structure - else: - beamline = filepaths[2] - visit_id = filepaths[5] - except IndexError: - pass - visit_id = params.get("visit", visit_id) - beamline = params.get("beamline", beamline) - # Conditionally acknowledge receipt of the message txn = self._transport.transaction_begin(subscription_id=header["subscription"]) self._transport.ack(header, transaction=txn) + visit_id = params["visit"] + beamline = params["beamline"] + allowed_industrial_visit_codes = tuple( + self.config.storage.get( + "zocalo.archiver.allowed_industrial_visit_codes", [] + ) + ) + if not self.visit_is_archivable(visit_id, allowed_industrial_visit_codes): + self.log.info( + f"Skipping archiving of {filelist} because it is from a forbidden visit" + ) + self._transport.transaction_commit(txn) + return + + filepath = Path(filelist[0]) + dataset_name = Path(*filepath.parts[6:-1]).as_posix() or "topdir" # Archive files - df = Dropfile(visit_id.upper(), beamline, "/".join(filepaths[6:-1]) or "topdir") + df = Dropfile(visit_id.upper(), beamline, dataset_name) message_out = {"success": 0, "failed": 0} files_not_found = [] diff --git a/tests/services/test_archiver.py b/tests/services/test_archiver.py new file mode 100644 index 000000000..e04d53dad --- /dev/null +++ b/tests/services/test_archiver.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +from unittest import mock + +import pytest +from workflows.recipe.wrapper import RecipeWrapper +from workflows.transport.offline_transport import OfflineTransport + +from dlstbx.services.archiver import DLSArchiver, Dropfile + + +@pytest.fixture +def mock_transport(): + """Create a mock transport object.""" + transport = mock.Mock() + txn = mock.Mock() + transport.transaction_begin.return_value = txn + return transport, txn + + +@pytest.fixture +def archiver_service(mock_transport): + """Create a DLSArchiver service instance with mocked transport and config.""" + transport, _ = mock_transport + service = DLSArchiver() + service._transport = transport + service._environment = { + "config": mock.Mock( + storage={"zocalo.archiver.allowed_industrial_visit_codes": ["sw"]}, + ) + } + return service + + +def create_test_files(tmp_path, file_list): + """Helper to create test files in tmp_path.""" + files = [] + for filename in file_list: + filepath = tmp_path / filename + filepath.parent.mkdir(parents=True, exist_ok=True) + filepath.write_text(f"Content of {filename}") + files.append(str(filepath)) + return files + + +def create_recipe_message_filelist(parameters, files): + """Create a recipe message for filelist archiving.""" + message = { + "recipe": { + "1": { + "service": "DLS Archiver", + "queue": "archive.filelist", + "parameters": parameters, + }, + }, + "recipe-pointer": "1", + "recipe-path": [], + "environment": { + "ID": mock.sentinel.GUID, + "source": mock.sentinel.source, + "timestamp": mock.sentinel.timestamp, + }, + "payload": {"filelist": files}, + } + return message + + +def create_recipe_message_dcid(parameters): + """Create a recipe message for dcid archiving.""" + message = { + "recipe": { + "1": { + "service": "DLS Archiver", + "queue": "archive.pattern", + "parameters": parameters, + }, + }, + "recipe-pointer": "1", + "recipe-path": [], + "environment": { + "ID": mock.sentinel.GUID, + "source": mock.sentinel.source, + "timestamp": mock.sentinel.timestamp, + }, + "payload": {}, + } + return message + + +class TestArchiverFilelist: + """Tests for the archive_filelist method.""" + + def test_archive_filelist_creates_dropfile( + self, archiver_service, mock_transport, tmp_path + ): + """Test that a dropfile is created when valid files are provided.""" + transport, txn = mock_transport + archiver_service._transport = transport + + # Create test files + files = create_test_files(tmp_path, ["file1.txt", "file2.txt"]) + + # Create recipe parameters + dropfile_path = tmp_path / "dropfile.xml" + parameters = { + "visit": "cm00001-1", + "beamline": "i03", + "dropfile": str(dropfile_path), + "filelist": files, + } + + # Create message + message = create_recipe_message_filelist(parameters, files) + header = { + "message-id": "test-message-id", + "subscription": "test-subscription", + } + + # Create wrapper and call method + rw = RecipeWrapper(message=message, transport=OfflineTransport()) + rw._transport = OfflineTransport() + archiver_service.archive_filelist(rw, header, message.get("payload")) + + # Verify dropfile was created + assert dropfile_path.exists(), "Dropfile should be created" + + # Verify content is XML + content = dropfile_path.read_text() + assert "' in xml_str + assert "CM00001-1" in xml_str + assert "i03" in xml_str + assert "testdata.h5" in xml_str + assert f"{str(testfile)}" in xml_str + + def test_dropfile_multiple_files(self, tmp_path): + """Test dropfile with multiple files.""" + # Create test files + testfiles = [] + for i in range(3): + testfile = tmp_path / f"data_{i}.h5" + testfile.write_text(f"test data {i}") + testfiles.append(str(testfile)) + + # Create dropfile + df = Dropfile("cm00001-1", "i03", "testdir") + for testfile in testfiles: + df.add(testfile) + df.close() + + # Get XML string + xml_str = df.to_string().decode("latin-1") + + # Verify all files are in the XML + for testfile in testfiles: + assert f"data_{testfiles.index(testfile)}.h5" in xml_str + + def test_dropfile_visit_uppercase_conversion(self): + """Test that visit codes are converted to uppercase in dropfile.""" + df = Dropfile("cm00001-1", "i03", "testdir") + df.close() + + xml_str = df.to_string().decode("latin-1") + assert "CM00001-1" in xml_str