Skip to content

Commit

Permalink
Execute from_json with struct schema using `JSONUtils.fromJSONToStr…
Browse files Browse the repository at this point in the history
…ucts` (#11618)

* Migrate `castJsonStringToBool` to `JSONUtils.castStringsToBooleans`

Signed-off-by: Nghia Truong <[email protected]>

* Migrate undoKeepQuotes` to use `JSONUtils.removeQuote`

Signed-off-by: Nghia Truong <[email protected]>

* Migrate `fixupQuotedStrings` to `JSONUtils.removeQuotes`

Signed-off-by: Nghia Truong <[email protected]>

* Use `castStringsToDecimals`

Signed-off-by: Nghia Truong <[email protected]>

* Use `removeQuotesForFloats` for implementing `castStringToFloat`

Signed-off-by: Nghia Truong <[email protected]>

* Use `JSONUtils.castStringsToIntegers`

Signed-off-by: Nghia Truong <[email protected]>

* Throw if not supported type

Signed-off-by: Nghia Truong <[email protected]>

* Use `JSONUtils.castStringsToDates` for non-legacy conversion

Signed-off-by: Nghia Truong <[email protected]>

* Revert "Use `JSONUtils.castStringsToDates` for non-legacy conversion"

This reverts commit b3dcffc.

* Use `JSONUtils.castStringsToFloats`

Signed-off-by: Nghia Truong <[email protected]>

* Fix  compile error

Signed-off-by: Nghia Truong <[email protected]>

* Adopting `fromJSONToStructs`

Signed-off-by: Nghia Truong <[email protected]>

* Fix style

Signed-off-by: Nghia Truong <[email protected]>

* Adopt `JSONUtils.convertDataType`

Signed-off-by: Nghia Truong <[email protected]>

* Cleanup

Signed-off-by: Nghia Truong <[email protected]>

* Fix import

Signed-off-by: Nghia Truong <[email protected]>

* Revert unrelated change

Signed-off-by: Nghia Truong <[email protected]>

* Remove empty lines

Signed-off-by: Nghia Truong <[email protected]>

* Change function name

Signed-off-by: Nghia Truong <[email protected]>

* Add more data to test

Signed-off-by: Nghia Truong <[email protected]>

* Fix test pattern

Signed-off-by: Nghia Truong <[email protected]>

* Add test

Signed-off-by: Nghia Truong <[email protected]>

---------

Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia authored Nov 23, 2024
1 parent cacc3ae commit daaaf24
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 287 deletions.
30 changes: 30 additions & 0 deletions integration_tests/src/main/python/json_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,36 @@ def test_from_json_struct_of_list(schema):
.select(f.from_json('a', schema)),
conf=_enable_all_types_conf)

@allow_non_gpu(*non_utc_allow)
def test_from_json_struct_of_list_with_mismatched_schema():
json_string_gen = StringGen(r'{"teacher": "[A-Z]{1}[a-z]{2,5}",' \
r'"student": \["[A-Z]{1}[a-z]{2,5}"\]}') \
.with_special_pattern('', weight=50) \
.with_special_pattern('null', weight=50)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, json_string_gen) \
.select(f.from_json('a', 'struct<teacher:string,student:array<struct<name:string,class:string>>>')),
conf=_enable_all_types_conf)

@pytest.mark.parametrize('schema', ['struct<teacher:string>',
'struct<student:array<struct<name:string,class:string>>>',
'struct<teacher:string,student:array<struct<name:string,class:string>>>'])
@allow_non_gpu(*non_utc_allow)
@pytest.mark.xfail(reason='https://github.com/rapidsai/cudf/issues/17349')
def test_from_json_struct_of_list_with_mixed_nested_types_input(schema):
json_string_gen = StringGen(r'{"teacher": "[A-Z]{1}[a-z]{2,5}",' \
r'"student": \[{"name": "[A-Z]{1}[a-z]{2,5}", "class": "junior"},' \
r'{"name": "[A-Z]{1}[a-z]{2,5}", "class": "freshman"}\]}') \
.with_special_pattern('', weight=50) \
.with_special_pattern('null', weight=50) \
.with_special_pattern('{"student": \["[A-Z]{1}[a-z]{2,5}"\]}', weight=100) \
.with_special_pattern('{"student": \[[1-9]{1,5}\]}', weight=100) \
.with_special_pattern('{"student": {"[A-Z]{1}[a-z]{2,5}": "[A-Z]{1}[a-z]{2,5}"}}', weight=100)
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, json_string_gen) \
.select(f.from_json('a', schema)),
conf=_enable_all_types_conf)

@pytest.mark.parametrize('schema', [
'struct<a:string>'
])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@
* limitations under the License.
*/


package org.apache.spark.sql.rapids

import java.util.Locale

import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Scalar, Schema, Table}
import ai.rapids.cudf.{ColumnVector, ColumnView, DType, NvtxColor, NvtxRange, Schema, Table}
import com.fasterxml.jackson.core.JsonParser
import com.nvidia.spark.rapids.{ColumnCastUtil, GpuCast, GpuColumnVector, GpuScalar, GpuTextBasedPartitionReader}
import com.nvidia.spark.rapids.{ColumnCastUtil, GpuColumnVector, GpuScalar, GpuTextBasedPartitionReader}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import com.nvidia.spark.rapids.jni.CastStrings
import com.nvidia.spark.rapids.jni.JSONUtils

import org.apache.spark.sql.catalyst.json.{GpuJsonUtils, JSONOptions}
import org.apache.spark.sql.rapids.shims.GpuJsonToStructsShim
Expand All @@ -47,8 +46,10 @@ object GpuJsonReadCommon {
}
case _: MapType =>
throw new IllegalArgumentException("MapType is not supported yet for schema conversion")
case dt: DecimalType =>
builder.addColumn(GpuColumnVector.getNonNestedRapidsType(dt), name, dt.precision)
case _ =>
builder.addColumn(DType.STRING, name)
builder.addColumn(GpuColumnVector.getNonNestedRapidsType(dt), name)
}

/**
Expand All @@ -62,160 +63,6 @@ object GpuJsonReadCommon {
builder.build
}

private def isQuotedString(input: ColumnView): ColumnVector = {
withResource(Scalar.fromString("\"")) { quote =>
withResource(input.startsWith(quote)) { sw =>
withResource(input.endsWith(quote)) { ew =>
sw.binaryOp(BinaryOp.LOGICAL_AND, ew, DType.BOOL8)
}
}
}
}

private def stripFirstAndLastChar(input: ColumnView): ColumnVector = {
withResource(Scalar.fromInt(1)) { one =>
val end = withResource(input.getCharLengths) { cc =>
withResource(cc.sub(one)) { endWithNulls =>
withResource(endWithNulls.isNull) { eIsNull =>
eIsNull.ifElse(one, endWithNulls)
}
}
}
withResource(end) { _ =>
withResource(ColumnVector.fromScalar(one, end.getRowCount.toInt)) { start =>
input.substring(start, end)
}
}
}
}

private def undoKeepQuotes(input: ColumnView): ColumnVector = {
withResource(isQuotedString(input)) { iq =>
withResource(stripFirstAndLastChar(input)) { stripped =>
iq.ifElse(stripped, input)
}
}
}

private def fixupQuotedStrings(input: ColumnView): ColumnVector = {
withResource(isQuotedString(input)) { iq =>
withResource(stripFirstAndLastChar(input)) { stripped =>
withResource(Scalar.fromString(null)) { ns =>
iq.ifElse(stripped, ns)
}
}
}
}

private lazy val specialUnquotedFloats =
Seq("NaN", "+INF", "-INF", "+Infinity", "Infinity", "-Infinity")
private lazy val specialQuotedFloats = specialUnquotedFloats.map(s => '"'+s+'"')

/**
* JSON has strict rules about valid numeric formats. See https://www.json.org/ for specification.
*
* Spark then has its own rules for supporting NaN and Infinity, which are not
* valid numbers in JSON.
*/
private def sanitizeFloats(input: ColumnView, options: JSONOptions): ColumnVector = {
// Note that this is not 100% consistent with Spark versions prior to Spark 3.3.0
// due to https://issues.apache.org/jira/browse/SPARK-38060
if (options.allowNonNumericNumbers) {
// Need to normalize the quotes to non-quoted to parse properly
withResource(ColumnVector.fromStrings(specialQuotedFloats: _*)) { quoted =>
withResource(ColumnVector.fromStrings(specialUnquotedFloats: _*)) { unquoted =>
input.findAndReplaceAll(quoted, unquoted)
}
}
} else {
input.copyToColumnVector()
}
}

private def sanitizeInts(input: ColumnView): ColumnVector = {
// Integer numbers cannot look like a float, so no `.` or e The rest of the parsing should
// handle this correctly. The rest of the validation is in CUDF itself

val tmp = withResource(Scalar.fromString(".")) { dot =>
withResource(input.stringContains(dot)) { hasDot =>
withResource(Scalar.fromString("e")) { e =>
withResource(input.stringContains(e)) { hase =>
hasDot.or(hase)
}
}
}
}
val invalid = withResource(tmp) { _ =>
withResource(Scalar.fromString("E")) { E =>
withResource(input.stringContains(E)) { hasE =>
tmp.or(hasE)
}
}
}
withResource(invalid) { _ =>
withResource(Scalar.fromNull(DType.STRING)) { nullString =>
invalid.ifElse(nullString, input)
}
}
}

private def sanitizeQuotedDecimalInUSLocale(input: ColumnView): ColumnVector = {
// The US locale is kind of special in that it will remove the , and then parse the
// input normally
withResource(stripFirstAndLastChar(input)) { stripped =>
withResource(Scalar.fromString(",")) { comma =>
withResource(Scalar.fromString("")) { empty =>
stripped.stringReplace(comma, empty)
}
}
}
}

private def sanitizeDecimal(input: ColumnView, options: JSONOptions): ColumnVector = {
assert(options.locale == Locale.US)
withResource(isQuotedString(input)) { isQuoted =>
withResource(sanitizeQuotedDecimalInUSLocale(input)) { quoted =>
isQuoted.ifElse(quoted, input)
}
}
}

private def castStringToFloat(input: ColumnView, dt: DType,
options: JSONOptions): ColumnVector = {
withResource(sanitizeFloats(input, options)) { sanitizedInput =>
CastStrings.toFloat(sanitizedInput, false, dt)
}
}

private def castStringToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = {
// TODO there is a bug here around 0 https://github.com/NVIDIA/spark-rapids/issues/10898
CastStrings.toDecimal(input, false, false, dt.precision, -dt.scale)
}

private def castJsonStringToBool(input: ColumnView): ColumnVector = {
// Sadly there is no good kernel right now to do just this check/conversion
val isTrue = withResource(Scalar.fromString("true")) { trueStr =>
input.equalTo(trueStr)
}
withResource(isTrue) { _ =>
val isFalse = withResource(Scalar.fromString("false")) { falseStr =>
input.equalTo(falseStr)
}
val falseOrNull = withResource(isFalse) { _ =>
withResource(Scalar.fromBool(false)) { falseLit =>
withResource(Scalar.fromNull(DType.BOOL8)) { nul =>
isFalse.ifElse(falseLit, nul)
}
}
}
withResource(falseOrNull) { _ =>
withResource(Scalar.fromBool(true)) { trueLit =>
isTrue.ifElse(trueLit, falseOrNull)
}
}
}
}

private def dateFormat(options: JSONOptions): Option[String] =
GpuJsonUtils.optionalDateFormatInRead(options)

Expand All @@ -228,7 +75,7 @@ object GpuJsonReadCommon {
}

private def nestedColumnViewMismatchTransform(cv: ColumnView,
dt: DataType): (Option[ColumnView], Seq[AutoCloseable]) = {
dt: DataType): (Option[ColumnView], Seq[AutoCloseable]) = {
// In the future we should be able to convert strings to maps/etc, but for
// now we are working around issues where CUDF is not returning a STRING for nested
// types when asked for it.
Expand Down Expand Up @@ -264,43 +111,40 @@ object GpuJsonReadCommon {
}
}

private def convertStringToDate(input: ColumnView, options: JSONOptions): ColumnVector = {
withResource(JSONUtils.removeQuotes(input, /*nullifyIfNotQuoted*/ true)) { removedQuotes =>
GpuJsonToStructsShim.castJsonStringToDateFromScan(removedQuotes, DType.TIMESTAMP_DAYS,
dateFormat(options))
}
}

private def convertStringToTimestamp(input: ColumnView, options: JSONOptions): ColumnVector = {
withResource(JSONUtils.removeQuotes(input, /*nullifyIfNotQuoted*/ true)) { removedQuotes =>
GpuTextBasedPartitionReader.castStringToTimestamp(removedQuotes, timestampFormat(options),
DType.TIMESTAMP_MICROSECONDS)
}
}

private def convertToDesiredType(inputCv: ColumnVector,
topLevelType: DataType,
options: JSONOptions): ColumnVector = {
ColumnCastUtil.deepTransform(inputCv, Some(topLevelType),
Some(nestedColumnViewMismatchTransform)) {
case (cv, Some(BooleanType)) if cv.getType == DType.STRING =>
castJsonStringToBool(cv)
case (cv, Some(DateType)) if cv.getType == DType.STRING =>
withResource(fixupQuotedStrings(cv)) { fixed =>
GpuJsonToStructsShim.castJsonStringToDateFromScan(fixed, DType.TIMESTAMP_DAYS,
dateFormat(options))
}
convertStringToDate(cv, options)
case (cv, Some(TimestampType)) if cv.getType == DType.STRING =>
withResource(fixupQuotedStrings(cv)) { fixed =>
GpuTextBasedPartitionReader.castStringToTimestamp(fixed, timestampFormat(options),
DType.TIMESTAMP_MICROSECONDS)
}
case (cv, Some(StringType)) if cv.getType == DType.STRING =>
undoKeepQuotes(cv)
case (cv, Some(dt: DecimalType)) if cv.getType == DType.STRING =>
withResource(sanitizeDecimal(cv, options)) { tmp =>
castStringToDecimal(tmp, dt)
}
case (cv, Some(dt)) if (dt == DoubleType || dt == FloatType) && cv.getType == DType.STRING =>
castStringToFloat(cv, GpuColumnVector.getNonNestedRapidsType(dt), options)
case (cv, Some(dt))
if (dt == ByteType || dt == ShortType || dt == IntegerType || dt == LongType ) &&
cv.getType == DType.STRING =>
withResource(sanitizeInts(cv)) { tmp =>
CastStrings.toInteger(tmp, false, GpuColumnVector.getNonNestedRapidsType(dt))
}
convertStringToTimestamp(cv, options)
case (cv, Some(dt)) if cv.getType == DType.STRING =>
GpuCast.doCast(cv, StringType, dt)
// There is an issue with the Schema implementation such that the schema's top level
// is never used when passing down data schema from Java to C++.
// As such, we have to wrap the current column schema `dt` in a struct schema.
val builder = Schema.builder // This is created as a struct schema
populateSchema(dt, "", builder)
JSONUtils.convertFromStrings(cv, builder.build, options.allowNonNumericNumbers,
options.locale == Locale.US)
}
}


/**
* Convert the parsed input table to the desired output types
* @param table the table to start with
Expand All @@ -320,10 +164,28 @@ object GpuJsonReadCommon {
}
}

def cudfJsonOptions(options: JSONOptions): ai.rapids.cudf.JSONOptions =
cudfJsonOptionBuilder(options).build()
/**
* Convert a strings column into date/time types.
* @param inputCv The input column vector
* @param topLevelType The desired output data type
* @param options JSON options for the conversion
* @return The converted column vector
*/
def convertDateTimeType(inputCv: ColumnVector,
topLevelType: DataType,
options: JSONOptions): ColumnVector = {
withResource(new NvtxRange("convertDateTimeType", NvtxColor.RED)) { _ =>
ColumnCastUtil.deepTransform(inputCv, Some(topLevelType),
Some(nestedColumnViewMismatchTransform)) {
case (cv, Some(DateType)) if cv.getType == DType.STRING =>
convertStringToDate(cv, options)
case (cv, Some(TimestampType)) if cv.getType == DType.STRING =>
convertStringToTimestamp(cv, options)
}
}
}

def cudfJsonOptionBuilder(options: JSONOptions): ai.rapids.cudf.JSONOptions.Builder = {
def cudfJsonOptions(options: JSONOptions): ai.rapids.cudf.JSONOptions = {
// This is really ugly, but options.allowUnquotedControlChars is marked as private
// and this is the only way I know to get it without even uglier tricks
@scala.annotation.nowarn("msg=Java enum ALLOW_UNQUOTED_CONTROL_CHARS in " +
Expand All @@ -332,16 +194,17 @@ object GpuJsonReadCommon {
.isEnabled(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS)

ai.rapids.cudf.JSONOptions.builder()
.withRecoverWithNull(true)
.withMixedTypesAsStrings(true)
.withNormalizeWhitespace(true)
.withKeepQuotes(true)
.withNormalizeSingleQuotes(options.allowSingleQuotes)
.withStrictValidation(true)
.withLeadingZeros(options.allowNumericLeadingZeros)
.withNonNumericNumbers(options.allowNonNumericNumbers)
.withUnquotedControlChars(allowUnquotedControlChars)
.withCudfPruneSchema(true)
.withExperimental(true)
.withRecoverWithNull(true)
.withMixedTypesAsStrings(true)
.withNormalizeWhitespace(true)
.withKeepQuotes(true)
.withNormalizeSingleQuotes(options.allowSingleQuotes)
.withStrictValidation(true)
.withLeadingZeros(options.allowNumericLeadingZeros)
.withNonNumericNumbers(options.allowNonNumericNumbers)
.withUnquotedControlChars(allowUnquotedControlChars)
.withCudfPruneSchema(true)
.withExperimental(true)
.build()
}
}
Loading

0 comments on commit daaaf24

Please sign in to comment.