Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-27388][SQL] encoder for objects defined by properties (ie. Avro) #24299

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.{Expression, _}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}


/**
* A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s
* for classes whose fields are entirely defined by constructor params but should not be
Expand Down Expand Up @@ -273,6 +272,7 @@ object ScalaReflection extends ScalaReflection {
// We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
// to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
case t if isSubtype(t, localTypeOf[Seq[_]]) ||
isSubtype(t, localTypeOf[java.util.List[_]]) ||
isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
Expand Down Expand Up @@ -312,6 +312,37 @@ object ScalaReflection extends ScalaReflection {
mirror.runtimeClass(t.typeSymbol.asClass)
)

case t if isSubtype(t, localTypeOf[java.util.Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t

val classNameForKey = getClassNameFromType(keyType)
val classNameForValue = getClassNameFromType(valueType)

val newTypePath = walkedTypePath.recordMap(classNameForKey, classNameForValue)

val keyData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(keyType, p, newTypePath),
MapKeys(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(valueType, p, newTypePath),
MapValues(path)),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
ArrayBasedMapData.getClass,
ObjectType(classOf[java.util.Map[_, _]]),
"toJavaMap",
keyData :: valueData :: Nil,
returnNullable = false)

case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
getConstructor().newInstance()
Expand All @@ -330,9 +361,13 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)

case t if t.erasure.typeSymbol.asClass.isJavaEnum =>
createDeserializerForTypesSupportValueOf(
createDeserializerForString(path, returnNullable = false),
getClassFromType(tpe))

case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)

val cls = getClassFromType(tpe)

val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
Expand Down Expand Up @@ -365,6 +400,32 @@ object ScalaReflection extends ScalaReflection {
expressions.Literal.create(null, ObjectType(cls)),
newInstance
)
case t =>
val props = getObjectProperties(t)
val cls = getClassFromType(tpe)

val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false)

val setters = props.map { case (fieldName, getter, setter, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val clsName = getClassNameFromType(fieldType)
val newTypePath = walkedTypePath.recordField(clsName, fieldName)

val newPath = expressionWithNullSafety(
deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
newTypePath),
nullable = nullable,
newTypePath)

(setter, newPath)

}.toMap

val result = InitializeJavaBean(newInstance, setters)

expressions.If(IsNull(path),
expressions.Literal.create(null, ObjectType(cls)),
result)
}
}

Expand Down Expand Up @@ -437,15 +498,17 @@ object ScalaReflection extends ScalaReflection {
// Since List[_] also belongs to localTypeOf[Product], we put this case before
// "case t if definedByConstructorParams(t)" to make sure it will match to the
// case "localTypeOf[Seq[_]]"
case t if isSubtype(t, localTypeOf[Seq[_]]) =>
case t if isSubtype(t, localTypeOf[Seq[_]]) ||
isSubtype(t, localTypeOf[java.util.List[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)

case t if isSubtype(t, localTypeOf[Array[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
toCatalystArray(inputObject, elementType)

case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
case t if isSubtype(t, localTypeOf[Map[_, _]]) ||
isSubtype(t, localTypeOf[java.util.Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val keyClsName = getClassNameFromType(keyType)
val valueClsName = getClassNameFromType(valueType)
Expand Down Expand Up @@ -525,6 +588,10 @@ object ScalaReflection extends ScalaReflection {
val udtClass = udt.getClass
createSerializerForUserDefinedType(inputObject, udt, udtClass)

case t if t.erasure.typeSymbol.asClass.isJavaEnum =>
createSerializerForString(
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))

case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
throw new UnsupportedOperationException(
Expand All @@ -538,10 +605,6 @@ object ScalaReflection extends ScalaReflection {
"cannot be used as field name\n" + walkedTypePath)
}

// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
// is necessary here. Because for a nullable nested inputObject with struct data
// type, e.g. StructType(IntegerType, StringType), it will return nullable=true
// for IntegerType without KnownNotNull. And that's what we do not expect to.
val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
Expand All @@ -550,9 +613,31 @@ object ScalaReflection extends ScalaReflection {
}
createSerializerForObject(inputObject, fields)

case _ =>
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath)
case t =>
val props = getObjectProperties(t)
if (props.isEmpty) {
throw new
UnsupportedOperationException(s"No Encoder found for $tpe\n" + walkedTypePath)
} else {
if (seenTypeSet.contains(t)) {
throw new UnsupportedOperationException(
s"cannot have circular references in class, but got the circular reference of class $t")
}
val fields = props.map { case (fieldName, getter, setter, fieldType) =>
if (javaKeywords.contains(fieldName)) {
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
"cannot be used as field name\n" + walkedTypePath)
}

val fieldValue = Invoke(KnownNotNull(inputObject), getter, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
val newPath = walkedTypePath.recordField(clsName, fieldName)
(fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
}
createSerializerForObject(inputObject, fields)

}
}
}

Expand Down Expand Up @@ -645,11 +730,13 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if isSubtype(t, localTypeOf[Seq[_]]) =>
case t if isSubtype(t, localTypeOf[Seq[_]]) ||
isSubtype(t, localTypeOf[java.util.List[_]]) =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
case t if isSubtype(t, localTypeOf[Map[_, _]]) ||
isSubtype(t, localTypeOf[java.util.Map[_, _]]) =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
Expand Down Expand Up @@ -689,15 +776,26 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false)
case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false)
case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false)
case t if t.erasure.typeSymbol.asClass.isJavaEnum => Schema(StringType, true)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
case other =>
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
case t =>
val props = getObjectProperties(t)
if (props.isEmpty) {
throw new UnsupportedOperationException(s"Schema for type $t is not supported")
} else {
Schema(StructType(
props.map { case (fieldName, getter, setter, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)

}
}
}

Expand Down Expand Up @@ -871,6 +969,47 @@ trait ScalaReflection extends Logging {
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
* Returns: Seq[(porpertyName, getterName, setterName, propertyType)]
*
* Properties of the object are defined by a getter '[get]PropertyName():propertyType' and a
* setter '[set]PropertyName(value: propertyType):Unit' functions; where [get]PropertyName is
* the name of the getter function, and [set]PropertyName is the name of the setter function.
*/
def getObjectProperties(tpe: `Type`): Seq[(String, String, String, Type)] = {
def propertyName(name: String): String = {
if (name.indexOf("get") == 0 || name.indexOf("set") == 0) {
name.substring(3)
} else {
name
}
}

val getters = tpe.members.filter(method => method.isMethod &&
method.asMethod.paramLists.size == 1 &&
method.asMethod.paramLists.head.size == 0)
.map(method => {
(method.name.decodedName.toString,
method.asMethod.returnType)
})

val setters = tpe.members.filter(method => method.isMethod &&
method.asMethod.returnType =:= typeOf[Unit] &&
method.asMethod.paramLists.size == 1 &&
method.asMethod.paramLists.head.size == 1
).map(method => {
(method.name.decodedName.toString,
method.asMethod.paramLists.head.head.typeSignature)
})

(for {
a <- getters
b <- setters
if propertyName(a._1) == propertyName(b._1) && a._2 =:= b._2
} yield (propertyName(a._1), a._1, b._1, a._2))
.toSeq
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand Down Expand Up @@ -924,3 +1063,4 @@ trait ScalaReflection extends Logging {
}

}

Loading