Skip to content

Commit

Permalink
nmdc: fix races when closing connections; fixes #114
Browse files Browse the repository at this point in the history
  • Loading branch information
dennwc committed Feb 23, 2020
1 parent e4f08e3 commit 175c6f9
Showing 1 changed file with 40 additions and 10 deletions.
50 changes: 40 additions & 10 deletions nmdc/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,16 @@ func NewConn(conn net.Conn) (*Conn, error) {
type Conn struct {
kps []string // keyprints, set by TLS

cmu sync.Mutex
closed bool

fallback encoding.Encoding

conn net.Conn

w *nmdc.Writer
r *nmdc.Reader
wmu sync.Mutex
w *nmdc.Writer
closed bool

rmu sync.Mutex
r *nmdc.Reader
}

// GetKeyPrints returns keyprints set by TLS, if any.
Expand Down Expand Up @@ -182,6 +183,8 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
}

func (c *Conn) SetWriteTimeout(dt time.Duration) {
c.wmu.Lock()
defer c.wmu.Unlock()
if dt <= 0 {
c.w.Timeout = nil
return
Expand Down Expand Up @@ -223,57 +226,78 @@ func (c *Conn) setEncoding(enc encoding.Encoding, event bool) {
}

func (c *Conn) SetEncoding(enc encoding.Encoding) {
c.wmu.Lock()
defer c.wmu.Unlock()
c.setEncoding(enc, false)
}

func (c *Conn) SetFallbackEncoding(enc encoding.Encoding) {
c.wmu.Lock()
defer c.wmu.Unlock()
c.fallback = enc
}

func (c *Conn) ZOn(lvl int) error {
c.wmu.Lock()
defer c.wmu.Unlock()
return c.w.ZOnLevel(lvl)
}

// Close closes the connection.
func (c *Conn) Close() error {
c.cmu.Lock()
defer c.cmu.Unlock()
c.wmu.Lock()
defer c.wmu.Unlock()
if c.closed {
return nil
}
c.closed = true
// should not hold any other mutex
var last error
if err := c.r.Close(); err != nil {
last = err
}
// first close the writer so it flushes all buffers
if err := c.w.Close(); err != nil {
last = err
}
c.rmu.Lock()
defer c.rmu.Unlock()
// then close the connection so it unblocks the reader
_ = c.conn.Close()
// finally close the reader
if err := c.r.Close(); err != nil {
last = err
}
return last
}

func (c *Conn) WriteMsg(m ...nmdc.Message) error {
c.wmu.Lock()
defer c.wmu.Unlock()
return c.w.WriteMsg(m...)
}

func (c *Conn) WriteLine(data []byte) error {
c.wmu.Lock()
defer c.wmu.Unlock()
return c.w.WriteLine(data)
}

func (c *Conn) Flush() error {
c.wmu.Lock()
defer c.wmu.Unlock()
return c.w.Flush()
}

func (c *Conn) WriteOneMsg(m nmdc.Message) error {
c.wmu.Lock()
defer c.wmu.Unlock()
if err := c.w.WriteMsg(m); err != nil {
return err
}
return c.w.Flush()
}

func (c *Conn) WriteOneLine(data []byte) error {
c.wmu.Lock()
defer c.wmu.Unlock()
if err := c.w.WriteLine(data); err != nil {
return err
}
Expand Down Expand Up @@ -303,6 +327,8 @@ func (c *Conn) ReadMsgTo(deadline time.Time, m nmdc.Message) error {
if m == nil {
panic("nil message to decode")
}
c.rmu.Lock()
defer c.rmu.Unlock()
if !deadline.IsZero() {
c.conn.SetReadDeadline(deadline)
defer c.conn.SetReadDeadline(time.Time{})
Expand All @@ -314,6 +340,8 @@ func (c *Conn) ReadMsgToAny(deadline time.Time, m ...nmdc.Message) (nmdc.Message
if len(m) == 0 {
panic("no messages to decode")
}
c.rmu.Lock()
defer c.rmu.Unlock()
if !deadline.IsZero() {
c.conn.SetReadDeadline(deadline)
defer c.conn.SetReadDeadline(time.Time{})
Expand All @@ -322,6 +350,8 @@ func (c *Conn) ReadMsgToAny(deadline time.Time, m ...nmdc.Message) (nmdc.Message
}

func (c *Conn) ReadMsg(deadline time.Time) (nmdc.Message, error) {
c.rmu.Lock()
defer c.rmu.Unlock()
if !deadline.IsZero() {
c.conn.SetReadDeadline(deadline)
defer c.conn.SetReadDeadline(time.Time{})
Expand Down

0 comments on commit 175c6f9

Please sign in to comment.