Skip to content

Commit

Permalink
Propagate half-closes correctly in forward
Browse files Browse the repository at this point in the history
Before, the following would not work as you would expect:

```go
// # One terminal
// $ ncat --recv-only -l 9090

// ngrok-go code
fwd, err := sess.ListenAndForward(
  ctx,
  "127.0.0.1:9090",
  config.TCPEndpoint(),
)

// fwd.URL() is 0.tcp.jp.ngrok.io:14517 for this example

// another terminal
// $ ncat --send-only 0.tcp.jp.ngrok.io 14517 < hello-world.txt
```

What we would expect from the above would be for the send side to send
"hello world" and exit, and then the recv side to print "hello world"
and also exit.

This is what happens if you do `ncat --send-only localhost 9090`
instead of copying through the ngrok tcp tunnel.

Before this change, when copying through ngrok the recv side would not
exit because the 'Close' of the connection did not get propagated
through the 'join'.

I've also added a unit test showing this.

Thank you to @abakum for originally noticing this issue and offering a
fix over in #137.
In the hopes of landing this more quickly, I've written a new version, derived from
the internal ngrok agent's join code, which should thus be easier to
review etc.
  • Loading branch information
abakum authored and euank committed May 28, 2024
1 parent 79dc3b2 commit 981f6b2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
20 changes: 12 additions & 8 deletions forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,20 @@ func (fwd *forwarder) Wait() error {
// compile-time check that we're implementing the proper interface
var _ Forwarder = (*forwarder)(nil)

func join(ctx context.Context, left, right io.ReadWriter) {
func join(logger log15.Logger, left, right net.Conn) {
g := &sync.WaitGroup{}
g.Add(2)
go func() {
_, _ = io.Copy(left, right)
g.Done()
defer g.Done()
defer left.Close()
n, err := io.Copy(left, right)
logger.Debug("left join finished", "err", err, "bytes", n)
}()
go func() {
_, _ = io.Copy(right, left)
g.Done()
defer g.Done()
defer right.Close()
n, err := io.Copy(right, left)
logger.Debug("right join finished", "err", err, "bytes", n)
}()
g.Wait()
}
Expand All @@ -85,21 +89,21 @@ func forwardTunnel(ctx context.Context, tun Tunnel, url *url.URL) Forwarder {
if err != nil {
return err
}
logger.Debug("accept connection from", "address", conn.RemoteAddr())
fwdTasks.Add(1)

go func() {
ngrokConn := conn.(Conn)
defer ngrokConn.Close()

backend, err := openBackend(ctx, logger, tun, ngrokConn, url)
if err != nil {
defer ngrokConn.Close()
logger.Warn("failed to connect to backend url", "error", err)
fwdTasks.Done()
return
}

defer backend.Close()
join(ctx, ngrokConn, backend)
join(logger.New("url", url), ngrokConn, backend)
fwdTasks.Done()
}()
}
Expand Down
46 changes: 46 additions & 0 deletions forward_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package ngrok

import (
"errors"
"io"
"net"
"testing"

"github.com/inconshreveable/log15/v3"
"github.com/stretchr/testify/require"
)

func TestHalfCloseJoin(t *testing.T) {
srv, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

waitSrvConn := make(chan net.Conn)
go func() {
srvConn, err := srv.Accept()
if err != nil {
panic(err)
}
waitSrvConn <- srvConn
}()

browser, ngrokEndpoint := net.Pipe()
agent, userService := net.Pipe()

waitJoinDone := make(chan struct{})
go func() {
defer close(waitJoinDone)
join(log15.New(), ngrokEndpoint, agent)
}()

_, err = browser.Write([]byte("hello world"))
require.NoError(t, err)
var b [len("hello world")]byte
_, err = userService.Read(b[:])
require.NoError(t, err)
require.Equal(t, []byte("hello world"), b[:])
browser.Close()
_, err = userService.Read(b[:])
require.Truef(t, errors.Is(err, io.EOF), "io.EOF expected, got %v", err)

<-waitJoinDone
}

0 comments on commit 981f6b2

Please sign in to comment.