Skip to content

Commit

Permalink
ENH: ImageRandomCoordinateSampler UseMultiThread use ITK's ThreadPool
Browse files Browse the repository at this point in the history
Internally calling `multiThreader.SingleMethodExecute()`, instead of overriding `ThreadedGenerateData`.

When 10'000 random samples are retrieved, a significant performance improvement is observed by enabling `UseMultiThread`: from ~48 millisecond (single-threaded) to ~40 millisecond (multi-threaded). When 100'000 random samples are retrieved, even from ~123 to ~56 millisecond: Tested on a Windows 10 PC (6 cores, 12 logical processors), VS2019 Release build.

Follow-up to pull request #973 commit ca9e9ec "ENH: ImageRandomSampler `UseMultiThread` now using ITK's ThreadPool"
  • Loading branch information
N-Dekker committed Oct 26, 2023
1 parent 106076f commit 7ac934e
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 146 deletions.
72 changes: 25 additions & 47 deletions Common/GTesting/itkImageRandomCoordinateSamplerGTest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -30,64 +30,42 @@ using elx::CoreMainGTestUtilities::CreateImageFilledWithSequenceOfNaturalNumbers

#include <gtest/gtest.h>
#include <array>
#include <chrono>


GTEST_TEST(ImageRandomCoordinateSampler, CheckImageValuesOfSamples)
GTEST_TEST(ImageRandomCoordinateSampler, HasSameOutputWhenUsingMultiThread)
{
using PixelType = int;
using ImageType = itk::Image<PixelType>;
using ImageType = itk::Image<PixelType, 3>;
using SamplerType = itk::ImageRandomCoordinateSampler<ImageType>;

// Use a fixed seed, in order to have a reproducible sampler output.
DerefSmartPointer(itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1);
const auto image = CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(ImageType::SizeType::Filled(100));

const auto image =
CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(ImageType::SizeType::Filled(minimumImageSizeValue));
const auto generateSamples = [image](const bool useMultiThread, const SamplerType::SeedIntegerType seed) {
elx::DefaultConstruct<SamplerType> sampler{};
sampler.SetUseMultiThread(useMultiThread);

elx::DefaultConstruct<SamplerType> sampler{};
// This sampler uses ITK's global MersenneTwisterRandomVariateGenerator instance.
DerefSmartPointer(itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(seed);

const size_t numberOfSamples{ 3 };
sampler.SetNumberOfSamples(numberOfSamples);
sampler.SetInput(image);
sampler.Update();
sampler.SetNumberOfSamples(sampler.GetNumberOfSamples() * 100);
sampler.SetInput(image);
using namespace std::chrono;
const auto timePoint = high_resolution_clock::now();
sampler.Update();
std::cout << sampler.GetNumberOfSamples() << " samples "
<< " Duration: " << duration_cast<duration<double>>(high_resolution_clock::now() - timePoint).count()
<< " seconds " << (useMultiThread ? "MultiThread" : "single threaded") << std::endl;
return std::move(DerefRawPointer(sampler.GetOutput()).CastToSTLContainer());
};

const auto & samples = DerefRawPointer(sampler.GetOutput()).CastToSTLConstContainer();

ASSERT_EQ(samples.size(), numberOfSamples);

// The image values that appeared during the development of the test.
const std::array<SamplerType::ImageSampleType::RealType, numberOfSamples> expectedImageValues = { 14.269278,
14.93714,
1.882026 };

for (size_t i{}; i < numberOfSamples; ++i)
for (SamplerType::SeedIntegerType seed{}; seed < 7; ++seed)
{
EXPECT_FLOAT_EQ(samples[i].m_ImageValue, expectedImageValues[i]);
}
}

const auto multiThreadedSamples = generateSamples(true, seed);
const auto singleThreadedSamples = generateSamples(false, seed);

GTEST_TEST(ImageRandomCoordinateSampler, SetSeedMakesRandomizationDeterministic)
{
using PixelType = int;
using ImageType = itk::Image<PixelType>;
using SamplerType = itk::ImageRandomCoordinateSampler<ImageType>;

const auto image =
CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(ImageType::SizeType::Filled(minimumImageSizeValue));

for (const SamplerType::SeedIntegerType seed : { 0, 1 })
{
const auto generateSamples = [seed, image] {
elx::DefaultConstruct<SamplerType> sampler{};

DerefSmartPointer(itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(seed);
sampler.SetInput(image);
sampler.Update();
return std::move(DerefRawPointer(sampler.GetOutput()).CastToSTLContainer());
};

// Do the same test twice, to check that the result remains the same.
EXPECT_EQ(generateSamples(), generateSamples());
EXPECT_EQ(multiThreadedSamples, singleThreadedSamples);
}

std::cin.get();
}
34 changes: 27 additions & 7 deletions Common/ImageSamplers/itkImageRandomCoordinateSampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "itkBSplineInterpolateImageFunction.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"

#include <optional>

namespace itk
{

Expand Down Expand Up @@ -119,13 +121,6 @@ class ITK_TEMPLATE_EXPORT ImageRandomCoordinateSampler : public ImageRandomSampl
void
GenerateData() override;

/** Multi-threaded functionality that does the work. */
void
BeforeThreadedGenerateData() override;

void
ThreadedGenerateData(const InputImageRegionType & inputRegionForThread, ThreadIdType threadId) override;

/** Generate a point randomly in a bounding box. */
virtual void
GenerateRandomCoordinate(const InputImageContinuousIndexType & smallestContIndex,
Expand Down Expand Up @@ -153,6 +148,31 @@ class ITK_TEMPLATE_EXPORT ImageRandomCoordinateSampler : public ImageRandomSampl
InputImageContinuousIndexType & largestContIndex);

private:
struct UserData
{
ITK_DISALLOW_COPY_AND_ASSIGN(UserData);

UserData(const std::vector<double> & randomNumberList,
std::vector<ImageSampleType> & samples,
const InputImageType & inputImage,
const InterpolatorType & interpolator)
: m_RandomNumberList(randomNumberList)
, m_Samples(samples)
, m_InputImage(inputImage)
, m_Interpolator(interpolator)
{}

const std::vector<double> & m_RandomNumberList{};
std::vector<ImageSampleType> & m_Samples{};
const InputImageType & m_InputImage{};
const InterpolatorType & m_Interpolator{};
};

std::optional<UserData> m_OptionalUserData{};

static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
ThreaderCallback(void * arg);

bool m_UseRandomSampleRegion{ false };
};

Expand Down
154 changes: 62 additions & 92 deletions Common/ImageSamplers/itkImageRandomCoordinateSampler.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define itkImageRandomCoordinateSampler_hxx

#include "itkImageRandomCoordinateSampler.h"
#include "elxDeref.h"
#include <vnl/vnl_math.h>

namespace itk
Expand All @@ -32,14 +33,6 @@ template <class TInputImage>
void
ImageRandomCoordinateSampler<TInputImage>::GenerateData()
{
/** Get a handle to the mask. If there was no mask supplied we exercise a multi-threaded version. */
typename MaskType::ConstPointer mask = this->GetMask();
if (mask.IsNull() && this->m_UseMultiThread)
{
/** Calls ThreadedGenerateData(). */
return Superclass::GenerateData();
}

/** Get handles to the input image, output sample container, and interpolator. */
InputImageConstPointer inputImage = this->GetInput();
typename ImageSampleContainerType::Pointer sampleContainer = this->GetOutput();
Expand All @@ -61,6 +54,37 @@ ImageRandomCoordinateSampler<TInputImage>::GenerateData()
InputImageContinuousIndexType largestContIndex;
this->GenerateSampleRegion(smallestImageContIndex, largestImageContIndex, smallestContIndex, largestContIndex);

/** Get a handle to the mask. If there was no mask supplied we exercise a multi-threaded version. */
typename MaskType::ConstPointer mask = this->GetMask();
if (mask.IsNull() && this->m_UseMultiThread)
{
auto & samples = elastix::Deref(sampleContainer).CastToSTLContainer();
samples.resize(this->Superclass::m_NumberOfSamples);

/** Clear the random number list. */
this->m_RandomNumberList.clear();
this->m_RandomNumberList.reserve(this->m_NumberOfSamples * InputImageDimension);

/** Fill the list with random numbers. */
for (unsigned long i = 0; i < this->m_NumberOfSamples; ++i)
{
InputImageContinuousIndexType randomCIndex;

this->GenerateRandomCoordinate(smallestContIndex, largestContIndex, randomCIndex);
for (unsigned int j = 0; j < InputImageDimension; ++j)
{
this->m_RandomNumberList.push_back(randomCIndex[j]);
}
}

m_OptionalUserData.emplace(this->Superclass::m_RandomNumberList, samples, *inputImage, *interpolator);

MultiThreaderBase & multiThreader = elastix::Deref(this->ProcessObject::GetMultiThreader());
multiThreader.SetSingleMethod(&Self::ThreaderCallback, &*m_OptionalUserData);
multiThreader.SingleMethodExecute();
return;
}

/** Reserve memory for the output. */
sampleContainer->Reserve(this->GetNumberOfSamples());

Expand Down Expand Up @@ -139,113 +163,59 @@ ImageRandomCoordinateSampler<TInputImage>::GenerateData()


/**
* ******************* BeforeThreadedGenerateData *******************
* ******************* ThreaderCallback *******************
*/

template <class TInputImage>
void
ImageRandomCoordinateSampler<TInputImage>::BeforeThreadedGenerateData()
ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
ImageRandomCoordinateSampler<TInputImage>::ThreaderCallback(void * const arg)
{
/** Set up the interpolator. */
typename InterpolatorType::Pointer interpolator = this->GetModifiableInterpolator();
interpolator->SetInputImage(this->GetInput()); // only once per resolution?

/** Clear the random number list. */
this->m_RandomNumberList.clear();
this->m_RandomNumberList.reserve(this->m_NumberOfSamples * InputImageDimension);

const auto croppedInputImageRegion = this->GetCroppedInputImageRegion();

/** Convert inputImageRegion to bounding box in physical space. */
InputImageSizeType unitSize;
unitSize.Fill(1);
InputImageIndexType smallestIndex = croppedInputImageRegion.GetIndex();
InputImageIndexType largestIndex = smallestIndex + croppedInputImageRegion.GetSize() - unitSize;
InputImageContinuousIndexType smallestImageCIndex(smallestIndex);
InputImageContinuousIndexType largestImageCIndex(largestIndex);
InputImageContinuousIndexType smallestCIndex, largestCIndex, randomCIndex;
this->GenerateSampleRegion(smallestImageCIndex, largestImageCIndex, smallestCIndex, largestCIndex);
assert(arg);
const auto & info = *static_cast<const MultiThreaderBase::WorkUnitInfo *>(arg);

/** Fill the list with random numbers. */
for (unsigned long i = 0; i < this->m_NumberOfSamples; ++i)
{
this->GenerateRandomCoordinate(smallestCIndex, largestCIndex, randomCIndex);
for (unsigned int j = 0; j < InputImageDimension; ++j)
{
this->m_RandomNumberList.push_back(randomCIndex[j]);
}
}
assert(info.UserData);
auto & userData = *static_cast<UserData *>(info.UserData);

/** Initialize variables needed for threads. */
this->m_ThreaderSampleContainer.clear();
this->m_ThreaderSampleContainer.resize(this->GetNumberOfWorkUnits());
for (std::size_t i = 0; i < this->GetNumberOfWorkUnits(); ++i)
{
this->m_ThreaderSampleContainer[i] = ImageSampleContainerType::New();
}
const auto & randomNumberList = userData.m_RandomNumberList;
auto & samples = userData.m_Samples;
const auto & interpolator = userData.m_Interpolator;

} // end BeforeThreadedGenerateData()
const auto totalNumberOfSamples = samples.size();
assert((totalNumberOfSamples * InputImageDimension) == randomNumberList.size());

const auto numberOfSamplesPerWorkUnit = totalNumberOfSamples / info.NumberOfWorkUnits;
const auto remainderNumberOfSamples = totalNumberOfSamples % info.NumberOfWorkUnits;

/**
* ******************* ThreadedGenerateData *******************
*/
const auto offset =
info.WorkUnitID * numberOfSamplesPerWorkUnit + std::min<size_t>(info.WorkUnitID, remainderNumberOfSamples);
const auto beginOfRandomNumbers = randomNumberList.data() + InputImageDimension * offset;
const auto beginOfSamples = samples.data() + offset;

template <class TInputImage>
void
ImageRandomCoordinateSampler<TInputImage>::ThreadedGenerateData(const InputImageRegionType &, ThreadIdType threadId)
{
/** Sanity check. */
typename MaskType::ConstPointer mask = this->GetMask();
if (mask.IsNotNull())
{
itkExceptionMacro("ERROR: do not call this function when a mask is supplied.");
}
const auto & inputImage = userData.m_InputImage;

/** Get handle to the input image. */
InputImageConstPointer inputImage = this->GetInput();
const size_t n{ numberOfSamplesPerWorkUnit + (info.WorkUnitID < remainderNumberOfSamples ? 1 : 0) };

/** Figure out which samples to process. */
unsigned long chunkSize = this->GetNumberOfSamples() / this->GetNumberOfWorkUnits();
unsigned long sampleStart = threadId * chunkSize * InputImageDimension;
if (threadId == this->GetNumberOfWorkUnits() - 1)
for (size_t i = 0; i < n; ++i)
{
chunkSize = this->GetNumberOfSamples() - ((this->GetNumberOfWorkUnits() - 1) * chunkSize);
}
auto & sample = beginOfSamples[i];
InputImageContinuousIndexType sampleCIndex;

/** Get a reference to the output and reserve memory for it. */
ImageSampleContainerPointer & sampleContainerThisThread // & ???
= this->m_ThreaderSampleContainer[threadId];
sampleContainerThisThread->Reserve(chunkSize);

/** Setup an iterator over the sampleContainerThisThread. */
typename ImageSampleContainerType::Iterator iter;
typename ImageSampleContainerType::ConstIterator end = sampleContainerThisThread->End();

/** Fill the local sample container. */
InputImageContinuousIndexType sampleCIndex;
unsigned long sampleId = sampleStart;
for (iter = sampleContainerThisThread->Begin(); iter != end; ++iter)
{
/** Create a random point out of InputImageDimension random numbers. */
for (unsigned int j = 0; j < InputImageDimension; ++j, sampleId++)
for (unsigned int j = 0; j < InputImageDimension; ++j)
{
sampleCIndex[j] = this->m_RandomNumberList[sampleId];
sampleCIndex[j] = beginOfRandomNumbers[InputImageDimension * i + j];
}

/** Make a reference to the current sample in the container. */
InputImagePointType & samplePoint = iter->Value().m_ImageCoordinates;
ImageSampleValueType & sampleValue = iter->Value().m_ImageValue;

/** Convert to point */
inputImage->TransformContinuousIndexToPhysicalPoint(sampleCIndex, samplePoint);
inputImage.TransformContinuousIndexToPhysicalPoint(sampleCIndex, sample.m_ImageCoordinates);

/** Compute the value at the contindex. */
sampleValue = static_cast<ImageSampleValueType>(this->m_Interpolator->EvaluateAtContinuousIndex(sampleCIndex));
/** Compute the value at the continuous index. */
sample.m_ImageValue = static_cast<ImageSampleValueType>(interpolator.EvaluateAtContinuousIndex(sampleCIndex));

} // end for loop

} // end ThreadedGenerateData()
return ITK_THREAD_RETURN_DEFAULT_VALUE;
}


/**
Expand Down

0 comments on commit 7ac934e

Please sign in to comment.