From f7a36fbd9ee970bf2cf1295069edfbcffe314648 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Thu, 11 Jan 2024 16:06:03 -0500 Subject: [PATCH 1/4] Add websocket struct to httpx --- go.mod | 1 + go.sum | 2 + httpx/websocket.go | 174 ++++++++++++++++++++++++++++++++++++++++ httpx/websocket_test.go | 69 ++++++++++++++++ 4 files changed, 246 insertions(+) create mode 100644 httpx/websocket.go create mode 100644 httpx/websocket_test.go diff --git a/go.mod b/go.mod index 741e606..2ab20a8 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/gorilla/websocket v1.5.1 github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/leodido/go-urn v1.2.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/go.sum b/go.sum index d05d53d..556029a 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= diff --git a/httpx/websocket.go b/httpx/websocket.go new file mode 100644 index 0000000..147a75a --- /dev/null +++ b/httpx/websocket.go @@ -0,0 +1,174 @@ +package httpx + +import ( + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + // max time for between reading a message before socket is considered closed + maxReadWait = 60 * time.Second + + // maximum time to wait for message to be written + maxWriteWait = 15 * time.Second + + // how often to send a ping message + pingPeriod = 30 * time.Second +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +type WebSocket interface { + Start() + Send(msg []byte) + Close() + + OnMessage(fn func([]byte)) + OnClose(fn func(int)) +} + +// Socket implemention using gorilla library +type socket struct { + conn *websocket.Conn + onMessage func([]byte) + onClose func(int) + outbox chan []byte + readError chan error + writeError chan error + stopWriter chan bool + stopMonitor chan bool + rwWaitGroup sync.WaitGroup + monitorWaitGroup sync.WaitGroup +} + +func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, err + } + + conn.SetReadLimit(maxReadBytes) + + return &socket{ + conn: conn, + onMessage: func([]byte) {}, + onClose: func(int) {}, + outbox: make(chan []byte, sendBuffer), + readError: make(chan error, 1), + writeError: make(chan error, 1), + stopWriter: make(chan bool, 1), + stopMonitor: make(chan bool, 1), + }, nil +} + +func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn } +func (s *socket) OnClose(fn func(int)) { s.onClose = fn } + +func (s *socket) Start() { + s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) + s.conn.SetPongHandler(s.pong) + + go s.monitor() + go s.reader() + go s.writer() +} + +func (s *socket) Send(msg []byte) { + s.outbox <- msg +} + +func (s *socket) Close() { + s.conn.Close() // causes reader to stop + s.stopWriter <- true + s.stopMonitor <- true + + s.monitorWaitGroup.Wait() +} + +func (s *socket) pong(m string) error { + s.conn.SetReadDeadline(time.Now().Add(maxReadWait)) + + return nil +} + +func (s *socket) monitor() { + s.monitorWaitGroup.Add(1) + defer s.monitorWaitGroup.Done() + + closeCode := websocket.CloseNormalClosure + +out: + for { + select { + case err := <-s.readError: + if e, ok := err.(*websocket.CloseError); ok { + closeCode = e.Code + } + s.stopWriter <- true // ensure writer is stopped + break out + case err := <-s.writeError: + if e, ok := err.(*websocket.CloseError); ok { + closeCode = e.Code + } + s.conn.Close() // ensure reader is stopped + break out + case <-s.stopMonitor: + break out + } + } + + s.rwWaitGroup.Wait() + + s.onClose(closeCode) +} + +func (s *socket) reader() { + s.rwWaitGroup.Add(1) + defer s.rwWaitGroup.Done() + + for { + _, message, err := s.conn.ReadMessage() + if err != nil { + s.readError <- err + return + } + + s.onMessage(message) + } +} + +func (s *socket) writer() { + s.rwWaitGroup.Add(1) + defer s.rwWaitGroup.Done() + + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + for { + select { + case msg := <-s.outbox: + s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) + + err := s.conn.WriteMessage(websocket.TextMessage, msg) + if err != nil { + s.writeError <- err + return + } + case <-ticker.C: + s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait)) + + if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + s.writeError <- err + return + } + case <-s.stopWriter: + return + } + } +} diff --git a/httpx/websocket_test.go b/httpx/websocket_test.go new file mode 100644 index 0000000..714ccda --- /dev/null +++ b/httpx/websocket_test.go @@ -0,0 +1,69 @@ +package httpx_test + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/nyaruka/gocommon/httpx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSocket(t *testing.T) { + var sock httpx.WebSocket + var err error + + var serverReceived [][]byte + var serverCloseCode int + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sock, err = httpx.NewWebSocket(w, r, 4096, 5) + + sock.OnMessage(func(b []byte) { + serverReceived = append(serverReceived, b) + }) + sock.OnClose(func(code int) { + serverCloseCode = code + }) + + sock.Start() + + require.NoError(t, err) + })) + + wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:") + + d := websocket.Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, + } + conn, _, err := d.Dial(wsURL, nil) + assert.NoError(t, err) + assert.NotNil(t, conn) + + // send a message from the server... + sock.Send([]byte("from server")) + + // and read it from the client + msgType, msg, err := conn.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, 1, msgType) + assert.Equal(t, "from server", string(msg)) + + // send a message from the client... + conn.WriteMessage(websocket.TextMessage, []byte("to server")) + + // and check server received it + time.Sleep(time.Second) + assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived) + + sock.Close() + + assert.Equal(t, 1000, serverCloseCode) +} From fdb6d17f557f5ee06e04e373b7c74e2db107e89f Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Fri, 12 Jan 2024 10:29:39 -0500 Subject: [PATCH 2/4] Tests and fixes --- httpx/websocket.go | 44 ++++++++--------- httpx/websocket_test.go | 101 ++++++++++++++++++++++++++++++---------- 2 files changed, 96 insertions(+), 49 deletions(-) diff --git a/httpx/websocket.go b/httpx/websocket.go index 147a75a..e0ece26 100644 --- a/httpx/websocket.go +++ b/httpx/websocket.go @@ -19,15 +19,12 @@ const ( pingPeriod = 30 * time.Second ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} +var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} type WebSocket interface { Start() - Send(msg []byte) - Close() + Send([]byte) + Close(int) OnMessage(fn func([]byte)) OnClose(fn func(int)) @@ -42,11 +39,12 @@ type socket struct { readError chan error writeError chan error stopWriter chan bool - stopMonitor chan bool + closingWithCode int rwWaitGroup sync.WaitGroup monitorWaitGroup sync.WaitGroup } +// NewWebSocket creates a new web socket from a regular HTTP request func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -56,14 +54,13 @@ func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, se conn.SetReadLimit(maxReadBytes) return &socket{ - conn: conn, - onMessage: func([]byte) {}, - onClose: func(int) {}, - outbox: make(chan []byte, sendBuffer), - readError: make(chan error, 1), - writeError: make(chan error, 1), - stopWriter: make(chan bool, 1), - stopMonitor: make(chan bool, 1), + conn: conn, + onMessage: func([]byte) {}, + onClose: func(int) {}, + outbox: make(chan []byte, sendBuffer), + readError: make(chan error, 1), + writeError: make(chan error, 1), + stopWriter: make(chan bool, 1), }, nil } @@ -83,10 +80,11 @@ func (s *socket) Send(msg []byte) { s.outbox <- msg } -func (s *socket) Close() { +func (s *socket) Close(code int) { + s.closingWithCode = code + s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, "")) s.conn.Close() // causes reader to stop s.stopWriter <- true - s.stopMonitor <- true s.monitorWaitGroup.Wait() } @@ -101,31 +99,27 @@ func (s *socket) monitor() { s.monitorWaitGroup.Add(1) defer s.monitorWaitGroup.Done() - closeCode := websocket.CloseNormalClosure - out: for { select { case err := <-s.readError: - if e, ok := err.(*websocket.CloseError); ok { - closeCode = e.Code + if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 { + s.closingWithCode = e.Code } s.stopWriter <- true // ensure writer is stopped break out case err := <-s.writeError: if e, ok := err.(*websocket.CloseError); ok { - closeCode = e.Code + s.closingWithCode = e.Code } s.conn.Close() // ensure reader is stopped break out - case <-s.stopMonitor: - break out } } s.rwWaitGroup.Wait() - s.onClose(closeCode) + s.onClose(s.closingWithCode) } func (s *socket) reader() { diff --git a/httpx/websocket_test.go b/httpx/websocket_test.go index 714ccda..a78c119 100644 --- a/httpx/websocket_test.go +++ b/httpx/websocket_test.go @@ -13,39 +13,46 @@ import ( "github.com/stretchr/testify/require" ) -func TestSocket(t *testing.T) { - var sock httpx.WebSocket - var err error +func newSocketServer(t *testing.T, fn func(httpx.WebSocket)) string { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sock, err := httpx.NewWebSocket(w, r, 4096, 5) + require.NoError(t, err) + + fn(sock) + })) + return "ws:" + strings.TrimPrefix(s.URL, "http:") +} + +func newSocketConnection(t *testing.T, url string) *websocket.Conn { + d := websocket.Dialer{ + Subprotocols: []string{"p1", "p2"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + HandshakeTimeout: 30 * time.Second, + } + c, _, err := d.Dial(url, nil) + assert.NoError(t, err) + return c +} + +func TestSocketMessages(t *testing.T) { + var sock httpx.WebSocket var serverReceived [][]byte var serverCloseCode int - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - sock, err = httpx.NewWebSocket(w, r, 4096, 5) - + serverURL := newSocketServer(t, func(ws httpx.WebSocket) { + sock = ws sock.OnMessage(func(b []byte) { serverReceived = append(serverReceived, b) }) sock.OnClose(func(code int) { serverCloseCode = code }) - sock.Start() + }) - require.NoError(t, err) - })) - - wsURL := "ws:" + strings.TrimPrefix(server.URL, "http:") - - d := websocket.Dialer{ - Subprotocols: []string{"p1", "p2"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, - HandshakeTimeout: 30 * time.Second, - } - conn, _, err := d.Dial(wsURL, nil) - assert.NoError(t, err) - assert.NotNil(t, conn) + conn := newSocketConnection(t, serverURL) // send a message from the server... sock.Send([]byte("from server")) @@ -60,10 +67,56 @@ func TestSocket(t *testing.T) { conn.WriteMessage(websocket.TextMessage, []byte("to server")) // and check server received it - time.Sleep(time.Second) + time.Sleep(500 * time.Millisecond) assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived) - sock.Close() + var connCloseCode int + conn.SetCloseHandler(func(code int, text string) error { + connCloseCode = code + return nil + }) + + sock.Close(1001) + + conn.ReadMessage() // read the close message + + assert.Equal(t, 1001, serverCloseCode) + assert.Equal(t, 1001, connCloseCode) +} + +func TestSocketClientCloseWithMessage(t *testing.T) { + var sock httpx.WebSocket + var serverCloseCode int + + serverURL := newSocketServer(t, func(ws httpx.WebSocket) { + sock = ws + sock.OnClose(func(code int) { serverCloseCode = code }) + sock.Start() + }) + + conn := newSocketConnection(t, serverURL) + conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "")) + conn.Close() + + time.Sleep(250 * time.Millisecond) + + assert.Equal(t, websocket.ClosePolicyViolation, serverCloseCode) +} + +func TestSocketClientCloseWithoutMessage(t *testing.T) { + var sock httpx.WebSocket + var serverCloseCode int + + serverURL := newSocketServer(t, func(ws httpx.WebSocket) { + sock = ws + sock.OnClose(func(code int) { serverCloseCode = code }) + sock.Start() + }) + + conn := newSocketConnection(t, serverURL) + conn.Close() + + time.Sleep(250 * time.Millisecond) - assert.Equal(t, 1000, serverCloseCode) + assert.Equal(t, websocket.CloseAbnormalClosure, serverCloseCode) } From 366930ebd935e892286cf1f47239392518b03a1e Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Fri, 12 Jan 2024 10:34:57 -0500 Subject: [PATCH 3/4] More comments in code --- httpx/websocket.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/httpx/websocket.go b/httpx/websocket.go index e0ece26..9af3521 100644 --- a/httpx/websocket.go +++ b/httpx/websocket.go @@ -21,16 +21,25 @@ const ( var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024} +// WebSocket provides a websocket interface similar to that of Javascript. type WebSocket interface { + // Start begins reading and writing of messages on this socket Start() + + // Send sends the given message over the socket Send([]byte) + + // Close closes the socket connection Close(int) - OnMessage(fn func([]byte)) - OnClose(fn func(int)) + // OnMessage is called when the socket receives a message + OnMessage(func([]byte)) + + // OnClose is called when the socket is closed (even if we initiate the close) + OnClose(func(int)) } -// Socket implemention using gorilla library +// WebSocket implemention using gorilla library type socket struct { conn *websocket.Conn onMessage func([]byte) From 35142c6c6615673bf2f29d9ec336df576756f680 Mon Sep 17 00:00:00 2001 From: Rowan Seymour Date: Fri, 12 Jan 2024 10:56:29 -0500 Subject: [PATCH 4/4] Add test for ping handling --- httpx/websocket_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/httpx/websocket_test.go b/httpx/websocket_test.go index a78c119..80ec40a 100644 --- a/httpx/websocket_test.go +++ b/httpx/websocket_test.go @@ -70,6 +70,24 @@ func TestSocketMessages(t *testing.T) { time.Sleep(500 * time.Millisecond) assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived) + pongReceived := false + conn.SetPongHandler(func(appData string) error { + pongReceived = true + return nil + }) + + // send a ping message from the client... + conn.WriteMessage(websocket.PingMessage, []byte{}) + + // and give server time to receive it and respond + time.Sleep(500 * time.Millisecond) + + // give the connection something to read because ReadMessage will block until it gets a non-ping-pong message + sock.Send([]byte("dummy")) + conn.ReadMessage() + + assert.True(t, pongReceived) + var connCloseCode int conn.SetCloseHandler(func(code int, text string) error { connCloseCode = code