From 3bbd128902b808515b03fe13fa6edc7276d6a079 Mon Sep 17 00:00:00 2001 From: "Alex Ellis (OpenFaaS Ltd)" Date: Sun, 11 Sep 2022 21:23:37 +0100 Subject: [PATCH] Add a shorter timeout and close connections earlier Signed-off-by: Alex Ellis (OpenFaaS Ltd) --- README.md | 4 ++ main.go | 131 ++++++++++++++++++++++++++++++++++----------- rules.example.yaml | 5 ++ 3 files changed, 108 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 09d615e..3817057 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,10 @@ To make the upstream address listen on all interfaces, use `0.0.0.0` instead of The port for the from and to addresses do not need to match. +See also: +* `-t` - specify the dial timeout for an upstream host in the "to" field of the config file. +* `-v` - verbose logging - set to false to turn off logs of connections established and closed. + ## License This software is licensed MIT. diff --git a/main.go b/main.go index 2146514..9465170 100644 --- a/main.go +++ b/main.go @@ -28,45 +28,62 @@ type Rule struct { func main() { var ( - file string + file string + verbose bool + dialTimeout time.Duration ) flag.StringVar(&file, "f", "", "Job to run or leave blank for job.yaml in current directory") - + flag.BoolVar(&verbose, "v", true, "Verbose output for opened and closed connections") + flag.DurationVar(&dialTimeout, "t", time.Millisecond*1500, "Dial timeout") flag.Parse() + if len(file) == 0 { + fmt.Fprintf(os.Stderr, "usage: mixctl -f rules.yaml\n") + os.Exit(1) + } + set := ForwardingSet{} data, err := os.ReadFile(file) if err != nil { - log.Fatalf("error reading file %s %s", file, err.Error()) + fmt.Fprintf(os.Stderr, "error reading file %s %s", file, err.Error()) + os.Exit(1) } if err = yaml.Unmarshal(data, &set); err != nil { - log.Fatalf("error parsing file %s %s", file, err.Error()) + fmt.Fprintf(os.Stderr, "error parsing file %s %s", file, err.Error()) + os.Exit(1) + } + + if len(set.Rules) == 0 { + fmt.Fprintf(os.Stderr, "no rules found in file %s", file) + os.Exit(1) } - fmt.Printf("mixctl by inlets..\n") + fmt.Printf("Starting mixctl by https://inlets.dev/\n\n") wg := sync.WaitGroup{} wg.Add(len(set.Rules)) - for _, f := range set.Rules { - - r := f - go func(rule *Rule) { - fmt.Printf("Forward (%s) from: %s to: %s\n", rule.Name, rule.From, rule.To) + for _, rule := range set.Rules { + fmt.Printf("Forward (%s) from: %s to: %s\n", rule.Name, rule.From, rule.To) + } + fmt.Println() - if err := forward(rule.Name, rule.From, rule.To); err != nil { + for _, rule := range set.Rules { + // Copy the value to avoid the loop variable being reused + r := rule + go func() { + if err := forward(r.Name, r.From, r.To, verbose, dialTimeout); err != nil { log.Printf("error forwarding %s", err.Error()) os.Exit(1) } - defer wg.Done() - }(&r) + }() } - wg.Wait() + wg.Wait() } -func forward(name, from string, to []string) error { +func forward(name, from string, to []string, verbose bool, dialTimeout time.Duration) error { seed := time.Now().UnixNano() rand.Seed(seed) @@ -76,42 +93,92 @@ func forward(name, from string, to []string) error { return fmt.Errorf("error listening on %s %s", from, err.Error()) } + defer l.Close() + for { - conn, err := l.Accept() + // accept a connection on the local port of the load balancer + local, err := l.Accept() if err != nil { return fmt.Errorf("error accepting connection %s", err.Error()) } + // pick randomly from the list of upstream servers + // available index := rand.Intn(len(to)) + upstream := to[index] - remote, err := net.Dial("tcp", to[index]) - if err != nil { - return fmt.Errorf("error dialing %s %s", to[index], err.Error()) - } + // A separate Goroutine means the loop can accept another + // incoming connection on the local address + go connect(local, upstream, from, verbose, dialTimeout) + } +} - go func() { - log.Printf("[%s] %s => %s", - from, - conn.RemoteAddr().String(), - remote.RemoteAddr().String()) - if err := forwardConnection(conn, remote); err != nil && err.Error() != "done" { - log.Printf("error forwarding connection %s", err.Error()) - } - }() +// connect dials the upstream address, then copies data +// between it and connection accepted on a local port +func connect(local net.Conn, upstreamAddr, from string, verbose bool, dialTimeout time.Duration) { + defer local.Close() + + // If Dial is used on its own, then the timeout can be as long + // as 2 minutes on MacOS for an unreachable host + upstream, err := net.DialTimeout("tcp", upstreamAddr, dialTimeout) + if err != nil { + log.Printf("error dialing %s %s", upstreamAddr, err.Error()) + return + } + defer upstream.Close() + + if verbose { + log.Printf("Connected %s => %s (%s)", + from, + upstream.RemoteAddr().String(), + local.RemoteAddr().String()) + } + + ctx := context.Background() + if err := copy(ctx, local, upstream); err != nil && err.Error() != "done" { + log.Printf("error forwarding connection %s", err.Error()) + } + + if verbose { + log.Printf("Closed %s => %s (%s)", + from, + upstream.RemoteAddr().String(), + local.RemoteAddr().String()) } } -func forwardConnection(from, to net.Conn) error { - errgrp, _ := errgroup.WithContext(context.Background()) +// copy copies data between two connections using io.Copy +// and will exit when either connection is closed or runs +// into an error +func copy(ctx context.Context, from, to net.Conn) error { + + ctx, cancel := context.WithCancel(ctx) + errgrp, _ := errgroup.WithContext(ctx) errgrp.Go(func() error { io.Copy(from, to) + cancel() return fmt.Errorf("done") }) errgrp.Go(func() error { io.Copy(to, from) + cancel() + return fmt.Errorf("done") }) + errgrp.Go(func() error { + <-ctx.Done() + + // This closes both ends of the connection as + // soon as possible. + from.Close() + to.Close() + return fmt.Errorf("done") + }) + + if err := errgrp.Wait(); err != nil { + return err + } - return errgrp.Wait() + return nil } diff --git a/rules.example.yaml b/rules.example.yaml index 99c3158..13f7933 100644 --- a/rules.example.yaml +++ b/rules.example.yaml @@ -14,3 +14,8 @@ rules: - 192.168.1.19:22 - 192.168.1.21:22 - 192.168.1.20:22 + +- name: remap-local-ssh-port + from: 127.0.0.1:2222 + to: + - 127.0.0.1:22