From 11b4fd5286c4ca2895d0cc1895910f09e0bba833 Mon Sep 17 00:00:00 2001 From: Grigory Pomadchin Date: Thu, 26 May 2022 17:31:17 -0400 Subject: [PATCH] Fix Array HSerializers --- .../hiveless/serializers/HSerializer.scala | 4 +- .../hiveless/serializers/GroupString.scala | 23 +++++++++++ .../serializers/HSerializerSpec.scala | 38 +++++++++++++++++++ 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 spatial/src/test/scala/com/azavea/hiveless/serializers/GroupString.scala create mode 100644 spatial/src/test/scala/com/azavea/hiveless/serializers/HSerializerSpec.scala diff --git a/core/src/main/scala/com/azavea/hiveless/serializers/HSerializer.scala b/core/src/main/scala/com/azavea/hiveless/serializers/HSerializer.scala index 1c56a0d..53681d9 100644 --- a/core/src/main/scala/com/azavea/hiveless/serializers/HSerializer.scala +++ b/core/src/main/scala/com/azavea/hiveless/serializers/HSerializer.scala @@ -16,7 +16,9 @@ package com.azavea.hiveless.serializers +import com.azavea.hiveless.serializers.syntax._ import com.azavea.hiveless.spark.encoders.syntax._ + import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -106,6 +108,6 @@ object HSerializer extends Serializable { implicit def arraySerializer[T: HSerializer: ClassTag: λ[τ => C[τ] => Seq[τ]], C[_]]: HSerializer[C[T]] = new HSerializer[C[T]] { def dataType: DataType = ArrayType(HSerializer[T].dataType) - def serialize: C[T] => Any = seq => ArrayData.toArrayData(seq.toArray) + def serialize: C[T] => Any = seq => ArrayData.toArrayData(seq.map(_.serialize).toArray) } } diff --git a/spatial/src/test/scala/com/azavea/hiveless/serializers/GroupString.scala b/spatial/src/test/scala/com/azavea/hiveless/serializers/GroupString.scala new file mode 100644 index 0000000..2b7faee --- /dev/null +++ b/spatial/src/test/scala/com/azavea/hiveless/serializers/GroupString.scala @@ -0,0 +1,23 @@ +/* + * Copyright 2022 Azavea + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.azavea.hiveless.serializers + +import com.azavea.hiveless.HUDF + +class GroupString extends HUDF[(String, Int), Array[String]] { + def function = { case (str, size) => str.grouped(size).toArray } +} diff --git a/spatial/src/test/scala/com/azavea/hiveless/serializers/HSerializerSpec.scala b/spatial/src/test/scala/com/azavea/hiveless/serializers/HSerializerSpec.scala new file mode 100644 index 0000000..c70ce77 --- /dev/null +++ b/spatial/src/test/scala/com/azavea/hiveless/serializers/HSerializerSpec.scala @@ -0,0 +1,38 @@ +/* + * Copyright 2022 Azavea + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.azavea.hiveless.serializers + +import com.azavea.hiveless.{SpatialHiveTestEnvironment, SpatialTestTables} +import org.apache.spark.sql.SparkSession +import org.scalatest.funspec.AnyFunSpec + +class HSerializerSpec extends AnyFunSpec with SpatialHiveTestEnvironment with SpatialTestTables { + override def registerHiveUDFs(ssc: SparkSession): Unit = { + super.registerHiveUDFs(ssc) + ssc.sql("CREATE OR REPLACE FUNCTION groupString as 'com.azavea.hiveless.serializers.GroupString';") + } + + describe("HSerializerSpec") { + it("should serialize array of strings") { + val (str, n) = ("HSerializerSpecString", 3) + val expected = str.grouped(n).toArray + val df = ssc.sql(s"SELECT groupString('$str', $n);") + + df.collect().head.getAs[Array[String]](0) shouldBe expected + } + } +}