Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial PR for the framework reusing Vanilla Spark's unit tests #10743

Merged
merged 20 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@ The Apache Software Foundation (http://www.apache.org/).

--------------------------------------------------------------------------------

This project includes software from the Apache Gluten project
(www.github.com/apache/incubator-gluten/).

Apache Gluten (Incubating)
Copyright (2024) The Apache Software Foundation

This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).

--------------------------------------------------------------------------------

This project includes code from Kite, developed at Cloudera, Inc. with
the following copyright notice:

Expand Down
47 changes: 47 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,53 @@
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client-api</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client-runtime</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.curator</groupId>
<artifactId>curator-recipes</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
<exclusion>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
47 changes: 47 additions & 0 deletions scala2.13/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,53 @@
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client-api</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client-runtime</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.curator</groupId>
<artifactId>curator-recipes</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
</exclusion>
<exclusion>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
21 changes: 21 additions & 0 deletions scala2.13/tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,27 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-avro_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.scalatestplus</groupId>
<artifactId>scalatestplus-scalacheck_${scala.binary.version}</artifactId>
<version>3.1.0.0-RC2</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
21 changes: 21 additions & 0 deletions tests/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,27 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-avro_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.scalatestplus</groupId>
<artifactId>scalatestplus-scalacheck_${scala.binary.version}</artifactId>
<version>3.1.0.0-RC2</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
205 changes: 205 additions & 0 deletions tests/src/test/java/com/nvidia/spark/rapids/TestStats.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.spark.rapids;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.Stack;

/** Only use in UT Env. It's not thread safe. */
public class TestStats {
private static final String HEADER_FORMAT = "<tr><th>%s</th><th colspan=5>%s</th></tr>";
private static final String ROW_FORMAT =
"<tr><td>%s</td><td>%s</td><td>%s</td><td>%s</td><td>%s</td><td>%s</td></tr>";

private static boolean UT_ENV = false;
private static final Map<String, CaseInfo> caseInfos = new HashMap<>();
private static String currentCase;
public static int offloadRapidsUnitNumber = 0;
public static int testUnitNumber = 0;

// use the rapids backend to execute the query
public static boolean offloadRapids = true;
public static int suiteTestNumber = 0;
public static int offloadRapidsTestNumber = 0;

public static void beginStatistic() {
UT_ENV = true;
}

public static void reset() {
offloadRapids = false;
suiteTestNumber = 0;
offloadRapidsTestNumber = 0;
testUnitNumber = 0;
offloadRapidsUnitNumber = 0;
resetCase();
caseInfos.clear();
}

private static int totalSuiteTestNumber = 0;
public static int totalOffloadRapidsTestNumber = 0;

public static int totalTestUnitNumber = 0;
public static int totalOffloadRapidsCaseNumber = 0;

public static void printMarkdown(String suitName) {
if (!UT_ENV) {
return;
}

String title = "print_markdown_" + suitName;

String info =
"Case Count: %d, OffloadRapids Case Count: %d, "
+ "Unit Count %d, OffloadRapids Unit Count %d";

System.out.println(
String.format(
HEADER_FORMAT,
title,
String.format(
info,
TestStats.suiteTestNumber,
TestStats.offloadRapidsTestNumber,
TestStats.testUnitNumber,
TestStats.offloadRapidsUnitNumber)));

caseInfos.forEach(
(key, value) ->
System.out.println(
String.format(
ROW_FORMAT,
title,
key,
value.status,
value.type,
String.join("<br/>", value.fallbackExpressionName),
String.join("<br/>", value.fallbackClassName))));

totalSuiteTestNumber += suiteTestNumber;
totalOffloadRapidsTestNumber += offloadRapidsTestNumber;
totalTestUnitNumber += testUnitNumber;
totalOffloadRapidsCaseNumber += offloadRapidsUnitNumber;
System.out.println(
"total_markdown_ totalCaseNum:"
+ totalSuiteTestNumber
+ " offloadRapids: "
+ totalOffloadRapidsTestNumber
+ " total unit: "
+ totalTestUnitNumber
+ " offload unit: "
+ totalOffloadRapidsCaseNumber);
}

public static void addFallBackClassName(String className) {
if (!UT_ENV) {
return;
}

if (caseInfos.containsKey(currentCase) && !caseInfos.get(currentCase).stack.isEmpty()) {
CaseInfo info = caseInfos.get(currentCase);
caseInfos.get(currentCase).fallbackExpressionName.add(info.stack.pop());
caseInfos.get(currentCase).fallbackClassName.add(className);
}
}

public static void addFallBackCase() {
if (!UT_ENV) {
return;
}

if (caseInfos.containsKey(currentCase)) {
caseInfos.get(currentCase).type = "fallback";
}
}

public static void addExpressionClassName(String className) {
if (!UT_ENV) {
return;
}

if (caseInfos.containsKey(currentCase)) {
CaseInfo info = caseInfos.get(currentCase);
info.stack.add(className);
}
}

public static Set<String> getFallBackClassName() {
if (!UT_ENV) {
return Collections.emptySet();
}

if (caseInfos.containsKey(currentCase)) {
return Collections.unmodifiableSet(caseInfos.get(currentCase).fallbackExpressionName);
}

return Collections.emptySet();
}

public static void addIgnoreCaseName(String caseName) {
if (!UT_ENV) {
return;
}

if (caseInfos.containsKey(caseName)) {
caseInfos.get(caseName).type = "fatal";
}
}

public static void resetCase() {
if (!UT_ENV) {
return;
}

if (caseInfos.containsKey(currentCase)) {
caseInfos.get(currentCase).stack.clear();
}
currentCase = "";
}

public static void startCase(String caseName) {
if (!UT_ENV) {
return;
}

caseInfos.putIfAbsent(caseName, new CaseInfo());
currentCase = caseName;
}

public static void endCase(boolean status) {
if (!UT_ENV) {
return;
}

if (caseInfos.containsKey(currentCase)) {
caseInfos.get(currentCase).status = status ? "success" : "error";
}

resetCase();
}
}

class CaseInfo {
final Stack<String> stack = new Stack<>();
Set<String> fallbackExpressionName = new HashSet<>();
Set<String> fallbackClassName = new HashSet<>();
String type = "";
String status = "";
}
Loading
Loading