From c19f01af60458bf6ffd24f2505581d924fffffd2 Mon Sep 17 00:00:00 2001 From: Nihed MBAREK Date: Thu, 16 Feb 2017 02:47:30 -0800 Subject: [PATCH] add support for DateType Hi, based on this issue https://github.com/databricks/spark-avro/issues/67 I create this pull request Author: Nihed MBAREK Author: vlyubin Author: nihed Closes #124 from nihed/master. --- .../spark/avro/AvroOutputWriter.scala | 3 ++ .../spark/avro/SchemaConverters.scala | 2 ++ .../com/databricks/spark/avro/AvroSuite.scala | 33 +++++++++++++++---- .../spark/avro/AvroWriteBenchmark.scala | 7 ++-- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index bc71564e..297c39d6 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -19,6 +19,7 @@ package com.databricks.spark.avro import java.io.{IOException, OutputStream} import java.nio.ByteBuffer import java.sql.Timestamp +import java.sql.Date import java.util.HashMap import org.apache.hadoop.fs.Path @@ -90,6 +91,8 @@ private[avro] class AvroOutputWriter( case _: DecimalType => (item: Any) => if (item == null) null else item.toString case TimestampType => (item: Any) => if (item == null) null else item.asInstanceOf[Timestamp].getTime + case DateType => (item: Any) => + if (item == null) null else item.asInstanceOf[Date].getTime case ArrayType(elementType, _) => val elementConverter = createConverterToAvro(elementType, structName, recordNamespace) (item: Any) => { diff --git a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala index aa634d4c..7f8e20f4 100644 --- a/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala +++ b/src/main/scala/com/databricks/spark/avro/SchemaConverters.scala @@ -328,6 +328,7 @@ object SchemaConverters { case BinaryType => schemaBuilder.bytesType() case BooleanType => schemaBuilder.booleanType() case TimestampType => schemaBuilder.longType() + case DateType => schemaBuilder.longType() case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) @@ -371,6 +372,7 @@ object SchemaConverters { case BinaryType => newFieldBuilder.bytesType() case BooleanType => newFieldBuilder.booleanType() case TimestampType => newFieldBuilder.longType() + case DateType => newFieldBuilder.longType() case ArrayType(elementType, _) => val builder = getSchemaBuilder(dataType.asInstanceOf[ArrayType].containsNull) diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 1b5d07aa..4843ad46 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -19,22 +19,21 @@ package com.databricks.spark.avro import java.io._ import java.nio.ByteBuffer import java.nio.file.Files -import java.sql.Timestamp -import java.util.UUID +import java.sql.{Date, Timestamp} +import java.util.{TimeZone, UUID} import scala.collection.JavaConversions._ - -import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} import org.apache.avro.file.DataFileWriter -import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils - -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.SparkContext +import org.apache.spark.sql._ import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterAll, FunSuite} +import com.databricks.spark.avro.SchemaConverters.IncompatibleSchemaException class AvroSuite extends FunSuite with BeforeAndAfterAll { val episodesFile = "src/test/resources/episodes.avro" @@ -297,6 +296,26 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } + test("Date field type") { + TestUtils.withTempDir { dir => + val schema = StructType(Seq( + StructField("float", FloatType, true), + StructField("date", DateType, true) + )) + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + val rdd = spark.sparkContext.parallelize(Seq( + Row(1f, null), + Row(2f, new Date(1451948400000L)), + Row(3f, new Date(1460066400500L)) + )) + val df = spark.createDataFrame(rdd, schema) + df.write.avro(dir.toString) + assert(spark.read.avro(dir.toString).count == rdd.count) + assert(spark.read.avro(dir.toString).select("date").collect().map(_(0)).toSet == + Array(null, 1451865600000L, 1459987200000L).toSet) + } + } + test("Array data types") { TestUtils.withTempDir { dir => val testSchema = StructType(Seq( diff --git a/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala b/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala index b36438c1..2ccc456b 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroWriteBenchmark.scala @@ -16,6 +16,7 @@ package com.databricks.spark.avro +import java.sql.Date import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -23,8 +24,7 @@ import scala.util.Random import com.google.common.io.Files import org.apache.commons.io.FileUtils - -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql._ import org.apache.spark.sql.types._ /** @@ -40,6 +40,7 @@ object AvroWriteBenchmark { val testSchema = StructType(Seq( StructField("StringField", StringType, false), StructField("IntField", IntegerType, true), + StructField("dateField", DateType, true), StructField("DoubleField", DoubleType, false), StructField("DecimalField", DecimalType(10, 10), true), StructField("ArrayField", ArrayType(BooleanType), false), @@ -48,7 +49,7 @@ object AvroWriteBenchmark { private def generateRandomRow(): Row = { val rand = new Random() - Row(rand.nextString(defaultSize), rand.nextInt(), rand.nextDouble(), rand.nextDouble(), + Row(rand.nextString(defaultSize), rand.nextInt(), new Date(rand.nextLong()) ,rand.nextDouble(), rand.nextDouble(), TestUtils.generateRandomArray(rand, defaultSize).toSeq, TestUtils.generateRandomMap(rand, defaultSize).toMap, Row(rand.nextInt())) }