diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index 7c6f739b..22f65510 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -20,33 +20,49 @@ import java.io.{IOException, OutputStream} import java.nio.ByteBuffer import java.sql.Timestamp import java.sql.Date +import java.util import java.util.HashMap import org.apache.hadoop.fs.Path + import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import scala.collection.immutable.Map - import org.apache.avro.generic.GenericData.Record -import org.apache.avro.generic.GenericRecord +import org.apache.avro.generic.{GenericData, GenericRecord, GenericRecordBuilder} import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.mapred.AvroKey import org.apache.avro.mapreduce.AvroKeyOutputFormat import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID} - +import org.apache.log4j.Logger import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[avro] class AvroOutputWriter( path: String, context: TaskAttemptContext, schema: StructType, recordName: String, - recordNamespace: String) extends OutputWriter { + recordNamespace: String, + forceSchema: String) extends OutputWriter { + + private val logger = Logger.getLogger(this.getClass) + + private val forceAvroSchema = if (forceSchema.contentEquals("")) { + None + } else { + Option(new Schema.Parser().parse(forceSchema)) + } + private lazy val converter = createConverterToAvro( + schema, recordName, recordNamespace, forceAvroSchema + ) - private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace) // copy of the old conversion logic after api change in SPARK-19085 private lazy val internalRowConverter = CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row] @@ -83,6 +99,27 @@ private[avro] class AvroOutputWriter( override def close(): Unit = recordWriter.close(context) + private def resolveStructTypeToAvroUnion(schema:Schema, dataType:String): Schema = { + val allowedAvroTypes = dataType match { + case "boolean" => List(Schema.Type.BOOLEAN) + case "integer" => List(Schema.Type.INT) + case "long" => List(Schema.Type.LONG) + case "float" => List(Schema.Type.FLOAT) + case "double" => List(Schema.Type.DOUBLE) + case "binary" => List(Schema.Type.BYTES, Schema.Type.FIXED) + case "array" => List(Schema.Type.ARRAY) + case "map" => List(Schema.Type.MAP) + case "string" => List(Schema.Type.STRING, Schema.Type.ENUM) + case "struct" => List(Schema.Type.ARRAY, Schema.Type.RECORD) + case default => { + throw new RuntimeException( + s"Cannot map SparkSQL type '$dataType' against Avro schema '$schema'" + ) + } + } + schema.getTypes.find (allowedAvroTypes contains _.getType).get + } + /** * This function constructs converter function for a given sparkSQL datatype. This is used in * writing Avro records out to disk @@ -90,21 +127,51 @@ private[avro] class AvroOutputWriter( private def createConverterToAvro( dataType: DataType, structName: String, - recordNamespace: String): (Any) => Any = { + recordNamespace: String, + forceAvroSchema: Option[Schema]): (Any) => Any = { dataType match { case BinaryType => (item: Any) => item match { case null => null - case bytes: Array[Byte] => ByteBuffer.wrap(bytes) + case bytes: Array[Byte] => if (forceAvroSchema.isDefined) { + // Handle mapping from binary => bytes|fixed w/ forceSchema + forceAvroSchema.get.getType match { + case Schema.Type.BYTES => ByteBuffer.wrap(bytes) + case Schema.Type.FIXED => new GenericData.Fixed( + forceAvroSchema.get, bytes + ) + case default => bytes + } + } else { + ByteBuffer.wrap(bytes) + } } case ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | StringType | BooleanType => identity + FloatType | DoubleType | BooleanType => identity + case StringType => (item: Any) => if (forceAvroSchema.isDefined) { + // Handle case when forcing schema where this string should map + // to an ENUM + forceAvroSchema.get.getType match { + case Schema.Type.ENUM => new GenericData.EnumSymbol( + forceAvroSchema.get, item.toString + ) + case default => item + } + } else { + item + } case _: DecimalType => (item: Any) => if (item == null) null else item.toString case TimestampType => (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime case DateType => (item: Any) => if (item == null) null else item.asInstanceOf[Date].getTime case ArrayType(elementType, _) => - val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) + val elementConverter = if (forceAvroSchema.isDefined) { + createConverterToAvro(elementType, structName, + recordNamespace, Option(forceAvroSchema.get.getElementType)) + } else { + createConverterToAvro(elementType, structName, + recordNamespace, forceAvroSchema) + } (item: Any) => { if (item == null) { null @@ -117,14 +184,20 @@ private[avro] class AvroOutputWriter( targetArray(idx) = elementConverter(sourceArray(idx)) idx += 1 } - targetArray + targetArray.toSeq.asJava } } case MapType(StringType, valueType, _) => - val valueConverter = createConverterToAvro(valueType, structName, recordNamespace) + val valueConverter = if (forceAvroSchema.isDefined) { + createConverterToAvro(valueType, structName, + recordNamespace, Option(forceAvroSchema.get.getValueType)) + } else { + createConverterToAvro(valueType, structName, + recordNamespace, forceAvroSchema) + } (item: Any) => { if (item == null) { - null + if (forceAvroSchema.isDefined) new HashMap[String, Any]() else null } else { val javaMap = new HashMap[String, Any]() item.asInstanceOf[Map[String, Any]].foreach { case (key, value) => @@ -135,10 +208,36 @@ private[avro] class AvroOutputWriter( } case structType: StructType => val builder = SchemaBuilder.record(structName).namespace(recordNamespace) - val schema: Schema = SchemaConverters.convertStructToAvro( - structType, builder, recordNamespace) - val fieldConverters = structType.fields.map(field => - createConverterToAvro(field.dataType, field.name, recordNamespace)) + val schema: Schema = if (!forceAvroSchema.isDefined) { + SchemaConverters.convertStructToAvro( + structType, builder, recordNamespace) + } else { + if (forceAvroSchema.get.getType == Schema.Type.ARRAY) { + forceAvroSchema.get.getElementType + } else { + forceAvroSchema.get + } + } + + val fieldConverters = structType.fields.map ( + field => { + val fieldConvertSchema = if (forceAvroSchema.isDefined) { + val thisFieldSchema = schema.getField(field.name).schema + Option( + thisFieldSchema.getType match { + case Schema.Type.UNION => { + resolveStructTypeToAvroUnion(thisFieldSchema, field.dataType.typeName) + } + case default => thisFieldSchema + } + ) + } else { + forceAvroSchema + } + createConverterToAvro(field.dataType, field.name, recordNamespace, fieldConvertSchema) + } + ) + (item: Any) => { if (item == null) { null @@ -150,9 +249,27 @@ private[avro] class AvroOutputWriter( while (convertersIterator.hasNext) { val converter = convertersIterator.next() - record.put(fieldNamesIterator.next(), converter(rowIterator.next())) + val fieldValue = rowIterator.next() + val fieldName = fieldNamesIterator.next() + try { + record.put(fieldName, converter(fieldValue)) + } catch { + case ex:NullPointerException => { + // This can happen with forceAvroSchema conversion + if (forceAvroSchema.isDefined) { + logger.debug(s"Trying to write field $fieldName which may be null? $fieldValue") + } else { + // Keep previous behavior when forceAvroSchema is not used + throw ex + } + } + } + } + if(forceAvroSchema.isDefined) { + new GenericRecordBuilder(record).build() + } else { + record } - record } } } diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala index 3f3cbf07..728c4125 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala @@ -16,24 +16,33 @@ package com.databricks.spark.avro +import org.apache.avro.Schema import org.apache.hadoop.mapreduce.TaskAttemptContext - import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType private[avro] class AvroOutputWriterFactory( schema: StructType, recordName: String, - recordNamespace: String) extends OutputWriterFactory { + recordNamespace: String, + forceSchema: String) extends OutputWriterFactory { - override def getFileExtension(context: TaskAttemptContext): String = { + def getFileExtension(context: TaskAttemptContext): String = { ".avro" } - override def newInstance( + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + newInstance(path, dataSchema, context) + } + + def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new AvroOutputWriter(path, context, schema, recordName, recordNamespace) + new AvroOutputWriter(path, context, schema, recordName, recordNamespace, forceSchema) } } diff --git a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala index bfbadd7c..191022ab 100644 --- a/src/main/scala/com/databricks/spark/avro/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/avro/DefaultSource.scala @@ -112,10 +112,17 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { dataSchema: StructType): OutputWriterFactory = { val recordName = options.getOrElse("recordName", "topLevelRecord") val recordNamespace = options.getOrElse("recordNamespace", "") + val forceAvroSchema = options.getOrElse("forceSchema", "") val build = SchemaBuilder.record(recordName).namespace(recordNamespace) - val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) + val outputAvroSchema = if (forceAvroSchema.contentEquals("")) { + SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace) + } else { + val parser = new Schema.Parser() + parser.parse(forceAvroSchema) + } AvroJob.setOutputKeySchema(job, outputAvroSchema) + val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val COMPRESS_KEY = "mapred.output.compress" @@ -142,7 +149,7 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister { log.error(s"unsupported compression codec $unknown") } - new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace) + new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace, forceAvroSchema) } override def buildReader( diff --git a/src/test/resources/messy.avro b/src/test/resources/messy.avro new file mode 100644 index 00000000..3dec11c5 Binary files /dev/null and b/src/test/resources/messy.avro differ diff --git a/src/test/resources/messy.avsc b/src/test/resources/messy.avsc new file mode 100644 index 00000000..a0a385c8 --- /dev/null +++ b/src/test/resources/messy.avsc @@ -0,0 +1,224 @@ +{ + "type": "record", + "name": "MessyRecord", + "namespace": "foo.bar.baz", + "doc": "This record encapsulates many failure cases for forceSchema support", + "fields": [ + { + "name": "someEnumField", + "type": { + "type": "enum", + "name": "someEnum", + "symbols": [ + "string1", + "string2", + "string3" + ] + } + }, + { + "name": "someString", + "type": "string" + }, + { + "name": "someNullUnionBool", + "type": [ + "null", + "boolean" + ] + }, + { + "name": "someNullUnionInt", + "type": [ + "null", + "int" + ] + }, + { + "name": "someNullUnionLong", + "type": [ + "null", + "long" + ] + }, + { + "name": "someNullUnionFloat", + "type": [ + "null", + "float" + ] + }, + { + "name": "someNullUnionDouble", + "type": [ + "null", + "double" + ] + }, + { + "name": "someNullUnionBytes", + "type": [ + "null", + "bytes" + ] + }, + { + "name": "someNullUnionString", + "type": [ + "null", + "string" + ] + }, + { + "name": "someNullUnionEnum", + "type": [ + "null", + { + "type": "enum", + "name": "anotherEnum", + "symbols": [ + "string1", + "string2", + "string3" + ] + } + ] + }, + { + "name": "someNullUnionArray", + "type": [ + "null", + { + "type": "array", + "items": "string" + } + ] + }, + { + "name": "someNullUnionMap", + "type": [ + "null", + { + "type": "map", + "values": "long" + } + ], + "order": "ignore", + "default": null + }, + { + "name": "someNullUnionFixed", + "type": [ + "null", + { + "type": "fixed", + "name": "foo", + "size": 8 + } + ] + }, + { + "name": "someRecordField", + "doc": "Field contains a record (or null) with crazy nested stuff", + "type": [ + "null", + { + "type": "record", + "name": "someRecord", + "fields": [ + { + "name": "someSubRecordField", + "type": { + "type": "record", + "name": "someSubRecords", + "fields": [ + { + "name": "someSubRecordsField", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "someSubSubRecords", + "fields": [ + { + "name": "notQuiteDeepEnoughRecordsField", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "notQuiteDeepEnoughRecord", + "fields": [ + { + "name": "notQuiteDeepEnoughRecords", + "type": { + "type": "array", + "items": { + "type": "record", + "name": "okDeepEnoughRecord", + "fields": [ + { + "name": "someReallyDeepString", + "type": "string" + }, + { + "name": "someReallyDeepInt", + "type": "int" + }, + { + "name": "someReallyDeepBooleanWithDefault", + "type": "boolean", + "default": false + } + ] + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + } + ] + } + ], + "default": null + }, + { + "name": "someIntWithDefault", + "type": "int", + "default": 0 + }, + { + "name": "someFloatWithDefault", + "type": "float", + "default": 3.14159 + }, + { + "name": "someLongWithDefault", + "type": "long", + "default": 9999999999999 + }, + { + "name": "someDoubleWithDefault", + "type": "double", + "default": 3.141592654 + }, + { + "name": "someStringMap", + "doc": "This will make sure order: ignore is preseved, null values are defaulted, etc.", + "type": { + "type": "map", + "values": "string" + }, + "order": "ignore", + "default": { + } + } + ] +} diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 4843ad46..c450d47e 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -34,9 +34,13 @@ import org.apache.spark.sql._ import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterAll, FunSuite} import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema class AvroSuite extends FunSuite with BeforeAndAfterAll { val episodesFile = "src/test/resources/episodes.avro" + val messyFile = "src/test/resources/messy.avro" + val messySchemaFile = "src/test/resources/messy.avsc" val testFile = "src/test/resources/test.avro" private var spark: SparkSession = _ @@ -642,6 +646,85 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } + // Read an avro, write it with converted schema, read it, write it with original forceSchema + // Then, test that data does not change + test("test read original schema, write with converted schema, read, write with original schema") { + // Test if load works as expected + TestUtils.withTempDir { tempDir => + // Let's imagine that + def getSomeSubRecords(df:Dataset[Row]): RDD[GenericRowWithSchema] = { + df.rdd.flatMap { + row => { + val topStruct = row.getAs[GenericRowWithSchema]("someRecordField") + if (topStruct != null){ + val subStruct = topStruct.getAs[GenericRowWithSchema]( + topStruct.fieldIndex("someSubRecordField") + ) + subStruct.getSeq(subStruct.fieldIndex("someSubRecordsField")) + } else { + Array[GenericRowWithSchema]() + } + } + } + } + + def getNotQuiteDeepEnoughRecords(df:Dataset[Row]): RDD[GenericRowWithSchema] = { + getSomeSubRecords(df).flatMap ( + row => row.getSeq(row.fieldIndex("notQuiteDeepEnoughRecordsField")) + ) + } + + def getSumOfReallyDeepInts(df:Dataset[Row]): Long = { + getNotQuiteDeepEnoughRecords(df).flatMap { + row => { + val topLevel:Seq[GenericRowWithSchema] = row + .getSeq(row.fieldIndex("notQuiteDeepEnoughRecords")) + topLevel.map(sub => sub.getInt(sub.fieldIndex("someReallyDeepInt"))) + } + }.reduce((x:Int,y:Int) => x + y) + } + + val forceSchema = scala.io.Source + .fromFile("src/test/resources/messy.avsc") + .getLines() + .mkString("\n") + + val df = spark.read.avro(messyFile) + assert(df.count == 10) + + val tempSaveDir1 = s"$tempDir/save1/" + val tempSaveDir2 = s"$tempDir/save2/" + + df.write.avro(tempSaveDir1) + + val newDf = spark.read.avro(tempSaveDir1) + assert(newDf.count == 10) + + // number of someSubRecords in dataset + val numSomeSubRecords1 = getSomeSubRecords(newDf).collect.length + val notQuiteDeepEnoughRecords1 = getNotQuiteDeepEnoughRecords(newDf).collect.length + val sumOfReallyDeepInts1 = getSumOfReallyDeepInts(newDf) + assert(numSomeSubRecords1 == 33) + assert(notQuiteDeepEnoughRecords1 == 375) + assert(sumOfReallyDeepInts1 == -367812589) + + newDf.write + .option("forceSchema", forceSchema) + .avro(tempSaveDir2) + + val newerDf = spark.read.avro(tempSaveDir2) + assert(newerDf.count == 10) + + // number of someSubRecords in dataset + val numSomeSubRecords2 = getSomeSubRecords(newerDf).collect.length + val notQuiteDeepEnoughRecords2 = getNotQuiteDeepEnoughRecords(newDf).collect.length + val sumOfReallyDeepInts2 = getSumOfReallyDeepInts(newDf) + assert(numSomeSubRecords2 == 33) + assert(notQuiteDeepEnoughRecords2 == 375) + assert(sumOfReallyDeepInts2 == -367812589) + } + } + test("read avro with user defined schema: read partial columns") { val partialColumns = StructType(Seq( StructField("string", StringType, false),