Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting websocket connection parameters #361

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ func NewClientUsingGet(endpoint string, httpClient Doer) Client {
// The client does not support queries nor mutations, and will return an error
// if passed a request that attempts one.
func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Header) WebSocketClient {
return NewClientUsingWebSocketWithConnectionParams(endpoint, wsDialer, headers, nil)
}

// NewClientUsingWebSocketWithConnectionParams returns a [WebSocketClient] which makes subscription requests
// to the given endpoint using webSocket. It allows to pass additional connection parameters
// to the server during the initial connection handshake.
//
// connectionParams is a map of connection parameters to be sent to the server
// during the initial connection handshake.
func NewClientUsingWebSocketWithConnectionParams(endpoint string, wsDialer Dialer, headers http.Header, connParams map[string]interface{}) WebSocketClient {
if headers == nil {
headers = http.Header{}
}
Expand All @@ -141,6 +151,7 @@ func NewClientUsingWebSocket(endpoint string, wsDialer Dialer, headers http.Head
return &webSocketClient{
Dialer: wsDialer,
Header: headers,
connParams: connParams,
errChan: make(chan error),
endpoint: endpoint,
subscriptions: subscriptionMap{map_: make(map[string]subscription)},
Expand Down
11 changes: 9 additions & 2 deletions graphql/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,18 @@ type webSocketClient struct {
Header http.Header
endpoint string
conn WSConn
connParams map[string]interface{}
errChan chan error
subscriptions subscriptionMap
isClosing bool
sync.Mutex
}

type webSocketInitMessage struct {
Payload map[string]interface{} `json:"payload"`
Type string `json:"type"`
}

type webSocketSendMessage struct {
Payload *Request `json:"payload"`
Type string `json:"type"`
Expand All @@ -67,8 +73,9 @@ type webSocketReceiveMessage struct {
}

func (w *webSocketClient) sendInit() error {
connInitMsg := webSocketSendMessage{
Type: webSocketTypeConnInit,
connInitMsg := webSocketInitMessage{
Type: webSocketTypeConnInit,
Payload: w.connParams,
}
return w.sendStructAsJSON(connInitMsg)
}
Expand Down
60 changes: 60 additions & 0 deletions internal/integration/generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

83 changes: 82 additions & 1 deletion internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ func TestSubscription(t *testing.T) {

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL)
wsClient := newRoundtripWebSocketClient(t, server.URL, nil)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -144,6 +145,86 @@ func TestSubscription(t *testing.T) {
}
}

func TestSubscriptionConnectionParams(t *testing.T) {
_ = `# @genqlient
subscription countAuthorized { countAuthorized }`

authKey := server.AuthKey

ctx := context.Background()
server := server.RunServer()
defer server.Close()

cases := []struct {
connParams map[string]interface{}
name string
expectedError string
}{
{
name: "authorized_user_gets_counter",
connParams: map[string]interface{}{
authKey: "authorized-user-token",
},
},
{
name: "unauthorized_user_gets_error",
expectedError: "input: countAuthorized unauthorized\n",
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
wsClient := newRoundtripWebSocketClient(t, server.URL, tc.connParams)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)

dataChan, subscriptionID, err := countAuthorized(ctx, wsClient)
require.NoError(t, err)
defer wsClient.Close()

var (
counter = 0
start = time.Now()
)

for loop := true; loop; {
select {
case resp, more := <-dataChan:
if !more {
loop = false
break
}

if tc.expectedError != "" {
require.Error(t, resp.Errors)
assert.Equal(t, tc.expectedError, resp.Errors.Error())
continue
}

require.NotNil(t, resp.Data)
assert.Equal(t, counter, resp.Data.CountAuthorized)
require.Nil(t, resp.Errors)

if time.Since(start) > 5*time.Second {
err := wsClient.Unsubscribe(subscriptionID)
require.NoError(t, err)
loop = false
}

counter++

case err := <-errChan:
require.NoError(t, err)

case <-time.After(10 * time.Second):
require.NoError(t, fmt.Errorf("subscription timed out"))
}
}
})
}
}

func TestServerError(t *testing.T) {
_ = `# @genqlient
query failingQuery { fail me { id } }`
Expand Down
11 changes: 8 additions & 3 deletions internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,19 @@ func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeade
return graphql.WSConn(conn), err
}

func newRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient {
func newRoundtripWebSocketClient(t *testing.T, endpoint string, connectionParams map[string]interface{}) graphql.WebSocketClient {
dialer := websocket.DefaultDialer
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}
return &roundtripClient{
wsWrapped: graphql.NewClientUsingWebSocket(endpoint, &MyDialer{Dialer: dialer}, nil),
t: t,
wsWrapped: graphql.NewClientUsingWebSocketWithConnectionParams(
endpoint,
&MyDialer{Dialer: dialer},
nil,
connectionParams,
),
t: t,
}
}
1 change: 1 addition & 0 deletions internal/integration/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Mutation {

type Subscription {
count: Int!
countAuthorized: Int!
}

type User implements Being & Lucky {
Expand Down
71 changes: 70 additions & 1 deletion internal/integration/server/gqlgen_exec.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading