diff --git a/src/main/scala/com/databricks/spark/avro/RddUtils.scala b/src/main/scala/com/databricks/spark/avro/RddUtils.scala new file mode 100644 index 00000000..12495bc1 --- /dev/null +++ b/src/main/scala/com/databricks/spark/avro/RddUtils.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.spark.avro + +import SchemaConverters._ +import scala.util.Try +import org.apache.avro.generic.GenericRecord +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.rdd.RDD + +/** + * [[RDD]] implicits. + */ +object RddUtils { + /** + * Extensions to [[RDD]]s of [[GenericRecord]]s. + * + * @param rdd the [[RDD]] to decorate with additional functionality. + */ + implicit class RddToDataFrame(val rdd: RDD[GenericRecord]) { + /** + * Convert a [[RDD]] of [[GenericRecord]]s to a [[DataFrame]] + * + * @return the [[DataFrame]] + */ + def toDF(): DataFrame = { + val spark = SparkSession + .builder + .config(rdd.sparkContext.getConf) + .getOrCreate() + + val avroSchema = rdd.take(1)(0).getSchema + val dataFrameSchema = toSqlType(avroSchema).dataType.asInstanceOf[StructType] + val converter = createConverterToSQL(avroSchema, dataFrameSchema) + + val rowRdd = rdd.flatMap { record => + Try(converter(record).asInstanceOf[Row]).toOption + } + + spark.createDataFrame(rowRdd, dataFrameSchema) + } + } +} diff --git a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala index 1b5d07aa..b3a0c97f 100644 --- a/src/test/scala/com/databricks/spark/avro/AvroSuite.scala +++ b/src/test/scala/com/databricks/spark/avro/AvroSuite.scala @@ -30,8 +30,10 @@ import org.apache.avro.Schema.{Field, Type} import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} +import org.apache.avro.mapred.{AvroInputFormat, AvroWrapper} import org.apache.commons.io.FileUtils +import org.apache.hadoop.io.NullWritable import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.types._ import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -48,6 +50,7 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { .master("local[2]") .appName("AvroSuite") .config("spark.sql.files.maxPartitionBytes", 1024) + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .getOrCreate() } @@ -59,6 +62,22 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll { } } + test("converting rdd to dataframe") { + val rdd = spark.sparkContext.hadoopFile[ + AvroWrapper[GenericRecord], + NullWritable, + AvroInputFormat[GenericRecord] + ](testFile).map(_._1.datum) + + import RddUtils.RddToDataFrame + + val df1 = rdd.toDF + val df2 = spark.read.avro(testFile) + + assert(df1.schema.simpleString === df2.schema.simpleString) + assert(df1.orderBy("string").collect === df2.orderBy("string").collect) + } + test("reading and writing partitioned data") { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor")