Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt changes from JNI for casting from float to decimal #10917

Merged
merged 19 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_cast_string_timestamp_fallback():
decimal_gen_32bit,
pytest.param(decimal_gen_32bit_neg_scale, marks=
pytest.mark.skipif(is_dataproc_serverless_runtime(),
reason="Dataproc Serverless does not support negative scale for Decimal cast")),
reason="Dataproc Serverless does not support negative scale for Decimal cast")),
DecimalGen(precision=7, scale=7),
decimal_gen_64bit, decimal_gen_128bit, DecimalGen(precision=30, scale=2),
DecimalGen(precision=36, scale=5), DecimalGen(precision=38, scale=0),
Expand Down Expand Up @@ -265,6 +265,25 @@ def test_cast_long_to_decimal_overflow():
lambda spark : unary_op_df(spark, long_gen).select(
f.col('a').cast(DecimalType(18, -1))))


_float_special_cases = [(float("inf"), 5.0), (float("-inf"), 5.0), (float("nan"), 5.0)]
@pytest.mark.parametrize('data_gen', [FloatGen(special_cases=_float_special_cases),
DoubleGen(special_cases=_float_special_cases)],
ids=idfn)
@pytest.mark.parametrize('to_type', [
DecimalType(7, 1),
DecimalType(9, 9),
DecimalType(15, 2),
DecimalType(15, 15),
DecimalType(30, 3),
DecimalType(5, -3),
DecimalType(3, 0)], ids=idfn)
def test_cast_floating_point_to_decimal(data_gen, to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(
f.col('a'), f.col('a').cast(to_type)),
conf={'spark.rapids.sql.castFloatToDecimal.enabled': 'true'})

# casting these types to string should be passed
basic_gens_for_cast_to_string = [ByteGen, ShortGen, IntegerGen, LongGen, StringGen, BooleanGen, DateGen, TimestampGen]
basic_array_struct_gens_for_cast_to_string = [f() for f in basic_gens_for_cast_to_string] + [null_gen] + decimal_gens
Expand Down Expand Up @@ -310,7 +329,7 @@ def test_cast_array_to_string(data_gen, legacy):
_assert_cast_to_string_equal(
data_gen,
{"spark.sql.legacy.castComplexTypesToString.enabled": legacy})

def test_cast_float_to_string():
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, FloatGen()).selectExpr("cast(cast(a as string) as float)"),
Expand Down
117 changes: 11 additions & 106 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.util.Optional

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DecimalUtils, DType, RegexProgram, Scalar}
import ai.rapids.cudf.{BinaryOp, CaptureGroups, ColumnVector, ColumnView, DType, RegexProgram, Scalar}
import ai.rapids.cudf
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.jni.{CastStrings, GpuTimeZoneDB}
import com.nvidia.spark.rapids.jni.{CastStrings, DecimalUtils, GpuTimeZoneDB}
import com.nvidia.spark.rapids.shims.{AnsiUtil, GpuCastShims, GpuIntervalUtils, GpuTypeShims, SparkShimImpl, YearParseUtil}
import org.apache.commons.text.StringEscapeUtils

Expand Down Expand Up @@ -192,7 +192,7 @@ object CastOptions {
val ARITH_ANSI_OPTIONS = new CastOptions(false, true, false)
val TO_PRETTY_STRING_OPTIONS = ToPrettyStringOptions

def getArithmeticCastOptions(failOnError: Boolean): CastOptions =
def getArithmeticCastOptions(failOnError: Boolean): CastOptions =
if (failOnError) ARITH_ANSI_OPTIONS else DEFAULT_CAST_OPTIONS

object ToPrettyStringOptions extends CastOptions(false, false, false,
Expand Down Expand Up @@ -628,7 +628,7 @@ object GpuCast {
case (TimestampType, DateType) if options.timeZoneId.isDefined =>
val zoneId = DateTimeUtils.getZoneId(options.timeZoneId.get)
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(input.asInstanceOf[ColumnVector],
zoneId.normalized())) {
zoneId.normalized())) {
shifted => shifted.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType))
}
case _ =>
Expand Down Expand Up @@ -696,49 +696,6 @@ object GpuCast {
}
}

/**
* Detects outlier values of a column given with specific range, and replaces them with
* a inputted substitution value.
*
* @param values ColumnVector to be performed with range check
* @param minValue Named parameter for function to create Scalar representing range minimum value
* @param maxValue Named parameter for function to create Scalar representing range maximum value
* @param replaceValue Named parameter for function to create scalar to substitute outlier value
* @param inclusiveMin Whether the min value is included in the valid range or not
* @param inclusiveMax Whether the max value is included in the valid range or not
*/
private def replaceOutOfRangeValues(values: ColumnView,
minValue: => Scalar,
maxValue: => Scalar,
replaceValue: => Scalar,
inclusiveMin: Boolean,
inclusiveMax: Boolean): ColumnVector = {

withResource(minValue) { minValue =>
withResource(maxValue) { maxValue =>
val minPredicate = if (inclusiveMin) {
values.lessThan(minValue)
} else {
values.lessOrEqualTo(minValue)
}
withResource(minPredicate) { minPredicate =>
val maxPredicate = if (inclusiveMax) {
values.greaterThan(maxValue)
} else {
values.greaterOrEqualTo(maxValue)
}
withResource(maxPredicate) { maxPredicate =>
withResource(maxPredicate.or(minPredicate)) { rangePredicate =>
withResource(replaceValue) { nullScalar =>
rangePredicate.ifElse(nullScalar, values)
}
}
}
}
}
}
}

def castToString(
input: ColumnView,
fromDataType: DataType, options: CastOptions): ColumnVector = fromDataType match {
Expand Down Expand Up @@ -1638,65 +1595,13 @@ object GpuCast {
input: ColumnView,
dt: DecimalType,
ansiMode: Boolean): ColumnVector = {

// Approach to minimize difference between CPUCast and GPUCast:
// step 1. cast input to FLOAT64 (if necessary)
// step 2. cast FLOAT64 to container DECIMAL (who keeps one more digit for rounding)
// step 3. perform HALF_UP rounding on container DECIMAL
val checkedInput = withResource(input.castTo(DType.FLOAT64)) { double =>
val roundedDouble = double.round(dt.scale, cudf.RoundMode.HALF_UP)
withResource(roundedDouble) { rounded =>
// We rely on containerDecimal to perform preciser rounding. So, we have to take extra
// space cost of container into consideration when we run bound check.
val containerScaleBound = DType.DECIMAL128_MAX_PRECISION - (dt.scale + 1)
val bound = math.pow(10, (dt.precision - dt.scale) min containerScaleBound)
if (ansiMode) {
assertValuesInRange[Double](rounded,
minValue = -bound,
maxValue = bound,
inclusiveMin = false,
inclusiveMax = false)
rounded.incRefCount()
} else {
replaceOutOfRangeValues(rounded,
minValue = Scalar.fromDouble(-bound),
maxValue = Scalar.fromDouble(bound),
inclusiveMin = false,
inclusiveMax = false,
replaceValue = Scalar.fromNull(DType.FLOAT64))
}
}
}

withResource(checkedInput) { checked =>
val targetType = DecimalUtil.createCudfDecimal(dt)
// If target scale reaches DECIMAL128_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
val casted = if (DType.DECIMAL128_MAX_PRECISION == dt.scale) {
checked.castTo(targetType)
} else {
// Increase precision by one along with scale in case of overflow, which may lead to
// the upcast of cuDF decimal type. If precision already hits the max precision, it is safe
// to increase the scale solely because we have checked and replaced out of range values.
val containerType = DecimalUtils.createDecimalType(
dt.precision + 1 min DType.DECIMAL128_MAX_PRECISION, dt.scale + 1)
withResource(checked.castTo(containerType)) { container =>
withResource(container.round(dt.scale, cudf.RoundMode.HALF_UP)) { rd =>
// The cast here is for cases that cuDF decimal type got promoted as precision + 1.
// Need to convert back to original cuDF type, to keep align with the precision.
rd.castTo(targetType)
}
}
}
// Cast NaN values to nulls
withResource(casted) { casted =>
withResource(input.isNan) { inputIsNan =>
withResource(Scalar.fromNull(targetType)) { nullScalar =>
inputIsNan.ifElse(nullScalar, casted)
}
}
}
val targetType = DecimalUtil.createCudfDecimal(dt)
val converted = DecimalUtils.floatingPointToDecimal(input, targetType, dt.precision)
if (ansiMode && converted.hasFailure) {
converted.result.close()
throw RapidsErrorUtils.arithmeticOverflowError(OVERFLOW_MESSAGE)
}
converted.result
}

def fixDecimalBounds(input: ColumnView,
Expand Down Expand Up @@ -1901,4 +1806,4 @@ case class GpuCast(

override def doColumnar(input: GpuColumnVector): ColumnVector =
doCast(input.getBase, input.dataType(), dataType, options)
}
}
44 changes: 40 additions & 4 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

private def compareFloatToStringResults(float: Boolean, fromCpu: Array[Row],
private def compareFloatToStringResults(float: Boolean, fromCpu: Array[Row],
fromGpu: Array[Row]): Unit = {
fromCpu.zip(fromGpu).foreach {
case (c, g) =>
Expand Down Expand Up @@ -438,12 +438,12 @@ class CastOpSuite extends GpuExpressionTestSuite {
}

test("cast float to string") {
testCastToString[Float](DataTypes.FloatType, comparisonFunc =
testCastToString[Float](DataTypes.FloatType, comparisonFunc =
Some(compareStringifiedFloats(true)))
}

test("cast double to string") {
testCastToString[Double](DataTypes.DoubleType, comparisonFunc =
testCastToString[Double](DataTypes.DoubleType, comparisonFunc =
Some(compareStringifiedFloats(false)))
}

Expand Down Expand Up @@ -693,6 +693,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
List(-10, -1, 0, 1, 10).foreach { scale =>
testCastToDecimal(DataTypes.FloatType, scale,
customDataGenerator = Some(floatsIncludeNaNs))
assertThrows[Throwable] {
testCastToDecimal(DataTypes.FloatType, scale,
customDataGenerator = Some(floatsIncludeNaNs),
ansiEnabled = true)
}
}
}

Expand All @@ -710,6 +715,11 @@ class CastOpSuite extends GpuExpressionTestSuite {
List(-10, -1, 0, 1, 10).foreach { scale =>
testCastToDecimal(DataTypes.DoubleType, scale,
customDataGenerator = Some(doublesIncludeNaNs))
assertThrows[Throwable] {
testCastToDecimal(DataTypes.DoubleType, scale,
customDataGenerator = Some(doublesIncludeNaNs),
ansiEnabled = true)
}
}
}

Expand All @@ -729,6 +739,32 @@ class CastOpSuite extends GpuExpressionTestSuite {
customDataGenerator = Option(genDoubles))
}

test("cast float/double to decimal (borderline value rounding)") {
val genFloats_12_7: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(3527.61953125f))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.FloatType, precision = 12, scale = 7,
customDataGenerator = Option(genFloats_12_7))

val genDoubles_12_7: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(3527.61953125))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.DoubleType, precision = 12, scale = 7,
customDataGenerator = Option(genDoubles_12_7))

val genFloats_3_1: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(9.95f))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.FloatType, precision = 3, scale = 1,
customDataGenerator = Option(genFloats_3_1))

val genDoubles_3_1: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(9.95))).selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.DoubleType, precision = 3, scale = 1,
customDataGenerator = Option(genDoubles_3_1))
}

test("cast decimal to decimal") {
// fromScale == toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
Expand Down Expand Up @@ -967,7 +1003,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
dataType: DataType,
scale: Int,
precision: Int = ai.rapids.cudf.DType.DECIMAL128_MAX_PRECISION,
floatEpsilon: Double = 1e-9,
floatEpsilon: Double = 1e-14,
customDataGenerator: Option[SparkSession => DataFrame] = None,
customRandGenerator: Option[scala.util.Random] = None,
ansiEnabled: Boolean = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll {
-9223183700000000000L
).toDF("longs")
}

def datesPostEpochDf(session: SparkSession): DataFrame = {
import session.sqlContext.implicits._
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class RapidsTestSettings extends BackendTestSettings {
.exclude("SPARK-35719: cast timestamp with local time zone to timestamp without timezone", WONT_FIX_ISSUE("https://issues.apache.org/jira/browse/SPARK-40851"))
.exclude("SPARK-35112: Cast string to day-time interval", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10980"))
.exclude("SPARK-35735: Take into account day-time interval fields in cast", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10980"))
.exclude("casting to fixed-precision decimals", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10809"))
.exclude("casting to fixed-precision decimals", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/11250"))
.exclude("SPARK-32828: cast from a derived user-defined type to a base type", WONT_FIX_ISSUE("User-defined types are not supported"))
.exclude("cast string to timestamp", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/blob/main/docs/compatibility.md#string-to-timestamp"))
.exclude("cast string to date", KNOWN_ISSUE("https://github.com/NVIDIA/spark-rapids/issues/10771"))
Expand Down