fix: replace set_track_meta with F.interpolate in RandSimulateLowResolution (#8409)#8837
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRandSimulateLowResolution no longer uses MONAI Resize transforms; it converts the input with Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
6df6570 to
b15b2a2
Compare
b15b2a2 to
b991fcb
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/transforms/spatial/array.py`:
- Around line 3577-3585: The downsampling call to
torch.nn.functional.interpolate does not pass align_corners even though
align_corners should apply to both downsample_mode and upsample_mode; update the
img_downsampled interpolation call (the torch.nn.functional.interpolate
invocation that produces img_downsampled from img_float/target_shape) to include
align_corners when downsample_mode is one of the valid modes (use the same
_align_corners_modes check used for upsample_align_corners), e.g., compute
downsample_align_corners = self.align_corners if downsample_mode in
_align_corners_modes else None and pass downsample_align_corners as the
align_corners argument to the interpolate call so behavior matches the
docstring.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 836eec9b-957b-433c-80e6-909cce2aa688
📒 Files selected for processing (2)
monai/transforms/spatial/array.pytests/transforms/test_rand_simulate_low_resolution.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/transforms/test_rand_simulate_low_resolution.py
26f27c3 to
067ed70
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/transforms/spatial/array.py`:
- Around line 3592-3595: The code always wraps the upsampled array into a
MetaTensor and copies metadata (creating img_upsampled = MetaTensor(...);
img_upsampled.copy_meta_from(img)), which ignores get_track_meta() and breaks
callers that disabled meta tracking; change the rebuild logic to check
get_track_meta() and only construct and copy into a MetaTensor when
get_track_meta() is True, otherwise return a plain ndarray (img_upsampled_t)
without metadata, and add a small regression test that disables meta tracking to
ensure the returned object is not a MetaTensor.
In `@tests/transforms/test_rand_simulate_low_resolution.py`:
- Around line 95-117: The test is nondeterministic because each worker records
its baseline at different times; make it deterministic by synchronizing thread
start with a threading.Barrier: create barrier = threading.Barrier(len(threads)
+ 1), have run_transform call barrier.wait() first, then capture before =
get_track_meta(), run tfm(img) and check get_track_meta(), and have the main
test call barrier.wait() after starting all threads to release them
simultaneously; this ensures all workers record the same global baseline and
reliably reproduces the race.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9ed93f83-2a6c-4bdc-a735-c955ea71e7c8
📒 Files selected for processing (2)
monai/transforms/spatial/array.pytests/transforms/test_rand_simulate_low_resolution.py
4d1b850 to
937a516
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/transforms/spatial/array.py`:
- Line 3568: The computed target_shape can become zero when computing
tuple(np.round(np.array(input_shape) *
self.zoom_factor).astype(np.int_).tolist()), causing interpolate to fail; update
the target_shape calculation in the same block (where target_shape,
self.zoom_factor and input_shape are used) to clamp each axis to at least 1
(e.g., replace the raw rounded/casted values with max(1, value) per axis) so no
spatial dimension is zero before calling torch.nn.functional.interpolate.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 36dda1e2-4b10-45cf-b1e4-aaee25cc8a62
📒 Files selected for processing (2)
monai/transforms/spatial/array.pytests/transforms/test_rand_simulate_low_resolution.py
9645a81 to
3f50313
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@monai/transforms/spatial/array.py`:
- Around line 3577-3589: The interpolation mode choices (downsample_mode /
upsample_mode) can be incompatible with the spatial rank of the tensor and cause
F.interpolate to raise NotImplementedError; update the logic in the block around
downsample_mode, upsample_mode and where img_downsampled/img_upsampled_t are
created to map/validate modes based on the spatial rank (len(input_shape)): use
"linear" for 1D, "bilinear" for 2D and "trilinear" for 3D (or fall back to an
appropriate supported mode), and ensure align_corners logic still only applies
to the supported modes; implement this mapping before calling
torch.nn.functional.interpolate so the mode strings passed to F.interpolate are
valid for the input rank.
In `@tests/transforms/test_rand_simulate_low_resolution.py`:
- Around line 105-123: The test currently only verifies track_meta after the
50-iteration loop and swallows all exceptions, allowing transient races or
hidden failures to escape detection; inside run_transform (where tfm =
RandSimulateLowResolution(...) and the loop for _ in range(50) calls tfm(img)),
move the check get_track_meta() != expected_track_meta to immediately after each
tfm(img) call so each invocation is validated, and replace the blanket except
Exception with capturing the actual exception (and optional traceback)
per-iteration and appending that error to errors so the original worker failure
is preserved and reported.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a3bfbc51-eb7f-49fe-8368-d893d0f7a030
📒 Files selected for processing (2)
monai/transforms/spatial/array.pytests/transforms/test_rand_simulate_low_resolution.py
…lution (Project-MONAI#8409) Signed-off-by: chhayankjain <chhayank44@gmail.com>
ec57803 to
0860882
Compare
for more information, see https://pre-commit.ci
Fixes #8409
Description
RandSimulateLowResolutioninternally performs a downsample → upsample cycle using twoResizetransforms. To prevent these from being recorded in the invertible transformstack, the previous implementation temporarily toggled the global
set_track_meta(False)flag and restored it afterward.
This is not thread-safe: in multi-threaded data loading (e.g.
ThreadDataLoader), anotherthread calling
get_track_meta()between the toggle and the restore would silently receivethe wrong value, causing incorrect metadata tracking behaviour.
Fix: replace the
Resizetransforms with directtorch.nn.functional.interpolatecalls ona plain tensor obtained via
convert_to_tensor(img, track_meta=False). This avoids anyglobal state mutation entirely. Output dtype (float32) and metadata-copy behaviour are
preserved from the original implementation.
set_track_metais also removed from the importsince it is no longer used anywhere in the file.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.