Skip to content

Commit

Permalink
refactor: use AsT funcs for Type.Equals methods
Browse files Browse the repository at this point in the history
  • Loading branch information
saffage committed May 9, 2024
1 parent 1a71d40 commit 18b1d90
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 37 deletions.
19 changes: 15 additions & 4 deletions types/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ func NewAlias(t Type, name string) *Alias {
if t == nil {
panic("alias to unknown must be used in built-in types only")
}

return &Alias{
base: t,
actual: removeAlias(t),
Expand Down Expand Up @@ -41,9 +40,21 @@ func SkipAlias(t Type) Type {
}

func removeAlias(t0 Type) Type {
if a, ok := t0.(*Alias); ok {
return removeAlias(a.actual)
if t0 == nil {
return nil
}

t := t0
a, ok := t0.(*Alias)

for ok && a != nil {
t = a.actual
a, ok = t.(*Alias)
}

return t0
if t == nil {
panic("invalid alias")
}

return t
}
7 changes: 2 additions & 5 deletions types/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ func NewArray(size int, t Type) *Array {
if size < 0 {
panic(fmt.Sprintf("invalid array size (%d)", size))
}

return &Array{size, t}
}

func (t *Array) Equals(other Type) bool {
if otherArray, ok := other.Underlying().(*Array); ok {
return (t.size == -1 || t.size == otherArray.size) && t.elem.Equals(otherArray.elem)
if t2 := AsArray(other); t2 != nil {
return t.size == t2.size && t.elem.Equals(t2.elem)
}

return false
}

Expand All @@ -29,7 +27,6 @@ func (t *Array) String() string {
if t.size == -1 {
return "[_]" + t.elem.String()
}

return fmt.Sprintf("[%d]%s", t.size, t.elem)
}

Expand Down
8 changes: 2 additions & 6 deletions types/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@ func NewFunc(result *Tuple, params *Tuple) *Func {
if result == nil {
result = Unit
}

if params == nil {
params = Unit
}

return &Func{
params: params,
result: result,
}
}

func (t *Func) Equals(other Type) bool {
if otherFunc, ok := other.(*Func); ok {
return t.result.Equals(otherFunc.result) &&
t.params.Equals(otherFunc.params)
if t2 := AsFunc(other); t2 != nil {
return t.result.Equals(t2.result) && t.params.Equals(t2.params)
}

return false
}

Expand Down
15 changes: 8 additions & 7 deletions types/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ type Primitive struct {
kind PrimitiveKind
}

func (p *Primitive) Equals(other Type) bool {
p2, _ := other.Underlying().(*Primitive)
return p.kind == KindAny || (p2 != nil && (p.kind == p2.kind || p.kind == SkipUntyped(p2).(*Primitive).kind))
func (t *Primitive) Equals(other Type) bool {
if t2 := AsPrimitive(other); t2 != nil {
return t.kind == KindAny || (t2 != nil && (t.kind == t2.kind || t.kind == SkipUntyped(t2).(*Primitive).kind))
}

return false
}

func (p *Primitive) Underlying() Type { return p }
Expand All @@ -25,7 +28,6 @@ func AsPrimitive(t Type) *Primitive {
if primitive, _ := t.Underlying().(*Primitive); primitive != nil {
return primitive
}

return nil
}

Expand All @@ -38,11 +40,10 @@ func SkipUntyped(t Type) Type {
case KindUntypedInt:
return I32

// case UntypedFloat, UntypedString:
// panic("not implemented")
// case UntypedFloat, UntypedString:
// panic("not implemented")
}
}

return t
}

Expand Down
6 changes: 2 additions & 4 deletions types/ref.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ func NewRef(t Type) *Ref {
if IsTypeDesc(t) {
panic("references to meta type is not allowed")
}

return &Ref{base: t}
}

func (t *Ref) Equals(other Type) bool {
if otherRef, ok := other.Underlying().(*Ref); ok {
return t.base.Equals(otherRef.base)
if t2 := AsRef(other); t2 != nil {
return t.base.Equals(t2.base)
}

return false
}

Expand Down
9 changes: 3 additions & 6 deletions types/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@ func NewStruct(fields map[string]Type) *Struct {
}

func (t *Struct) Equals(other Type) bool {
if otherStruct, _ := other.Underlying().(*Struct); otherStruct != nil {
if t2 := AsStruct(other); t2 != nil {
for name, tField := range t.fields {
tOtherField, ok := otherStruct.fields[name]

if !ok || !tField.Equals(tOtherField) {
t2Field, ok := t2.fields[name]
if !ok || !tField.Equals(t2Field) {
return false
}
}

return true
}

return false
}

Expand Down
5 changes: 2 additions & 3 deletions types/tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ func NewTuple(types ...Type) *Tuple {
//
// NOTE: name of the elements are not required to be the same.
func (t *Tuple) Equals(other Type) bool {
if otherTuple, ok := other.Underlying().(*Tuple); ok {
if t2 := AsTuple(other); t2 != nil {
return slices.EqualFunc(
t.types,
otherTuple.types,
t2.types,
func(a, b Type) bool { return a.Equals(b) },
)
} else if underlying := t.Underlying(); underlying != t {
// The tuple has 1 element and can be equals to the element type.
return underlying.Equals(other.Underlying())
}

return false
}

Expand Down
6 changes: 4 additions & 2 deletions types/typedesc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ func NewTypeDesc(t Type) *TypeDesc {
}

func (t *TypeDesc) Equals(other Type) bool {
typedesc, _ := other.(*TypeDesc)
return typedesc != nil && t.base.Equals(typedesc.base)
if t2 := AsTypeDesc(other); t2 != nil {
return t.base.Equals(t2.base)
}
return false
}

func (t *TypeDesc) Underlying() Type { return t }
Expand Down

0 comments on commit 18b1d90

Please sign in to comment.