Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MinAbsolute to complement MaxAbsolute #105

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions intgemm/intgemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ float Unsupported_MaxAbsolute(const float * /*begin*/, const float * /*end*/) {
return 0.0f;
}

float Unsupported_MinAbsolute(const float * /*begin*/, const float * /*end*/) {
UnsupportedCPUError();
return 0.0f;
}

MeanStd Unsupported_VectorMeanStd(const float * /*begin*/, const float * /*end*/, bool /*absolute*/) {
UnsupportedCPUError();
return MeanStd();
Expand Down Expand Up @@ -186,6 +191,8 @@ using AVX2::VectorMeanStd;

float (*const MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(AVX512BW::MaxAbsolute, AVX512BW::MaxAbsolute, AVX2::MaxAbsolute, SSE2::MaxAbsolute, SSE2::MaxAbsolute, Unsupported_MaxAbsolute);

float (*const MinAbsolute)(const float *begin, const float *end) = ChooseCPU(AVX512BW::MinAbsolute, AVX512BW::MinAbsolute, AVX2::MinAbsolute, SSE2::MinAbsolute, SSE2::MinAbsolute, Unsupported_MinAbsolute);

MeanStd (*const VectorMeanStd)(const float *begin, const float *end, bool absolute) = ChooseCPU(AVX512BW::VectorMeanStd, AVX512BW::VectorMeanStd, AVX2::VectorMeanStd, SSE2::VectorMeanStd, SSE2::VectorMeanStd, Unsupported_VectorMeanStd);

constexpr const char *const Unsupported_16bit::kName;
Expand Down
3 changes: 3 additions & 0 deletions intgemm/intgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ extern const CPUType kCPU;
// Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned.
extern float (*const MaxAbsolute)(const float *begin, const float *end);

// Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned.
extern float (*const MinAbsolute)(const float *begin, const float *end);

// Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default
extern MeanStd (*const VectorMeanStd)(const float *begin, const float *end, bool);

Expand Down
20 changes: 20 additions & 0 deletions intgemm/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
// This casting compiles to nothing.
return *reinterpret_cast<float*>(&a);
}
INTGEMM_SSE2 static inline float MinFloat32(__m128 a) {
// Fold to just using the first 64 bits.
__m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
a = _mm_min_ps(a, second_half);
// Fold to just using the first 32 bits.
second_half = _mm_shuffle_ps(a, a, 1);
a = _mm_min_ps(a, second_half);
// This casting compiles to nothing.
return *reinterpret_cast<float*>(&a);
}
INTGEMM_SSE2 static inline float AddFloat32(__m128 a) {
// Fold to just using the first 64 bits.
__m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
Expand All @@ -36,6 +46,9 @@ INTGEMM_SSE2 static inline float AddFloat32(__m128 a) {
INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) {
return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
}
INTGEMM_AVX2 static inline float MinFloat32(__m256 a) {
return MinFloat32(min_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
}
INTGEMM_AVX2 static inline float AddFloat32(__m256 a) {
return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
}
Expand All @@ -49,6 +62,13 @@ INTGEMM_AVX512F static inline float MaxFloat32(__m512 a) {
__m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper));
}
// Find the minimum float.
INTGEMM_AVX512F static inline float MinFloat32(__m512 a) {
// _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking.
// So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps.
__m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
return MinFloat32(min_ps(_mm512_castps512_ps256(a), upper));
}
INTGEMM_AVX512F static inline float AddFloat32(__m512 a) {
__m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
return AddFloat32(add_ps(_mm512_castps512_ps256(a), upper));
Expand Down
36 changes: 36 additions & 0 deletions intgemm/stats.inl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,42 @@ INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const f
return ret;
}

/* Compute the minimum absolute value over floats aligned to register size.
* Do not call this function directly; it's a subroutine of MinAbsolute.
*/
INTGEMM_TARGET static inline float MinAbsoluteThread(const FRegister *begin, const FRegister *end) {
const FRegister abs_mask = cast_ps(set1_epi32<Register>(kFloatAbsoluteMask));
// Technically not the lowest but we don't have set_inf_ps. Instead we want a low POSITIVE number
FRegister lowest = and_ps(abs_mask, set1_ps<FRegister>(*reinterpret_cast<const float *>(begin)));
#pragma omp for
for (const FRegister *i = begin; i < end; ++i) {
FRegister reg = and_ps(abs_mask, *i);
lowest = min_ps(lowest, reg);
}
return MinFloat32(lowest);
}

/* Compute the minimum absolute value of an array of floats.
* begin_float must be aligned to a multiple of the register size.
*/
INTGEMM_TARGET static inline float MinAbsolute(const float *begin_float, const float *end_float) {
assert(reinterpret_cast<uintptr_t>(begin_float) % sizeof(FRegister) == 0);
const float *end_reg = end_float - (reinterpret_cast<uintptr_t>(end_float) % sizeof(FRegister)) / sizeof(float);
float ret = std::abs(begin_float[0]);
#pragma omp parallel reduction(min:ret) num_threads(std::max<int>(1, std::min<int>(omp_get_max_threads(), (end_float - begin_float) / 16384)))
{
float shard_max = MinAbsoluteThread(
reinterpret_cast<const FRegister*>(begin_float),
reinterpret_cast<const FRegister*>(end_reg));
ret = std::min(ret, shard_max);
}
// Overhang, don't bother with vectorising it
for (const float *i = end_reg; i < end_float; ++i) {
ret = std::min(ret, std::fabs(*i));
}
return ret;
}

/* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */
INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) {
assert(end_float > begin_float);
Expand Down
49 changes: 49 additions & 0 deletions test/multiply_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,36 @@ template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute(
}
}

void CompareMinAbs(const float *begin, const float *end, float test, std::size_t offset) {
float minabs = std::abs(begin[0]);
for (const float * it = begin; it < end; it++) {
minabs = std::min(minabs, std::abs(*it));
}
// For when we get C++17
//float minabs = std::reduce(begin, end, begin[0], [&](float a, float b){return std::min(std::fabs(a), std::fabs(b));});
CHECK_MESSAGE(minabs == test, "Error: " << minabs << " versus " << test << " in length " << (end - begin) << " offset " << offset);
}

template <float (*Backend) (const float *, const float *)> void TestMinAbsolute() {
std::mt19937 gen;
std::uniform_real_distribution<float> dist(-8.0, 8.0);
const std::size_t kLengthMax = 65;
AlignedVector<float> test(kLengthMax);
for (std::size_t len = 1; len < kLengthMax; ++len) {
for (std::size_t t = 0; t < len; ++t) {
// Fill with [-8, 8).
for (auto& it : test) {
it = dist(gen);
}
CompareMinAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
test[t] = -32.0;
CompareMinAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
test[t] = 32.0;
CompareMinAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t);
}
}
}

TEST_CASE("MaxAbsolute SSE2", "[max]") {
if (kCPU < CPUType::SSE2) return;
TestMaxAbsolute<SSE2::MaxAbsolute>();
Expand All @@ -242,6 +272,25 @@ TEST_CASE("MaxAbsolute AVX512BW", "[max]") {
}
#endif

TEST_CASE("MinAbsolute SSE2", "[min]") {
if (kCPU < CPUType::SSE2) return;
TestMinAbsolute<SSE2::MinAbsolute>();
}

#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("MinAbsolute AVX2", "[min]") {
if (kCPU < CPUType::AVX2) return;
TestMinAbsolute<AVX2::MinAbsolute>();
}
#endif

#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("MinAbsolute AVX512BW", "[min]") {
if (kCPU < CPUType::AVX512BW) return;
TestMinAbsolute<AVX512BW::MinAbsolute>();
}
#endif

// Based on https://arxiv.org/abs/1705.01991

// Copyright (c) 2017 Microsoft Corporation
Expand Down