From 206fb2bf77e115f698e5c53d1e256cc510f29900 Mon Sep 17 00:00:00 2001 From: Corvin Kuebler Date: Tue, 16 Apr 2024 13:51:35 +0200 Subject: [PATCH] [EXPB-2111] MergeJoin removes null values in front of non null values Signed-off-by: Corvin Kuebler --- .../enumerable/EnumerableMergeJoin.java | 8 +- .../apache/calcite/util/BuiltInMethod.java | 13 +-- .../linq4j/MergeJoinNotNullEnumerable.java | 87 +++++++++++++++++++ 3 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 linq4j/src/main/java/org/apache/calcite/linq4j/MergeJoinNotNullEnumerable.java diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java index 97eae019d3f..46ad7d81a67 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeJoin.java @@ -500,8 +500,12 @@ public static EnumerableMergeJoin create(RelNode left, RelNode right, Expressions.call( BuiltInMethod.MERGE_JOIN.method, Expressions.list( - leftExpression, - rightExpression, + Expressions.call(BuiltInMethod.MERGE_JOIN_NOT_NULL_ENUMERABLE.method, leftExpression, + Expressions.lambda( + leftKeyPhysType.record(leftExpressions), left_)), + Expressions.call(BuiltInMethod.MERGE_JOIN_NOT_NULL_ENUMERABLE.method, rightExpression, + Expressions.lambda( + rightKeyPhysType.record(rightExpressions), right_)), Expressions.lambda( leftKeyPhysType.record(leftExpressions), left_), Expressions.lambda( diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 392e89fe4de..1a2b94d3734 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -32,16 +32,7 @@ import org.apache.calcite.interpreter.Context; import org.apache.calcite.interpreter.Row; import org.apache.calcite.interpreter.Scalar; -import org.apache.calcite.linq4j.AbstractEnumerable; -import org.apache.calcite.linq4j.Enumerable; -import org.apache.calcite.linq4j.EnumerableDefaults; -import org.apache.calcite.linq4j.Enumerator; -import org.apache.calcite.linq4j.ExtendedEnumerable; -import org.apache.calcite.linq4j.JoinType; -import org.apache.calcite.linq4j.Linq4j; -import org.apache.calcite.linq4j.MemoryFactory; -import org.apache.calcite.linq4j.QueryProvider; -import org.apache.calcite.linq4j.Queryable; +import org.apache.calcite.linq4j.*; import org.apache.calcite.linq4j.function.EqualityComparer; import org.apache.calcite.linq4j.function.Function0; import org.apache.calcite.linq4j.function.Function1; @@ -213,6 +204,8 @@ public enum BuiltInMethod { MERGE_JOIN(EnumerableDefaults.class, "mergeJoin", Enumerable.class, Enumerable.class, Function1.class, Function1.class, Predicate2.class, Function2.class, JoinType.class, Comparator.class, EqualityComparer.class), + MERGE_JOIN_NOT_NULL_ENUMERABLE( + MergeJoinNotNullEnumerable.class, "create", Enumerable.class, Function1.class), SLICE0(Enumerables.class, "slice0", Enumerable.class), SEMI_JOIN(EnumerableDefaults.class, "semiJoin", Enumerable.class, Enumerable.class, Function1.class, Function1.class, diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/MergeJoinNotNullEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/MergeJoinNotNullEnumerable.java new file mode 100644 index 00000000000..b528ae1aef7 --- /dev/null +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/MergeJoinNotNullEnumerable.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.calcite.linq4j; + +import org.apache.calcite.linq4j.function.Function1; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class MergeJoinNotNullEnumerable extends AbstractEnumerable { + private final Enumerable enumerable; + private final Function1 keySelector; + private boolean skipNulls = true; + + private MergeJoinNotNullEnumerable(Enumerable enumerable, Function1 keySelector) { + this.enumerable = enumerable; + this.keySelector = keySelector; + } + + public static MergeJoinNotNullEnumerable create(Enumerable enumerable, Function1 keySelector) { + return new MergeJoinNotNullEnumerable<>(enumerable, keySelector); + } + + @Override + public Enumerator enumerator() { + Enumerator enumerator = enumerable.enumerator(); + return new Enumerator() { + @Override + public T current() { + return enumerator.current(); + } + + @Override + public boolean moveNext() { + boolean next = enumerator.moveNext(); + if (!skipNulls) { + return next; + } + while (next) { + K key = keySelector.apply(enumerator.current()); + if (key != null) { + if (key instanceof Object[]) { + if(Arrays.stream((Object[]) key).noneMatch(Objects::isNull)) { + break; + } + } else if (key instanceof List) { + if (((List) key).stream().noneMatch(Objects::isNull)) { + break; + } + } else { + break; + } + } + next = enumerator.moveNext(); + } + skipNulls = false; + return next; + } + + @Override + public void reset() { + enumerator.reset(); + } + + @Override + public void close() { + enumerator.close(); + } + }; + } +}