diff --git a/src/murfey/client/contexts/fib.py b/src/murfey/client/contexts/fib.py index fff544e20..684741835 100644 --- a/src/murfey/client/contexts/fib.py +++ b/src/murfey/client/contexts/fib.py @@ -4,8 +4,7 @@ import re import threading import xml.etree.ElementTree as ET -from dataclasses import dataclass -from datetime import datetime +from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Type, TypeVar @@ -25,12 +24,6 @@ lock = threading.Lock() -@dataclass -class MillingImage: - file: Path - timestamp: float - - def _number_from_name(name: str) -> int: """ In the AutoTEM and Maps workflows for the FIB, the sites and images are @@ -170,10 +163,10 @@ def _parse_boolean(text: str): # Map class attribute to element name # Paths are relative to the "Site" node "preparation": "PreparationSiteLocation/StagePosition/StagePosition", - "chunk_coincidence": "Parameters/ChunkCoincidenceStagePosition/StagePosition", "chunk": "ChunkSiteLocation/StagePosition/StagePosition", - "thinning_1": "Parameters/ThinningStagePosition/StagePosition", - "thinning_2": "ThinningSiteLocation/StagePosition/StagePosition", + "thinning_1": "ThinningSiteLocation/StagePosition/StagePosition", + "chunk_coincidence": "Parameters/ChunkCoincidenceStagePosition/StagePosition", + "thinning_2": "Parameters/ThinningStagePosition/StagePosition", } @@ -233,6 +226,13 @@ def _file_transferred_to( return destination +@dataclass +class FIBImage: + images: list[Path] = field(default_factory=list) + output_file: Path | None = None + is_submitted: bool = False + + class FIBContext(Context): def __init__( self, @@ -245,7 +245,7 @@ def __init__( self._basepath = basepath self._machine_config = machine_config self._site_info: dict[int, LamellaSiteInfo] = {} - self._drift_correction_images: dict[int, list[MillingImage]] = {} + self._drift_correction_images: dict[int, FIBImage] = {} def post_transfer( self, @@ -262,7 +262,6 @@ def post_transfer( # AutoTEM # ----------------------------------------------------------------------------- if self._acquisition_software == "autotem": - parts = transferred_file.parts if transferred_file.name == "ProjectData.dat": logger.info(f"Found metadata file {transferred_file} for parsing") @@ -289,82 +288,32 @@ def post_transfer( # Update existing dict self._site_info[site_num] = site_info_new logger.info(f"Updating metadata for site {site_num}") - return None - elif "DCImages" in parts and transferred_file.suffix == ".png": - lamella_name = parts[parts.index("Sites") + 1] - lamella_number = _number_from_name(lamella_name) - time_from_name = transferred_file.name.split("-")[:6] - timestamp = datetime.timestamp( - datetime( - year=int(time_from_name[0]), - month=int(time_from_name[1]), - day=int(time_from_name[2]), - hour=int(time_from_name[3]), - minute=int(time_from_name[4]), - second=int(time_from_name[5]), - ) - ) - if not (source := _get_source(transferred_file, environment)): - logger.warning(f"No source found for file {transferred_file}") - return - if not ( - destination_file := _file_transferred_to( - environment=environment, - source=source, - file_path=transferred_file, - rsync_basepath=Path( - self._machine_config.get("rsync_basepath", "") - ), - ) - ): - logger.warning( - f"File {transferred_file.name!r} not found on storage system" - ) - return - if not self._drift_correction_images.get(lamella_number): - self._drift_correction_images[lamella_number] = [ - MillingImage( - timestamp=timestamp, - file=destination_file, - ) - ] - else: - self._drift_correction_images[lamella_number].append( - MillingImage( - timestamp=timestamp, - file=destination_file, - ) - ) - gif_list = [ - l.file - for l in sorted( - self._drift_correction_images[lamella_number], - key=lambda x: x.timestamp, - ) - ] - raw_directory = Path( - environment.default_destinations[self._basepath] - ).name - # Submit job to backend to construct a GIF - capture_post( - base_url=str(environment.url.geturl()), - router_name="workflow.correlative_router", - function_name="make_gif", - token=self._token, - instrument_name=environment.instrument_name, - data={ - "lamella_number": lamella_number, - "images": [str(file) for file in gif_list], - "raw_directory": raw_directory, - }, - # Endpoint kwargs - year=datetime.now().year, - visit_name=environment.visit, - session_id=environment.murfey_session, - ) + # Post drift correction GIF request if it hasn't already been done + if ( + (fib_image := self._drift_correction_images.get(site_num, None)) + is not None + and not fib_image.is_submitted + and fib_image.output_file is not None + ): + if self._make_gif( + environment=environment, + lamella_number=site_num, + images=sorted(fib_image.images), + output_file=fib_image.output_file, + ): + with lock: + self._drift_correction_images[ + site_num + ].is_submitted = True return None + elif ( + "DCImages" in transferred_file.parts + and transferred_file.suffix == ".png" + ): + self._make_drift_correction_gif(transferred_file, environment) + # ----------------------------------------------------------------------------- # Maps # ----------------------------------------------------------------------------- @@ -491,9 +440,9 @@ def _parse_autotem_metadata(self, file: Path): ) # Iteratively update fields in the MillingSteps model it's not None - for field, path, func in ACTIVITY_FIELD_MAP: + for field_name, path, func in ACTIVITY_FIELD_MAP: if (value := _parse_xml_text(activity, path, func)) is not None: - step_info.__setattr__(field, value) + step_info.__setattr__(field_name, value) # Add info for current step to the site info model site_info.steps.__setattr__( @@ -506,6 +455,158 @@ def _parse_autotem_metadata(self, file: Path): logger.info(f"Successfully extracted AutoTEM metadata from file {file}") return all_site_info + def _make_drift_correction_gif( + self, + file: Path, + environment: MurfeyInstanceEnvironment, + ): + """ + Helper function to create GIFs using the drift correction images seen by the + FIBContext class. The function uses the metadata returned + """ + parts = file.parts + try: + lamella_name = parts[parts.index("Sites") + 1] + lamella_number = _number_from_name(lamella_name) + except Exception: + logger.warning( + f"Could not extract metadata from file {file}", exc_info=True + ) + return None + if not (source := _get_source(file, environment)): + logger.warning(f"No source found for file {file}") + return + if not ( + destination_file := _file_transferred_to( + environment=environment, + source=source, + file_path=file, + rsync_basepath=Path(self._machine_config.get("rsync_basepath", "")), + ) + ): + logger.warning(f"File {file.name!r} not found on storage system") + return + + # Create FIBImage instance for this lamella site, or update existing one + if not self._drift_correction_images.get(lamella_number): + with lock: + self._drift_correction_images[lamella_number] = FIBImage( + images=[destination_file] + ) + else: + with lock: + self._drift_correction_images[lamella_number].images.append( + destination_file + ) + self._drift_correction_images[lamella_number].is_submitted = False + + # Determine the output directory to save the milling image to + output_file = self._drift_correction_images[lamella_number].output_file + if output_file is None: + # Early exits if data for creating output image path is absent + # No site info + if (site_info := self._site_info.get(lamella_number)) is None: + logger.debug(f"No metadata found for site {lamella_number} yet") + return None + # No project name + if (project_name := site_info.project_name) is None: + logger.warning(f"No project name associated with site {lamella_number}") + return None + # No stage position information + if all( + getattr(site_info.stage_info, stage_name, None) is None + for stage_name in STAGE_POSITION_NAMES.keys() + ): + logger.warning( + f"No stage position information associated with site {lamella_number}" + ) + return None + # Determine the slot number + slot_number: int | None = None + for stage_name in reversed(STAGE_POSITION_NAMES.keys()): + if ( + stage_info := getattr(site_info.stage_info, stage_name, None) + ) is None: + continue + if stage_info.slot_number is None: + continue + else: + slot_number = stage_info.slot_number + break + # Early exit if no slot number + if slot_number is None: + logger.warning( + f"Could not determine slot number of site {lamella_number}" + ) + return None + # Determine the path to save the GIF to + try: + visit_index = destination_file.parts.index(environment.visit) + visit_dir = list(reversed(destination_file.parents))[visit_index] + output_file = ( + visit_dir + / "processed" + / project_name + / f"grid_{slot_number}" + / "drift_correction" + / f"lamella_{lamella_number}.gif" + ) + with lock: + self._drift_correction_images[ + lamella_number + ].output_file = output_file + except Exception: + logger.error( + f"Could not construct drift correction GIF output path for site {lamella_number}" + ) + return None + + # Submit job to backend to construct a GIF + if self._make_gif( + environment=environment, + lamella_number=lamella_number, + images=sorted(self._drift_correction_images[lamella_number].images), + output_file=output_file, + ): + # Mark this dataset as having been submitted + with lock: + self._drift_correction_images[lamella_number].is_submitted = True + logger.info( + f"Submitted request to create drift correction GIF for site {lamella_number}" + ) + return None + + def _make_gif( + self, + environment: MurfeyInstanceEnvironment, + lamella_number: int, + images: list[Path], + output_file: Path, + ): + """ + Submits a POST request to the backend server to create a GIF using the + JSON payload provided. The payload will contain + """ + try: + capture_post( + base_url=str(environment.url.geturl()), + router_name="workflow_fib.router", + function_name="make_gif", + token=self._token, + instrument_name=environment.instrument_name, + data={ + "lamella_number": lamella_number, + "images": [str(file) for file in images], + "output_file": str(output_file), + }, + # Endpoint kwargs + session_id=environment.murfey_session, + ) + return True + except Exception: + logger.error(f"Could not submit GIF for site {lamella_number}") + return False + def _register_atlas(self, file: Path, environment: MurfeyInstanceEnvironment): """ Constructs the URL and dictionary to be posted to the server, which then triggers diff --git a/src/murfey/server/api/workflow.py b/src/murfey/server/api/workflow.py index ce63b6e84..4eb7906ba 100644 --- a/src/murfey/server/api/workflow.py +++ b/src/murfey/server/api/workflow.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional -import numpy as np import sqlalchemy from fastapi import APIRouter, Depends from ispyb.sqlalchemy import ( @@ -21,11 +20,6 @@ from sqlmodel import col, select from werkzeug.utils import secure_filename -try: - from PIL import Image -except ImportError: - Image = None - try: from smartem_backend.api_client import SmartEMAPIClient from smartem_common.schemas import ( @@ -1194,71 +1188,3 @@ def register_sample_image( if _transport_object: return _transport_object.do_insert_sample_image(record) return {"success": False} - - -class MillingParameters(BaseModel): - lamella_number: int - images: List[str] - raw_directory: str - - -@correlative_router.post( - "/year/{year}/visits/{visit_name}/sessions/{session_id}/make_milling_gif" -) -async def make_gif( - year: int, - visit_name: str, - session_id: int, - gif_params: MillingParameters, - db=murfey_db, -): - instrument_name = ( - db.exec(select(Session).where(Session.id == session_id)).one().instrument_name - ) - machine_config = get_machine_config(instrument_name=instrument_name)[ - instrument_name - ] - output_dir = ( - (machine_config.rsync_basepath or Path("")).resolve() - / secure_filename(str(year)) - / secure_filename(visit_name) - / "processed" - ) - output_dir.mkdir(exist_ok=True) - os.chmod(output_dir, mode=machine_config.mkdir_chmod) - output_dir = output_dir / secure_filename(gif_params.raw_directory) - output_dir.mkdir(exist_ok=True) - os.chmod(output_dir, mode=machine_config.mkdir_chmod) - output_path = output_dir / f"lamella_{gif_params.lamella_number}_milling.gif" - - if Image is not None: - images = [Image.open(f) for f in gif_params.images] - else: - images = [] - for im in images: - im.thumbnail((512, 512)) - - # Normalize and convert individual frames to 8-bit - arr: list[np.ndarray] = [] - for im in images: - frame = np.array(im).astype(np.float32) - vmin, vmax = np.percentile(frame, (0.5, 99.5)) - scale = 255 / ((vmax - vmin) or 1) - np.clip(frame, a_min=vmin, a_max=vmax, out=frame) - np.subtract(frame, vmin, out=frame) - np.multiply(frame, scale, out=frame) - arr.append(frame.astype(np.uint8)) - arr = np.array(arr).astype(np.uint8) - - # Convert back to Image objects and save as GIF - converted = [Image.fromarray(arr[f], mode="L") for f in range(len(images))] - converted[0].save( - output_path, - format="GIF", - append_images=converted[1:], - save_all=True, - duration=30, - loop=0, - ) - - return {"output_gif": str(output_path)} diff --git a/src/murfey/server/api/workflow_fib.py b/src/murfey/server/api/workflow_fib.py index 444067575..ed961aa04 100644 --- a/src/murfey/server/api/workflow_fib.py +++ b/src/murfey/server/api/workflow_fib.py @@ -1,14 +1,19 @@ import json import logging +import os from importlib.metadata import entry_points from pathlib import Path +import numpy as np +import PIL.Image from fastapi import APIRouter, Depends from pydantic import BaseModel -from sqlmodel import Session +from sqlmodel import Session, select +import murfey.util.db as MurfeyDB from murfey.server.api.auth import validate_instrument_token from murfey.server.murfey_db import murfey_db +from murfey.util.config import get_machine_config from murfey.util.models import LamellaSiteInfo logger = logging.getLogger("murfey.server.api.workflow_fib") @@ -57,3 +62,73 @@ def register_fib_milling_progress( "Received the following FIB metadata for registration:\n" f"{json.dumps(site_info.model_dump(exclude_none=True), indent=2, default=str)}" ) + + +class FIBGIFParameters(BaseModel): + lamella_number: int + images: list[Path] + output_file: Path + + +@router.post("/sessions/{session_id}/make_gif") +async def make_gif( + session_id: int, + gif_params: FIBGIFParameters, + db=murfey_db, +): + # Load machine config and session info + session_entry = db.exec( + select(MurfeyDB.Session).where(MurfeyDB.Session.id == session_id) + ).one() + instrument_name = session_entry.instrument_name + visit_name = session_entry.visit + machine_config = get_machine_config(instrument_name=instrument_name)[ + instrument_name + ] + + # Create the directory structure + if not (output_dir := gif_params.output_file.parent).exists(): + output_dir.mkdir(parents=True) + logger.debug(f"Created output directory {output_dir}") + visit_index = output_dir.parts.index(visit_name) + # Change permissions for folders in the visit directory and onwards + for current_path in list(reversed(output_dir.parents))[visit_index + 1 :]: + try: + os.chmod(current_path, mode=machine_config.mkdir_chmod) + except PermissionError: + logger.warning( + f"Insufficient permissions to modify directory {current_path}" + ) + continue + + if PIL.Image is not None: + images = [PIL.Image.open(f) for f in gif_params.images] + else: + images = [] + for im in images: + im.thumbnail((512, 512)) + + # Normalize and convert individual frames to 8-bit + arr: list[np.ndarray] = [] + for im in images: + frame = np.array(im).astype(np.float32) + vmin, vmax = np.percentile(frame, (0.5, 99.5)) + scale = 255 / ((vmax - vmin) or 1) + np.clip(frame, a_min=vmin, a_max=vmax, out=frame) + np.subtract(frame, vmin, out=frame) + np.multiply(frame, scale, out=frame) + arr.append(frame.astype(np.uint8)) + arr = np.array(arr).astype(np.uint8) + + # Convert back to PIL.Image objects and save as GIF + converted = [PIL.Image.fromarray(arr[f], mode="L") for f in range(len(images))] + converted[0].save( + gif_params.output_file, + format="GIF", + append_images=converted[1:], + save_all=True, + duration=30, + loop=0, + ) + logger.info(f"Created GIF file {gif_params.output_file}") + return {"output_gif": str(gif_params.output_file)} diff --git a/src/murfey/util/models.py b/src/murfey/util/models.py index 11d1b8516..9b05fca89 100644 --- a/src/murfey/util/models.py +++ b/src/murfey/util/models.py @@ -128,20 +128,22 @@ class StagePositionInfo(BaseModel): "ChunkCoincidenceStagePosition" currently correspond to. """ + # Top-level values preparation: StagePositionValues | None = ( None # PreparationSiteLocation/StagePosition/StagePosition ) - chunk_coincidence: StagePositionValues | None = ( - None # Parameters/ChunkCoincidenceStagePosition/StagePosition - ) chunk: StagePositionValues | None = ( None # ChunkSiteLocation/StagePosition/StagePosition ) thinning_1: StagePositionValues | None = ( - None # Parameters/ThinningStagePosition/StagePosition + None # ThinningSiteLocation/StagePosition/StagePosition + ) + # Stored under Parameters + chunk_coincidence: StagePositionValues | None = ( + None # Parameters/ChunkCoincidenceStagePosition/StagePosition ) thinning_2: StagePositionValues | None = ( - None # ThinningSiteLocation/StagePosition/StagePosition + None # Parameters/ThinningStagePosition/StagePosition ) diff --git a/src/murfey/util/route_manifest.yaml b/src/murfey/util/route_manifest.yaml index 953bed23c..10ace7fd3 100644 --- a/src/murfey/util/route_manifest.yaml +++ b/src/murfey/util/route_manifest.yaml @@ -1270,17 +1270,6 @@ murfey.server.api.workflow.correlative_router: type: str methods: - POST - - path: /workflow/correlative/year/{year}/visits/{visit_name}/sessions/{session_id}/make_milling_gif - function: make_gif - path_params: - - name: year - type: int - - name: visit_name - type: str - - name: session_id - type: int - methods: - - POST murfey.server.api.workflow.router: - path: /workflow/visits/{visit_name}/sessions/{session_id}/register_data_collection_group function: register_dc_group @@ -1447,3 +1436,10 @@ murfey.server.api.workflow_fib.router: type: int methods: - POST + - path: /workflow/fib/sessions/{session_id}/make_gif + function: make_gif + path_params: + - name: session_id + type: int + methods: + - POST diff --git a/tests/client/contexts/test_fib.py b/tests/client/contexts/test_fib.py index 91b9b7665..58161bbb1 100644 --- a/tests/client/contexts/test_fib.py +++ b/tests/client/contexts/test_fib.py @@ -16,6 +16,7 @@ _number_from_name, _parse_boolean, ) +from murfey.util.models import LamellaSiteInfo # Mock session values num_lamellae = 5 @@ -595,27 +596,49 @@ def test_fib_autotem_context_projectdata( @pytest.mark.parametrize( "test_params", - ( # Use environment? | Find source? | Find destination? - (True, True, True), - (False, True, True), - (True, False, True), - (True, True, False), + ( + # Early exits + # No MurfeyInstanceEnvironment + (False, True, True, True, True, True, True), + # No source + (True, False, True, True, True, True, True), + # No destination + (True, True, False, True, True, True, True), + # No site info + (True, True, True, False, True, True, True), + # No project name + (True, True, True, True, False, True, True), + # No stage position + (True, True, True, True, True, False, True), + # No stage position values + (True, True, True, True, True, True, False), + # Successful case + (True, True, True, True, True, True, True), ), ) def test_fib_autotem_context_drift_correction_images( mocker: MockerFixture, - test_params: tuple[bool, bool, bool], + test_params: tuple[bool, bool, bool, bool, bool, bool, bool], tmp_path: Path, visit_dir: Path, fib_autotem_dc_images: list[Path], ): # Unpack test params - use_env, find_source, find_dst = test_params + ( + use_env, + find_source, + find_dst, + has_site_info, + has_project_name, + has_stage_position, + has_stage_values, + ) = test_params # Mock the environment mock_environment = None if use_env: mock_environment = MagicMock() + mock_environment.visit = visit_name # Mock the logger to check if specific logs are triggered mock_logger = mocker.patch("murfey.client.contexts.fib.logger") @@ -649,6 +672,23 @@ def test_fib_autotem_context_drift_correction_images( token="", ) + # Create the Pydantic model for each site and add metadata + for i in range(num_lamellae): + lamella_num = i + 1 + metadata_dict = { + "site_name": f"Lamella ({lamella_num})", + "site_number": lamella_num, + } + if has_project_name: + metadata_dict["project_name"] = project_name + if has_stage_position: + stage_dict: dict[str, dict] = {"preparation": {}} + if has_stage_values: + stage_dict["preparation"] = {"x": 0.003} + metadata_dict["stage_info"] = stage_dict + if has_site_info: + context._site_info[lamella_num] = LamellaSiteInfo(**metadata_dict) + # Parse images one-by-one and check that expected calls were made for file in fib_autotem_dc_images: context.post_transfer(file, environment=mock_environment) @@ -660,6 +700,22 @@ def test_fib_autotem_context_drift_correction_images( mock_logger.warning.assert_called_with( f"File {file.name!r} not found on storage system" ) + elif not has_site_info: + mock_logger.debug.assert_called_with( + f"No metadata found for site {lamella_num} yet" + ) + elif not has_project_name: + mock_logger.warning.assert_called_with( + f"No project name associated with site {lamella_num}" + ) + elif not has_stage_position: + mock_logger.warning.assert_called_with( + f"No stage position information associated with site {lamella_num}" + ) + elif not has_stage_values: + mock_logger.warning.assert_called_with( + f"Could not determine slot number of site {lamella_num}" + ) else: mock_get_source.assert_called_with(file, mock_environment) mock_file_transferred_to.assert_called_with( @@ -668,9 +724,34 @@ def test_fib_autotem_context_drift_correction_images( file_path=file, rsync_basepath=Path(""), ) - assert mock_capture_post.call_count == len(fib_autotem_dc_images) assert len(context._drift_correction_images) == num_lamellae + for i in range(num_lamellae): + lamella_num = i + 1 + # The '_site_info' attribute should now be populated + assert ( + context._site_info[lamella_num].stage_info.preparation.slot_number == 2 + ) + + # The output file should point to 'grid_2' for a positive x stage position + output_file = ( + tmp_path + / "fib" + / "data" + / "current_year" + / visit_name + / "processed" + / project_name + / "grid_2" + / "drift_correction" + / f"lamella_{lamella_num}.gif" + ) + assert ( + context._drift_correction_images[lamella_num].output_file == output_file + ) + # 'capture_post' should be called for every image + assert mock_capture_post.call_count == len(destination_files) + def test_fib_maps_context( mocker: MockerFixture, diff --git a/tests/server/api/test_workflow.py b/tests/server/api/test_workflow.py index 1f6ee0c54..a1ec8336c 100644 --- a/tests/server/api/test_workflow.py +++ b/tests/server/api/test_workflow.py @@ -1,17 +1,9 @@ -from pathlib import Path from unittest import mock -from unittest.mock import MagicMock -import numpy as np -import PIL.Image -import pytest -from pytest_mock import MockerFixture from sqlmodel import Session, select from murfey.server.api.workflow import ( DCGroupParameters, - MillingParameters, - make_gif, register_dc_group, ) from murfey.util.db import DataCollectionGroup, SearchMap @@ -438,75 +430,3 @@ def test_register_dc_group_new_atlas_with_searchmaps( murfey_db_session, close_db=False, ) - - -@pytest.mark.asyncio -async def test_make_gif( - mocker: MockerFixture, - tmp_path: Path, -): - # Set up test variables - session_id = 10 - instrument_name = "test_instrument" - rsync_basepath = tmp_path / "data" - visit_name = "cm12345-6" - year = 2020 - visit_dir = rsync_basepath / str(year) / visit_name - lamella_num = 12 - lamella_folder = "Lamella" - if lamella_num > 1: - lamella_folder += f" ({lamella_num})" - raw_directory = "autotem" - - # Create a list of test image file paths - raw_images = [ - visit_dir - / "autotem" - / visit_name - / "Sites" - / lamella_folder - / "DCImages/DCM_asdfjkl/asdfjkl-Polishing-dc_rescan-image-.png" - ] * 5 - # Mock the output of PIL.Image.open to always return a NumPY array - mocker.patch( - "murfey.server.api.workflow.Image.open", - return_value=PIL.Image.fromarray(np.ones((512, 512), dtype=np.uint16)), - ) - - # Create the Pydantic model - params = MillingParameters( - lamella_number=lamella_num, - images=[str(f) for f in raw_images], - raw_directory=raw_directory, - ) - - # Mock the database query - mock_db = MagicMock() - mock_db.exec.return_value.one.return_value.instrument_name = instrument_name - - # Mock the machine config and 'get_machine_config' - mock_machine_config = MagicMock() - mock_machine_config.rsync_basepath = rsync_basepath - mock_machine_config.mkdir_chmod = 0o775 - mocker.patch( - "murfey.server.api.workflow.get_machine_config", - return_value={ - instrument_name: mock_machine_config, - }, - ) - - # Create the save directory directory - save_dir = visit_dir / "processed" / raw_directory - save_dir.mkdir(parents=True, exist_ok=True) - - # Run the function and check that the expected outputs are there - result = await make_gif( - year=year, - visit_name=visit_name, - session_id=session_id, - gif_params=params, - db=mock_db, - ) - image_path = save_dir / f"lamella_{lamella_num}_milling.gif" - assert image_path.exists() - assert result.get("output_gif") == str(image_path) diff --git a/tests/server/api/test_workflow_fib.py b/tests/server/api/test_workflow_fib.py index e72deb57e..6baadcd09 100644 --- a/tests/server/api/test_workflow_fib.py +++ b/tests/server/api/test_workflow_fib.py @@ -1,10 +1,17 @@ from pathlib import Path from unittest.mock import MagicMock +import numpy as np +import PIL.Image import pytest from pytest_mock import MockerFixture -from murfey.server.api.workflow_fib import FIBAtlasInfo, register_fib_atlas +from murfey.server.api.workflow_fib import ( + FIBAtlasInfo, + FIBGIFParameters, + make_gif, + register_fib_atlas, +) def test_register_fib_atlas( @@ -52,3 +59,75 @@ def test_register_fib_atlas_no_entry_point( fib_atlas_info=fib_atlas_info, db=mock_db, ) + + +@pytest.mark.asyncio +async def test_make_gif( + mocker: MockerFixture, + tmp_path: Path, +): + # Set up test variables + session_id = 10 + instrument_name = "test_instrument" + rsync_basepath = tmp_path / "data" + visit_name = "cm12345-6" + year = 2020 + visit_dir = rsync_basepath / str(year) / visit_name + lamella_num = 12 + lamella_folder = "Lamella" + if lamella_num > 1: + lamella_folder += f" ({lamella_num})" + output_file = ( + visit_dir + / "processed" + / "project_name" + / "grid_1" + / "drift_correction" + / f"lamella_{lamella_num}.gif" + ) + + # Create a list of test image file paths + raw_images = [ + visit_dir + / "autotem" + / visit_name + / "Sites" + / lamella_folder + / "DCImages/DCM_asdfjkl/asdfjkl-Polishing-dc_rescan-image-.png" + ] * 5 + # Mock the output of PIL.Image.open to always return a NumPY array + mocker.patch( + "murfey.server.api.workflow_fib.PIL.Image.open", + return_value=PIL.Image.fromarray(np.ones((512, 512), dtype=np.uint16)), + ) + + # Create the Pydantic model + params = FIBGIFParameters( + lamella_number=lamella_num, + images=[str(f) for f in raw_images], + output_file=output_file, + ) + + # Mock the database query + mock_db = MagicMock() + mock_db.exec.return_value.one.return_value.instrument_name = instrument_name + mock_db.exec.return_value.one.return_value.visit = visit_name + + # Mock the machine config and 'get_machine_config' + mock_machine_config = MagicMock() + mock_machine_config.mkdir_chmod = 0o775 + mocker.patch( + "murfey.server.api.workflow_fib.get_machine_config", + return_value={ + instrument_name: mock_machine_config, + }, + ) + + # Run the function and check that the expected outputs are there + result = await make_gif( + session_id=session_id, + gif_params=params, + db=mock_db, + ) + assert output_file.exists() + assert result.get("output_gif") == str(output_file)