diff --git a/Common/GTesting/itkImageRandomCoordinateSamplerGTest.cxx b/Common/GTesting/itkImageRandomCoordinateSamplerGTest.cxx index 22d18a344..453e94bb9 100644 --- a/Common/GTesting/itkImageRandomCoordinateSamplerGTest.cxx +++ b/Common/GTesting/itkImageRandomCoordinateSamplerGTest.cxx @@ -30,64 +30,42 @@ using elx::CoreMainGTestUtilities::CreateImageFilledWithSequenceOfNaturalNumbers #include #include +#include -GTEST_TEST(ImageRandomCoordinateSampler, CheckImageValuesOfSamples) +GTEST_TEST(ImageRandomCoordinateSampler, HasSameOutputWhenUsingMultiThread) { using PixelType = int; - using ImageType = itk::Image; + using ImageType = itk::Image; using SamplerType = itk::ImageRandomCoordinateSampler; - // Use a fixed seed, in order to have a reproducible sampler output. - DerefSmartPointer(itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1); + const auto image = CreateImageFilledWithSequenceOfNaturalNumbers(ImageType::SizeType::Filled(100)); - const auto image = - CreateImageFilledWithSequenceOfNaturalNumbers(ImageType::SizeType::Filled(minimumImageSizeValue)); + const auto generateSamples = [image](const bool useMultiThread, const SamplerType::SeedIntegerType seed) { + elx::DefaultConstruct sampler{}; + sampler.SetUseMultiThread(useMultiThread); - elx::DefaultConstruct sampler{}; + // This sampler uses ITK's global MersenneTwisterRandomVariateGenerator instance. + DerefSmartPointer(itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(seed); - const size_t numberOfSamples{ 3 }; - sampler.SetNumberOfSamples(numberOfSamples); - sampler.SetInput(image); - sampler.Update(); + sampler.SetNumberOfSamples(sampler.GetNumberOfSamples() * 100); + sampler.SetInput(image); + using namespace std::chrono; + const auto timePoint = high_resolution_clock::now(); + sampler.Update(); + std::cout << sampler.GetNumberOfSamples() << " samples " + << " Duration: " << duration_cast>(high_resolution_clock::now() - timePoint).count() + << " seconds " << (useMultiThread ? "MultiThread" : "single threaded") << std::endl; + return std::move(DerefRawPointer(sampler.GetOutput()).CastToSTLContainer()); + }; - const auto & samples = DerefRawPointer(sampler.GetOutput()).CastToSTLConstContainer(); - - ASSERT_EQ(samples.size(), numberOfSamples); - - // The image values that appeared during the development of the test. - const std::array expectedImageValues = { 14.269278, - 14.93714, - 1.882026 }; - - for (size_t i{}; i < numberOfSamples; ++i) + for (SamplerType::SeedIntegerType seed{}; seed < 7; ++seed) { - EXPECT_FLOAT_EQ(samples[i].m_ImageValue, expectedImageValues[i]); - } -} - + const auto multiThreadedSamples = generateSamples(true, seed); + const auto singleThreadedSamples = generateSamples(false, seed); -GTEST_TEST(ImageRandomCoordinateSampler, SetSeedMakesRandomizationDeterministic) -{ - using PixelType = int; - using ImageType = itk::Image; - using SamplerType = itk::ImageRandomCoordinateSampler; - - const auto image = - CreateImageFilledWithSequenceOfNaturalNumbers(ImageType::SizeType::Filled(minimumImageSizeValue)); - - for (const SamplerType::SeedIntegerType seed : { 0, 1 }) - { - const auto generateSamples = [seed, image] { - elx::DefaultConstruct sampler{}; - - DerefSmartPointer(itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(seed); - sampler.SetInput(image); - sampler.Update(); - return std::move(DerefRawPointer(sampler.GetOutput()).CastToSTLContainer()); - }; - - // Do the same test twice, to check that the result remains the same. - EXPECT_EQ(generateSamples(), generateSamples()); + EXPECT_EQ(multiThreadedSamples, singleThreadedSamples); } + + std::cin.get(); } diff --git a/Common/ImageSamplers/itkImageRandomCoordinateSampler.h b/Common/ImageSamplers/itkImageRandomCoordinateSampler.h index 1492090fd..6655cd11b 100644 --- a/Common/ImageSamplers/itkImageRandomCoordinateSampler.h +++ b/Common/ImageSamplers/itkImageRandomCoordinateSampler.h @@ -23,6 +23,8 @@ #include "itkBSplineInterpolateImageFunction.h" #include "itkMersenneTwisterRandomVariateGenerator.h" +#include + namespace itk { @@ -119,13 +121,6 @@ class ITK_TEMPLATE_EXPORT ImageRandomCoordinateSampler : public ImageRandomSampl void GenerateData() override; - /** Multi-threaded functionality that does the work. */ - void - BeforeThreadedGenerateData() override; - - void - ThreadedGenerateData(const InputImageRegionType & inputRegionForThread, ThreadIdType threadId) override; - /** Generate a point randomly in a bounding box. */ virtual void GenerateRandomCoordinate(const InputImageContinuousIndexType & smallestContIndex, @@ -153,6 +148,31 @@ class ITK_TEMPLATE_EXPORT ImageRandomCoordinateSampler : public ImageRandomSampl InputImageContinuousIndexType & largestContIndex); private: + struct UserData + { + ITK_DISALLOW_COPY_AND_ASSIGN(UserData); + + UserData(const std::vector & randomNumberList, + std::vector & samples, + const InputImageType & inputImage, + const InterpolatorType & interpolator) + : m_RandomNumberList(randomNumberList) + , m_Samples(samples) + , m_InputImage(inputImage) + , m_Interpolator(interpolator) + {} + + const std::vector & m_RandomNumberList{}; + std::vector & m_Samples{}; + const InputImageType & m_InputImage{}; + const InterpolatorType & m_Interpolator{}; + }; + + std::optional m_OptionalUserData{}; + + static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION + ThreaderCallback(void * arg); + bool m_UseRandomSampleRegion{ false }; }; diff --git a/Common/ImageSamplers/itkImageRandomCoordinateSampler.hxx b/Common/ImageSamplers/itkImageRandomCoordinateSampler.hxx index 28ea6ff86..23bf27dfb 100644 --- a/Common/ImageSamplers/itkImageRandomCoordinateSampler.hxx +++ b/Common/ImageSamplers/itkImageRandomCoordinateSampler.hxx @@ -19,6 +19,7 @@ #define itkImageRandomCoordinateSampler_hxx #include "itkImageRandomCoordinateSampler.h" +#include "elxDeref.h" #include namespace itk @@ -32,14 +33,6 @@ template void ImageRandomCoordinateSampler::GenerateData() { - /** Get a handle to the mask. If there was no mask supplied we exercise a multi-threaded version. */ - typename MaskType::ConstPointer mask = this->GetMask(); - if (mask.IsNull() && this->m_UseMultiThread) - { - /** Calls ThreadedGenerateData(). */ - return Superclass::GenerateData(); - } - /** Get handles to the input image, output sample container, and interpolator. */ InputImageConstPointer inputImage = this->GetInput(); typename ImageSampleContainerType::Pointer sampleContainer = this->GetOutput(); @@ -61,6 +54,37 @@ ImageRandomCoordinateSampler::GenerateData() InputImageContinuousIndexType largestContIndex; this->GenerateSampleRegion(smallestImageContIndex, largestImageContIndex, smallestContIndex, largestContIndex); + /** Get a handle to the mask. If there was no mask supplied we exercise a multi-threaded version. */ + typename MaskType::ConstPointer mask = this->GetMask(); + if (mask.IsNull() && this->m_UseMultiThread) + { + auto & samples = elastix::Deref(sampleContainer).CastToSTLContainer(); + samples.resize(this->Superclass::m_NumberOfSamples); + + /** Clear the random number list. */ + this->m_RandomNumberList.clear(); + this->m_RandomNumberList.reserve(this->m_NumberOfSamples * InputImageDimension); + + /** Fill the list with random numbers. */ + for (unsigned long i = 0; i < this->m_NumberOfSamples; ++i) + { + InputImageContinuousIndexType randomCIndex; + + this->GenerateRandomCoordinate(smallestContIndex, largestContIndex, randomCIndex); + for (unsigned int j = 0; j < InputImageDimension; ++j) + { + this->m_RandomNumberList.push_back(randomCIndex[j]); + } + } + + m_OptionalUserData.emplace(this->Superclass::m_RandomNumberList, samples, *inputImage, *interpolator); + + MultiThreaderBase & multiThreader = elastix::Deref(this->ProcessObject::GetMultiThreader()); + multiThreader.SetSingleMethod(&Self::ThreaderCallback, &*m_OptionalUserData); + multiThreader.SingleMethodExecute(); + return; + } + /** Reserve memory for the output. */ sampleContainer->Reserve(this->GetNumberOfSamples()); @@ -139,113 +163,59 @@ ImageRandomCoordinateSampler::GenerateData() /** - * ******************* BeforeThreadedGenerateData ******************* + * ******************* ThreaderCallback ******************* */ template -void -ImageRandomCoordinateSampler::BeforeThreadedGenerateData() +ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION +ImageRandomCoordinateSampler::ThreaderCallback(void * const arg) { - /** Set up the interpolator. */ - typename InterpolatorType::Pointer interpolator = this->GetModifiableInterpolator(); - interpolator->SetInputImage(this->GetInput()); // only once per resolution? - - /** Clear the random number list. */ - this->m_RandomNumberList.clear(); - this->m_RandomNumberList.reserve(this->m_NumberOfSamples * InputImageDimension); - - const auto croppedInputImageRegion = this->GetCroppedInputImageRegion(); - - /** Convert inputImageRegion to bounding box in physical space. */ - InputImageSizeType unitSize; - unitSize.Fill(1); - InputImageIndexType smallestIndex = croppedInputImageRegion.GetIndex(); - InputImageIndexType largestIndex = smallestIndex + croppedInputImageRegion.GetSize() - unitSize; - InputImageContinuousIndexType smallestImageCIndex(smallestIndex); - InputImageContinuousIndexType largestImageCIndex(largestIndex); - InputImageContinuousIndexType smallestCIndex, largestCIndex, randomCIndex; - this->GenerateSampleRegion(smallestImageCIndex, largestImageCIndex, smallestCIndex, largestCIndex); + assert(arg); + const auto & info = *static_cast(arg); - /** Fill the list with random numbers. */ - for (unsigned long i = 0; i < this->m_NumberOfSamples; ++i) - { - this->GenerateRandomCoordinate(smallestCIndex, largestCIndex, randomCIndex); - for (unsigned int j = 0; j < InputImageDimension; ++j) - { - this->m_RandomNumberList.push_back(randomCIndex[j]); - } - } + assert(info.UserData); + auto & userData = *static_cast(info.UserData); - /** Initialize variables needed for threads. */ - this->m_ThreaderSampleContainer.clear(); - this->m_ThreaderSampleContainer.resize(this->GetNumberOfWorkUnits()); - for (std::size_t i = 0; i < this->GetNumberOfWorkUnits(); ++i) - { - this->m_ThreaderSampleContainer[i] = ImageSampleContainerType::New(); - } + const auto & randomNumberList = userData.m_RandomNumberList; + auto & samples = userData.m_Samples; + const auto & interpolator = userData.m_Interpolator; -} // end BeforeThreadedGenerateData() + const auto totalNumberOfSamples = samples.size(); + assert((totalNumberOfSamples * InputImageDimension) == randomNumberList.size()); + const auto numberOfSamplesPerWorkUnit = totalNumberOfSamples / info.NumberOfWorkUnits; + const auto remainderNumberOfSamples = totalNumberOfSamples % info.NumberOfWorkUnits; -/** - * ******************* ThreadedGenerateData ******************* - */ + const auto offset = + info.WorkUnitID * numberOfSamplesPerWorkUnit + std::min(info.WorkUnitID, remainderNumberOfSamples); + const auto beginOfRandomNumbers = randomNumberList.data() + InputImageDimension * offset; + const auto beginOfSamples = samples.data() + offset; -template -void -ImageRandomCoordinateSampler::ThreadedGenerateData(const InputImageRegionType &, ThreadIdType threadId) -{ - /** Sanity check. */ - typename MaskType::ConstPointer mask = this->GetMask(); - if (mask.IsNotNull()) - { - itkExceptionMacro("ERROR: do not call this function when a mask is supplied."); - } + const auto & inputImage = userData.m_InputImage; - /** Get handle to the input image. */ - InputImageConstPointer inputImage = this->GetInput(); + const size_t n{ numberOfSamplesPerWorkUnit + (info.WorkUnitID < remainderNumberOfSamples ? 1 : 0) }; - /** Figure out which samples to process. */ - unsigned long chunkSize = this->GetNumberOfSamples() / this->GetNumberOfWorkUnits(); - unsigned long sampleStart = threadId * chunkSize * InputImageDimension; - if (threadId == this->GetNumberOfWorkUnits() - 1) + for (size_t i = 0; i < n; ++i) { - chunkSize = this->GetNumberOfSamples() - ((this->GetNumberOfWorkUnits() - 1) * chunkSize); - } + auto & sample = beginOfSamples[i]; + InputImageContinuousIndexType sampleCIndex; - /** Get a reference to the output and reserve memory for it. */ - ImageSampleContainerPointer & sampleContainerThisThread // & ??? - = this->m_ThreaderSampleContainer[threadId]; - sampleContainerThisThread->Reserve(chunkSize); - - /** Setup an iterator over the sampleContainerThisThread. */ - typename ImageSampleContainerType::Iterator iter; - typename ImageSampleContainerType::ConstIterator end = sampleContainerThisThread->End(); - - /** Fill the local sample container. */ - InputImageContinuousIndexType sampleCIndex; - unsigned long sampleId = sampleStart; - for (iter = sampleContainerThisThread->Begin(); iter != end; ++iter) - { /** Create a random point out of InputImageDimension random numbers. */ - for (unsigned int j = 0; j < InputImageDimension; ++j, sampleId++) + for (unsigned int j = 0; j < InputImageDimension; ++j) { - sampleCIndex[j] = this->m_RandomNumberList[sampleId]; + sampleCIndex[j] = beginOfRandomNumbers[InputImageDimension * i + j]; } - /** Make a reference to the current sample in the container. */ - InputImagePointType & samplePoint = iter->Value().m_ImageCoordinates; - ImageSampleValueType & sampleValue = iter->Value().m_ImageValue; - /** Convert to point */ - inputImage->TransformContinuousIndexToPhysicalPoint(sampleCIndex, samplePoint); + inputImage.TransformContinuousIndexToPhysicalPoint(sampleCIndex, sample.m_ImageCoordinates); - /** Compute the value at the contindex. */ - sampleValue = static_cast(this->m_Interpolator->EvaluateAtContinuousIndex(sampleCIndex)); + /** Compute the value at the continuous index. */ + sample.m_ImageValue = static_cast(interpolator.EvaluateAtContinuousIndex(sampleCIndex)); } // end for loop -} // end ThreadedGenerateData() + return ITK_THREAD_RETURN_DEFAULT_VALUE; +} /**