diff --git a/ipset.go b/ipset.go index a98caaa..f6cf7f4 100644 --- a/ipset.go +++ b/ipset.go @@ -309,3 +309,18 @@ func (s *IPSet) ContainsFunc() (contains func(IP) bool) { return rv[i].contains(ip) } } + +// Equal reports whether s contains exactly the same IPs as t. +func (s *IPSet) Equal(t *IPSet) bool { + sr := s.Ranges() + tr := t.Ranges() + if len(sr) != len(tr) { + return false + } + for i := range sr { + if sr[i] != tr[i] { + return false + } + } + return true +} diff --git a/ipset_test.go b/ipset_test.go index 092b25f..b88ffdc 100644 --- a/ipset_test.go +++ b/ipset_test.go @@ -701,3 +701,35 @@ func TestIPSetRangesStress(t *testing.T) { } } } + +func TestIPSetEqual(t *testing.T) { + a := new(IPSet) + b := new(IPSet) + + assertEqual := func(want bool) { + t.Helper() + if got := a.Equal(b); got != want { + t.Errorf("%v.Equal(%v) = %v want %v", a, b, got, want) + } + } + + a.Add(MustParseIP("1.1.1.0")) + a.Add(MustParseIP("1.1.1.1")) + a.Add(MustParseIP("1.1.1.2")) + b.AddPrefix(MustParseIPPrefix("1.1.1.0/31")) + b.Add(MustParseIP("1.1.1.2")) + assertEqual(true) + + a.RemoveSet(a) + assertEqual(false) + b.RemoveSet(b) + assertEqual(true) + + a.Add(MustParseIP("1.1.1.0")) + a.Add(MustParseIP("1.1.1.1")) + a.Add(MustParseIP("1.1.1.2")) + + b.AddPrefix(MustParseIPPrefix("1.1.1.0/30")) + b.Remove(MustParseIP("1.1.1.3")) + assertEqual(true) +}