Fix determinism in augmentation#13
Conversation
|
The related PR in nnunet is here :) |
|
Hey, thanks for this work! One question from my side: This PR is based on the assumption that the torch RNG is not properly seeded in a multiprocessing environment. Why is it not possible to just seed all used RNGs to obtain deterministic augmentations. I would like to avoid bouncing back and forth between torch and numpy all the time. The few times that we are doing this already bug be quite a lot. |
Hey, following the comment of @FabianIsensee - I've opened a new PR in |
Hi MIC-DKFZ team,
This PR fixes an issue that was preventing fully deterministic behavior in the augmentation pipeline, even when a seed was provided. I noticed this behavior when trying to train a nnunet model deterministically, I'll submit a PR for the nnunet side too, hoping that can help !
Problem
The main issue was that several transforms used PyTorch's random functions (
torch.rand,torch.normal). In a multi-process environment likeMultiThreadedAugmenter(like in nnunet), the numpy RNG is correctly seeded in each worker, but the PyTorch RNG is not. This led to unpredictable results from any transform usingtorchfor randomness.I also found a couple of related issues: the
benchmark=Trueparameter inGaussianBlurTransformis inherently non-deterministic, andSpatialTransformhad some unstable randomness from mixing torch and numpy operations.Solution
The fix was to go through the library and make sure all random operations rely on NumPy's random generator. This way, everything is controlled by the single RNG that's properly seeded in the data loader's workers.
Here are the transforms that were updated:
RandomTransformSpatialTransform(for elastic deform)MirrorTransformSimulateLowResolutionTransformGaussianBlurTransformGaussianNoiseTransformMultiplicativeBrightnessTransformContrastTransformGammaTransformRicianNoiseTransformInvertImageTransformRemoveRandomConnectedComponentFromOneHotEncodingTransformApplyRandomBinaryOperatorTransformHow It's Tested
To make sure these fixes work and to catch any future regressions, I've added a new testing script,
determinism_test_pipeline.py.The script tests every transform in the library for both 2D and 3D data. The key part of the test is how it checks for determinism: it runs each transform twice, but for the second run, it only re-seeds the NumPy RNG. This aims to mimic the multi-worker environment. The script demonstrates the non determinism on the original code.
With these changes, the whole library now passes this test, so I believe we can be confident that augmentation pipelines are fully reproducible. This change should allow for fully reproducible training pipelines, which is a big deal for research. The performance impact should be minimal, and might even be a little better since some inefficient operations and benchmarking overhead were removed.
Thanks for maintaining this great library. Hope this helps, and let me know what you think!
Post Scriptum
Original image :


Augmented image :
Augmented image :
