Skip to content

Commit

Permalink
Fix remaining corener cases
Browse files Browse the repository at this point in the history
  • Loading branch information
pmattione-nvidia committed Jul 10, 2024
1 parent a2ce4f9 commit 80e9926
Showing 1 changed file with 44 additions and 23 deletions.
67 changes: 44 additions & 23 deletions cpp/include/cudf/fixed_point/floating_conversion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cudf/utilities/traits.hpp>

#include <cuda/std/cmath>
#include <cuda/std/limits>
#include <cuda/std/type_traits>

Expand Down Expand Up @@ -1050,10 +1051,11 @@ template <typename Rep,
CUDF_HOST_DEVICE inline Rep convert_floating_to_integral_SPARK_RAPIDS(FloatingType floating,
scale_type const& scale)
{
// The rounding and precision decisions made here are chosen to match Apache Spark.
// Spark wants to perform the conversion as double to have the most precision.
// However, the behavior is still slightly different if the original type was float.

// Extract components of the floating point number
// Extract components of the (double-ized) floating point number
using converter = floating_converter<double>;
auto const integer_rep = converter::bit_cast_to_integer(double(floating));
if (converter::is_zero(integer_rep)) { return 0; }
Expand All @@ -1062,41 +1064,60 @@ CUDF_HOST_DEVICE inline Rep convert_floating_to_integral_SPARK_RAPIDS(FloatingTy
auto const is_negative = converter::get_is_negative(integer_rep);
auto const [significand, floating_pow2] = converter::get_significand_and_pow2(integer_rep);

// Spark wants to round the last decimal place, so we'll perform the conversion
// with one lower power of 10 so that we can round at the end.
// Spark often wants to round the last decimal place, so we'll perform the conversion
// with one lower power of 10 so that we can (optionally) round at the end.
auto const pow10 = static_cast<int>(scale);
auto const shifting_pow10 = pow10 - 1;

// Four doubles, add half a bit to correct for compiler rounding text to nearest floating-point value.
// See comments in add_half_if_truncates(), except here we always add it ... but only for doubles.
// Even if we don't add (floats), shift bits to line up with what the shifting algorithm is expecting.
auto base2_value = (significand << 1);
auto const pow2 = floating_pow2 - 1;
if constexpr (cuda::std::is_same_v<FloatingType, double>) {
++base2_value;
}
// Sometimes add half a bit to correct for compiler rounding text to nearest floating-point value.
// See comments in add_half_if_truncates(), with differences detailed below.
// Even if we don't add the bit, shift bits to line up with what the shifting algorithm is expecting.
bool const is_whole_number = (cuda::std::floor(floating) == floating);
auto const [base2_value, pow2] = [is_whole_number](auto significand, auto floating_pow2){
if constexpr (cuda::std::is_same_v<FloatingType, double>) {
// Add the 1/2 bit regardless of truncation, but still not for whole numbers
auto const base2_value = (significand << 1) + static_cast<decltype(significand)>(!is_whole_number);
return std::make_pair(base2_value, floating_pow2 - 1);
} else {
// Input was float: never add 1/2 bit.
// Why? Because we converted to double, and the 1/2 bit beyond float is WAY too large compared
// to double's precision. And the 1/2 bit beyond double is not due to user input.
return std::make_pair(significand << 1, floating_pow2 - 1);
}
}(significand, floating_pow2);

// Main algorithm: Apply the powers of 2 and 10 (except for the last power-of-10)
auto magnitude = convert_floating_to_integral_shifting<Rep, double>(base2_value, shifting_pow10, pow2);

//To round the final decimal place, add 5 to one past the last decimal place
magnitude += 5;

// Spark wants to floor the last digits of the output, clearing data that was beyond the
// precision that was available in double.
// How many digits do we need to floor?
// The (rounded) decimal digit corresponding to pow2 (just past double precision) to the end (pow10).
// The conversion from pow2 to pow10 is log10(2), which is ~90/299
int const rounding_term = (pow2 > 0) ? 299 : -299; //round away from zero
int const rounded_pow2_decimal_digit = (180 * pow2 + rounding_term) / 598;
int const floor_pow10 = rounded_pow2_decimal_digit - pow10;
if (floor_pow10 >= 0) {
// Note, if floor_pow10 is negative, the scale factor cut off the extra, imprecise bits.
// From the decimal digit corresponding to pow2 (just past double precision) to the end (pow10).
// The conversion from pow2 to pow10 is log10(2), which is ~ 90/299 (close enough for ints)
int const floor_pow10 = (90 * pow2) / 299 - pow10;
if (floor_pow10 < 0) {
// Truncated: The scale factor cut off the extra, imprecise bits.
// To round to the final decimal place, add 5 to one past the last decimal place
magnitude += 5U;
magnitude /= 10U; //Apply the last power of 10
} else {
// We are keeping decimal digits with data beyond the precision of double
// We want to truncate these digits, but sometimes we want to round first
// We will round if and only if we didn't already add a half-bit earlier
if constexpr (cuda::std::is_same_v<FloatingType, double>) {
// For doubles, only round the extra digits of whole numbers
// If it was not a whole number, we already added 1/2 a bit at higher precision than this earlier.
if (is_whole_number) {
magnitude += multiply_power10<Rep>(decltype(magnitude)(5), floor_pow10);
}
} else {
// Input was float: we didn't add a half-bit earlier, so round at the edge of precision here.
magnitude += multiply_power10<Rep>(decltype(magnitude)(5), floor_pow10);
}

// +1: Divide the last power-of-10 that we postponed earlier to do rounding.
auto const truncated = divide_power10<Rep>(magnitude, floor_pow10 + 1);
magnitude = multiply_power10<Rep>(truncated, floor_pow10);
} else {
magnitude /= 10; //Apply the last power of 10
}

// Reapply the sign and return
Expand Down

0 comments on commit 80e9926

Please sign in to comment.