Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: provide a thread-safe packager #2

Merged
merged 1 commit into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions example/concurrencyPack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package main

import (
"fmt"
"github.com/lim-yoona/tcpack"
safetcpack "github.com/lim-yoona/tcpack/safe"
"net"
"sync"
"time"
)

func main() {
// Mainly testing whether messages can be
// safely sent and received in concurrent scenarios.
address, _ := net.ResolveTCPAddr("tcp4", "127.0.0.1:9001")
address2, _ := net.ResolveTCPAddr("tcp4", ":9001")

//server
listener, _ := net.ListenTCP("tcp4", address2)

connChan := make(chan net.Conn, 5)

revNum1 := 0
revNum2 := 0
revNum3 := 0
//var mu sync.Mutex
go func() {
go func() {
tcpConnServer2 := <-connChan
mpServer2 := safetcpack.NewSafeMsgPack(8, tcpConnServer2)
for {
msgRev2, _ := mpServer2.Unpack()
fmt.Println("tcpConnServer2 receive:", msgRev2)
fmt.Println("tcpConnServer2 receive:", string(msgRev2.GetMsgData()))
revNum1++
}
}()
go func() {
tcpConnServer2 := <-connChan
mpServer2 := safetcpack.NewSafeMsgPack(8, tcpConnServer2)
for {
msgRev2, _ := mpServer2.Unpack()
//msgChan <- &msgRev2
fmt.Println("tcpConnServer2 receive:", msgRev2)
fmt.Println("tcpConnServer2 receive:", string(msgRev2.GetMsgData()))
revNum2++
}
}()
for {
tcpConnServer, _ := listener.AcceptTCP()
connChan <- tcpConnServer
connChan <- tcpConnServer
mpServer := safetcpack.NewSafeMsgPack(8, tcpConnServer)
for {
msgRev, _ := mpServer.Unpack()
//msgChan <- &msgRev
fmt.Println("tcpConnServer1 receive:", msgRev)
fmt.Println("tcpConnServer1 receive:", string(msgRev.GetMsgData()))
revNum3++
}
}
}()

// client
tcpConnClient, err := net.DialTCP("tcp4", nil, address)
if err != nil {
fmt.Println("create tcpConn failed")
}

sendNum1 := 0
sendNum2 := 0
sendNum3 := 0
stopChan := make(chan int, 3)
sChan := make(chan int)
go func() {
for {
data := []byte("helloworld!")
mpClient := safetcpack.NewSafeMsgPack(8, tcpConnClient)
msgSend := tcpack.NewMessage(0, uint32(len(data)), data)
_, _ = mpClient.Pack(msgSend)
sendNum1++
select {
case <-stopChan:
<-sChan
default:

}
}
}()
go func() {
for {
data := []byte("1")
mpClient := safetcpack.NewSafeMsgPack(8, tcpConnClient)
msgSend := tcpack.NewMessage(0, uint32(len(data)), data)
_, _ = mpClient.Pack(msgSend)
sendNum2++
select {
case <-stopChan:
<-sChan
default:

}
}
}()
go func() {
for {
data := []byte("qwertyuiopasdfghjklzxcvbnm1234567890")
mpClient := safetcpack.NewSafeMsgPack(8, tcpConnClient)
msgSend := tcpack.NewMessage(0, uint32(len(data)), data)
_, _ = mpClient.Pack(msgSend)
sendNum3++
select {
case <-stopChan:
<-sChan
default:

}
}
}()

var mu sync.Mutex
for true {
mu.Lock()
if sendNum1+sendNum2+sendNum3 == 50000 {
stopChan <- 1
stopChan <- 1
stopChan <- 1
break
}
mu.Unlock()
}

time.Sleep(time.Second * 100)
fmt.Println("sendNum:", sendNum1+sendNum2+sendNum3)
fmt.Println("revNum:", revNum1+revNum2+revNum3)
fmt.Println("succeed!")
}
102 changes: 102 additions & 0 deletions safe/safetcpack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package safetcpack

import (
"bytes"
"encoding/binary"
"github.com/lim-yoona/tcpack"
"io"
"log"
"net"
"sync"
)

// connMap is a map ensures that only one SafeMsgPack instance for one conn.
var connMap map[net.Conn]*SafeMsgPack

// mu is a mutex lock for Concurrently reading and writing connMap.
var mu sync.Mutex

// SafeMsgPack implements the interface IMsgPack,
// carrying HeadLen and conn for Pack() and Unpack(),
// and mutex for concurrent Pack() and Unpack().
type SafeMsgPack struct {
headLen uint32
conn net.Conn
mutex sync.Mutex
}

// NewSafeMsgPack returns a thread-safe packager *SafeMsgPack.
// NewSafeMsgPack returns the same packager for the same TCP connection,
// so the value of the headLen is consistent with the first time you new a SafeMsgPack.
func NewSafeMsgPack(headleng uint32, conn net.Conn) *SafeMsgPack {
defer mu.Unlock()
mu.Lock()
smp, ok := connMap[conn]
if !ok {
smp = &SafeMsgPack{
headLen: headleng,
conn: conn,
}
connMap[conn] = smp
}
return smp
}

// GetHeadLen return headLen of the message.
func (smp *SafeMsgPack) GetHeadLen() uint32 {
return smp.headLen
}

// Pack packs a message to bytes stream and sends it.
func (smp *SafeMsgPack) Pack(msg tcpack.Imessage) (uint32, error) {
buffer := bytes.NewBuffer([]byte{})
if err := binary.Write(buffer, binary.LittleEndian, msg.GetDataLen()); err != nil {
return 0, err
}
if err := binary.Write(buffer, binary.LittleEndian, msg.GetMsgId()); err != nil {
return 0, err
}
if err := binary.Write(buffer, binary.LittleEndian, msg.GetMsgData()); err != nil {
return 0, err
}
smp.mutex.Lock()
num, err := smp.conn.Write(buffer.Bytes())
smp.mutex.Unlock()
if err != nil {
return 0, err
}
return uint32(num), nil
}

// Unpack unpacks a certain length bytes stream to a message.
func (smp *SafeMsgPack) Unpack() (tcpack.Imessage, error) {
defer smp.mutex.Unlock()
smp.mutex.Lock()
headDate := make([]byte, smp.GetHeadLen())
_, err := io.ReadFull(smp.conn, headDate)
if err != nil {
log.Println("read headData failed:", err)
return nil, err
}
buffer := bytes.NewReader(headDate)
msg := tcpack.NewMessage(0, 0, nil)
if err := binary.Read(buffer, binary.LittleEndian, &msg.DataLen); err != nil {
return nil, err
}
if err := binary.Read(buffer, binary.LittleEndian, &msg.Id); err != nil {
return nil, err
}
if msg.GetDataLen() > 0 {
msg.Data = make([]byte, msg.GetDataLen())
_, err := io.ReadFull(smp.conn, msg.Data)
if err != nil {
log.Println("read msgData failed:", err)
return nil, err
}
}
return msg, nil
}

func init() {
connMap = make(map[net.Conn]*SafeMsgPack, 5)
}
154 changes: 154 additions & 0 deletions safe/safetcpack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package safetcpack

import (
"fmt"
"github.com/lim-yoona/tcpack"
. "github.com/smartystreets/goconvey/convey"
"net"
"sync"
"testing"
"time"
)

func TestNewSafeMsgPack(t *testing.T) {

address2, _ := net.ResolveTCPAddr("tcp4", ":8099")
listener, _ := net.ListenTCP("tcp4", address2)
go func() {
for {
_, _ = listener.AcceptTCP()
}
}()

address, _ := net.ResolveTCPAddr("tcp4", "127.0.0.1:8099")
tcpConn, _ := net.DialTCP("tcp4", nil, address)
smp := NewSafeMsgPack(8, tcpConn)
smp2 := NewSafeMsgPack(8, tcpConn)
smp3 := NewSafeMsgPack(6, tcpConn)
Convey("whether only onne packager for the same connection", t, func() {
So(smp, ShouldEqual, smp2)
So(8, ShouldEqual, smp3.GetHeadLen())
So(smp, ShouldEqual, smp3)
So(smp2, ShouldEqual, smp3)
})
}

func TestSafeMsgPack_Pack(t *testing.T) {
// Mainly testing whether messages can be
// safely sent and received in concurrent scenarios.
address, _ := net.ResolveTCPAddr("tcp4", "127.0.0.1:9001")
address2, _ := net.ResolveTCPAddr("tcp4", ":9001")

//server
listener, _ := net.ListenTCP("tcp4", address2)

connChan := make(chan net.Conn, 5)

revNum1 := 0
revNum2 := 0
revNum3 := 0
go func() {
go func() {
tcpConnServer2 := <-connChan
mpServer2 := NewSafeMsgPack(8, tcpConnServer2)
for {
_, _ = mpServer2.Unpack()
revNum1++
}
}()
go func() {
tcpConnServer2 := <-connChan
mpServer2 := NewSafeMsgPack(8, tcpConnServer2)
for {
_, _ = mpServer2.Unpack()
revNum2++
}
}()
for {
tcpConnServer, _ := listener.AcceptTCP()
connChan <- tcpConnServer
connChan <- tcpConnServer
mpServer := NewSafeMsgPack(8, tcpConnServer)
for {
_, _ = mpServer.Unpack()
revNum3++
}
}
}()

// client
tcpConnClient, err := net.DialTCP("tcp4", nil, address)
if err != nil {
fmt.Println("create tcpConn failed")
}

sendNum1 := 0
sendNum2 := 0
sendNum3 := 0
stopChan := make(chan int, 3)
sChan := make(chan int)
go func() {
for {
data := []byte("helloworld")
mpClient := NewSafeMsgPack(8, tcpConnClient)
msgSend := tcpack.NewMessage(0, uint32(len(data)), data)
_, _ = mpClient.Pack(msgSend)
sendNum1++
select {
case <-stopChan:
<-sChan
default:

}
}
}()
go func() {
for {
data := []byte("1")
mpClient := NewSafeMsgPack(8, tcpConnClient)
msgSend := tcpack.NewMessage(0, uint32(len(data)), data)
_, _ = mpClient.Pack(msgSend)
sendNum2++
select {
case <-stopChan:
<-sChan
default:

}
}
}()
go func() {
for {
data := make([]byte, 65600)
mpClient := NewSafeMsgPack(8, tcpConnClient)
msgSend := tcpack.NewMessage(0, uint32(len(data)), data)
_, _ = mpClient.Pack(msgSend)
sendNum3++
select {
case <-stopChan:
<-sChan
default:

}
}
}()

var mu sync.Mutex
for true {
mu.Lock()
if sendNum1+sendNum2+sendNum3 == 50000 {
stopChan <- 1
stopChan <- 1
stopChan <- 1
break
}
mu.Unlock()
}

time.Sleep(time.Second * 100)
fmt.Println("send message number:", sendNum1+sendNum2+sendNum3)
fmt.Println("receive message number:", revNum1+revNum2+revNum3)
Convey("correct send and receive message number whether equal", t, func() {
So(sendNum1+sendNum2+sendNum3, ShouldEqual, revNum1+revNum2+revNum3)
})
}
Loading