Skip to content

Commit

Permalink
Allow Client Hello with other data attached & Simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
RPRX authored Mar 31, 2023
1 parent 442d33e commit d3d3761
Showing 1 changed file with 22 additions and 62 deletions.
84 changes: 22 additions & 62 deletions tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,54 +40,31 @@ import (
"golang.org/x/crypto/hkdf"
)

type ReaderConn struct {
Conn net.Conn
Reader *bytes.Reader
Written int
Closed bool
type WeakConn struct {
net.Conn
}

func (c *ReaderConn) Read(b []byte) (int, error) {
if c.Closed {
return 0, errors.New("Closed")
}
n, err := c.Reader.Read(b)
if err == io.EOF {
return n, errors.New("io.EOF") // prevent looping
}
return n, err
}

func (c *ReaderConn) Write(b []byte) (int, error) {
if c.Closed {
return 0, errors.New("Closed")
}
c.Written += len(b)
return len(b), nil
func (c *WeakConn) Read(b []byte) (int, error) {
return 0, fmt.Errorf("Read(%v)", len(b))
}

func (c *ReaderConn) Close() error {
c.Closed = true
return nil
}

func (c *ReaderConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
func (c *WeakConn) Write(b []byte) (int, error) {
return 0, fmt.Errorf("Write(%v)", len(b))
}

func (c *ReaderConn) RemoteAddr() net.Addr {
return c.Conn.RemoteAddr()
func (c *WeakConn) Close() error {
return fmt.Errorf("Close()")
}

func (c *ReaderConn) SetDeadline(t time.Time) error {
func (c *WeakConn) SetDeadline(t time.Time) error {
return nil
}

func (c *ReaderConn) SetReadDeadline(t time.Time) error {
func (c *WeakConn) SetReadDeadline(t time.Time) error {
return nil
}

func (c *ReaderConn) SetWriteDeadline(t time.Time) error {
func (c *WeakConn) SetWriteDeadline(t time.Time) error {
return nil
}

Expand Down Expand Up @@ -175,7 +152,7 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
done = true
break
}
if copying || len(c2sSaved) > size || len(s2cSaved) > 0 { // follow; too long; unexpected
if len(c2sSaved) > size || copying { // too long; follow
break
}
if clientHelloLen == 0 && len(c2sSaved) > recordHeaderLen {
Expand All @@ -191,19 +168,12 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
mutex.Unlock()
continue
}
if len(c2sSaved) > clientHelloLen { // unexpected
break
}
readerConn := &ReaderConn{
Conn: conn,
Reader: bytes.NewReader(c2sSaved),
}
hs.c = &Conn{
conn: readerConn,
config: config,
conn: &WeakConn{conn},
config: config,
rawInput: *bytes.NewBuffer(c2sSaved),
}
hs.clientHello, err = hs.c.readClientHello(context.Background())
if err != nil || readerConn.Reader.Len() > 0 || readerConn.Written > 0 || readerConn.Closed {
if hs.clientHello, err = hs.c.readClientHello(context.Background()); err != nil {
break
}
if hs.c.vers != VersionTLS13 || !config.ServerNames[hs.clientHello.serverName] {
Expand Down Expand Up @@ -260,11 +230,8 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
}
break
}
if done {
mutex.Unlock()
} else {
copying = true
mutex.Unlock()
mutex.Unlock()
if !done {
io.CopyBuffer(target, underlying, buf)
}
waitGroup.Done()
Expand All @@ -289,15 +256,11 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
mutex.Lock()
s2cSaved = append(s2cSaved, buf[:n]...)
if hs.c == nil || hs.c.conn != conn {
copying = true
if _, err = conn.Write(buf[:n]); err != nil {
done = true
break
}
if copying || len(s2cSaved) > size { // follow; too long
break
}
mutex.Unlock()
continue
break
}
done = true // special
if len(s2cSaved) > size {
Expand Down Expand Up @@ -386,11 +349,8 @@ func Server(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
handled = true
break
}
if done {
mutex.Unlock()
} else {
copying = true
mutex.Unlock()
mutex.Unlock()
if !done {
io.CopyBuffer(underlying, target, buf)
}
waitGroup.Done()
Expand Down

0 comments on commit d3d3761

Please sign in to comment.