Skip to content

Commit

Permalink
ENH: Improve ImageRandomSamplerSparseMask UseMultiThread support
Browse files Browse the repository at this point in the history
ImageRandomSamplerSparseMask now forwards `UseMultiThread` to its internal FullSampler, _and_ uses its own ThreaderCallback, instead of the old `ThreadedGenerateData` mechanism.

Follow-up to:

- pull request #973 commit ca9e9ec (ImageRandomSampler)
- pull request #978 commit 1ad43ea (ImageRandomCoordinateSampler)
  • Loading branch information
N-Dekker committed Nov 14, 2023
1 parent 92145c1 commit 61db741
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 71 deletions.
32 changes: 32 additions & 0 deletions Common/GTesting/itkImageRandomSamplerSparseMaskGTest.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,35 @@ GTEST_TEST(ImageRandomSamplerSparseMask, SetSeedMakesRandomizationDeterministic)
EXPECT_EQ(generateSamples(), samples);
}
}


GTEST_TEST(ImageRandomSamplerSparseMask, HasSameOutputWhenUsingMultiThread)
{
using PixelType = int;
constexpr auto Dimension = 2;
using ImageType = itk::Image<PixelType, Dimension>;
using SamplerType = itk::ImageRandomSamplerSparseMask<ImageType>;
using MaskSpatialObjectType = itk::ImageMaskSpatialObject<Dimension>;

const auto image =
CreateImageFilledWithSequenceOfNaturalNumbers<PixelType>(ImageType::SizeType::Filled(minimumImageSizeValue));

const auto maskImage = CreateImage<MaskSpatialObjectType::PixelType>(ImageDomain(*image));
FillImageRegion(*maskImage, itk::Index<Dimension>::Filled(1), ImageType::SizeType::Filled(minimumImageSizeValue - 1));

const auto maskSpatialObject = MaskSpatialObjectType::New();
maskSpatialObject->SetImage(maskImage);
maskSpatialObject->Update();

const auto generateSamples = [image, maskSpatialObject](const bool useMultiThread) {
DerefSmartPointer(MersenneTwisterRandomVariateGenerator::GetInstance()).SetSeed(1);
elx::DefaultConstruct<SamplerType> sampler{};
sampler.SetUseMultiThread(useMultiThread);
sampler.SetInput(image);
sampler.SetMask(maskSpatialObject);
sampler.Update();
return std::move(DerefRawPointer(sampler.GetOutput()).CastToSTLContainer());
};

EXPECT_EQ(generateSamples(true), generateSamples(false));
}
32 changes: 25 additions & 7 deletions Common/ImageSamplers/itkImageRandomSamplerSparseMask.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "itkImageRandomSamplerBase.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
#include "itkImageFullSampler.h"
#include <optional>

namespace itk
{
Expand Down Expand Up @@ -94,15 +95,32 @@ class ITK_TEMPLATE_EXPORT ImageRandomSamplerSparseMask : public ImageRandomSampl
void
GenerateData() override;

/** Multi-threaded functionality that does the work. */
void
BeforeThreadedGenerateData() override;

void
ThreadedGenerateData(const InputImageRegionType & inputRegionForThread, ThreadIdType threadId) override;

RandomGeneratorPointer m_RandomGenerator{ RandomGeneratorType::GetInstance() };
InternalFullSamplerPointer m_InternalFullSampler{ InternalFullSamplerType::New() };

private:
struct UserData
{
ITK_DISALLOW_COPY_AND_MOVE(UserData);

UserData(const std::vector<ImageSampleType> & allValidSamples,
const std::vector<size_t> & randomIndices,
std::vector<ImageSampleType> & samples)
: m_AllValidSamples(allValidSamples)
, m_RandomIndices(randomIndices)
, m_Samples(samples)
{}

const std::vector<ImageSampleType> & m_AllValidSamples;
const std::vector<size_t> & m_RandomIndices;
std::vector<ImageSampleType> & m_Samples;
};

std::vector<size_t> m_RandomIndices{};
std::optional<UserData> m_OptionalUserData{};

static ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
ThreaderCallback(void * arg);
};

} // end namespace itk
Expand Down
110 changes: 46 additions & 64 deletions Common/ImageSamplers/itkImageRandomSamplerSparseMask.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define itkImageRandomSamplerSparseMask_hxx

#include "itkImageRandomSamplerSparseMask.h"
#include "elxDeref.h"

namespace itk
{
Expand Down Expand Up @@ -53,6 +54,7 @@ ImageRandomSamplerSparseMask<TInputImage>::GenerateData()
this->m_InternalFullSampler->SetInput(inputImage);
this->m_InternalFullSampler->SetMask(mask);
this->m_InternalFullSampler->SetInputImageRegion(this->GetCroppedInputImageRegion());
this->m_InternalFullSampler->SetUseMultiThread(Superclass::m_UseMultiThread);

/** Use try/catch, since the full sampler may crash, due to insufficient memory. */
try
Expand All @@ -79,16 +81,32 @@ ImageRandomSamplerSparseMask<TInputImage>::GenerateData()
itkExceptionMacro(<< message);
}

/** Get a handle to the full sampler output. */
typename ImageSampleContainerType::Pointer allValidSamples = this->m_InternalFullSampler->GetOutput();
unsigned long numberOfValidSamples = allValidSamples->Size();


/** If desired we exercise a multi-threaded version. */
if (this->m_UseMultiThread)
{
/** Calls ThreadedGenerateData(). */
return Superclass::GenerateData();
}
m_RandomIndices.clear();
m_RandomIndices.reserve(Superclass::m_NumberOfSamples);

/** Get a handle to the full sampler output. */
typename ImageSampleContainerType::Pointer allValidSamples = this->m_InternalFullSampler->GetOutput();
unsigned long numberOfValidSamples = allValidSamples->Size();
for (unsigned int i = 0; i < Superclass::m_NumberOfSamples; ++i)
{
m_RandomIndices.push_back(m_RandomGenerator->GetIntegerVariate(numberOfValidSamples - 1));
}

auto & samples = elastix::Deref(sampleContainer).CastToSTLContainer();
samples.resize(m_RandomIndices.size());

m_OptionalUserData.emplace(elastix::Deref(allValidSamples).CastToSTLConstContainer(), m_RandomIndices, samples);

MultiThreaderBase & multiThreader = elastix::Deref(this->ProcessObject::GetMultiThreader());
multiThreader.SetSingleMethod(&Self::ThreaderCallback, &*m_OptionalUserData);
multiThreader.SingleMethodExecute();
return;
}

/** Take random samples from the allValidSamples-container. */
for (unsigned int i = 0; i < this->GetNumberOfSamples(); ++i)
Expand All @@ -103,75 +121,39 @@ ImageRandomSamplerSparseMask<TInputImage>::GenerateData()
} // end GenerateData()


/**
* ******************* BeforeThreadedGenerateData *******************
*/

template <class TInputImage>
void
ImageRandomSamplerSparseMask<TInputImage>::BeforeThreadedGenerateData()
ITK_THREAD_RETURN_FUNCTION_CALL_CONVENTION
ImageRandomSamplerSparseMask<TInputImage>::ThreaderCallback(void * const arg)
{
/** Clear the random number list. */
this->m_RandomNumberList.clear();
this->m_RandomNumberList.reserve(this->m_NumberOfSamples);

/** Get a handle to the full sampler output size. */
const unsigned long numberOfValidSamples = this->m_InternalFullSampler->GetOutput()->Size();
assert(arg);
const auto & info = *static_cast<const MultiThreaderBase::WorkUnitInfo *>(arg);

/** Fill the list with random numbers. */
for (unsigned int i = 0; i < this->GetNumberOfSamples(); ++i)
{
unsigned long randomIndex = this->m_RandomGenerator->GetIntegerVariate(numberOfValidSamples - 1);
this->m_RandomNumberList.push_back(randomIndex);
}

/** Initialize variables needed for threads. */
this->m_ThreaderSampleContainer.clear();
this->m_ThreaderSampleContainer.resize(this->GetNumberOfWorkUnits());
for (std::size_t i = 0; i < this->GetNumberOfWorkUnits(); ++i)
{
this->m_ThreaderSampleContainer[i] = ImageSampleContainerType::New();
}

} // end BeforeThreadedGenerateData()
assert(info.UserData);
auto & userData = *static_cast<UserData *>(info.UserData);

const auto & randomIndices = userData.m_RandomIndices;
auto & samples = userData.m_Samples;

/**
* ******************* ThreadedGenerateData *******************
*/

template <class TInputImage>
void
ImageRandomSamplerSparseMask<TInputImage>::ThreadedGenerateData(const InputImageRegionType &, ThreadIdType threadId)
{
/** Get a handle to the full sampler output. */
typename ImageSampleContainerType::Pointer allValidSamples = this->m_InternalFullSampler->GetOutput();
const auto totalNumberOfSamples = samples.size();
assert(totalNumberOfSamples == randomIndices.size());

/** Figure out which samples to process. */
unsigned long chunkSize = this->GetNumberOfSamples() / this->GetNumberOfWorkUnits();
unsigned long sampleStart = threadId * chunkSize;
if (threadId == this->GetNumberOfWorkUnits() - 1)
{
chunkSize = this->GetNumberOfSamples() - ((this->GetNumberOfWorkUnits() - 1) * chunkSize);
}
const auto numberOfSamplesPerWorkUnit = totalNumberOfSamples / info.NumberOfWorkUnits;
const auto remainderNumberOfSamples = totalNumberOfSamples % info.NumberOfWorkUnits;

/** Get a reference to the output and reserve memory for it. */
ImageSampleContainerPointer & sampleContainerThisThread = this->m_ThreaderSampleContainer[threadId];
sampleContainerThisThread->Reserve(chunkSize);
const auto offset =
info.WorkUnitID * numberOfSamplesPerWorkUnit + std::min<size_t>(info.WorkUnitID, remainderNumberOfSamples);
const auto beginOfRandomIndices = randomIndices.data() + offset;
const auto beginOfSamples = samples.data() + offset;
const auto & allValidSamples = userData.m_AllValidSamples;

/** Setup an iterator over the sampleContainerThisThread. */
typename ImageSampleContainerType::Iterator iter;
typename ImageSampleContainerType::ConstIterator end = sampleContainerThisThread->End();
const size_t n{ numberOfSamplesPerWorkUnit + (info.WorkUnitID < remainderNumberOfSamples ? 1 : 0) };

/** Take random samples from the allValidSamples-container. */
unsigned long sampleId = sampleStart;
for (iter = sampleContainerThisThread->Begin(); iter != end; ++iter, sampleId++)
for (size_t i = 0; i < n; ++i)
{
unsigned long randomIndex = static_cast<unsigned long>(this->m_RandomNumberList[sampleId]);
iter->Value() = allValidSamples->ElementAt(randomIndex);
beginOfSamples[i] = allValidSamples[beginOfRandomIndices[i]];
}

} // end ThreadedGenerateData()
return ITK_THREAD_RETURN_DEFAULT_VALUE;
}


/**
Expand Down

0 comments on commit 61db741

Please sign in to comment.