Skip to content

Commit

Permalink
UDAFs for Datasets (#1)
Browse files Browse the repository at this point in the history
* sbt infra

* update .gitignore

* TDigestUDT.scala

* A working UDAF for TDigest

* ignore null entries

* bug repro

* workaround SPARK-21277 with flattened unsafe storage

* remove silex dep

* complete draft of DataFrame UDAF suite

* rc1
  • Loading branch information
erikerlandson authored Jul 1, 2017
1 parent 34689ea commit 41f6da1
Show file tree
Hide file tree
Showing 7 changed files with 469 additions and 0 deletions.
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
*.class
*.log

# sbt specific
.cache
.history
.lib/
dist/*
target/
lib_managed/
src_managed/
project/boot/
project/plugins/project/

# Scala-IDE specific
.scala_dependencies
.worksheet
59 changes: 59 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name := "isarn-sketches-spark"

organization := "org.isarnproject"

bintrayOrganization := Some("isarn")

version := "0.1.0.rc1"

scalaVersion := "2.11.8"

crossScalaVersions := Seq("2.10.6", "2.11.8")

def commonSettings = Seq(
libraryDependencies ++= Seq(
"org.isarnproject" %% "isarn-sketches" % "0.1.0",
"org.apache.spark" %% "spark-core" % "2.1.0" % Provided,
"org.apache.spark" %% "spark-sql" % "2.1.0" % Provided,
"org.apache.spark" %% "spark-mllib" % "2.1.0" % Provided,
"org.isarnproject" %% "isarn-scalatest" % "0.0.1" % Test,
"org.scalatest" %% "scalatest" % "2.2.4" % Test,
"org.apache.commons" % "commons-math3" % "3.6.1" % Test),
initialCommands in console := """
|import org.apache.spark.SparkConf
|import org.apache.spark.SparkContext
|import org.apache.spark.sql.SparkSession
|import org.apache.spark.SparkContext._
|import org.apache.spark.rdd.RDD
|import org.apache.spark.ml.linalg.Vectors
|import org.isarnproject.sketches.TDigest
|import org.isarnproject.sketches.udaf._
|import org.apache.spark.isarnproject.sketches.udt._
|val initialConf = new SparkConf().setAppName("repl").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer", "16mb")
|val spark = SparkSession.builder.config(initialConf).master("local[2]").getOrCreate()
|import spark._
|val sc = spark.sparkContext
|import org.apache.log4j.{Logger, ConsoleAppender, Level}
|Logger.getRootLogger().getAppender("console").asInstanceOf[ConsoleAppender].setThreshold(Level.WARN)
""".stripMargin,
cleanupCommands in console := "spark.stop"
)

seq(commonSettings:_*)

licenses += ("Apache-2.0", url("http://opensource.org/licenses/Apache-2.0"))

scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature")

scalacOptions in (Compile, doc) ++= Seq("-doc-root-content", baseDirectory.value+"/root-doc.txt")

site.settings

site.includeScaladoc()

// Re-enable if/when we want to support gh-pages w/ jekyll
// site.jekyllSupport()

ghpages.settings

git.remoteRepo := "[email protected]:isarn/isarn-sketches-spark.git"
1 change: 1 addition & 0 deletions project/build.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version=0.13.11
20 changes: 20 additions & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
resolvers += Resolver.url(
"bintray-sbt-plugin-releases",
url("http://dl.bintray.com/content/sbt/sbt-plugin-releases"))(
Resolver.ivyStylePatterns)

resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"

resolvers += "jgit-repo" at "http://download.eclipse.org/jgit/maven"

addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0")

addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.5.4")

addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0")

// scoverage and coveralls deps are at old versions to avoid a bug in the current versions
// update these when this fix is released: https://github.com/scoverage/sbt-coveralls/issues/73
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.0.4")

addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.0.0")
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
Copyright 2017 Erik Erlandson
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package org.apache.spark.isarnproject.sketches.udt

import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.isarnproject.sketches.TDigest
import org.isarnproject.sketches.tdmap.TDigestMap

@SQLUserDefinedType(udt = classOf[TDigestUDT])
case class TDigestSQL(tdigest: TDigest)

class TDigestUDT extends UserDefinedType[TDigestSQL] {
def userClass: Class[TDigestSQL] = classOf[TDigestSQL]

def sqlType: DataType = StructType(
StructField("delta", DoubleType, false) ::
StructField("maxDiscrete", IntegerType, false) ::
StructField("nclusters", IntegerType, false) ::
StructField("clustX", ArrayType(DoubleType, false), false) ::
StructField("clustM", ArrayType(DoubleType, false), false) ::
Nil)

def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest)

def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum))

private[sketches] def serializeTD(td: TDigest): InternalRow = {
val TDigest(delta, maxDiscrete, nclusters, clusters) = td
val row = new GenericInternalRow(5)
row.setDouble(0, delta)
row.setInt(1, maxDiscrete)
row.setInt(2, nclusters)
val clustX = clusters.keys.toArray
val clustM = clusters.values.toArray
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX))
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM))
row
}

private[sketches] def deserializeTD(datum: Any): TDigest = datum match {
case row: InternalRow =>
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}")
val delta = row.getDouble(0)
val maxDiscrete = row.getInt(1)
val nclusters = row.getInt(2)
val clustX = row.getArray(3).toDoubleArray()
val clustM = row.getArray(4).toDoubleArray()
val clusters = clustX.zip(clustM)
.foldLeft(TDigestMap.empty) { case (td, e) => td + e }
TDigest(delta, maxDiscrete, nclusters, clusters)
case u => throw new Exception(s"failed to deserialize: $u")
}
}

case object TDigestUDT extends TDigestUDT

@SQLUserDefinedType(udt = classOf[TDigestArrayUDT])
case class TDigestArraySQL(tdigests: Array[TDigest])

class TDigestArrayUDT extends UserDefinedType[TDigestArraySQL] {
def userClass: Class[TDigestArraySQL] = classOf[TDigestArraySQL]

// Spark seems to have trouble with ArrayType data that isn't
// serialized using UnsafeArrayData (SPARK-21277), so my workaround
// is to store all the cluster information flattened into single Unsafe arrays.
// To deserialize, I unpack the slices.
def sqlType: DataType = StructType(
StructField("delta", DoubleType, false) ::
StructField("maxDiscrete", IntegerType, false) ::
StructField("clusterS", ArrayType(IntegerType, false), false) ::
StructField("clusterX", ArrayType(DoubleType, false), false) ::
StructField("ClusterM", ArrayType(DoubleType, false), false) ::
Nil)

def serialize(tdasql: TDigestArraySQL): Any = {
val row = new GenericInternalRow(5)
val tda: Array[TDigest] = tdasql.tdigests
val delta = if (tda.isEmpty) 0.0 else tda.head.delta
val maxDiscrete = if (tda.isEmpty) 0 else tda.head.maxDiscrete
val clustS = tda.map(_.nclusters)
val clustX = tda.flatMap(_.clusters.keys)
val clustM = tda.flatMap(_.clusters.values)
row.setDouble(0, delta)
row.setInt(1, maxDiscrete)
row.update(2, UnsafeArrayData.fromPrimitiveArray(clustS))
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX))
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM))
row
}

def deserialize(datum: Any): TDigestArraySQL = datum match {
case row: InternalRow =>
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}")
val delta = row.getDouble(0)
val maxDiscrete = row.getInt(1)
val clustS = row.getArray(2).toIntArray()
val clustX = row.getArray(3).toDoubleArray()
val clustM = row.getArray(4).toDoubleArray()
var beg = 0
val tda = clustS.map { nclusters =>
val x = clustX.slice(beg, beg + nclusters)
val m = clustM.slice(beg, beg + nclusters)
val clusters = x.zip(m).foldLeft(TDigestMap.empty) { case (td, e) => td + e }
val td = TDigest(delta, maxDiscrete, nclusters, clusters)
beg += nclusters
td
}
TDigestArraySQL(tda)
case u => throw new Exception(s"failed to deserialize: $u")
}
}

case object TDigestArrayUDT extends TDigestArrayUDT

// VectorUDT is private[spark], but I can expose what I need this way:
object TDigestUDTInfra {
private val udtML = new org.apache.spark.ml.linalg.VectorUDT
def udtVectorML: DataType = udtML

private val udtMLLib = new org.apache.spark.mllib.linalg.VectorUDT
def udtVectorMLLib: DataType = udtMLLib
}
Loading

0 comments on commit 41f6da1

Please sign in to comment.