Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocket functionality to httpx #108

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/gorilla/websocket v1.5.1
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
Expand Down
177 changes: 177 additions & 0 deletions httpx/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package httpx

import (
"net/http"
"sync"
"time"

"github.com/gorilla/websocket"
)

const (
// max time for between reading a message before socket is considered closed
maxReadWait = 60 * time.Second

// maximum time to wait for message to be written
maxWriteWait = 15 * time.Second

// how often to send a ping message
pingPeriod = 30 * time.Second
)

var upgrader = websocket.Upgrader{ReadBufferSize: 1024, WriteBufferSize: 1024}

// WebSocket provides a websocket interface similar to that of Javascript.
type WebSocket interface {
// Start begins reading and writing of messages on this socket
Start()

// Send sends the given message over the socket
Send([]byte)

// Close closes the socket connection
Close(int)

// OnMessage is called when the socket receives a message
OnMessage(func([]byte))

// OnClose is called when the socket is closed (even if we initiate the close)
OnClose(func(int))
}

// WebSocket implemention using gorilla library
type socket struct {
conn *websocket.Conn
onMessage func([]byte)
onClose func(int)
outbox chan []byte
readError chan error
writeError chan error
stopWriter chan bool
closingWithCode int
rwWaitGroup sync.WaitGroup
monitorWaitGroup sync.WaitGroup
}

// NewWebSocket creates a new web socket from a regular HTTP request
func NewWebSocket(w http.ResponseWriter, r *http.Request, maxReadBytes int64, sendBuffer int) (WebSocket, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}

conn.SetReadLimit(maxReadBytes)

return &socket{
conn: conn,
onMessage: func([]byte) {},
onClose: func(int) {},
outbox: make(chan []byte, sendBuffer),
readError: make(chan error, 1),
writeError: make(chan error, 1),
stopWriter: make(chan bool, 1),
}, nil
}

func (s *socket) OnMessage(fn func([]byte)) { s.onMessage = fn }
func (s *socket) OnClose(fn func(int)) { s.onClose = fn }

func (s *socket) Start() {
s.conn.SetReadDeadline(time.Now().Add(maxReadWait))
s.conn.SetPongHandler(s.pong)

go s.monitor()
go s.reader()
go s.writer()
}

func (s *socket) Send(msg []byte) {
s.outbox <- msg
}

func (s *socket) Close(code int) {
s.closingWithCode = code
s.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(code, ""))
s.conn.Close() // causes reader to stop
s.stopWriter <- true

s.monitorWaitGroup.Wait()
}

func (s *socket) pong(m string) error {
s.conn.SetReadDeadline(time.Now().Add(maxReadWait))

return nil
}

func (s *socket) monitor() {
s.monitorWaitGroup.Add(1)
defer s.monitorWaitGroup.Done()

out:
for {
select {
case err := <-s.readError:
if e, ok := err.(*websocket.CloseError); ok && s.closingWithCode == 0 {
s.closingWithCode = e.Code
}
s.stopWriter <- true // ensure writer is stopped
break out
case err := <-s.writeError:
if e, ok := err.(*websocket.CloseError); ok {
s.closingWithCode = e.Code
}
s.conn.Close() // ensure reader is stopped
break out
}
}

s.rwWaitGroup.Wait()

s.onClose(s.closingWithCode)
}

func (s *socket) reader() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()

for {
_, message, err := s.conn.ReadMessage()
if err != nil {
s.readError <- err
return
}

s.onMessage(message)
}
}

func (s *socket) writer() {
s.rwWaitGroup.Add(1)
defer s.rwWaitGroup.Done()

ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()

for {
select {
case msg := <-s.outbox:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

err := s.conn.WriteMessage(websocket.TextMessage, msg)
if err != nil {
s.writeError <- err
return
}
case <-ticker.C:
s.conn.SetWriteDeadline(time.Now().Add(maxWriteWait))

if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
s.writeError <- err
return
}
case <-s.stopWriter:
return
}
}
}
140 changes: 140 additions & 0 deletions httpx/websocket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package httpx_test

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/nyaruka/gocommon/httpx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newSocketServer(t *testing.T, fn func(httpx.WebSocket)) string {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sock, err := httpx.NewWebSocket(w, r, 4096, 5)
require.NoError(t, err)

fn(sock)
}))

return "ws:" + strings.TrimPrefix(s.URL, "http:")
}

func newSocketConnection(t *testing.T, url string) *websocket.Conn {
d := websocket.Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 30 * time.Second,
}
c, _, err := d.Dial(url, nil)
assert.NoError(t, err)
return c
}

func TestSocketMessages(t *testing.T) {
var sock httpx.WebSocket
var serverReceived [][]byte
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnMessage(func(b []byte) {
serverReceived = append(serverReceived, b)
})
sock.OnClose(func(code int) {
serverCloseCode = code
})
sock.Start()
})

conn := newSocketConnection(t, serverURL)

// send a message from the server...
sock.Send([]byte("from server"))

// and read it from the client
msgType, msg, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, 1, msgType)
assert.Equal(t, "from server", string(msg))

// send a message from the client...
conn.WriteMessage(websocket.TextMessage, []byte("to server"))

// and check server received it
time.Sleep(500 * time.Millisecond)
assert.Equal(t, [][]byte{[]byte("to server")}, serverReceived)

pongReceived := false
conn.SetPongHandler(func(appData string) error {
pongReceived = true
return nil
})

// send a ping message from the client...
conn.WriteMessage(websocket.PingMessage, []byte{})

// and give server time to receive it and respond
time.Sleep(500 * time.Millisecond)

// give the connection something to read because ReadMessage will block until it gets a non-ping-pong message
sock.Send([]byte("dummy"))
conn.ReadMessage()

assert.True(t, pongReceived)

var connCloseCode int
conn.SetCloseHandler(func(code int, text string) error {
connCloseCode = code
return nil
})

sock.Close(1001)

conn.ReadMessage() // read the close message

assert.Equal(t, 1001, serverCloseCode)
assert.Equal(t, 1001, connCloseCode)
}

func TestSocketClientCloseWithMessage(t *testing.T) {
var sock httpx.WebSocket
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnClose(func(code int) { serverCloseCode = code })
sock.Start()
})

conn := newSocketConnection(t, serverURL)
conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.ClosePolicyViolation, ""))
conn.Close()

time.Sleep(250 * time.Millisecond)

assert.Equal(t, websocket.ClosePolicyViolation, serverCloseCode)
}

func TestSocketClientCloseWithoutMessage(t *testing.T) {
var sock httpx.WebSocket
var serverCloseCode int

serverURL := newSocketServer(t, func(ws httpx.WebSocket) {
sock = ws
sock.OnClose(func(code int) { serverCloseCode = code })
sock.Start()
})

conn := newSocketConnection(t, serverURL)
conn.Close()

time.Sleep(250 * time.Millisecond)

assert.Equal(t, websocket.CloseAbnormalClosure, serverCloseCode)
}
Loading