diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index 775b4a9d1cb..1d395d0e29a 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -676,6 +676,27 @@ def test_write_daytime_interval(spark_tmp_path): data_path, conf=writer_confs) + +hold_gpu_configs = [True, False] +@pytest.mark.parametrize('hold_gpu', hold_gpu_configs, ids=idfn) +def test_async_writer(spark_tmp_path, hold_gpu): + data_path = spark_tmp_path + '/PARQUET_DATA' + num_rows = 2048 + num_cols = 10 + parquet_gen = [int_gen for _ in range(num_cols)] + gen_list = [('_c' + str(i), gen) for i, gen in enumerate(parquet_gen)] + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list, length=num_rows).coalesce(1).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + copy_and_update( + writer_confs, + {"spark.rapids.sql.asyncWrite.queryOutput.enabled": "true", + "spark.rapids.sql.batchSizeBytes": 4 * num_cols * 100, # 100 rows per batch + "spark.rapids.sql.queryOutput.holdGpuInTask": hold_gpu} + )) + + @ignore_order @pytest.mark.skipif(is_before_spark_320(), reason="is only supported in Spark 320+") def test_concurrent_writer(spark_tmp_path): diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala index df62683d346..8d89583d9df 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala @@ -73,13 +73,14 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, dataSchema: StructType, rangeName: String, includeRetry: Boolean, - holdGpuBetweenBatches: Boolean = false) extends HostBufferConsumer with Logging { + holdGpuBetweenBatches: Boolean = false, + useAsyncWrite: Boolean = false) extends HostBufferConsumer with Logging { protected val tableWriter: TableWriter protected val conf: Configuration = context.getConfiguration - private val trafficController: Option[TrafficController] = TrafficController.getInstance + private val trafficController: TrafficController = TrafficController.getInstance private def openOutputStream(): OutputStream = { val hadoopPath = new Path(path) @@ -90,10 +91,12 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, // This is implemented as a method to make it easier to subclass // ColumnarOutputWriter in the tests, and override this behavior. protected def getOutputStream: OutputStream = { - trafficController.map(controller => { + if (useAsyncWrite) { logWarning("Async output write enabled") - new AsyncOutputStream(() => openOutputStream(), controller) - }).getOrElse(openOutputStream()) + new AsyncOutputStream(() => openOutputStream(), trafficController) + } else { + openOutputStream() + } } protected val outputStream: OutputStream = getOutputStream diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 2b5f246e56a..e5aa52c727d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -283,7 +283,7 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { context: TaskAttemptContext): ColumnarOutputWriter = { new GpuParquetWriter(path, dataSchema, compressionType, outputTimestampType.toString, dateTimeRebaseMode, timestampRebaseMode, context, parquetFieldIdWriteEnabled, - holdGpuBetweenBatches) + holdGpuBetweenBatches, asyncOutputWriteEnabled) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -306,8 +306,10 @@ class GpuParquetWriter( timestampRebaseMode: DateTimeRebaseMode, context: TaskAttemptContext, parquetFieldIdEnabled: Boolean, - holdGpuBetweenBatches: Boolean) - extends ColumnarOutputWriter(context, dataSchema, "Parquet", true, holdGpuBetweenBatches) { + holdGpuBetweenBatches: Boolean, + useAsyncWrite: Boolean) + extends ColumnarOutputWriter(context, dataSchema, "Parquet", true, holdGpuBetweenBatches, + useAsyncWrite) { override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = { val cols = GpuColumnVector.extractBases(batch) cols.foreach { col => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala index 0110f2d89ca..cd698966c8c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala @@ -124,14 +124,14 @@ object TrafficController { * This is called once per executor. */ def initialize(conf: RapidsConf): Unit = synchronized { - if (conf.isAsyncOutputWriteEnabled && instance == null) { + if (instance == null) { instance = new TrafficController( new HostMemoryThrottle(conf.asyncWriteMaxInFlightHostMemoryBytes)) } } - def getInstance: Option[TrafficController] = synchronized { - Option(instance) + def getInstance: TrafficController = synchronized { + instance } def shutdown(): Unit = synchronized {