Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute from_json with struct schema using JSONUtils.fromJSONToStructs #11618

Merged
merged 28 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
18b7f3e
Migrate `castJsonStringToBool` to `JSONUtils.castStringsToBooleans`
ttnghia Oct 16, 2024
1865a0a
Migrate undoKeepQuotes` to use `JSONUtils.removeQuote`
ttnghia Oct 16, 2024
a5d5f04
Migrate `fixupQuotedStrings` to `JSONUtils.removeQuotes`
ttnghia Oct 17, 2024
6291972
Use `castStringsToDecimals`
ttnghia Oct 18, 2024
01724df
Use `removeQuotesForFloats` for implementing `castStringToFloat`
ttnghia Oct 18, 2024
692a0cb
Use `JSONUtils.castStringsToIntegers`
ttnghia Oct 18, 2024
df8c595
Throw if not supported type
ttnghia Oct 18, 2024
b3dcffc
Use `JSONUtils.castStringsToDates` for non-legacy conversion
ttnghia Oct 18, 2024
2d1cc03
Revert "Use `JSONUtils.castStringsToDates` for non-legacy conversion"
ttnghia Oct 23, 2024
94aaf95
Merge branch 'branch-24.12' into from_json_post_processing
ttnghia Oct 24, 2024
e9d6a8c
Use `JSONUtils.castStringsToFloats`
ttnghia Oct 24, 2024
2772b9e
Fix compile error
ttnghia Oct 29, 2024
9795a9b
Adopting `fromJSONToStructs`
ttnghia Oct 29, 2024
7e7dd5c
Merge branch 'branch-24.12' into from_json_post_processing
ttnghia Oct 30, 2024
3dd2a0a
Fix style
ttnghia Oct 30, 2024
049a6b4
Merge branch 'branch-24.12' into from_json_post_processing
ttnghia Nov 5, 2024
1239c10
Merge branch 'branch-24.12' into from_json_post_processing
ttnghia Nov 8, 2024
9f69280
Adopt `JSONUtils.convertDataType`
ttnghia Nov 13, 2024
47e7404
Cleanup
ttnghia Nov 13, 2024
e1c451a
Fix import
ttnghia Nov 14, 2024
5755e7f
Revert unrelated change
ttnghia Nov 14, 2024
e2f1724
Remove empty lines
ttnghia Nov 14, 2024
fe0b29b
Merge branch 'branch-24.12' into from_json_post_processing
ttnghia Nov 14, 2024
5a517f8
Change function name
ttnghia Nov 15, 2024
3e82285
Add more data to test
ttnghia Nov 16, 2024
425ddc5
Fix test pattern
ttnghia Nov 16, 2024
f174973
Add test
ttnghia Nov 18, 2024
7de3cce
Merge branch 'branch-24.12' into from_json_post_processing
ttnghia Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package org.apache.spark.sql.rapids

import java.util.Locale

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Scalar, Schema, Table}
import ai.rapids.cudf.{ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Schema, Table}
import com.fasterxml.jackson.core.JsonParser
import com.nvidia.spark.rapids.{ColumnCastUtil, GpuCast, GpuColumnVector, GpuScalar, GpuTextBasedPartitionReader}
import com.nvidia.spark.rapids.{ColumnCastUtil, GpuColumnVector, GpuScalar, GpuTextBasedPartitionReader}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.{JSONUtils}

import org.apache.spark.sql.catalyst.json.{GpuJsonUtils, JSONOptions}
import org.apache.spark.sql.rapids.shims.GpuJsonToStructsShim
Expand All @@ -47,8 +47,10 @@ object GpuJsonReadCommon {
}
case _: MapType =>
throw new IllegalArgumentException("MapType is not supported yet for schema conversion")
case dt: DecimalType =>
builder.addColumn(GpuColumnVector.getNonNestedRapidsType(dt), name, dt.precision)
case _ =>
builder.addColumn(DType.STRING, name)
builder.addColumn(GpuColumnVector.getNonNestedRapidsType(dt), name)
}

/**
Expand All @@ -62,160 +64,6 @@ object GpuJsonReadCommon {
builder.build
}

private def isQuotedString(input: ColumnView): ColumnVector = {
withResource(Scalar.fromString("\"")) { quote =>
withResource(input.startsWith(quote)) { sw =>
withResource(input.endsWith(quote)) { ew =>
sw.binaryOp(BinaryOp.LOGICAL_AND, ew, DType.BOOL8)
}
}
}
}

private def stripFirstAndLastChar(input: ColumnView): ColumnVector = {
withResource(Scalar.fromInt(1)) { one =>
val end = withResource(input.getCharLengths) { cc =>
withResource(cc.sub(one)) { endWithNulls =>
withResource(endWithNulls.isNull) { eIsNull =>
eIsNull.ifElse(one, endWithNulls)
}
}
}
withResource(end) { _ =>
withResource(ColumnVector.fromScalar(one, end.getRowCount.toInt)) { start =>
input.substring(start, end)
}
}
}
}

private def undoKeepQuotes(input: ColumnView): ColumnVector = {
withResource(isQuotedString(input)) { iq =>
withResource(stripFirstAndLastChar(input)) { stripped =>
iq.ifElse(stripped, input)
}
}
}

private def fixupQuotedStrings(input: ColumnView): ColumnVector = {
withResource(isQuotedString(input)) { iq =>
withResource(stripFirstAndLastChar(input)) { stripped =>
withResource(Scalar.fromString(null)) { ns =>
iq.ifElse(stripped, ns)
}
}
}
}

private lazy val specialUnquotedFloats =
Seq("NaN", "+INF", "-INF", "+Infinity", "Infinity", "-Infinity")
private lazy val specialQuotedFloats = specialUnquotedFloats.map(s => '"'+s+'"')

/**
* JSON has strict rules about valid numeric formats. See https://www.json.org/ for specification.
*
* Spark then has its own rules for supporting NaN and Infinity, which are not
* valid numbers in JSON.
*/
private def sanitizeFloats(input: ColumnView, options: JSONOptions): ColumnVector = {
// Note that this is not 100% consistent with Spark versions prior to Spark 3.3.0
// due to https://issues.apache.org/jira/browse/SPARK-38060
if (options.allowNonNumericNumbers) {
// Need to normalize the quotes to non-quoted to parse properly
withResource(ColumnVector.fromStrings(specialQuotedFloats: _*)) { quoted =>
withResource(ColumnVector.fromStrings(specialUnquotedFloats: _*)) { unquoted =>
input.findAndReplaceAll(quoted, unquoted)
}
}
} else {
input.copyToColumnVector()
}
}

private def sanitizeInts(input: ColumnView): ColumnVector = {
// Integer numbers cannot look like a float, so no `.` or e The rest of the parsing should
// handle this correctly. The rest of the validation is in CUDF itself

val tmp = withResource(Scalar.fromString(".")) { dot =>
withResource(input.stringContains(dot)) { hasDot =>
withResource(Scalar.fromString("e")) { e =>
withResource(input.stringContains(e)) { hase =>
hasDot.or(hase)
}
}
}
}
val invalid = withResource(tmp) { _ =>
withResource(Scalar.fromString("E")) { E =>
withResource(input.stringContains(E)) { hasE =>
tmp.or(hasE)
}
}
}
withResource(invalid) { _ =>
withResource(Scalar.fromNull(DType.STRING)) { nullString =>
invalid.ifElse(nullString, input)
}
}
}

private def sanitizeQuotedDecimalInUSLocale(input: ColumnView): ColumnVector = {
// The US locale is kind of special in that it will remove the , and then parse the
// input normally
withResource(stripFirstAndLastChar(input)) { stripped =>
withResource(Scalar.fromString(",")) { comma =>
withResource(Scalar.fromString("")) { empty =>
stripped.stringReplace(comma, empty)
}
}
}
}

private def sanitizeDecimal(input: ColumnView, options: JSONOptions): ColumnVector = {
assert(options.locale == Locale.US)
withResource(isQuotedString(input)) { isQuoted =>
withResource(sanitizeQuotedDecimalInUSLocale(input)) { quoted =>
isQuoted.ifElse(quoted, input)
}
}
}

private def castStringToFloat(input: ColumnView, dt: DType,
options: JSONOptions): ColumnVector = {
withResource(sanitizeFloats(input, options)) { sanitizedInput =>
CastStrings.toFloat(sanitizedInput, false, dt)
}
}

private def castStringToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {
// TODO there is a bug here around 0 https://github.com/NVIDIA/spark-rapids/issues/10898
CastStrings.toDecimal(input, false, false, dt.precision, -dt.scale)
}

private def castJsonStringToBool(input: ColumnView): ColumnVector = {
// Sadly there is no good kernel right now to do just this check/conversion
val isTrue = withResource(Scalar.fromString("true")) { trueStr =>
input.equalTo(trueStr)
}
withResource(isTrue) { _ =>
val isFalse = withResource(Scalar.fromString("false")) { falseStr =>
input.equalTo(falseStr)
}
val falseOrNull = withResource(isFalse) { _ =>
withResource(Scalar.fromBool(false)) { falseLit =>
withResource(Scalar.fromNull(DType.BOOL8)) { nul =>
isFalse.ifElse(falseLit, nul)
}
}
}
withResource(falseOrNull) { _ =>
withResource(Scalar.fromBool(true)) { trueLit =>
isTrue.ifElse(trueLit, falseOrNull)
}
}
}
}

private def dateFormat(options: JSONOptions): Option[String] =
GpuJsonUtils.optionalDateFormatInRead(options)

Expand All @@ -227,8 +75,9 @@ object GpuJsonReadCommon {
throw new IllegalStateException(s"Don't know how to transform $cv to $dt for JSON")
}


private def nestedColumnViewMismatchTransform(cv: ColumnView,
dt: DataType): (Option[ColumnView], Seq[AutoCloseable]) = {
dt: DataType): (Option[ColumnView], Seq[AutoCloseable]) = {
// In the future we should be able to convert strings to maps/etc, but for
// now we are working around issues where CUDF is not returning a STRING for nested
// types when asked for it.
Expand Down Expand Up @@ -269,34 +118,24 @@ object GpuJsonReadCommon {
options: JSONOptions): ColumnVector = {
ColumnCastUtil.deepTransform(inputCv, Some(topLevelType),
Some(nestedColumnViewMismatchTransform)) {
case (cv, Some(BooleanType)) if cv.getType == DType.STRING =>
castJsonStringToBool(cv)

case (cv, Some(DateType)) if cv.getType == DType.STRING =>
withResource(fixupQuotedStrings(cv)) { fixed =>
withResource(JSONUtils.removeQuotes(cv, true)) { fixed =>
GpuJsonToStructsShim.castJsonStringToDateFromScan(fixed, DType.TIMESTAMP_DAYS,
dateFormat(options))
}
case (cv, Some(TimestampType)) if cv.getType == DType.STRING =>
withResource(fixupQuotedStrings(cv)) { fixed =>
withResource(JSONUtils.removeQuotes(cv, true)) { fixed =>
GpuTextBasedPartitionReader.castStringToTimestamp(fixed, timestampFormat(options),
DType.TIMESTAMP_MICROSECONDS)
}
case (cv, Some(StringType)) if cv.getType == DType.STRING =>
undoKeepQuotes(cv)
case (cv, Some(dt: DecimalType)) if cv.getType == DType.STRING =>
withResource(sanitizeDecimal(cv, options)) { tmp =>
castStringToDecimal(tmp, dt)
}
case (cv, Some(dt)) if (dt == DoubleType || dt == FloatType) && cv.getType == DType.STRING =>
castStringToFloat(cv, GpuColumnVector.getNonNestedRapidsType(dt), options)
case (cv, Some(dt))
if (dt == ByteType || dt == ShortType || dt == IntegerType || dt == LongType ) &&
cv.getType == DType.STRING =>
withResource(sanitizeInts(cv)) { tmp =>
CastStrings.toInteger(tmp, false, GpuColumnVector.getNonNestedRapidsType(dt))
}

case (cv, Some(dt)) if cv.getType == DType.STRING =>
GpuCast.doCast(cv, StringType, dt)
val builder = Schema.builder
populateSchema(dt, "", builder)
JSONUtils.convertDataType(cv, builder.build, cudfJsonOptions(options),
options.locale == Locale.US)

}
}

Expand All @@ -320,6 +159,44 @@ object GpuJsonReadCommon {
}
}

def convertDateTimeType(inputCv: ColumnVector,
topLevelType: DataType,
options: JSONOptions): ColumnVector = {
ColumnCastUtil.deepTransform(inputCv, Some(topLevelType),
Some(nestedColumnViewMismatchTransform)) {


case (cv, Some(DateType)) if cv.getType == DType.STRING =>
withResource(JSONUtils.removeQuotes(cv, true)) { fixed =>
GpuJsonToStructsShim.castJsonStringToDateFromScan(fixed, DType.TIMESTAMP_DAYS,
dateFormat(options))
}
case (cv, Some(TimestampType)) if cv.getType == DType.STRING =>
withResource(JSONUtils.removeQuotes(cv, true)) { fixed =>
GpuTextBasedPartitionReader.castStringToTimestamp(fixed, timestampFormat(options),
DType.TIMESTAMP_MICROSECONDS)
}


}
}


/**
* Convert the parsed input table to the desired output types
* @param input the column to start with
* @param desired the desired output data types
* @param options the options the user provided
* @return an array of converted column vectors in the same order as the input table.
*/
// def convertDateTimeType(input: ColumnVector,
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
// desired: DataType,
// options: JSONOptions): Array[ColumnVector] = {
// withResource(new NvtxRange("convertDateTimeType", NvtxColor.RED)) { _ =>
// convertDateTimeType(input, desired, options)
// }
// }

def cudfJsonOptions(options: JSONOptions): ai.rapids.cudf.JSONOptions =
cudfJsonOptionBuilder(options).build()

Expand Down
Loading
Loading