Skip to content

Commit

Permalink
optimize(head): packet encapsuling
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 12, 2024
1 parent 8215abb commit 6ede65b
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 26 deletions.
77 changes: 59 additions & 18 deletions gold/head/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,44 @@ import (
"github.com/sirupsen/logrus"
)

type PacketFlags uint16

func (pf PacketFlags) IsValid() bool {
return pf&0x8000 == 0
}

func (pf PacketFlags) DontFrag() bool {
return pf&0x4000 == 0x4000
}

func (pf PacketFlags) NoFrag() bool {
return pf == 0x4000
}

func (pf PacketFlags) IsSingle() bool {
return pf == 0
}

func (pf PacketFlags) ZeroOffset() bool {
return pf&0x1fff == 0
}

func (pf PacketFlags) Offset() uint16 {
return uint16(pf << 3)
}

// Flags extract flags from raw data
func Flags(data []byte) PacketFlags {
return PacketFlags(binary.LittleEndian.Uint16(data[10:12]))
}

// Packet 是发送和接收的最小单位
type Packet struct {
// TeaTypeDataSZ len(Data)
// idxdatsz len(Data)
// 高 5 位指定加密所用 key index
// 高 5-16 位是递增值, 用于 xchacha20 验证 additionalData
// 不得超过 65507-head 字节
TeaTypeDataSZ uint32
idxdatsz uint32
// Proto 详见 head
Proto uint8
// TTL is time to live
Expand All @@ -28,7 +59,7 @@ type Packet struct {
// DstPort 目的端口
DstPort uint16
// Flags 高3位为标志(xDM),低13位为分片偏移
Flags uint16
Flags PacketFlags
// Src 源 ip (ipv4)
Src net.IP
// Dst 目的 ip (ipv4)
Expand All @@ -37,8 +68,8 @@ type Packet struct {
// 生成时 Hash 全 0
// https://github.com/fumiama/blake2b-simd
Hash [32]byte
// CRC64 包头字段的 checksum 值,可以认为在一定时间内唯一
CRC64 uint64
// crc64 包头字段的 checksum 值,可以认为在一定时间内唯一
crc64 uint64
// Data 承载的数据
Data []byte
// 记录还有多少字节未到达
Expand All @@ -64,15 +95,16 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
err = errors.New("data len < 60")
return
}
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != binary.LittleEndian.Uint64(data[52:60]) {
p.crc64 = binary.LittleEndian.Uint64(data[52:60])
if crc64.Checksum(data[:52], crc64.MakeTable(crc64.ISO)) != p.crc64 {
err = errors.New("bad crc checksum")
return
}

sz := p.TeaTypeDataSZ & 0x0000ffff
sz := p.idxdatsz & 0x0000ffff
if sz == 0 && len(p.Data) == 0 {
p.TeaTypeDataSZ = binary.LittleEndian.Uint32(data[:4])
sz = p.TeaTypeDataSZ & 0x0000ffff
p.idxdatsz = binary.LittleEndian.Uint32(data[:4])
sz = p.idxdatsz & 0x0000ffff
if int(sz)+52 == len(data) {
p.Data = data[52:]
p.rembytes = 0
Expand All @@ -87,20 +119,19 @@ func (p *Packet) Unmarshal(data []byte) (complete bool, err error) {
p.DstPort = binary.LittleEndian.Uint16(data[8:10])
}

flags := binary.LittleEndian.Uint16(data[10:12])
flags := PacketFlags(binary.LittleEndian.Uint16(data[10:12]))

if flags&0x1fff == 0 {
if flags.ZeroOffset() {
p.Flags = flags
p.Src = make(net.IP, 4)
copy(p.Src, data[12:16])
p.Dst = make(net.IP, 4)
copy(p.Dst, data[16:20])
copy(p.Hash[:], data[20:52])
p.CRC64 = binary.LittleEndian.Uint64(data[52:60])
}

if p.rembytes > 0 {
p.rembytes -= copy(p.Data[flags<<3:], data[60:])
p.rembytes -= copy(p.Data[flags.Offset():], data[60:])
logrus.Debugln("[packet] copied frag", hex.EncodeToString(p.Hash[:]), "rembytes:", p.rembytes)
}

Expand All @@ -118,7 +149,7 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
}

if src != nil {
p.TeaTypeDataSZ = uint32(teatype)<<27 | (uint32(additional&0x07ff) << 16) | datasz&0xffff
p.idxdatsz = (uint32(teatype) << 27) | (uint32(additional&0x07ff) << 16) | datasz&0xffff
p.Src = src
offset &= 0x1fff
if dontfrag {
Expand All @@ -127,15 +158,15 @@ func (p *Packet) Marshal(src net.IP, teatype uint8, additional uint16, datasz ui
if hasmore {
offset |= 0x2000
}
p.Flags = offset
p.Flags = PacketFlags(offset)
}

return helper.OpenWriterF(func(w *helper.Writer) {
w.WriteUInt32(p.TeaTypeDataSZ)
w.WriteUInt32(p.idxdatsz)
w.WriteUInt16((uint16(p.TTL) << 8) | uint16(p.Proto))
w.WriteUInt16(p.SrcPort)
w.WriteUInt16(p.DstPort)
w.WriteUInt16(p.Flags)
w.WriteUInt16(uint16(p.Flags))
w.Write(p.Src.To4())
w.Write(p.Dst.To4())
w.Write(p.Hash[:])
Expand Down Expand Up @@ -171,7 +202,17 @@ func (p *Packet) IsVaildHash() bool {

// AdditionalData 获得 packet 的 additionalData
func (p *Packet) AdditionalData() uint16 {
return uint16((p.TeaTypeDataSZ >> 16) & 0x07ff)
return uint16((p.idxdatsz >> 16) & 0x07ff)
}

// CipherIndex packet 加密使用的密钥集目录
func (p *Packet) CipherIndex() uint8 {
return uint8(p.idxdatsz >> 27)
}

// Len is packet size
func (p *Packet) Len() int {
return int(p.idxdatsz & 0xffff)
}

// Put 将自己放回池中
Expand Down
68 changes: 68 additions & 0 deletions gold/head/packet_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package head

import (
crand "crypto/rand"
"math/rand"
"net"
"testing"
)

func TestMarshalUnmarshal(t *testing.T) {
data := make([]byte, 4096)
_, err := crand.Read(data)
if err != nil {
t.Fatal(err)
}
for i := 0; i < 0x7ff; i++ {
proto := uint8(rand.Intn(255))
teatype := uint8(rand.Intn(32))
srcPort := uint16(rand.Intn(65535))
dstPort := uint16(rand.Intn(65535))
src := make(net.IP, 4)
_, err = crand.Read(src)
if err != nil {
t.Fatal(err)
}
dst := make(net.IP, 4)
_, err = crand.Read(dst)
if err != nil {
t.Fatal(err)
}
p := NewPacket(proto, srcPort, dst, dstPort, data)
p.FillHash()
d, cl := p.Marshal(src, teatype, uint16(i), uint32(len(data)), 0, true, false)
p = SelectPacket()
ok, err := p.Unmarshal(d)
cl()
if !ok {
t.Fatal("index", i)
}
if err != nil {
t.Fatal(err)
}
if !p.IsVaildHash() {
t.Fatal("index", i)
}
if p.Proto != proto {
t.Fatal("index", i)
}
if p.CipherIndex() != teatype {
t.Fatal("index", i, "expect", teatype, "got", p.CipherIndex())
}
if p.SrcPort != srcPort {
t.Fatal("index", i)
}
if p.DstPort != dstPort {
t.Fatal("index", i)
}
if !p.Src.Equal(src) {
t.Fatal("index", i)
}
if !p.Dst.Equal(dst) {
t.Fatal("index", i)
}
if p.AdditionalData() != uint16(i) {
t.Fatal("index", i)
}
}
}
2 changes: 1 addition & 1 deletion gold/head/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func SelectPacket() *Packet {

// PutPacket 将 Packet 放回池中
func PutPacket(p *Packet) {
p.TeaTypeDataSZ = 0
p.idxdatsz = 0
p.Data = nil
packetPool.Put(p)
}
5 changes: 2 additions & 3 deletions gold/link/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ func (m *Me) listenudp() (conn *net.UDPConn, err error) {
func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, finish func()) {
defer finish()
defer logrus.Debugln("[listen] unlock index", index)
sz := packet.TeaTypeDataSZ & 0x0000ffff
r := int(sz) - len(packet.Data)
r := packet.Len() - len(packet.Data)
if r > 0 {
logrus.Warnln("[listen] @", index, "packet from endpoint", addr, "is smaller than it declared: drop it")
packet.Put()
Expand All @@ -112,7 +111,7 @@ func (m *Me) listenthread(packet *head.Packet, addr *net.UDPAddr, index int, fin
case p.IsToMe(packet.Dst):
addt := packet.AdditionalData()
var err error
packet.Data, err = p.Decode(uint8(packet.TeaTypeDataSZ>>27), addt, packet.Data)
packet.Data, err = p.Decode(packet.CipherIndex(), addt, packet.Data)
if err != nil {
logrus.Debugln("[listen] @", index, "drop invalid packet, addt:", addt, "err:", err)
packet.Put()
Expand Down
8 changes: 5 additions & 3 deletions gold/link/recv.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package link
import (
"encoding/binary"
"encoding/hex"
"strconv"
"unsafe"

"github.com/fumiama/WireGold/gold/head"
Expand All @@ -27,17 +28,18 @@ func (m *Me) wait(data []byte) *head.Packet {
logrus.Debugln("[recv] data bytes", hex.EncodeToString(data[:bound]), endl)
data = m.xordec(data)
logrus.Debugln("[recv] data xored", hex.EncodeToString(data[:bound]), endl)
flags := binary.LittleEndian.Uint16(data[10:12])
if flags&0x8000 != 0 { // not a valid packet
flags := head.Flags(data)
if !flags.IsValid() {
logrus.Debugln("[recv] drop invalid flags packet:", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
return nil
}
crc := binary.LittleEndian.Uint64(data[52:60])
if m.recved.Get(crc) { // 是重放攻击
logrus.Warnln("[recv] ignore duplicated crc packet", strconv.FormatUint(crc, 16))
return nil
}
logrus.Debugln("[recv]", len(data), "bytes data with flag", hex.EncodeToString(data[11:12]), hex.EncodeToString(data[10:11]))
if flags == 0 || flags == 0x4000 {
if flags.IsSingle() || flags.NoFrag() {
h := head.SelectPacket()
_, err := h.Unmarshal(data)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion gold/link/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (l *Link) WriteAndPut(p *head.Packet, istransfer bool) (n int, err error) {
if len(p.Data) <= delta {
return l.write(p, teatype, sndcnt, uint32(len(p.Data)), 0, istransfer, false)
}
if istransfer && p.Flags&0x4000 == 0x4000 && len(p.Data) > delta {
if istransfer && p.Flags.DontFrag() && len(p.Data) > delta {
return 0, errors.New("drop don't fragmnet big trans packet")
}
data := p.Data
Expand Down

0 comments on commit 6ede65b

Please sign in to comment.