Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Java Compatibility Check for Regex Patterns #11912

Open
wants to merge 2 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1102,6 +1102,17 @@ def test_regexp_memory_ok():
}
)

def test_illegal_regexp_exception():
gen = mk_str_gen('[abcdef]{0,5}')
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "a{", "bb")',
'REGEXP_REPLACE(a, "\\}\\,\\{", "}>>{")'
).collect(),
conf=_regexp_conf,
error_message="Illegal"
)

@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9731')
def test_re_replace_all():
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,8 @@
*/
package com.nvidia.spark.rapids

import java.util.regex.Pattern

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils}
import org.apache.spark.sql.types.DataTypes
Expand Down Expand Up @@ -51,6 +53,8 @@ class GpuRegExpReplaceMeta(
expr.regexp match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
javaPattern = Some(s.toString())
// check that this is valid in Java
Pattern.compile(javaPattern.toString)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move this to the RegexTranspiler code. I think it could be done in both getTranspiledAST(...)/RegexParser(...)

Otherwise this line will have to be called in every Regexp Expression class, and it could easily be lost in a few places. The transpiler is used by all of these methods, so this makes sense as a shortcut.

try {
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids
import java.nio.charset.Charset
import java.text.DecimalFormatSymbols
import java.util.{EnumSet, Locale, Optional}
import java.util.regex.Pattern

import scala.annotation.tailrec
import scala.collection.mutable
Expand Down Expand Up @@ -1173,9 +1174,11 @@ class GpuRLikeMeta(
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.right match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
val originalPattern = str.toString
// check that this is valid in Java
Pattern.compile(originalPattern)
try {
// verify that we support this regex and can transpile it to cuDF format
val originalPattern = str.toString
val regexAst = new RegexParser(originalPattern).parse()
if (conf.isRlikeRegexRewriteEnabled) {
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst)
Expand Down Expand Up @@ -1438,8 +1441,10 @@ class GpuRegExpExtractMeta(

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
val javaRegexpPattern = str.toString
// check that this is valid in Java
Pattern.compile(javaRegexpPattern)
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(
Expand Down Expand Up @@ -1567,8 +1572,10 @@ class GpuRegExpExtractAllMeta(

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
val javaRegexpPattern = str.toString
// check that this is valid in Java
Pattern.compile(javaRegexpPattern)
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(
Expand Down Expand Up @@ -1830,6 +1837,8 @@ abstract class StringSplitRegExpMeta[INPUT <: TernaryExpression](expr: INPUT,
case Some(simplified) =>
pattern = simplified
case None =>
// check that this is valid in Java
Pattern.compile(utf8Str.toString)
try {
val (transpiledAST, _) = transpiler.getTranspiledAST(utf8Str.toString, None, None)
GpuRegExpUtils.validateRegExpComplexity(this, transpiledAST)
Expand Down
Loading