From 391bc091f3f032e3dddda79a2c1ac73206dd6a68 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Fri, 23 Sep 2022 11:10:36 +0100 Subject: [PATCH 1/2] Push C++ standard to 17 and add MinAbsolute to complement MaxAbsolute --- CMakeLists.txt | 2 +- intgemm/intgemm.cc | 7 +++++++ intgemm/intgemm.h | 3 +++ intgemm/stats.h | 20 ++++++++++++++++++++ intgemm/stats.inl | 36 +++++++++++++++++++++++++++++++++++ test/multiply_test.cc | 44 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 111 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 11d613e..fa0e10e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) if(MSVC) add_compile_options(/W4 /WX) diff --git a/intgemm/intgemm.cc b/intgemm/intgemm.cc index 58e4bc5..ccabd7c 100644 --- a/intgemm/intgemm.cc +++ b/intgemm/intgemm.cc @@ -136,6 +136,11 @@ float Unsupported_MaxAbsolute(const float * /*begin*/, const float * /*end*/) { return 0.0f; } +float Unsupported_MinAbsolute(const float * /*begin*/, const float * /*end*/) { + UnsupportedCPUError(); + return 0.0f; +} + MeanStd Unsupported_VectorMeanStd(const float * /*begin*/, const float * /*end*/, bool /*absolute*/) { UnsupportedCPUError(); return MeanStd(); @@ -186,6 +191,8 @@ using AVX2::VectorMeanStd; float (*const MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(AVX512BW::MaxAbsolute, AVX512BW::MaxAbsolute, AVX2::MaxAbsolute, SSE2::MaxAbsolute, SSE2::MaxAbsolute, Unsupported_MaxAbsolute); +float (*const MinAbsolute)(const float *begin, const float *end) = ChooseCPU(AVX512BW::MinAbsolute, AVX512BW::MinAbsolute, AVX2::MinAbsolute, SSE2::MinAbsolute, SSE2::MinAbsolute, Unsupported_MinAbsolute); + MeanStd (*const VectorMeanStd)(const float *begin, const float *end, bool absolute) = ChooseCPU(AVX512BW::VectorMeanStd, AVX512BW::VectorMeanStd, AVX2::VectorMeanStd, SSE2::VectorMeanStd, SSE2::VectorMeanStd, Unsupported_VectorMeanStd); constexpr const char *const Unsupported_16bit::kName; diff --git a/intgemm/intgemm.h b/intgemm/intgemm.h index 26febb5..99eb3d5 100644 --- a/intgemm/intgemm.h +++ b/intgemm/intgemm.h @@ -352,6 +352,9 @@ extern const CPUType kCPU; // Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned. extern float (*const MaxAbsolute)(const float *begin, const float *end); +// Get the maximum absolute value of an array of floats. The number of floats must be a multiple of 16 and 64-byte aligned. +extern float (*const MinAbsolute)(const float *begin, const float *end); + // Get a Quantization value that is equant to the mean of the data +N standard deviations. Use 2 by default extern MeanStd (*const VectorMeanStd)(const float *begin, const float *end, bool); diff --git a/intgemm/stats.h b/intgemm/stats.h index 9573c4b..c741883 100644 --- a/intgemm/stats.h +++ b/intgemm/stats.h @@ -21,6 +21,16 @@ INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) { // This casting compiles to nothing. return *reinterpret_cast(&a); } +INTGEMM_SSE2 static inline float MinFloat32(__m128 a) { + // Fold to just using the first 64 bits. + __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); + a = _mm_min_ps(a, second_half); + // Fold to just using the first 32 bits. + second_half = _mm_shuffle_ps(a, a, 1); + a = _mm_min_ps(a, second_half); + // This casting compiles to nothing. + return *reinterpret_cast(&a); +} INTGEMM_SSE2 static inline float AddFloat32(__m128 a) { // Fold to just using the first 64 bits. __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2); @@ -36,6 +46,9 @@ INTGEMM_SSE2 static inline float AddFloat32(__m128 a) { INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) { return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))); } +INTGEMM_AVX2 static inline float MinFloat32(__m256 a) { + return MinFloat32(min_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))); +} INTGEMM_AVX2 static inline float AddFloat32(__m256 a) { return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))); } @@ -49,6 +62,13 @@ INTGEMM_AVX512F static inline float MaxFloat32(__m512 a) { __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1)); return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper)); } +// Find the minimum float. +INTGEMM_AVX512F static inline float MinFloat32(__m512 a) { + // _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking. + // So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps. + __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1)); + return MinFloat32(min_ps(_mm512_castps512_ps256(a), upper)); +} INTGEMM_AVX512F static inline float AddFloat32(__m512 a) { __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1)); return AddFloat32(add_ps(_mm512_castps512_ps256(a), upper)); diff --git a/intgemm/stats.inl b/intgemm/stats.inl index 68a5b8e..816eedd 100644 --- a/intgemm/stats.inl +++ b/intgemm/stats.inl @@ -60,6 +60,42 @@ INTGEMM_TARGET static inline float MaxAbsolute(const float *begin_float, const f return ret; } +/* Compute the minimum absolute value over floats aligned to register size. + * Do not call this function directly; it's a subroutine of MinAbsolute. + */ +INTGEMM_TARGET static inline float MinAbsoluteThread(const FRegister *begin, const FRegister *end) { + const FRegister abs_mask = cast_ps(set1_epi32(kFloatAbsoluteMask)); + // Technically not the lowest but we don't have set_inf_ps. Instead we want a low POSITIVE number + FRegister lowest = and_ps(abs_mask, set1_ps(*reinterpret_cast(begin))); +#pragma omp for + for (const FRegister *i = begin; i < end; ++i) { + FRegister reg = and_ps(abs_mask, *i); + lowest = min_ps(lowest, reg); + } + return MinFloat32(lowest); +} + +/* Compute the minimum absolute value of an array of floats. + * begin_float must be aligned to a multiple of the register size. +*/ +INTGEMM_TARGET static inline float MinAbsolute(const float *begin_float, const float *end_float) { + assert(reinterpret_cast(begin_float) % sizeof(FRegister) == 0); + const float *end_reg = end_float - (reinterpret_cast(end_float) % sizeof(FRegister)) / sizeof(float); + float ret = std::abs(begin_float[0]); +#pragma omp parallel reduction(min:ret) num_threads(std::max(1, std::min(omp_get_max_threads(), (end_float - begin_float) / 16384))) + { + float shard_max = MinAbsoluteThread( + reinterpret_cast(begin_float), + reinterpret_cast(end_reg)); + ret = std::min(ret, shard_max); + } + // Overhang, don't bother with vectorising it + for (const float *i = end_reg; i < end_float; ++i) { + ret = std::min(ret, std::fabs(*i)); + } + return ret; +} + /* Computes the euclidean norm and returns the mean and the standard deviation. Optionally it can be the mean and standard deviation in absolute terms. */ INTGEMM_TARGET static inline MeanStd VectorMeanStd(const float *begin_float, const float *end_float, bool absolute) { assert(end_float > begin_float); diff --git a/test/multiply_test.cc b/test/multiply_test.cc index f72758f..a1ad6ef 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -223,6 +223,31 @@ template void TestMaxAbsolute( } } +void CompareMinAbs(const float *begin, const float *end, float test, std::size_t offset) { + float minabs = std::reduce(begin, end, begin[0], [&](float a, float b){return std::min(std::fabs(a), std::fabs(b));}); + CHECK_MESSAGE(minabs == test, "Error: " << minabs << " versus " << test << " in length " << (end - begin) << " offset " << offset); +} + +template void TestMinAbsolute() { + std::mt19937 gen; + std::uniform_real_distribution dist(-8.0, 8.0); + const std::size_t kLengthMax = 65; + AlignedVector test(kLengthMax); + for (std::size_t len = 1; len < kLengthMax; ++len) { + for (std::size_t t = 0; t < len; ++t) { + // Fill with [-8, 8). + for (auto& it : test) { + it = dist(gen); + } + CompareMinAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + test[t] = -32.0; + CompareMinAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + test[t] = 32.0; + CompareMinAbs(test.begin(), test.begin() + len, Backend(test.begin(), test.begin() + len), t); + } + } +} + TEST_CASE("MaxAbsolute SSE2", "[max]") { if (kCPU < CPUType::SSE2) return; TestMaxAbsolute(); @@ -242,6 +267,25 @@ TEST_CASE("MaxAbsolute AVX512BW", "[max]") { } #endif +TEST_CASE("MinAbsolute SSE2", "[min]") { + if (kCPU < CPUType::SSE2) return; + TestMinAbsolute(); +} + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2 +TEST_CASE("MinAbsolute AVX2", "[min]") { + if (kCPU < CPUType::AVX2) return; + TestMinAbsolute(); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +TEST_CASE("MinAbsolute AVX512BW", "[min]") { + if (kCPU < CPUType::AVX512BW) return; + TestMinAbsolute(); +} +#endif + // Based on https://arxiv.org/abs/1705.01991 // Copyright (c) 2017 Microsoft Corporation From f1481ec0e9e46272bc1261d4a28d8d0f7dca08f1 Mon Sep 17 00:00:00 2001 From: XapaJIaMnu Date: Fri, 23 Sep 2022 13:45:46 +0100 Subject: [PATCH 2/2] Roll back c++17 --- CMakeLists.txt | 2 +- test/multiply_test.cc | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index fa0e10e..11d613e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 11) if(MSVC) add_compile_options(/W4 /WX) diff --git a/test/multiply_test.cc b/test/multiply_test.cc index a1ad6ef..4278a97 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -224,7 +224,12 @@ template void TestMaxAbsolute( } void CompareMinAbs(const float *begin, const float *end, float test, std::size_t offset) { - float minabs = std::reduce(begin, end, begin[0], [&](float a, float b){return std::min(std::fabs(a), std::fabs(b));}); + float minabs = std::abs(begin[0]); + for (const float * it = begin; it < end; it++) { + minabs = std::min(minabs, std::abs(*it)); + } + // For when we get C++17 + //float minabs = std::reduce(begin, end, begin[0], [&](float a, float b){return std::min(std::fabs(a), std::fabs(b));}); CHECK_MESSAGE(minabs == test, "Error: " << minabs << " versus " << test << " in length " << (end - begin) << " offset " << offset); }