-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Enable global coordinates in spatial crop transforms #8794
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
dd737f8
6484465
91f40d0
d0b2c47
7f7a218
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |||||||||||||
|
|
||||||||||||||
| from collections.abc import Callable, Hashable, Mapping, Sequence | ||||||||||||||
| from copy import deepcopy | ||||||||||||||
| from typing import Any | ||||||||||||||
| from typing import Any, Optional, Union, cast | ||||||||||||||
|
|
||||||||||||||
| import numpy as np | ||||||||||||||
| import torch | ||||||||||||||
|
|
@@ -50,7 +50,7 @@ | |||||||||||||
| from monai.transforms.traits import LazyTrait, MultiSampleTrait | ||||||||||||||
| from monai.transforms.transform import LazyTransform, MapTransform, Randomizable | ||||||||||||||
| from monai.transforms.utils import is_positive | ||||||||||||||
| from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep | ||||||||||||||
| from monai.utils import MAX_SEED, Method, PytorchPadMode, TraceKeys, ensure_tuple_rep | ||||||||||||||
|
|
||||||||||||||
| __all__ = [ | ||||||||||||||
| "Padd", | ||||||||||||||
|
|
@@ -431,17 +431,33 @@ class SpatialCropd(Cropd): | |||||||||||||
| - a spatial center and size | ||||||||||||||
| - the start and end coordinates of the ROI | ||||||||||||||
|
|
||||||||||||||
| ROI parameters (``roi_center``, ``roi_size``, ``roi_start``, ``roi_end``) can also be specified as | ||||||||||||||
| string dictionary keys. When a string is provided, the actual coordinate values are read from the | ||||||||||||||
| data dictionary at call time. This enables pipelines where coordinates are computed by earlier | ||||||||||||||
| transforms (e.g., :py:class:`monai.transforms.TransformPointsWorldToImaged`) and stored in the | ||||||||||||||
| data dictionary under the given key. | ||||||||||||||
|
|
||||||||||||||
| Example:: | ||||||||||||||
|
|
||||||||||||||
| from monai.transforms import Compose, TransformPointsWorldToImaged, SpatialCropd | ||||||||||||||
|
|
||||||||||||||
| pipeline = Compose([ | ||||||||||||||
| TransformPointsWorldToImaged(keys="roi_start", refer_keys="image"), | ||||||||||||||
| TransformPointsWorldToImaged(keys="roi_end", refer_keys="image"), | ||||||||||||||
| SpatialCropd(keys="image", roi_start="roi_start", roi_end="roi_end"), | ||||||||||||||
| ]) | ||||||||||||||
|
|
||||||||||||||
| This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>` | ||||||||||||||
| for more information. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| def __init__( | ||||||||||||||
| self, | ||||||||||||||
| keys: KeysCollection, | ||||||||||||||
| roi_center: Sequence[int] | int | None = None, | ||||||||||||||
| roi_size: Sequence[int] | int | None = None, | ||||||||||||||
| roi_start: Sequence[int] | int | None = None, | ||||||||||||||
| roi_end: Sequence[int] | int | None = None, | ||||||||||||||
| roi_center: Sequence[int] | int | str | None = None, | ||||||||||||||
| roi_size: Sequence[int] | int | str | None = None, | ||||||||||||||
| roi_start: Sequence[int] | int | str | None = None, | ||||||||||||||
| roi_end: Sequence[int] | int | str | None = None, | ||||||||||||||
| roi_slices: Sequence[slice] | None = None, | ||||||||||||||
| allow_missing_keys: bool = False, | ||||||||||||||
| lazy: bool = False, | ||||||||||||||
|
|
@@ -450,19 +466,134 @@ def __init__( | |||||||||||||
| Args: | ||||||||||||||
| keys: keys of the corresponding items to be transformed. | ||||||||||||||
| See also: :py:class:`monai.transforms.compose.MapTransform` | ||||||||||||||
| roi_center: voxel coordinates for center of the crop ROI. | ||||||||||||||
| roi_center: voxel coordinates for center of the crop ROI, or a string key to look up | ||||||||||||||
| the coordinates from the data dictionary. | ||||||||||||||
| roi_size: size of the crop ROI, if a dimension of ROI size is larger than image size, | ||||||||||||||
| will not crop that dimension of the image. | ||||||||||||||
| roi_start: voxel coordinates for start of the crop ROI. | ||||||||||||||
| will not crop that dimension of the image. Can also be a string key. | ||||||||||||||
| roi_start: voxel coordinates for start of the crop ROI, or a string key to look up | ||||||||||||||
| the coordinates from the data dictionary. | ||||||||||||||
| roi_end: voxel coordinates for end of the crop ROI, if a coordinate is out of image, | ||||||||||||||
| use the end coordinate of image. | ||||||||||||||
| use the end coordinate of image. Can also be a string key. | ||||||||||||||
| roi_slices: list of slices for each of the spatial dimensions. | ||||||||||||||
| allow_missing_keys: don't raise exception if key is missing. | ||||||||||||||
| lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False. | ||||||||||||||
| """ | ||||||||||||||
| cropper = SpatialCrop(roi_center, roi_size, roi_start, roi_end, roi_slices, lazy=lazy) | ||||||||||||||
| self._roi_center = roi_center | ||||||||||||||
| self._roi_size = roi_size | ||||||||||||||
| self._roi_start = roi_start | ||||||||||||||
| self._roi_end = roi_end | ||||||||||||||
| self._roi_slices = roi_slices | ||||||||||||||
| self._has_str_roi = any(isinstance(v, str) for v in [roi_center, roi_size, roi_start, roi_end]) | ||||||||||||||
|
|
||||||||||||||
| if not self._has_str_roi: | ||||||||||||||
| _roi_t = Optional[Union[Sequence[int], int]] | ||||||||||||||
| cropper = SpatialCrop( | ||||||||||||||
| cast(_roi_t, roi_center), | ||||||||||||||
| cast(_roi_t, roi_size), | ||||||||||||||
| cast(_roi_t, roi_start), | ||||||||||||||
| cast(_roi_t, roi_end), | ||||||||||||||
| roi_slices, | ||||||||||||||
| lazy=lazy, | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| # Placeholder cropper for the string-key path. Replaced on self.cropper at | ||||||||||||||
| # __call__ time once string keys are resolved from the data dictionary. | ||||||||||||||
| cropper = SpatialCrop(roi_start=[0], roi_end=[1], lazy=lazy) | ||||||||||||||
| super().__init__(keys, cropper=cropper, allow_missing_keys=allow_missing_keys, lazy=lazy) | ||||||||||||||
|
|
||||||||||||||
| @staticmethod | ||||||||||||||
| def _resolve_roi_param(val, d): | ||||||||||||||
| """Resolve an ROI parameter from the data dictionary if it is a string key. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| val: the ROI parameter value. If a string, it is used as a key to look up | ||||||||||||||
| the actual value from ``d``. Otherwise returned as-is. | ||||||||||||||
| d: the data dictionary. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| The resolved ROI parameter. Tensors and numpy arrays are flattened to 1-D | ||||||||||||||
| and rounded to int64 so they can be consumed by ``Crop.compute_slices``. | ||||||||||||||
|
|
||||||||||||||
| Raises: | ||||||||||||||
| KeyError: if ``val`` is a string key that does not exist in ``d``. | ||||||||||||||
| """ | ||||||||||||||
| if not isinstance(val, str): | ||||||||||||||
| return val | ||||||||||||||
| if val not in d: | ||||||||||||||
| raise KeyError(f"ROI key '{val}' not found in the data dictionary.") | ||||||||||||||
| resolved = d[val] | ||||||||||||||
| # ApplyTransformToPoints outputs tensors of shape (C, N, dims). | ||||||||||||||
| # A single coordinate like [142.5, -67.3, 301.8] becomes shape (1, 1, 3). | ||||||||||||||
| # Flatten to 1-D and round to integers for compute_slices. | ||||||||||||||
| # Uses banker's rounding (torch.round) to avoid systematic bias in spatial coordinates. | ||||||||||||||
| if isinstance(resolved, np.ndarray): | ||||||||||||||
| resolved = torch.from_numpy(resolved) | ||||||||||||||
| if isinstance(resolved, torch.Tensor): | ||||||||||||||
| resolved = torch.round(resolved.flatten()).to(torch.int64) | ||||||||||||||
| return resolved | ||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||
|
|
||||||||||||||
| @property | ||||||||||||||
| def requires_current_data(self): | ||||||||||||||
| """bool: Whether this transform requires the current data dictionary to resolve ROI parameters.""" | ||||||||||||||
| return self._has_str_roi | ||||||||||||||
|
Comment on lines
+536
to
+538
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
|
|
||||||||||||||
| def __call__(self, data: Mapping[Hashable, torch.Tensor], lazy: bool | None = None) -> dict[Hashable, torch.Tensor]: | ||||||||||||||
| """ | ||||||||||||||
| Args: | ||||||||||||||
| data: dictionary of data items to be transformed. | ||||||||||||||
| lazy: whether to execute lazily. If ``None``, uses the instance default. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| Dictionary with cropped data for each key. | ||||||||||||||
| """ | ||||||||||||||
| if not self._has_str_roi: | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| return super().__call__(data, lazy=lazy) | ||||||||||||||
|
|
||||||||||||||
| d = dict(data) | ||||||||||||||
| roi_center = self._resolve_roi_param(self._roi_center, d) | ||||||||||||||
| roi_size = self._resolve_roi_param(self._roi_size, d) | ||||||||||||||
| roi_start = self._resolve_roi_param(self._roi_start, d) | ||||||||||||||
| roi_end = self._resolve_roi_param(self._roi_end, d) | ||||||||||||||
|
|
||||||||||||||
| lazy_ = self.lazy if lazy is None else lazy | ||||||||||||||
| self.cropper = SpatialCrop( | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're creating a cropper for each invocation this will not be thread-safe if stored in a member like this. I would just use a local |
||||||||||||||
| roi_center=roi_center, | ||||||||||||||
| roi_size=roi_size, | ||||||||||||||
| roi_start=roi_start, | ||||||||||||||
| roi_end=roi_end, | ||||||||||||||
| roi_slices=self._roi_slices, | ||||||||||||||
| lazy=lazy_, | ||||||||||||||
| ) | ||||||||||||||
| for key in self.key_iterator(d): | ||||||||||||||
| d[key] = self.cropper(d[key], lazy=lazy_) | ||||||||||||||
| return d | ||||||||||||||
|
|
||||||||||||||
| def inverse(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor]: | ||||||||||||||
| """ | ||||||||||||||
| Inverse of the crop transform, restoring the original spatial dimensions via padding. | ||||||||||||||
|
|
||||||||||||||
| For the string-key path, ``self.cropper`` is recreated on each ``__call__``, so its | ||||||||||||||
| ``id()`` won't match the one stored in the MetaTensor's transform stack. This override | ||||||||||||||
| bypasses the ID check and applies the inverse directly using the crop info stored in the | ||||||||||||||
| MetaTensor. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| data: dictionary of cropped ``MetaTensor`` items. | ||||||||||||||
|
|
||||||||||||||
| Returns: | ||||||||||||||
| Dictionary with inverse-transformed (padded) data for each key. | ||||||||||||||
| """ | ||||||||||||||
| if not self._has_str_roi: | ||||||||||||||
| return super().inverse(data) | ||||||||||||||
| d = dict(data) | ||||||||||||||
| for key in self.key_iterator(d): | ||||||||||||||
| transform = self.cropper.pop_transform(d[key], check=False) | ||||||||||||||
| cropped = transform[TraceKeys.EXTRA_INFO]["cropped"] | ||||||||||||||
| inverse_transform = BorderPad(cropped) | ||||||||||||||
| with inverse_transform.trace_transform(False): | ||||||||||||||
| d[key] = inverse_transform(d[key]) # type: ignore[assignment] | ||||||||||||||
| return d | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class CenterSpatialCropd(Cropd): | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was using this and
castnecessary? I would expect mypy to pick up the type ofroi_*correctly. This might be a Python version issue, you should have been able to useSequence[int] | int | str | Noneas the value here.