diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index f0c1d1949d..6a76f49bfc 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -300,7 +300,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep ) if self.ensure_channel_first: - img = EnsureChannelFirst()(img) + img = EnsureChannelFirst()(img, meta_dict=meta_data) if self.image_only: return img return img, img.meta if isinstance(img, MetaTensor) else meta_data diff --git a/tests/transforms/test_load_image.py b/tests/transforms/test_load_image.py index 031e38272e..75fb577ab2 100644 --- a/tests/transforms/test_load_image.py +++ b/tests/transforms/test_load_image.py @@ -497,6 +497,16 @@ def test_correct(self, input_param, expected_shape, track_meta): self.assertNotIsInstance(r, MetaTensor) self.assertFalse(hasattr(r, "affine")) + def test_track_meta_false_ensure_channel_first(self): + try: + set_track_meta(False) + r = LoadImage(image_only=True, ensure_channel_first=True)(self.test_data) + self.assertTupleEqual(r.shape, (1, 128, 128, 128)) + self.assertIsInstance(r, torch.Tensor) + self.assertNotIsInstance(r, MetaTensor) + finally: + set_track_meta(True) + if __name__ == "__main__": unittest.main()