diff --git a/integration_tests/src/main/python/regexp_test.py b/integration_tests/src/main/python/regexp_test.py index b67f9dc6679..d3cadc9e75b 100644 --- a/integration_tests/src/main/python/regexp_test.py +++ b/integration_tests/src/main/python/regexp_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1102,6 +1102,17 @@ def test_regexp_memory_ok(): } ) +def test_illegal_regexp_exception(): + gen = mk_str_gen('[abcdef]{0,5}') + assert_gpu_and_cpu_error( + lambda spark: unary_op_df(spark, gen).selectExpr( + 'REGEXP_REPLACE(a, "a{", "bb")', + 'REGEXP_REPLACE(a, "\\}\\,\\{", "}>>{")' + ).collect(), + conf=_regexp_conf, + error_message="Illegal" + ) + @datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9731') def test_re_replace_all(): """ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala index 99e82d8913e..cb6672511c0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRegExpReplaceMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids +import java.util.regex.Pattern + import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace} import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils} import org.apache.spark.sql.types.DataTypes @@ -51,6 +53,8 @@ class GpuRegExpReplaceMeta( expr.regexp match { case Literal(s: UTF8String, DataTypes.StringType) if s != null => javaPattern = Some(s.toString()) + // check that this is valid in Java + Pattern.compile(javaPattern.toString) try { val (pat, repl) = new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index f668195abc7..0f0fe5b5bfa 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids import java.nio.charset.Charset import java.text.DecimalFormatSymbols import java.util.{EnumSet, Locale, Optional} +import java.util.regex.Pattern import scala.annotation.tailrec import scala.collection.mutable @@ -1173,9 +1174,11 @@ class GpuRLikeMeta( GpuRegExpUtils.tagForRegExpEnabled(this) expr.right match { case Literal(str: UTF8String, DataTypes.StringType) if str != null => + val originalPattern = str.toString + // check that this is valid in Java + Pattern.compile(originalPattern) try { // verify that we support this regex and can transpile it to cuDF format - val originalPattern = str.toString val regexAst = new RegexParser(originalPattern).parse() if (conf.isRlikeRegexRewriteEnabled) { rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst) @@ -1438,8 +1441,10 @@ class GpuRegExpExtractMeta( expr.regexp match { case Literal(str: UTF8String, DataTypes.StringType) if str != null => + val javaRegexpPattern = str.toString + // check that this is valid in Java + Pattern.compile(javaRegexpPattern) try { - val javaRegexpPattern = str.toString // verify that we support this regex and can transpile it to cuDF format val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode).getTranspiledAST( @@ -1567,8 +1572,10 @@ class GpuRegExpExtractAllMeta( expr.regexp match { case Literal(str: UTF8String, DataTypes.StringType) if str != null => + val javaRegexpPattern = str.toString + // check that this is valid in Java + Pattern.compile(javaRegexpPattern) try { - val javaRegexpPattern = str.toString // verify that we support this regex and can transpile it to cuDF format val (transpiledAST, _) = new CudfRegexTranspiler(RegexFindMode).getTranspiledAST( @@ -1830,6 +1837,8 @@ abstract class StringSplitRegExpMeta[INPUT <: TernaryExpression](expr: INPUT, case Some(simplified) => pattern = simplified case None => + // check that this is valid in Java + Pattern.compile(utf8Str.toString) try { val (transpiledAST, _) = transpiler.getTranspiledAST(utf8Str.toString, None, None) GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)