Skip to content

Commit

Permalink
use coder/websocket for everything, get rid of gobwas.
Browse files Browse the repository at this point in the history
supposedly it is faster, and anyway it's better to use it since we're already using it for wasm/js.

(previously named nhooyr/websocket).
  • Loading branch information
fiatjaf committed Jan 3, 2025
1 parent b33cfb1 commit defc349
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 243 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ To use it, use `-tags=libsecp256k1` whenever you're compiling your program that
Install [wasmbrowsertest](https://github.com/agnivade/wasmbrowsertest), then run tests:

```sh
TEST_RELAY_URL=<relay_url> GOOS=js GOARCH=wasm go test -short ./...
GOOS=js GOARCH=wasm go test -short ./...
```

## Warning: risk of goroutine bloat (if used incorrectly)
Expand Down
162 changes: 14 additions & 148 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,187 +1,53 @@
//go:build !js

package nostr

import (
"bytes"
"compress/flate"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"

"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
"github.com/gobwas/ws/wsutil"
ws "github.com/coder/websocket"
)

type Connection struct {
conn net.Conn
enableCompression bool
controlHandler wsutil.FrameHandlerFunc
flateReader *wsflate.Reader
reader *wsutil.Reader
flateWriter *wsflate.Writer
writer *wsutil.Writer
msgStateR *wsflate.MessageState
msgStateW *wsflate.MessageState
conn *ws.Conn
}

func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) {
dialer := ws.Dialer{
Header: ws.HandshakeHeaderHTTP(requestHeader),
Extensions: []httphead.Option{
wsflate.DefaultParameters.Option(),
},
TLSConfig: tlsConfig,
}
conn, _, hs, err := dialer.Dial(ctx, url)
c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig))
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}

enableCompression := false
state := ws.StateClientSide
for _, extension := range hs.Extensions {
if string(extension.Name) == wsflate.ExtensionName {
enableCompression = true
state |= ws.StateExtended
break
}
}

// reader
var flateReader *wsflate.Reader
var msgStateR wsflate.MessageState
if enableCompression {
msgStateR.SetCompressed(true)

flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor {
return flate.NewReader(r)
})
}

controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide)
reader := &wsutil.Reader{
Source: conn,
State: state,
OnIntermediate: controlHandler,
CheckUTF8: false,
Extensions: []wsutil.RecvExtension{
&msgStateR,
},
}

// writer
var flateWriter *wsflate.Writer
var msgStateW wsflate.MessageState
if enableCompression {
msgStateW.SetCompressed(true)

flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor {
fw, err := flate.NewWriter(w, 4)
if err != nil {
InfoLogger.Printf("Failed to create flate writer: %v", err)
}
return fw
})
return nil, err
}

writer := wsutil.NewWriter(conn, state, ws.OpText)
writer.SetExtensions(&msgStateW)

return &Connection{
conn: conn,
enableCompression: enableCompression,
controlHandler: controlHandler,
flateReader: flateReader,
reader: reader,
msgStateR: &msgStateR,
flateWriter: flateWriter,
writer: writer,
msgStateW: &msgStateW,
conn: c,
}, nil
}

func (c *Connection) WriteMessage(ctx context.Context, data []byte) error {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}

if c.msgStateW.IsCompressed() && c.enableCompression {
c.flateWriter.Reset(c.writer)
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}

if err := c.flateWriter.Close(); err != nil {
return fmt.Errorf("failed to close flate writer: %w", err)
}
} else {
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}
}

if err := c.writer.Flush(); err != nil {
return fmt.Errorf("failed to flush writer: %w", err)
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil {
return fmt.Errorf("failed to write message: %w", err)
}

return nil
}

func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error {
for {
select {
case <-ctx.Done():
return errors.New("context canceled")
default:
}

h, err := c.reader.NextFrame()
if err != nil {
c.conn.Close()
return fmt.Errorf("failed to advance frame: %w", err)
}

if h.OpCode.IsControl() {
if err := c.controlHandler(h, c.reader); err != nil {
return fmt.Errorf("failed to handle control frame: %w", err)
}
} else if h.OpCode == ws.OpBinary ||
h.OpCode == ws.OpText {
break
}

if err := c.reader.Discard(); err != nil {
return fmt.Errorf("failed to discard: %w", err)
}
_, reader, err := c.conn.Reader(ctx)
if err != nil {
return fmt.Errorf("failed to get reader: %w", err)
}

if c.msgStateR.IsCompressed() && c.enableCompression {
c.flateReader.Reset(c.reader)
if _, err := io.Copy(buf, c.flateReader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
} else {
if _, err := io.Copy(buf, c.reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}
if _, err := io.Copy(buf, reader); err != nil {
return fmt.Errorf("failed to read message: %w", err)
}

return nil
}

func (c *Connection) Close() error {
return c.conn.Close()
return c.conn.Close(ws.StatusNormalClosure, "")
}

func (c *Connection) Ping(ctx context.Context) error {
return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil)
return c.conn.Ping(ctx)
}
55 changes: 0 additions & 55 deletions connection_js.go

This file was deleted.

34 changes: 34 additions & 0 deletions connection_options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//go:build !js

package nostr

import (
"crypto/tls"
"net/http"
"net/textproto"

ws "github.com/coder/websocket"
)

var defaultConnectionOptions = &ws.DialOptions{
CompressionMode: ws.CompressionContextTakeover,
HTTPHeader: http.Header{
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"},
},
}

func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
if requestHeader == nil && tlsConfig == nil {
return defaultConnectionOptions
}

return &ws.DialOptions{
HTTPHeader: requestHeader,
CompressionMode: ws.CompressionContextTakeover,
HTTPClient: &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
},
}
}
15 changes: 15 additions & 0 deletions connection_options_js.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package nostr

import (
"crypto/tls"
"net/http"

ws "github.com/coder/websocket"
)

var emptyOptions = ws.DialOptions{}

func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions {
// on javascript we ignore everything because there is nothing else we can do
return &emptyOptions
}
29 changes: 15 additions & 14 deletions pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log"
"math"
"net/http"
"slices"
"strings"
"sync"
Expand All @@ -31,7 +32,7 @@ type SimplePool struct {
// custom things not often used
penaltyBoxMu sync.Mutex
penaltyBox map[string][2]float64
userAgent string
relayOptions []RelayOption
}

type DirectedFilters struct {
Expand Down Expand Up @@ -69,6 +70,17 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool {
return pool
}

// WithRelayOptions sets options that will be used on every relay instance created by this pool.
func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt {
return ropts
}

type withRelayOptionsOpt []RelayOption

func (h withRelayOptionsOpt) ApplyPoolOption(pool *SimplePool) {
pool.relayOptions = h
}

// WithAuthHandler must be a function that signs the auth event when called.
// it will be called whenever any relay in the pool returns a `CLOSED` message
// with the "auth-required:" prefix, only once for each relay
Expand Down Expand Up @@ -129,20 +141,11 @@ func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *SimplePool) {
pool.queryMiddleware = h
}

// WithUserAgent sets the user-agent header for all relay connections in the pool.
func WithUserAgent(userAgent string) withUserAgentOpt { return withUserAgentOpt(userAgent) }

type withUserAgentOpt string

func (h withUserAgentOpt) ApplyPoolOption(pool *SimplePool) {
pool.userAgent = string(h)
}

var (
_ PoolOption = (WithAuthHandler)(nil)
_ PoolOption = (WithEventMiddleware)(nil)
_ PoolOption = WithPenaltyBox()
_ PoolOption = WithUserAgent("")
_ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{}))
)

func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
Expand All @@ -169,9 +172,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) {
ctx, cancel := context.WithTimeout(pool.Context, time.Second*15)
defer cancel()

relay = NewRelay(context.Background(), url)
relay.RequestHeader.Set("User-Agent", pool.userAgent)

relay = NewRelay(context.Background(), url, pool.relayOptions...)
if err := relay.Connect(ctx); err != nil {
if pool.penaltyBox != nil {
// putting relay in penalty box
Expand Down
Loading

0 comments on commit defc349

Please sign in to comment.