From f9faa02e0728c43fc91a1319d515f47bf294fd88 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jan 2023 08:53:24 +0000 Subject: [PATCH] feat: add Unix forwarding server implementations Adds optional (disabled by default) implementations of local->remote and remote->local Unix forwarding through OpenSSH's protocol extensions: - streamlocal-forward@openssh.com - cancel-streamlocal-forward@openssh.com - forwarded-streamlocal@openssh.com - direct-streamlocal@openssh.com Adds tests for Unix forwarding, reverse Unix forwarding and reverse TCP forwarding. Co-authored-by: Samuel Corsi-House --- options_test.go | 2 +- server.go | 2 + server_test.go | 4 +- session_test.go | 19 +++- ssh.go | 20 ++++ streamlocal.go | 252 ++++++++++++++++++++++++++++++++++++++++++++ streamlocal_test.go | 206 ++++++++++++++++++++++++++++++++++++ tcpip.go | 69 ++++++++---- tcpip_test.go | 98 ++++++++++++++++- 9 files changed, 642 insertions(+), 30 deletions(-) create mode 100644 streamlocal.go create mode 100644 streamlocal_test.go diff --git a/options_test.go b/options_test.go index 23fca5a..2992b6a 100644 --- a/options_test.go +++ b/options_test.go @@ -49,7 +49,7 @@ func TestPasswordAuth(t *testing.T) { func TestPasswordAuthBadPass(t *testing.T) { t.Parallel() - l := newLocalListener() + l := newLocalTCPListener() srv := &Server{Handler: func(s Session) {}} srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { return false diff --git a/server.go b/server.go index f783ee5..6d7734b 100644 --- a/server.go +++ b/server.go @@ -47,6 +47,8 @@ type Server struct { ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil + LocalUnixForwardingCallback LocalUnixForwardingCallback // callback for allowing local unix forwarding (direct-streamlocal@openssh.com), denies all if nil + ReverseUnixForwardingCallback ReverseUnixForwardingCallback // callback for allowing reverse unix forwarding (streamlocal-forward@openssh.com), denies all if nil ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions diff --git a/server_test.go b/server_test.go index 8028a3a..8db22fa 100644 --- a/server_test.go +++ b/server_test.go @@ -29,7 +29,7 @@ func TestAddHostKey(t *testing.T) { } func TestServerShutdown(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() testBytes := []byte("Hello world\n") s := &Server{ Handler: func(s Session) { @@ -80,7 +80,7 @@ func TestServerShutdown(t *testing.T) { } func TestServerClose(t *testing.T) { - l := newLocalListener() + l := newLocalTCPListener() s := &Server{ Handler: func(s Session) { time.Sleep(5 * time.Second) diff --git a/session_test.go b/session_test.go index c6ce617..69de83a 100644 --- a/session_test.go +++ b/session_test.go @@ -20,14 +20,25 @@ func (srv *Server) serveOnce(l net.Listener) error { return e } srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, + "session": DefaultSessionHandler, + "direct-tcpip": DirectTCPIPHandler, + "direct-streamlocal@openssh.com": DirectStreamLocalHandler, } + + forwardedTCPHandler := &ForwardedTCPHandler{} + forwardedUnixHandler := &ForwardedUnixHandler{} + srv.RequestHandlers = map[string]RequestHandler{ + "tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardedTCPHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": forwardedUnixHandler.HandleSSHRequest, + } + srv.HandleConn(conn) return nil } -func newLocalListener() net.Listener { +func newLocalTCPListener() net.Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { @@ -64,7 +75,7 @@ func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*g } func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() + l := newLocalTCPListener() go srv.serveOnce(l) return newClientSession(t, l.Addr().String(), cfg) } diff --git a/ssh.go b/ssh.go index 775b454..ed3d48a 100644 --- a/ssh.go +++ b/ssh.go @@ -2,6 +2,7 @@ package ssh import ( "crypto/subtle" + "errors" "net" gossh "golang.org/x/crypto/ssh" @@ -29,6 +30,9 @@ const ( // DefaultHandler is the default Handler used by Serve. var DefaultHandler Handler +// ErrReject is returned by some callbacks to reject a request. +var ErrRejected = errors.New("ssh: rejected") + // Option is a functional option handler for Server. type Option func(*Server) error @@ -64,6 +68,22 @@ type LocalPortForwardingCallback func(ctx Context, destinationHost string, desti // ReversePortForwardingCallback is a hook for allowing reverse port forwarding type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool +// LocalUnixForwardingCallback is a hook for allowing unix forwarding +// (direct-streamlocal@openssh.com). Returning ErrRejected will reject the +// request. The returned net.Conn will be closed by the server when no longer +// needed. +// +// Use SimpleUnixLocalForwardingCallback for a basic implementation. +type LocalUnixForwardingCallback func(ctx Context, socketPath string) (net.Conn, error) + +// ReverseUnixForwardingCallback is a hook for allowing reverse unix forwarding +// (streamlocal-forward@openssh.com). Returning ErrRejected will reject the +// request. The returned net.Listener will be closed by the server when no +// longer needed. +// +// Use SimpleUnixReverseForwardingCallback for a basic implementation. +type ReverseUnixForwardingCallback func(ctx Context, socketPath string) (net.Listener, error) + // ServerConfigCallback is a hook for creating custom default server configs type ServerConfigCallback func(ctx Context) *gossh.ServerConfig diff --git a/streamlocal.go b/streamlocal.go new file mode 100644 index 0000000..2daa1a2 --- /dev/null +++ b/streamlocal.go @@ -0,0 +1,252 @@ +package ssh + +import ( + "context" + "errors" + "fmt" + "io/fs" + "net" + "os" + "path/filepath" + "sync" + "syscall" + + gossh "golang.org/x/crypto/ssh" +) + +const ( + forwardedUnixChannelType = "forwarded-streamlocal@openssh.com" +) + +// directStreamLocalChannelData data struct as specified in OpenSSH's protocol +// extensions document, Section 2.4. +// https://cvsweb.openbsd.org/src/usr.bin/ssh/PROTOCOL?annotate=HEAD +type directStreamLocalChannelData struct { + SocketPath string + + Reserved1 string + Reserved2 uint32 +} + +// DirectStreamLocalHandler provides Unix forwarding from client -> server. It +// can be enabled by adding it to the server's ChannelHandlers under +// `direct-streamlocal@openssh.com`. +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +func DirectStreamLocalHandler(srv *Server, _ *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { + var d directStreamLocalChannelData + err := gossh.Unmarshal(newChan.ExtraData(), &d) + if err != nil { + _ = newChan.Reject(gossh.ConnectionFailed, "error parsing direct-streamlocal data: "+err.Error()) + return + } + + if srv.LocalUnixForwardingCallback == nil { + _ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + return + } + dconn, err := srv.LocalUnixForwardingCallback(ctx, d.SocketPath) + if err != nil { + if errors.Is(err, ErrRejected) { + _ = newChan.Reject(gossh.Prohibited, "unix forwarding is disabled") + return + } + _ = newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("dial unix socket %q: %+v", d.SocketPath, err.Error())) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go gossh.DiscardRequests(reqs) + + bicopy(ctx, ch, dconn) +} + +// remoteUnixForwardRequest describes the extra data sent in a +// streamlocal-forward@openssh.com containing the socket path to bind to. +type remoteUnixForwardRequest struct { + SocketPath string +} + +// remoteUnixForwardChannelData describes the data sent as the payload in the new +// channel request when a Unix connection is accepted by the listener. +type remoteUnixForwardChannelData struct { + SocketPath string + Reserved uint32 +} + +// ForwardedUnixHandler can be enabled by creating a ForwardedUnixHandler and +// adding the HandleSSHRequest callback to the server's RequestHandlers under +// `streamlocal-forward@openssh.com` and +// `cancel-streamlocal-forward@openssh.com` +// +// Unix socket support on Windows is not widely available, so this handler may +// not work on all Windows installations and is not tested on Windows. +type ForwardedUnixHandler struct { + sync.Mutex + forwards map[string]net.Listener +} + +func (h *ForwardedUnixHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { + h.Lock() + if h.forwards == nil { + h.forwards = make(map[string]net.Listener) + } + h.Unlock() + conn, ok := ctx.Value(ContextKeyConn).(*gossh.ServerConn) + if !ok { + // TODO: log cast failure + return false, nil + } + + switch req.Type { + case "streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + + if srv.ReverseUnixForwardingCallback == nil { + return false, []byte("unix forwarding is disabled") + } + + addr := reqPayload.SocketPath + h.Lock() + _, ok := h.forwards[addr] + h.Unlock() + if ok { + // TODO: log failure + return false, nil + } + + ln, err := srv.ReverseUnixForwardingCallback(ctx, addr) + if err != nil { + if errors.Is(err, ErrRejected) { + return false, []byte("unix forwarding is disabled") + } + // TODO: log unix listen failure + return false, nil + } + + // The listener needs to successfully start before it can be added to + // the map, so we don't have to worry about checking for an existing + // listener as you can't listen on the same socket twice. + // + // This is also what the TCP version of this code does. + h.Lock() + h.forwards[addr] = ln + h.Unlock() + + ctx, cancel := context.WithCancel(ctx) + go func() { + <-ctx.Done() + _ = ln.Close() + }() + go func() { + defer cancel() + + for { + c, err := ln.Accept() + if err != nil { + // closed below + break + } + payload := gossh.Marshal(&remoteUnixForwardChannelData{ + SocketPath: addr, + }) + + go func() { + ch, reqs, err := conn.OpenChannel(forwardedUnixChannelType, payload) + if err != nil { + _ = c.Close() + return + } + go gossh.DiscardRequests(reqs) + bicopy(ctx, ch, c) + }() + } + + h.Lock() + ln2, ok := h.forwards[addr] + if ok && ln2 == ln { + delete(h.forwards, addr) + } + h.Unlock() + _ = ln.Close() + }() + + return true, nil + + case "cancel-streamlocal-forward@openssh.com": + var reqPayload remoteUnixForwardRequest + err := gossh.Unmarshal(req.Payload, &reqPayload) + if err != nil { + // TODO: log parse failure + return false, nil + } + h.Lock() + ln, ok := h.forwards[reqPayload.SocketPath] + h.Unlock() + if ok { + _ = ln.Close() + } + return true, nil + + default: + return false, nil + } +} + +// unlink removes files and unlike os.Remove, directories are kept. +func unlink(path string) error { + // Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go + // for more details. + for { + err := syscall.Unlink(path) + if !errors.Is(err, syscall.EINTR) { + return err + } + } +} + +// SimpleUnixLocalForwardingCallback provides a basic implementation for +// LocalUnixForwardingCallback. It will simply dial the requested socket using +// a context-aware dialer. +func SimpleUnixLocalForwardingCallback(ctx Context, socketPath string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) +} + +// SimpleUnixReverseForwardingCallback provides a basic implementation for +// ReverseUnixForwardingCallback. The parent directory will be created (with +// os.MkdirAll), and existing files with the same name will be removed. +func SimpleUnixReverseForwardingCallback(_ Context, socketPath string) (net.Listener, error) { + // Create socket parent dir if not exists. + parentDir := filepath.Dir(socketPath) + err := os.MkdirAll(parentDir, 0700) + if err != nil { + return nil, fmt.Errorf("failed to create parent directory %q for socket %q: %w", parentDir, socketPath, err) + } + + // Remove existing socket if it exists. We do not use os.Remove() here + // so that directories are kept. Note that it's possible that we will + // overwrite a regular file here. Both of these behaviors match OpenSSH, + // however, which is why we unlink. + err = unlink(socketPath) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("failed to remove existing file in socket path %q: %w", socketPath, err) + } + + ln, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("failed to listen on unix socket %q: %w", socketPath, err) + } + + return ln, err +} diff --git a/streamlocal_test.go b/streamlocal_test.go new file mode 100644 index 0000000..41ae3c9 --- /dev/null +++ b/streamlocal_test.go @@ -0,0 +1,206 @@ +package ssh + +import ( + "bytes" + "fmt" + "io" + "net" + "os" + "path/filepath" + "runtime" + "strings" + "sync/atomic" + "testing" + + gossh "golang.org/x/crypto/ssh" +) + +// tempDirUnixSocket returns a temporary directory that can safely hold unix +// sockets. +// +// On all platforms other than darwin this just returns t.TempDir(). On darwin +// we manually make a temporary directory in /tmp because t.TempDir() returns a +// very long directory name, and the path length limit for Unix sockets on +// darwin is 104 characters. +func tempDirUnixSocket(t *testing.T) string { + t.Helper() + if runtime.GOOS == "darwin" { + testName := strings.ReplaceAll(t.Name(), "/", "_") + dir, err := os.MkdirTemp("/tmp", fmt.Sprintf("gliderlabs-ssh-test-%s-", testName)) + if err != nil { + t.Fatalf("create temp dir for test: %v", err) + } + + t.Cleanup(func() { + err := os.RemoveAll(dir) + if err != nil { + t.Errorf("remove temp dir %s: %v", dir, err) + } + }) + return dir + } + + return t.TempDir() +} + +func newLocalUnixListener(t *testing.T) net.Listener { + path := filepath.Join(tempDirUnixSocket(t), "socket.sock") + l, err := net.Listen("unix", path) + if err != nil { + t.Fatalf("failed to listen on a unix socket %q: %v", path, err) + } + return l +} + +func sampleUnixSocketServer(t *testing.T) net.Listener { + l := newLocalUnixListener(t) + + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + return l +} + +func newTestSessionWithUnixForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { + l := sampleUnixSocketServer(t) + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + LocalUnixForwardingCallback: func(ctx Context, socketPath string) (net.Conn, error) { + if socketPath != l.Addr().String() { + panic("unexpected socket path: " + socketPath) + } + if !forwardingEnabled { + return nil, ErrRejected + } + return SimpleUnixLocalForwardingCallback(ctx, socketPath) + }, + }, nil) + + return l, client, func() { + cleanup() + l.Close() + } +} + +func TestLocalUnixForwardingWorks(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, true) + defer cleanup() + + conn, err := client.Dial("unix", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } +} + +func TestLocalUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + l, client, cleanup := newTestSessionWithUnixForwarding(t, false) + defer cleanup() + + _, err := client.Dial("unix", l.Addr().String()) + if err == nil { + t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) + } + if !strings.Contains(err.Error(), "unix forwarding is disabled") { + t.Fatalf("Expected permission error but got %#v", err) + } +} + +func TestReverseUnixForwardingWorks(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return SimpleUnixReverseForwardingCallback(ctx, socketPath) + }, + }, nil) + defer cleanup() + + l, err := client.ListenUnix(remoteSocketPath) + if err != nil { + t.Fatalf("failed to listen on a unix socket over SSH %q: %v", remoteSocketPath, err) + } + defer l.Close() + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("unix", remoteSocketPath) + if err != nil { + t.Fatalf("Error connecting to %v: %v", remoteSocketPath, err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the Unix socket is gone. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + _, err = os.Stat(remoteSocketPath) + if err == nil && !os.IsNotExist(err) { + t.Fatalf("expected remote socket to be gone but it still exists: %v", err) + } +} + +func TestReverseUnixForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + remoteSocketPath := filepath.Join(tempDirUnixSocket(t), "remote.sock") + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReverseUnixForwardingCallback: func(ctx Context, socketPath string) (net.Listener, error) { + atomic.AddInt64(&called, 1) + if socketPath != remoteSocketPath { + panic("unexpected socket path: " + socketPath) + } + return nil, ErrRejected + }, + }, nil) + defer cleanup() + + _, err := client.ListenUnix(remoteSocketPath) + if err == nil { + t.Fatalf("Expected error listening on %q but it succeeded", remoteSocketPath) + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +} diff --git a/tcpip.go b/tcpip.go index 335fda6..843704a 100644 --- a/tcpip.go +++ b/tcpip.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "io" "log" "net" @@ -53,16 +54,7 @@ func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewCh } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() + bicopy(ctx, ch, dconn) } type remoteForwardRequest struct { @@ -117,8 +109,14 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go // TODO: log listen failure return false, []byte{} } + + // If the bind port was port 0, we need to use the actual port in the + // listener map. _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) destPort, _ := strconv.Atoi(destPortStr) + if reqPayload.BindPort == 0 { + addr = net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(destPort)) + } h.Lock() h.forwards[addr] = ln h.Unlock() @@ -155,16 +153,7 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return } go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() + bicopy(ctx, ch, c) }() } h.Lock() @@ -191,3 +180,43 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go return false, nil } } + +// bicopy copies all of the data between the two connections and will close them +// after one or both of them are done writing. If the context is canceled, both +// of the connections will be closed. +func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + defer func() { + _ = c1.Close() + _ = c2.Close() + }() + + var wg sync.WaitGroup + copyFunc := func(dst io.WriteCloser, src io.Reader) { + defer func() { + wg.Done() + // If one side of the copy fails, ensure the other one exits as + // well. + cancel() + }() + _, _ = io.Copy(dst, src) + } + + wg.Add(2) + go copyFunc(c1, c2) + go copyFunc(c2, c1) + + // Convert waitgroup to a channel so we can also wait on the context. + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-ctx.Done(): + case <-done: + } +} diff --git a/tcpip_test.go b/tcpip_test.go index 4ddf40e..525ca2d 100644 --- a/tcpip_test.go +++ b/tcpip_test.go @@ -2,19 +2,22 @@ package ssh import ( "bytes" + "context" "io" "net" "strconv" "strings" + "sync/atomic" "testing" + "time" gossh "golang.org/x/crypto/ssh" ) var sampleServerResponse = []byte("Hello world") -func sampleSocketServer() net.Listener { - l := newLocalListener() +func sampleTCPSocketServer() net.Listener { + l := newLocalTCPListener() go func() { conn, err := l.Accept() @@ -29,7 +32,7 @@ func sampleSocketServer() net.Listener { } func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() + l := sampleTCPSocketServer() _, client, cleanup := newTestSession(t, &Server{ Handler: func(s Session) {}, @@ -81,3 +84,92 @@ func TestLocalPortForwardingRespectsCallback(t *testing.T) { t.Fatalf("Expected permission error but got %#v", err) } } + +func TestReverseTCPForwardingWorks(t *testing.T) { + t.Parallel() + + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return true + }, + }, nil) + defer cleanup() + + l, err := client.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on a random TCP port over SSH: %v", err) + } + defer l.Close() + go func() { + conn, err := l.Accept() + if err != nil { + return + } + conn.Write(sampleServerResponse) + conn.Close() + }() + + // Dial the listener that should've been created by the server. + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) + } + result, err := io.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, sampleServerResponse) { + t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) + } + + // Close the listener and make sure that the port is no longer in use. + err = l.Close() + if err != nil { + t.Fatalf("failed to close remote listener: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + var d net.Dialer + _, err = d.DialContext(ctx, "tcp", l.Addr().String()) + if err == nil { + t.Fatalf("expected error connecting to %v but it succeeded", l.Addr().String()) + } +} + +func TestReverseTCPForwardingRespectsCallback(t *testing.T) { + t.Parallel() + + var called int64 + _, client, cleanup := newTestSession(t, &Server{ + Handler: func(s Session) {}, + ReversePortForwardingCallback: func(ctx Context, bindHost string, bindPort uint32) bool { + atomic.AddInt64(&called, 1) + if bindHost != "127.0.0.1" { + panic("unexpected bindHost: " + bindHost) + } + if bindPort != 0 { + panic("unexpected bindPort: " + strconv.Itoa(int(bindPort))) + } + return false + }, + }, nil) + defer cleanup() + + _, err := client.Listen("tcp", "127.0.0.1:0") + if err == nil { + t.Fatalf("Expected error listening on random port but it succeeded") + } + + if atomic.LoadInt64(&called) != 1 { + t.Fatalf("Expected callback to be called once but it was called %d times", called) + } +}