Skip to content

Commit

Permalink
WIP - Add null support
Browse files Browse the repository at this point in the history
Code is mostly copy-pasted from the PR graphql-go#536
  • Loading branch information
AndrewSisley committed Sep 12, 2022
1 parent b2134d2 commit 8eafea7
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 22 deletions.
1 change: 1 addition & 0 deletions language/ast/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var _ Node = (*IntValue)(nil)
var _ Node = (*FloatValue)(nil)
var _ Node = (*StringValue)(nil)
var _ Node = (*BooleanValue)(nil)
var _ Node = (*NullValue)(nil)
var _ Node = (*EnumValue)(nil)
var _ Node = (*ListValue)(nil)
var _ Node = (*ObjectValue)(nil)
Expand Down
28 changes: 28 additions & 0 deletions language/ast/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ var _ Value = (*IntValue)(nil)
var _ Value = (*FloatValue)(nil)
var _ Value = (*StringValue)(nil)
var _ Value = (*BooleanValue)(nil)
var _ Value = (*NullValue)(nil)
var _ Value = (*EnumValue)(nil)
var _ Value = (*ListValue)(nil)
var _ Value = (*ObjectValue)(nil)
Expand Down Expand Up @@ -172,6 +173,33 @@ func (v *BooleanValue) GetValue() interface{} {
return v.Value
}

type NullValue struct {
Kind string
Loc *Location
Value interface{}
}

func NewNullValue(v *NullValue) *NullValue {

return &NullValue{
Kind: kinds.NullValue,
Loc: v.Loc,
Value: v.Value,
}
}

func (v *NullValue) GetKind() string {
return v.Kind
}

func (v *NullValue) GetLoc() *Location {
return v.Loc
}

func (v *NullValue) GetValue() interface{} {
return nil
}

// EnumValue implements Node, Value
type EnumValue struct {
Kind string
Expand Down
1 change: 1 addition & 0 deletions language/kinds/kinds.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
FloatValue = "FloatValue"
StringValue = "StringValue"
BooleanValue = "BooleanValue"
NullValue = "NullValue"
EnumValue = "EnumValue"
ListValue = "ListValue"
ObjectValue = "ObjectValue"
Expand Down
14 changes: 11 additions & 3 deletions language/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,23 @@ func parseValueLiteral(parser *Parser, isConst bool) (ast.Value, error) {
Value: value,
Loc: loc(parser, token.Start),
}), nil
} else if token.Value != "null" {
} else if token.Value == "null" {
if err := advance(parser); err != nil {
return nil, err
}
return ast.NewEnumValue(&ast.EnumValue{
Value: token.Value,
return ast.NewNullValue(&ast.NullValue{
Value: nil,
Loc: loc(parser, token.Start),
}), nil
}

if err := advance(parser); err != nil {
return nil, err
}
return ast.NewEnumValue(&ast.EnumValue{
Value: token.Value,
Loc: loc(parser, token.Start),
}), nil
case lexer.DOLLAR:
if !isConst {
return parseVariable(parser)
Expand Down
9 changes: 9 additions & 0 deletions language/printer/printer.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,15 @@ var printDocASTReducer = map[string]visitor.VisitFunc{
}
return visitor.ActionNoChange, nil
},
"NullValue": func(p visitor.VisitFuncParams) (string, interface{}) {
switch node := p.Node.(type) {
case *ast.NullValue:
return visitor.ActionUpdate, fmt.Sprintf("%v", node.Value)
case map[string]interface{}:
return visitor.ActionUpdate, getMapValueString(node, "Value")
}
return visitor.ActionNoChange, nil
},
"EnumValue": func(p visitor.VisitFuncParams) (string, interface{}) {
switch node := p.Node.(type) {
case *ast.EnumValue:
Expand Down
10 changes: 7 additions & 3 deletions rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,10 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) {
return true, nil
}

if valueAST.GetKind() == kinds.NullValue {
return true, nil
}

// This function only tests literals, and assumes variables will provide
// values of the correct type.
if valueAST.GetKind() == kinds.Variable {
Expand All @@ -1742,7 +1746,7 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) {
if e := ttype.Error(); e != nil {
return false, []string{e.Error()}
}
if valueAST == nil {
if valueAST == nil || valueAST.GetKind() == kinds.NullValue {
if ttype.OfType.Name() != "" {
return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())}
}
Expand Down Expand Up @@ -1797,11 +1801,11 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) {
}
return (len(messagesReduce) == 0), messagesReduce
case *Scalar:
if isNullish(ttype.ParseLiteral(valueAST)) {
if !isNullish(ttype.ParseLiteral(valueAST)) {
return false, []string{fmt.Sprintf(`Expected type "%v", found %v.`, ttype.Name(), printer.Print(valueAST))}
}
case *Enum:
if isNullish(ttype.ParseLiteral(valueAST)) {
if !isNullish(ttype.ParseLiteral(valueAST)) {
return false, []string{fmt.Sprintf(`Expected type "%v", found %v.`, ttype.Name(), printer.Print(valueAST))}
}
}
Expand Down
63 changes: 47 additions & 16 deletions values.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"github.com/graphql-go/graphql/language/printer"
)

// Used to detect the difference between a "null" literal and not present
type nullValue struct{}

// Prepares an object map of variableValues of the correct type based on the
// provided variable definitions and arbitrary input. If the input cannot be
// parsed to match the variable definitions, a GraphQLError will be returned.
Expand All @@ -27,7 +30,7 @@ func getVariableValues(
continue
}
varName := defAST.Variable.Name.Value
if varValue, err := getVariableValue(schema, defAST, inputs[varName]); err != nil {
if varValue, err := getVariableValue(schema, defAST, getValueOrNull(inputs, varName)); err != nil {
return values, err
} else {
values[varName] = varValue
Expand All @@ -36,6 +39,25 @@ func getVariableValues(
return values, nil
}

func getValueOrNull(values map[string]interface{}, name string) interface{} {
if tmp, ok := values[name]; ok { // Is present
if tmp == nil {
return nullValue{} // Null value
} else {
return tmp
}
}
return nil // Not present
}

func addValueOrNull(values map[string]interface{}, name string, value interface{}) {
if _, ok := value.(nullValue); ok { // Null value
values[name] = nil
} else if !isNullish(value) { // Not present
values[name] = value
}
}

// Prepares an object map of argument values given a list of argument
// definitions and list of argument AST nodes.
func getArgumentValues(
Expand All @@ -60,9 +82,7 @@ func getArgumentValues(
if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) {
tmp = argDef.DefaultValue
}
if !isNullish(tmp) {
results[argDef.PrivateName] = tmp
}
addValueOrNull(results, argDef.PrivateName, tmp)
}
return results
}
Expand Down Expand Up @@ -97,7 +117,7 @@ func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, inpu
}
return coerceValue(ttype, input), nil
}
if isNullish(input) {
if _, ok := input.(nullValue); ok || isNullish(input) {
return "", gqlerrors.NewError(
fmt.Sprintf(`Variable "$%v" of required type `+
`"%v" was not provided.`, variable.Name.Value, printer.Print(definitionAST.Type)),
Expand Down Expand Up @@ -134,6 +154,11 @@ func coerceValue(ttype Input, value interface{}) interface{} {
if isNullish(value) {
return nil
}

if _, ok := value.(nullValue); ok {
return nullValue{}
}

switch ttype := ttype.(type) {
case *NonNull:
return coerceValue(ttype.OfType, value)
Expand All @@ -156,13 +181,11 @@ func coerceValue(ttype Input, value interface{}) interface{} {
}

for name, field := range ttype.Fields() {
fieldValue := coerceValue(field.Type, valueMap[name])
fieldValue := coerceValue(field.Type, getValueOrNull(valueMap, name))
if isNullish(fieldValue) {
fieldValue = field.DefaultValue
}
if !isNullish(fieldValue) {
obj[name] = fieldValue
}
addValueOrNull(obj, name, fieldValue)
}
return obj
case *Scalar:
Expand Down Expand Up @@ -212,7 +235,7 @@ func typeFromAST(schema Schema, inputTypeAST ast.Type) (Type, error) {
// accepted for that type. This is primarily useful for validating the
// runtime values of query variables.
func isValidInputValue(value interface{}, ttype Input) (bool, []string) {
if isNullish(value) {
if _, ok := value.(nullValue); ok || isNullish(value) {
if ttype, ok := ttype.(*NonNull); ok {
if ttype.OfType.Name() != "" {
return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())}
Expand All @@ -233,9 +256,14 @@ func isValidInputValue(value interface{}, ttype Input) (bool, []string) {
messagesReduce := []string{}
for i := 0; i < valType.Len(); i++ {
val := valType.Index(i).Interface()
_, messages := isValidInputValue(val, ttype.OfType)
for idx, message := range messages {
messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, idx+1, message))
var messages []string
if _, ok := val.(nullValue); ok {
messages = []string{"Unexpected null value."}
} else {
_, messages = isValidInputValue(val, ttype.OfType)
}
for _, message := range messages {
messagesReduce = append(messagesReduce, fmt.Sprintf(`In element #%v: %v`, i+1, message))
}
}
return (len(messagesReduce) == 0), messagesReduce
Expand Down Expand Up @@ -352,6 +380,11 @@ func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interfac
if valueAST == nil {
return nil
}

if valueAST.GetKind() == kinds.NullValue {
return nullValue{}
}

// precedence: value > type
if valueAST, ok := valueAST.(*ast.Variable); ok {
if valueAST.Name == nil || variables == nil {
Expand Down Expand Up @@ -398,9 +431,7 @@ func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interfac
} else {
value = field.DefaultValue
}
if !isNullish(value) {
obj[name] = value
}
addValueOrNull(obj, name, value)
}
return obj
case *Scalar:
Expand Down

0 comments on commit 8eafea7

Please sign in to comment.