diff --git a/README.md b/README.md index c9f7dff..db74ca0 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,10 @@ GLOBAL OPTIONS: --server-json value Use an alternative server list from remote JSON file --local-json value Use an alternative server list from local JSON file, or read from stdin with "--local-json -". - --source SOURCE SOURCE IP address to bind to + --source SOURCE SOURCE IP address to bind to. Incompatible with --interface. + --interface INTERFACE The name of the network interface to bind to. Example: "enp0s3". + Not supported on Windows and incompatible with --source. + Implies --no-icmp. --timeout TIMEOUT HTTP TIMEOUT in seconds. (default: 15) --duration value Upload and download test duration in seconds (default: 15) --chunks value Chunks to download from server, chunk size depends on server configuration (default: 100) diff --git a/defs/options.go b/defs/options.go index 255a5b1..bc7e506 100644 --- a/defs/options.go +++ b/defs/options.go @@ -24,6 +24,7 @@ const ( OptionExclude = "exclude" OptionServerJSON = "server-json" OptionSource = "source" + OptionInterface = "interface" OptionTimeout = "timeout" OptionChunks = "chunks" OptionUploadSize = "upload-size" diff --git a/main.go b/main.go index a60b437..5cb6643 100644 --- a/main.go +++ b/main.go @@ -139,6 +139,10 @@ func main() { Name: defs.OptionSource, Usage: "`SOURCE` IP address to bind to", }, + &cli.StringFlag{ + Name: defs.OptionInterface, + Usage: "network INTERFACE to bind to", + }, &cli.IntFlag{ Name: defs.OptionTimeout, Usage: "HTTP `TIMEOUT` in seconds.", diff --git a/speedtest/helper.go b/speedtest/helper.go index ddc9bc1..4d1d5d7 100644 --- a/speedtest/helper.go +++ b/speedtest/helper.go @@ -28,7 +28,7 @@ const ( ) // doSpeedTest is where the actual speed test happens -func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.TelemetryServer, network string, silent bool) error { +func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.TelemetryServer, network string, silent bool, noICMP bool) error { if serverCount := len(servers); serverCount > 1 { log.Infof("Testing against %d servers", serverCount) } @@ -70,7 +70,7 @@ func doSpeedTest(c *cli.Context, servers []defs.Server, telemetryServer defs.Tel } // skip ICMP if option given - currentServer.NoICMP = c.Bool(defs.OptionNoICMP) + currentServer.NoICMP = noICMP p, jitter, err := currentServer.ICMPPingAndJitter(pingCount, c.String(defs.OptionSource), network) if err != nil { diff --git a/speedtest/speedtest.go b/speedtest/speedtest.go index 8a68092..93e815f 100644 --- a/speedtest/speedtest.go +++ b/speedtest/speedtest.go @@ -74,6 +74,10 @@ func SpeedTest(c *cli.Context) error { return nil } + if c.String(defs.OptionSource) != "" && c.String(defs.OptionInterface) != "" { + return fmt.Errorf("incompatible options '%s' and '%s'", defs.OptionSource, defs.OptionInterface) + } + // set CSV delimiter gocsv.TagSeparator = c.String(defs.OptionCSVDelimiter) @@ -138,6 +142,8 @@ func SpeedTest(c *cli.Context) error { return errors.New("invalid concurrent requests setting") } + noICMP := c.Bool(defs.OptionNoICMP) + // HTTP requests timeout http.DefaultClient.Timeout = time.Duration(c.Int(defs.OptionTimeout)) * time.Second @@ -157,57 +163,48 @@ func SpeedTest(c *cli.Context) error { transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: c.Bool(defs.OptionSkipCertVerify)} - // bind to source IP address if given, or if ipv4/ipv6 is forced - if src := c.String(defs.OptionSource); src != "" || (forceIPv4 || forceIPv6) { - var localTCPAddr *net.TCPAddr - if src != "" { - // first we parse the IP to see if it's valid - addr, err := net.ResolveIPAddr(network, src) - if err != nil { - if strings.Contains(err.Error(), "no suitable address") { - if forceIPv6 { - log.Errorf("Address %s is not a valid IPv6 address", src) - } else { - log.Errorf("Address %s is not a valid IPv4 address", src) - } - } else { - log.Errorf("Error parsing source IP: %s", err) - } - return err - } - - log.Debugf("Using %s as source IP", src) - localTCPAddr = &net.TCPAddr{IP: addr.IP} + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + // bind to source IP address if given + if src := c.String(defs.OptionSource); src != "" { + var err error + dialer, err = newDialerAddressBound(src, network) + if err != nil { + return err } + } - var dialContext func(context.Context, string, string) (net.Conn, error) - defaultDialer := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, + // bind to interface if given + if iface := c.String(defs.OptionInterface); iface != "" { + var err error + dialer, err = newDialerInterfaceBound(iface) + if err != nil { + return err } + // ICMP ping does not support interface binding. + noICMP = true + } - if localTCPAddr != nil { - defaultDialer.LocalAddr = localTCPAddr + // enforce if ipv4/ipv6 is forced + var dialContext func(context.Context, string, string) (net.Conn, error) + switch { + case forceIPv4: + dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp4", address) } - - switch { - case forceIPv4: - dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) { - return defaultDialer.DialContext(ctx, "tcp4", address) - } - case forceIPv6: - dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) { - return defaultDialer.DialContext(ctx, "tcp6", address) - } - default: - dialContext = defaultDialer.DialContext + case forceIPv6: + dialContext = func(ctx context.Context, network, address string) (conn net.Conn, err error) { + return dialer.DialContext(ctx, "tcp6", address) } - - // set default HTTP client's Transport to the one that binds the source address - // this is modified from http.DefaultTransport - transport.DialContext = dialContext + default: + dialContext = dialer.DialContext } + // set default HTTP client's Transport to the one that binds the source address + // this is modified from http.DefaultTransport + transport.DialContext = dialContext http.DefaultClient.Transport = transport // load server list @@ -258,7 +255,7 @@ func SpeedTest(c *cli.Context) error { // if --server is given, do speed tests with all of them if len(c.IntSlice(defs.OptionServer)) > 0 { - return doSpeedTest(c, servers, telemetryServer, network, silent) + return doSpeedTest(c, servers, telemetryServer, network, silent, noICMP) } else { // else select the fastest server from the list log.Info("Selecting the fastest server based on ping") @@ -272,7 +269,7 @@ func SpeedTest(c *cli.Context) error { // spawn 10 concurrent pingers for i := 0; i < 10; i++ { - go pingWorker(jobs, results, &wg, c.String(defs.OptionSource), network, c.Bool(defs.OptionNoICMP)) + go pingWorker(jobs, results, &wg, c.String(defs.OptionSource), network, noICMP) } // send ping jobs to workers @@ -309,7 +306,7 @@ func SpeedTest(c *cli.Context) error { } // do speed test on the server - return doSpeedTest(c, []defs.Server{servers[serverIdx]}, telemetryServer, network, silent) + return doSpeedTest(c, []defs.Server{servers[serverIdx]}, telemetryServer, network, silent, noICMP) } } @@ -474,3 +471,31 @@ func contains(arr []int, val int) bool { } return false } + +func newDialerAddressBound(src string, network string) (dialer *net.Dialer, err error) { + // first we parse the IP to see if it's valid + addr, err := net.ResolveIPAddr(network, src) + if err != nil { + if strings.Contains(err.Error(), "no suitable address") { + if network == "ip6" { + log.Errorf("Address %s is not a valid IPv6 address", src) + } else { + log.Errorf("Address %s is not a valid IPv4 address", src) + } + } else { + log.Errorf("Error parsing source IP: %s", err) + } + return nil, err + } + + log.Debugf("Using %s as source IP", src) + localTCPAddr := &net.TCPAddr{IP: addr.IP} + + defaultDialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + + defaultDialer.LocalAddr = localTCPAddr + return defaultDialer, nil +} diff --git a/speedtest/util_linux.go b/speedtest/util_linux.go new file mode 100644 index 0000000..e2a0a64 --- /dev/null +++ b/speedtest/util_linux.go @@ -0,0 +1,32 @@ +package speedtest + +import ( + "net" + "syscall" + "time" + + "golang.org/x/sys/unix" +) + +func newDialerInterfaceBound(iface string) (dialer *net.Dialer, err error) { + // In linux there is the socket option SO_BINDTODEVICE. + // Therefore we can really bind the socket to the device instead of binding to the address that + // would be affected by the default routes. + control := func(network, address string, c syscall.RawConn) error { + var errSock error + err := c.Control((func(fd uintptr) { + errSock = unix.BindToDevice(int(fd), iface) + })) + if err != nil { + return err + } + return errSock + } + + dialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Control: control, + } + return dialer, nil +} diff --git a/speedtest/util_windows.go b/speedtest/util_windows.go new file mode 100644 index 0000000..c4a0cd0 --- /dev/null +++ b/speedtest/util_windows.go @@ -0,0 +1,10 @@ +package speedtest + +import ( + "fmt" + "net" +) + +func newDialerInterfaceBound(iface string) (dialer *net.Dialer, err error) { + return nil, fmt.Errorf("cannot bound to interface on Windows") +}