Skip to content

Commit

Permalink
Adds FnProvider, Fn, AggProvider, and Agg
Browse files Browse the repository at this point in the history
Adds a RoutineSignature and RoutineProviderSignature

Adds a RoutineProviderParam

Adds scalar and aggregate function builders

Updates existing function implementations to use new APIs
  • Loading branch information
johnedquinn committed Jan 22, 2025
1 parent ec92c6b commit 25043b2
Show file tree
Hide file tree
Showing 68 changed files with 1,609 additions and 1,112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.partiql.spi.catalog.Identifier
import org.partiql.spi.errors.PError
import org.partiql.spi.errors.PErrorKind
import org.partiql.spi.errors.Severity
import org.partiql.spi.function.Function
import org.partiql.spi.function.FnProvider
import org.partiql.spi.types.PType
import java.io.PrintWriter
import java.io.Writer
Expand Down Expand Up @@ -186,7 +186,7 @@ object ErrorMessageFormatter {
*/
private fun fnTypeMismatch(error: PError): String {
val functionName = error.getOrNull("FN_ID", Identifier::class.java)
val candidates = error.getListOrNull("CANDIDATES", Function::class.java)
val candidates = error.getListOrNull("CANDIDATES", FnProvider::class.java)
val args = error.getListOrNull("ARG_TYPES", PType::class.java)
val fnNameStr = prepare(functionName.toString(), " ", "")
val fnStr = when {
Expand All @@ -196,20 +196,6 @@ object ErrorMessageFormatter {
}
return buildString {
append("Undefined function$fnStr.")
if (!candidates.isNullOrEmpty()) {
appendLine(" Did you mean: ")
for (variant in candidates) {
variant as Function
append("- ")
append(variant.getName())
append(
variant.getParameters().joinToString(", ", "(", ")") {
"${it.getName()}: ${it.getType()}"
}
)
appendLine()
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ internal class StandardCompiler(strategies: List<Strategy>) : PartiQLCompiler {
// Compile the candidates
val candidates = Array(functions.size) {
val fn = functions[it]
val fnArity = fn.getParameters().size
val fnArity = fn.signature.arity
if (arity == -1) {
// set first
arity = fnArity
Expand All @@ -385,7 +385,7 @@ internal class StandardCompiler(strategies: List<Strategy>) : PartiQLCompiler {
override fun visitCall(rex: RexCall, ctx: Unit): ExprValue {
val func = rex.getFunction()
val args = rex.getArgs()
val catch = func.parameters.any { it.code() == PType.DYNAMIC }
val catch = func.signature.parameters.any { it.type.code() == PType.DYNAMIC }
return when (catch) {
true -> ExprCall(func, Array(args.size) { i -> compile(args[i], Unit).catch() })
else -> ExprCall(func, Array(args.size) { i -> compile(args[i], Unit) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.partiql.spi.errors.PError
import org.partiql.spi.errors.PErrorKind
import org.partiql.spi.errors.PRuntimeException
import org.partiql.spi.errors.Severity
import org.partiql.spi.function.Function
import org.partiql.spi.function.FnProvider
import org.partiql.spi.types.PType

internal object PErrors {
Expand All @@ -20,7 +20,7 @@ internal object PErrors {
/**
* Returns a PRuntimeException with code: [PError.FUNCTION_TYPE_MISMATCH].
*/
fun functionTypeMismatchException(name: String, actualTypes: List<PType>, candidates: List<Function>): PRuntimeException {
fun functionTypeMismatchException(name: String, actualTypes: List<PType>, candidates: List<FnProvider>): PRuntimeException {
val pError = functionTypeMismatch(name, actualTypes, candidates)
return PRuntimeException(pError)
}
Expand Down Expand Up @@ -158,7 +158,7 @@ internal object PErrors {
/**
* Returns a PError with code: [PError.FUNCTION_TYPE_MISMATCH].
*/
private fun functionTypeMismatch(name: String, actualTypes: List<PType>, candidates: List<Function>): PError {
private fun functionTypeMismatch(name: String, actualTypes: List<PType>, candidates: List<FnProvider>): PError {
return PError(
PError.FUNCTION_TYPE_MISMATCH,
Severity.ERROR(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package org.partiql.eval.internal.operator

import org.partiql.eval.ExprValue
import org.partiql.spi.function.Aggregation
import org.partiql.spi.function.Agg

/**
* Simple data class to hold a compile aggregation call.
*/
internal class Aggregate(
val agg: Aggregation,
val agg: Agg,
val args: List<ExprValue>,
val distinct: Boolean
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import org.partiql.eval.ExprValue
import org.partiql.eval.Row
import org.partiql.eval.internal.helpers.DatumArrayComparator
import org.partiql.eval.internal.operator.Aggregate
import org.partiql.spi.function.Aggregation
import org.partiql.spi.types.PType
import org.partiql.spi.function.Accumulator
import org.partiql.spi.value.Datum
import java.util.TreeMap
import java.util.TreeSet
Expand All @@ -23,12 +22,12 @@ internal class RelOpAggregate(
private val aggregationMap = TreeMap<Array<Datum>, List<AccumulatorWrapper>>(DatumArrayComparator)

/**
* Wraps an [Aggregation.Accumulator] to help with filtering distinct values.
* Wraps an [Accumulator] to help with filtering distinct values.
*
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
*/
class AccumulatorWrapper(
val delegate: Aggregation.Accumulator,
val delegate: Accumulator,
val args: List<ExprValue>,
val seen: TreeSet<Array<Datum>>?
)
Expand All @@ -47,13 +46,10 @@ internal class RelOpAggregate(
}
}

// TODO IT DOES NOT MATTER NOW, BUT SqlCompiler SHOULD HANDLE GET THE ARGUMENT TYPES FOR .getAccumulator
val args: Array<PType> = emptyArray()

val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
aggregates.map {
AccumulatorWrapper(
delegate = it.agg.getAccumulator(args),
delegate = it.agg.accumulator,
args = it.args,
seen = if (it.distinct) TreeSet(DatumArrayComparator) else null
)
Expand Down Expand Up @@ -84,7 +80,7 @@ internal class RelOpAggregate(
if (groups.isEmpty() && aggregationMap.isEmpty()) {
val record = mutableListOf<Datum>()
aggregates.forEach { function ->
val accumulator = function.agg.getAccumulator(args = emptyArray())
val accumulator = function.agg.accumulator
record.add(accumulator.value())
}
records = iterator { yield(Row(record.toTypedArray())) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.eval.Environment
import org.partiql.eval.ExprValue
import org.partiql.spi.function.Function
import org.partiql.spi.function.Fn
import org.partiql.spi.value.Datum

/**
Expand All @@ -12,14 +12,14 @@ import org.partiql.spi.value.Datum
* @property args Input argument expressions.
*/
internal class ExprCall(
private var function: Function.Instance,
private var function: Fn,
private var args: Array<ExprValue>,
) : ExprValue {

private var isNullCall: Boolean = function.isNullCall
private var isMissingCall: Boolean = function.isMissingCall
private var nil = { Datum.nullValue(function.returns) }
private var missing = { Datum.missing(function.returns) }
private var isNullCall: Boolean = function.signature.isNullCall
private var isMissingCall: Boolean = function.signature.isMissingCall
private var nil = { Datum.nullValue(function.signature.returns) }
private var missing = { Datum.missing(function.signature.returns) }

override fun eval(env: Environment): Datum {
// Evaluate arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import org.partiql.eval.internal.helpers.PErrors
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.Candidate
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.DYNAMIC
import org.partiql.eval.internal.operator.rex.ExprCallDynamic.CoercionFamily.UNKNOWN
import org.partiql.spi.function.Function
import org.partiql.spi.function.Fn
import org.partiql.spi.function.FnProvider
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum

Expand All @@ -31,7 +32,7 @@ import org.partiql.spi.value.Datum
*/
internal class ExprCallDynamic(
private val name: String,
private val functions: Array<Function>,
private val functions: Array<FnProvider>,
private val args: Array<ExprValue>
) : ExprValue {

Expand Down Expand Up @@ -69,13 +70,13 @@ internal class ExprCallDynamic(
val argFamilies = args.map { family(it.code()) }
functions.indices.forEach { candidateIndex ->
var currentExactMatches = 0
val params = functions[candidateIndex].getInstance(args.toTypedArray())?.parameters ?: return@forEach
val params = functions[candidateIndex].getInstance(args.toTypedArray())?.signature?.parameters ?: return@forEach
for (paramIndex in paramIndices) {
val argType = args[paramIndex]
val paramType = params[paramIndex]
if (paramType.code() == argType.code()) { currentExactMatches++ } // TODO: Convert all functions to use the new modelling, or else we need to only check kinds
if (paramType.type.code() == argType.code()) { currentExactMatches++ } // TODO: Convert all functions to use the new modelling, or else we need to only check kinds
val argFamily = argFamilies[paramIndex]
val paramFamily = family(paramType.code())
val paramFamily = family(paramType.type.code())
if (paramFamily != argFamily && argFamily != UNKNOWN && paramFamily != DYNAMIC) { return@forEach }
}
if (currentExactMatches > exactMatches) {
Expand Down Expand Up @@ -160,28 +161,28 @@ internal class ExprCallDynamic(
*
* @see ExprCallDynamic
*/
private class Candidate(private var function: Function.Instance) {
private class Candidate(private var function: Fn) {

private var nil = { Datum.nullValue(function.returns) }
private var missing = { Datum.missing(function.returns) }
private var nil = { Datum.nullValue(function.signature.returns) }
private var missing = { Datum.missing(function.signature.returns) }

/**
* Function instance parameters (just types).
*/
fun eval(args: Array<Datum>): Datum {
val coerced = Array(args.size) { i ->
val arg = args[i]
if (function.isNullCall && arg.isNull) {
if (function.signature.isNullCall && arg.isNull) {
return nil.invoke()
}
if (function.isMissingCall && arg.isMissing) {
if (function.signature.isMissingCall && arg.isMissing) {
return missing.invoke()
}
val argType = arg.type
val paramType = function.parameters[i]
when (paramType == argType) {
val paramType = function.signature.parameters[i]
when (paramType.type == argType) {
true -> arg
false -> CastTable.cast(arg, paramType)
false -> CastTable.cast(arg, paramType.type)
}
}
return function.invoke(coerced)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.Mode
import org.partiql.eval.compiler.PartiQLCompiler
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
import org.partiql.value.PartiQLValue
import org.partiql.value.bagValue
Expand Down Expand Up @@ -1412,14 +1413,10 @@ class PartiQLEvaluatorTest {
fun developmentTest() {
val tc =
SuccessTestCase(
input = "SELECT DISTINCT VALUE t * 100 FROM <<0, 1, 2.0, 3.0>> AS t;",
expected = bagValue(
int32Value(0),
int32Value(100),
decimalValue(BigDecimal.valueOf(2000, 1)),
decimalValue(BigDecimal.valueOf(3000, 1)),
),
mode = Mode.STRICT()
input = """
non_existing_column = 1
""".trimIndent(),
expected = Datum.nullValue(PType.bool())
)
tc.run()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.Environment
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.spi.function.Function
import org.partiql.spi.function.FnProvider
import org.partiql.spi.function.Parameter
import org.partiql.spi.types.PType
import org.partiql.spi.value.Datum
Expand Down Expand Up @@ -58,32 +58,15 @@ class ExprCallDynamicTest {
PartiQLValueType.LIST to PartiQLValueType.ANY, // Index 11
PartiQLValueType.ANY to PartiQLValueType.ANY, // Index 12
)

internal val functions: Array<Function> = params.mapIndexed { index, it ->
object : Function {

override fun getName(): String {
return "example"
}

override fun getParameters(): Array<Parameter> {
return arrayOf(Parameter("lhs", it.first.toPType()), Parameter("rhs", it.second.toPType()))
}

override fun getReturnType(args: Array<PType>): PType {
return PType.integer()
}

override fun getInstance(args: Array<PType>): Function.Instance {
return object : Function.Instance(
name = "example",
returns = PType.integer(),
parameters = arrayOf(it.first.toPType(), it.second.toPType())
) {
override fun invoke(args: Array<Datum>): Datum = integer(index)
}
}
}
internal val functions: Array<FnProvider> = params.mapIndexed { index, it ->
FnProvider.Builder("example")
.returns(PType.integer())
.addParameters(
Parameter("lhs", it.first.toPType()),
Parameter("rhs", it.second.toPType())
)
.body { integer(index) }
.build()
}.toTypedArray()
}
}
Expand Down
12 changes: 6 additions & 6 deletions partiql-plan/api/partiql-plan.api
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ public abstract interface class org/partiql/plan/Operators {
public abstract fun aggregate (Lorg/partiql/plan/rel/Rel;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/rel/RelAggregate;
public abstract fun array (Ljava/util/List;)Lorg/partiql/plan/rex/RexArray;
public abstract fun bag (Ljava/util/Collection;)Lorg/partiql/plan/rex/RexBag;
public abstract fun call (Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public abstract fun call (Lorg/partiql/spi/function/Fn;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public abstract fun caseWhen (Lorg/partiql/plan/rex/Rex;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase;
public abstract fun cast (Lorg/partiql/plan/rex/Rex;Lorg/partiql/spi/types/PType;)Lorg/partiql/plan/rex/RexCast;
public abstract fun coalesce (Ljava/util/List;)Lorg/partiql/plan/rex/RexCoalesce;
Expand Down Expand Up @@ -344,7 +344,7 @@ public final class org/partiql/plan/Operators$DefaultImpls {
public static fun aggregate (Lorg/partiql/plan/Operators;Lorg/partiql/plan/rel/Rel;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/rel/RelAggregate;
public static fun array (Lorg/partiql/plan/Operators;Ljava/util/List;)Lorg/partiql/plan/rex/RexArray;
public static fun bag (Lorg/partiql/plan/Operators;Ljava/util/Collection;)Lorg/partiql/plan/rex/RexBag;
public static fun call (Lorg/partiql/plan/Operators;Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public static fun call (Lorg/partiql/plan/Operators;Lorg/partiql/spi/function/Fn;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public static fun caseWhen (Lorg/partiql/plan/Operators;Lorg/partiql/plan/rex/Rex;Ljava/util/List;Lorg/partiql/plan/rex/Rex;)Lorg/partiql/plan/rex/RexCase;
public static fun cast (Lorg/partiql/plan/Operators;Lorg/partiql/plan/rex/Rex;Lorg/partiql/spi/types/PType;)Lorg/partiql/plan/rex/RexCast;
public static fun coalesce (Lorg/partiql/plan/Operators;Ljava/util/List;)Lorg/partiql/plan/rex/RexCoalesce;
Expand Down Expand Up @@ -408,14 +408,14 @@ public abstract class org/partiql/plan/rel/RelAggregate : org/partiql/plan/rel/R
public abstract fun getGroups ()Ljava/util/List;
public abstract fun getInput ()Lorg/partiql/plan/rel/Rel;
public abstract fun getMeasures ()Ljava/util/List;
public static fun measure (Lorg/partiql/spi/function/Aggregation;Ljava/util/List;Z)Lorg/partiql/plan/rel/RelAggregate$Measure;
public static fun measure (Lorg/partiql/spi/function/Agg;Ljava/util/List;Z)Lorg/partiql/plan/rel/RelAggregate$Measure;
protected final fun operands ()Ljava/util/List;
protected final fun type ()Lorg/partiql/plan/rel/RelType;
}

public class org/partiql/plan/rel/RelAggregate$Measure {
public fun copy (Ljava/util/List;)Lorg/partiql/plan/rel/RelAggregate$Measure;
public fun getAgg ()Lorg/partiql/spi/function/Aggregation;
public fun getAgg ()Lorg/partiql/spi/function/Agg;
public fun getArgs ()Ljava/util/List;
public fun isDistinct ()Z
}
Expand Down Expand Up @@ -659,9 +659,9 @@ public abstract class org/partiql/plan/rex/RexBase : org/partiql/plan/rex/Rex {
public abstract class org/partiql/plan/rex/RexCall : org/partiql/plan/rex/RexBase {
public fun <init> ()V
public fun accept (Lorg/partiql/plan/OperatorVisitor;Ljava/lang/Object;)Ljava/lang/Object;
public static fun create (Lorg/partiql/spi/function/Function$Instance;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public static fun create (Lorg/partiql/spi/function/Fn;Ljava/util/List;)Lorg/partiql/plan/rex/RexCall;
public abstract fun getArgs ()Ljava/util/List;
public abstract fun getFunction ()Lorg/partiql/spi/function/Function$Instance;
public abstract fun getFunction ()Lorg/partiql/spi/function/Fn;
protected fun operands ()Ljava/util/List;
protected fun type ()Lorg/partiql/plan/rex/RexType;
}
Expand Down
Loading

0 comments on commit 25043b2

Please sign in to comment.