Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for UDP proxying #154

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ type DeviceConfig struct {
CheckAliveInterval int
}

type UDPProxyTunnelConfig struct {
BindAddress string
Target string
InactivityTimeout int
}

type TCPClientTunnelConfig struct {
BindAddress *net.TCPAddr
Target string
Expand Down Expand Up @@ -434,6 +440,34 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) {
return config, nil
}

func parseUDPProxyTunnelConfig(section *ini.Section) (RoutineSpawner, error) {
config := &UDPProxyTunnelConfig{}

bindAddress, err := parseString(section, "BindAddress")
if err != nil {
return nil, err
}
config.BindAddress = bindAddress

target, err := parseString(section, "Target")
if err != nil {
return nil, err
}
config.Target = target

inactivityTimeout := 0
if sectionKey, err := section.GetKey("InactivityTimeout"); err == nil {
timeoutVal, err := sectionKey.Int()
if err != nil {
return nil, err
}
inactivityTimeout = timeoutVal
}
config.InactivityTimeout = inactivityTimeout

return config, nil
}

// Takes a function that parses an individual section into a config, and apply it on all
// specified sections
func parseRoutinesConfig(routines *[]RoutineSpawner, cfg *ini.File, sectionName string, f func(*ini.Section) (RoutineSpawner, error)) error {
Expand Down Expand Up @@ -518,6 +552,11 @@ func ParseConfig(path string) (*Configuration, error) {
return nil, err
}

err = parseRoutinesConfig(&routinesSpawners, cfg, "UDPProxyTunnel", parseUDPProxyTunnelConfig)
if err != nil {
return nil, err
}

return &Configuration{
Device: device,
Routines: routinesSpawners,
Expand Down
160 changes: 160 additions & 0 deletions udp_proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package wireproxy

import (
"fmt"
"log"
"net"
"sync"
"time"
)

// udpSession represents a UDP forwarding session, keyed by the local source address.
// remoteConn is the UDP connection to the remote endpoint (on the WireGuard side).
type udpSession struct {
remoteConn net.Conn
lastActive time.Time
closeChan chan struct{}
inactivityDur time.Duration
}

// SpawnRoutine implements the RoutineSpawner interface.
// It starts listening on config.BindAddress, handling each unique source (client) address
// with its own udpSession. If InactivityTimeout > 0, sessions automatically close after inactivity
func (conf *UDPProxyTunnelConfig) SpawnRoutine(vt *VirtualTun) {
addr, err := net.ResolveUDPAddr("udp", conf.BindAddress)
if err != nil {
log.Fatalf("UDPProxyTunnelConfig: could not resolve bind address %s: %v", conf.BindAddress, err)
}

listener, err := net.ListenUDP("udp", addr)
if err != nil {
log.Fatalf("UDPProxyTunnelConfig: could not listen on %s: %v", conf.BindAddress, err)
}
log.Printf("UDPProxyTunnel listening on %s, forwarding to %s", conf.BindAddress, conf.Target)

inactivityDur := time.Duration(conf.InactivityTimeout) * time.Second
sessions := make(map[string]*udpSession)
var sessionMu sync.Mutex

// Periodically clean up expired sessions if inactivity timeout is enabled
if conf.InactivityTimeout > 0 {
go func() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
sessionMu.Lock()
for key, sess := range sessions {
if now.Sub(sess.lastActive) >= inactivityDur {
log.Printf("UDPProxyTunnel: closing inactive session for %s", key)
close(sess.closeChan)
delete(sessions, key)
}
}
sessionMu.Unlock()
}
}()
}

// Create or get a UDP session based on the local source address
getOrCreateSession := func(srcAddr string) (*udpSession, error) {
sessionMu.Lock()
defer sessionMu.Unlock()

// return if session already exists
if s, ok := sessions[srcAddr]; ok {
s.lastActive = time.Now()
return s, nil
}

// Create a new session
remoteConn, err := vt.Tnet.Dial("udp", conf.Target)
if err != nil {
return nil, fmt.Errorf("UDPProxyTunnel: could not Dial(%s): %w", conf.Target, err)
}

s := &udpSession{
remoteConn: remoteConn,
lastActive: time.Now(),
closeChan: make(chan struct{}),
inactivityDur: inactivityDur,
}
sessions[srcAddr] = s

// Spin up a goroutine to handle traffic from remote -> local
go conf.handleRemoteToLocal(listener, srcAddr, s)
return s, nil
}

// Main loop to read from local client and forward to remote
go func() {
buf := make([]byte, 64*1024) // typical max UDP size
for {
n, src, err := listener.ReadFromUDP(buf)
if err != nil {
log.Printf("UDPProxyTunnel: error reading from UDP: %v", err)
continue
}

srcKey := src.String() // identify session by the local client's IP:port
s, err := getOrCreateSession(srcKey)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: getOrCreateSession failed for %s: %v", srcKey, err)
continue
}

s.lastActive = time.Now()
_, err = s.remoteConn.Write(buf[:n])
if err != nil {
errorLogger.Printf("UDPProxyTunnel: could not write to remote (%s): %v", conf.Target, err)
}
}
}()
}

// handles data from the remote WireGuard side back to the local client
// this function blocks until the session is closed
func (conf *UDPProxyTunnelConfig) handleRemoteToLocal(listener *net.UDPConn, srcAddr string, s *udpSession) {
defer func() {
_ = s.remoteConn.Close()
}()
buf := make([]byte, 64*1024)

for {
select {
case <-s.closeChan:
return
default:
}

_ = s.remoteConn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := s.remoteConn.Read(buf)
if err != nil {
// If a timeout or temporary error, continue to see if the session is closed
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
select {
case <-s.closeChan:
return
default:
continue
}
}
errorLogger.Printf("UDPProxyTunnel: read error from remote: %v", err)
return
}

s.lastActive = time.Now()

dstUDPAddr, err := net.ResolveUDPAddr("udp", srcAddr)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: cannot resolve local address %s: %v", srcAddr, err)
return
}

_, err = listener.WriteToUDP(buf[:n], dstUDPAddr)
if err != nil {
errorLogger.Printf("UDPProxyTunnel: cannot write to local %s: %v", srcAddr, err)
return
}
}
}