diff --git a/README.md b/README.md index d81eb60..96ec4a1 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ To use it, use `-tags=libsecp256k1` whenever you're compiling your program that Install [wasmbrowsertest](https://github.com/agnivade/wasmbrowsertest), then run tests: ```sh -TEST_RELAY_URL= GOOS=js GOARCH=wasm go test -short ./... +GOOS=js GOARCH=wasm go test -short ./... ``` ## Warning: risk of goroutine bloat (if used incorrectly) diff --git a/connection.go b/connection.go index 0d08c7d..dee35f2 100644 --- a/connection.go +++ b/connection.go @@ -1,187 +1,53 @@ -//go:build !js - package nostr import ( - "bytes" - "compress/flate" "context" "crypto/tls" - "errors" "fmt" "io" - "net" "net/http" - "github.com/gobwas/httphead" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsflate" - "github.com/gobwas/ws/wsutil" + ws "github.com/coder/websocket" ) type Connection struct { - conn net.Conn - enableCompression bool - controlHandler wsutil.FrameHandlerFunc - flateReader *wsflate.Reader - reader *wsutil.Reader - flateWriter *wsflate.Writer - writer *wsutil.Writer - msgStateR *wsflate.MessageState - msgStateW *wsflate.MessageState + conn *ws.Conn } func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) { - dialer := ws.Dialer{ - Header: ws.HandshakeHeaderHTTP(requestHeader), - Extensions: []httphead.Option{ - wsflate.DefaultParameters.Option(), - }, - TLSConfig: tlsConfig, - } - conn, _, hs, err := dialer.Dial(ctx, url) + c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig)) if err != nil { - return nil, fmt.Errorf("failed to dial: %w", err) - } - - enableCompression := false - state := ws.StateClientSide - for _, extension := range hs.Extensions { - if string(extension.Name) == wsflate.ExtensionName { - enableCompression = true - state |= ws.StateExtended - break - } - } - - // reader - var flateReader *wsflate.Reader - var msgStateR wsflate.MessageState - if enableCompression { - msgStateR.SetCompressed(true) - - flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { - return flate.NewReader(r) - }) - } - - controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide) - reader := &wsutil.Reader{ - Source: conn, - State: state, - OnIntermediate: controlHandler, - CheckUTF8: false, - Extensions: []wsutil.RecvExtension{ - &msgStateR, - }, - } - - // writer - var flateWriter *wsflate.Writer - var msgStateW wsflate.MessageState - if enableCompression { - msgStateW.SetCompressed(true) - - flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { - fw, err := flate.NewWriter(w, 4) - if err != nil { - InfoLogger.Printf("Failed to create flate writer: %v", err) - } - return fw - }) + return nil, err } - writer := wsutil.NewWriter(conn, state, ws.OpText) - writer.SetExtensions(&msgStateW) - return &Connection{ - conn: conn, - enableCompression: enableCompression, - controlHandler: controlHandler, - flateReader: flateReader, - reader: reader, - msgStateR: &msgStateR, - flateWriter: flateWriter, - writer: writer, - msgStateW: &msgStateW, + conn: c, }, nil } func (c *Connection) WriteMessage(ctx context.Context, data []byte) error { - select { - case <-ctx.Done(): - return errors.New("context canceled") - default: - } - - if c.msgStateW.IsCompressed() && c.enableCompression { - c.flateWriter.Reset(c.writer) - if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { - return fmt.Errorf("failed to write message: %w", err) - } - - if err := c.flateWriter.Close(); err != nil { - return fmt.Errorf("failed to close flate writer: %w", err) - } - } else { - if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil { - return fmt.Errorf("failed to write message: %w", err) - } - } - - if err := c.writer.Flush(); err != nil { - return fmt.Errorf("failed to flush writer: %w", err) + if err := c.conn.Write(ctx, ws.MessageText, data); err != nil { + return fmt.Errorf("failed to write message: %w", err) } return nil } func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { - for { - select { - case <-ctx.Done(): - return errors.New("context canceled") - default: - } - - h, err := c.reader.NextFrame() - if err != nil { - c.conn.Close() - return fmt.Errorf("failed to advance frame: %w", err) - } - - if h.OpCode.IsControl() { - if err := c.controlHandler(h, c.reader); err != nil { - return fmt.Errorf("failed to handle control frame: %w", err) - } - } else if h.OpCode == ws.OpBinary || - h.OpCode == ws.OpText { - break - } - - if err := c.reader.Discard(); err != nil { - return fmt.Errorf("failed to discard: %w", err) - } + _, reader, err := c.conn.Reader(ctx) + if err != nil { + return fmt.Errorf("failed to get reader: %w", err) } - - if c.msgStateR.IsCompressed() && c.enableCompression { - c.flateReader.Reset(c.reader) - if _, err := io.Copy(buf, c.flateReader); err != nil { - return fmt.Errorf("failed to read message: %w", err) - } - } else { - if _, err := io.Copy(buf, c.reader); err != nil { - return fmt.Errorf("failed to read message: %w", err) - } + if _, err := io.Copy(buf, reader); err != nil { + return fmt.Errorf("failed to read message: %w", err) } - return nil } func (c *Connection) Close() error { - return c.conn.Close() + return c.conn.Close(ws.StatusNormalClosure, "") } func (c *Connection) Ping(ctx context.Context) error { - return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil) + return c.conn.Ping(ctx) } diff --git a/connection_js.go b/connection_js.go deleted file mode 100644 index ddcfee0..0000000 --- a/connection_js.go +++ /dev/null @@ -1,55 +0,0 @@ -//go:build js - -package nostr - -import ( - "context" - "crypto/tls" - "fmt" - "io" - "net/http" - - ws "github.com/coder/websocket" -) - -type Connection struct { - conn *ws.Conn -} - -func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) { - c, _, err := ws.Dial(ctx, url, nil) - if err != nil { - return nil, err - } - - return &Connection{ - conn: c, - }, nil -} - -func (c *Connection) WriteMessage(ctx context.Context, data []byte) error { - if err := c.conn.Write(ctx, ws.MessageBinary, data); err != nil { - return fmt.Errorf("failed to write message: %w", err) - } - - return nil -} - -func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { - _, reader, err := c.conn.Reader(ctx) - if err != nil { - return fmt.Errorf("failed to get reader: %w", err) - } - if _, err := io.Copy(buf, reader); err != nil { - return fmt.Errorf("failed to read message: %w", err) - } - return nil -} - -func (c *Connection) Close() error { - return c.conn.Close(ws.StatusNormalClosure, "") -} - -func (c *Connection) Ping(ctx context.Context) error { - return c.conn.Ping(ctx) -} diff --git a/connection_options.go b/connection_options.go new file mode 100644 index 0000000..69a8bf9 --- /dev/null +++ b/connection_options.go @@ -0,0 +1,34 @@ +//go:build !js + +package nostr + +import ( + "crypto/tls" + "net/http" + "net/textproto" + + ws "github.com/coder/websocket" +) + +var defaultConnectionOptions = &ws.DialOptions{ + CompressionMode: ws.CompressionContextTakeover, + HTTPHeader: http.Header{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"}, + }, +} + +func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions { + if requestHeader == nil && tlsConfig == nil { + return defaultConnectionOptions + } + + return &ws.DialOptions{ + HTTPHeader: requestHeader, + CompressionMode: ws.CompressionContextTakeover, + HTTPClient: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + }, + } +} diff --git a/connection_options_js.go b/connection_options_js.go new file mode 100644 index 0000000..f40ec86 --- /dev/null +++ b/connection_options_js.go @@ -0,0 +1,15 @@ +package nostr + +import ( + "crypto/tls" + "net/http" + + ws "github.com/coder/websocket" +) + +var emptyOptions = ws.DialOptions{} + +func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions { + // on javascript we ignore everything because there is nothing else we can do + return &emptyOptions +} diff --git a/pool.go b/pool.go index b5509b7..173b5f0 100644 --- a/pool.go +++ b/pool.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "math" + "net/http" "slices" "strings" "sync" @@ -31,7 +32,7 @@ type SimplePool struct { // custom things not often used penaltyBoxMu sync.Mutex penaltyBox map[string][2]float64 - userAgent string + relayOptions []RelayOption } type DirectedFilters struct { @@ -69,6 +70,17 @@ func NewSimplePool(ctx context.Context, opts ...PoolOption) *SimplePool { return pool } +// WithRelayOptions sets options that will be used on every relay instance created by this pool. +func WithRelayOptions(ropts ...RelayOption) withRelayOptionsOpt { + return ropts +} + +type withRelayOptionsOpt []RelayOption + +func (h withRelayOptionsOpt) ApplyPoolOption(pool *SimplePool) { + pool.relayOptions = h +} + // WithAuthHandler must be a function that signs the auth event when called. // it will be called whenever any relay in the pool returns a `CLOSED` message // with the "auth-required:" prefix, only once for each relay @@ -129,20 +141,11 @@ func (h WithAuthorKindQueryMiddleware) ApplyPoolOption(pool *SimplePool) { pool.queryMiddleware = h } -// WithUserAgent sets the user-agent header for all relay connections in the pool. -func WithUserAgent(userAgent string) withUserAgentOpt { return withUserAgentOpt(userAgent) } - -type withUserAgentOpt string - -func (h withUserAgentOpt) ApplyPoolOption(pool *SimplePool) { - pool.userAgent = string(h) -} - var ( _ PoolOption = (WithAuthHandler)(nil) _ PoolOption = (WithEventMiddleware)(nil) _ PoolOption = WithPenaltyBox() - _ PoolOption = WithUserAgent("") + _ PoolOption = WithRelayOptions(WithRequestHeader(http.Header{})) ) func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { @@ -169,9 +172,7 @@ func (pool *SimplePool) EnsureRelay(url string) (*Relay, error) { ctx, cancel := context.WithTimeout(pool.Context, time.Second*15) defer cancel() - relay = NewRelay(context.Background(), url) - relay.RequestHeader.Set("User-Agent", pool.userAgent) - + relay = NewRelay(context.Background(), url, pool.relayOptions...) if err := relay.Connect(ctx); err != nil { if pool.penaltyBox != nil { // putting relay in penalty box diff --git a/relay.go b/relay.go index 795b4ab..ca23f38 100644 --- a/relay.go +++ b/relay.go @@ -23,7 +23,7 @@ type Relay struct { closeMutex sync.Mutex URL string - RequestHeader http.Header // e.g. for origin header + requestHeader http.Header // e.g. for origin header Connection *Connection Subscriptions *xsync.MapOf[int64, *Subscription] @@ -60,7 +60,7 @@ func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay { okCallbacks: xsync.NewMapOf[string, func(bool, string)](), writeQueue: make(chan writeRequest), subscriptionChannelCloseQueue: make(chan *Subscription), - RequestHeader: make(http.Header, 1), + requestHeader: nil, } for _, opt := range opts { @@ -88,6 +88,7 @@ type RelayOption interface { var ( _ RelayOption = (WithNoticeHandler)(nil) _ RelayOption = (WithCustomHandler)(nil) + _ RelayOption = (WithRequestHeader)(nil) ) // WithNoticeHandler just takes notices and is expected to do something with them. @@ -106,6 +107,13 @@ func (ch WithCustomHandler) ApplyRelayOption(r *Relay) { r.customHandler = ch } +// WithRequestHeader sets the HTTP request header of the websocket preflight request. +type WithRequestHeader http.Header + +func (ch WithRequestHeader) ApplyRelayOption(r *Relay) { + r.requestHeader = http.Header(ch) +} + // String just returns the relay URL. func (r *Relay) String() string { return r.URL @@ -146,11 +154,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error defer cancel() } - if r.RequestHeader.Get("User-Agent") == "" { - r.RequestHeader.Set("User-Agent", "github.com/nbd-wtf/go-nostr") - } - - conn, err := NewConnection(ctx, r.URL, r.RequestHeader, tlsConfig) + conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig) if err != nil { return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err) } diff --git a/relay_js_test.go b/relay_js_test.go index f087917..c278a2e 100644 --- a/relay_js_test.go +++ b/relay_js_test.go @@ -12,40 +12,33 @@ import ( "github.com/stretchr/testify/require" ) -func TestConnectContext(t *testing.T) { +var testRelayURL = func() string { url := os.Getenv("TEST_RELAY_URL") - if url == "" { - t.Fatal("please set the environment: $TEST_RELAY_URL") + if url != "" { + return url } + return "wss://nos.lol" +}() +func TestConnectContext(t *testing.T) { // relay client ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - r, err := RelayConnect(ctx, url) + r, err := RelayConnect(ctx, testRelayURL) assert.NoError(t, err) defer r.Close() } func TestConnectContextCanceled(t *testing.T) { - url := os.Getenv("TEST_RELAY_URL") - if url == "" { - t.Fatal("please set the environment: $TEST_RELAY_URL") - } - // relay client ctx, cancel := context.WithCancel(context.Background()) cancel() // make ctx expired - _, err := RelayConnect(ctx, url) + _, err := RelayConnect(ctx, testRelayURL) assert.ErrorIs(t, err, context.Canceled) } func TestPublish(t *testing.T) { - url := os.Getenv("TEST_RELAY_URL") - if url == "" { - t.Fatal("please set the environment: $TEST_RELAY_URL") - } - // test note to be sent over websocket priv, pub := makeKeyPair(t) textNote := Event{ @@ -59,7 +52,7 @@ func TestPublish(t *testing.T) { assert.NoError(t, err) // connect a client and send the text note - rl := mustRelayConnect(t, url) + rl := mustRelayConnect(t, testRelayURL) err = rl.Publish(context.Background(), textNote) assert.NoError(t, err) } diff --git a/relay_test.go b/relay_test.go index 7dc7187..acf2f18 100644 --- a/relay_test.go +++ b/relay_test.go @@ -149,8 +149,8 @@ func TestConnectWithOrigin(t *testing.T) { defer ws.Close() // relay client - r := NewRelay(context.Background(), NormalizeURL(ws.URL)) - r.RequestHeader = http.Header{"origin": {"https://example.com"}} + r := NewRelay(context.Background(), NormalizeURL(ws.URL), + WithRequestHeader(http.Header{"origin": {"https://example.com"}})) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() err := r.Connect(ctx)