From 502f5a3cd96e458c8471794af9d2e209d9f0b42f Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 10 Sep 2024 15:50:50 -0700 Subject: [PATCH] Fixed some of the failing parquet_tests [databricks] (#11429) * Fixed some of the failing parquet_tests * Signing off Signed-off-by: Raza Jafri * addressed review comments * removed unused import --------- Signed-off-by: Raza Jafri --- .../src/main/python/parquet_test.py | 47 ++++++++++--------- .../spark/rapids/shims/GpuBatchScanExec.scala | 3 +- .../spark/rapids/shims/GpuBatchScanExec.scala | 1 + 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index e21ba622f46..6032d469fb2 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -35,15 +35,19 @@ def read_parquet_df(data_path): def read_parquet_sql(data_path): return lambda spark : spark.sql('select * from parquet.`{}`'.format(data_path)) +datetimeRebaseModeInWriteKey = 'spark.sql.parquet.datetimeRebaseModeInWrite' +int96RebaseModeInWriteKey = 'spark.sql.parquet.int96RebaseModeInWrite' +datetimeRebaseModeInReadKey = 'spark.sql.parquet.datetimeRebaseModeInRead' +int96RebaseModeInReadKey = 'spark.sql.parquet.int96RebaseModeInRead' rebase_write_corrected_conf = { - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'CORRECTED', - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'CORRECTED' + datetimeRebaseModeInWriteKey : 'CORRECTED', + int96RebaseModeInWriteKey : 'CORRECTED' } rebase_write_legacy_conf = { - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': 'LEGACY', - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': 'LEGACY' + datetimeRebaseModeInWriteKey : 'LEGACY', + int96RebaseModeInWriteKey : 'LEGACY' } # Like the standard map_gens_sample but with timestamps limited @@ -146,8 +150,8 @@ def test_parquet_read_coalescing_multiple_files(spark_tmp_path, parquet_gens, re all_confs = copy_and_update(reader_confs, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + int96RebaseModeInReadKey : 'CORRECTED', + datetimeRebaseModeInReadKey : 'CORRECTED'}) # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround # for nested timestamp/date support assert_gpu_and_cpu_are_equal_collect(read_func(data_path), @@ -188,8 +192,8 @@ def test_parquet_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader all_confs = copy_and_update(reader_confs, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + int96RebaseModeInReadKey : 'CORRECTED', + datetimeRebaseModeInReadKey : 'CORRECTED'}) # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround # for nested timestamp/date support assert_gpu_and_cpu_are_equal_collect(read_func(data_path), @@ -199,6 +203,7 @@ def test_parquet_read_round_trip(spark_tmp_path, parquet_gens, read_func, reader @allow_non_gpu('FileSourceScanExec') @pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) @pytest.mark.parametrize('disable_conf', ['spark.rapids.sql.format.parquet.enabled', 'spark.rapids.sql.format.parquet.read.enabled']) +@disable_ansi_mode def test_parquet_fallback(spark_tmp_path, read_func, disable_conf): data_gens = [string_gen, byte_gen, short_gen, int_gen, long_gen, boolean_gen] + decimal_gens @@ -225,8 +230,8 @@ def test_parquet_read_round_trip_binary(std_input_path, read_func, binary_as_str all_confs = copy_and_update(reader_confs, { 'spark.sql.parquet.binaryAsString': binary_as_string, # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + int96RebaseModeInReadKey : 'CORRECTED', + datetimeRebaseModeInReadKey : 'CORRECTED'}) # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround # for nested timestamp/date support assert_gpu_and_cpu_are_equal_collect(read_func(data_path), @@ -245,8 +250,8 @@ def test_binary_df_read(spark_tmp_path, binary_as_string, read_func, data_gen): all_confs = { 'spark.sql.parquet.binaryAsString': binary_as_string, # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInRead': 'CORRECTED', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'} + int96RebaseModeInReadKey : 'CORRECTED', + datetimeRebaseModeInReadKey : 'CORRECTED'} assert_gpu_and_cpu_are_equal_collect(read_func(data_path), conf=all_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) @@ -256,8 +261,8 @@ def test_parquet_read_forced_binary_schema(std_input_path, v1_enabled_list): all_confs = copy_and_update(reader_opt_confs[0], { 'spark.sql.sources.useV1SourceList': v1_enabled_list, # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + int96RebaseModeInReadKey : 'CORRECTED', + datetimeRebaseModeInReadKey : 'CORRECTED'}) # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround # for nested timestamp/date support @@ -277,8 +282,8 @@ def test_parquet_read_round_trip_binary_as_string(std_input_path, read_func, rea 'spark.sql.sources.useV1SourceList': v1_enabled_list, 'spark.sql.parquet.binaryAsString': 'true', # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU - 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + int96RebaseModeInReadKey : 'CORRECTED', + datetimeRebaseModeInReadKey : 'CORRECTED'}) # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround # for nested timestamp/date support assert_gpu_and_cpu_are_equal_collect(read_func(data_path), @@ -342,16 +347,16 @@ def test_parquet_read_roundtrip_datetime_with_legacy_rebase(spark_tmp_path, parq gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gens)] data_path = spark_tmp_path + '/PARQUET_DATA' write_confs = {'spark.sql.parquet.outputTimestampType': ts_type, - 'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase_write[0], - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase_write[1]} + datetimeRebaseModeInWriteKey : ts_rebase_write[0], + int96RebaseModeInWriteKey : ts_rebase_write[1]} with_cpu_session( lambda spark: gen_df(spark, gen_list).write.parquet(data_path), conf=write_confs) # The rebase modes in read configs should be ignored and overridden by the same modes in write # configs, which are retrieved from the written files. read_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list, - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': ts_rebase_read[0], - 'spark.sql.legacy.parquet.int96RebaseModeInRead': ts_rebase_read[1]}) + datetimeRebaseModeInReadKey : ts_rebase_read[0], + int96RebaseModeInReadKey : ts_rebase_read[1]}) assert_gpu_and_cpu_are_equal_collect( lambda spark: spark.read.parquet(data_path), conf=read_confs) @@ -734,7 +739,7 @@ def test_nested_pruning_and_case_insensitive(spark_tmp_path, data_gen, read_sche all_confs = copy_and_update(reader_confs, { 'spark.sql.sources.useV1SourceList': v1_enabled_list, 'spark.sql.optimizer.nestedSchemaPruning.enabled': nested_enabled, - 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + datetimeRebaseModeInReadKey : 'CORRECTED'}) # This is a hack to get the type in a slightly less verbose way rs = StructGen(read_schema, nullable=False).data_type assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.read.schema(rs).parquet(data_path), diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 4e3b86d5ef0..7e5a6f4cd35 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.catalyst.util.InternalRowSet import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.execution.datasources.rapids.DataSourceStrategyUtils -import org.apache.spark.sql.execution.datasources.v2._ case class GpuBatchScanExec( output: Seq[AttributeReference], @@ -46,7 +45,7 @@ case class GpuBatchScanExec( // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { - case other: BatchScanExec => + case other: GpuBatchScanExec => this.batch == other.batch && this.runtimeFilters == other.runtimeFilters case _ => false diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala index 4fc62d82df3..c2d2c8d934e 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuBatchScanExec.scala @@ -137,6 +137,7 @@ case class GpuBatchScanExec( override lazy val readerFactory: PartitionReaderFactory = batch.createReaderFactory() override lazy val inputRDD: RDD[InternalRow] = { + scan.metrics = allMetrics val rdd = if (filteredPartitions.isEmpty && outputPartitioning == SinglePartition) { // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1)