Skip to content

Commit

Permalink
Copy UDP GSO support from tailscale
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 22, 2024
1 parent 55a70eb commit 8355396
Show file tree
Hide file tree
Showing 11 changed files with 1,312 additions and 819 deletions.
2 changes: 1 addition & 1 deletion stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
_, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
Expand Down
6 changes: 4 additions & 2 deletions tun.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package tun

import (
"github.com/sagernet/sing/common/control"
"io"
"net"
"net/netip"
"runtime"
"strconv"
"strings"

"github.com/sagernet/sing/common/control"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
Expand Down Expand Up @@ -39,7 +39,9 @@ type LinuxTUN interface {
N.FrontHeadroom
BatchSize() int
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
BatchWrite(buffers [][]byte, offset int) (n int, err error)
DisableUDPGRO()
DisableTCPGRO()
TXChecksumOffload() bool
}

Expand Down
2 changes: 1 addition & 1 deletion tun_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package tun
import (
"errors"
"fmt"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"net"
"net/netip"
"os"
"syscall"
"unsafe"

"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
Expand Down
154 changes: 118 additions & 36 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tun

import (
"errors"
"fmt"
"math/rand"
"net"
"net/netip"
Expand Down Expand Up @@ -35,13 +36,15 @@ type NativeTun struct {
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
options Options
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
readAccess sync.Mutex
writeAccess sync.Mutex
vnetHdr bool
writeBuffer []byte
gsoToWrite []int
gsoReadAccess sync.Mutex
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
tcpGROTable *tcpGROTable
udpGroAccess sync.Mutex
udpGROTable *udpGROTable
gro groDisablementFlags
txChecksumOffload bool
}

Expand Down Expand Up @@ -81,20 +84,23 @@ func New(options Options) (Tun, error) {
}

func (t *NativeTun) FrontHeadroom() int {
if t.gsoEnabled {
if t.vnetHdr {
return virtioNetHdrLen
}
return 0
}

func (t *NativeTun) Read(p []byte) (n int, err error) {
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if t.vnetHdr {
n, err = t.tunFile.Read(t.writeBuffer)
if err != nil {
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
return
}
var sizes [1]int
n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0)
n, err = handleVirtioRead(t.writeBuffer[:n], [][]byte{p}, sizes[:], 0)
if err != nil {
return
}
Expand All @@ -108,9 +114,50 @@ func (t *NativeTun) Read(p []byte) (n int, err error) {
}
}

// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]

options, err := hdr.toGSOOptions()
if err != nil {
return 0, err
}

// Don't trust HdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the transport header length and add it onto
// CsumStart, which is synonymous for IP header length.
if options.GSOType == GSOUDPL4 {
options.HdrLen = options.CsumStart + 8
} else if options.GSOType != GSONone {
if len(in) <= int(options.CsumStart+12) {
return 0, errors.New("packet is too short")
}

tcpHLen := uint16(in[options.CsumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
options.HdrLen = options.CsumStart + tcpHLen
}

return GSOSplit(in, options, bufs, sizes, offset)
}

func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
if t.vnetHdr {
buffer := buf.Get(virtioNetHdrLen + len(p))
copy(buffer[virtioNetHdrLen:], p)
_, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen)
buf.Put(buffer)
if err != nil {
return
}
Expand All @@ -121,7 +168,7 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
}

func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
if t.gsoEnabled {
if t.vnetHdr {
n := buf.LenMulti(buffers)
buffer := buf.NewSize(virtioNetHdrLen + n)
buffer.Truncate(virtioNetHdrLen)
Expand All @@ -135,7 +182,7 @@ func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
}

func (t *NativeTun) BatchSize() int {
if !t.gsoEnabled {
if !t.vnetHdr {
return 1
}
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
Expand All @@ -147,36 +194,67 @@ func (t *NativeTun) BatchSize() int {
return idealBatchSize
}

// DisableUDPGRO disables UDP GRO if it is enabled. See the GRODevice interface
// for cases where it should be called.
func (t *NativeTun) DisableUDPGRO() {
t.writeAccess.Lock()
t.gro.disableUDPGRO()
t.writeAccess.Unlock()
}

// DisableTCPGRO disables TCP GRO if it is enabled. See the GRODevice interface
// for cases where it should be called.
func (t *NativeTun) DisableTCPGRO() {
t.writeAccess.Lock()
t.gro.disableTCPGRO()
t.writeAccess.Unlock()
}

func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.gsoReadAccess.Lock()
defer t.gsoReadAccess.Unlock()
n, err = t.tunFile.Read(t.gsoBuffer)
t.readAccess.Lock()
defer t.readAccess.Unlock()
n, err = t.tunFile.Read(t.writeBuffer)
if err != nil {
return
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
return handleVirtioRead(t.writeBuffer[:n], buffers, readN, offset)
}

func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
t.tcpGROAccess.Lock()
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) (int, error) {
t.writeAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
t.tcpGROTable.reset()
t.udpGROTable.reset()
t.writeAccess.Unlock()
}()
var (
errs error
total int
)
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
if t.vnetHdr {
err := handleGRO(buffers, offset, t.tcpGROTable, t.udpGROTable, t.gro, &t.gsoToWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else {
for i := range buffers {
t.gsoToWrite = append(t.gsoToWrite, i)
}
}
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
for _, toWrite := range t.gsoToWrite {
n, err := t.tunFile.Write(buffers[toWrite][offset:])
if errors.Is(err, syscall.EBADFD) {
return total, os.ErrClosed
}
if err != nil {
return err
errs = errors.Join(errs, err)
} else {
total += n
}
}
return nil
return total, errs
}

var controlPath string
Expand Down Expand Up @@ -262,10 +340,14 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
if err != nil {
return err
}
t.gsoEnabled = true
t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
t.tcp4GROTable = newTCPGROTable()
t.tcp6GROTable = newTCPGROTable()
t.vnetHdr = true
t.writeBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
t.tcpGROTable = newTCPGROTable()
t.udpGROTable = newUDPGROTable()
err = setUDPOffload(t.tunFd)
if err != nil {
t.gro.disableUDPGRO()
}
}

var rxChecksumOffload bool
Expand All @@ -280,7 +362,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
if err != nil {
return err
}
if err == nil && !txChecksumOffload {
if !txChecksumOffload {
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
if err != nil {
return err
Expand Down
16 changes: 11 additions & 5 deletions tun_linux_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (
"golang.org/x/sys/unix"
)

const (
// TODO: support TSO with ECN bits
tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
)

func checkVNETHDREnabled(fd int, name string) (bool, error) {
ifr, err := unix.NewIfreq(name)
if err != nil {
Expand All @@ -25,17 +31,17 @@ func checkVNETHDREnabled(fd int, name string) (bool, error) {
}

func setTCPOffload(fd int) error {
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads)
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads)
if err != nil {
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
}
return nil
}

func setUDPOffload(fd int) error {
return unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads)
}

type ifreqData struct {
ifrName [unix.IFNAMSIZ]byte
ifrData uintptr
Expand Down
2 changes: 1 addition & 1 deletion tun_linux_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
var _ GVisorTun = (*NativeTun)(nil)

func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
if t.gsoEnabled {
if t.vnetHdr {
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
Expand Down
Loading

0 comments on commit 8355396

Please sign in to comment.