Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface option #65

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ GLOBAL OPTIONS:
--server-json value Use an alternative server list from remote JSON file
--local-json value Use an alternative server list from local JSON file,
or read from stdin with "--local-json -".
--source SOURCE SOURCE IP address to bind to
--source SOURCE SOURCE IP address to bind to. Incompatible with --interface.
--interface INTERFACE The name of the network interface to bind to. Example: "enp0s3".
Not supported on Windows and incompatible with --source.
Implies --no-icmp.
--timeout TIMEOUT HTTP TIMEOUT in seconds. (default: 15)
--duration value Upload and download test duration in seconds (default: 15)
--chunks value Chunks to download from server, chunk size depends on server configuration (default: 100)
Expand Down
1 change: 1 addition & 0 deletions defs/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
OptionExclude = "exclude"
OptionServerJSON = "server-json"
OptionSource = "source"
OptionInterface = "interface"
OptionTimeout = "timeout"
OptionChunks = "chunks"
OptionUploadSize = "upload-size"
Expand Down
4 changes: 4 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ func main() {
Name: defs.OptionSource,
Usage: "`SOURCE` IP address to bind to",
},
&cli.StringFlag{
Name: defs.OptionInterface,
Usage: "network INTERFACE to bind to",
},
&cli.IntFlag{
Name: defs.OptionTimeout,
Usage: "HTTP `TIMEOUT` in seconds.",
Expand Down
4 changes: 2 additions & 2 deletions speedtest/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (
)

// doSpeedTest is where the actual speed test happens
func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.TelemetryServer, network string, silent bool) error {
func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.TelemetryServer, network string, silent bool, noICMP bool) error {
if serverCount := len(servers); serverCount > 1 {
log.Infof("Testing against %d servers", serverCount)
}
Expand Down Expand Up @@ -70,7 +70,7 @@ func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.Tel
}

// skip ICMP if option given
currentServer.NoICMP = c.Bool(defs.OptionNoICMP)
currentServer.NoICMP = noICMP

p, jitter, err := currentServer.ICMPPingAndJitter(pingCount, c.String(defs.OptionSource), network)
if err != nil {
Expand Down
117 changes: 71 additions & 46 deletions speedtest/speedtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ func SpeedTest(c *cli.Context) error {
return nil
}

if c.String(defs.OptionSource) != "" && c.String(defs.OptionInterface) != "" {
return fmt.Errorf("incompatible options '%s' and '%s'", defs.OptionSource, defs.OptionInterface)
}

// set CSV delimiter
gocsv.TagSeparator = c.String(defs.OptionCSVDelimiter)

Expand Down Expand Up @@ -138,6 +142,8 @@ func SpeedTest(c *cli.Context) error {
return errors.New("invalid concurrent requests setting")
}

noICMP := c.Bool(defs.OptionNoICMP)

// HTTP requests timeout
http.DefaultClient.Timeout = time.Duration(c.Int(defs.OptionTimeout)) * time.Second

Expand All @@ -157,57 +163,48 @@ func SpeedTest(c *cli.Context) error {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: c.Bool(defs.OptionSkipCertVerify)}

// bind to source IP address if given, or if ipv4/ipv6 is forced
if src := c.String(defs.OptionSource); src != "" || (forceIPv4 || forceIPv6) {
var localTCPAddr *net.TCPAddr
if src != "" {
// first we parse the IP to see if it's valid
addr, err := net.ResolveIPAddr(network, src)
if err != nil {
if strings.Contains(err.Error(), "no suitable address") {
if forceIPv6 {
log.Errorf("Address %s is not a valid IPv6 address", src)
} else {
log.Errorf("Address %s is not a valid IPv4 address", src)
}
} else {
log.Errorf("Error parsing source IP: %s", err)
}
return err
}

log.Debugf("Using %s as source IP", src)
localTCPAddr = &net.TCPAddr{IP: addr.IP}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
// bind to source IP address if given
if src := c.String(defs.OptionSource); src != "" {
var err error
dialer, err = newDialerAddressBound(src, network)
if err != nil {
return err
}
}

var dialContext func(context.Context, string, string) (net.Conn, error)
defaultDialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
// bind to interface if given
if iface := c.String(defs.OptionInterface); iface != "" {
var err error
dialer, err = newDialerInterfaceBound(iface)
if err != nil {
return err
}
// ICMP ping does not support interface binding.
noICMP = true
}

if localTCPAddr != nil {
defaultDialer.LocalAddr = localTCPAddr
// enforce if ipv4/ipv6 is forced
var dialContext func(context.Context, string, string) (net.Conn, error)
switch {
case forceIPv4:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return dialer.DialContext(ctx, "tcp4", address)
}

switch {
case forceIPv4:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return defaultDialer.DialContext(ctx, "tcp4", address)
}
case forceIPv6:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return defaultDialer.DialContext(ctx, "tcp6", address)
}
default:
dialContext = defaultDialer.DialContext
case forceIPv6:
dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) {
return dialer.DialContext(ctx, "tcp6", address)
}

// set default HTTP client's Transport to the one that binds the source address
// this is modified from http.DefaultTransport
transport.DialContext = dialContext
default:
dialContext = dialer.DialContext
}

// set default HTTP client's Transport to the one that binds the source address
// this is modified from http.DefaultTransport
transport.DialContext = dialContext
http.DefaultClient.Transport = transport

// load server list
Expand Down Expand Up @@ -258,7 +255,7 @@ func SpeedTest(c *cli.Context) error {

// if --server is given, do speed tests with all of them
if len(c.IntSlice(defs.OptionServer)) > 0 {
return doSpeedTest(c, servers, telemetryServer, network, silent)
return doSpeedTest(c, servers, telemetryServer, network, silent, noICMP)
} else {
// else select the fastest server from the list
log.Info("Selecting the fastest server based on ping")
Expand All @@ -272,7 +269,7 @@ func SpeedTest(c *cli.Context) error {

// spawn 10 concurrent pingers
for i := 0; i < 10; i++ {
go pingWorker(jobs, results, &wg, c.String(defs.OptionSource), network, c.Bool(defs.OptionNoICMP))
go pingWorker(jobs, results, &wg, c.String(defs.OptionSource), network, noICMP)
}

// send ping jobs to workers
Expand Down Expand Up @@ -309,7 +306,7 @@ func SpeedTest(c *cli.Context) error {
}

// do speed test on the server
return doSpeedTest(c, []defs.Server{servers[serverIdx]}, telemetryServer, network, silent)
return doSpeedTest(c, []defs.Server{servers[serverIdx]}, telemetryServer, network, silent, noICMP)
}
}

Expand Down Expand Up @@ -474,3 +471,31 @@ func contains(arr []int, val int) bool {
}
return false
}

func newDialerAddressBound(src string, network string) (dialer *net.Dialer, err error) {
// first we parse the IP to see if it's valid
addr, err := net.ResolveIPAddr(network, src)
if err != nil {
if strings.Contains(err.Error(), "no suitable address") {
if network == "ip6" {
log.Errorf("Address %s is not a valid IPv6 address", src)
} else {
log.Errorf("Address %s is not a valid IPv4 address", src)
}
} else {
log.Errorf("Error parsing source IP: %s", err)
}
return nil, err
}

log.Debugf("Using %s as source IP", src)
localTCPAddr := &net.TCPAddr{IP: addr.IP}

defaultDialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}

defaultDialer.LocalAddr = localTCPAddr
return defaultDialer, nil
}
32 changes: 32 additions & 0 deletions speedtest/util_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package speedtest

import (
"net"
"syscall"
"time"

"golang.org/x/sys/unix"
)

func newDialerInterfaceBound(iface string) (dialer *net.Dialer, err error) {
// In linux there is the socket option SO_BINDTODEVICE.
// Therefore we can really bind the socket to the device instead of binding to the address that
// would be affected by the default routes.
control := func(network, address string, c syscall.RawConn) error {
var errSock error
err := c.Control((func(fd uintptr) {
errSock = unix.BindToDevice(int(fd), iface)
}))
if err != nil {
return err
}
return errSock
}

dialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: control,
}
return dialer, nil
}
10 changes: 10 additions & 0 deletions speedtest/util_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package speedtest

import (
"fmt"
"net"
)

func newDialerInterfaceBound(iface string) (dialer *net.Dialer, err error) {
return nil, fmt.Errorf("cannot bound to interface on Windows")
}