Skip to content

Commit

Permalink
Merge branch 'main' into documentation
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala
#	src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala
#	src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala
  • Loading branch information
milos.colic committed Mar 14, 2022
2 parents a1ae2e1 + 6a60429 commit d657523
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ object H3IndexSystem extends IndexSystem with Serializable {
/**
* Get the index ID corresponding to the provided coordinates.
*
* @param x
* X coordinate of the point.
* @param y
* Y coordinate of the point.
* @param lon
* Longitude coordinate of the point.
* @param lat
* Latitude coordinate of the point.
* @param resolution
* Resolution of the index.
* @return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class ST_IntersectionAggregate(
override val dataType: DataType = BinaryType
private val emptyWKB = geometryAPI.geometry("POLYGON(EMPTY)", "WKT").toWKB

override def prettyName: String = "st_reduce_intersection"
override def prettyName: String = "st_intersection_aggregate"

override def update(accumulator: Array[Byte], inputRow: InternalRow): Array[Byte] = {
val state = accumulator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import com.databricks.mosaic.core.geometry.api.GeometryAPI
case class ST_NumPoints(inputGeom: Expression, geometryAPIName: String) extends UnaryExpression with NullIntolerant with CodegenFallback {

/**
* ST_Area expression returns are covered by the
* [[org.locationtech.jts.geom.Geometry]] instance extracted from inputGeom
* ST_NumPoints expression returns the number of points for a given geometry.
* expression.
*/

Expand All @@ -27,7 +26,7 @@ case class ST_NumPoints(inputGeom: Expression, geometryAPIName: String) extends

override def makeCopy(newArgs: Array[AnyRef]): Expression = {
val asArray = newArgs.take(1).map(_.asInstanceOf[Expression])
val res = ST_Area(asArray(0), geometryAPIName)
val res = ST_NumPoints(asArray(0), geometryAPIName)
res.copyTagsFrom(this)
res
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ object PointIndexLonLat {
new ExpressionInfo(
classOf[PointIndexLonLat].getCanonicalName,
db.orNull,
"point_index_latlon",
"point_index_lonlat",
"""
| _FUNC_(lat, lng, resolution) - Returns the h3 index of a point(lat, lng) at resolution.
| _FUNC_(lon, lat, resolution) - Returns the h3 index of a point(lon, lat) at resolution.
""".stripMargin,
"",
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
(exprs: Seq[Expression]) => ST_Intersects(exprs(0), exprs(1), geometryAPI.name)
)
registry.registerFunction(
FunctionIdentifier("st_intersects", database),
FunctionIdentifier("st_intersection", database),
ST_Intersection.registryExpressionInfo(database),
(exprs: Seq[Expression]) => ST_Intersection(exprs(0), exprs(1), geometryAPI.name)
)
Expand Down Expand Up @@ -388,14 +388,14 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
ColumnAdapter(MosaicFill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name))
def mosaicfill(geom: Column, resolution: Int): Column =
ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name))
def point_index(lon: Column, lat: Column, resolution: Column): Column =
def point_index(point: Column, resolution: Column): Column =
ColumnAdapter(PointIndex(point.expr, resolution.expr, indexSystem.name, geometryAPI.name))
def point_index(point: Column, resolution: Int): Column =
ColumnAdapter(PointIndex(point.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name))
def point_index_lonlat(lon: Column, lat: Column, resolution: Column): Column =
ColumnAdapter(PointIndexLonLat(lon.expr, lat.expr, resolution.expr, indexSystem.name))
def point_index(lon: Column, lat: Column, resolution: Int): Column =
def point_index_lonlat(lon: Column, lat: Column, resolution: Int): Column =
ColumnAdapter(PointIndexLonLat(lon.expr, lat.expr, lit(resolution).expr, indexSystem.name))
def point_index(geom: Column, resolution: Column): Column =
ColumnAdapter(PointIndex(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name))
def point_index(geom: Column, resolution: Int): Column =
ColumnAdapter(PointIndex(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name))
def polyfill(geom: Column, resolution: Column): Column =
ColumnAdapter(Polyfill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name))
def polyfill(geom: Column, resolution: Int): Column =
Expand Down
34 changes: 21 additions & 13 deletions src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.databricks.mosaic.sql

import scala.util.{Failure, Success, Try}
import scala.util._

import org.apache.spark.sql._
import org.apache.spark.sql.functions._
Expand All @@ -10,18 +10,15 @@ import com.databricks.mosaic.functions.MosaicContext

class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) {

val spark: SparkSession = analyzerMosaicFrame.sparkSession
import spark.implicits._
val mosaicContext: MosaicContext = MosaicContext.context
import mosaicContext.functions._

val defaultSampleFraction = 0.01

def getOptimalResolution(sampleFraction: Double): Int = {
getOptimalResolution(SampleStrategy(sampleFraction = Some(sampleFraction)))
}

private def getOptimalResolution(sampleStrategy: SampleStrategy): Int = {
def getOptimalResolution(sampleStrategy: SampleStrategy): Int = {
val ss = SparkSession.builder().getOrCreate()
import ss.implicits._

val metrics = getResolutionMetrics(sampleStrategy, 1, 100)
.select("resolution", "percentile_50_geometry_area")
Expand All @@ -33,6 +30,10 @@ class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) {
}

def getResolutionMetrics(sampleStrategy: SampleStrategy, lowerLimit: Int = 5, upperLimit: Int = 500): DataFrame = {
val mosaicContext = MosaicContext.context
import mosaicContext.functions._
val spark = SparkSession.builder().getOrCreate()

def areaPercentile(p: Double): Column = percentile_approx(col("area"), lit(p), lit(10000))

val percentiles = analyzerMosaicFrame
Expand Down Expand Up @@ -94,13 +95,24 @@ class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) {
private def getMeanIndexArea(sampleStrategy: SampleStrategy, resolution: Int): Double = {
val mosaicContext = MosaicContext.context
import mosaicContext.functions._
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._

val meanIndexAreaDf = analyzerMosaicFrame
.transform(sampleStrategy.transformer)
.withColumn("centroid", st_centroid2D(analyzerMosaicFrame.getGeometryColumn))
.select(
mean(st_area(index_geometry(point_index($"centroid.x", $"centroid.x", lit(resolution)))))
mean(
st_area(
index_geometry(
point_index_lonlat(
col("centroid").getItem("x"),
col("centroid").getItem("y"),
lit(resolution)
)
)
)
)
)

Try(meanIndexAreaDf.as[Double].collect.head) match {
Expand All @@ -109,18 +121,14 @@ class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) {
}
}

def getOptimalResolution(sampleRows: Int): Int = {
getOptimalResolution(SampleStrategy(sampleRows = Some(sampleRows)))
}

}

case class SampleStrategy(sampleFraction: Option[Double] = None, sampleRows: Option[Int] = None) {
def transformer(df: DataFrame): DataFrame = {
(sampleFraction, sampleRows) match {
case (Some(d), None) => df.sample(d)
case (None, Some(l)) => df.limit(l)
case (Some(d), Some(l)) => df.limit(l)
case (Some(_), Some(l)) => df.limit(l)
case _ => df
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TestIntersectionExpressions extends AnyFlatSpec with IntersectionExpressio
it should behave like intersects(MosaicContext.build(H3IndexSystem, JTS), spark)
}

"ST_IntersectionAggregate" should "compute the intersects flag via aggregate expression" in {
"ST_IntersectionAggregate" should "compute the intersection via aggregate expression" in {
it should behave like intersection(MosaicContext.build(H3IndexSystem, OGC), spark)
it should behave like intersection(MosaicContext.build(H3IndexSystem, JTS), spark)
}
Expand Down

0 comments on commit d657523

Please sign in to comment.