-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* sbt infra * update .gitignore * TDigestUDT.scala * A working UDAF for TDigest * ignore null entries * bug repro * workaround SPARK-21277 with flattened unsafe storage * remove silex dep * complete draft of DataFrame UDAF suite * rc1
- Loading branch information
1 parent
34689ea
commit 41f6da1
Showing
7 changed files
with
469 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,17 @@ | ||
*.class | ||
*.log | ||
|
||
# sbt specific | ||
.cache | ||
.history | ||
.lib/ | ||
dist/* | ||
target/ | ||
lib_managed/ | ||
src_managed/ | ||
project/boot/ | ||
project/plugins/project/ | ||
|
||
# Scala-IDE specific | ||
.scala_dependencies | ||
.worksheet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
name := "isarn-sketches-spark" | ||
|
||
organization := "org.isarnproject" | ||
|
||
bintrayOrganization := Some("isarn") | ||
|
||
version := "0.1.0.rc1" | ||
|
||
scalaVersion := "2.11.8" | ||
|
||
crossScalaVersions := Seq("2.10.6", "2.11.8") | ||
|
||
def commonSettings = Seq( | ||
libraryDependencies ++= Seq( | ||
"org.isarnproject" %% "isarn-sketches" % "0.1.0", | ||
"org.apache.spark" %% "spark-core" % "2.1.0" % Provided, | ||
"org.apache.spark" %% "spark-sql" % "2.1.0" % Provided, | ||
"org.apache.spark" %% "spark-mllib" % "2.1.0" % Provided, | ||
"org.isarnproject" %% "isarn-scalatest" % "0.0.1" % Test, | ||
"org.scalatest" %% "scalatest" % "2.2.4" % Test, | ||
"org.apache.commons" % "commons-math3" % "3.6.1" % Test), | ||
initialCommands in console := """ | ||
|import org.apache.spark.SparkConf | ||
|import org.apache.spark.SparkContext | ||
|import org.apache.spark.sql.SparkSession | ||
|import org.apache.spark.SparkContext._ | ||
|import org.apache.spark.rdd.RDD | ||
|import org.apache.spark.ml.linalg.Vectors | ||
|import org.isarnproject.sketches.TDigest | ||
|import org.isarnproject.sketches.udaf._ | ||
|import org.apache.spark.isarnproject.sketches.udt._ | ||
|val initialConf = new SparkConf().setAppName("repl").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer", "16mb") | ||
|val spark = SparkSession.builder.config(initialConf).master("local[2]").getOrCreate() | ||
|import spark._ | ||
|val sc = spark.sparkContext | ||
|import org.apache.log4j.{Logger, ConsoleAppender, Level} | ||
|Logger.getRootLogger().getAppender("console").asInstanceOf[ConsoleAppender].setThreshold(Level.WARN) | ||
""".stripMargin, | ||
cleanupCommands in console := "spark.stop" | ||
) | ||
|
||
seq(commonSettings:_*) | ||
|
||
licenses += ("Apache-2.0", url("http://opensource.org/licenses/Apache-2.0")) | ||
|
||
scalacOptions ++= Seq("-unchecked", "-deprecation", "-feature") | ||
|
||
scalacOptions in (Compile, doc) ++= Seq("-doc-root-content", baseDirectory.value+"/root-doc.txt") | ||
|
||
site.settings | ||
|
||
site.includeScaladoc() | ||
|
||
// Re-enable if/when we want to support gh-pages w/ jekyll | ||
// site.jekyllSupport() | ||
|
||
ghpages.settings | ||
|
||
git.remoteRepo := "[email protected]:isarn/isarn-sketches-spark.git" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
sbt.version=0.13.11 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
resolvers += Resolver.url( | ||
"bintray-sbt-plugin-releases", | ||
url("http://dl.bintray.com/content/sbt/sbt-plugin-releases"))( | ||
Resolver.ivyStylePatterns) | ||
|
||
resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" | ||
|
||
resolvers += "jgit-repo" at "http://download.eclipse.org/jgit/maven" | ||
|
||
addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0") | ||
|
||
addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.5.4") | ||
|
||
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") | ||
|
||
// scoverage and coveralls deps are at old versions to avoid a bug in the current versions | ||
// update these when this fix is released: https://github.com/scoverage/sbt-coveralls/issues/73 | ||
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.0.4") | ||
|
||
addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.0.0") |
136 changes: 136 additions & 0 deletions
136
src/main/scala/org/apache/spark/isarnproject/sketches/udt/TDigestUDT.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/* | ||
Copyright 2017 Erik Erlandson | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.isarnproject.sketches.udt | ||
|
||
import org.apache.spark.sql.catalyst.util._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} | ||
import org.isarnproject.sketches.TDigest | ||
import org.isarnproject.sketches.tdmap.TDigestMap | ||
|
||
@SQLUserDefinedType(udt = classOf[TDigestUDT]) | ||
case class TDigestSQL(tdigest: TDigest) | ||
|
||
class TDigestUDT extends UserDefinedType[TDigestSQL] { | ||
def userClass: Class[TDigestSQL] = classOf[TDigestSQL] | ||
|
||
def sqlType: DataType = StructType( | ||
StructField("delta", DoubleType, false) :: | ||
StructField("maxDiscrete", IntegerType, false) :: | ||
StructField("nclusters", IntegerType, false) :: | ||
StructField("clustX", ArrayType(DoubleType, false), false) :: | ||
StructField("clustM", ArrayType(DoubleType, false), false) :: | ||
Nil) | ||
|
||
def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest) | ||
|
||
def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum)) | ||
|
||
private[sketches] def serializeTD(td: TDigest): InternalRow = { | ||
val TDigest(delta, maxDiscrete, nclusters, clusters) = td | ||
val row = new GenericInternalRow(5) | ||
row.setDouble(0, delta) | ||
row.setInt(1, maxDiscrete) | ||
row.setInt(2, nclusters) | ||
val clustX = clusters.keys.toArray | ||
val clustM = clusters.values.toArray | ||
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX)) | ||
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM)) | ||
row | ||
} | ||
|
||
private[sketches] def deserializeTD(datum: Any): TDigest = datum match { | ||
case row: InternalRow => | ||
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}") | ||
val delta = row.getDouble(0) | ||
val maxDiscrete = row.getInt(1) | ||
val nclusters = row.getInt(2) | ||
val clustX = row.getArray(3).toDoubleArray() | ||
val clustM = row.getArray(4).toDoubleArray() | ||
val clusters = clustX.zip(clustM) | ||
.foldLeft(TDigestMap.empty) { case (td, e) => td + e } | ||
TDigest(delta, maxDiscrete, nclusters, clusters) | ||
case u => throw new Exception(s"failed to deserialize: $u") | ||
} | ||
} | ||
|
||
case object TDigestUDT extends TDigestUDT | ||
|
||
@SQLUserDefinedType(udt = classOf[TDigestArrayUDT]) | ||
case class TDigestArraySQL(tdigests: Array[TDigest]) | ||
|
||
class TDigestArrayUDT extends UserDefinedType[TDigestArraySQL] { | ||
def userClass: Class[TDigestArraySQL] = classOf[TDigestArraySQL] | ||
|
||
// Spark seems to have trouble with ArrayType data that isn't | ||
// serialized using UnsafeArrayData (SPARK-21277), so my workaround | ||
// is to store all the cluster information flattened into single Unsafe arrays. | ||
// To deserialize, I unpack the slices. | ||
def sqlType: DataType = StructType( | ||
StructField("delta", DoubleType, false) :: | ||
StructField("maxDiscrete", IntegerType, false) :: | ||
StructField("clusterS", ArrayType(IntegerType, false), false) :: | ||
StructField("clusterX", ArrayType(DoubleType, false), false) :: | ||
StructField("ClusterM", ArrayType(DoubleType, false), false) :: | ||
Nil) | ||
|
||
def serialize(tdasql: TDigestArraySQL): Any = { | ||
val row = new GenericInternalRow(5) | ||
val tda: Array[TDigest] = tdasql.tdigests | ||
val delta = if (tda.isEmpty) 0.0 else tda.head.delta | ||
val maxDiscrete = if (tda.isEmpty) 0 else tda.head.maxDiscrete | ||
val clustS = tda.map(_.nclusters) | ||
val clustX = tda.flatMap(_.clusters.keys) | ||
val clustM = tda.flatMap(_.clusters.values) | ||
row.setDouble(0, delta) | ||
row.setInt(1, maxDiscrete) | ||
row.update(2, UnsafeArrayData.fromPrimitiveArray(clustS)) | ||
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX)) | ||
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM)) | ||
row | ||
} | ||
|
||
def deserialize(datum: Any): TDigestArraySQL = datum match { | ||
case row: InternalRow => | ||
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}") | ||
val delta = row.getDouble(0) | ||
val maxDiscrete = row.getInt(1) | ||
val clustS = row.getArray(2).toIntArray() | ||
val clustX = row.getArray(3).toDoubleArray() | ||
val clustM = row.getArray(4).toDoubleArray() | ||
var beg = 0 | ||
val tda = clustS.map { nclusters => | ||
val x = clustX.slice(beg, beg + nclusters) | ||
val m = clustM.slice(beg, beg + nclusters) | ||
val clusters = x.zip(m).foldLeft(TDigestMap.empty) { case (td, e) => td + e } | ||
val td = TDigest(delta, maxDiscrete, nclusters, clusters) | ||
beg += nclusters | ||
td | ||
} | ||
TDigestArraySQL(tda) | ||
case u => throw new Exception(s"failed to deserialize: $u") | ||
} | ||
} | ||
|
||
case object TDigestArrayUDT extends TDigestArrayUDT | ||
|
||
// VectorUDT is private[spark], but I can expose what I need this way: | ||
object TDigestUDTInfra { | ||
private val udtML = new org.apache.spark.ml.linalg.VectorUDT | ||
def udtVectorML: DataType = udtML | ||
|
||
private val udtMLLib = new org.apache.spark.mllib.linalg.VectorUDT | ||
def udtVectorMLLib: DataType = udtMLLib | ||
} |
Oops, something went wrong.