diff --git a/conn/bind_std.go b/conn/bind_std.go index 69789b33f..46df7fd4e 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "context" "errors" + "fmt" "net" "net/netip" "runtime" @@ -29,16 +30,19 @@ var ( // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { - mu sync.Mutex // protects all fields except as specified - ipv4 *net.UDPConn - ipv6 *net.UDPConn - ipv4PC *ipv4.PacketConn // will be nil on non-Linux - ipv6PC *ipv6.PacketConn // will be nil on non-Linux - - // these three fields are not guarded by mu - udpAddrPool sync.Pool - ipv4MsgsPool sync.Pool - ipv6MsgsPool sync.Pool + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool blackhole4 bool blackhole6 bool @@ -54,23 +58,14 @@ func NewStdNetBind() Bind { }, }, - ipv4MsgsPool: sync.Pool{ - New: func() any { - msgs := make([]ipv4.Message, IdealBatchSize) - for i := range msgs { - msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) - } - return &msgs - }, - }, - - ipv6MsgsPool: sync.Pool{ + msgsPool: sync.Pool{ New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, srcControlSize) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) } return &msgs }, @@ -81,11 +76,10 @@ func NewStdNetBind() Bind { type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort - // src is the current sticky source address and interface index, if supported. - src struct { - netip.Addr - ifidx int32 - } + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte } var ( @@ -104,21 +98,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { } func (e *StdNetEndpoint) ClearSrc() { - e.src.ifidx = 0 - e.src.Addr = netip.Addr{} + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } -func (e *StdNetEndpoint) SrcIP() netip.Addr { - return e.src.Addr -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - return e.src.ifidx -} +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() @@ -129,10 +119,6 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func (e *StdNetEndpoint) SrcToString() string { - return e.src.Addr.String() -} - func listenNet(network string, port int) (*net.UDPConn, int, error) { conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { @@ -188,19 +174,21 @@ again: } var fns []ReceiveFunc if v4conn != nil { - if runtime.GOOS == "linux" { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } - fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) s.ipv4 = v4conn } if v6conn != nil { - if runtime.GOOS == "linux" { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } - fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) s.ipv6 = v6conn } if len(fns) == 0 { @@ -210,76 +198,101 @@ again: return fns, uint16(port), nil } -func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - defer s.ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue } - return numMsgs, nil + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep } + return numMsgs, nil } -func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - defer s.ipv6MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - sizes[i] = msg.N - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - getSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) } } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { return IdealBatchSize } return 1 @@ -302,28 +315,42 @@ func (s *StdNetBind) Close() error { } s.blackhole4 = false s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false if err1 != nil { return err1 } return err2 } +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 - var ( - pc4 *ipv4.PacketConn - pc6 *ipv6.PacketConn - ) + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 - pc6 = s.ipv6PC + br = s.ipv6PC is6 = true - } else { - pc4 = s.ipv4PC + offload = s.ipv6TxOffload } s.mu.Unlock() @@ -333,85 +360,185 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { if conn == nil { return syscall.EAFNOSUPPORT } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) if is6 { - return s.send6(conn, pc6, endpoint, bufs) + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + } else { + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } } else { - return s.send4(conn, pc4, endpoint, bufs) + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} } + return err } -func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as4 := ep.DstIP().As4() - copy(ua.IP, as4[:]) - ua.IP = ua.IP[:4] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { var ( n int err error start int ) - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { break } start += n } } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) if err != nil { break } } } - s.udpAddrPool.Put(ua) - s.ipv4MsgsPool.Put(msgs) return err } -func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { - ua := s.udpAddrPool.Get().(*net.UDPAddr) - as16 := ep.DstIP().As16() - copy(ua.IP, as16[:]) - ua.IP = ua.IP[:16] - ua.Port = int(ep.(*StdNetEndpoint).Port()) - msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) - for i, buf := range bufs { - (*msgs)[i].Buffers[0] = buf - (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) - } +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { var ( - n int - err error - start int + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs ) - if runtime.GOOS == "linux" { - for { - n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) - if err != nil || n == len((*msgs)[start:len(bufs)]) { - break + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue } - start += n } - } else { - for i, buf := range bufs { - _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) - if err != nil { - break + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 } } - s.udpAddrPool.Put(ua) - s.ipv6MsgsPool.Put(msgs) - return err + return n, nil } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 1e4677654..34a3c9acf 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -1,6 +1,12 @@ package conn -import "testing" +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { bind := NewStdNetBind().(*StdNetBind) @@ -20,3 +26,225 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { fn(bufs, sizes, eps) } } + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 228167e39..d5095e004 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go index a2396fe89..f6ab1d2ec 100644 --- a/conn/controlfns_linux.go +++ b/conn/controlfns_linux.go @@ -57,5 +57,13 @@ func init() { } return err }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, ) } diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go index c4536d4bb..91692c0a6 100644 --- a/conn/controlfns_unix.go +++ b/conn/controlfns_unix.go @@ -1,4 +1,4 @@ -//go:build !windows && !linux && !js +//go:build !windows && !linux && !wasm /* SPDX-License-Identifier: MIT * diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 000000000..f1e5b90e5 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 000000000..8e61000f8 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 000000000..d53ff5f7b --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 000000000..8959d9358 --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + txOffload = errSyscall == nil + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + rxOffload = errSyscall == nil && opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/conn/gso_default.go b/conn/gso_default.go new file mode 100644 index 000000000..57780dbb5 --- /dev/null +++ b/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/conn/gso_linux.go b/conn/gso_linux.go new file mode 100644 index 000000000..8596b292e --- /dev/null +++ b/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/conn/sticky_default.go b/conn/sticky_default.go index 05f00ea5b..0b213867d 100644 --- a/conn/sticky_default.go +++ b/conn/sticky_default.go @@ -7,8 +7,23 @@ package conn -// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but -// use alternatively named flags and need ports and require testing. +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. @@ -20,8 +35,8 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { func setSrcControl(control *[]byte, ep *StdNetEndpoint) { } -// srcControlSize returns the recommended buffer size for pooling sticky control -// data. -const srcControlSize = 0 +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go index 274fa38a1..8e206e90b 100644 --- a/conn/sticky_linux.go +++ b/conn/sticky_linux.go @@ -14,6 +14,37 @@ import ( "golang.org/x/sys/unix" ) +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { @@ -35,83 +66,47 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { if hdr.Level == unix.IPPROTO_IP && hdr.Type == unix.IP_PKTINFO { - info := pktInfoFromBuf[unix.Inet4Pktinfo](data) - ep.src.Addr = netip.AddrFrom4(info.Spec_dst) - ep.src.ifidx = info.Ifindex + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } if hdr.Level == unix.IPPROTO_IPV6 && hdr.Type == unix.IPV6_PKTINFO { - info := pktInfoFromBuf[unix.Inet6Pktinfo](data) - ep.src.Addr = netip.AddrFrom16(info.Addr) - ep.src.ifidx = int32(info.Ifindex) + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } } } -// pktInfoFromBuf returns type T populated from the provided buf via copy(). It -// panics if buf is of insufficient size. -func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { - size := int(unsafe.Sizeof(t)) - if len(buf) < size { - panic("pktInfoFromBuf: buffer too small") - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) - return t -} - // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // and source ifindex found in ep. control's len will be set to 0 in the event // that ep is a default value. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - *control = (*control)[:cap(*control)] - if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { - *control = (*control)[:0] + if cap(*control) < len(ep.src) { return } - - if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { - *control = (*control)[:0] - return - } - - if len(*control) < srcControlSize { - *control = (*control)[:0] - return - } - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) - if ep.SrcIP().Is4() { - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = ep.src.ifidx - if ep.SrcIP().IsValid() { - info.Spec_dst = ep.SrcIP().As4() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - } else { - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = uint32(ep.src.ifidx) - if ep.SrcIP().IsValid() { - info.Addr = ep.SrcIP().As16() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - } - + *control = (*control)[:0] + *control = append(*control, ep.src...) } -var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 0219ac300..d2bd58436 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -18,15 +18,49 @@ import ( "golang.org/x/sys/unix" ) +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + func Test_setSrcControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), } - ep.src.Addr = netip.MustParseAddr("127.0.0.1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) setSrcControl(&control, ep) @@ -53,10 +87,9 @@ func Test_setSrcControl(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("[::1]:1234"), } - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) setSrcControl(&control, ep) @@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) { }) t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 @@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) { func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(control, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("IPv6", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, stickyControlSize) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO @@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) { if ep.SrcIP() != netip.MustParseAddr("::1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("ClearOnEmpty", func(t *testing.T) { - control := make([]byte, srcControlSize) + var control []byte ep := &StdNetEndpoint{} - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) getSrcFromControl(control, ep) if ep.SrcIP().IsValid() { - t.Errorf("unexpected address: %v", ep.src.Addr) + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 0 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("Multiple", func(t *testing.T) { @@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) { zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) zeroHdr.SetLen(unix.CmsgLen(0)) - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(combined, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) } diff --git a/device/channels.go b/device/channels.go index 039d8dfd0..e526f6bb1 100644 --- a/device/channels.go +++ b/device/channels.go @@ -19,13 +19,13 @@ import ( // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. type outboundQueue struct { - c chan *QueueOutboundElement + c chan *QueueOutboundElementsContainer wg sync.WaitGroup } func newOutboundQueue() *outboundQueue { q := &outboundQueue{ - c: make(chan *QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } q.wg.Add(1) go func() { @@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue { // A inboundQueue is similar to an outboundQueue; see those docs. type inboundQueue struct { - c chan *QueueInboundElement + c chan *QueueInboundElementsContainer wg sync.WaitGroup } func newInboundQueue() *inboundQueue { q := &inboundQueue{ - c: make(chan *QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } q.wg.Add(1) go func() { @@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue { } type autodrainingInboundQueue struct { - c chan *[]*QueueInboundElement + c chan *QueueInboundElementsContainer } // newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. @@ -81,7 +81,7 @@ type autodrainingInboundQueue struct { // some other means, such as sending a sentinel nil values. func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { q := &autodrainingInboundQueue{ - c: make(chan *[]*QueueInboundElement, QueueInboundSize), + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), } runtime.SetFinalizer(q, device.flushInboundQueue) return q @@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) default: return } @@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { } type autodrainingOutboundQueue struct { - c chan *[]*QueueOutboundElement + c chan *QueueOutboundElementsContainer } // newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. @@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct { // All sends to the channel must be best-effort, because there may be no receivers. func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { q := &autodrainingOutboundQueue{ - c: make(chan *[]*QueueOutboundElement, QueueOutboundSize), + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), } runtime.SetFinalizer(q, device.flushOutboundQueue) return q @@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { for { select { - case elems := <-q.c: - for _, elem := range *elems { - elem.Lock() + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) default: return } diff --git a/device/device.go b/device/device.go index 091c8d40a..f9557a075 100644 --- a/device/device.go +++ b/device/device.go @@ -68,11 +68,11 @@ type Device struct { cookieChecker CookieChecker pool struct { - outboundElementsSlice *WaitPool - inboundElementsSlice *WaitPool - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { @@ -368,6 +368,8 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() device.state.Lock() defer device.state.Unlock() if device.isClosed() { diff --git a/device/peer.go b/device/peer.go index 0ac48962c..2fb5da62a 100644 --- a/device/peer.go +++ b/device/peer.go @@ -45,9 +45,9 @@ type Peer struct { } queue struct { - staged chan *[]*QueueOutboundElement // staged packets before a handshake is available - outbound *autodrainingOutboundQueue // sequential ordering of udp transmission - inbound *autodrainingInboundQueue // sequential ordering of tun writing + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device peer.queue.outbound = newAutodrainingOutboundQueue(device) peer.queue.inbound = newAutodrainingInboundQueue(device) - peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key _, ok := device.peers.keyMap[pk] diff --git a/device/pools.go b/device/pools.go index 02a5d6acb..94f3dc7e6 100644 --- a/device/pools.go +++ b/device/pools.go @@ -46,13 +46,13 @@ func (p *WaitPool) Put(x any) { } func (device *Device) PopulatePools() { - device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { - s := make([]*QueueOutboundElement, 0, device.BatchSize()) - return &s - }) - device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) - return &s + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} }) device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) @@ -65,28 +65,32 @@ func (device *Device) PopulatePools() { }) } -func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement { - return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement) +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.outboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) } -func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement { - return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement) +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) { - for i := range *s { - (*s)[i] = nil +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } - *s = (*s)[:0] - device.pool.inboundElementsSlice.Put(s) + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) } func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index 3d80eadb0..25f700a90 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -14,6 +14,6 @@ const ( QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 ) diff --git a/device/receive.go b/device/receive.go index e24d29f5b..4b32dc587 100644 --- a/device/receive.go +++ b/device/receive.go @@ -27,7 +27,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -35,6 +34,11 @@ type QueueInboundElement struct { endpoint conn.Endpoint } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement +} + // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. @@ -87,7 +91,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int - elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) for i := range bufsArrs { @@ -170,15 +174,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetInboundElementsSlice() + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) bufsArrs[i] = device.GetMessageBuffer() bufs[i] = bufsArrs[i][:] continue @@ -217,18 +220,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive default: } } - for peer, elems := range elemsByPeer { + for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { - peer.queue.inbound.c <- elems - for _, elem := range *elems { - device.queue.decryption.c <- elem - } + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } delete(elemsByPeer, peer) } @@ -241,26 +242,28 @@ func (device *Device) RoutineDecryption(id int) { defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) device.log.Verbosef("Routine: decryption worker %d - started", id) - for elem := range device.queue.decryption.c { - // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // decrypt and release to consumer - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - // copy counter to nonce - binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.packet = nil + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { + // split message into fields + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // decrypt and release to consumer + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.packet = nil + } } - elem.Unlock() + elemsContainer.Unlock() } } @@ -437,12 +440,12 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.inbound.c { - if elems == nil { + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return } - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if elem.packet == nil { // decryption failed continue @@ -515,11 +518,11 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutInboundElement(elem) } bufs = bufs[:0] - device.PutInboundElementsSlice(elems) + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index d22bf264e..769720af8 100644 --- a/device/send.go +++ b/device/send.go @@ -17,6 +17,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun" ) @@ -45,7 +46,6 @@ import ( */ type QueueOutboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -53,10 +53,14 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -78,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() - elems := peer.device.GetOutboundElementsSlice() - *elems = append(*elems, elem) + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elems: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -218,7 +222,7 @@ func (device *Device) RoutineReadFromTUN() { readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) - elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageTransportHeaderSize @@ -275,10 +279,10 @@ func (device *Device) RoutineReadFromTUN() { } elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetOutboundElementsSlice() + elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } @@ -288,11 +292,11 @@ func (device *Device) RoutineReadFromTUN() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { - for _, elem := range *elemsForPeer { + for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elemsForPeer) + device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } @@ -316,7 +320,7 @@ func (device *Device) RoutineReadFromTUN() { } } -func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: @@ -325,11 +329,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { } select { case tooOld := <-peer.queue.staged: - for _, elem := range *tooOld { + for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(tooOld) + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -348,54 +352,52 @@ top: } for { - var elemsOOO *[]*QueueOutboundElement + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elems := <-peer.queue.staged: + case elemsContainer := <-peer.queue.staged: i := 0 - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) - if elemsOOO == nil { - elemsOOO = peer.device.GetOutboundElementsSlice() + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - *elemsOOO = append(*elemsOOO, elem) + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { - (*elems)[i] = elem + elemsContainer.elems[i] = elem i++ } elem.keypair = keypair - elem.Lock() } - *elems = (*elems)[:i] + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - if elemsOOO != nil { - peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } - if len(*elems) == 0 { - peer.device.PutOutboundElementsSlice(elems) + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elems - for _, elem := range *elems { - peer.device.queue.encryption.c <- elem - } + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } - if elemsOOO != nil { + if elemsContainerOOO != nil { goto top } default: @@ -407,12 +409,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elems := <-peer.queue.staged: - for _, elem := range *elems { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -446,32 +448,34 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { + // populate header fields + header := elem.buffer[:MessageTransportHeaderSize] - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - // encrypt content and release to consumer + // encrypt content and release to consumer - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + } + elemsContainer.Unlock() } } @@ -485,9 +489,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.outbound.c { + for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] - if elems == nil { + if elemsContainer == nil { return } if !peer.isRunning.Load() { @@ -497,16 +501,16 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } continue } dataSent := false - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true } @@ -520,11 +524,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue diff --git a/go.mod b/go.mod index c04e1bb61..919dc4927 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module golang.zx2c4.com/wireguard go 1.20 require ( - golang.org/x/crypto v0.6.0 - golang.org/x/net v0.7.0 - golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 + golang.org/x/crypto v0.13.0 + golang.org/x/net v0.15.0 + golang.org/x/sys v0.12.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 - gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 ) require ( github.com/google/btree v1.0.1 // indirect - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) diff --git a/go.sum b/go.sum index cfeaee623..6bcecea3f 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,14 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= -golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89 h1:260HNjMTPDya+jq5AM1zZLgG9pv9GASPAGiEEJUbRg4= -golang.org/x/sys v0.5.1-0.20230222185716-a3b23cc77e89/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= -golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 h1:Wobr37noukisGxpKo5jAsLREcpj61RxrWYzD8uwveOY= -gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0/go.mod h1:Dn5idtptoW1dIos9U6A2rpebLs/MtTwFacjKb8jLdQA= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/ipc/uapi_js.go b/ipc/uapi_wasm.go similarity index 80% rename from ipc/uapi_js.go rename to ipc/uapi_wasm.go index 2570515e2..fa84684aa 100644 --- a/ipc/uapi_js.go +++ b/ipc/uapi_wasm.go @@ -5,7 +5,7 @@ package ipc -// Made up sentinel error codes for the js/wasm platform. +// Made up sentinel error codes for {js,wasip1}/wasm. const ( IpcErrorIO = 1 IpcErrorInvalid = 2 diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index 63e1510b1..e397c0e8a 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -1,4 +1,4 @@ -//go:build !windows && !js +//go:build !windows && !wasm /* SPDX-License-Identifier: MIT * diff --git a/rwcancel/rwcancel_stub.go b/rwcancel/rwcancel_stub.go index 182940b32..2a98b2b4a 100644 --- a/rwcancel/rwcancel_stub.go +++ b/rwcancel/rwcancel_stub.go @@ -1,4 +1,4 @@ -//go:build windows || js +//go:build windows || wasm // SPDX-License-Identifier: MIT diff --git a/tun/checksum.go b/tun/checksum.go index f4f847164..29a8fc8fc 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -3,23 +3,99 @@ package tun import "encoding/binary" // TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] } - if n == 1 { - ac += uint64(b[i]) << 8 + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + return ac } diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 000000000..c1ccff531 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index fa15f5361..2b73054b2 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -25,7 +25,7 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" - "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -43,7 +43,7 @@ type netTun struct { ep *channel.Endpoint stack *stack.Stack events chan tun.Event - incomingPacket chan *bufferv2.View + incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool @@ -61,7 +61,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, ep: channel.New(1024, uint32(mtu), ""), stack: stack.New(opts), events: make(chan tun.Event, 10), - incomingPacket: make(chan *bufferv2.View), + incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } @@ -84,7 +84,7 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, - AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { @@ -140,7 +140,7 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { continue } - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) switch packet[0] >> 4 { case 4: tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) @@ -198,7 +198,7 @@ func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.Networ } return tcpip.FullAddress{ NIC: 1, - Addr: tcpip.Address(endpoint.Addr().AsSlice()), + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } @@ -453,7 +453,7 @@ func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } - remoteAddr, _ := netip.AddrFromSlice([]byte(res.RemoteAddr.Addr)) + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } @@ -912,7 +912,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, } } } - // We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) diff --git a/tun/tcp_offload_linux.go b/tun/tcp_offload_linux.go index 39a7180c5..1afd27edf 100644 --- a/tun/tcp_offload_linux.go +++ b/tun/tcp_offload_linux.go @@ -269,11 +269,11 @@ func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool { type coalesceResult int const ( - coalesceInsufficientCap coalesceResult = 0 - coalescePSHEnding coalesceResult = 1 - coalesceItemInvalidCSum coalesceResult = 2 - coalescePktInvalidCSum coalesceResult = 3 - coalesceSuccess coalesceResult = 4 + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess ) // coalesceTCPPackets attempts to coalesce pkt with the packet described by @@ -339,42 +339,6 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize if gsoSize > item.gsoSize { item.gsoSize = gsoSize } - hdr := virtioNetHdr{ - flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(headersLen), - gsoSize: uint16(item.gsoSize), - csumStart: uint16(item.iphLen), - csumOffset: 16, - } - - // Recalculate the total len (IPv4) or payload len (IPv6). Recalculate the - // (IPv4) header checksum. - if isV6 { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pktHead[4:], uint16(coalescedLen)-uint16(item.iphLen)) // set new payload len - } else { - hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 - pktHead[10], pktHead[11] = 0, 0 // clear checksum field - binary.BigEndian.PutUint16(pktHead[2:], uint16(coalescedLen)) // set new total length - iphCSum := ^checksum(pktHead[:item.iphLen], 0) // compute checksum - binary.BigEndian.PutUint16(pktHead[10:], iphCSum) // set checksum field - } - hdr.encode(bufs[item.bufsIndex][bufsOffset-virtioNetHdrLen:]) - - // Calculate the pseudo header checksum and place it at the TCP checksum - // offset. Downstream checksum offloading will combine this with computation - // of the tcp header and payload checksum. - addrLen := 4 - addrOffset := ipv4SrcAddrOffset - if isV6 { - addrLen = 16 - addrOffset = ipv6SrcAddrOffset - } - srcAddrAt := bufsOffset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(coalescedLen-int(item.iphLen))) - binary.BigEndian.PutUint16(pktHead[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) item.numMerged++ return coalesceSuccess @@ -390,43 +354,52 @@ const ( maxUint16 = 1<<16 - 1 ) +type tcpGROResult int + +const ( + tcpGROResultNoop tcpGROResult = iota + tcpGROResultTableInsert + tcpGROResultCoalesced +) + // tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with -// existing packets tracked in table. It will return false when pktI is not -// coalesced, otherwise true. This indicates to the caller if bufs[pktI] -// should be written to the Device. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) (pktCoalesced bool) { +// existing packets tracked in table. It returns a tcpGROResultNoop when no +// action was taken, tcpGROResultTableInsert when the evaluated packet was +// inserted into table, and tcpGROResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult { pkt := bufs[pktI][offset:] if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. - return false + return tcpGROResultNoop } iphLen := int((pkt[0] & 0x0F) * 4) if isV6 { iphLen = 40 ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) if ipv6HPayloadLen != len(pkt)-iphLen { - return false + return tcpGROResultNoop } } else { totalLen := int(binary.BigEndian.Uint16(pkt[2:])) if totalLen != len(pkt) { - return false + return tcpGROResultNoop } } if len(pkt) < iphLen { - return false + return tcpGROResultNoop } tcphLen := int((pkt[iphLen+12] >> 4) * 4) if tcphLen < 20 || tcphLen > 60 { - return false + return tcpGROResultNoop } if len(pkt) < iphLen+tcphLen { - return false + return tcpGROResultNoop } if !isV6 { if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { // no GRO support for fragmented segments for now - return false + return tcpGROResultNoop } } tcpFlags := pkt[iphLen+tcpFlagsOffset] @@ -434,14 +407,14 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // not a candidate if any non-ACK flags (except PSH+ACK) are set if tcpFlags != tcpFlagACK { if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { - return false + return tcpGROResultNoop } pshSet = true } gsoSize := uint16(len(pkt) - tcphLen - iphLen) // not a candidate if payload len is 0 if gsoSize < 1 { - return false + return tcpGROResultNoop } seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) srcAddrOffset := ipv4SrcAddrOffset @@ -452,7 +425,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) if !existing { - return false + return tcpGROResultNoop } for i := len(items) - 1; i >= 0; i-- { // In the best case of packets arriving in order iterating in reverse is @@ -470,20 +443,20 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) switch result { case coalesceSuccess: table.updateAt(item, i) - return true + return tcpGROResultCoalesced case coalesceItemInvalidCSum: // delete the item with an invalid csum table.deleteAt(item.key, i) case coalescePktInvalidCSum: // no point in inserting an item that we can't coalesce - return false + return tcpGROResultNoop default: } } } // failed to coalesce with any other packets; store the item in the flow table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) - return false + return tcpGROResultTableInsert } func isTCP4NoIPOptions(b []byte) bool { @@ -515,6 +488,64 @@ func isTCP6NoEH(b []byte) bool { return true } +// applyCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + // handleGRO evaluates bufs for GRO, and writes the indices of the resulting // packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset @@ -524,23 +555,28 @@ func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toW if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } - var coalesced bool + var result tcpGROResult switch { case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp4Table, false) + result = tcpGRO(bufs, offset, i, tcp4Table, false) case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce - coalesced = tcpGRO(bufs, offset, i, tcp6Table, true) + result = tcpGRO(bufs, offset, i, tcp6Table, true) } - if !coalesced { + switch result { + case tcpGROResultNoop: hdr := virtioNetHdr{} err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) if err != nil { return err } + fallthrough + case tcpGROResultTableInsert: *toWrite = append(*toWrite, i) } } - return nil + err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false) + err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true) + return errors.Join(err4, err6) } // tcpTSO splits packets from in into outBuffs, writing the size of each diff --git a/tun/tcp_offload_linux_test.go b/tun/tcp_offload_linux_test.go index 9160e18cd..ddddc4868 100644 --- a/tun/tcp_offload_linux_test.go +++ b/tun/tcp_offload_linux_test.go @@ -35,8 +35,8 @@ func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header. srcAs4 := srcIPPort.Addr().As4() dstAs4 := dstIPPort.Addr().As4() ipFields := &header.IPv4Fields{ - SrcAddr: tcpip.Address(srcAs4[:]), - DstAddr: tcpip.Address(dstAs4[:]), + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), Protocol: unix.IPPROTO_TCP, TTL: 64, TotalLength: uint16(totalLen), @@ -72,8 +72,8 @@ func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header. srcAs16 := srcIPPort.Addr().As16() dstAs16 := dstIPPort.Addr().As16() ipFields := &header.IPv6Fields{ - SrcAddr: tcpip.Address(srcAs16[:]), - DstAddr: tcpip.Address(dstAs16[:]), + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), TransportProtocol: unix.IPPROTO_TCP, HopLimit: 64, PayloadLength: uint16(segmentSize + 20), diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 0cb4ce192..34f29805d 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -127,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update {