diff --git a/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala b/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala index 3b8fbfa62..a2c41d88c 100644 --- a/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala +++ b/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala @@ -159,10 +159,10 @@ object H3IndexSystem extends IndexSystem with Serializable { /** * Get the index ID corresponding to the provided coordinates. * - * @param x - * X coordinate of the point. - * @param y - * Y coordinate of the point. + * @param lon + * Longitude coordinate of the point. + * @param lat + * Latitude coordinate of the point. * @param resolution * Resolution of the index. * @return diff --git a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala index 9069df00b..eba19b04d 100644 --- a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala +++ b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala @@ -29,7 +29,7 @@ case class ST_IntersectionAggregate( override val dataType: DataType = BinaryType private val emptyWKB = geometryAPI.geometry("POLYGON(EMPTY)", "WKT").toWKB - override def prettyName: String = "st_reduce_intersection" + override def prettyName: String = "st_intersection_aggregate" override def update(accumulator: Array[Byte], inputRow: InternalRow): Array[Byte] = { val state = accumulator diff --git a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala index e267a904f..c2d75ebbd 100644 --- a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala +++ b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala @@ -9,8 +9,7 @@ import com.databricks.mosaic.core.geometry.api.GeometryAPI case class ST_NumPoints(inputGeom: Expression, geometryAPIName: String) extends UnaryExpression with NullIntolerant with CodegenFallback { /** - * ST_Area expression returns are covered by the - * [[org.locationtech.jts.geom.Geometry]] instance extracted from inputGeom + * ST_NumPoints expression returns the number of points for a given geometry. * expression. */ @@ -27,7 +26,7 @@ case class ST_NumPoints(inputGeom: Expression, geometryAPIName: String) extends override def makeCopy(newArgs: Array[AnyRef]): Expression = { val asArray = newArgs.take(1).map(_.asInstanceOf[Expression]) - val res = ST_Area(asArray(0), geometryAPIName) + val res = ST_NumPoints(asArray(0), geometryAPIName) res.copyTagsFrom(this) res } diff --git a/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLatLon.scala b/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLatLon.scala deleted file mode 100644 index 048a00e1d..000000000 --- a/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLatLon.scala +++ /dev/null @@ -1,90 +0,0 @@ -package com.databricks.mosaic.expressions.index - -import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, NullIntolerant, TernaryExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types._ - -import com.databricks.mosaic.core.index.{H3IndexSystem, IndexSystemID} - -case class PointIndexLatLon(lat: Expression, lon: Expression, resolution: Expression, indexSystemName: String) - extends TernaryExpression - with ExpectsInputTypes - with NullIntolerant - with CodegenFallback { - - override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType, IntegerType) - - /** Expression output DataType. */ - override def dataType: DataType = LongType - - override def toString: String = s"point_index_latlon($lat, $lon, $resolution)" - - /** Overridden to ensure [[Expression.sql]] is properly formatted. */ - override def prettyName: String = "point_index_latlon" - - /** - * Computes the H3 index corresponding to the provided lat and long - * coordinates. - * - * @param input1 - * Any instance containing latitude. - * @param input2 - * Any instance containing longitude. - * @param input3 - * Any instance containing resolution. - * @return - * H3 index id in Long. - */ - override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - val resolution: Int = H3IndexSystem.getResolution(input3) - val x: Double = input1.asInstanceOf[Double] - val y: Double = input2.asInstanceOf[Double] - - val indexSystem = IndexSystemID.getIndexSystem(IndexSystemID(indexSystemName)) - - indexSystem.pointToIndex(x, y, resolution) - } - - override def makeCopy(newArgs: Array[AnyRef]): Expression = { - val asArray = newArgs.take(3).map(_.asInstanceOf[Expression]) - val res = PointIndexLatLon(asArray(0), asArray(1), asArray(2), indexSystemName) - res.copyTagsFrom(this) - res - } - - override def first: Expression = lat - - override def second: Expression = lon - - override def third: Expression = resolution - - override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - copy(lat = newFirst, lon = newSecond, resolution = newThird) - -} - -object PointIndexLatLon { - - /** Entry to use in the function registry. */ - def registryExpressionInfo(db: Option[String]): ExpressionInfo = - new ExpressionInfo( - classOf[PointIndexLatLon].getCanonicalName, - db.orNull, - "point_index_latlon", - """ - | _FUNC_(lat, lng, resolution) - Returns the h3 index of a point(lat, lng) at resolution. - """.stripMargin, - "", - """ - | Examples: - | > SELECT _FUNC_(a, b, 10); - | 622236721348804607 - | """.stripMargin, - "", - "misc_funcs", - "1.0", - "", - "built-in" - ) - -} \ No newline at end of file diff --git a/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala b/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala index 5ac5f27ce..eeec35de5 100644 --- a/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala +++ b/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala @@ -70,9 +70,9 @@ object PointIndexLonLat { new ExpressionInfo( classOf[PointIndexLonLat].getCanonicalName, db.orNull, - "point_index_latlon", + "point_index_lonlat", """ - | _FUNC_(lat, lng, resolution) - Returns the h3 index of a point(lat, lng) at resolution. + | _FUNC_(lon, lat, resolution) - Returns the h3 index of a point(lon, lat) at resolution. """.stripMargin, "", """ diff --git a/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala index 2f24d6512..55801e0b5 100644 --- a/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala @@ -247,7 +247,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends (exprs: Seq[Expression]) => ST_Intersects(exprs(0), exprs(1), geometryAPI.name) ) registry.registerFunction( - FunctionIdentifier("st_intersects", database), + FunctionIdentifier("st_intersection", database), ST_Intersection.registryExpressionInfo(database), (exprs: Seq[Expression]) => ST_Intersection(exprs(0), exprs(1), geometryAPI.name) ) @@ -388,14 +388,14 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(MosaicFill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) def mosaicfill(geom: Column, resolution: Int): Column = ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) - def point_index(lon: Column, lat: Column, resolution: Column): Column = + def point_index(point: Column, resolution: Column): Column = + ColumnAdapter(PointIndex(point.expr, resolution.expr, indexSystem.name, geometryAPI.name)) + def point_index(point: Column, resolution: Int): Column = + ColumnAdapter(PointIndex(point.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) + def point_index_lonlat(lon: Column, lat: Column, resolution: Column): Column = ColumnAdapter(PointIndexLonLat(lon.expr, lat.expr, resolution.expr, indexSystem.name)) - def point_index(lon: Column, lat: Column, resolution: Int): Column = + def point_index_lonlat(lon: Column, lat: Column, resolution: Int): Column = ColumnAdapter(PointIndexLonLat(lon.expr, lat.expr, lit(resolution).expr, indexSystem.name)) - def point_index(geom: Column, resolution: Column): Column = - ColumnAdapter(PointIndex(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) - def point_index(geom: Column, resolution: Int): Column = - ColumnAdapter(PointIndex(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) def polyfill(geom: Column, resolution: Column): Column = ColumnAdapter(Polyfill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) def polyfill(geom: Column, resolution: Int): Column = diff --git a/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala b/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala index 37936f827..34b0585e4 100644 --- a/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala +++ b/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala @@ -1,6 +1,6 @@ package com.databricks.mosaic.sql -import scala.util.{Failure, Success, Try} +import scala.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -10,18 +10,15 @@ import com.databricks.mosaic.functions.MosaicContext class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) { - val spark: SparkSession = analyzerMosaicFrame.sparkSession - import spark.implicits._ - val mosaicContext: MosaicContext = MosaicContext.context - import mosaicContext.functions._ - val defaultSampleFraction = 0.01 def getOptimalResolution(sampleFraction: Double): Int = { getOptimalResolution(SampleStrategy(sampleFraction = Some(sampleFraction))) } - private def getOptimalResolution(sampleStrategy: SampleStrategy): Int = { + def getOptimalResolution(sampleStrategy: SampleStrategy): Int = { + val ss = SparkSession.builder().getOrCreate() + import ss.implicits._ val metrics = getResolutionMetrics(sampleStrategy, 1, 100) .select("resolution", "percentile_50_geometry_area") @@ -33,6 +30,10 @@ class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) { } def getResolutionMetrics(sampleStrategy: SampleStrategy, lowerLimit: Int = 5, upperLimit: Int = 500): DataFrame = { + val mosaicContext = MosaicContext.context + import mosaicContext.functions._ + val spark = SparkSession.builder().getOrCreate() + def areaPercentile(p: Double): Column = percentile_approx(col("area"), lit(p), lit(10000)) val percentiles = analyzerMosaicFrame @@ -94,13 +95,24 @@ class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) { private def getMeanIndexArea(sampleStrategy: SampleStrategy, resolution: Int): Double = { val mosaicContext = MosaicContext.context import mosaicContext.functions._ + val spark = SparkSession.builder().getOrCreate() import spark.implicits._ val meanIndexAreaDf = analyzerMosaicFrame .transform(sampleStrategy.transformer) .withColumn("centroid", st_centroid2D(analyzerMosaicFrame.getGeometryColumn)) .select( - mean(st_area(index_geometry(point_index($"centroid.x", $"centroid.x", lit(resolution))))) + mean( + st_area( + index_geometry( + point_index_lonlat( + col("centroid").getItem("x"), + col("centroid").getItem("y"), + lit(resolution) + ) + ) + ) + ) ) Try(meanIndexAreaDf.as[Double].collect.head) match { @@ -109,10 +121,6 @@ class MosaicAnalyzer(analyzerMosaicFrame: MosaicFrame) { } } - def getOptimalResolution(sampleRows: Int): Int = { - getOptimalResolution(SampleStrategy(sampleRows = Some(sampleRows))) - } - } case class SampleStrategy(sampleFraction: Option[Double] = None, sampleRows: Option[Int] = None) { @@ -120,7 +128,7 @@ case class SampleStrategy(sampleFraction: Option[Double] = None, sampleRows: Opt (sampleFraction, sampleRows) match { case (Some(d), None) => df.sample(d) case (None, Some(l)) => df.limit(l) - case (Some(d), Some(l)) => df.limit(l) + case (Some(_), Some(l)) => df.limit(l) case _ => df } } diff --git a/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala b/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala index c7ef78ce7..b4e2186a1 100644 --- a/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala +++ b/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala @@ -14,7 +14,7 @@ class TestIntersectionExpressions extends AnyFlatSpec with IntersectionExpressio it should behave like intersects(MosaicContext.build(H3IndexSystem, JTS), spark) } - "ST_IntersectionAggregate" should "compute the intersects flag via aggregate expression" in { + "ST_IntersectionAggregate" should "compute the intersection via aggregate expression" in { it should behave like intersection(MosaicContext.build(H3IndexSystem, OGC), spark) it should behave like intersection(MosaicContext.build(H3IndexSystem, JTS), spark) }