diff --git a/inlining_test.go b/inlining_test.go index b3611e3..267fa3a 100644 --- a/inlining_test.go +++ b/inlining_test.go @@ -43,8 +43,9 @@ func TestInlining(t *testing.T) { }) for _, want := range []string{ "(*IPSet).Add", + "(*IPSet).Clone", "(*IPSet).Remove", - "(*IPSet).RemoveRange", + "(*IPSet).removeRangeLocked", "(*uint128).halves", "IP.BitLen", "IP.IPAddr", @@ -60,6 +61,7 @@ func TestInlining(t *testing.T) { "IP.Prior", "IP.Unmap", "IP.Zone", + "IP.lessOrEq", "IP.v4", "IP.v6", "IP.v6u16", @@ -74,6 +76,7 @@ func TestInlining(t *testing.T) { "IPPrefix.Masked", "IPPrefix.Valid", "IPRange.Prefixes", + "IPRange.entirelyBefore", "IPRange.prefixFrom128AndBits", "IPRange.prefixFrom128AndBits-fm", "IPv4", diff --git a/ipset.go b/ipset.go index a98caaa..40ef9f4 100644 --- a/ipset.go +++ b/ipset.go @@ -4,7 +4,10 @@ package netaddr -import "sort" +import ( + "sort" + "sync" +) // IPSet represents a set of IP addresses. // @@ -15,25 +18,24 @@ import "sort" // nothing on an empty set. Ranges may be fully, partially, or not // overlapping. type IPSet struct { + mu sync.Mutex + // in are the ranges in the set. in []IPRange // out are the ranges to be removed from 'in'. out []IPRange -} -// toInOnly updates s to clear s.out, by merging any s.out into s.in. -func (s *IPSet) toInOnly() { - if len(s.out) > 0 { - s.in = s.Ranges() - s.out = nil - } + // normalized indicates that 'in' is in sorted order, and 'out' is + // empty. + normalized bool } // Clone returns a copy of s that shares no memory with s. func (s *IPSet) Clone() *IPSet { return &IPSet{ - in: s.Ranges(), + in: s.Ranges(), + normalized: true, } } @@ -45,13 +47,22 @@ func (s *IPSet) AddPrefix(p IPPrefix) { s.AddRange(p.Range()) } // AddRange adds r to s. func (s *IPSet) AddRange(r IPRange) { + s.mu.Lock() + defer s.mu.Unlock() + s.addRangeLocked(r) +} + +func (s *IPSet) addRangeLocked(r IPRange) { if !r.Valid() { return } // If there are any removals (s.out), then we need to compact the set // first to get the order right. - s.toInOnly() + if len(s.out) > 0 { + s.rangesLocked() + } s.in = append(s.in, r) + s.normalized = false } // Remove removes ip from the set s. @@ -59,7 +70,9 @@ func (s *IPSet) Remove(ip IP) { s.RemoveRange(IPRange{ip, ip}) } // RemoveFreePrefix removes and returns a Prefix of length bits from the IPSet. func (s *IPSet) RemoveFreePrefix(bitLen uint8) (p IPPrefix, ok bool) { - prefixes := s.Prefixes() + s.mu.Lock() + defer s.mu.Unlock() + prefixes := s.prefixesLocked() if len(prefixes) == 0 { return IPPrefix{}, false } @@ -83,7 +96,7 @@ func (s *IPSet) RemoveFreePrefix(bitLen uint8) (p IPPrefix, ok bool) { } prefix := IPPrefix{IP: bestFit.IP, Bits: bitLen} - s.RemovePrefix(prefix) + s.removeRangeLocked(prefix.Range()) return prefix, true } @@ -92,34 +105,49 @@ func (s *IPSet) RemovePrefix(p IPPrefix) { s.RemoveRange(p.Range()) } // RemoveRange removes r from s. func (s *IPSet) RemoveRange(r IPRange) { + s.mu.Lock() + defer s.mu.Unlock() + s.removeRangeLocked(r) +} + +func (s *IPSet) removeRangeLocked(r IPRange) { if r.Valid() { s.out = append(s.out, r) + s.normalized = false } } // AddSet adds all ranges in b to s. func (s *IPSet) AddSet(b *IPSet) { - for _, r := range b.Ranges() { - s.AddRange(r) + rr := b.Ranges() + s.mu.Lock() + defer s.mu.Unlock() + for _, r := range rr { + s.addRangeLocked(r) } } // RemoveSet removes all ranges in b from s. func (s *IPSet) RemoveSet(b *IPSet) { - for _, r := range b.Ranges() { - s.RemoveRange(r) + rr := b.Ranges() + s.mu.Lock() + defer s.mu.Unlock() + for _, r := range rr { + s.removeRangeLocked(r) } } // Complement updates s to contain the complement of its current // contents. func (s *IPSet) Complement() { - s.toInOnly() - s.out = s.in + s.mu.Lock() + defer s.mu.Unlock() + s.out = s.rangesLocked() s.in = []IPRange{ IPPrefix{IP: IPv4(0, 0, 0, 0), Bits: 0}.Range(), IPPrefix{IP: IPv6Unspecified(), Bits: 0}.Range(), } + s.normalized = false } // Intersect updates s to the set intersection of s and b. @@ -154,6 +182,16 @@ var debugf = discardf // Ranges returns the minimum and sorted set of IP // ranges that covers s. func (s *IPSet) Ranges() []IPRange { + s.mu.Lock() + defer s.mu.Unlock() + return s.rangesLocked() +} + +func (s *IPSet) rangesLocked() []IPRange { + if s.normalized { + return s.in + } + const debug = false if debug { debugf("ranges start in=%v out=%v", s.in, s.out) @@ -274,7 +312,9 @@ func (s *IPSet) Ranges() []IPRange { } } - // TODO: possibly update s.in and s.out, if #110 supports that. + s.in = ret + s.out = nil + s.normalized = true return ret } @@ -284,8 +324,14 @@ func (s *IPSet) Ranges() []IPRange { // returning a new slice of prefixes that covers all of the given 'add' // prefixes with all the 'remove' prefixes removed. func (s *IPSet) Prefixes() []IPPrefix { + s.mu.Lock() + defer s.mu.Unlock() + return s.prefixesLocked() +} + +func (s *IPSet) prefixesLocked() []IPPrefix { var out []IPPrefix - for _, r := range s.Ranges() { + for _, r := range s.rangesLocked() { out = append(out, r.Prefixes()...) } return out diff --git a/ipset_test.go b/ipset_test.go index 092b25f..70e3166 100644 --- a/ipset_test.go +++ b/ipset_test.go @@ -481,11 +481,11 @@ func TestIPSetOverlaps(t *testing.T) { for _, test := range tests { got := test.a.Overlaps(test.b) if got != test.want { - t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.a, test.b, got, test.want) + t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.a.Ranges(), test.b.Ranges(), got, test.want) } got = test.b.Overlaps(test.a) if got != test.want { - t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.b, test.a, got, test.want) + t.Errorf("(%s).Overlaps(%s) = %v, want %v", test.b.Ranges(), test.a.Ranges(), got, test.want) } } }