-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use coder/websocket for everything, get rid of gobwas.
supposedly it is faster, and anyway it's better to use it since we're already using it for wasm/js. (previously named nhooyr/websocket).
- Loading branch information
Showing
9 changed files
with
101 additions
and
243 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,187 +1,53 @@ | ||
//go:build !js | ||
|
||
package nostr | ||
|
||
import ( | ||
"bytes" | ||
"compress/flate" | ||
"context" | ||
"crypto/tls" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net" | ||
"net/http" | ||
|
||
"github.com/gobwas/httphead" | ||
"github.com/gobwas/ws" | ||
"github.com/gobwas/ws/wsflate" | ||
"github.com/gobwas/ws/wsutil" | ||
ws "github.com/coder/websocket" | ||
) | ||
|
||
type Connection struct { | ||
conn net.Conn | ||
enableCompression bool | ||
controlHandler wsutil.FrameHandlerFunc | ||
flateReader *wsflate.Reader | ||
reader *wsutil.Reader | ||
flateWriter *wsflate.Writer | ||
writer *wsutil.Writer | ||
msgStateR *wsflate.MessageState | ||
msgStateW *wsflate.MessageState | ||
conn *ws.Conn | ||
} | ||
|
||
func NewConnection(ctx context.Context, url string, requestHeader http.Header, tlsConfig *tls.Config) (*Connection, error) { | ||
dialer := ws.Dialer{ | ||
Header: ws.HandshakeHeaderHTTP(requestHeader), | ||
Extensions: []httphead.Option{ | ||
wsflate.DefaultParameters.Option(), | ||
}, | ||
TLSConfig: tlsConfig, | ||
} | ||
conn, _, hs, err := dialer.Dial(ctx, url) | ||
c, _, err := ws.Dial(ctx, url, getConnectionOptions(requestHeader, tlsConfig)) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to dial: %w", err) | ||
} | ||
|
||
enableCompression := false | ||
state := ws.StateClientSide | ||
for _, extension := range hs.Extensions { | ||
if string(extension.Name) == wsflate.ExtensionName { | ||
enableCompression = true | ||
state |= ws.StateExtended | ||
break | ||
} | ||
} | ||
|
||
// reader | ||
var flateReader *wsflate.Reader | ||
var msgStateR wsflate.MessageState | ||
if enableCompression { | ||
msgStateR.SetCompressed(true) | ||
|
||
flateReader = wsflate.NewReader(nil, func(r io.Reader) wsflate.Decompressor { | ||
return flate.NewReader(r) | ||
}) | ||
} | ||
|
||
controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide) | ||
reader := &wsutil.Reader{ | ||
Source: conn, | ||
State: state, | ||
OnIntermediate: controlHandler, | ||
CheckUTF8: false, | ||
Extensions: []wsutil.RecvExtension{ | ||
&msgStateR, | ||
}, | ||
} | ||
|
||
// writer | ||
var flateWriter *wsflate.Writer | ||
var msgStateW wsflate.MessageState | ||
if enableCompression { | ||
msgStateW.SetCompressed(true) | ||
|
||
flateWriter = wsflate.NewWriter(nil, func(w io.Writer) wsflate.Compressor { | ||
fw, err := flate.NewWriter(w, 4) | ||
if err != nil { | ||
InfoLogger.Printf("Failed to create flate writer: %v", err) | ||
} | ||
return fw | ||
}) | ||
return nil, err | ||
} | ||
|
||
writer := wsutil.NewWriter(conn, state, ws.OpText) | ||
writer.SetExtensions(&msgStateW) | ||
|
||
return &Connection{ | ||
conn: conn, | ||
enableCompression: enableCompression, | ||
controlHandler: controlHandler, | ||
flateReader: flateReader, | ||
reader: reader, | ||
msgStateR: &msgStateR, | ||
flateWriter: flateWriter, | ||
writer: writer, | ||
msgStateW: &msgStateW, | ||
conn: c, | ||
}, nil | ||
} | ||
|
||
func (c *Connection) WriteMessage(ctx context.Context, data []byte) error { | ||
select { | ||
case <-ctx.Done(): | ||
return errors.New("context canceled") | ||
default: | ||
} | ||
|
||
if c.msgStateW.IsCompressed() && c.enableCompression { | ||
c.flateWriter.Reset(c.writer) | ||
if _, err := io.Copy(c.flateWriter, bytes.NewReader(data)); err != nil { | ||
return fmt.Errorf("failed to write message: %w", err) | ||
} | ||
|
||
if err := c.flateWriter.Close(); err != nil { | ||
return fmt.Errorf("failed to close flate writer: %w", err) | ||
} | ||
} else { | ||
if _, err := io.Copy(c.writer, bytes.NewReader(data)); err != nil { | ||
return fmt.Errorf("failed to write message: %w", err) | ||
} | ||
} | ||
|
||
if err := c.writer.Flush(); err != nil { | ||
return fmt.Errorf("failed to flush writer: %w", err) | ||
if err := c.conn.Write(ctx, ws.MessageText, data); err != nil { | ||
return fmt.Errorf("failed to write message: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (c *Connection) ReadMessage(ctx context.Context, buf io.Writer) error { | ||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return errors.New("context canceled") | ||
default: | ||
} | ||
|
||
h, err := c.reader.NextFrame() | ||
if err != nil { | ||
c.conn.Close() | ||
return fmt.Errorf("failed to advance frame: %w", err) | ||
} | ||
|
||
if h.OpCode.IsControl() { | ||
if err := c.controlHandler(h, c.reader); err != nil { | ||
return fmt.Errorf("failed to handle control frame: %w", err) | ||
} | ||
} else if h.OpCode == ws.OpBinary || | ||
h.OpCode == ws.OpText { | ||
break | ||
} | ||
|
||
if err := c.reader.Discard(); err != nil { | ||
return fmt.Errorf("failed to discard: %w", err) | ||
} | ||
_, reader, err := c.conn.Reader(ctx) | ||
if err != nil { | ||
return fmt.Errorf("failed to get reader: %w", err) | ||
} | ||
|
||
if c.msgStateR.IsCompressed() && c.enableCompression { | ||
c.flateReader.Reset(c.reader) | ||
if _, err := io.Copy(buf, c.flateReader); err != nil { | ||
return fmt.Errorf("failed to read message: %w", err) | ||
} | ||
} else { | ||
if _, err := io.Copy(buf, c.reader); err != nil { | ||
return fmt.Errorf("failed to read message: %w", err) | ||
} | ||
if _, err := io.Copy(buf, reader); err != nil { | ||
return fmt.Errorf("failed to read message: %w", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (c *Connection) Close() error { | ||
return c.conn.Close() | ||
return c.conn.Close(ws.StatusNormalClosure, "") | ||
} | ||
|
||
func (c *Connection) Ping(ctx context.Context) error { | ||
return wsutil.WriteClientMessage(c.conn, ws.OpPing, nil) | ||
return c.conn.Ping(ctx) | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
//go:build !js | ||
|
||
package nostr | ||
|
||
import ( | ||
"crypto/tls" | ||
"net/http" | ||
"net/textproto" | ||
|
||
ws "github.com/coder/websocket" | ||
) | ||
|
||
var defaultConnectionOptions = &ws.DialOptions{ | ||
CompressionMode: ws.CompressionContextTakeover, | ||
HTTPHeader: http.Header{ | ||
textproto.CanonicalMIMEHeaderKey("User-Agent"): {"github.com/nbd-wtf/go-nostr"}, | ||
}, | ||
} | ||
|
||
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions { | ||
if requestHeader == nil && tlsConfig == nil { | ||
return defaultConnectionOptions | ||
} | ||
|
||
return &ws.DialOptions{ | ||
HTTPHeader: requestHeader, | ||
CompressionMode: ws.CompressionContextTakeover, | ||
HTTPClient: &http.Client{ | ||
Transport: &http.Transport{ | ||
TLSClientConfig: tlsConfig, | ||
}, | ||
}, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package nostr | ||
|
||
import ( | ||
"crypto/tls" | ||
"net/http" | ||
|
||
ws "github.com/coder/websocket" | ||
) | ||
|
||
var emptyOptions = ws.DialOptions{} | ||
|
||
func getConnectionOptions(requestHeader http.Header, tlsConfig *tls.Config) *ws.DialOptions { | ||
// on javascript we ignore everything because there is nothing else we can do | ||
return &emptyOptions | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.