From 1f295e304e0c0f3edcae8fbc37a1aeeb16b84fc5 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 14 Mar 2022 11:23:18 +0000 Subject: [PATCH 1/2] Fix lonlat convention. Fix references ot ST_Area in ST_NumPoints. Fix wrong sql name in ST_IntersectionAggregate. --- .../mosaic/core/index/H3IndexSystem.scala | 12 ++++----- .../mosaic/core/index/IndexSystem.scala | 6 ++--- .../geometry/ST_IntersectionAggregate.scala | 2 +- .../expressions/geometry/ST_NumPoints.scala | 5 ++-- ...dexLatLon.scala => PointIndexLonLat.scala} | 26 +++++++++---------- .../mosaic/functions/MosaicContext.scala | 22 ++++++++-------- .../mosaic/sql/MosaicAnalyzer.scala | 2 +- 7 files changed, 37 insertions(+), 38 deletions(-) rename src/main/scala/com/databricks/mosaic/expressions/index/{PointIndexLatLon.scala => PointIndexLonLat.scala} (80%) diff --git a/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala b/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala index 652b967d4..a2c41d88c 100644 --- a/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala +++ b/src/main/scala/com/databricks/mosaic/core/index/H3IndexSystem.scala @@ -159,17 +159,17 @@ object H3IndexSystem extends IndexSystem with Serializable { /** * Get the index ID corresponding to the provided coordinates. * - * @param x - * X coordinate of the point. - * @param y - * Y coordinate of the point. + * @param lon + * Longitude coordinate of the point. + * @param lat + * Latitude coordinate of the point. * @param resolution * Resolution of the index. * @return * Index ID in this index system. */ - override def pointToIndex(x: Double, y: Double, resolution: Int): Long = { - h3.geoToH3(x, y, resolution) + override def pointToIndex(lon: Double, lat: Double, resolution: Int): Long = { + h3.geoToH3(lat, lon, resolution) } override def minResolution: Int = 0 diff --git a/src/main/scala/com/databricks/mosaic/core/index/IndexSystem.scala b/src/main/scala/com/databricks/mosaic/core/index/IndexSystem.scala index 74acef93b..701081f37 100644 --- a/src/main/scala/com/databricks/mosaic/core/index/IndexSystem.scala +++ b/src/main/scala/com/databricks/mosaic/core/index/IndexSystem.scala @@ -117,15 +117,15 @@ trait IndexSystem extends Serializable { /** * Get the index ID corresponding to the provided coordinates. * - * @param x + * @param lon * X coordinate of the point. - * @param y + * @param lat * Y coordinate of the point. * @param resolution * Resolution of the index. * @return * Index ID in this index system. */ - def pointToIndex(x: Double, y: Double, resolution: Int): Long + def pointToIndex(lon: Double, lat: Double, resolution: Int): Long } diff --git a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala index 9069df00b..eba19b04d 100644 --- a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala +++ b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_IntersectionAggregate.scala @@ -29,7 +29,7 @@ case class ST_IntersectionAggregate( override val dataType: DataType = BinaryType private val emptyWKB = geometryAPI.geometry("POLYGON(EMPTY)", "WKT").toWKB - override def prettyName: String = "st_reduce_intersection" + override def prettyName: String = "st_intersection_aggregate" override def update(accumulator: Array[Byte], inputRow: InternalRow): Array[Byte] = { val state = accumulator diff --git a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala index e267a904f..c2d75ebbd 100644 --- a/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala +++ b/src/main/scala/com/databricks/mosaic/expressions/geometry/ST_NumPoints.scala @@ -9,8 +9,7 @@ import com.databricks.mosaic.core.geometry.api.GeometryAPI case class ST_NumPoints(inputGeom: Expression, geometryAPIName: String) extends UnaryExpression with NullIntolerant with CodegenFallback { /** - * ST_Area expression returns are covered by the - * [[org.locationtech.jts.geom.Geometry]] instance extracted from inputGeom + * ST_NumPoints expression returns the number of points for a given geometry. * expression. */ @@ -27,7 +26,7 @@ case class ST_NumPoints(inputGeom: Expression, geometryAPIName: String) extends override def makeCopy(newArgs: Array[AnyRef]): Expression = { val asArray = newArgs.take(1).map(_.asInstanceOf[Expression]) - val res = ST_Area(asArray(0), geometryAPIName) + val res = ST_NumPoints(asArray(0), geometryAPIName) res.copyTagsFrom(this) res } diff --git a/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLatLon.scala b/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala similarity index 80% rename from src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLatLon.scala rename to src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala index 048a00e1d..56f9fba5c 100644 --- a/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLatLon.scala +++ b/src/main/scala/com/databricks/mosaic/expressions/index/PointIndexLonLat.scala @@ -6,7 +6,7 @@ import org.apache.spark.sql.types._ import com.databricks.mosaic.core.index.{H3IndexSystem, IndexSystemID} -case class PointIndexLatLon(lat: Expression, lon: Expression, resolution: Expression, indexSystemName: String) +case class PointIndexLonLat(lon: Expression, lat: Expression, resolution: Expression, indexSystemName: String) extends TernaryExpression with ExpectsInputTypes with NullIntolerant @@ -17,19 +17,19 @@ case class PointIndexLatLon(lat: Expression, lon: Expression, resolution: Expres /** Expression output DataType. */ override def dataType: DataType = LongType - override def toString: String = s"point_index_latlon($lat, $lon, $resolution)" + override def toString: String = s"point_index_lonlat($lon, $lat, $resolution)" /** Overridden to ensure [[Expression.sql]] is properly formatted. */ - override def prettyName: String = "point_index_latlon" + override def prettyName: String = "point_index_lonlat" /** * Computes the H3 index corresponding to the provided lat and long * coordinates. * * @param input1 - * Any instance containing latitude. - * @param input2 * Any instance containing longitude. + * @param input2 + * Any instance containing latitude. * @param input3 * Any instance containing resolution. * @return @@ -37,17 +37,17 @@ case class PointIndexLatLon(lat: Expression, lon: Expression, resolution: Expres */ override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { val resolution: Int = H3IndexSystem.getResolution(input3) - val x: Double = input1.asInstanceOf[Double] - val y: Double = input2.asInstanceOf[Double] + val lon: Double = input1.asInstanceOf[Double] + val lat: Double = input2.asInstanceOf[Double] val indexSystem = IndexSystemID.getIndexSystem(IndexSystemID(indexSystemName)) - indexSystem.pointToIndex(x, y, resolution) + indexSystem.pointToIndex(lon, lat, resolution) } override def makeCopy(newArgs: Array[AnyRef]): Expression = { val asArray = newArgs.take(3).map(_.asInstanceOf[Expression]) - val res = PointIndexLatLon(asArray(0), asArray(1), asArray(2), indexSystemName) + val res = PointIndexLonLat(asArray(0), asArray(1), asArray(2), indexSystemName) res.copyTagsFrom(this) res } @@ -63,16 +63,16 @@ case class PointIndexLatLon(lat: Expression, lon: Expression, resolution: Expres } -object PointIndexLatLon { +object PointIndexLonLat { /** Entry to use in the function registry. */ def registryExpressionInfo(db: Option[String]): ExpressionInfo = new ExpressionInfo( - classOf[PointIndexLatLon].getCanonicalName, + classOf[PointIndexLonLat].getCanonicalName, db.orNull, - "point_index_latlon", + "point_index_lonlat", """ - | _FUNC_(lat, lng, resolution) - Returns the h3 index of a point(lat, lng) at resolution. + | _FUNC_(lat, lng, resolution) - Returns the h3 index of a point(lon, lat) at resolution. """.stripMargin, "", """ diff --git a/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala index 925174f27..d1b21317d 100644 --- a/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala @@ -280,12 +280,12 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ) registry.registerFunction( FunctionIdentifier("point_index_latlon", database), - PointIndexLatLon.registryExpressionInfo(database), - (exprs: Seq[Expression]) => PointIndexLatLon(exprs(0), exprs(1), exprs(2), indexSystem.name) + PointIndexLonLat.registryExpressionInfo(database), + (exprs: Seq[Expression]) => PointIndexLonLat(exprs(0), exprs(1), exprs(2), indexSystem.name) ) registry.registerFunction( FunctionIdentifier("point_index", database), - PointIndexLatLon.registryExpressionInfo(database), + PointIndex.registryExpressionInfo(database), (exprs: Seq[Expression]) => PointIndex(exprs(0), exprs(1), indexSystem.name, geometryAPI.name) ) registry.registerFunction( @@ -383,14 +383,14 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(MosaicFill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) def mosaicfill(geom: Column, resolution: Int): Column = ColumnAdapter(MosaicFill(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) - def point_index(lat: Column, lng: Column, resolution: Column): Column = - ColumnAdapter(PointIndexLatLon(lat.expr, lng.expr, resolution.expr, indexSystem.name)) - def point_index(lat: Column, lng: Column, resolution: Int): Column = - ColumnAdapter(PointIndexLatLon(lat.expr, lng.expr, lit(resolution).expr, indexSystem.name)) - def point_index(geom: Column, resolution: Column): Column = - ColumnAdapter(PointIndex(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) - def point_index(geom: Column, resolution: Int): Column = - ColumnAdapter(PointIndex(geom.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) + def point_index(point: Column, resolution: Column): Column = + ColumnAdapter(PointIndex(point.expr, resolution.expr, indexSystem.name, geometryAPI.name)) + def point_index(point: Column, resolution: Int): Column = + ColumnAdapter(PointIndex(point.expr, lit(resolution).expr, indexSystem.name, geometryAPI.name)) + def point_index_lonlat(lon: Column, lat: Column, resolution: Column): Column = + ColumnAdapter(PointIndexLonLat(lon.expr, lat.expr, resolution.expr, indexSystem.name)) + def point_index_lonlat(lon: Column, lat: Column, resolution: Int): Column = + ColumnAdapter(PointIndexLonLat(lon.expr, lat.expr, lit(resolution).expr, indexSystem.name)) def polyfill(geom: Column, resolution: Column): Column = ColumnAdapter(Polyfill(geom.expr, resolution.expr, indexSystem.name, geometryAPI.name)) def polyfill(geom: Column, resolution: Int): Column = diff --git a/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala b/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala index 4d272bb84..289065fdc 100644 --- a/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala +++ b/src/main/scala/com/databricks/mosaic/sql/MosaicAnalyzer.scala @@ -99,7 +99,7 @@ object MosaicAnalyzer { mean( st_area( index_geometry( - point_index( + point_index_lonlat( col("centroid").getItem("x"), col("centroid").getItem("y"), lit(resolution) From a7ef71826f798353fd942f66ffb92d2ae6d1f109 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 14 Mar 2022 11:55:03 +0000 Subject: [PATCH 2/2] Fix misnomers. --- .../scala/com/databricks/mosaic/functions/MosaicContext.scala | 2 +- .../expressions/geometry/TestIntersectionExpressions.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala index d1b21317d..b17c34230 100644 --- a/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/mosaic/functions/MosaicContext.scala @@ -247,7 +247,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends (exprs: Seq[Expression]) => ST_Intersects(exprs(0), exprs(1), geometryAPI.name) ) registry.registerFunction( - FunctionIdentifier("st_intersects", database), + FunctionIdentifier("st_intersection", database), ST_Intersection.registryExpressionInfo(database), (exprs: Seq[Expression]) => ST_Intersection(exprs(0), exprs(1), geometryAPI.name) ) diff --git a/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala b/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala index c7ef78ce7..b4e2186a1 100644 --- a/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala +++ b/src/test/scala/com/databricks/mosaic/expressions/geometry/TestIntersectionExpressions.scala @@ -14,7 +14,7 @@ class TestIntersectionExpressions extends AnyFlatSpec with IntersectionExpressio it should behave like intersects(MosaicContext.build(H3IndexSystem, JTS), spark) } - "ST_IntersectionAggregate" should "compute the intersects flag via aggregate expression" in { + "ST_IntersectionAggregate" should "compute the intersection via aggregate expression" in { it should behave like intersection(MosaicContext.build(H3IndexSystem, OGC), spark) it should behave like intersection(MosaicContext.build(H3IndexSystem, JTS), spark) }