diff --git a/ast/walk.go b/ast/walk.go index 3fc4f66..c5f1b94 100644 --- a/ast/walk.go +++ b/ast/walk.go @@ -11,7 +11,7 @@ import ( // The result must be the next action to be performed on each // of the child nodes. If no action is required and this branch // should be dropped, nil should be returned. -type Visitor func(Node) (Visitor, error) +type Visitor func(Node) Visitor // Preorder\top-down traversal. // Visit a parent node before visiting its children. @@ -26,17 +26,15 @@ type Visitor func(Node) (Visitor, error) // - - List (length = 0) // - - nil (List) // - nil (List) -func WalkTopDown(visit Visitor, tree Node) error { +func WalkTopDown(visit Visitor, tree Node) { if tree == nil { panic("can't walk a nil node") } - if v, err := visit(tree); err != nil { - return err - } else if v == nil { - return nil - } else { + if v := visit(tree); v != nil { visit = v + } else { + return } switch n := tree.(type) { @@ -47,189 +45,120 @@ func WalkTopDown(visit Visitor, tree Node) error { assert.Ok(n.Binding.Name != nil) assert.Ok(n.Binding.Type != nil || (n.Value != nil && n.Operator != nil)) - if err := WalkTopDown(visit, n.Binding.Name); err != nil { - return err - } + WalkTopDown(visit, n.Binding.Name) if n.Type != nil { - if err := WalkTopDown(visit, n.Binding.Type); err != nil { - return err - } + WalkTopDown(visit, n.Binding.Type) } if n.Value != nil { - if err := WalkTopDown(visit, n.Operator); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Value); err != nil { - return err - } + WalkTopDown(visit, n.Operator) + WalkTopDown(visit, n.Value) } case *Signature: assert.Ok(n.Params != nil) - if err := walkExprList(visit, n.Params.ExprList); err != nil { - return err - } + walkExprList(visit, n.Params.ExprList) if n.Result != nil { - if err := WalkTopDown(visit, n.Result); err != nil { - return err - } + WalkTopDown(visit, n.Result) } case *Call: assert.Ok(n.X != nil) assert.Ok(n.Args != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } - - if err := walkExprList(visit, n.Args.ExprList); err != nil { - return err - } + WalkTopDown(visit, n.X) + walkExprList(visit, n.Args.ExprList) case *Index: assert.Ok(n.X != nil) assert.Ok(n.Args != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } - - if err := walkExprList(visit, n.Args.ExprList); err != nil { - return err - } + WalkTopDown(visit, n.X) + walkExprList(visit, n.Args.ExprList) case *ArrayType: assert.Ok(n.X != nil) assert.Ok(n.Args != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } - - if err := walkExprList(visit, n.Args.ExprList); err != nil { - return err - } + WalkTopDown(visit, n.X) + walkExprList(visit, n.Args.ExprList) case *MemberAccess: assert.Ok(n.X != nil) assert.Ok(n.Selector != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Selector); err != nil { - return err - } + WalkTopDown(visit, n.X) + WalkTopDown(visit, n.Selector) case *PrefixOp: assert.Ok(n.X != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } + WalkTopDown(visit, n.X) case *InfixOp: assert.Ok(n.X != nil) assert.Ok(n.Y != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Y); err != nil { - return err - } + WalkTopDown(visit, n.X) + WalkTopDown(visit, n.Y) case *PostfixOp: assert.Ok(n.X != nil) - if err := WalkTopDown(visit, n.X); err != nil { - return err - } + WalkTopDown(visit, n.X) case *List: - if err := walkList(visit, n); err != nil { - return err - } + walkList(visit, n) case *ExprList: - if err := walkExprList(visit, n); err != nil { - return err - } + walkExprList(visit, n) case *BracketList: - if err := walkExprList(visit, n.ExprList); err != nil { - return err - } + walkExprList(visit, n.ExprList) case *ParenList: - if err := walkExprList(visit, n.ExprList); err != nil { - return err - } + walkExprList(visit, n.ExprList) case *CurlyList: - if err := walkList(visit, n.List); err != nil { - return err - } + walkList(visit, n.List) case *AttributeList: assert.Ok(n.List != nil) - if err := walkExprList(visit, n.List.ExprList); err != nil { - return err - } + walkExprList(visit, n.List.ExprList) case *BuiltInCall: assert.Ok(n.Name != nil) assert.Ok(n.Args != nil) - if err := WalkTopDown(visit, n.Name); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Args); err != nil { - return err - } + WalkTopDown(visit, n.Name) + WalkTopDown(visit, n.Args) case *ModuleDecl: assert.Ok(n.Name != nil) assert.Ok(n.Body != nil) if n.Attrs != nil { - if err := WalkTopDown(visit, n.Attrs); err != nil { - return err - } + WalkTopDown(visit, n.Attrs) } - if err := WalkTopDown(visit, n.Name); err != nil { - return err - } + WalkTopDown(visit, n.Name) switch b := n.Body.(type) { case *List: - if err := walkList(visit, b); err != nil { - return err - } + walkList(visit, b) case *ExprList: - if err := walkExprList(visit, b); err != nil { - return err - } + walkExprList(visit, b) case *CurlyList: - if err := walkList(visit, b.List); err != nil { - return err - } + walkList(visit, b.List) default: - return fmt.Errorf("unexpected node type '%T' for module body", n.Body) + panic(fmt.Sprintf("unexpected node type '%T' for module body", n.Body)) } case *VarDecl: @@ -237,62 +166,32 @@ func WalkTopDown(visit Visitor, tree Node) error { assert.Ok(n.Binding.Type != nil || n.Value != nil) if n.Attrs != nil { - if err := WalkTopDown(visit, n.Attrs); err != nil { - return err - } + WalkTopDown(visit, n.Attrs) } - if err := WalkTopDown(visit, n.Binding.Name); err != nil { - return err - } + WalkTopDown(visit, n.Binding.Name) if n.Binding.Type != nil { - if err := WalkTopDown(visit, n.Binding.Type); err != nil { - return err - } + WalkTopDown(visit, n.Binding.Type) } if n.Value != nil { - if err := WalkTopDown(visit, n.Binding.Type); err != nil { - return err - } + WalkTopDown(visit, n.Binding.Type) } - // case *GenericDecl: - // assert.Ok(n.Field != nil) - - // if n.Attrs != nil { - // if err := WalkTopDown(visit, n.Attrs); err != nil { - // return err - // } - // } - - // if err := WalkTopDown(visit, n.Field); err != nil { - // return err - // } - case *FuncDecl: assert.Ok(n.Name != nil) assert.Ok(n.Signature != nil) if n.Attrs != nil { - if err := WalkTopDown(visit, n.Attrs); err != nil { - return err - } + WalkTopDown(visit, n.Attrs) } - if err := WalkTopDown(visit, n.Name); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Signature); err != nil { - return err - } + WalkTopDown(visit, n.Name) + WalkTopDown(visit, n.Signature) if n.Body != nil { - if err := WalkTopDown(visit, n.Body); err != nil { - return err - } + WalkTopDown(visit, n.Body) } case *TypeAliasDecl: @@ -300,75 +199,48 @@ func WalkTopDown(visit Visitor, tree Node) error { assert.Ok(n.Expr != nil) if n.Attrs != nil { - if err := WalkTopDown(visit, n.Attrs); err != nil { - return err - } + WalkTopDown(visit, n.Attrs) } - if err := WalkTopDown(visit, n.Name); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Expr); err != nil { - return err - } + WalkTopDown(visit, n.Name) + WalkTopDown(visit, n.Expr) case *If: assert.Ok(n.Cond != nil) assert.Ok(n.Body != nil) - if err := WalkTopDown(visit, n.Cond); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Body); err != nil { - return err - } + WalkTopDown(visit, n.Cond) + WalkTopDown(visit, n.Body) if n.Else != nil { - if err := WalkTopDown(visit, n.Else); err != nil { - return err - } + WalkTopDown(visit, n.Else) } case *Else: assert.Ok(n.Body != nil) - if err := WalkTopDown(visit, n.Body); err != nil { - return err - } + WalkTopDown(visit, n.Body) case *While: assert.Ok(n.Cond != nil) assert.Ok(n.Body != nil) - if err := WalkTopDown(visit, n.Cond); err != nil { - return err - } - - if err := WalkTopDown(visit, n.Body); err != nil { - return err - } + WalkTopDown(visit, n.Cond) + WalkTopDown(visit, n.Body) case *Return: if n.X != nil { - if err := WalkTopDown(visit, n.X); err != nil { - return err - } + WalkTopDown(visit, n.X) } case *Break: if n.Label != nil { - if err := WalkTopDown(visit, n.Label); err != nil { - return err - } + WalkTopDown(visit, n.Label) } case *Continue: if n.Label != nil { - if err := WalkTopDown(visit, n.Label); err != nil { - return err - } + WalkTopDown(visit, n.Label) } default: @@ -376,30 +248,21 @@ func WalkTopDown(visit Visitor, tree Node) error { panic(fmt.Sprintf("unknown node type '%T'", n)) } - _, err := visit(nil) - return err + visit(nil) } -func walkList(visit Visitor, list *List) error { +func walkList(visit Visitor, list *List) { if list != nil { for _, node := range list.Nodes { - if err := WalkTopDown(visit, node); err != nil { - return err - } + WalkTopDown(visit, node) } } - - return nil } -func walkExprList(visit Visitor, list *ExprList) error { +func walkExprList(visit Visitor, list *ExprList) { if list != nil { for _, node := range list.Exprs { - if err := WalkTopDown(visit, node); err != nil { - return err - } + WalkTopDown(visit, node) } } - - return nil } diff --git a/checker/block.go b/checker/block.go index 819276e..4f5b5f7 100644 --- a/checker/block.go +++ b/checker/block.go @@ -1,8 +1,6 @@ package checker import ( - "fmt" - "github.com/saffage/jet/ast" "github.com/saffage/jet/types" ) @@ -16,38 +14,30 @@ func NewBlock(scope *Scope) *Block { return &Block{scope, types.Unit} } -func (expr *Block) visit(node ast.Node) (ast.Visitor, error) { - switch node := node.(type) { - case ast.Decl: - switch decl := node.(type) { - case *ast.VarDecl: - if err := resolveVar(decl, expr.scope); err != nil { - return nil, err - } - - fmt.Printf(">>> def local var `%s`\n", decl.Binding.Name) +func (check *Checker) blockVisitor(expr *Block) ast.Visitor { + return func(node ast.Node) ast.Visitor { + if decl, _ := node.(ast.Decl); decl != nil { + switch decl := decl.(type) { + case *ast.VarDecl: + check.resolveVarDecl(decl) + expr.t = types.Unit - expr.t = types.Unit - return nil, nil + case *ast.TypeAliasDecl, *ast.FuncDecl, *ast.ModuleDecl: + panic("not implemented") - case *ast.TypeAliasDecl, *ast.FuncDecl, *ast.ModuleDecl: - panic("not implemented") + default: + panic("unreachable") + } - default: - panic("unreachable") + return nil } - default: - t, err := expr.scope.TypeOf(node) - if err != nil { - return nil, err + t := check.typeOf(node) + if t == nil { + return nil } expr.t = t - return nil, nil - - // fmt.Printf("unchecked node: '%T'\n", node) - // expr.t = types.Unit - // return nil, nil + return nil } } diff --git a/checker/builtin_fns.go b/checker/builtin_fns.go index 290435a..4475ead 100644 --- a/checker/builtin_fns.go +++ b/checker/builtin_fns.go @@ -5,68 +5,58 @@ import ( "github.com/saffage/jet/types" ) -func builtInMagic(args ast.Node, scope *Scope) (*Value, error) { - argList, ok := args.(*ast.ParenList) - if !ok { - return nil, NewError(args, "expected argument list") +func (check *Checker) builtInMagic(args ast.Node, scope *Scope) *TypedValue { + argList, _ := args.(*ast.ParenList) + if argList == nil { + check.errorf(args, "expected argument list") + return nil } - arg1, ok := argList.Exprs[0].(*ast.Literal) - if !ok { - return nil, NewError(argList.Exprs[0], "expected string literal") + arg1, _ := argList.Exprs[0].(*ast.Literal) + if arg1 == nil || arg1.Kind != ast.StringLiteral { + check.errorf(argList.Exprs[0], "expected string literal") + return nil } - // tArg1, err := scope.TypeOf(argList.Exprs[0]) - // if err != nil { - // return nil, NewError(argList.Exprs[0], "expected literal") - // } - - // if !types.Primitives[types.UntypedString].Equals(tArg1) { - // return nil, NewErrorf( - // argList.Exprs[0], - // "expected 'untyped string', got '%s' instead", - // tArg1, - // ) - // } - switch arg1.Value { case "Bool": - return &Value{types.NewTypeDesc(types.Primitives[types.Bool]), nil}, nil + return &TypedValue{types.NewTypeDesc(types.Primitives[types.Bool]), nil} case "I32": - return &Value{types.NewTypeDesc(types.Primitives[types.I32]), nil}, nil + return &TypedValue{types.NewTypeDesc(types.Primitives[types.I32]), nil} default: - return nil, NewErrorf(arg1, "unknown magic '%s'", arg1.Value) + check.errorf(arg1, "unknown magic '%s'", arg1.Value) + return nil } } -func builtInTypeOf(args ast.Node, scope *Scope) (*Value, error) { - argList, ok := args.(*ast.ParenList) - if !ok { - return nil, NewError(args, "expected argument list") +func (check *Checker) builtInTypeOf(args ast.Node, scope *Scope) *TypedValue { + argList, _ := args.(*ast.ParenList) + if argList == nil { + check.errorf(args, "expected argument list") + return nil } - arg1 := argList.Exprs[0] - - t, err := scope.TypeOf(arg1) - if err != nil { - return nil, err + t := check.typeOf(argList.Exprs[0]) + if t == nil { + return nil } - return &Value{types.NewTypeDesc(types.SkipUntyped(t)), nil}, nil + return &TypedValue{types.NewTypeDesc(types.SkipUntyped(t)), nil} } -func builtInPrint(args ast.Node, scope *Scope) (*Value, error) { - argList, ok := args.(*ast.ParenList) - if !ok { - return nil, NewError(args, "expected argument list") +func (check *Checker) builtInPrint(args ast.Node, scope *Scope) *TypedValue { + argList, _ := args.(*ast.ParenList) + if argList == nil { + check.errorf(args, "expected argument list") + return nil } - _, err := scope.TypeOf(argList) - if err != nil { - return nil, err + t := check.typeOf(argList) + if t == nil { + return nil } - return &Value{types.Unit, nil}, nil + return &TypedValue{types.Unit, nil} } diff --git a/checker/builtins.go b/checker/builtins.go index 1a48865..a75f102 100644 --- a/checker/builtins.go +++ b/checker/builtins.go @@ -5,7 +5,7 @@ import ( "github.com/saffage/jet/types" ) -type BuiltInFn func(args ast.Node, scope *Scope) (*Value, error) +type BuiltInFn func(args ast.Node, scope *Scope) *TypedValue type BuiltIn struct { name string @@ -19,15 +19,11 @@ func (b *BuiltIn) Name() string { return b.name } func (b *BuiltIn) Ident() *ast.Ident { return nil } func (b *BuiltIn) Node() ast.Node { return nil } -func (b *BuiltIn) setType(t types.Type) { panic("can't change the type of the built-in function") } - -var builtIns []*BuiltIn - -func init() { - builtIns = []*BuiltIn{ +func (check *Checker) defBuiltIns() { + check.builtIns = []*BuiltIn{ { name: "magic", - f: builtInMagic, + f: check.builtInMagic, t: types.NewFunc( types.NewTuple(types.Primitives[types.AnyTypeDesc]), types.NewTuple(types.Primitives[types.UntypedString]), @@ -35,7 +31,7 @@ func init() { }, { name: "type_of", - f: builtInTypeOf, + f: check.builtInTypeOf, t: types.NewFunc( types.NewTuple(types.Primitives[types.AnyTypeDesc]), types.NewTuple(types.Primitives[types.Any]), @@ -43,7 +39,7 @@ func init() { }, { name: "print", - f: builtInPrint, + f: check.builtInPrint, t: types.NewFunc( types.Unit, types.NewTuple(types.Primitives[types.Any]), diff --git a/checker/checker.go b/checker/checker.go index bd6ee98..0734106 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -1,7 +1,134 @@ package checker -// type Checker struct { -// errors []error -// } +import ( + "os" -// func (check *Checker) Errors() []error { return check.errors } + "github.com/davecgh/go-spew/spew" + "github.com/saffage/jet/ast" + "github.com/saffage/jet/internal/assert" + "github.com/saffage/jet/types" +) + +type Checker struct { + types map[ast.Node]TypedValue + defs map[*ast.Ident]Symbol + uses map[*ast.Ident]Symbol + + module *Module + scope *Scope + builtIns []*BuiltIn + errors []error + isErrorHandled bool +} + +func Check(node *ast.ModuleDecl) []error { + module := NewModule(node) + check := &Checker{ + types: make(map[ast.Node]TypedValue), + defs: make(map[*ast.Ident]Symbol), + uses: make(map[*ast.Ident]Symbol), + module: module, + scope: module.scope, + errors: make([]error, 0), + isErrorHandled: true, + } + + check.defBuiltIns() + + { + nodes := []ast.Node(nil) + + switch body := node.Body.(type) { + case *ast.List: + nodes = body.Nodes + + case *ast.CurlyList: + nodes = body.List.Nodes + + default: + panic("ill-formed AST") + } + + for _, node := range nodes { + ast.WalkTopDown(check.visit, node) + } + + module.completed = true + } + + { + f, err := os.Create("checker_state.txt") + if err != nil { + panic(err) + } + defer f.Close() + spew.Fdump(f, check) + } + + return check.errors +} + +// Type checks 'expr' and returns its type. +// If error was occured, result is undefined +func (check *Checker) typeOf(expr ast.Node) types.Type { + if t, ok := check.types[expr]; ok { + return t.Type + } + + if v := check.valueOfInternal(expr); v != nil { + check.setValue(expr, *v) + return v.Type + } + + t := check.typeOfInternal(expr) + check.setType(expr, t) + return t +} + +func (check *Checker) valueOf(expr ast.Node) *TypedValue { + if t, ok := check.types[expr]; ok { + return &t + } + + if value := check.valueOfInternal(expr); value != nil { + check.setValue(expr, *value) + return value + } + + return nil +} + +func (check *Checker) setType(expr ast.Node, t types.Type) { + assert.Ok(expr != nil) + assert.Ok(t != nil) + + if check.types != nil { + check.types[expr] = TypedValue{t, nil} + } +} + +func (check *Checker) setValue(expr ast.Node, value TypedValue) { + assert.Ok(expr != nil) + assert.Ok(value.Type != nil) + + if check.types != nil { + check.types[expr] = value + } +} + +func (check *Checker) newDef(ident *ast.Ident, sym Symbol) { + assert.Ok(ident != nil) + + if check.defs != nil { + check.defs[ident] = sym + } +} + +func (check *Checker) newUse(ident *ast.Ident, sym Symbol) { + assert.Ok(ident != nil) + assert.Ok(sym != nil) + + if check.uses != nil { + check.uses[ident] = sym + } +} diff --git a/checker/error.go b/checker/error.go index a0bf6dd..637c7a2 100644 --- a/checker/error.go +++ b/checker/error.go @@ -9,23 +9,27 @@ import ( type Error struct { Message string Node ast.Node - Notes []Error + Notes []*Error } -func NewError(node ast.Node, message string) Error { - return Error{ +func NewError(node ast.Node, message string) *Error { + return &Error{ Message: message, Node: node, } } -func NewErrorf(node ast.Node, format string, args ...any) Error { - return Error{ +func NewErrorf(node ast.Node, format string, args ...any) *Error { + return &Error{ Message: fmt.Sprintf(format, args...), Node: node, } } -func (err Error) Error() string { - return err.Message +func (err *Error) Error() string { return err.Message } + +func (check *Checker) errorf(node ast.Node, format string, args ...any) { + err := NewErrorf(node, format, args...) + check.errors = append(check.errors, err) + check.isErrorHandled = false } diff --git a/checker/func.go b/checker/func.go index 2589664..c004524 100644 --- a/checker/func.go +++ b/checker/func.go @@ -1,21 +1,19 @@ package checker import ( - "fmt" - "github.com/saffage/jet/ast" "github.com/saffage/jet/types" ) type Func struct { owner *Scope - scope *Scope + local *Scope t *types.Func node *ast.FuncDecl } -func NewFunc(owner *Scope, t *types.Func, node *ast.FuncDecl) *Func { - return &Func{owner, NewScope(owner), t, node} +func NewFunc(owner *Scope, local *Scope, t *types.Func, node *ast.FuncDecl) *Func { + return &Func{owner, local, t, node} } func (sym *Func) Owner() *Scope { return sym.owner } @@ -23,49 +21,3 @@ func (sym *Func) Type() types.Type { return sym.t } func (sym *Func) Name() string { return sym.node.Name.Name } func (sym *Func) Ident() *ast.Ident { return sym.node.Name } func (sym *Func) Node() ast.Node { return sym.node } - -func (sym *Func) setType(t types.Type) { - if t, _ := t.(*types.Func); t != nil { - sym.t = t - return - } - - panic(fmt.Sprintf("type '%s' is not a function", t)) -} - -// func (sym *Func) Check(args []types.Type) error { -// if len(args) > len(t.Params) { -// return NewErrorf(node, "too many arguments (expected %d)", len(b.params)) -// } - -// if len(args) < len(t.Params) { -// return NewErrorf(node, "not enough arguments (expected %d)", len(b.params)) -// } - -// for i := 0; i < maxlen; i++ { -// if i < len(args.Nodes) { -// if type_, err := scope.TypeOf(args.Nodes[i]); err == nil { -// actual = type_ -// node = args.Nodes[i] -// } else { -// return err -// } -// } - -// if i < len(b.params) { -// expected = b.params[i] -// } - -// if expected == nil { -// } - -// if actual == nil { -// } - -// if !expected.Equals(actual) { -// return NewErrorf(node, "expected '%s' for %d argument but got '%s'", expected, i+1, actual) -// } -// } - -// return nil -// } diff --git a/checker/module.go b/checker/module.go index 047d903..8556d97 100644 --- a/checker/module.go +++ b/checker/module.go @@ -1,8 +1,6 @@ package checker import ( - "fmt" - "github.com/saffage/jet/ast" "github.com/saffage/jet/types" ) @@ -14,33 +12,12 @@ type Module struct { completed bool } -func NewModule(node *ast.ModuleDecl) (*Module, error) { - m := &Module{ - scope: NewScope(nil), - node: node, +func NewModule(node *ast.ModuleDecl) *Module { + return &Module{ + scope: NewScope(nil), + node: node, + completed: false, } - - nodes := []ast.Node(nil) - - switch body := node.Body.(type) { - case *ast.List: - nodes = body.Nodes - - case *ast.CurlyList: - nodes = body.List.Nodes - - default: - panic("ill-formed AST") - } - - for _, node := range nodes { - if err := ast.WalkTopDown(m.visit, node); err != nil { - return nil, err - } - } - - m.completed = true - return m, nil } func (m *Module) Owner() *Scope { panic("modules have no owner") } @@ -49,18 +26,13 @@ func (m *Module) Name() string { return m.node.Name.Name } func (m *Module) Ident() *ast.Ident { return m.node.Name } func (m *Module) Node() ast.Node { return m.node } -func (m *Module) setType(t types.Type) { panic("modules have no type") } - -func (m *Module) visit(node ast.Node) (ast.Visitor, error) { - if _, isEmpty := node.(*ast.Empty); isEmpty { - return nil, nil - } - +func (check *Checker) visit(node ast.Node) ast.Visitor { decl, isDecl := node.(ast.Decl) if !isDecl { // NOTE parser should prevent this in future. - return nil, NewError(node, "expected declaration") + check.errorf(node, "expected declaration") + return nil } switch decl := decl.(type) { @@ -68,49 +40,17 @@ func (m *Module) visit(node ast.Node) (ast.Visitor, error) { panic("not implemented") case *ast.VarDecl: - if err := resolveVar(decl, m.scope); err != nil { - return nil, err - } - - fmt.Printf(">>> def var `%s`\n", decl.Binding.Name) + check.resolveVarDecl(decl) case *ast.FuncDecl: - sym := NewFunc(m.scope, nil, decl) - - if defined := m.scope.Define(sym); defined != nil { - return nil, errorAlreadyDefined(sym.Ident(), defined.Ident()) - } - - fmt.Printf(">>> def func `%s`\n", decl.Name.Name) - - if err := resolveFuncDecl(sym); err != nil { - return nil, err - } + check.resolveFuncDecl(decl) case *ast.TypeAliasDecl: - t, err := m.scope.TypeOf(decl.Expr) - if err != nil { - return nil, err - } - - t = types.SkipAlias(t) - - typedesc, _ := t.(*types.TypeDesc) - if typedesc == nil { - return nil, NewErrorf(decl.Expr, "expression is not a type (%s)", t) - } - - sym := NewTypeAlias(m.scope, typedesc, decl) - - if defined := m.scope.Define(sym); defined != nil { - return nil, errorAlreadyDefined(sym.Ident(), defined.Ident()) - } - - fmt.Printf(">>> def alias `%s` for `%s`\n", sym.Name(), typedesc.Base()) + check.resolveTypeAliasDecl(decl) default: - panic(fmt.Sprintf("unhandled declaration kind (%T)", decl)) + panic("unreachable") } - return nil, nil + return nil } diff --git a/checker/resolver.go b/checker/resolver.go index 5db15ed..a20687e 100644 --- a/checker/resolver.go +++ b/checker/resolver.go @@ -7,98 +7,114 @@ import ( "github.com/saffage/jet/types" ) -func resolveVar(node *ast.VarDecl, scope *Scope) error { +func (check *Checker) resolveVarDecl(node *ast.VarDecl) { // 'tValue' can be nil. - tValue, err := resolveVarValue(node.Value, scope) - if err != nil { - return err + tValue, ok := check.resolveVarValue(node.Value) + if !ok { + return } // 'tType' must be not nil. - tType, err := resolveVarType(node.Binding.Type, tValue, scope) - if err != nil { - return err + tType := check.resolveVarType(node.Binding.Type, tValue) + if tType == nil { + return + } + + if tValue != nil { + fmt.Printf(">>> var value '%s'\n", tValue) } - fmt.Printf(">>> var value '%s'\n", tValue) fmt.Printf(">>> var type '%s'\n", tType) if tValue != nil && !tType.Equals(tValue) { - return NewErrorf( + check.errorf( node.Binding.Name, "type mismatch, expected '%s', got '%s'", tType, - tValue, - ) + tValue) + return } fmt.Printf(">>> var actual type '%s'\n", tType) - sym := NewVar(scope, tType, node.Binding, node.Binding.Name) + sym := NewVar(check.scope, tType, node.Binding, node.Binding.Name) - if defined := scope.Define(sym); defined != nil { - return errorAlreadyDefined(sym.Ident(), defined.Ident()) + if defined := check.scope.Define(sym); defined != nil { + err := errorAlreadyDefined(sym.Ident(), defined.Ident()) + check.errors = append(check.errors, err) + return } - return nil + check.newDef(node.Binding.Name, sym) + fmt.Printf(">>> def var `%s`\n", node.Binding.Name) } -func resolveVarValue(value ast.Node, scope *Scope) (types.Type, error) { +func (check *Checker) resolveVarValue(value ast.Node) (types.Type, bool) { if value != nil { - t, err := scope.TypeOf(value) - if err != nil { - return nil, err + t := check.typeOf(value) + if t == nil { + return nil, false } if types.IsTypeDesc(t) { - return nil, NewErrorf(value, "expected value, got type '%s' instead", t.Underlying()) + check.errorf(value, "expected value, got type '%s' instead", t) + return nil, false } - return types.SkipUntyped(t), nil + return types.SkipUntyped(t), true } - return nil, nil + return nil, true } -func resolveVarType(typeExpr ast.Node, value types.Type, scope *Scope) (types.Type, error) { - if typeExpr != nil { - t, err := scope.TypeOf(typeExpr) - if err != nil { - return t, err - } +func (check *Checker) resolveVarType(typeExpr ast.Node, value types.Type) types.Type { + if typeExpr == nil { + return value + } - if typedesc := types.AsTypeDesc(t); typedesc != nil { - return typedesc.Base(), nil - } + t := check.typeOf(typeExpr) + if t == nil { + return value + } - return nil, NewError(typeExpr, "expression is not a type") + typedesc := types.AsTypeDesc(t) + if typedesc == nil { + check.errorf(typeExpr, "expression is not a type") + return nil } - return value, nil + return typedesc.Base() } -func resolveFuncDecl(sym *Func) error { - sig := sym.node.Signature +func (check *Checker) resolveFuncDecl(node *ast.FuncDecl) { + sig := node.Signature tParams := []types.Type{} + local := NewScope(check.scope) for _, param := range sig.Params.Exprs { switch param := param.(type) { case *ast.Binding: - t, err := sym.owner.TypeOf(param.Type) - if err != nil { - return err + t := check.typeOf(param.Type) + if t == nil { + return } t = types.SkipTypeDesc(t) tParams = append(tParams, t) - paramSym := NewVar(sym.scope, t, nil, param.Name) - sym.scope.Define(paramSym) + paramSym := NewVar(local, t, param, param.Name) + if defined := local.Define(paramSym); defined != nil { + check.errorf(param, "paramter with the same name was already defined") + return + } + + check.newDef(param.Name, paramSym) fmt.Printf(">>> set `%s` type `%s`\n", paramSym.Name(), t) fmt.Printf(">>> def param `%s`\n", paramSym.Name()) case *ast.BindingWithValue: - return NewError(param, "parameters can't have the default value") + check.errorf(param, "parameters can't have a default value") + return default: panic(fmt.Sprintf("ill-formed AST: unexpected node type '%T'", param)) @@ -110,9 +126,9 @@ func resolveFuncDecl(sym *Func) error { tResult := types.Unit if sig.Result != nil { - t, err := sym.owner.TypeOf(sig.Result) - if err != nil { - return err + t := check.typeOf(sig.Result) + if t == nil { + return } tResult = types.NewTuple(types.SkipTypeDesc(t)) @@ -121,29 +137,74 @@ func resolveFuncDecl(sym *Func) error { // Produce function type. t := types.NewFunc(tResult, types.NewTuple(tParams...)) + sym := NewFunc(check.scope, local, t, node) + + fmt.Printf(">>> set `%s` type `%s`\n", sym.Name(), t) + + if defined := check.scope.Define(sym); defined != nil { + err := errorAlreadyDefined(sym.Ident(), defined.Ident()) + check.errors = append(check.errors, err) + return + } - sym.setType(t) - fmt.Printf(">>> set `%s` type `%s`\n", sym.Name(), t.String()) + // Define function symbol inside their scope for recursion. + local.Define(sym) // Body. - if sym.node.Body != nil { - tBody, err := sym.scope.TypeOf(sym.node.Body) - if err != nil { - return err - } + if sym.node.Body == nil { + check.errorf(sym.Ident(), "functions without body is not allowed") + return + } - if !tResult.Equals(tBody) { - return NewErrorf( - sym.node.Body.Nodes[len(sym.node.Body.Nodes)-1], - "expected expression of type '%s' for function result, got '%s' instead", - tResult, - tBody, - ) - } - } else { - return NewError(sym.Ident(), "functions without body is not allowed") + prevScope := check.scope + check.scope = local + + defer func() { + check.scope = prevScope + }() + + tBody := check.typeOf(sym.node.Body) + if tBody == nil { + return + } + + if !tResult.Equals(tBody) { + check.errorf( + sym.node.Body.Nodes[len(sym.node.Body.Nodes)-1], + "expected expression of type '%s' for function result, got '%s' instead", + tResult, + tBody, + ) + return + } + + check.newDef(node.Name, sym) + fmt.Printf(">>> def func `%s`\n", node.Name.Name) +} + +func (check *Checker) resolveTypeAliasDecl(node *ast.TypeAliasDecl) { + t := check.typeOf(node.Expr) + if t == nil { + return + } + + typedesc := types.AsTypeDesc(t) + + if typedesc == nil { + check.errorf(node.Expr, "expression is not a type (%s)", t) + return + } + + sym := NewTypeAlias(check.scope, typedesc, node) + + if defined := check.scope.Define(sym); defined != nil { + err := errorAlreadyDefined(sym.Ident(), defined.Ident()) + check.errors = append(check.errors, err) + return } - return nil + check.newDef(node.Name, sym) + check.setType(node, typedesc) + fmt.Printf(">>> def alias `%s`\n", node.Name.Name) } diff --git a/checker/scope.go b/checker/scope.go index b1806b8..7719e09 100644 --- a/checker/scope.go +++ b/checker/scope.go @@ -67,11 +67,11 @@ func (scope *Scope) Member(name string) Symbol { return nil } -func errorAlreadyDefined(ident, previous *ast.Ident) Error { - err := NewErrorf(ident, "name '%s' is already declared in this scope", ident.Name) +func errorAlreadyDefined(ident, previous *ast.Ident) *Error { + err := NewErrorf(ident, "name '%s' is already defined in this scope", ident.Name) if previous != nil { - err.Notes = []Error{ + err.Notes = []*Error{ NewError(previous, "previous declaration was here"), } } diff --git a/checker/symbol.go b/checker/symbol.go index 4c276de..9c94d46 100644 --- a/checker/symbol.go +++ b/checker/symbol.go @@ -11,6 +11,4 @@ type Symbol interface { Name() string // Identifier or name of a symbol. Ident() *ast.Ident // Identifier node. Node() ast.Node // Related AST node. - - setType(types.Type) } diff --git a/checker/type_alias.go b/checker/type_alias.go index d31a057..575dfba 100644 --- a/checker/type_alias.go +++ b/checker/type_alias.go @@ -26,7 +26,3 @@ func (sym *TypeAlias) Type() types.Type { return types.NewTypeDesc(sym.t) } func (sym *TypeAlias) Name() string { return sym.name.Name } func (sym *TypeAlias) Ident() *ast.Ident { return sym.name } func (sym *TypeAlias) Node() ast.Node { return sym.node } - -func (sym *TypeAlias) setType(t types.Type) { - sym.t = types.NewAlias(types.SkipTypeDesc(t), sym.Name()) -} diff --git a/checker/type_resolver.go b/checker/type_resolver.go index 3d7e0dd..3b012a8 100644 --- a/checker/type_resolver.go +++ b/checker/type_resolver.go @@ -3,6 +3,7 @@ package checker import ( "fmt" "math" + "slices" "strconv" "github.com/saffage/jet/ast" @@ -10,8 +11,9 @@ import ( "github.com/saffage/jet/types" ) -// Return type is never nil, if no error. -func (scope *Scope) TypeOf(expr ast.Node) (types.Type, error) { +// Type checks 'expr' and returns its type. +// If error was occured, result is undefined +func (check *Checker) typeOfInternal(expr ast.Node) types.Type { switch node := expr.(type) { case nil: panic("got nil not for expr") @@ -30,55 +32,55 @@ func (scope *Scope) TypeOf(expr ast.Node) (types.Type, error) { panic("ill-formed AST") case *ast.Empty: - return types.Unit, nil + return types.Unit case *ast.Ident: - return typeCheckIdent(node, scope) + return check.typeOfIdent(node) case *ast.Literal: - return typeCheckLiteral(node) + return check.typeOfLiteral(node) // case *ast.Operator: // panic("not implemented") case *ast.BuiltInCall: - return typeCheckBuiltInCall(node, scope) + return check.typeOfBuiltInCall(node) case *ast.Call: - return typeCheckCall(node, scope) + return check.typeOfCall(node) case *ast.Index: - return typeCheckIndex(node, scope) + return check.typeOfIndex(node) case *ast.ArrayType: - return typeCheckArrayType(node, scope) + return check.typeOfArrayType(node) case *ast.Signature: - return typeCheckSignature(node, scope) + return check.typeOfSignature(node) case *ast.PrefixOp: - return typeCheckPrefixOp(node, scope) + return check.typeOfPrefixOp(node) case *ast.InfixOp: - return typeCheckInfixOp(node, scope) + return check.typeOfInfixOp(node) case *ast.PostfixOp: - return typeCheckPostfixOp(node, scope) + return check.typeOfPostfixOp(node) case *ast.BracketList: - return typeCheckBracketList(node, scope) + return check.typeOfBracketList(node) case *ast.ParenList: - return typeCheckParenList(node, scope) + return check.typeOfParenList(node) case *ast.CurlyList: - return typeCheckCurlyList(scope, node) + return check.typeOfCurlyList(node) case *ast.If: - return typeCheckIf(node, scope) + return check.typeOfIf(node) case *ast.While: - return nil, typeCheckWhile(node, scope) + return check.typeOfWhile(node) // case *ast.Return, *ast.Break, *ast.Continue: // panic("not implemented") @@ -88,267 +90,261 @@ func (scope *Scope) TypeOf(expr ast.Node) (types.Type, error) { } } -// For the `_` identifier the result is (nil, nil). This is the -// only 1 way to get this result. -func (scope *Scope) ValueOf(expr ast.Node) (*Value, error) { +func (check *Checker) valueOfInternal(expr ast.Node) *TypedValue { switch node := expr.(type) { case *ast.Literal: value := constant.FromNode(node) type_ := types.FromConstant(value) - return &Value{ - Type: type_, - Value: value, - }, nil + return &TypedValue{type_, value} - case *ast.Ident: - if node.Name == "_" { - return nil, nil - } - - if sym, _ := scope.Lookup(node.Name); sym != nil { - panic("constants are not implemented") - } + // case *ast.Ident: + // panic("constants are not implemented") - return nil, NewErrorf(node, "identifier `%s` is undefined", node.Name) - - case *ast.PrefixOp, *ast.PostfixOp, *ast.InfixOp: - panic("not implemented") + // case *ast.PrefixOp, *ast.PostfixOp, *ast.InfixOp: + // panic("not implemented") } - return nil, NewError(expr, "expression is not a constant value") + return nil } -func (scope *Scope) SymbolOf(ident *ast.Ident) Symbol { - if sym, _ := scope.Lookup(ident.Name); sym != nil { +func (check *Checker) symbolOf(ident *ast.Ident) Symbol { + if sym, _ := check.scope.Lookup(ident.Name); sym != nil { return sym } return nil } -func typeCheckIdent(node *ast.Ident, scope *Scope) (types.Type, error) { - if sym := scope.SymbolOf(node); sym != nil { - if sym.Type() == nil { - return nil, NewErrorf(node, "expression `%s` has no type", node.Name) +func (check *Checker) typeOfIdent(node *ast.Ident) types.Type { + if sym := check.symbolOf(node); sym != nil { + if sym.Type() != nil { + return sym.Type() } - return sym.Type(), nil + check.errorf(node, "expression `%s` has no type", node.Name) + return nil } - return nil, NewErrorf(node, "identifier `%s` is undefined", node.Name) + check.errorf(node, "identifier `%s` is undefined", node.Name) + return nil } -func typeCheckLiteral(node *ast.Literal) (types.Type, error) { +func (check *Checker) typeOfLiteral(node *ast.Literal) types.Type { switch node.Kind { case ast.IntLiteral: - return types.Primitives[types.UntypedInt], nil + return types.Primitives[types.UntypedInt] case ast.FloatLiteral: - return types.Primitives[types.UntypedFloat], nil + return types.Primitives[types.UntypedFloat] case ast.StringLiteral: - return types.Primitives[types.UntypedString], nil + return types.Primitives[types.UntypedString] default: panic(fmt.Sprintf("unhandled literal kind: '%s'", node.Kind.String())) } } -func typeCheckBuiltInCall(node *ast.BuiltInCall, scope *Scope) (types.Type, error) { - var builtIn *BuiltIn +func (check *Checker) typeOfBuiltInCall(node *ast.BuiltInCall) types.Type { + builtIn := (*BuiltIn)(nil) + idx := slices.IndexFunc(check.builtIns, func(b *BuiltIn) bool { + return b.name == node.Name.Name + }) - for _, b := range builtIns { - if b.name == node.Name.Name { - builtIn = b - } + if idx != -1 { + builtIn = check.builtIns[idx] } if builtIn == nil { - return nil, NewErrorf(node.Name, "unknown built-in function '@%s'", node.Name.Name) + check.errorf(node.Name, "unknown built-in function '@%s'", node.Name.Name) + return nil } - args, ok := node.Args.(*ast.ParenList) - if !ok { - return nil, NewError(node.Args, "block as built-in function argument is not yet supported") + args, _ := node.Args.(*ast.ParenList) + if args == nil { + check.errorf(node.Args, "block as built-in function argument is not yet supported") + return nil } - argTypes, err := scope.TypeOf(args) - if err != nil { - return nil, err + tArgs := check.typeOfParenList(args) + if tArgs == nil { + return nil } - if idx, err := builtIn.t.CheckArgs(argTypes.(*types.Tuple)); err != nil { + if idx, err := builtIn.t.CheckArgs(tArgs.(*types.Tuple)); err != nil { n := ast.Node(args) if idx < len(args.Exprs) { n = args.Exprs[idx] } - return nil, NewErrorf(n, err.Error()) + check.errorf(n, err.Error()) + return nil } - value, err := builtIn.f(args, scope) - if err != nil { - return nil, err - } - - if value != nil { - return value.Type, nil + value := builtIn.f(args, check.scope) + if value == nil { + return nil } - return types.Unit, nil + return value.Type } -func typeCheckCall(node *ast.Call, scope *Scope) (types.Type, error) { - t, err := scope.TypeOf(node.X) - if err != nil { - return nil, err +func (check *Checker) typeOfCall(node *ast.Call) types.Type { + tOperand := check.typeOf(node.X) + if tOperand == nil { + return nil } - fn, ok := t.Underlying().(*types.Func) - if !ok { - return nil, NewError(node.X, "expression is not a function") + fn := types.AsFunc(tOperand) + if fn == nil { + check.errorf(node.X, "expression is not a function") + return nil } - argTypes, err := scope.TypeOf(node.Args) - if err != nil { - return nil, err + tArgs := check.typeOfParenList(node.Args) + if tArgs == nil { + return nil } - if idx, err := fn.CheckArgs(argTypes.(*types.Tuple)); err != nil { + if idx, err := fn.CheckArgs(tArgs.(*types.Tuple)); err != nil { n := ast.Node(node.Args) if idx < len(node.Args.Exprs) { n = node.Args.Exprs[idx] } - return nil, NewErrorf(n, err.Error()) + check.errorf(n, err.Error()) + return nil } - return fn.Result(), nil + return fn.Result() } -func typeCheckIndex(node *ast.Index, scope *Scope) (types.Type, error) { - t, err := scope.TypeOf(node.X) - if err != nil { - return nil, err +func (check *Checker) typeOfIndex(node *ast.Index) types.Type { + t := check.typeOf(node.X) + if t == nil { + return nil } if len(node.Args.Exprs) != 1 { - return nil, NewErrorf(node.Args.ExprList, "expected 1 argument") + check.errorf(node.Args.ExprList, "expected 1 argument") + return nil } - i, err := scope.TypeOf(node.Args.Exprs[0]) - if err != nil { - return nil, err + tIndex := check.typeOf(node.Args.Exprs[0]) + if tIndex == nil { + return nil } if array := types.AsArray(t); array != nil { - if !types.Primitives[types.I32].Equals(i) { - return nil, NewErrorf(node.Args.Exprs[0], "expected type 'i32' for index, got '%s' instead", i) + if !types.Primitives[types.I32].Equals(tIndex) { + check.errorf(node.Args.Exprs[0], "expected type (i32) for index, got (%s) instead", tIndex) + return nil } - return array.ElemType(), nil - } else if tuple := types.AsTuple(t); tuple != nil { - - index := uint64(0) + return array.ElemType() + } - if lit, _ := node.Args.Exprs[0].(*ast.Literal); lit != nil && lit.Kind == ast.IntLiteral { - n, err := strconv.ParseInt(lit.Value, 0, 64) - if err != nil { - panic(err) - } + tuple := types.AsTuple(t) + if tuple == nil { + check.errorf(node.X, "expression is not an array or tuple") + return nil + } - if n < 0 || n > int64(tuple.Len())-1 { - return nil, NewErrorf(node.Args.Exprs[0], "index must be in range 0..%d", tuple.Len()-1) - } + // TODO use [Scope.ValueOf] + lit, _ := node.Args.Exprs[0].(*ast.Literal) + if lit == nil || lit.Kind != ast.IntLiteral { + check.errorf(node.Args.Exprs[0], "expected integer literal") + return nil + } - index = uint64(n) - } else { - return nil, NewError(node.Args.Exprs[0], "expected integer literal") - } + n, err := strconv.ParseInt(lit.Value, 0, 64) + if err != nil { + panic(err) + } - return tuple.Types()[index], nil - } else { - return nil, NewError(node.X, "expression is not an array or tuple") + if n < 0 || n > int64(tuple.Len())-1 { + check.errorf(node.Args.Exprs[0], "index must be in range 0..%d", tuple.Len()-1) + return nil } + + return tuple.Types()[uint64(n)] } -func typeCheckArrayType(node *ast.ArrayType, scope *Scope) (types.Type, error) { +func (check *Checker) typeOfArrayType(node *ast.ArrayType) types.Type { if len(node.Args.Exprs) == 0 { - return nil, NewError(node.Args, "slices are not implemented") + check.errorf(node.Args, "slices are not implemented") + return nil } if len(node.Args.Exprs) > 1 { - return nil, NewError(node.Args, "expected 1 argument") - } - - value, err := scope.ValueOf(node.Args.Exprs[0]) - if err != nil { - return nil, err + check.errorf(node.Args, "expected 1 argument") + return nil } + value := check.valueOf(node.Args.Exprs[0]) if value == nil { - return nil, NewError(node.Args.Exprs[0], "array size cannot be infered") + check.errorf(node.Args.Exprs[0], "array size cannot be infered") + return nil } intValue := constant.AsInt(value.Value) if intValue == nil { - return nil, NewError(node.Args.Exprs[0], "expected integer value for array size") + check.errorf(node.Args.Exprs[0], "expected integer value for array size") + return nil } if intValue.Sign() == -1 || intValue.Int64() > math.MaxInt { - return nil, NewErrorf(node.Args.Exprs[0], "size must be in range 0..9223372036854775807") + check.errorf(node.Args.Exprs[0], "size must be in range 0..9223372036854775807") + return nil } - elemType, err := scope.TypeOf(node.X) - if err != nil { - return nil, err + elemType := check.typeOf(node.X) + if elemType == nil { + return nil } if !types.IsTypeDesc(elemType) { - return nil, NewErrorf(node.X, "expected type, got '%s'", elemType) + check.errorf(node.X, "expected type, got (%s)", elemType) + return nil } size := int(intValue.Int64()) t := types.NewArray(size, types.SkipTypeDesc(elemType)) - return types.NewTypeDesc(t), nil + return types.NewTypeDesc(t) } -func typeCheckSignature(node *ast.Signature, scope *Scope) (types.Type, error) { - params, err := scope.TypeOf(node.Params) - if err != nil { - return nil, err +func (check *Checker) typeOfSignature(node *ast.Signature) types.Type { + tParams := check.typeOfParenList(node.Params) + if tParams == nil { + return nil } - result := types.Unit + tResult := types.Unit if node.Result != nil { - tResult, err := scope.TypeOf(node.Result) - if err != nil { - return nil, err + tActualResult := check.typeOf(node.Result) + if tActualResult == nil { + return nil } - if !types.IsTypeDesc(tResult) { - return nil, NewErrorf( - node.Result, - "expected type, got (%s) instead", - tResult, - ) + if !types.IsTypeDesc(tActualResult) { + check.errorf(node.Result, "expected type, got (%s) instead", tActualResult) + return nil } - result = types.WrapInTuple(types.SkipTypeDesc(tResult)) + tResult = types.WrapInTuple(types.SkipTypeDesc(tActualResult)) } - // 'param' should be a [*types.Tuple] because 'node.Params' is a [*ast.ParenList]. - t := types.NewFunc(result, params.(*types.Tuple)) - return types.NewTypeDesc(t), nil + t := types.NewFunc(tResult, tParams.(*types.Tuple)) + return types.NewTypeDesc(t) } -func typeCheckPrefixOp(node *ast.PrefixOp, scope *Scope) (types.Type, error) { - tOperand, err := scope.TypeOf(node.X) - if err != nil { - return nil, err +func (check *Checker) typeOfPrefixOp(node *ast.PrefixOp) types.Type { + tOperand := check.typeOf(node.X) + if tOperand == nil { + return nil } switch node.Opr.Kind { @@ -356,39 +352,39 @@ func typeCheckPrefixOp(node *ast.PrefixOp, scope *Scope) (types.Type, error) { if p := types.AsPrimitive(tOperand); p != nil { switch p.Kind() { case types.UntypedInt, types.UntypedFloat, types.I32: - return tOperand, nil + return tOperand } } - return nil, NewErrorf( + check.errorf( node.Opr, "operator '%s' is not defined for the type (%s)", node.Opr.Kind.String(), - tOperand.String(), - ) + tOperand.String()) + return nil case ast.OperatorNot: if p, ok := tOperand.Underlying().(*types.Primitive); ok { switch p.Kind() { case types.UntypedBool, types.Bool: - return tOperand, nil + return tOperand } } - return nil, NewErrorf( + check.errorf( node.X, "operator '%s' is not defined for the type (%s)", node.Opr.Kind.String(), - tOperand.String(), - ) + tOperand.String()) + return nil case ast.OperatorAddr: if types.IsTypeDesc(tOperand) { t := types.NewRef(types.SkipTypeDesc(tOperand)) - return types.NewTypeDesc(t), nil + return types.NewTypeDesc(t) } - return types.NewRef(types.SkipUntyped(tOperand)), nil + return types.NewRef(types.SkipUntyped(tOperand)) case ast.OperatorMutAddr: panic("not implemented") @@ -398,23 +394,24 @@ func typeCheckPrefixOp(node *ast.PrefixOp, scope *Scope) (types.Type, error) { } } -func typeCheckInfixOp(node *ast.InfixOp, scope *Scope) (types.Type, error) { - tOperandX, err := scope.TypeOf(node.X) - if err != nil { - return nil, err +func (check *Checker) typeOfInfixOp(node *ast.InfixOp) types.Type { + tOperandX := check.typeOf(node.X) + if tOperandX == nil { + return nil } - tOperandY, err := scope.TypeOf(node.Y) - if err != nil { - return nil, err + tOperandY := check.typeOf(node.Y) + if tOperandY == nil { + return nil } if !tOperandX.Equals(tOperandY) { - return nil, NewErrorf(node, "type mismatch (%s and %s)", tOperandX, tOperandY) + check.errorf(node, "type mismatch (%s and %s)", tOperandX, tOperandY) + return nil } if node.Opr.Kind == ast.OperatorAssign { - return types.Unit, nil + return types.Unit } if primitive := types.AsPrimitive(tOperandX); primitive != nil { @@ -423,16 +420,16 @@ func typeCheckInfixOp(node *ast.InfixOp, scope *Scope) (types.Type, error) { ast.OperatorBitAnd, ast.OperatorBitOr, ast.OperatorBitXor, ast.OperatorBitShl, ast.OperatorBitShr: switch primitive.Kind() { case types.UntypedInt, types.UntypedFloat, types.I32: - return tOperandX, nil + return tOperandX } case ast.OperatorEq, ast.OperatorNe, ast.OperatorLt, ast.OperatorLe, ast.OperatorGt, ast.OperatorGe: switch primitive.Kind() { case types.UntypedBool, types.UntypedInt, types.UntypedFloat: - return types.Primitives[types.UntypedBool], nil + return types.Primitives[types.UntypedBool] case types.Bool, types.I32: - return types.Primitives[types.Bool], nil + return types.Primitives[types.Bool] } default: @@ -440,27 +437,28 @@ func typeCheckInfixOp(node *ast.InfixOp, scope *Scope) (types.Type, error) { } } - return nil, NewErrorf( + check.errorf( node.Opr, "operator '%s' is not defined for the type '%s'", node.Opr.Kind.String(), - tOperandX.String(), - ) + tOperandX.String()) + return nil } -func typeCheckPostfixOp(node *ast.PostfixOp, scope *Scope) (types.Type, error) { - tOperand, err := scope.TypeOf(node.X) - if err != nil { - return nil, err +func (check *Checker) typeOfPostfixOp(node *ast.PostfixOp) types.Type { + tOperand := check.typeOf(node.X) + if tOperand == nil { + return nil } switch node.Opr.Kind { case ast.OperatorUnwrap: if ref := types.AsRef(tOperand); ref != nil { - return ref.Base(), nil + return ref.Base() } - return nil, NewError(node.X, "expression is not a reference type") + check.errorf(node.X, "expression is not a reference type") + return nil case ast.OperatorTry: panic("not inplemented") @@ -470,13 +468,13 @@ func typeCheckPostfixOp(node *ast.PostfixOp, scope *Scope) (types.Type, error) { } } -func typeCheckBracketList(node *ast.BracketList, scope *Scope) (types.Type, error) { +func (check *Checker) typeOfBracketList(node *ast.BracketList) types.Type { var elemType types.Type for _, expr := range node.Exprs { - t, err := scope.TypeOf(expr) - if err != nil { - return nil, err + t := check.typeOf(expr) + if t == nil { + return nil } if elemType == nil { @@ -485,122 +483,114 @@ func typeCheckBracketList(node *ast.BracketList, scope *Scope) (types.Type, erro } if !elemType.Equals(t) { - return nil, NewErrorf( - expr, - "expected type '%s' for element, got '%s' instead", - elemType, - t, - ) + check.errorf(expr, "expected type (%s) for element, got (%s) instead", elemType, t) + return nil } } size := len(node.Exprs) - return types.NewArray(size, elemType), nil + return types.NewArray(size, elemType) } -func typeCheckParenList(node *ast.ParenList, scope *Scope) (types.Type, error) { +func (check *Checker) typeOfParenList(node *ast.ParenList) types.Type { // Either typedesc or tuple contructor. if len(node.Exprs) == 0 { - return types.Unit, nil + return types.Unit } elemTypes := []types.Type{} isTypeDescTuple := false - { - t, err := scope.TypeOf(node.Exprs[0]) - if err != nil { - return nil, err - } + t := check.typeOf(node.Exprs[0]) + if t == nil { + return nil + } - if types.IsTypeDesc(t) { - isTypeDescTuple = true - elemTypes = append(elemTypes, types.SkipTypeDesc(t)) - } else { - elemTypes = append(elemTypes, types.SkipUntyped(t)) - } + if types.IsTypeDesc(t) { + isTypeDescTuple = true + elemTypes = append(elemTypes, types.SkipTypeDesc(t)) + } else { + elemTypes = append(elemTypes, types.SkipUntyped(t)) } for _, expr := range node.Exprs[1:] { - t, err := scope.TypeOf(expr) - if err != nil { - return nil, err + t := check.typeOf(expr) + if t == nil { + return nil } if isTypeDescTuple { if !types.IsTypeDesc(t) { - return nil, NewErrorf(expr, "expected type, got '%s' instead", t) + check.errorf(expr, "expected type, got '%s' instead", t) + return nil } elemTypes = append(elemTypes, types.SkipTypeDesc(t)) } else { if types.IsTypeDesc(t) { - return nil, NewErrorf(expr, "expected expression, got type '%s' instead", t) + check.errorf(expr, "expected expression, got type '%s' instead", t) + return nil } elemTypes = append(elemTypes, types.SkipUntyped(t)) } } - t := types.NewTuple(elemTypes...) - if isTypeDescTuple { - return types.NewTypeDesc(t), nil + return types.NewTypeDesc(types.NewTuple(elemTypes...)) } - return t, nil + return types.NewTuple(elemTypes...) } -func typeCheckCurlyList(scope *Scope, node *ast.CurlyList) (types.Type, error) { - block := NewBlock(scope) +func (check *Checker) typeOfCurlyList(node *ast.CurlyList) types.Type { + block := NewBlock(NewScope(check.scope)) fmt.Printf(">>> push local\n") for _, node := range node.Nodes { - if err := ast.WalkTopDown(block.visit, node); err != nil { - return nil, err - } + ast.WalkTopDown(check.blockVisitor(block), node) } fmt.Printf(">>> pop local\n") - - return block.t, nil + return block.t } -func typeCheckIf(node *ast.If, scope *Scope) (types.Type, error) { +func (check *Checker) typeOfIf(node *ast.If) types.Type { // We check the body type before the condition to return the // body type in case the condition is not a boolean expression. - tBody, err := scope.TypeOf(node.Body) - if err != nil { - return nil, err + tBody := check.typeOf(node.Body) + if tBody == nil { + return nil } if node.Else != nil { - if err := typeCheckElse(node.Else, tBody, scope); err != nil { - return tBody, err + tElse := check.typeOfElse(node.Else, tBody) + if tElse == nil { + return tBody } } - tCondition, err := scope.TypeOf(node.Cond) - if err != nil { - return tBody, err + tCondition := check.typeOf(node.Cond) + if tCondition == nil { + return tBody } if !types.Primitives[types.Bool].Equals(tCondition) { - return tBody, NewErrorf( + check.errorf( node.Cond, "expected type (bool) for condition, got (%s) instead", - tCondition, - ) + tCondition) + return tBody } - return tBody, nil + return tBody } -func typeCheckElse(node *ast.Else, expectedType types.Type, scope *Scope) error { - tBody, err := scope.TypeOf(node.Body) - if err != nil { - return err +func (check *Checker) typeOfElse(node *ast.Else, expectedType types.Type) types.Type { + tBody := check.typeOf(node.Body) + if tBody == nil { + return nil } if !expectedType.Equals(tBody) { @@ -615,35 +605,37 @@ func typeCheckElse(node *ast.Else, expectedType types.Type, scope *Scope) error lastNode = body.Body.Nodes[len(body.Body.Nodes)-1] } - return NewErrorf( + check.errorf( lastNode, "all branches must have the same type with first branch (%s), got (%s) instead", expectedType, - tBody, - ) + tBody) + return nil } - return nil + return tBody } -func typeCheckWhile(node *ast.While, scope *Scope) error { - tBody, err := scope.TypeOf(node.Body) - if err != nil { - return err +func (check *Checker) typeOfWhile(node *ast.While) types.Type { + tBody := check.typeOf(node.Body) + if tBody == nil { + return nil } if !types.Unit.Equals(tBody) { - return NewErrorf(node.Body, "while loop body must have no type, but body has type '%s'", tBody) + check.errorf(node.Body, "while loop body must have no type, but got (%s)", tBody) + return nil } - tCond, err := scope.TypeOf(node.Cond) - if err != nil { - return err + tCond := check.typeOf(node.Cond) + if tCond == nil { + return nil } if !types.Primitives[types.Bool].Equals(tCond) { - return NewErrorf(node.Cond, "expected type 'bool' for condition, got '%s' instead", tCond) + check.errorf(node.Cond, "expected type 'bool' for condition, got (%s) instead", tCond) + return nil } - return nil + return types.Unit } diff --git a/checker/value.go b/checker/value.go index 49c210d..8b71e69 100644 --- a/checker/value.go +++ b/checker/value.go @@ -7,7 +7,7 @@ import ( // Represents a compile-time known value. // Also can represent a type in some situations. -type Value struct { +type TypedValue struct { Type types.Type Value constant.Value // Can be nil. } diff --git a/checker/var.go b/checker/var.go index d065b47..95613f7 100644 --- a/checker/var.go +++ b/checker/var.go @@ -2,7 +2,6 @@ package checker import ( "github.com/saffage/jet/ast" - "github.com/saffage/jet/internal/assert" "github.com/saffage/jet/types" ) @@ -27,8 +26,3 @@ func (v *Var) Type() types.Type { return v.t } func (v *Var) Name() string { return v.name.Name } func (v *Var) Ident() *ast.Ident { return v.name } func (v *Var) Node() ast.Node { return v.node } - -func (v *Var) setType(t types.Type) { - assert.Ok(t != nil) - v.t = t -} diff --git a/constant/float.go b/constant/float.go deleted file mode 100644 index 3f2495e..0000000 --- a/constant/float.go +++ /dev/null @@ -1 +0,0 @@ -package constant diff --git a/constant/int.go b/constant/int.go deleted file mode 100644 index 3f2495e..0000000 --- a/constant/int.go +++ /dev/null @@ -1 +0,0 @@ -package constant diff --git a/constant/kind.go b/constant/kind.go deleted file mode 100644 index 3f2495e..0000000 --- a/constant/kind.go +++ /dev/null @@ -1 +0,0 @@ -package constant diff --git a/constant/kind_string.go b/constant/kind_string.go new file mode 100644 index 0000000..e9fe041 --- /dev/null +++ b/constant/kind_string.go @@ -0,0 +1,26 @@ +// Code generated by "stringer -type=Kind"; DO NOT EDIT. + +package constant + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Int-0] + _ = x[Float-1] + _ = x[String-2] + _ = x[Bool-3] +} + +const _Kind_name = "IntFloatStringBool" + +var _Kind_index = [...]uint8{0, 3, 8, 14, 18} + +func (i Kind) String() string { + if i >= Kind(len(_Kind_index)-1) { + return "Kind(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Kind_name[_Kind_index[i]:_Kind_index[i+1]] +} diff --git a/constant/value.go b/constant/value.go index 72c359b..2be5019 100644 --- a/constant/value.go +++ b/constant/value.go @@ -8,27 +8,80 @@ import ( "github.com/saffage/jet/ast" ) +type Kind byte + +//go:generate stringer -type=Kind +const ( + Int Kind = iota + Float + String + Bool +) + type Value interface { + // String() string Kind() Kind - String() string - implValue() } -func NewBool(value bool) Value { - return boolValue{value} +func NewBigInt(value *big.Int) Value { + if value == nil { + panic("nil argument") + } + + return &intValue{value} } -func NewInt(value *big.Int) Value { - return intValue{value} +func NewBigFloat(value *big.Float) Value { + if value == nil { + panic("nil argument") + } + + return &floatValue{value} } -func NewFloat(value *big.Float) floatValue { - return floatValue{value} +func NewInt(value int64) Value { return &intValue{big.NewInt(value)} } +func NewFloat(value float64) Value { return &floatValue{big.NewFloat(value)} } +func NewString(value string) Value { return &stringValue{value} } +func NewBool(value bool) Value { return &boolValue{value} } + +func IsInt(value Value) bool { return value.Kind() == Int } +func IsFloat(value Value) bool { return value.Kind() == Float } +func IsString(value Value) bool { return value.Kind() == String } +func IsBool(value Value) bool { return value.Kind() == Bool } + +func AsInt(value Value) *big.Int { + if IsInt(value) { + return value.(*intValue).val + } + + return nil } -func NewString(value string) Value { - return stringValue{value} +func AsFloat(value Value) *big.Float { + if IsFloat(value) { + return value.(*floatValue).val + } + + return nil +} + +func AsString(value Value) *string { + if IsString(value) { + val := value.(*stringValue).val + return &val + } + + return nil +} + +func AsBool(value Value) *bool { + if IsBool(value) { + val := value.(*boolValue).val + return &val + } + + return nil } func FromNode(node *ast.Literal) Value { @@ -38,16 +91,16 @@ func FromNode(node *ast.Literal) Value { switch node.Kind { case ast.IntLiteral: - if value, ok := new(big.Int).SetString(node.Value, 0); ok { - return NewInt(value) + if value, ok := big.NewInt(0).SetString(node.Value, 0); ok { + return NewBigInt(value) } // Unreachable? panic(fmt.Sprintf("invalid integer value for constant: '%s'", node.Value)) case ast.FloatLiteral: - if value, ok := new(big.Float).SetString(node.Value); ok { - return NewFloat(value) + if value, ok := big.NewFloat(0.0).SetString(node.Value); ok { + return NewBigFloat(value) } // Unreachable? @@ -61,69 +114,28 @@ func FromNode(node *ast.Literal) Value { } } -func AsBool(value Value) *bool { - if v, ok := value.(boolValue); ok { - return &v.bool - } - - return nil -} - -func AsInt(value Value) *big.Int { - if v, ok := value.(intValue); ok && v.Int != nil { - return v.Int - } - - return nil -} - -func AsFloat(value Value) *big.Float { - if v, ok := value.(floatValue); ok && v.Float != nil { - return v.Float - } - - return nil -} - -func AsString(value Value) *string { - if v, ok := value.(stringValue); ok { - return &v.string - } - - return nil -} - -type Kind byte - -const ( - Unknown Kind = iota - - Bool // TODO delete and implement through attributes. - Int - Float - String -) +//------------------------------------------------ +// Value implementation +//------------------------------------------------ type ( - boolValue struct{ bool } - intValue struct{ *big.Int } - floatValue struct{ *big.Float } - stringValue struct{ string } + intValue struct{ val *big.Int } + floatValue struct{ val *big.Float } + stringValue struct{ val string } + boolValue struct{ val bool } ) -func (boolValue) implValue() {} +func (v *intValue) String() string { return v.val.String() } +func (v *floatValue) String() string { return v.val.String() } +func (v *stringValue) String() string { return strconv.Quote(v.val) } +func (v *boolValue) String() string { return strconv.FormatBool(v.val) } + +func (v *intValue) Kind() Kind { return Int } +func (v *floatValue) Kind() Kind { return Float } +func (v *stringValue) Kind() Kind { return String } +func (v *boolValue) Kind() Kind { return Bool } + func (intValue) implValue() {} func (floatValue) implValue() {} func (stringValue) implValue() {} - -func (v boolValue) Kind() Kind { return Bool } -func (v boolValue) String() string { return strconv.FormatBool(v.bool) } - -func (v intValue) Kind() Kind { return Int } -func (v intValue) String() string { return v.Int.String() } - -func (v floatValue) Kind() Kind { return Float } -func (v floatValue) String() string { return v.Float.String() } - -func (v stringValue) Kind() Kind { return String } -func (v stringValue) String() string { return "\"" + v.string + "\"" } +func (boolValue) implValue() {} diff --git a/internal/jet/jet.go b/internal/jet/jet.go index 042fb37..0f2132e 100644 --- a/internal/jet/jet.go +++ b/internal/jet/jet.go @@ -38,7 +38,7 @@ func reportError(cfg *config.Config, err error) { report.Note(cfg, "parser note: "+note, token.Loc{}, token.Loc{}) } - case checker.Error: + case *checker.Error: start, end := token.Loc{}, token.Loc{} if err.Node != nil { @@ -122,11 +122,11 @@ func process( if err := recover(); err != nil { switch e := err.(type) { case checker.Error: - reportError(cfg, e) + reportError(cfg, &e) case []checker.Error: for i := range e { - reportError(cfg, e[i]) + reportError(cfg, &e[i]) } default: @@ -140,14 +140,15 @@ func process( Name: &ast.Ident{Name: "repl"}, Body: &ast.CurlyList{List: nodeList}, } - checker.NewFunc(nil, nil, decl) + checker.NewFunc(nil, nil, nil, decl) } else { mod := &ast.ModuleDecl{ Name: &ast.Ident{Name: cfg.Files[config.MainFileID].Name}, Body: nodeList, } - _, err := checker.NewModule(mod) - if err != nil { + errs := checker.Check(mod) + + for _, err := range errs { reportError(cfg, err) } } diff --git a/types/func.go b/types/func.go index 877ba51..ad3df17 100644 --- a/types/func.go +++ b/types/func.go @@ -96,6 +96,18 @@ func (t *Func) CheckArgs(args *Tuple) (idx int, err error) { return -1, nil } +func IsFunc(t Type) bool { return AsFunc(t) != nil } + +func AsFunc(t Type) *Func { + if t != nil { + if fn, _ := t.Underlying().(*Func); fn != nil { + return fn + } + } + + return nil +} + func ordinalize(num int) string { s := strconv.Itoa(num)