diff --git a/types/alias.go b/types/alias.go index e68ca28..49f3346 100644 --- a/types/alias.go +++ b/types/alias.go @@ -12,7 +12,6 @@ func NewAlias(t Type, name string) *Alias { if t == nil { panic("alias to unknown must be used in built-in types only") } - return &Alias{ base: t, actual: removeAlias(t), @@ -41,9 +40,21 @@ func SkipAlias(t Type) Type { } func removeAlias(t0 Type) Type { - if a, ok := t0.(*Alias); ok { - return removeAlias(a.actual) + if t0 == nil { + return nil + } + + t := t0 + a, ok := t0.(*Alias) + + for ok && a != nil { + t = a.actual + a, ok = t.(*Alias) } - return t0 + if t == nil { + panic("invalid alias") + } + + return t } diff --git a/types/array.go b/types/array.go index e0dc544..db4099e 100644 --- a/types/array.go +++ b/types/array.go @@ -11,15 +11,13 @@ func NewArray(size int, t Type) *Array { if size < 0 { panic(fmt.Sprintf("invalid array size (%d)", size)) } - return &Array{size, t} } func (t *Array) Equals(other Type) bool { - if otherArray, ok := other.Underlying().(*Array); ok { - return (t.size == -1 || t.size == otherArray.size) && t.elem.Equals(otherArray.elem) + if t2 := AsArray(other); t2 != nil { + return t.size == t2.size && t.elem.Equals(t2.elem) } - return false } @@ -29,7 +27,6 @@ func (t *Array) String() string { if t.size == -1 { return "[_]" + t.elem.String() } - return fmt.Sprintf("[%d]%s", t.size, t.elem) } diff --git a/types/func.go b/types/func.go index ad3df17..f0a3f8a 100644 --- a/types/func.go +++ b/types/func.go @@ -15,11 +15,9 @@ func NewFunc(result *Tuple, params *Tuple) *Func { if result == nil { result = Unit } - if params == nil { params = Unit } - return &Func{ params: params, result: result, @@ -27,11 +25,9 @@ func NewFunc(result *Tuple, params *Tuple) *Func { } func (t *Func) Equals(other Type) bool { - if otherFunc, ok := other.(*Func); ok { - return t.result.Equals(otherFunc.result) && - t.params.Equals(otherFunc.params) + if t2 := AsFunc(other); t2 != nil { + return t.result.Equals(t2.result) && t.params.Equals(t2.params) } - return false } diff --git a/types/primitive.go b/types/primitive.go index da68adf..be21e08 100644 --- a/types/primitive.go +++ b/types/primitive.go @@ -6,9 +6,12 @@ type Primitive struct { kind PrimitiveKind } -func (p *Primitive) Equals(other Type) bool { - p2, _ := other.Underlying().(*Primitive) - return p.kind == KindAny || (p2 != nil && (p.kind == p2.kind || p.kind == SkipUntyped(p2).(*Primitive).kind)) +func (t *Primitive) Equals(other Type) bool { + if t2 := AsPrimitive(other); t2 != nil { + return t.kind == KindAny || (t2 != nil && (t.kind == t2.kind || t.kind == SkipUntyped(t2).(*Primitive).kind)) + } + + return false } func (p *Primitive) Underlying() Type { return p } @@ -25,7 +28,6 @@ func AsPrimitive(t Type) *Primitive { if primitive, _ := t.Underlying().(*Primitive); primitive != nil { return primitive } - return nil } @@ -38,11 +40,10 @@ func SkipUntyped(t Type) Type { case KindUntypedInt: return I32 - // case UntypedFloat, UntypedString: - // panic("not implemented") + // case UntypedFloat, UntypedString: + // panic("not implemented") } } - return t } diff --git a/types/ref.go b/types/ref.go index 20b629f..df9d965 100644 --- a/types/ref.go +++ b/types/ref.go @@ -8,15 +8,13 @@ func NewRef(t Type) *Ref { if IsTypeDesc(t) { panic("references to meta type is not allowed") } - return &Ref{base: t} } func (t *Ref) Equals(other Type) bool { - if otherRef, ok := other.Underlying().(*Ref); ok { - return t.base.Equals(otherRef.base) + if t2 := AsRef(other); t2 != nil { + return t.base.Equals(t2.base) } - return false } diff --git a/types/struct.go b/types/struct.go index aa3855e..614952c 100644 --- a/types/struct.go +++ b/types/struct.go @@ -16,18 +16,15 @@ func NewStruct(fields map[string]Type) *Struct { } func (t *Struct) Equals(other Type) bool { - if otherStruct, _ := other.Underlying().(*Struct); otherStruct != nil { + if t2 := AsStruct(other); t2 != nil { for name, tField := range t.fields { - tOtherField, ok := otherStruct.fields[name] - - if !ok || !tField.Equals(tOtherField) { + t2Field, ok := t2.fields[name] + if !ok || !tField.Equals(t2Field) { return false } } - return true } - return false } diff --git a/types/tuple.go b/types/tuple.go index b167569..a6f24c0 100644 --- a/types/tuple.go +++ b/types/tuple.go @@ -19,17 +19,16 @@ func NewTuple(types ...Type) *Tuple { // // NOTE: name of the elements are not required to be the same. func (t *Tuple) Equals(other Type) bool { - if otherTuple, ok := other.Underlying().(*Tuple); ok { + if t2 := AsTuple(other); t2 != nil { return slices.EqualFunc( t.types, - otherTuple.types, + t2.types, func(a, b Type) bool { return a.Equals(b) }, ) } else if underlying := t.Underlying(); underlying != t { // The tuple has 1 element and can be equals to the element type. return underlying.Equals(other.Underlying()) } - return false } diff --git a/types/typedesc.go b/types/typedesc.go index e727b0d..f601a28 100644 --- a/types/typedesc.go +++ b/types/typedesc.go @@ -21,8 +21,10 @@ func NewTypeDesc(t Type) *TypeDesc { } func (t *TypeDesc) Equals(other Type) bool { - typedesc, _ := other.(*TypeDesc) - return typedesc != nil && t.base.Equals(typedesc.base) + if t2 := AsTypeDesc(other); t2 != nil { + return t.base.Equals(t2.base) + } + return false } func (t *TypeDesc) Underlying() Type { return t }