Skip to content

Commit

Permalink
add Pattern.compile()
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Aralihalli <[email protected]>
  • Loading branch information
SurajAralihalli committed Jan 3, 2025
1 parent 4df6d60 commit 275817f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
13 changes: 12 additions & 1 deletion integration_tests/src/main/python/regexp_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Copyright (c) 2022-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1102,6 +1102,17 @@ def test_regexp_memory_ok():
}
)

def test_illegal_regexp_exception():
gen = mk_str_gen('[abcdef]{0,5}')
assert_gpu_and_cpu_error(
lambda spark: unary_op_df(spark, gen).selectExpr(
'REGEXP_REPLACE(a, "a{", "bb")',
'REGEXP_REPLACE(a, "\\}\\,\\{", "}>>{")'
).collect(),
conf=_regexp_conf,
error_message="Illegal"
)

@datagen_overrides(seed=0, reason='https://github.com/NVIDIA/spark-rapids/issues/9731')
def test_re_replace_all():
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,8 @@
*/
package com.nvidia.spark.rapids

import java.util.regex.Pattern

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, RegExpReplace}
import org.apache.spark.sql.rapids.{GpuRegExpReplace, GpuRegExpReplaceWithBackref, GpuRegExpUtils}
import org.apache.spark.sql.types.DataTypes
Expand Down Expand Up @@ -51,6 +53,8 @@ class GpuRegExpReplaceMeta(
expr.regexp match {
case Literal(s: UTF8String, DataTypes.StringType) if s != null =>
javaPattern = Some(s.toString())
// check that this is valid in Java
Pattern.compile(javaPattern.toString)
try {
val (pat, repl) =
new CudfRegexTranspiler(RegexReplaceMode).getTranspiledAST(s.toString, None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids
import java.nio.charset.Charset
import java.text.DecimalFormatSymbols
import java.util.{EnumSet, Locale, Optional}
import java.util.regex.Pattern

import scala.annotation.tailrec
import scala.collection.mutable
Expand Down Expand Up @@ -1173,9 +1174,11 @@ class GpuRLikeMeta(
GpuRegExpUtils.tagForRegExpEnabled(this)
expr.right match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
val originalPattern = str.toString
// check that this is valid in Java
Pattern.compile(originalPattern)
try {
// verify that we support this regex and can transpile it to cuDF format
val originalPattern = str.toString
val regexAst = new RegexParser(originalPattern).parse()
if (conf.isRlikeRegexRewriteEnabled) {
rewriteOptimizationType = RegexRewrite.matchSimplePattern(regexAst)
Expand Down Expand Up @@ -1438,8 +1441,10 @@ class GpuRegExpExtractMeta(

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
val javaRegexpPattern = str.toString
// check that this is valid in Java
Pattern.compile(javaRegexpPattern)
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(
Expand Down Expand Up @@ -1567,8 +1572,10 @@ class GpuRegExpExtractAllMeta(

expr.regexp match {
case Literal(str: UTF8String, DataTypes.StringType) if str != null =>
val javaRegexpPattern = str.toString
// check that this is valid in Java
Pattern.compile(javaRegexpPattern)
try {
val javaRegexpPattern = str.toString
// verify that we support this regex and can transpile it to cuDF format
val (transpiledAST, _) =
new CudfRegexTranspiler(RegexFindMode).getTranspiledAST(
Expand Down

0 comments on commit 275817f

Please sign in to comment.