diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java index 94ef9c32133..448c4540b6f 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoin.java @@ -34,6 +34,7 @@ import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.metadata.RelMdCollation; +import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rex.RexNode; import org.apache.calcite.util.BuiltInMethod; @@ -55,6 +56,7 @@ public class EnumerableBatchNestedLoopJoin extends Join implements EnumerableRel { private final ImmutableBitSet requiredColumns; + private final double rightSideFilterSelectivity; protected EnumerableBatchNestedLoopJoin( RelOptCluster cluster, RelTraitSet traits, @@ -63,9 +65,11 @@ protected EnumerableBatchNestedLoopJoin( RexNode condition, Set variablesSet, ImmutableBitSet requiredColumns, - JoinRelType joinType) { + JoinRelType joinType, + double rightSideFilterSelectivity) { super(cluster, traits, ImmutableList.of(), left, right, condition, variablesSet, joinType); this.requiredColumns = requiredColumns; + this.rightSideFilterSelectivity = rightSideFilterSelectivity; } public static EnumerableBatchNestedLoopJoin create( @@ -74,7 +78,8 @@ public static EnumerableBatchNestedLoopJoin create( RexNode condition, ImmutableBitSet requiredColumns, Set variablesSet, - JoinRelType joinType) { + JoinRelType joinType, + double rightSideFilterSelectivity) { final RelOptCluster cluster = left.getCluster(); final RelMetadataQuery mq = cluster.getMetadataQuery(); final RelTraitSet traitSet = @@ -89,7 +94,19 @@ public static EnumerableBatchNestedLoopJoin create( condition, variablesSet, requiredColumns, - joinType); + joinType, + rightSideFilterSelectivity); + } + + @Deprecated + public static EnumerableBatchNestedLoopJoin create( + RelNode left, + RelNode right, + RexNode condition, + ImmutableBitSet requiredColumns, + Set variablesSet, + JoinRelType joinType) { + return create(left, right, condition, requiredColumns, variablesSet, joinType, 1.0); } @Override public @Nullable Pair> passThroughTraits( @@ -115,14 +132,14 @@ public static EnumerableBatchNestedLoopJoin create( @Override public EnumerableBatchNestedLoopJoin copy(RelTraitSet traitSet, RexNode condition, RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone) { - return new EnumerableBatchNestedLoopJoin(getCluster(), traitSet, - left, right, condition, variablesSet, requiredColumns, joinType); + return new EnumerableBatchNestedLoopJoin(getCluster(), traitSet, left, right, condition, + variablesSet, requiredColumns, joinType, rightSideFilterSelectivity); } @Override public @Nullable RelOptCost computeSelfCost( final RelOptPlanner planner, final RelMetadataQuery mq) { - double rowCount = mq.getRowCount(this); + double rowCount = estimateRowCount(mq); final double rightRowCount = right.estimateRowCount(mq); final double leftRowCount = left.estimateRowCount(mq); @@ -144,6 +161,18 @@ public static EnumerableBatchNestedLoopJoin create( rowCount + leftRowCount, 0, 0).plus(rescanCost); } + @Override public double estimateRowCount(RelMetadataQuery mq) { + return unwrapDouble(RelMdUtil.getJoinRowCount(mq, this, condition, + unwrapDouble(mq.getRowCount(getRight())) / rightSideFilterSelectivity)); + } + + static double unwrapDouble(Double value) { + if (value == null) { + return Double.POSITIVE_INFINITY; + } + return value.doubleValue(); + } + @Override public RelWriter explainTerms(RelWriter pw) { super.explainTerms(pw); return pw.item("batchSize", variablesSet.size()); diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java index d7accbd7058..09d0762be94 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableBatchNestedLoopJoinRule.java @@ -40,6 +40,8 @@ import java.util.List; import java.util.Set; +import static org.apache.calcite.adapter.enumerable.EnumerableBatchNestedLoopJoin.unwrapDouble; + /** Rule to convert a {@link LogicalJoin} to an {@link EnumerableBatchNestedLoopJoin}. * You may provide a custom config to convert other nodes that extend {@link Join}. * @@ -134,9 +136,14 @@ public EnumerableBatchNestedLoopJoinRule(RelBuilderFactory relBuilderFactory, conditionList.add(condition2); } + RexNode filterCondition = relBuilder.or(conditionList); + // Push a filter with batchSize disjunctions - relBuilder.push(join.getRight()).filter(relBuilder.or(conditionList)); + relBuilder.push(join.getRight()).filter(filterCondition); final RelNode right = relBuilder.build(); + final double filterSelectivity = right.getInputs().size() == 1 + ? unwrapDouble(call.getMetadataQuery().getSelectivity(right.getInput(0), filterCondition)) + : 1.0; call.transformTo( EnumerableBatchNestedLoopJoin.create( @@ -147,7 +154,8 @@ public EnumerableBatchNestedLoopJoinRule(RelBuilderFactory relBuilderFactory, join.getCondition(), requiredColumns.build(), correlationIds, - join.getJoinType())); + join.getJoinType(), + filterSelectivity)); } /** Rule configuration. */ diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java index 696daa77db0..3d40f50c48e 100644 --- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java +++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdUtil.java @@ -796,6 +796,12 @@ public static double getMinusRowCount(RelMetadataQuery mq, Minus minus) { /** Returns an estimate of the number of rows returned by a {@link Join}. */ public static @Nullable Double getJoinRowCount(RelMetadataQuery mq, Join join, RexNode condition) { + return getJoinRowCount(mq, join, condition, mq.getRowCount(join.getRight())); + } + + /** Returns an estimate of the number of rows returned by a {@link Join}. */ + public static @Nullable Double getJoinRowCount(RelMetadataQuery mq, Join join, + RexNode condition, Double rightRowCount) { if (!join.getJoinType().projectsRight()) { // Create a RexNode representing the selectivity of the // semijoin filter and pass it to getSelectivity @@ -813,7 +819,7 @@ public static double getMinusRowCount(RelMetadataQuery mq, Minus minus) { // Row count estimates of 0 will be rounded up to 1. // So, use maxRowCount where the product is very small. final Double left = mq.getRowCount(join.getLeft()); - final Double right = mq.getRowCount(join.getRight()); + final Double right = rightRowCount; if (left == null || right == null) { return null; } diff --git a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableBatchNestedLoopJoinTest.java b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableBatchNestedLoopJoinTest.java index 36e141a8dd5..e6e78b88372 100644 --- a/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableBatchNestedLoopJoinTest.java +++ b/core/src/test/java/org/apache/calcite/test/enumerable/EnumerableBatchNestedLoopJoinTest.java @@ -226,6 +226,7 @@ class EnumerableBatchNestedLoopJoinTest { + "join locations l on e.empid <> l.empid and d.deptno = l.empid") .withHook(Hook.PLANNER, (Consumer) planner -> { planner.removeRule(EnumerableRules.ENUMERABLE_CORRELATE_RULE); + planner.removeRule(EnumerableRules.ENUMERABLE_JOIN_RULE); // Use a small batch size, otherwise we will run into Janino's // "InternalCompilerException: Code of method grows beyond 64 KB". planner.addRule(