diff --git a/adc/conn.go b/adc/conn.go index 7fa8662..a8fd102 100644 --- a/adc/conn.go +++ b/adc/conn.go @@ -1,10 +1,10 @@ package adc import ( - "bufio" "bytes" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -21,6 +21,8 @@ var ( Debug bool ) +const lineDelim = 0x0a + type Route interface { WriteMessage(msg Message) error Flush() error @@ -75,11 +77,16 @@ func NewConn(conn net.Conn) (*Conn, error) { c := &Conn{ conn: conn, } - c.write.w = bufio.NewWriter(conn) - c.read = lineproto.NewReader(conn, 0x0a) + c.write = lineproto.NewWriter(conn) + c.read = lineproto.NewReader(conn, lineDelim) if Debug { + c.write.OnLine = func(line []byte) (bool, error) { + line = bytes.TrimSuffix(line, []byte{lineDelim}) + log.Println("->", string(line)) + return true, nil + } c.read.OnLine = func(line []byte) (bool, error) { - line = bytes.TrimSuffix(line, []byte{0x0a}) + line = bytes.TrimSuffix(line, []byte{lineDelim}) log.Println("<-", string(line)) return true, nil } @@ -97,12 +104,8 @@ type Conn struct { conn net.Conn - write struct { - sync.Mutex - err error - w *bufio.Writer - } - read *lineproto.Reader + write *lineproto.Writer + read *lineproto.Reader } func (c *Conn) LocalAddr() net.Addr { @@ -142,7 +145,7 @@ func (c *Conn) KeepAlive(interval time.Duration) { case <-ticker.C: } // empty packet serves as keep-alive for ADC - err := c.writeRawPacket(nil) + err := c.writeRawPacket([]byte{lineDelim}) if err == nil { err = c.Flush() } @@ -177,10 +180,10 @@ func (c *Conn) readPacket(deadline time.Time) ([]byte, error) { s, err := c.read.ReadLine() if err != nil { return nil, err + } else if len(s) == 0 || s[len(s)-1] != lineDelim { + return nil, errors.New("invalid packet delimiter") } - // trim delimiter - s = s[:len(s)-1] - if len(s) != 0 { + if len(s) > 1 { return s, nil } // clients may send message containing only 0x0a byte @@ -323,48 +326,16 @@ func (c *Conn) writeRawPacket(s []byte) error { c.bin.RLock() defer c.bin.RUnlock() - c.write.Lock() - defer c.write.Unlock() - - if err := c.write.err; err != nil { - return err - } - if Debug { - log.Println("->", string(s)) - } - _, err := c.write.w.Write(s) - if err != nil { - c.write.err = err - } - err = c.write.w.WriteByte(0x0a) - if err != nil { - c.write.err = err - } - return err + return c.write.WriteLine(s) } // Flush the underlying buffer. Should be called after each WritePacket batch. func (c *Conn) Flush() error { - if Debug { - log.Println("-> [flushed]") - } - // make sure connection is not in binary mode c.bin.RLock() defer c.bin.RUnlock() - c.write.Lock() - defer c.write.Unlock() - - if err := c.write.err; err != nil { - return err - } - - err := c.write.w.Flush() - if err != nil { - c.write.err = err - } - return err + return c.write.Flush() } /* diff --git a/adc/packets.go b/adc/packets.go index 47e2ca3..b8be941 100644 --- a/adc/packets.go +++ b/adc/packets.go @@ -57,7 +57,7 @@ func (p BasePacket) Decode() (Message, error) { } func DecodePacket(p []byte) (Packet, error) { - if len(p) < 4 { + if len(p) < 5 { return nil, fmt.Errorf("too short for command: '%s'", string(p)) } if bytes.ContainsAny(p, "\x00") { @@ -90,10 +90,11 @@ func DecodePacket(p []byte) (Packet, error) { p = p[4:] var raw []byte if len(p) > 0 { - if p[0] != ' ' { + if p[0] == ' ' { + raw = p[1:] + } else if p[0] != lineDelim { return nil, fmt.Errorf("name separator expected") } - raw = p[1:] } if err := m.UnmarshalPacket(cname, raw); err != nil { return nil, err @@ -111,13 +112,19 @@ func (*InfoPacket) kind() byte { return kindInfo } func (p *InfoPacket) UnmarshalPacket(name MsgType, data []byte) error { + if len(data) != 0 && data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") + } p.Name = name + if len(data) != 0 { + data = data[:len(data)-1] + } p.Data = data return nil } func (p *InfoPacket) MarshalPacket() ([]byte, error) { - // IINF - n := 4 + // IINF 0x0a + n := 5 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -128,6 +135,7 @@ func (p *InfoPacket) MarshalPacket() ([]byte, error) { buf[4] = ' ' copy(buf[5:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -141,13 +149,19 @@ func (*HubPacket) kind() byte { return kindHub } func (p *HubPacket) UnmarshalPacket(name MsgType, data []byte) error { + if len(data) < 1 { + return errors.New("short hub command") + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") + } + data = data[:len(data)-1] p.Name = name p.Data = data return nil } func (p *HubPacket) MarshalPacket() ([]byte, error) { - // HINF - n := 4 + // HINF 0x0a + n := 5 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -158,6 +172,7 @@ func (p *HubPacket) MarshalPacket() ([]byte, error) { buf[4] = ' ' copy(buf[5:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -179,10 +194,13 @@ func (p *BroadcastPacket) Source() SID { } func (p *BroadcastPacket) UnmarshalPacket(name MsgType, data []byte) error { if len(data) < 4 { - return fmt.Errorf("short broadcast") - } else if len(data) > 4 && data[4] != ' ' { + return errors.New("short broadcast command") + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") + } else if len(data) > 4 && data[4] != ' ' && data[4] != lineDelim { return fmt.Errorf("separator expected: '%s'", string(data[:5])) } + data = data[:len(data)-1] p.Name = name if err := p.ID.UnmarshalAdc(data[0:4]); err != nil { return err @@ -193,8 +211,8 @@ func (p *BroadcastPacket) UnmarshalPacket(name MsgType, data []byte) error { return nil } func (p *BroadcastPacket) MarshalPacket() ([]byte, error) { - // BINF AAAA - n := 9 + // BINF AAAA 0x0a + n := 10 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -208,6 +226,7 @@ func (p *BroadcastPacket) MarshalPacket() ([]byte, error) { buf[9] = ' ' copy(buf[10:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -235,11 +254,14 @@ func (p *DirectPacket) Target() SID { func (p *DirectPacket) UnmarshalPacket(name MsgType, data []byte) error { if len(data) < 9 { return fmt.Errorf("short direct command") + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") } else if data[4] != ' ' { return fmt.Errorf("separator expected: '%s'", string(data[:9])) - } else if len(data) > 9 && data[9] != ' ' { + } else if len(data) > 9 && data[9] != ' ' && data[9] != lineDelim { return fmt.Errorf("separator expected: '%s'", string(data[:10])) } + data = data[:len(data)-1] p.Name = name if err := p.ID.UnmarshalAdc(data[0:4]); err != nil { return err @@ -253,8 +275,8 @@ func (p *DirectPacket) UnmarshalPacket(name MsgType, data []byte) error { return nil } func (p DirectPacket) MarshalPacket() ([]byte, error) { - // DCTM AAAA BBBB - n := 14 + // DCTM AAAA BBBB 0x0a + n := 15 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -271,6 +293,7 @@ func (p DirectPacket) MarshalPacket() ([]byte, error) { buf[14] = ' ' copy(buf[15:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -294,11 +317,14 @@ func (p *EchoPacket) Target() SID { func (p *EchoPacket) UnmarshalPacket(name MsgType, data []byte) error { if len(data) < 9 { return fmt.Errorf("short echo command") + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") } else if data[4] != ' ' { return fmt.Errorf("separator expected: '%s'", string(data[:9])) - } else if len(data) > 9 && data[9] != ' ' { + } else if len(data) > 9 && data[9] != ' ' && data[9] != lineDelim { return fmt.Errorf("separator expected: '%s'", string(data[:10])) } + data = data[:len(data)-1] p.Name = name if err := p.ID.UnmarshalAdc(data[0:4]); err != nil { return err @@ -312,8 +338,8 @@ func (p *EchoPacket) UnmarshalPacket(name MsgType, data []byte) error { return nil } func (p *EchoPacket) MarshalPacket() ([]byte, error) { - // EMSG AAAA BBBB - n := 14 + // EMSG AAAA BBBB 0x0a + n := 15 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -330,6 +356,7 @@ func (p *EchoPacket) MarshalPacket() ([]byte, error) { buf[14] = ' ' copy(buf[15:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -343,13 +370,19 @@ func (*ClientPacket) kind() byte { return kindClient } func (p *ClientPacket) UnmarshalPacket(name MsgType, data []byte) error { + if len(data) < 1 { + return errors.New("short client command") + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") + } + data = data[:len(data)-1] p.Name = name p.Data = data return nil } func (p *ClientPacket) MarshalPacket() ([]byte, error) { - // CINF - n := 4 + // CINF 0x0a + n := 5 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -360,6 +393,7 @@ func (p *ClientPacket) MarshalPacket() ([]byte, error) { buf[4] = ' ' copy(buf[5:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -383,9 +417,12 @@ func (p *FeaturePacket) Source() SID { func (p *FeaturePacket) UnmarshalPacket(name MsgType, data []byte) error { if len(data) < 4 { return fmt.Errorf("short feature command") - } else if len(data) > 4 && data[4] != ' ' { + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") + } else if len(data) > 4 && data[4] != ' ' && data[4] != lineDelim { return fmt.Errorf("separator expected: '%s'", string(data[:5])) } + data = data[:len(data)-1] p.Name = name p.Features = make(map[Feature]bool) if err := p.ID.UnmarshalAdc(data[0:4]); err != nil { @@ -433,8 +470,8 @@ func (p *FeaturePacket) UnmarshalPacket(name MsgType, data []byte) error { return nil } func (p *FeaturePacket) MarshalPacket() ([]byte, error) { - // FSCH AAAA +SEGA -NAT0 - n := 9 + // FSCH AAAA +SEGA -NAT0 0x0a + n := 10 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -462,6 +499,7 @@ func (p *FeaturePacket) MarshalPacket() ([]byte, error) { buf[off] = ' ' copy(buf[off+1:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } @@ -478,10 +516,13 @@ func (*UDPPacket) kind() byte { func (p *UDPPacket) UnmarshalPacket(name MsgType, data []byte) error { const l = 39 // len of CID in base32 if len(data) < l { - return fmt.Errorf("short upd command") - } else if len(data) > l && data[l] != ' ' { + return errors.New("short upd command") + } else if data[len(data)-1] != lineDelim { + return errors.New("invalid packet delimiter") + } else if len(data) > l && data[l] != ' ' && data[l] != lineDelim { return fmt.Errorf("separator expected: '%s'", string(data[:l+1])) } + data = data[:len(data)-1] p.Name = name if err := p.ID.FromBase32(string(data[0:l])); err != nil { return fmt.Errorf("wrong CID in upd command: %v", err) @@ -492,8 +533,8 @@ func (p *UDPPacket) UnmarshalPacket(name MsgType, data []byte) error { return nil } func (p *UDPPacket) MarshalPacket() ([]byte, error) { - // UINF - n := 39 + 5 + // UINF 0x0a + n := 39 + 5 + 1 if len(p.Data) > 0 { n += 1 + len(p.Data) } @@ -506,5 +547,6 @@ func (p *UDPPacket) MarshalPacket() ([]byte, error) { buf[5+39] = ' ' copy(buf[5+39+1:], p.Data) } + buf[len(buf)-1] = lineDelim return buf, nil } diff --git a/adc/packets_test.go b/adc/packets_test.go index e80c5f6..413ef5f 100644 --- a/adc/packets_test.go +++ b/adc/packets_test.go @@ -9,6 +9,8 @@ import ( "github.com/direct-connect/go-dcpp/adc/types" ) +const delim = "\x0a" + var casesPackets = []struct { data string packet Packet @@ -201,7 +203,7 @@ var casesPackets = []struct { func TestDecodePacket(t *testing.T) { for _, c := range casesPackets { t.Run("", func(t *testing.T) { - cmd, err := DecodePacket([]byte(c.data)) + cmd, err := DecodePacket([]byte(c.data + delim)) if err != nil { t.Fatal(err) } else if !reflect.DeepEqual(cmd, c.packet) { @@ -217,8 +219,8 @@ func TestEncodePacket(t *testing.T) { data, err := c.packet.MarshalPacket() if err != nil { t.Fatal(err) - } else if !bytes.Equal(data, []byte(c.data)) { - t.Fatalf("\n%#v\nvs\n%#v", string(data), string(c.data)) + } else if !bytes.Equal(data, []byte(c.data+delim)) { + t.Fatalf("\n%#v\nvs\n%#v", string(data), string(c.data+delim)) } }) }