From 71d8f7a842e34c289b043c62db776d2d8b036e02 Mon Sep 17 00:00:00 2001 From: Chen Kai <281165273grape@gmail.com> Date: Wed, 20 Nov 2024 22:19:37 +0800 Subject: [PATCH 01/13] feat:make portal wire out of p2p package Signed-off-by: Chen Kai <281165273grape@gmail.com> --- cmd/devp2p/discv4cmd.go | 2 +- cmd/devp2p/discv5cmd.go | 2 +- p2p/discover/api.go | 12 +- p2p/discover/lookup.go | 48 +- p2p/discover/node.go | 26 +- p2p/discover/portal_protocol.go | 78 +- p2p/discover/portal_utp.go | 2 +- p2p/discover/table.go | 87 +- p2p/discover/table_reval.go | 2 +- p2p/discover/table_reval_test.go | 2 +- p2p/discover/table_test.go | 70 +- p2p/discover/table_util_test.go | 18 +- p2p/discover/v4_lookup_test.go | 4 +- p2p/discover/v4_udp.go | 52 +- p2p/discover/v4_udp_test.go | 20 +- p2p/discover/v5_udp.go | 88 +- p2p/discover/v5_udp_test.go | 32 +- p2p/discover/v5wire/encoding.go | 4 +- portalnetwork/api.go | 543 +++++ portalnetwork/nat.go | 172 ++ portalnetwork/portal_protocol.go | 1918 +++++++++++++++++ portalnetwork/portal_protocol_metrics.go | 67 + portalnetwork/portal_protocol_test.go | 503 +++++ portalnetwork/portal_utp.go | 139 ++ portalnetwork/portalwire/messages.go | 336 +++ portalnetwork/portalwire/messages_encoding.go | 957 ++++++++ portalnetwork/portalwire/messages_test.go | 212 ++ 27 files changed, 5138 insertions(+), 258 deletions(-) create mode 100644 portalnetwork/api.go create mode 100644 portalnetwork/nat.go create mode 100644 portalnetwork/portal_protocol.go create mode 100644 portalnetwork/portal_protocol_metrics.go create mode 100644 portalnetwork/portal_protocol_test.go create mode 100644 portalnetwork/portal_utp.go create mode 100644 portalnetwork/portalwire/messages.go create mode 100644 portalnetwork/portalwire/messages_encoding.go create mode 100644 portalnetwork/portalwire/messages_test.go diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 8c48b3a557c1..0c832262a67c 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -163,7 +163,7 @@ func discv4Ping(ctx *cli.Context) error { defer disc.Close() start := time.Now() - if err := disc.Ping(n); err != nil { + if err := disc.PingWithoutResp(n); err != nil { return fmt.Errorf("node didn't respond: %v", err) } fmt.Printf("node responded to ping (RTT %v).\n", time.Since(start)) diff --git a/cmd/devp2p/discv5cmd.go b/cmd/devp2p/discv5cmd.go index 2422ef6644c9..b8a02b560acb 100644 --- a/cmd/devp2p/discv5cmd.go +++ b/cmd/devp2p/discv5cmd.go @@ -84,7 +84,7 @@ func discv5Ping(ctx *cli.Context) error { disc, _ := startV5(ctx) defer disc.Close() - fmt.Println(disc.Ping(n)) + fmt.Println(disc.PingWithoutResp(n)) return nil } diff --git a/p2p/discover/api.go b/p2p/discover/api.go index 4915fa688e2a..e7fe5c764ba3 100644 --- a/p2p/discover/api.go +++ b/p2p/discover/api.go @@ -114,7 +114,7 @@ func (d *DiscV5API) GetEnr(nodeId string) (bool, error) { if err != nil { return false, err } - n := d.DiscV5.tab.getNode(id) + n := d.DiscV5.tab.GetNode(id) if n == nil { return false, errors.New("record not in local routing table") } @@ -128,7 +128,7 @@ func (d *DiscV5API) DeleteEnr(nodeId string) (bool, error) { return false, err } - n := d.DiscV5.tab.getNode(id) + n := d.DiscV5.tab.GetNode(id) if n == nil { return false, errors.New("record not in local routing table") } @@ -161,7 +161,7 @@ func (d *DiscV5API) Ping(enr string) (*DiscV5PongResp, error) { return nil, err } - pong, err := d.DiscV5.pingInner(n) + pong, err := d.DiscV5.PingWithResp(n) if err != nil { return nil, err } @@ -178,7 +178,7 @@ func (d *DiscV5API) FindNodes(enr string, distances []uint) ([]string, error) { if err != nil { return nil, err } - findNodes, err := d.DiscV5.findnode(n, distances) + findNodes, err := d.DiscV5.Findnode(n, distances) if err != nil { return nil, err } @@ -283,7 +283,7 @@ func (p *PortalProtocolAPI) GetEnr(nodeId string) (string, error) { return p.portalProtocol.localNode.Node().String(), nil } - n := p.portalProtocol.table.getNode(id) + n := p.portalProtocol.table.GetNode(id) if n == nil { return "", errors.New("record not in local routing table") } @@ -297,7 +297,7 @@ func (p *PortalProtocolAPI) DeleteEnr(nodeId string) (bool, error) { return false, err } - n := p.portalProtocol.table.getNode(id) + n := p.portalProtocol.table.GetNode(id) if n == nil { return false, nil } diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index 09808b71e079..86e606ac5c79 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -24,30 +24,30 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" ) -// lookup performs a network search for nodes close to the given target. It approaches the +// Lookup performs a network search for nodes close to the given target. It approaches the // target by querying nodes that are closer to it on each iteration. The given target does // not need to be an actual node identifier. -type lookup struct { +type Lookup struct { tab *Table queryfunc queryFunc replyCh chan []*enode.Node cancelCh <-chan struct{} asked, seen map[enode.ID]bool - result nodesByDistance + result NodesByDistance replyBuffer []*enode.Node queries int } type queryFunc func(*enode.Node) ([]*enode.Node, error) -func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { - it := &lookup{ +func NewLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *Lookup { + it := &Lookup{ tab: tab, queryfunc: q, asked: make(map[enode.ID]bool), seen: make(map[enode.ID]bool), - result: nodesByDistance{target: target}, - replyCh: make(chan []*enode.Node, alpha), + result: NodesByDistance{Target: target}, + replyCh: make(chan []*enode.Node, Alpha), cancelCh: ctx.Done(), queries: -1, } @@ -57,16 +57,16 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l return it } -// run runs the lookup to completion and returns the closest nodes found. -func (it *lookup) run() []*enode.Node { +// Run runs the lookup to completion and returns the closest nodes found. +func (it *Lookup) Run() []*enode.Node { for it.advance() { } - return it.result.entries + return it.result.Entries } // advance advances the lookup until any new nodes have been found. // It returns false when the lookup has ended. -func (it *lookup) advance() bool { +func (it *Lookup) advance() bool { for it.startQueries() { select { case nodes := <-it.replyCh: @@ -74,7 +74,7 @@ func (it *lookup) advance() bool { for _, n := range nodes { if n != nil && !it.seen[n.ID()] { it.seen[n.ID()] = true - it.result.push(n, bucketSize) + it.result.Push(n, BucketSize) it.replyBuffer = append(it.replyBuffer, n) } } @@ -89,7 +89,7 @@ func (it *lookup) advance() bool { return false } -func (it *lookup) shutdown() { +func (it *Lookup) shutdown() { for it.queries > 0 { <-it.replyCh it.queries-- @@ -98,28 +98,28 @@ func (it *lookup) shutdown() { it.replyBuffer = nil } -func (it *lookup) startQueries() bool { +func (it *Lookup) startQueries() bool { if it.queryfunc == nil { return false } // The first query returns nodes from the local table. if it.queries == -1 { - closest := it.tab.findnodeByID(it.result.target, bucketSize, false) + closest := it.tab.FindnodeByID(it.result.Target, BucketSize, false) // Avoid finishing the lookup too quickly if table is empty. It'd be better to wait // for the table to fill in this case, but there is no good mechanism for that // yet. - if len(closest.entries) == 0 { + if len(closest.Entries) == 0 { it.slowdown() } it.queries = 1 - it.replyCh <- closest.entries + it.replyCh <- closest.Entries return true } // Ask the closest nodes that we haven't asked yet. - for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ { - n := it.result.entries[i] + for i := 0; i < len(it.result.Entries) && it.queries < Alpha; i++ { + n := it.result.Entries[i] if !it.asked[n.ID()] { it.asked[n.ID()] = true it.queries++ @@ -130,7 +130,7 @@ func (it *lookup) startQueries() bool { return it.queries > 0 } -func (it *lookup) slowdown() { +func (it *Lookup) slowdown() { sleep := time.NewTimer(1 * time.Second) defer sleep.Stop() select { @@ -139,9 +139,9 @@ func (it *lookup) slowdown() { } } -func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) { +func (it *Lookup) query(n *enode.Node, reply chan<- []*enode.Node) { r, err := it.queryfunc(n) - if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. + if !errors.Is(err, ErrClosed) { // avoid recording failures on shutdown. success := len(r) > 0 it.tab.trackRequest(n, success, r) if err != nil { @@ -158,10 +158,10 @@ type lookupIterator struct { nextLookup lookupFunc ctx context.Context cancel func() - lookup *lookup + lookup *Lookup } -type lookupFunc func(ctx context.Context) *lookup +type lookupFunc func(ctx context.Context) *Lookup func newLookupIterator(ctx context.Context, next lookupFunc) *lookupIterator { ctx, cancel := context.WithCancel(ctx) diff --git a/p2p/discover/node.go b/p2p/discover/node.go index ac34b7c5b2ea..8b6ec83c0376 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -54,27 +54,27 @@ func (n *tableNode) String() string { return n.Node.String() } -// nodesByDistance is a list of nodes, ordered by distance to target. -type nodesByDistance struct { - entries []*enode.Node - target enode.ID +// NodesByDistance is a list of nodes, ordered by distance to target. +type NodesByDistance struct { + Entries []*enode.Node + Target enode.ID } -// push adds the given node to the list, keeping the total size below maxElems. -func (h *nodesByDistance) push(n *enode.Node, maxElems int) { - ix := sort.Search(len(h.entries), func(i int) bool { - return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 +// Push adds the given node to the list, keeping the total size below maxElems. +func (h *NodesByDistance) Push(n *enode.Node, maxElems int) { + ix := sort.Search(len(h.Entries), func(i int) bool { + return enode.DistCmp(h.Target, h.Entries[i].ID(), n.ID()) > 0 }) - end := len(h.entries) - if len(h.entries) < maxElems { - h.entries = append(h.entries, n) + end := len(h.Entries) + if len(h.Entries) < maxElems { + h.Entries = append(h.Entries, n) } if ix < end { // Slide existing entries down to make room. // This will overwrite the entry we just appended. - copy(h.entries[ix+1:], h.entries[ix:]) - h.entries[ix] = n + copy(h.Entries[ix+1:], h.Entries[ix:]) + h.Entries[ix] = n } } diff --git a/p2p/discover/portal_protocol.go b/p2p/discover/portal_protocol.go index b1be63233c4e..8e2129854e73 100644 --- a/p2p/discover/portal_protocol.go +++ b/p2p/discover/portal_protocol.go @@ -255,7 +255,7 @@ func (p *PortalProtocol) Start() error { return err } - go p.table.loop() + go p.table.Loop() for i := 0; i < concurrentOffers; i++ { go p.offerWorker() @@ -269,7 +269,7 @@ func (p *PortalProtocol) Start() error { func (p *PortalProtocol) Stop() { p.cancelCloseCtx() - p.table.close() + p.table.Close() p.DiscV5.Close() if p.Utp != nil { p.Utp.Stop() @@ -335,7 +335,7 @@ func (p *PortalProtocol) setupDiscV5AndTable() error { Log: p.Log, } - p.table, err = newTable(p, p.localNode.Database(), cfg) + p.table, err = NewTable(p, p.localNode.Database(), cfg) if err != nil { return err } @@ -343,7 +343,7 @@ func (p *PortalProtocol) setupDiscV5AndTable() error { return nil } -func (p *PortalProtocol) ping(node *enode.Node) (uint64, error) { +func (p *PortalProtocol) Ping(node *enode.Node) (uint64, error) { pong, err := p.pingInner(node) if err != nil { return 0, err @@ -515,7 +515,7 @@ func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request * if metrics.Enabled { p.portalMetrics.messagesReceivedAccept.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -651,7 +651,7 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, if metrics.Enabled { p.portalMetrics.messagesReceivedContent.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -669,7 +669,7 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, if metrics.Enabled { p.portalMetrics.messagesReceivedContent.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -729,7 +729,7 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, if metrics.Enabled { p.portalMetrics.messagesReceivedContent.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -757,7 +757,7 @@ func (p *PortalProtocol) processNodes(target *enode.Node, resp []byte, distances return nil, err } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -828,7 +828,7 @@ func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwi if metrics.Enabled { p.portalMetrics.messagesReceivedPong.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -840,8 +840,8 @@ func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwi } func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { - if n := p.DiscV5.getNode(id); n != nil { - p.table.addInboundNode(n) + if n := p.DiscV5.GetNode(id); n != nil { + p.table.AddInboundNode(n) } msgCode := msg[0] @@ -1377,7 +1377,7 @@ func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, d return nil, errors.New("not contained in netrestrict list") } if n.UDP() <= 1024 { - return nil, errLowPort + return nil, ErrLowPort } if distances != nil { nd := enode.LogDist(sender.ID(), n.ID()) @@ -1394,24 +1394,24 @@ func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, d // lookupRandom looks up a random target. // This is needed to satisfy the transport interface. -func (p *PortalProtocol) lookupRandom() []*enode.Node { - return p.newRandomLookup(p.closeCtx).run() +func (p *PortalProtocol) LookupRandom() []*enode.Node { + return p.newRandomLookup(p.closeCtx).Run() } // lookupSelf looks up our own node ID. // This is needed to satisfy the transport interface. -func (p *PortalProtocol) lookupSelf() []*enode.Node { - return p.newLookup(p.closeCtx, p.Self().ID()).run() +func (p *PortalProtocol) LookupSelf() []*enode.Node { + return p.newLookup(p.closeCtx, p.Self().ID()).Run() } -func (p *PortalProtocol) newRandomLookup(ctx context.Context) *lookup { +func (p *PortalProtocol) newRandomLookup(ctx context.Context) *Lookup { var target enode.ID _, _ = crand.Read(target[:]) return p.newLookup(ctx, target) } -func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *lookup { - return newLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { +func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *Lookup { + return NewLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { return p.lookupWorker(n, target) }) } @@ -1419,28 +1419,28 @@ func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *lookup // lookupWorker performs FINDNODE calls against a single node during lookup. func (p *PortalProtocol) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { var ( - dists = lookupDistances(target, destNode.ID()) - nodes = nodesByDistance{target: target} + dists = LookupDistances(target, destNode.ID()) + nodes = NodesByDistance{Target: target} err error ) var r []*enode.Node r, err = p.findNodes(destNode, dists) - if errors.Is(err, errClosed) { + if errors.Is(err, ErrClosed) { return nil, err } for _, n := range r { if n.ID() != p.Self().ID() { - isAdded := p.table.addFoundNode(n, false) + isAdded := p.table.AddFoundNode(n, false) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) } else { log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) } - nodes.push(n, portalFindnodesResultLimit) + nodes.Push(n, portalFindnodesResultLimit) } } - return nodes.entries, err + return nodes.Entries, err } func (p *PortalProtocol) offerWorker() { @@ -1496,13 +1496,13 @@ func (p *PortalProtocol) findNodesCloseToContent(contentId []byte, limit int) [] // Lookup performs a recursive lookup for the given target. // It returns the closest nodes to target. func (p *PortalProtocol) Lookup(target enode.ID) []*enode.Node { - return p.newLookup(p.closeCtx, target).run() + return p.newLookup(p.closeCtx, target).Run() } // Resolve searches for a specific Node with the given ID and tries to get the most recent // version of the Node record for it. It returns n if the Node could not be resolved. func (p *PortalProtocol) Resolve(n *enode.Node) *enode.Node { - if intable := p.table.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + if intable := p.table.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { n = intable } // Try asking directly. This works if the Node is still responding on the endpoint we have. @@ -1527,7 +1527,7 @@ func (p *PortalProtocol) ResolveNodeId(id enode.ID) *enode.Node { return p.Self() } - n := p.table.getNode(id) + n := p.table.GetNode(id) if n != nil { p.Log.Debug("found Id in table and will request enr from the node", "id", id.String()) // Try asking directly. This works if the Node is still responding on the endpoint we have. @@ -1564,7 +1564,7 @@ func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit i processed[dist] = struct{}{} checkLive := !p.table.cfg.NoFindnodeLivenessCheck - for _, n := range p.table.appendBucketNodes(dist, bn[:0], checkLive) { + for _, n := range p.table.AppendBucketNodes(dist, bn[:0], checkLive) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. if netutil.CheckRelayIP(rip, n.IP()) != nil { @@ -1582,7 +1582,7 @@ func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit i func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bool, error) { lookupContext, cancel := context.WithCancel(context.Background()) - resChan := make(chan *traceContentInfoResp, alpha) + resChan := make(chan *traceContentInfoResp, Alpha) hasResult := int32(0) result := ContentInfoResp{} @@ -1600,9 +1600,9 @@ func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bo } }() - newLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) - }).run() + }).Run() close(resChan) wg.Wait() @@ -1616,7 +1616,7 @@ func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bo func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*TraceContentResult, error) { lookupContext, cancel := context.WithCancel(context.Background()) // resp channel - resChan := make(chan *traceContentInfoResp, alpha) + resChan := make(chan *traceContentInfoResp, Alpha) hasResult := int32(0) @@ -1633,10 +1633,10 @@ func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*Trac Cancelled: make([]string, 0), } - nodes := p.table.findnodeByID(enode.ID(contentId), bucketSize, false) + nodes := p.table.FindnodeByID(enode.ID(contentId), BucketSize, false) - localResponse := make([]string, 0, len(nodes.entries)) - for _, node := range nodes.entries { + localResponse := make([]string, 0, len(nodes.Entries)) + for _, node := range nodes.Entries { id := "0x" + node.ID().String() localResponse = append(localResponse, id) } @@ -1698,10 +1698,10 @@ func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*Trac } }() - lookup := newLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + lookup := NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) }) - lookup.run() + lookup.Run() close(resChan) wg.Wait() diff --git a/p2p/discover/portal_utp.go b/p2p/discover/portal_utp.go index e8c8e8f74ccd..589bd2bd15fe 100644 --- a/p2p/discover/portal_utp.go +++ b/p2p/discover/portal_utp.go @@ -122,7 +122,7 @@ func (p *PortalUtp) packetRouterFunc(buf []byte, id enode.ID, addr *net.UDPAddr) if n, ok := p.discV5.GetCachedNode(addr.String()); ok { //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf} - p.discV5.sendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) + p.discV5.SendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) return len(buf), nil } else { diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 0ad7f1bef496..547bf3440986 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -39,8 +39,8 @@ import ( ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 16 // Kademlia bucket size + Alpha = 3 // Kademlia concurrency factor + BucketSize = 16 // Kademlia bucket size maxReplacements = 10 // Size of per-bucket replacement list // We keep buckets for the upper 1/15 of distances because @@ -92,9 +92,9 @@ type Table struct { type transport interface { Self() *enode.Node RequestENR(*enode.Node) (*enode.Node, error) - lookupRandom() []*enode.Node - lookupSelf() []*enode.Node - ping(*enode.Node) (seq uint64, err error) + LookupRandom() []*enode.Node + LookupSelf() []*enode.Node + Ping(*enode.Node) (seq uint64, err error) } // bucket contains nodes, ordered by their last activity. the entry @@ -118,7 +118,7 @@ type trackRequestOp struct { success bool } -func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) { +func NewTable(t transport, db *enode.DB, cfg Config) (*Table, error) { cfg = cfg.withDefaults() tab := &Table{ net: t, @@ -196,8 +196,8 @@ func (tab *Table) self() *enode.Node { return tab.net.Self() } -// getNode returns the node with the given ID or nil if it isn't in the table. -func (tab *Table) getNode(id enode.ID) *enode.Node { +// GetNode returns the node with the given ID or nil if it isn't in the table. +func (tab *Table) GetNode(id enode.ID) *enode.Node { tab.mutex.Lock() defer tab.mutex.Unlock() @@ -210,8 +210,8 @@ func (tab *Table) getNode(id enode.ID) *enode.Node { return nil } -// close terminates the network listener and flushes the node database. -func (tab *Table) close() { +// Close terminates the network listener and flushes the node database. +func (tab *Table) Close() { close(tab.closeReq) <-tab.closed } @@ -255,40 +255,40 @@ func (tab *Table) refresh() <-chan struct{} { return done } -// findnodeByID returns the n nodes in the table that are closest to the given id. +// FindnodeByID returns the n nodes in the table that are closest to the given id. // This is used by the FINDNODE/v4 handler. // // The preferLive parameter says whether the caller wants liveness-checked results. If // preferLive is true and the table contains any verified nodes, the result will not // contain unverified nodes. However, if there are no verified nodes at all, the result // will contain unverified nodes. -func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance { +func (tab *Table) FindnodeByID(target enode.ID, nresults int, preferLive bool) *NodesByDistance { tab.mutex.Lock() defer tab.mutex.Unlock() // Scan all buckets. There might be a better way to do this, but there aren't that many // buckets, so this solution should be fine. The worst-case complexity of this loop // is O(tab.len() * nresults). - nodes := &nodesByDistance{target: target} - liveNodes := &nodesByDistance{target: target} + nodes := &NodesByDistance{Target: target} + liveNodes := &NodesByDistance{Target: target} for _, b := range &tab.buckets { for _, n := range b.entries { - nodes.push(n.Node, nresults) + nodes.Push(n.Node, nresults) if preferLive && n.isValidatedLive { - liveNodes.push(n.Node, nresults) + liveNodes.Push(n.Node, nresults) } } } - if preferLive && len(liveNodes.entries) > 0 { + if preferLive && len(liveNodes.Entries) > 0 { return liveNodes } return nodes } -// appendBucketNodes adds nodes at the given distance to the result slice. +// AppendBucketNodes adds nodes at the given distance to the result slice. // This is used by the FINDNODE/v5 handler. -func (tab *Table) appendBucketNodes(dist uint, result []*enode.Node, checkLive bool) []*enode.Node { +func (tab *Table) AppendBucketNodes(dist uint, result []*enode.Node, checkLive bool) []*enode.Node { if dist > 256 { return result } @@ -322,12 +322,12 @@ func (tab *Table) len() (n int) { return n } -// addFoundNode adds a node which may not be live. If the bucket has space available, +// AddFoundNode adds a node which may not be live. If the bucket has space available, // adding the node succeeds immediately. Otherwise, the node is added to the replacements // list. // // The caller must not hold tab.mutex. -func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { +func (tab *Table) AddFoundNode(n *enode.Node, forceSetLive bool) bool { op := addNodeOp{node: n, isInbound: false, forceSetLive: forceSetLive} select { case tab.addNodeCh <- op: @@ -337,7 +337,7 @@ func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { } } -// addInboundNode adds a node from an inbound contact. If the bucket has no space, the +// AddInboundNode adds a node from an inbound contact. If the bucket has no space, the // node is added to the replacements list. // // There is an additional safety measure: if the table is still initializing the node is @@ -345,7 +345,7 @@ func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { // repeatedly. // // The caller must not hold tab.mutex. -func (tab *Table) addInboundNode(n *enode.Node) bool { +func (tab *Table) AddInboundNode(n *enode.Node) bool { op := addNodeOp{node: n, isInbound: true} select { case tab.addNodeCh <- op: @@ -363,8 +363,8 @@ func (tab *Table) trackRequest(n *enode.Node, success bool, foundNodes []*enode. } } -// loop is the main loop of Table. -func (tab *Table) loop() { +// Loop is the main loop of Table. +func (tab *Table) Loop() { var ( refresh = time.NewTimer(tab.nextRefreshTime()) refreshDone = make(chan struct{}) // where doRefresh reports completion @@ -447,7 +447,7 @@ func (tab *Table) doRefresh(done chan struct{}) { tab.loadSeedNodes() // Run self lookup to discover new neighbor nodes. - tab.net.lookupSelf() + tab.net.LookupSelf() // The Kademlia paper specifies that the bucket refresh should // perform a lookup in the least recently used bucket. We cannot @@ -456,7 +456,7 @@ func (tab *Table) doRefresh(done chan struct{}) { // sha3 preimage that falls into a chosen bucket. // We perform a few lookups with a random target instead. for i := 0; i < 3; i++ { - tab.net.lookupRandom() + tab.net.LookupRandom() } } @@ -542,7 +542,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool { tab.log.Debug("the node is already in table", "id", req.node.ID()) return false } - if len(b.entries) >= bucketSize { + if len(b.entries) >= BucketSize { // Bucket full, maybe add as replacement. tab.log.Debug("the bucket is full and will add in replacement", "id", req.node.ID()) tab.addReplacement(b, req.node) @@ -697,7 +697,7 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) { // many times, but only if there are enough other nodes in the bucket. This latter // condition specifically exists to make bootstrapping in smaller test networks more // reliable. - if fails >= maxFindnodeFailures && len(b.entries) >= bucketSize/4 { + if fails >= maxFindnodeFailures && len(b.entries) >= BucketSize/4 { tab.deleteInBucket(b, op.node.ID()) } @@ -717,3 +717,32 @@ func pushNode(list []*tableNode, n *tableNode, max int) ([]*tableNode, *tableNod list[0] = n return list, removed } + +func (tab *Table) WaitInit() { + <-tab.initDone +} + +func (tab *Table) NodeIds() [][]string { + tab.mutex.Lock() + defer tab.mutex.Unlock() + nodes := make([][]string, 0) + for _, b := range &tab.buckets { + bucketNodes := make([]string, 0) + for _, n := range b.entries { + bucketNodes = append(bucketNodes, "0x"+n.ID().String()) + } + nodes = append(nodes, bucketNodes) + } + return nodes +} + +func (tab *Table) Config() Config { + return tab.cfg +} + +func (tab *Table) DeleteNode(n *enode.Node) { + tab.mutex.Lock() + defer tab.mutex.Unlock() + b := tab.bucket(n.ID()) + tab.deleteInBucket(b, n.ID()) +} diff --git a/p2p/discover/table_reval.go b/p2p/discover/table_reval.go index 2465fee9066f..844094cbb80a 100644 --- a/p2p/discover/table_reval.go +++ b/p2p/discover/table_reval.go @@ -111,7 +111,7 @@ func (tr *tableRevalidation) startRequest(tab *Table, n *tableNode) { func (tab *Table) doRevalidate(resp revalidationResponse, node *enode.Node) { // Ping the selected node and wait for a pong response. - remoteSeq, err := tab.net.ping(node) + remoteSeq, err := tab.net.Ping(node) resp.didRespond = err == nil // Also fetch record if the node replied and returned a higher sequence number. diff --git a/p2p/discover/table_reval_test.go b/p2p/discover/table_reval_test.go index 360544393439..16357e42aab0 100644 --- a/p2p/discover/table_reval_test.go +++ b/p2p/discover/table_reval_test.go @@ -63,7 +63,7 @@ func TestRevalidation_nodeRemoved(t *testing.T) { tr.handleResponse(tab, resp) // Ensure the node was not re-added to the table. - if tab.getNode(node.ID()) != nil { + if tab.GetNode(node.ID()) != nil { t.Fatal("node was re-added to Table") } if tr.fast.contains(node.ID()) || tr.slow.contains(node.ID()) { diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 8cc4ae33b2eb..63fa152ffc9d 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -59,7 +59,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding Log: testlog.Logger(t, log.LevelTrace), }) defer db.Close() - defer tab.close() + defer tab.Close() <-tab.initDone @@ -79,7 +79,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding transport.dead[replacementNode.ID()] = !newNodeIsResponding // Add replacement node to table. - tab.addFoundNode(replacementNode, false) + tab.AddFoundNode(replacementNode, false) t.Log("last:", last.ID()) t.Log("replacement:", replacementNode.ID()) @@ -108,7 +108,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding // Check bucket content. tab.mutex.Lock() defer tab.mutex.Unlock() - wantSize := bucketSize + wantSize := BucketSize if !lastInBucketIsResponding && !newNodeIsResponding { wantSize-- } @@ -150,11 +150,11 @@ func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport, Config{}) defer db.Close() - defer tab.close() + defer tab.Close() for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) - tab.addFoundNode(n, false) + tab.AddFoundNode(n, false) } if tab.len() > tableIPLimit { t.Errorf("too many nodes in table") @@ -167,12 +167,12 @@ func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport, Config{}) defer db.Close() - defer tab.close() + defer tab.Close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) - tab.addFoundNode(n, false) + tab.AddFoundNode(n, false) } if tab.len() > bucketIPLimit { t.Errorf("too many nodes in table") @@ -204,11 +204,11 @@ func TestTable_findnodeByID(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport, Config{}) defer db.Close() - defer tab.close() + defer tab.Close() fillTable(tab, test.All, true) // check that closest(Target, N) returns nodes - result := tab.findnodeByID(test.Target, test.N, false).entries + result := tab.FindnodeByID(test.Target, test.N, false).Entries if hasDuplicates(result) { t.Errorf("result contains duplicates") return false @@ -264,7 +264,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { t := &closeTest{ Self: gen(enode.ID{}, rand).(enode.ID), Target: gen(enode.ID{}, rand).(enode.ID), - N: rand.Intn(bucketSize), + N: rand.Intn(BucketSize), } for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { r := new(enr.Record) @@ -279,20 +279,20 @@ func TestTable_addInboundNode(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addFoundNode(n1, false) - tab.addFoundNode(n2, false) + tab.AddFoundNode(n1, false) + tab.AddFoundNode(n2, false) checkBucketContent(t, tab, []*enode.Node{n1, n2}) // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) n2v2 := enode.SignNull(newrec, n2.ID()) - tab.addInboundNode(n2v2) + tab.AddInboundNode(n2v2) checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) // Try updating n2 without sequence number change. The update is accepted @@ -301,7 +301,7 @@ func TestTable_addInboundNode(t *testing.T) { newrec.Set(enr.IP{100, 100, 100, 100}) newrec.SetSeq(n2.Seq()) n2v3 := enode.SignNull(newrec, n2.ID()) - tab.addInboundNode(n2v3) + tab.AddInboundNode(n2v3) checkBucketContent(t, tab, []*enode.Node{n1, n2v3}) } @@ -309,20 +309,20 @@ func TestTable_addFoundNode(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addFoundNode(n1, false) - tab.addFoundNode(n2, false) + tab.AddFoundNode(n1, false) + tab.AddFoundNode(n2, false) checkBucketContent(t, tab, []*enode.Node{n1, n2}) // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) n2v2 := enode.SignNull(newrec, n2.ID()) - tab.addFoundNode(n2v2, false) + tab.AddFoundNode(n2v2, false) checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) // Try updating n2 without a sequence number change. @@ -331,7 +331,7 @@ func TestTable_addFoundNode(t *testing.T) { newrec.Set(enr.IP{100, 100, 100, 100}) newrec.SetSeq(n2.Seq()) n2v3 := enode.SignNull(newrec, n2.ID()) - tab.addFoundNode(n2v3, false) + tab.AddFoundNode(n2v3, false) checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) } @@ -340,18 +340,18 @@ func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Add a v4 node. key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) - tab.addInboundNode(n1) + tab.AddInboundNode(n1) checkBucketContent(t, tab, []*enode.Node{n1}) // Add an updated version with changed IP. // The update will be accepted because it is inbound. n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) - tab.addInboundNode(n1v2) + tab.AddInboundNode(n1v2) checkBucketContent(t, tab, []*enode.Node{n1v2}) } @@ -361,18 +361,18 @@ func TestTable_addFoundNodeV4UpdateReject(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Add a v4 node. key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) - tab.addFoundNode(n1, false) + tab.AddFoundNode(n1, false) checkBucketContent(t, tab, []*enode.Node{n1}) // Add an updated version with changed IP. // The update won't be accepted because it isn't inbound. n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) - tab.addFoundNode(n1v2, false) + tab.AddFoundNode(n1v2, false) checkBucketContent(t, tab, []*enode.Node{n1}) } @@ -407,14 +407,14 @@ func TestTable_revalidateSyncRecord(t *testing.T) { }) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Insert a node. var r enr.Record r.Set(enr.IP(net.IP{127, 0, 0, 1})) id := enode.ID{1} n1 := enode.SignNull(&r, id) - tab.addFoundNode(n1, false) + tab.AddFoundNode(n1, false) // Update the node record. r.Set(enr.WithEntry("foo", "bar")) @@ -426,7 +426,7 @@ func TestTable_revalidateSyncRecord(t *testing.T) { waitForRevalidationPing(t, transport, tab, n2.ID()) waitForRevalidationPing(t, transport, tab, n2.ID()) - intable := tab.getNode(id) + intable := tab.GetNode(id) if !reflect.DeepEqual(intable, n2) { t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq()) } @@ -448,22 +448,22 @@ func TestNodesPush(t *testing.T) { // Insert all permutations into lists with size limit 3. for _, nodes := range perm { - list := nodesByDistance{target: target} + list := NodesByDistance{Target: target} for _, n := range nodes { - list.push(n, 3) + list.Push(n, 3) } - if !slices.EqualFunc(list.entries, perm[0], nodeIDEqual) { + if !slices.EqualFunc(list.Entries, perm[0], nodeIDEqual) { t.Fatal("not equal") } } // Insert all permutations into lists with size limit 2. for _, nodes := range perm { - list := nodesByDistance{target: target} + list := NodesByDistance{Target: target} for _, n := range nodes { - list.push(n, 2) + list.Push(n, 2) } - if !slices.EqualFunc(list.entries, perm[0][:2], nodeIDEqual) { + if !slices.EqualFunc(list.Entries, perm[0][:2], nodeIDEqual) { t.Fatal("not equal") } } diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 254471c25a1e..343be71f2f4b 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -45,14 +45,14 @@ func init() { func newTestTable(t transport, cfg Config) (*Table, *enode.DB) { tab, db := newInactiveTestTable(t, cfg) - go tab.loop() + go tab.Loop() return tab, db } // newInactiveTestTable creates a Table without running the main loop. func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) { db, _ := enode.OpenDB("") - tab, _ := newTable(t, db, cfg) + tab, _ := NewTable(t, db, cfg) return tab, db } @@ -110,20 +110,20 @@ func intIP(i int) net.IP { func fillBucket(tab *Table, id enode.ID) (last *tableNode) { ld := enode.LogDist(tab.self().ID(), id) b := tab.bucket(id) - for len(b.entries) < bucketSize { + for len(b.entries) < BucketSize { node := nodeAtDistance(tab.self().ID(), ld, intIP(ld)) - if !tab.addFoundNode(node, false) { + if !tab.AddFoundNode(node, false) { panic("node not added") } } - return b.entries[bucketSize-1] + return b.entries[BucketSize-1] } // fillTable adds nodes the table to the end of their corresponding bucket // if the bucket is not full. The caller must not hold tab.mutex. func fillTable(tab *Table, nodes []*enode.Node, setLive bool) { for _, n := range nodes { - tab.addFoundNode(n, setLive) + tab.AddFoundNode(n, setLive) } } @@ -160,8 +160,8 @@ func (t *pingRecorder) updateRecord(n *enode.Node) { // Stubs to satisfy the transport interface. func (t *pingRecorder) Self() *enode.Node { return nullNode } -func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } -func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } +func (t *pingRecorder) LookupSelf() []*enode.Node { return nil } +func (t *pingRecorder) LookupRandom() []*enode.Node { return nil } func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node { t.mu.Lock() @@ -190,7 +190,7 @@ func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node { } // ping simulates a ping request. -func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { +func (t *pingRecorder) Ping(n *enode.Node) (seq uint64, err error) { t.mu.Lock() defer t.mu.Unlock() diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 29a9dd6645e0..f7515fd3a9fa 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -59,8 +59,8 @@ func TestUDPv4_Lookup(t *testing.T) { for _, e := range results { t.Logf(" ld=%d, %x", enode.LogDist(lookupTestnet.target.ID(), e.ID()), e.ID().Bytes()) } - if len(results) != bucketSize { - t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) + if len(results) != BucketSize { + t.Errorf("wrong number of results: got %d, want %d", len(results), BucketSize) } checkLookupResults(t, lookupTestnet, results) } diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index 29ae5f2c084d..f1db0d63f234 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -43,8 +43,8 @@ var ( errUnknownNode = errors.New("unknown node") errTimeout = errors.New("RPC timeout") errClockWarp = errors.New("reply deadline too far in the future") - errClosed = errors.New("socket closed") - errLowPort = errors.New("low port") + ErrClosed = errors.New("socket closed") + ErrLowPort = errors.New("low port") errNoUDPEndpoint = errors.New("node has no UDP endpoint") ) @@ -143,12 +143,12 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { log: cfg.Log, } - tab, err := newTable(t, ln.Database(), cfg) + tab, err := NewTable(t, ln.Database(), cfg) if err != nil { return nil, err } t.tab = tab - go tab.loop() + go tab.Loop() t.wg.Add(2) go t.loop() @@ -167,7 +167,7 @@ func (t *UDPv4) Close() { t.cancelCloseCtx() t.conn.Close() t.wg.Wait() - t.tab.close() + t.tab.Close() }) } @@ -179,7 +179,7 @@ func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { return rn } // Check table for the ID, we might have a newer version there. - if intable := t.tab.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + if intable := t.tab.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { n = intable if rn, err := t.RequestENR(n); err == nil { return rn @@ -210,14 +210,14 @@ func (t *UDPv4) ourEndpoint() v4wire.Endpoint { return v4wire.NewEndpoint(addr, uint16(node.TCP())) } -// Ping sends a ping message to the given node. -func (t *UDPv4) Ping(n *enode.Node) error { - _, err := t.ping(n) +// PingWithoutResp sends a ping message to the given node. +func (t *UDPv4) PingWithoutResp(n *enode.Node) error { + _, err := t.Ping(n) return err } // ping sends a ping message to the given node and waits for a reply. -func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { +func (t *UDPv4) Ping(n *enode.Node) (seq uint64, err error) { addr, ok := n.UDPEndpoint() if !ok { return 0, errNoUDPEndpoint @@ -271,7 +271,7 @@ func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { // case and run the bootstrapping logic. <-t.tab.refresh() } - return t.newLookup(t.closeCtx, v4wire.EncodePubkey(key)).run() + return t.newLookup(t.closeCtx, v4wire.EncodePubkey(key)).Run() } // RandomNodes is an iterator yielding nodes from a random walk of the DHT. @@ -280,25 +280,25 @@ func (t *UDPv4) RandomNodes() enode.Iterator { } // lookupRandom implements transport. -func (t *UDPv4) lookupRandom() []*enode.Node { - return t.newRandomLookup(t.closeCtx).run() +func (t *UDPv4) LookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).Run() } // lookupSelf implements transport. -func (t *UDPv4) lookupSelf() []*enode.Node { +func (t *UDPv4) LookupSelf() []*enode.Node { pubkey := v4wire.EncodePubkey(&t.priv.PublicKey) - return t.newLookup(t.closeCtx, pubkey).run() + return t.newLookup(t.closeCtx, pubkey).Run() } -func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { +func (t *UDPv4) newRandomLookup(ctx context.Context) *Lookup { var target v4wire.Pubkey crand.Read(target[:]) return t.newLookup(ctx, target) } -func (t *UDPv4) newLookup(ctx context.Context, targetKey v4wire.Pubkey) *lookup { +func (t *UDPv4) newLookup(ctx context.Context, targetKey v4wire.Pubkey) *Lookup { target := enode.ID(crypto.Keccak256Hash(targetKey[:])) - it := newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { + it := NewLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { addr, ok := n.UDPEndpoint() if !ok { return nil, errNoUDPEndpoint @@ -315,7 +315,7 @@ func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is // active until enough nodes have been received. - nodes := make([]*enode.Node, 0, bucketSize) + nodes := make([]*enode.Node, 0, BucketSize) nreceived := 0 rm := t.pending(toid, toAddrPort.Addr(), v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { reply := r.(*v4wire.Neighbors) @@ -328,7 +328,7 @@ func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire } nodes = append(nodes, n) } - return true, nreceived >= bucketSize + return true, nreceived >= BucketSize }) t.send(toAddrPort, toid, &v4wire.Findnode{ Target: target, @@ -400,7 +400,7 @@ func (t *UDPv4) pending(id enode.ID, ip netip.Addr, ptype byte, callback replyMa case t.addReplyMatcher <- p: // loop will handle it case <-t.closeCtx.Done(): - ch <- errClosed + ch <- ErrClosed } return p } @@ -461,7 +461,7 @@ func (t *UDPv4) loop() { select { case <-t.closeCtx.Done(): for el := plist.Front(); el != nil; el = el.Next() { - el.Value.(*replyMatcher).errc <- errClosed + el.Value.(*replyMatcher).errc <- ErrClosed } return @@ -599,7 +599,7 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) { func (t *UDPv4) nodeFromRPC(sender netip.AddrPort, rn v4wire.Node) (*enode.Node, error) { if rn.UDP <= 1024 { - return nil, errLowPort + return nil, ErrLowPort } if err := netutil.CheckRelayIP(sender.Addr().AsSlice(), rn.IP); err != nil { return nil, err @@ -692,10 +692,10 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port())) if time.Since(t.db.LastPongReceived(n.ID(), from.Addr())) > bondExpiration { t.sendPing(fromID, from, func() { - t.tab.addInboundNode(n) + t.tab.AddInboundNode(n) }) } else { - t.tab.addInboundNode(n) + t.tab.AddInboundNode(n) } // Update node database and endpoint predictor. @@ -747,7 +747,7 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID e // Determine closest nodes. target := enode.ID(crypto.Keccak256Hash(req.Target[:])) preferLive := !t.tab.cfg.NoFindnodeLivenessCheck - closest := t.tab.findnodeByID(target, bucketSize, preferLive).entries + closest := t.tab.FindnodeByID(target, BucketSize, preferLive).Entries // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the packet size limit. diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 1af31f4f1b9b..004fe6d7e80a 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -112,7 +112,7 @@ func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() dgram, err := test.pipe.receive() - if err == errClosed { + if err == ErrClosed { return true } else if err != nil { test.t.Error("packet receive error:", err) @@ -151,7 +151,7 @@ func TestUDPv4_pingTimeout(t *testing.T) { key := newkey() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} node := enode.NewV4(&key.PublicKey, toaddr.IP, 0, toaddr.Port) - if _, err := test.udp.ping(node); err != errTimeout { + if _, err := test.udp.Ping(node); err != errTimeout { t.Error("expected timeout error, got", err) } } @@ -256,9 +256,9 @@ func TestUDPv4_findnode(t *testing.T) { // put a few nodes into the table. their exact // distribution shouldn't matter much, although we need to // take care not to overflow any bucket. - nodes := &nodesByDistance{target: testTarget.ID()} + nodes := &NodesByDistance{Target: testTarget.ID()} live := make(map[enode.ID]bool) - numCandidates := 2 * bucketSize + numCandidates := 2 * BucketSize for i := 0; i < numCandidates; i++ { key := newkey() ip := net.IP{10, 13, 0, byte(i)} @@ -267,8 +267,8 @@ func TestUDPv4_findnode(t *testing.T) { if i > numCandidates/2 { live[n.ID()] = true } - test.table.addFoundNode(n, live[n.ID()]) - nodes.push(n, numCandidates) + test.table.AddFoundNode(n, live[n.ID()]) + nodes.Push(n, numCandidates) } // ensure there's a bond with the test node, @@ -277,7 +277,7 @@ func TestUDPv4_findnode(t *testing.T) { test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr(), time.Now()) // check that closest neighbors are returned. - expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) + expected := test.table.FindnodeByID(testTarget.ID(), BucketSize, true) test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp}) waitNeighbors := func(want []*enode.Node) { test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) { @@ -287,7 +287,7 @@ func TestUDPv4_findnode(t *testing.T) { } for i, n := range p.Nodes { if n.ID.ID() != want[i].ID() { - t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i]) + t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.Entries[i]) } if !live[n.ID.ID()] { t.Errorf("result includes dead node %v", n.ID.ID()) @@ -296,7 +296,7 @@ func TestUDPv4_findnode(t *testing.T) { }) } // Receive replies. - want := expected.entries + want := expected.Entries if len(want) > v4wire.MaxNeighbors { waitNeighbors(want[:v4wire.MaxNeighbors]) want = want[v4wire.MaxNeighbors:] @@ -644,7 +644,7 @@ func (c *dgramPipe) receive() (dgram, error) { c.cond.Wait() } if c.closed { - return dgram{}, errClosed + return dgram{}, ErrClosed } if timedOut { return dgram{}, errTimeout diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 6383f5e4a731..db66231b02da 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -140,7 +140,7 @@ func ListenV5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { if err != nil { return nil, err } - go t.tab.loop() + go t.tab.Loop() t.wg.Add(2) go t.readLoop() go t.dispatch() @@ -180,7 +180,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { cancelCloseCtx: cancelCloseCtx, } t.talk = newTalkSystem(t) - tab, err := newTable(t, t.db, cfg) + tab, err := NewTable(t, t.db, cfg) if err != nil { return nil, err } @@ -200,20 +200,20 @@ func (t *UDPv5) Close() { t.conn.Close() t.talk.wait() t.wg.Wait() - t.tab.close() + t.tab.Close() }) } -// Ping sends a ping message to the given node. -func (t *UDPv5) Ping(n *enode.Node) error { - _, err := t.ping(n) +// PingWithoutResp sends a ping message to the given node. +func (t *UDPv5) PingWithoutResp(n *enode.Node) error { + _, err := t.Ping(n) return err } // Resolve searches for a specific node with the given ID and tries to get the most recent // version of the node record for it. It returns n if the node could not be resolved. func (t *UDPv5) Resolve(n *enode.Node) *enode.Node { - if intable := t.tab.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + if intable := t.tab.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { n = intable } // Try asking directly. This works if the node is still responding on the endpoint we have. @@ -237,7 +237,7 @@ func (t *UDPv5) ResolveNodeId(id enode.ID) *enode.Node { return t.Self() } - n := t.tab.getNode(id) + n := t.tab.GetNode(id) if n != nil { // Try asking directly. This works if the Node is still responding on the endpoint we have. if resp, err := t.RequestENR(n); err == nil { @@ -341,29 +341,29 @@ func (t *UDPv5) RandomNodes() enode.Iterator { // Lookup performs a recursive lookup for the given target. // It returns the closest nodes to target. func (t *UDPv5) Lookup(target enode.ID) []*enode.Node { - return t.newLookup(t.closeCtx, target).run() + return t.newLookup(t.closeCtx, target).Run() } // lookupRandom looks up a random target. // This is needed to satisfy the transport interface. -func (t *UDPv5) lookupRandom() []*enode.Node { - return t.newRandomLookup(t.closeCtx).run() +func (t *UDPv5) LookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).Run() } // lookupSelf looks up our own node ID. // This is needed to satisfy the transport interface. -func (t *UDPv5) lookupSelf() []*enode.Node { - return t.newLookup(t.closeCtx, t.Self().ID()).run() +func (t *UDPv5) LookupSelf() []*enode.Node { + return t.newLookup(t.closeCtx, t.Self().ID()).Run() } -func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup { +func (t *UDPv5) newRandomLookup(ctx context.Context) *Lookup { var target enode.ID crand.Read(target[:]) return t.newLookup(ctx, target) } -func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { - return newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { +func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *Lookup { + return NewLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { return t.lookupWorker(n, target) }) } @@ -371,27 +371,27 @@ func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { // lookupWorker performs FINDNODE calls against a single node during lookup. func (t *UDPv5) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { var ( - dists = lookupDistances(target, destNode.ID()) - nodes = nodesByDistance{target: target} + dists = LookupDistances(target, destNode.ID()) + nodes = NodesByDistance{Target: target} err error ) var r []*enode.Node - r, err = t.findnode(destNode, dists) - if errors.Is(err, errClosed) { + r, err = t.Findnode(destNode, dists) + if errors.Is(err, ErrClosed) { return nil, err } for _, n := range r { if n.ID() != t.Self().ID() { - nodes.push(n, findnodeResultLimit) + nodes.Push(n, findnodeResultLimit) } } - return nodes.entries, err + return nodes.Entries, err } -// lookupDistances computes the distance parameter for FINDNODE calls to dest. +// LookupDistances computes the distance parameter for FINDNODE calls to dest. // It chooses distances adjacent to logdist(target, dest), e.g. for a target // with logdist(target, dest) = 255 the result is [255, 256, 254]. -func lookupDistances(target, dest enode.ID) (dists []uint) { +func LookupDistances(target, dest enode.ID) (dists []uint) { td := enode.LogDist(target, dest) dists = append(dists, uint(td)) for i := 1; len(dists) < lookupRequestLimit; i++ { @@ -406,8 +406,8 @@ func lookupDistances(target, dest enode.ID) (dists []uint) { } // ping calls PING on a node and waits for a PONG response. -func (t *UDPv5) ping(n *enode.Node) (uint64, error) { - pong, err := t.pingInner(n) +func (t *UDPv5) Ping(n *enode.Node) (uint64, error) { + pong, err := t.PingWithResp(n) if err != nil { return 0, err } @@ -415,8 +415,8 @@ func (t *UDPv5) ping(n *enode.Node) (uint64, error) { return pong.ENRSeq, nil } -// pingInner calls PING on a node and waits for a PONG response. -func (t *UDPv5) pingInner(n *enode.Node) (*v5wire.Pong, error) { +// PingWithResp calls PING on a node and waits for a PONG response. +func (t *UDPv5) PingWithResp(n *enode.Node) (*v5wire.Pong, error) { req := &v5wire.Ping{ENRSeq: t.localNode.Node().Seq()} resp := t.callToNode(n, v5wire.PongMsg, req) defer t.callDone(resp) @@ -431,7 +431,7 @@ func (t *UDPv5) pingInner(n *enode.Node) (*v5wire.Pong, error) { // RequestENR requests n's record. func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { - nodes, err := t.findnode(n, []uint{0}) + nodes, err := t.Findnode(n, []uint{0}) if err != nil { return nil, err } @@ -441,8 +441,8 @@ func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { return nodes[0], nil } -// findnode calls FINDNODE on a node and waits for responses. -func (t *UDPv5) findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) { +// Findnode calls FINDNODE on a node and waits for responses. +func (t *UDPv5) Findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) { resp := t.callToNode(n, v5wire.NodesMsg, &v5wire.Findnode{Distances: distances}) return t.waitForNodes(resp, distances) } @@ -493,7 +493,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s return nil, errors.New("not contained in netrestrict list") } if node.UDP() <= 1024 { - return nil, errLowPort + return nil, ErrLowPort } if distances != nil { nd := enode.LogDist(c.id, node.ID()) @@ -537,7 +537,7 @@ func (t *UDPv5) initCall(c *callV5, responseType byte, packet v5wire.Packet) { select { case t.callCh <- c: case <-t.closeCtx.Done(): - c.err <- errClosed + c.err <- ErrClosed } } @@ -630,12 +630,12 @@ func (t *UDPv5) dispatch() { close(t.readNextCh) for id, queue := range t.callQueue { for _, c := range queue { - c.err <- errClosed + c.err <- ErrClosed } delete(t.callQueue, id) } for id, c := range t.activeCallByNode { - c.err <- errClosed + c.err <- ErrClosed delete(t.activeCallByNode, id) delete(t.activeCallByAuth, c.nonce) } @@ -709,7 +709,7 @@ func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr netip.AddrPort, pack } } -func (t *UDPv5) sendFromAnotherThreadWithNode(node *enode.Node, toAddr netip.AddrPort, packet v5wire.Packet) { +func (t *UDPv5) SendFromAnotherThreadWithNode(node *enode.Node, toAddr netip.AddrPort, packet v5wire.Packet) { select { case t.sendCh <- sendRequest{node.ID(), node, toAddr, packet}: case <-t.closeCtx.Done(): @@ -792,7 +792,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error { } if fromNode != nil { // Handshake succeeded, add to table. - t.tab.addInboundNode(fromNode) + t.tab.AddInboundNode(fromNode) t.putCache(fromAddr.String(), fromNode) } if packet.Kind() != v5wire.WhoareyouPacket { @@ -825,9 +825,9 @@ func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr netip.AddrPort, p v return true } -// getNode looks for a node record in table and database. -func (t *UDPv5) getNode(id enode.ID) *enode.Node { - if n := t.tab.getNode(id); n != nil { +// GetNode looks for a node record in table and database. +func (t *UDPv5) GetNode(id enode.ID) *enode.Node { + if n := t.tab.GetNode(id); n != nil { return n } if n := t.localNode.Database().Node(id); n != nil { @@ -865,7 +865,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) { challenge := &v5wire.Whoareyou{Nonce: p.Nonce} crand.Read(challenge.IDNonce[:]) - if n := t.getNode(fromID); n != nil { + if n := t.GetNode(fromID); n != nil { challenge.Node = n challenge.RecordSeq = n.Seq() } @@ -952,7 +952,7 @@ func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) [ processed[dist] = struct{}{} checkLive := !t.tab.cfg.NoFindnodeLivenessCheck - for _, n := range t.tab.appendBucketNodes(dist, bn[:0], checkLive) { + for _, n := range t.tab.AppendBucketNodes(dist, bn[:0], checkLive) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. if netutil.CheckRelayAddr(rip, n.IPAddr()) != nil { @@ -1014,3 +1014,7 @@ func (t *UDPv5) GetCachedNode(addr string) (*enode.Node, bool) { n, ok := t.cachedAddrNode[addr] return n, ok } + +func (t *UDPv5) Table() *Table { + return t.tab +} diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 2db9824e9708..3abea16884d3 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -143,7 +143,7 @@ func TestUDPv5_unknownPacket(t *testing.T) { // Make Node known. n := test.getNode(test.remotekey, test.remoteaddr).Node() - test.table.addFoundNode(n, false) + test.table.AddFoundNode(n, false) test.packetIn(&v5wire.Unknown{Nonce: nonce}) test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { @@ -237,7 +237,7 @@ func TestUDPv5_pingCall(t *testing.T) { // This ping times out. go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {}) @@ -247,7 +247,7 @@ func TestUDPv5_pingCall(t *testing.T) { // This ping works. go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { @@ -259,7 +259,7 @@ func TestUDPv5_pingCall(t *testing.T) { // This ping gets a reply from the wrong endpoint. go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { @@ -288,7 +288,7 @@ func TestUDPv5_findnodeCall(t *testing.T) { ) go func() { var err error - response, err = test.udp.findnode(remote, distances) + response, err = test.udp.Findnode(remote, distances) done <- err }() @@ -330,11 +330,11 @@ func TestUDPv5_callResend(t *testing.T) { remote := test.getNode(test.remotekey, test.remoteaddr).Node() done := make(chan error, 2) go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() @@ -367,7 +367,7 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) { remote := test.getNode(test.remotekey, test.remoteaddr).Node() done := make(chan error, 1) go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() @@ -398,7 +398,7 @@ func TestUDPv5_callTimeoutReset(t *testing.T) { done = make(chan error, 1) ) go func() { - _, err := test.udp.findnode(remote, []uint{distance}) + _, err := test.udp.Findnode(remote, []uint{distance}) done <- err }() @@ -535,38 +535,38 @@ func TestUDPv5_talkRequest(t *testing.T) { } } -// This test checks that lookupDistances works. +// This test checks that LookupDistances works. func TestUDPv5_lookupDistances(t *testing.T) { test := newUDPV5Test(t) lnID := test.table.self().ID() t.Run("target distance of 1", func(t *testing.T) { node := nodeAtDistance(lnID, 1, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{1, 2, 3}, dists) }) t.Run("target distance of 2", func(t *testing.T) { node := nodeAtDistance(lnID, 2, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{2, 3, 1}, dists) }) t.Run("target distance of 128", func(t *testing.T) { node := nodeAtDistance(lnID, 128, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{128, 129, 127}, dists) }) t.Run("target distance of 255", func(t *testing.T) { node := nodeAtDistance(lnID, 255, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{255, 256, 254}, dists) }) t.Run("target distance of 256", func(t *testing.T) { node := nodeAtDistance(lnID, 256, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{256, 255, 254}, dists) }) } @@ -817,7 +817,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { exptype := fn.Type().In(0) dgram, err := test.pipe.receive() - if err == errClosed { + if err == ErrClosed { return true } if err == errTimeout { diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index a6cc278bba7e..6f33aa831c48 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -94,7 +94,7 @@ const ( // Should reject packets smaller than minPacketSize. minPacketSize = 63 - maxPacketSize = 1280 + MaxPacketSize = 1280 minMessageSize = 48 // this refers to data after static headers randomPacketMsgSize = 20 @@ -169,7 +169,7 @@ func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, pr privkey: key, sc: NewSessionCache(1024, clock), protocolID: DefaultProtocolID, - decbuf: make([]byte, maxPacketSize), + decbuf: make([]byte, MaxPacketSize), } if protocolID != nil { c.protocolID = *protocolID diff --git a/portalnetwork/api.go b/portalnetwork/api.go new file mode 100644 index 000000000000..bc7305ef8b57 --- /dev/null +++ b/portalnetwork/api.go @@ -0,0 +1,543 @@ +package portalnetwork + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" + "github.com/holiman/uint256" +) + +// DiscV5API json-rpc spec +// https://playground.open-rpc.org/?schemaUrl=https://raw.githubusercontent.com/ethereum/portal-network-specs/assembled-spec/jsonrpc/openrpc.json&uiSchema%5BappBar%5D%5Bui:splitView%5D=false&uiSchema%5BappBar%5D%5Bui:input%5D=false&uiSchema%5BappBar%5D%5Bui:examplesDropdown%5D=false +type DiscV5API struct { + DiscV5 *discover.UDPv5 +} + +func NewDiscV5API(discV5 *discover.UDPv5) *DiscV5API { + return &DiscV5API{discV5} +} + +type NodeInfo struct { + NodeId string `json:"nodeId"` + Enr string `json:"enr"` + Ip string `json:"ip"` +} + +type RoutingTableInfo struct { + Buckets [][]string `json:"buckets"` + LocalNodeId string `json:"localNodeId"` +} + +type DiscV5PongResp struct { + EnrSeq uint64 `json:"enrSeq"` + RecipientIP string `json:"recipientIP"` + RecipientPort uint16 `json:"recipientPort"` +} + +type PortalPongResp struct { + EnrSeq uint32 `json:"enrSeq"` + DataRadius string `json:"dataRadius"` +} + +type ContentInfo struct { + Content string `json:"content"` + UtpTransfer bool `json:"utpTransfer"` +} + +type TraceContentResult struct { + Content string `json:"content"` + UtpTransfer bool `json:"utpTransfer"` + Trace Trace `json:"trace"` +} + +type Trace struct { + Origin string `json:"origin"` // local node id + TargetId string `json:"targetId"` // target content id + ReceivedFrom string `json:"receivedFrom"` // the node id of which content from + Responses map[string]RespByNode `json:"responses"` // the node id and there response nodeIds + Metadata map[string]*NodeMetadata `json:"metadata"` // node id and there metadata object + StartedAtMs int `json:"startedAtMs"` // timestamp of the beginning of this request in milliseconds + Cancelled []string `json:"cancelled"` // the node ids which are send but cancelled +} + +type NodeMetadata struct { + Enr string `json:"enr"` + Distance string `json:"distance"` +} + +type RespByNode struct { + DurationMs int32 `json:"durationMs"` + RespondedWith []string `json:"respondedWith"` +} + +type Enrs struct { + Enrs []string `json:"enrs"` +} + +func (d *DiscV5API) NodeInfo() *NodeInfo { + n := d.DiscV5.LocalNode().Node() + + return &NodeInfo{ + NodeId: "0x" + n.ID().String(), + Enr: n.String(), + Ip: n.IP().String(), + } +} + +func (d *DiscV5API) RoutingTableInfo() *RoutingTableInfo { + n := d.DiscV5.LocalNode().Node() + bucketNodes := d.DiscV5.RoutingTableInfo() + + return &RoutingTableInfo{ + Buckets: bucketNodes, + LocalNodeId: "0x" + n.ID().String(), + } +} + +func (d *DiscV5API) AddEnr(enr string) (bool, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return false, err + } + + // immediately add the node to the routing table + d.DiscV5.Table().AddInboundNode(n) + return true, nil +} + +func (d *DiscV5API) GetEnr(nodeId string) (bool, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return false, err + } + n := d.DiscV5.Table().GetNode(id) + if n == nil { + return false, errors.New("record not in local routing table") + } + + return true, nil +} + +func (d *DiscV5API) DeleteEnr(nodeId string) (bool, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return false, err + } + + n := d.DiscV5.Table().GetNode(id) + if n == nil { + return false, errors.New("record not in local routing table") + } + + d.DiscV5.Table().DeleteNode(n) + return true, nil +} + +func (d *DiscV5API) LookupEnr(nodeId string) (string, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return "", err + } + + enr := d.DiscV5.ResolveNodeId(id) + + if enr == nil { + return "", errors.New("record not found in DHT lookup") + } + + return enr.String(), nil +} + +func (d *DiscV5API) Ping(enr string) (*DiscV5PongResp, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + + pong, err := d.DiscV5.PingWithResp(n) + if err != nil { + return nil, err + } + + return &DiscV5PongResp{ + EnrSeq: pong.ENRSeq, + RecipientIP: pong.ToIP.String(), + RecipientPort: pong.ToPort, + }, nil +} + +func (d *DiscV5API) FindNodes(enr string, distances []uint) ([]string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + findNodes, err := d.DiscV5.Findnode(n, distances) + if err != nil { + return nil, err + } + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +func (d *DiscV5API) TalkReq(enr string, protocol string, payload string) (string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return "", err + } + + req, err := hexutil.Decode(payload) + if err != nil { + return "", err + } + + talkResp, err := d.DiscV5.TalkRequest(n, protocol, req) + if err != nil { + return "", err + } + return hexutil.Encode(talkResp), nil +} + +func (d *DiscV5API) RecursiveFindNodes(nodeId string) ([]string, error) { + findNodes := d.DiscV5.Lookup(enode.HexID(nodeId)) + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +type PortalProtocolAPI struct { + portalProtocol *PortalProtocol +} + +func NewPortalAPI(portalProtocol *PortalProtocol) *PortalProtocolAPI { + return &PortalProtocolAPI{ + portalProtocol: portalProtocol, + } +} + +func (p *PortalProtocolAPI) NodeInfo() *NodeInfo { + n := p.portalProtocol.localNode.Node() + + return &NodeInfo{ + NodeId: n.ID().String(), + Enr: n.String(), + Ip: n.IP().String(), + } +} + +func (p *PortalProtocolAPI) RoutingTableInfo() *RoutingTableInfo { + n := p.portalProtocol.localNode.Node() + bucketNodes := p.portalProtocol.RoutingTableInfo() + + return &RoutingTableInfo{ + Buckets: bucketNodes, + LocalNodeId: "0x" + n.ID().String(), + } +} + +func (p *PortalProtocolAPI) AddEnr(enr string) (bool, error) { + p.portalProtocol.Log.Debug("serving AddEnr", "enr", enr) + n, err := enode.ParseForAddEnr(enode.ValidSchemes, enr) + if err != nil { + return false, err + } + p.portalProtocol.AddEnr(n) + return true, nil +} + +func (p *PortalProtocolAPI) AddEnrs(enrs []string) bool { + // Note: unspecified RPC, but useful for our local testnet test + for _, enr := range enrs { + n, err := enode.ParseForAddEnr(enode.ValidSchemes, enr) + if err != nil { + continue + } + p.portalProtocol.AddEnr(n) + } + + return true +} + +func (p *PortalProtocolAPI) GetEnr(nodeId string) (string, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return "", err + } + + if id == p.portalProtocol.localNode.Node().ID() { + return p.portalProtocol.localNode.Node().String(), nil + } + + n := p.portalProtocol.table.GetNode(id) + if n == nil { + return "", errors.New("record not in local routing table") + } + + return n.String(), nil +} + +func (p *PortalProtocolAPI) DeleteEnr(nodeId string) (bool, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return false, err + } + + n := p.portalProtocol.table.GetNode(id) + if n == nil { + return false, nil + } + + p.portalProtocol.table.DeleteNode(n) + return true, nil +} + +func (p *PortalProtocolAPI) LookupEnr(nodeId string) (string, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return "", err + } + + enr := p.portalProtocol.ResolveNodeId(id) + + if enr == nil { + return "", errors.New("record not found in DHT lookup") + } + + return enr.String(), nil +} + +func (p *PortalProtocolAPI) Ping(enr string) (*PortalPongResp, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + + pong, err := p.portalProtocol.pingInner(n) + if err != nil { + return nil, err + } + + customPayload := &portalwire.PingPongCustomData{} + err = customPayload.UnmarshalSSZ(pong.CustomPayload) + if err != nil { + return nil, err + } + + nodeRadius := new(uint256.Int) + err = nodeRadius.UnmarshalSSZ(customPayload.Radius) + if err != nil { + return nil, err + } + + return &PortalPongResp{ + EnrSeq: uint32(pong.EnrSeq), + DataRadius: nodeRadius.Hex(), + }, nil +} + +func (p *PortalProtocolAPI) FindNodes(enr string, distances []uint) ([]string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + findNodes, err := p.portalProtocol.findNodes(n, distances) + if err != nil { + return nil, err + } + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +func (p *PortalProtocolAPI) FindContent(enr string, contentKey string) (interface{}, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + + contentKeyBytes, err := hexutil.Decode(contentKey) + if err != nil { + return nil, err + } + + flag, findContent, err := p.portalProtocol.findContent(n, contentKeyBytes) + if err != nil { + return nil, err + } + + switch flag { + case portalwire.ContentRawSelector: + contentInfo := &ContentInfo{ + Content: hexutil.Encode(findContent.([]byte)), + UtpTransfer: false, + } + p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) + return contentInfo, nil + case portalwire.ContentConnIdSelector: + contentInfo := &ContentInfo{ + Content: hexutil.Encode(findContent.([]byte)), + UtpTransfer: true, + } + p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) + return contentInfo, nil + default: + enrs := make([]string, 0) + for _, r := range findContent.([]*enode.Node) { + enrs = append(enrs, r.String()) + } + + p.portalProtocol.Log.Trace("FindContent", "enrs", enrs) + return &Enrs{ + Enrs: enrs, + }, nil + } +} + +func (p *PortalProtocolAPI) Offer(enr string, contentItems [][2]string) (string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return "", err + } + + entries := make([]*ContentEntry, 0, len(contentItems)) + for _, contentItem := range contentItems { + contentKey, err := hexutil.Decode(contentItem[0]) + if err != nil { + return "", err + } + contentValue, err := hexutil.Decode(contentItem[1]) + if err != nil { + return "", err + } + contentEntry := &ContentEntry{ + ContentKey: contentKey, + Content: contentValue, + } + entries = append(entries, contentEntry) + } + + transientOfferRequest := &TransientOfferRequest{ + Contents: entries, + } + + offerReq := &OfferRequest{ + Kind: TransientOfferRequestKind, + Request: transientOfferRequest, + } + accept, err := p.portalProtocol.offer(n, offerReq) + if err != nil { + return "", err + } + + return hexutil.Encode(accept), nil +} + +func (p *PortalProtocolAPI) RecursiveFindNodes(nodeId string) ([]string, error) { + findNodes := p.portalProtocol.Lookup(enode.HexID(nodeId)) + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +func (p *PortalProtocolAPI) RecursiveFindContent(contentKeyHex string) (*ContentInfo, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return nil, err + } + contentId := p.portalProtocol.toContentId(contentKey) + + data, err := p.portalProtocol.Get(contentKey, contentId) + if err == nil { + return &ContentInfo{ + Content: hexutil.Encode(data), + UtpTransfer: false, + }, err + } + p.portalProtocol.Log.Warn("find content err", "contextKey", hexutil.Encode(contentKey), "err", err) + + content, utpTransfer, err := p.portalProtocol.ContentLookup(contentKey, contentId) + + if err != nil { + return nil, err + } + + return &ContentInfo{ + Content: hexutil.Encode(content), + UtpTransfer: utpTransfer, + }, err +} + +func (p *PortalProtocolAPI) LocalContent(contentKeyHex string) (string, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return "", err + } + contentId := p.portalProtocol.ToContentId(contentKey) + content, err := p.portalProtocol.Get(contentKey, contentId) + + if err != nil { + return "", err + } + return hexutil.Encode(content), nil +} + +func (p *PortalProtocolAPI) Store(contentKeyHex string, contextHex string) (bool, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return false, err + } + contentId := p.portalProtocol.ToContentId(contentKey) + if !p.portalProtocol.InRange(contentId) { + return false, nil + } + content, err := hexutil.Decode(contextHex) + if err != nil { + return false, err + } + err = p.portalProtocol.Put(contentKey, contentId, content) + if err != nil { + return false, err + } + return true, nil +} + +func (p *PortalProtocolAPI) Gossip(contentKeyHex, contentHex string) (int, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return 0, err + } + content, err := hexutil.Decode(contentHex) + if err != nil { + return 0, err + } + id := p.portalProtocol.Self().ID() + return p.portalProtocol.Gossip(&id, [][]byte{contentKey}, [][]byte{content}) +} + +func (p *PortalProtocolAPI) TraceRecursiveFindContent(contentKeyHex string) (*TraceContentResult, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return nil, err + } + contentId := p.portalProtocol.toContentId(contentKey) + return p.portalProtocol.TraceContentLookup(contentKey, contentId) +} diff --git a/portalnetwork/nat.go b/portalnetwork/nat.go new file mode 100644 index 000000000000..ca479d7e457d --- /dev/null +++ b/portalnetwork/nat.go @@ -0,0 +1,172 @@ +package portalnetwork + +import ( + "net" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nat" +) + +const ( + portMapDuration = 10 * time.Minute + portMapRefreshInterval = 8 * time.Minute + portMapRetryInterval = 5 * time.Minute + extipRetryInterval = 2 * time.Minute +) + +type portMapping struct { + protocol string + name string + port int + + // for use by the portMappingLoop goroutine: + extPort int // the mapped port returned by the NAT interface + nextTime mclock.AbsTime +} + +// setupPortMapping starts the port mapping loop if necessary. +// Note: this needs to be called after the LocalNode instance has been set on the server. +func (p *PortalProtocol) setupPortMapping() { + // portMappingRegister will receive up to two values: one for the TCP port if + // listening is enabled, and one more for enabling UDP port mapping if discovery is + // enabled. We make it buffered to avoid blocking setup while a mapping request is in + // progress. + p.portMappingRegister = make(chan *portMapping, 2) + + switch p.NAT.(type) { + case nil: + // No NAT interface configured. + go p.consumePortMappingRequests() + + case nat.ExtIP: + // ExtIP doesn't block, set the IP right away. + ip, _ := p.NAT.ExternalIP() + p.localNode.SetStaticIP(ip) + go p.consumePortMappingRequests() + + case nat.STUN: + // STUN doesn't block, set the IP right away. + ip, _ := p.NAT.ExternalIP() + p.localNode.SetStaticIP(ip) + go p.consumePortMappingRequests() + + default: + go p.portMappingLoop() + } +} + +func (p *PortalProtocol) consumePortMappingRequests() { + for { + select { + case <-p.closeCtx.Done(): + return + case <-p.portMappingRegister: + } + } +} + +// portMappingLoop manages port mappings for UDP and TCP. +func (p *PortalProtocol) portMappingLoop() { + newLogger := func(proto string, e int, i int) log.Logger { + return log.New("proto", proto, "extport", e, "intport", i, "interface", p.NAT) + } + + var ( + mappings = make(map[string]*portMapping, 2) + refresh = mclock.NewAlarm(p.clock) + extip = mclock.NewAlarm(p.clock) + lastExtIP net.IP + ) + extip.Schedule(p.clock.Now()) + defer func() { + refresh.Stop() + extip.Stop() + for _, m := range mappings { + if m.extPort != 0 { + log := newLogger(m.protocol, m.extPort, m.port) + log.Debug("Deleting port mapping") + p.NAT.DeleteMapping(m.protocol, m.extPort, m.port) + } + } + }() + + for { + // Schedule refresh of existing mappings. + for _, m := range mappings { + refresh.Schedule(m.nextTime) + } + + select { + case <-p.closeCtx.Done(): + return + + case <-extip.C(): + extip.Schedule(p.clock.Now().Add(extipRetryInterval)) + ip, err := p.NAT.ExternalIP() + if err != nil { + log.Debug("Couldn't get external IP", "err", err, "interface", p.NAT) + } else if !ip.Equal(lastExtIP) { + log.Debug("External IP changed", "ip", extip, "interface", p.NAT) + } else { + continue + } + // Here, we either failed to get the external IP, or it has changed. + lastExtIP = ip + p.localNode.SetStaticIP(ip) + p.Log.Debug("set static ip in nat", "ip", p.localNode.Node().IP().String()) + // Ensure port mappings are refreshed in case we have moved to a new network. + for _, m := range mappings { + m.nextTime = p.clock.Now() + } + + case m := <-p.portMappingRegister: + if m.protocol != "TCP" && m.protocol != "UDP" { + panic("unknown NAT protocol name: " + m.protocol) + } + mappings[m.protocol] = m + m.nextTime = p.clock.Now() + + case <-refresh.C(): + for _, m := range mappings { + if p.clock.Now() < m.nextTime { + continue + } + + external := m.port + if m.extPort != 0 { + external = m.extPort + } + log := newLogger(m.protocol, external, m.port) + + log.Trace("Attempting port mapping") + port, err := p.NAT.AddMapping(m.protocol, external, m.port, m.name, portMapDuration) + if err != nil { + log.Debug("Couldn't add port mapping", "err", err) + m.extPort = 0 + m.nextTime = p.clock.Now().Add(portMapRetryInterval) + continue + } + // It was mapped! + m.extPort = int(port) + m.nextTime = p.clock.Now().Add(portMapRefreshInterval) + if external != m.extPort { + log = newLogger(m.protocol, m.extPort, m.port) + log.Info("NAT mapped alternative port") + } else { + log.Info("NAT mapped port") + } + + // Update port in local ENR. + switch m.protocol { + case "TCP": + p.localNode.Set(enr.TCP(m.extPort)) + case "UDP": + p.localNode.SetFallbackUDP(m.extPort) + } + } + } + } +} diff --git a/portalnetwork/portal_protocol.go b/portalnetwork/portal_protocol.go new file mode 100644 index 000000000000..41e20fcfe9fb --- /dev/null +++ b/portalnetwork/portal_protocol.go @@ -0,0 +1,1918 @@ +package portalnetwork + +import ( + "bytes" + "context" + "crypto/ecdsa" + crand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "math/rand" + "net" + "slices" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/VictoriaMetrics/fastcache" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/discover/portalwire" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/ethereum/go-ethereum/rlp" + ssz "github.com/ferranbt/fastssz" + "github.com/holiman/uint256" + "github.com/optimism-java/utp-go" + "github.com/optimism-java/utp-go/libutp" + "github.com/prysmaticlabs/go-bitfield" + "github.com/tetratelabs/wabin/leb128" +) + +const ( + + // TalkResp message is a response message so the session is established and a + // regular discv5 packet is assumed for size calculation. + // Regular message = IV + header + message + // talkResp message = rlp: [request-id, response] + talkRespOverhead = 16 + // IV size + 55 + // header size + 1 + // talkResp msg id + 3 + // rlp encoding outer list, max length will be encoded in 2 bytes + 9 + // request id (max = 8) + 1 byte from rlp encoding byte string + 3 + // rlp encoding response byte string, max length in 2 bytes + 16 // HMAC + + portalFindnodesResultLimit = 32 + + defaultUTPConnectTimeout = 15 * time.Second + + defaultUTPWriteTimeout = 60 * time.Second + + defaultUTPReadTimeout = 60 * time.Second + + // These are the concurrent offers per Portal wire protocol that is running. + // Using the `offerQueue` allows for limiting the amount of offers send and + // thus how many streams can be started. + // TODO: + // More thought needs to go into this as it is currently on a per network + // basis. Keep it simple like that? Or limit it better at the stream transport + // level? In the latter case, this might still need to be checked/blocked at + // the very start of sending the offer, because blocking/waiting too long + // between the received accept message and actually starting the stream and + // sending data could give issues due to timeouts on the other side. + // And then there are still limits to be applied also for FindContent and the + // incoming directions. + concurrentOffers = 50 +) + +const ( + TransientOfferRequestKind byte = 0x01 + PersistOfferRequestKind byte = 0x02 +) + +type ClientTag string + +func (c ClientTag) ENRKey() string { return "c" } + +const Tag ClientTag = "shisui" + +var ErrNilContentKey = errors.New("content key cannot be nil") + +var ContentNotFound = storage.ErrContentNotFound + +var ErrEmptyResp = errors.New("empty resp") + +var MaxDistance = hexutil.MustDecode("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + +type ContentElement struct { + Node enode.ID + ContentKeys [][]byte + Contents [][]byte +} + +type ContentEntry struct { + ContentKey []byte + Content []byte +} + +type TransientOfferRequest struct { + Contents []*ContentEntry +} + +type PersistOfferRequest struct { + ContentKeys [][]byte +} + +type OfferRequest struct { + Kind byte + Request interface{} +} + +type OfferRequestWithNode struct { + Request *OfferRequest + Node *enode.Node +} + +type ContentInfoResp struct { + Content []byte + UtpTransfer bool +} + +type traceContentInfoResp struct { + Node *enode.Node + Flag byte + Content any + UtpTransfer bool +} + +type PortalProtocolOption func(p *PortalProtocol) + +type PortalProtocolConfig struct { + BootstrapNodes []*enode.Node + // NodeIP net.IP + ListenAddr string + NetRestrict *netutil.Netlist + NodeRadius *uint256.Int + RadiusCacheSize int + NodeDBPath string + NAT nat.Interface + clock mclock.Clock +} + +func DefaultPortalProtocolConfig() *PortalProtocolConfig { + return &PortalProtocolConfig{ + BootstrapNodes: make([]*enode.Node, 0), + ListenAddr: ":9009", + NetRestrict: nil, + RadiusCacheSize: 32 * 1024 * 1024, + NodeDBPath: "", + clock: mclock.System{}, + } +} + +type PortalProtocol struct { + table *discover.Table + + protocolId string + protocolName string + + DiscV5 *discover.UDPv5 + localNode *enode.LocalNode + Log log.Logger + PrivateKey *ecdsa.PrivateKey + NetRestrict *netutil.Netlist + BootstrapNodes []*enode.Node + conn discover.UDPConn + + Utp *PortalUtp + connIdGen libutp.ConnIdGenerator + + validSchemes enr.IdentityScheme + radiusCache *fastcache.Cache + closeCtx context.Context + cancelCloseCtx context.CancelFunc + storage storage.ContentStorage + toContentId func(contentKey []byte) []byte + + contentQueue chan *ContentElement + offerQueue chan *OfferRequestWithNode + + portMappingRegister chan *portMapping + clock mclock.Clock + NAT nat.Interface + + portalMetrics *portalMetrics +} + +func defaultContentIdFunc(contentKey []byte) []byte { + digest := sha256.Sum256(contentKey) + return digest[:] +} + +func NewPortalProtocol(config *PortalProtocolConfig, protocolId portalwire.ProtocolId, privateKey *ecdsa.PrivateKey, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *PortalUtp, storage storage.ContentStorage, contentQueue chan *ContentElement, opts ...PortalProtocolOption) (*PortalProtocol, error) { + closeCtx, cancelCloseCtx := context.WithCancel(context.Background()) + + protocol := &PortalProtocol{ + protocolId: string(protocolId), + protocolName: protocolId.Name(), + Log: log.New("protocol", protocolId.Name()), + PrivateKey: privateKey, + NetRestrict: config.NetRestrict, + BootstrapNodes: config.BootstrapNodes, + radiusCache: fastcache.New(config.RadiusCacheSize), + closeCtx: closeCtx, + cancelCloseCtx: cancelCloseCtx, + localNode: localNode, + validSchemes: enode.ValidSchemes, + storage: storage, + toContentId: defaultContentIdFunc, + contentQueue: contentQueue, + offerQueue: make(chan *OfferRequestWithNode, concurrentOffers), + conn: conn, + DiscV5: discV5, + Utp: utp, + NAT: config.NAT, + clock: config.clock, + connIdGen: libutp.NewConnIdGenerator(), + } + + for _, opt := range opts { + opt(protocol) + } + + if metrics.Enabled { + protocol.portalMetrics = newPortalMetrics(protocolId.Name()) + } + + return protocol, nil +} + +func (p *PortalProtocol) Start() error { + p.setupPortMapping() + + err := p.setupDiscV5AndTable() + if err != nil { + return err + } + + p.DiscV5.RegisterTalkHandler(p.protocolId, p.handleTalkRequest) + if p.Utp != nil { + err = p.Utp.Start() + } + if err != nil { + return err + } + + go p.table.Loop() + + for i := 0; i < concurrentOffers; i++ { + go p.offerWorker() + } + + // wait for both initialization processes to complete + p.DiscV5.Table().WaitInit() + p.table.WaitInit() + return nil +} + +func (p *PortalProtocol) Stop() { + p.cancelCloseCtx() + p.table.Close() + p.DiscV5.Close() + if p.Utp != nil { + p.Utp.Stop() + } +} +func (p *PortalProtocol) RoutingTableInfo() [][]string { + return p.table.NodeIds() +} + +func (p *PortalProtocol) AddEnr(n *enode.Node) { + added := p.table.AddInboundNode(n) + if !added { + p.Log.Warn("add node failed", "id", n.ID(), "ip", n.IPAddr()) + return + } + id := n.ID().String() + p.radiusCache.Set([]byte(id), MaxDistance) +} + +func (p *PortalProtocol) Radius() *uint256.Int { + return p.storage.Radius() +} + +func (p *PortalProtocol) setupUDPListening() error { + laddr := p.conn.LocalAddr().(*net.UDPAddr) + p.localNode.SetFallbackUDP(laddr.Port) + p.Log.Debug("UDP listener up", "addr", laddr) + // TODO: NAT + if !laddr.IP.IsLoopback() && !laddr.IP.IsPrivate() { + p.portMappingRegister <- &portMapping{ + protocol: "UDP", + name: "ethereum portal peer discovery", + port: laddr.Port, + } + } + return nil +} + +func (p *PortalProtocol) setupDiscV5AndTable() error { + err := p.setupUDPListening() + if err != nil { + return err + } + + cfg := discover.Config{ + PrivateKey: p.PrivateKey, + NetRestrict: p.NetRestrict, + Bootnodes: p.BootstrapNodes, + Log: p.Log, + } + + p.table, err = discover.NewTable(p, p.localNode.Database(), cfg) + if err != nil { + return err + } + + return nil +} + +func (p *PortalProtocol) Ping(node *enode.Node) (uint64, error) { + pong, err := p.pingInner(node) + if err != nil { + return 0, err + } + + return pong.EnrSeq, nil +} + +func (p *PortalProtocol) pingInner(node *enode.Node) (*portalwire.Pong, error) { + enrSeq := p.Self().Seq() + radiusBytes, err := p.Radius().MarshalSSZ() + if err != nil { + return nil, err + } + customPayload := &portalwire.PingPongCustomData{ + Radius: radiusBytes, + } + + customPayloadBytes, err := customPayload.MarshalSSZ() + if err != nil { + return nil, err + } + + pingRequest := &portalwire.Ping{ + EnrSeq: enrSeq, + CustomPayload: customPayloadBytes, + } + + p.Log.Trace(">> PING/"+p.protocolName, "protocol", p.protocolName, "ip", p.Self().IP().String(), "source", p.Self().ID(), "target", node.ID(), "ping", pingRequest) + if metrics.Enabled { + p.portalMetrics.messagesSentPing.Mark(1) + } + pingRequestBytes, err := pingRequest.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRequestBytes := make([]byte, 0, len(pingRequestBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.PING) + talkRequestBytes = append(talkRequestBytes, pingRequestBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + + if err != nil { + return nil, err + } + + p.Log.Trace("<< PONG/"+p.protocolName, "source", p.Self().ID(), "target", node.ID(), "res", talkResp) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPong.Mark(1) + } + + return p.processPong(node, talkResp) +} + +func (p *PortalProtocol) findNodes(node *enode.Node, distances []uint) ([]*enode.Node, error) { + if p.localNode.ID().String() == node.ID().String() { + return make([]*enode.Node, 0), nil + } + + distancesBytes := make([][2]byte, len(distances)) + for i, distance := range distances { + copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), uint16(distance))) + } + + findNodes := &portalwire.FindNodes{ + Distances: distancesBytes, + } + + p.Log.Trace(">> FIND_NODES/"+p.protocolName, "id", node.ID(), "findNodes", findNodes) + if metrics.Enabled { + p.portalMetrics.messagesSentFindNodes.Mark(1) + } + findNodesBytes, err := findNodes.MarshalSSZ() + if err != nil { + p.Log.Error("failed to marshal find nodes request", "err", err) + return nil, err + } + + talkRequestBytes := make([]byte, 0, len(findNodesBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.FINDNODES) + talkRequestBytes = append(talkRequestBytes, findNodesBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + if err != nil { + p.Log.Error("failed to send find nodes request", "ip", node.IP().String(), "port", node.UDP(), "err", err) + return nil, err + } + + return p.processNodes(node, talkResp, distances) +} + +func (p *PortalProtocol) findContent(node *enode.Node, contentKey []byte) (byte, interface{}, error) { + findContent := &portalwire.FindContent{ + ContentKey: contentKey, + } + + p.Log.Trace(">> FIND_CONTENT/"+p.protocolName, "id", node.ID(), "findContent", findContent) + if metrics.Enabled { + p.portalMetrics.messagesSentFindContent.Mark(1) + } + findContentBytes, err := findContent.MarshalSSZ() + if err != nil { + p.Log.Error("failed to marshal find content request", "err", err) + return 0xff, nil, err + } + + talkRequestBytes := make([]byte, 0, len(findContentBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.FINDCONTENT) + talkRequestBytes = append(talkRequestBytes, findContentBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + if err != nil { + p.Log.Error("failed to send find content request", "ip", node.IP().String(), "port", node.UDP(), "err", err) + return 0xff, nil, err + } + + return p.processContent(node, talkResp) +} + +func (p *PortalProtocol) offer(node *enode.Node, offerRequest *OfferRequest) ([]byte, error) { + contentKeys := getContentKeys(offerRequest) + + offer := &portalwire.Offer{ + ContentKeys: contentKeys, + } + + p.Log.Trace(">> OFFER/"+p.protocolName, "offer", offer) + if metrics.Enabled { + p.portalMetrics.messagesSentOffer.Mark(1) + } + offerBytes, err := offer.MarshalSSZ() + if err != nil { + p.Log.Error("failed to marshal offer request", "err", err) + return nil, err + } + + talkRequestBytes := make([]byte, 0, len(offerBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.OFFER) + talkRequestBytes = append(talkRequestBytes, offerBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + if err != nil { + p.Log.Error("failed to send offer request", "err", err) + return nil, err + } + + return p.processOffer(node, talkResp, offerRequest) +} + +func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request *OfferRequest) ([]byte, error) { + var err error + if len(resp) == 0 { + return nil, ErrEmptyResp + } + if resp[0] != portalwire.ACCEPT { + return nil, fmt.Errorf("invalid accept response") + } + + p.Log.Info("will process Offer", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) + + accept := &portalwire.Accept{} + err = accept.UnmarshalSSZ(resp[1:]) + if err != nil { + return nil, err + } + + p.Log.Trace("<< ACCEPT/"+p.protocolName, "id", target.ID(), "accept", accept) + if metrics.Enabled { + p.portalMetrics.messagesReceivedAccept.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + var contentKeyLen int + if request.Kind == TransientOfferRequestKind { + contentKeyLen = len(request.Request.(*TransientOfferRequest).Contents) + } else { + contentKeyLen = len(request.Request.(*PersistOfferRequest).ContentKeys) + } + + contentKeyBitlist := bitfield.Bitlist(accept.ContentKeys) + if contentKeyBitlist.Len() != uint64(contentKeyLen) { + return nil, fmt.Errorf("accepted content key bitlist has invalid size, expected %d, got %d", contentKeyLen, contentKeyBitlist.Len()) + } + + if contentKeyBitlist.Count() == 0 { + return nil, nil + } + + connId := binary.BigEndian.Uint16(accept.ConnectionId[:]) + go func(ctx context.Context) { + var conn net.Conn + defer func() { + if conn == nil { + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + for { + select { + case <-ctx.Done(): + return + default: + contents := make([][]byte, 0, contentKeyBitlist.Count()) + var content []byte + if request.Kind == TransientOfferRequestKind { + for _, index := range contentKeyBitlist.BitIndices() { + content = request.Request.(*TransientOfferRequest).Contents[index].Content + contents = append(contents, content) + } + } else { + for _, index := range contentKeyBitlist.BitIndices() { + contentKey := request.Request.(*PersistOfferRequest).ContentKeys[index] + contentId := p.toContentId(contentKey) + if contentId != nil { + content, err = p.storage.Get(contentKey, contentId) + if err != nil { + p.Log.Error("failed to get content from storage", "err", err) + contents = append(contents, []byte{}) + } else { + contents = append(contents, content) + } + } else { + contents = append(contents, []byte{}) + } + } + } + + var contentsPayload []byte + contentsPayload, err = encodeContents(contents) + if err != nil { + p.Log.Error("failed to encode contents", "err", err) + return + } + + connctx, conncancel := context.WithTimeout(ctx, defaultUTPConnectTimeout) + conn, err = p.Utp.DialWithCid(connctx, target, libutp.ReceConnId(connId).SendId()) + conncancel() + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailConn.Inc(1) + } + p.Log.Error("failed to dial utp connection", "err", err) + return + } + + err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailDeadline.Inc(1) + } + p.Log.Error("failed to set write deadline", "err", err) + return + } + + var written int + written, err = conn.Write(contentsPayload) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailWrite.Inc(1) + } + p.Log.Error("failed to write to utp connection", "err", err) + return + } + p.Log.Trace(">> CONTENT/"+p.protocolName, "id", target.ID(), "contents", contents, "size", written) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + p.portalMetrics.utpOutSuccess.Inc(1) + } + return + } + } + }(p.closeCtx) + + return accept.ContentKeys, nil +} + +func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, interface{}, error) { + if len(resp) == 0 { + return 0x00, nil, ErrEmptyResp + } + + if resp[0] != portalwire.CONTENT { + return 0xff, nil, fmt.Errorf("invalid content response") + } + + p.Log.Info("will process content", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) + + switch resp[1] { + case portalwire.ContentRawSelector: + content := &portalwire.Content{} + err := content.UnmarshalSSZ(resp[2:]) + if err != nil { + return 0xff, nil, err + } + + p.Log.Trace("<< CONTENT/"+p.protocolName, "id", target.ID(), "content", content) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + return resp[1], content.Content, nil + case portalwire.ContentConnIdSelector: + connIdMsg := &portalwire.ConnectionId{} + err := connIdMsg.UnmarshalSSZ(resp[2:]) + if err != nil { + return 0xff, nil, err + } + + p.Log.Trace("<< CONTENT_CONNECTION_ID/"+p.protocolName, "id", target.ID(), "resp", common.Bytes2Hex(resp), "connIdMsg", connIdMsg) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + connctx, conncancel := context.WithTimeout(p.closeCtx, defaultUTPConnectTimeout) + connId := binary.BigEndian.Uint16(connIdMsg.Id[:]) + conn, err := p.Utp.DialWithCid(connctx, target, libutp.ReceConnId(connId).SendId()) + defer func() { + if conn == nil { + if metrics.Enabled { + p.portalMetrics.utpInFailConn.Inc(1) + } + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + conncancel() + if err != nil { + return 0xff, nil, err + } + + err = conn.SetReadDeadline(time.Now().Add(defaultUTPReadTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailDeadline.Inc(1) + } + return 0xff, nil, err + } + // Read ALL the data from the connection until EOF and return it + data, err := io.ReadAll(conn) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailRead.Inc(1) + } + p.Log.Error("failed to read from utp connection", "err", err) + return 0xff, nil, err + } + p.Log.Trace("<< CONTENT/"+p.protocolName, "id", target.ID(), "size", len(data), "data", data) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + p.portalMetrics.utpInSuccess.Inc(1) + } + return resp[1], data, nil + case portalwire.ContentEnrsSelector: + enrs := &portalwire.Enrs{} + err := enrs.UnmarshalSSZ(resp[2:]) + + if err != nil { + return 0xff, nil, err + } + + p.Log.Trace("<< CONTENT_ENRS/"+p.protocolName, "id", target.ID(), "enrs", enrs) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + nodes := p.filterNodes(target, enrs.Enrs, nil) + return resp[1], nodes, nil + default: + return 0xff, nil, fmt.Errorf("invalid content response") + } +} + +func (p *PortalProtocol) processNodes(target *enode.Node, resp []byte, distances []uint) ([]*enode.Node, error) { + if len(resp) == 0 { + return nil, ErrEmptyResp + } + + if resp[0] != portalwire.NODES { + return nil, fmt.Errorf("invalid nodes response") + } + + nodesResp := &portalwire.Nodes{} + err := nodesResp.UnmarshalSSZ(resp[1:]) + if err != nil { + return nil, err + } + + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + nodes := p.filterNodes(target, nodesResp.Enrs, distances) + + return nodes, nil +} + +func (p *PortalProtocol) filterNodes(target *enode.Node, enrs [][]byte, distances []uint) []*enode.Node { + var ( + nodes []*enode.Node + seen = make(map[enode.ID]struct{}) + err error + verified = 0 + n *enode.Node + ) + + for _, b := range enrs { + record := &enr.Record{} + err = rlp.DecodeBytes(b, record) + if err != nil { + p.Log.Error("Invalid record in nodes response", "id", target.ID(), "err", err) + continue + } + n, err = p.verifyResponseNode(target, record, distances, seen) + if err != nil { + p.Log.Error("Invalid record in nodes response", "id", target.ID(), "err", err) + continue + } + verified++ + nodes = append(nodes, n) + } + + p.Log.Trace("<< NODES/"+p.protocolName, "id", target.ID(), "total", len(enrs), "verified", verified, "nodes", nodes) + if metrics.Enabled { + p.portalMetrics.messagesReceivedNodes.Mark(1) + } + return nodes +} + +func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwire.Pong, error) { + if len(resp) == 0 { + return nil, ErrEmptyResp + } + if resp[0] != portalwire.PONG { + return nil, fmt.Errorf("invalid pong response") + } + pong := &portalwire.Pong{} + err := pong.UnmarshalSSZ(resp[1:]) + if err != nil { + return nil, err + } + + p.Log.Trace("<< PONG_RESPONSE/"+p.protocolName, "id", target.ID(), "pong", pong) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPong.Mark(1) + } + + customPayload := &portalwire.PingPongCustomData{} + err = customPayload.UnmarshalSSZ(pong.CustomPayload) + if err != nil { + return nil, err + } + + p.Log.Trace("<< PONG_RESPONSE/"+p.protocolName, "id", target.ID(), "pong", pong, "customPayload", customPayload) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPong.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + + p.radiusCache.Set([]byte(target.ID().String()), customPayload.Radius) + return pong, nil +} + +func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { + if n := p.DiscV5.GetNode(id); n != nil { + p.table.AddInboundNode(n) + } + + msgCode := msg[0] + + switch msgCode { + case portalwire.PING: + pingRequest := &portalwire.Ping{} + err := pingRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal ping request", "err", err) + return nil + } + + p.Log.Trace("<< PING/"+p.protocolName, "protocol", p.protocolName, "source", id, "pingRequest", pingRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPing.Mark(1) + } + resp, err := p.handlePing(id, pingRequest) + if err != nil { + p.Log.Error("failed to handle ping request", "err", err) + return nil + } + + return resp + case portalwire.FINDNODES: + findNodesRequest := &portalwire.FindNodes{} + err := findNodesRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal find nodes request", "err", err) + return nil + } + + p.Log.Trace("<< FIND_NODES/"+p.protocolName, "protocol", p.protocolName, "source", id, "findNodesRequest", findNodesRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedFindNodes.Mark(1) + } + resp, err := p.handleFindNodes(addr, findNodesRequest) + if err != nil { + p.Log.Error("failed to handle find nodes request", "err", err) + return nil + } + + return resp + case portalwire.FINDCONTENT: + findContentRequest := &portalwire.FindContent{} + err := findContentRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal find content request", "err", err) + return nil + } + + p.Log.Trace("<< FIND_CONTENT/"+p.protocolName, "protocol", p.protocolName, "source", id, "findContentRequest", findContentRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedFindContent.Mark(1) + } + resp, err := p.handleFindContent(id, addr, findContentRequest) + if err != nil { + p.Log.Error("failed to handle find content request", "err", err) + return nil + } + + return resp + case portalwire.OFFER: + offerRequest := &portalwire.Offer{} + err := offerRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal offer request", "err", err) + return nil + } + + p.Log.Trace("<< OFFER/"+p.protocolName, "protocol", p.protocolName, "source", id, "offerRequest", offerRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedOffer.Mark(1) + } + resp, err := p.handleOffer(id, addr, offerRequest) + if err != nil { + p.Log.Error("failed to handle offer request", "err", err) + return nil + } + + return resp + } + + return nil +} + +func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, error) { + pingCustomPayload := &portalwire.PingPongCustomData{} + err := pingCustomPayload.UnmarshalSSZ(ping.CustomPayload) + if err != nil { + return nil, err + } + + p.radiusCache.Set([]byte(id.String()), pingCustomPayload.Radius) + + enrSeq := p.Self().Seq() + radiusBytes, err := p.Radius().MarshalSSZ() + if err != nil { + return nil, err + } + pongCustomPayload := &portalwire.PingPongCustomData{ + Radius: radiusBytes, + } + + pongCustomPayloadBytes, err := pongCustomPayload.MarshalSSZ() + if err != nil { + return nil, err + } + + pong := &portalwire.Pong{ + EnrSeq: enrSeq, + CustomPayload: pongCustomPayloadBytes, + } + + p.Log.Trace(">> PONG/"+p.protocolName, "protocol", p.protocolName, "source", id, "pong", pong) + if metrics.Enabled { + p.portalMetrics.messagesSentPong.Mark(1) + } + pongBytes, err := pong.MarshalSSZ() + + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(pongBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.PONG) + talkRespBytes = append(talkRespBytes, pongBytes...) + + return talkRespBytes, nil +} + +func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *portalwire.FindNodes) ([]byte, error) { + distances := make([]uint, len(request.Distances)) + for i, distance := range request.Distances { + distances[i] = uint(ssz.UnmarshallUint16(distance[:])) + } + + nodes := p.collectTableNodes(fromAddr.IP, distances, portalFindnodesResultLimit) + + nodesOverhead := 1 + 1 + 4 // msg id + total + container offset + maxPayloadSize := v5wire.MaxPacketSize - talkRespOverhead - nodesOverhead + enrOverhead := 4 //per added ENR, 4 bytes offset overhead + + enrs := p.truncateNodes(nodes, maxPayloadSize, enrOverhead) + + nodesMsg := &portalwire.Nodes{ + Total: 1, + Enrs: enrs, + } + + p.Log.Trace(">> NODES/"+p.protocolName, "protocol", p.protocolName, "source", fromAddr, "nodes", nodesMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentNodes.Mark(1) + } + nodesMsgBytes, err := nodesMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(nodesMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.NODES) + talkRespBytes = append(talkRespBytes, nodesMsgBytes...) + + return talkRespBytes, nil +} + +func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, request *portalwire.FindContent) ([]byte, error) { + contentOverhead := 1 + 1 // msg id + SSZ Union selector + maxPayloadSize := v5wire.MaxPacketSize - talkRespOverhead - contentOverhead + enrOverhead := 4 //per added ENR, 4 bytes offset overhead + var err error + contentKey := request.ContentKey + contentId := p.toContentId(contentKey) + if contentId == nil { + return nil, ErrNilContentKey + } + + var content []byte + content, err = p.storage.Get(contentKey, contentId) + if err != nil && !errors.Is(err, ContentNotFound) { + return nil, err + } + + if errors.Is(err, ContentNotFound) { + closestNodes := p.findNodesCloseToContent(contentId, portalFindnodesResultLimit) + for i, n := range closestNodes { + if n.ID() == id { + closestNodes = append(closestNodes[:i], closestNodes[i+1:]...) + break + } + } + + enrs := p.truncateNodes(closestNodes, maxPayloadSize, enrOverhead) + // TODO fix when no content and no enrs found + if len(enrs) == 0 { + enrs = nil + } + + enrsMsg := &portalwire.Enrs{ + Enrs: enrs, + } + + p.Log.Trace(">> CONTENT_ENRS/"+p.protocolName, "protocol", p.protocolName, "source", addr, "enrs", enrsMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + } + var enrsMsgBytes []byte + enrsMsgBytes, err = enrsMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + contentMsgBytes := make([]byte, 0, len(enrsMsgBytes)+1) + contentMsgBytes = append(contentMsgBytes, portalwire.ContentEnrsSelector) + contentMsgBytes = append(contentMsgBytes, enrsMsgBytes...) + + talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, contentMsgBytes...) + + return talkRespBytes, nil + } else if len(content) <= maxPayloadSize { + rawContentMsg := &portalwire.Content{ + Content: content, + } + + p.Log.Trace(">> CONTENT_RAW/"+p.protocolName, "protocol", p.protocolName, "source", addr, "content", rawContentMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + } + + var rawContentMsgBytes []byte + rawContentMsgBytes, err = rawContentMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + contentMsgBytes := make([]byte, 0, len(rawContentMsgBytes)+1) + contentMsgBytes = append(contentMsgBytes, portalwire.ContentRawSelector) + contentMsgBytes = append(contentMsgBytes, rawContentMsgBytes...) + + talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, contentMsgBytes...) + + return talkRespBytes, nil + } else { + connectionId := p.connIdGen.GenCid(id, false) + + go func(bctx context.Context, connId *libutp.ConnId) { + var conn *utp.Conn + var connectCtx context.Context + var cancel context.CancelFunc + defer func() { + p.connIdGen.Remove(connectionId) + if conn == nil { + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + for { + select { + case <-bctx.Done(): + return + default: + p.Log.Debug("will accept find content conn from: ", "nodeId", id.String(), "source", addr, "connId", connId) + connectCtx, cancel = context.WithTimeout(bctx, defaultUTPConnectTimeout) + conn, err = p.Utp.AcceptWithCid(connectCtx, id, connectionId) + cancel() + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailConn.Inc(1) + } + p.Log.Error("failed to accept utp connection for handle find content", "connId", connectionId.SendId(), "err", err) + return + } + + err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailDeadline.Inc(1) + } + p.Log.Error("failed to set write deadline", "err", err) + return + } + + var n int + n, err = conn.Write(content) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailWrite.Inc(1) + } + p.Log.Error("failed to write content to utp connection", "err", err) + return + } + + if metrics.Enabled { + p.portalMetrics.utpOutSuccess.Inc(1) + } + p.Log.Trace("wrote content size to utp connection", "n", n) + return + } + } + }(p.closeCtx, connectionId) + + idBuffer := make([]byte, 2) + binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) + connIdMsg := &portalwire.ConnectionId{ + Id: idBuffer, + } + + p.Log.Trace(">> CONTENT_CONNECTION_ID/"+p.protocolName, "protocol", p.protocolName, "source", addr, "connId", connIdMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + } + var connIdMsgBytes []byte + connIdMsgBytes, err = connIdMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + contentMsgBytes := make([]byte, 0, len(connIdMsgBytes)+1) + contentMsgBytes = append(contentMsgBytes, portalwire.ContentConnIdSelector) + contentMsgBytes = append(contentMsgBytes, connIdMsgBytes...) + + talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, contentMsgBytes...) + + return talkRespBytes, nil + } +} + +func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *portalwire.Offer) ([]byte, error) { + var err error + contentKeyBitlist := bitfield.NewBitlist(uint64(len(request.ContentKeys))) + if len(p.contentQueue) >= cap(p.contentQueue) { + acceptMsg := &portalwire.Accept{ + ConnectionId: []byte{0, 0}, + ContentKeys: contentKeyBitlist, + } + + p.Log.Trace(">> ACCEPT/"+p.protocolName, "protocol", p.protocolName, "source", addr, "accept", acceptMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentAccept.Mark(1) + } + var acceptMsgBytes []byte + acceptMsgBytes, err = acceptMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) + talkRespBytes = append(talkRespBytes, acceptMsgBytes...) + + return talkRespBytes, nil + } + + contentKeys := make([][]byte, 0) + for i, contentKey := range request.ContentKeys { + contentId := p.toContentId(contentKey) + if contentId != nil { + if inRange(p.Self().ID(), p.Radius(), contentId) { + if _, err = p.storage.Get(contentKey, contentId); err != nil { + contentKeyBitlist.SetBitAt(uint64(i), true) + contentKeys = append(contentKeys, contentKey) + } + } + } else { + return nil, ErrNilContentKey + } + } + + idBuffer := make([]byte, 2) + if contentKeyBitlist.Count() != 0 { + connectionId := p.connIdGen.GenCid(id, false) + + go func(bctx context.Context, connId *libutp.ConnId) { + var conn *utp.Conn + var connectCtx context.Context + var cancel context.CancelFunc + defer func() { + p.connIdGen.Remove(connectionId) + if conn == nil { + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + for { + select { + case <-bctx.Done(): + return + default: + p.Log.Debug("will accept offer conn from: ", "source", addr, "connId", connId) + connectCtx, cancel = context.WithTimeout(bctx, defaultUTPConnectTimeout) + conn, err = p.Utp.AcceptWithCid(connectCtx, id, connectionId) + cancel() + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailConn.Inc(1) + } + p.Log.Error("failed to accept utp connection for handle offer", "connId", connectionId.SendId(), "err", err) + return + } + + err = conn.SetReadDeadline(time.Now().Add(defaultUTPReadTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailDeadline.Inc(1) + } + p.Log.Error("failed to set read deadline", "err", err) + return + } + // Read ALL the data from the connection until EOF and return it + var data []byte + data, err = io.ReadAll(conn) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailRead.Inc(1) + } + p.Log.Error("failed to read from utp connection", "err", err) + return + } + p.Log.Trace("<< OFFER_CONTENT/"+p.protocolName, "id", id, "size", len(data), "data", data) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + + err = p.handleOfferedContents(id, contentKeys, data) + if err != nil { + p.Log.Error("failed to handle offered Contents", "err", err) + return + } + + if metrics.Enabled { + p.portalMetrics.utpInSuccess.Inc(1) + } + return + } + } + }(p.closeCtx, connectionId) + + binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) + } else { + binary.BigEndian.PutUint16(idBuffer, uint16(0)) + } + + acceptMsg := &portalwire.Accept{ + ConnectionId: idBuffer, + ContentKeys: []byte(contentKeyBitlist), + } + + p.Log.Trace(">> ACCEPT/"+p.protocolName, "protocol", p.protocolName, "source", addr, "accept", acceptMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentAccept.Mark(1) + } + var acceptMsgBytes []byte + acceptMsgBytes, err = acceptMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) + talkRespBytes = append(talkRespBytes, acceptMsgBytes...) + + return talkRespBytes, nil +} + +func (p *PortalProtocol) handleOfferedContents(id enode.ID, keys [][]byte, payload []byte) error { + contents, err := decodeContents(payload) + if err != nil { + if metrics.Enabled { + p.portalMetrics.contentDecodedFalse.Inc(1) + } + return err + } + + keyLen := len(keys) + contentLen := len(contents) + if keyLen != contentLen { + if metrics.Enabled { + p.portalMetrics.contentDecodedFalse.Inc(1) + } + return fmt.Errorf("content keys len %d doesn't match content values len %d", keyLen, contentLen) + } + + contentElement := &ContentElement{ + Node: id, + ContentKeys: keys, + Contents: contents, + } + + p.contentQueue <- contentElement + + if metrics.Enabled { + p.portalMetrics.contentDecodedTrue.Inc(1) + } + return nil +} + +func (p *PortalProtocol) Self() *enode.Node { + return p.localNode.Node() +} + +func (p *PortalProtocol) RequestENR(n *enode.Node) (*enode.Node, error) { + nodes, err := p.findNodes(n, []uint{0}) + if err != nil { + return nil, err + } + if len(nodes) != 1 { + return nil, fmt.Errorf("%d nodes in response for distance zero", len(nodes)) + } + return nodes[0], nil +} + +func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, distances []uint, seen map[enode.ID]struct{}) (*enode.Node, error) { + n, err := enode.New(p.validSchemes, r) + if err != nil { + return nil, err + } + if err = netutil.CheckRelayIP(sender.IP(), n.IP()); err != nil { + return nil, err + } + if p.NetRestrict != nil && !p.NetRestrict.Contains(n.IP()) { + return nil, errors.New("not contained in netrestrict list") + } + if n.UDP() <= 1024 { + return nil, discover.ErrLowPort + } + if distances != nil { + nd := enode.LogDist(sender.ID(), n.ID()) + if !slices.Contains(distances, uint(nd)) { + return nil, errors.New("does not match any requested distance") + } + } + if _, ok := seen[n.ID()]; ok { + return nil, fmt.Errorf("duplicate record") + } + seen[n.ID()] = struct{}{} + return n, nil +} + +// LookupRandom looks up a random target. +// This is needed to satisfy the transport interface. +func (p *PortalProtocol) LookupRandom() []*enode.Node { + return p.newRandomLookup(p.closeCtx).Run() +} + +// LookupSelf looks up our own node ID. +// This is needed to satisfy the transport interface. +func (p *PortalProtocol) LookupSelf() []*enode.Node { + return p.newLookup(p.closeCtx, p.Self().ID()).Run() +} + +func (p *PortalProtocol) newRandomLookup(ctx context.Context) *discover.Lookup { + var target enode.ID + _, _ = crand.Read(target[:]) + return p.newLookup(ctx, target) +} + +func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *discover.Lookup { + return discover.NewLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { + return p.lookupWorker(n, target) + }) +} + +// lookupWorker performs FINDNODE calls against a single node during lookup. +func (p *PortalProtocol) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { + var ( + dists = discover.LookupDistances(target, destNode.ID()) + nodes = discover.NodesByDistance{Target: target} + err error + ) + var r []*enode.Node + + r, err = p.findNodes(destNode, dists) + if errors.Is(err, discover.ErrClosed) { + return nil, err + } + for _, n := range r { + if n.ID() != p.Self().ID() { + isAdded := p.table.AddFoundNode(n, false) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) + } + nodes.Push(n, portalFindnodesResultLimit) + } + } + return nodes.Entries, err +} + +func (p *PortalProtocol) offerWorker() { + for { + select { + case <-p.closeCtx.Done(): + return + case offerRequestWithNode := <-p.offerQueue: + p.Log.Trace("offerWorker", "offerRequestWithNode", offerRequestWithNode) + _, err := p.offer(offerRequestWithNode.Node, offerRequestWithNode.Request) + if err != nil { + p.Log.Error("failed to offer", "err", err) + } + } + } +} + +func (p *PortalProtocol) truncateNodes(nodes []*enode.Node, maxSize int, enrOverhead int) [][]byte { + res := make([][]byte, 0) + totalSize := 0 + for _, n := range nodes { + enrBytes, err := rlp.EncodeToBytes(n.Record()) + if err != nil { + p.Log.Error("failed to encode n", "err", err) + continue + } + + if totalSize+len(enrBytes)+enrOverhead > maxSize { + break + } else { + res = append(res, enrBytes) + totalSize += len(enrBytes) + } + } + return res +} + +func (p *PortalProtocol) findNodesCloseToContent(contentId []byte, limit int) []*enode.Node { + allNodes := p.table.NodeList() + sort.Slice(allNodes, func(i, j int) bool { + return enode.LogDist(allNodes[i].ID(), enode.ID(contentId)) < enode.LogDist(allNodes[j].ID(), enode.ID(contentId)) + }) + + if len(allNodes) > limit { + allNodes = allNodes[:limit] + } else { + allNodes = allNodes[:] + } + + return allNodes +} + +// Lookup performs a recursive lookup for the given target. +// It returns the closest nodes to target. +func (p *PortalProtocol) Lookup(target enode.ID) []*enode.Node { + return p.newLookup(p.closeCtx, target).Run() +} + +// Resolve searches for a specific Node with the given ID and tries to get the most recent +// version of the Node record for it. It returns n if the Node could not be resolved. +func (p *PortalProtocol) Resolve(n *enode.Node) *enode.Node { + if intable := p.table.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + n = intable + } + // Try asking directly. This works if the Node is still responding on the endpoint we have. + if resp, err := p.RequestENR(n); err == nil { + return resp + } + // Otherwise do a network lookup. + result := p.Lookup(n.ID()) + for _, rn := range result { + if rn.ID() == n.ID() && rn.Seq() > n.Seq() { + return rn + } + } + return n +} + +// ResolveNodeId searches for a specific Node with the given ID. +// It returns nil if the nodeId could not be resolved. +func (p *PortalProtocol) ResolveNodeId(id enode.ID) *enode.Node { + if id == p.Self().ID() { + p.Log.Debug("Resolve Self Id", "id", id.String()) + return p.Self() + } + + n := p.table.GetNode(id) + if n != nil { + p.Log.Debug("found Id in table and will request enr from the node", "id", id.String()) + // Try asking directly. This works if the Node is still responding on the endpoint we have. + if resp, err := p.RequestENR(n); err == nil { + return resp + } + } + + // Otherwise do a network lookup. + result := p.Lookup(id) + for _, rn := range result { + if rn.ID() == id { + if n != nil && rn.Seq() <= n.Seq() { + return n + } else { + return rn + } + } + } + + return n +} + +func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { + var bn []*enode.Node + var nodes []*enode.Node + var processed = make(map[uint]struct{}) + for _, dist := range distances { + // Reject duplicate / invalid distances. + _, seen := processed[dist] + if seen || dist > 256 { + continue + } + processed[dist] = struct{}{} + + checkLive := !p.table.Config().NoFindnodeLivenessCheck + for _, n := range p.table.AppendBucketNodes(dist, bn[:0], checkLive) { + // Apply some pre-checks to avoid sending invalid nodes. + // Note liveness is checked by appendLiveNodes. + if netutil.CheckRelayIP(rip, n.IP()) != nil { + continue + } + nodes = append(nodes, n) + if len(nodes) >= limit { + return nodes + } + } + } + return nodes +} + +func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bool, error) { + lookupContext, cancel := context.WithCancel(context.Background()) + + resChan := make(chan *traceContentInfoResp, discover.Alpha) + hasResult := int32(0) + + result := ContentInfoResp{} + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + for res := range resChan { + if res.Flag != portalwire.ContentEnrsSelector { + result.Content = res.Content.([]byte) + result.UtpTransfer = res.UtpTransfer + } + } + }() + + discover.NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) + }).Run() + close(resChan) + + wg.Wait() + if hasResult == 1 { + return result.Content, result.UtpTransfer, nil + } + defer cancel() + return nil, false, ContentNotFound +} + +func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*TraceContentResult, error) { + lookupContext, cancel := context.WithCancel(context.Background()) + // resp channel + resChan := make(chan *traceContentInfoResp, discover.Alpha) + + hasResult := int32(0) + + traceContentRes := &TraceContentResult{} + + selfHexId := "0x" + p.Self().ID().String() + + trace := &Trace{ + Origin: selfHexId, + TargetId: hexutil.Encode(contentId), + StartedAtMs: int(time.Now().UnixMilli()), + Responses: make(map[string]RespByNode), + Metadata: make(map[string]*NodeMetadata), + Cancelled: make([]string, 0), + } + + nodes := p.table.FindnodeByID(enode.ID(contentId), discover.BucketSize, false) + + localResponse := make([]string, 0, len(nodes.Entries)) + for _, node := range nodes.Entries { + id := "0x" + node.ID().String() + localResponse = append(localResponse, id) + } + trace.Responses[selfHexId] = RespByNode{ + DurationMs: 0, + RespondedWith: localResponse, + } + + dis := p.Distance(p.Self().ID(), enode.ID(contentId)) + + trace.Metadata[selfHexId] = &NodeMetadata{ + Enr: p.Self().String(), + Distance: hexutil.Encode(dis[:]), + } + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + for res := range resChan { + node := res.Node + hexId := "0x" + node.ID().String() + dis := p.Distance(node.ID(), enode.ID(contentId)) + p.Log.Debug("reveice res", "id", hexId, "flag", res.Flag) + trace.Metadata[hexId] = &NodeMetadata{ + Enr: node.String(), + Distance: hexutil.Encode(dis[:]), + } + // no content return + if traceContentRes.Content == "" { + if res.Flag == portalwire.ContentRawSelector || res.Flag == portalwire.ContentConnIdSelector { + trace.ReceivedFrom = hexId + content := res.Content.([]byte) + traceContentRes.Content = hexutil.Encode(content) + traceContentRes.UtpTransfer = res.UtpTransfer + trace.Responses[hexId] = RespByNode{} + } else { + nodes := res.Content.([]*enode.Node) + respByNode := RespByNode{ + RespondedWith: make([]string, 0, len(nodes)), + } + for _, node := range nodes { + idInner := "0x" + node.ID().String() + respByNode.RespondedWith = append(respByNode.RespondedWith, idInner) + if _, ok := trace.Metadata[idInner]; !ok { + dis := p.Distance(node.ID(), enode.ID(contentId)) + trace.Metadata[idInner] = &NodeMetadata{ + Enr: node.String(), + Distance: hexutil.Encode(dis[:]), + } + } + trace.Responses[hexId] = respByNode + } + } + } else { + trace.Cancelled = append(trace.Cancelled, hexId) + } + } + }() + + lookup := discover.NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) + }) + lookup.Run() + close(resChan) + + wg.Wait() + if hasResult == 0 { + cancel() + } + traceContentRes.Trace = *trace + + return traceContentRes, nil +} + +func (p *PortalProtocol) contentLookupWorker(n *enode.Node, contentKey []byte, resChan chan<- *traceContentInfoResp, cancel context.CancelFunc, done *int32) ([]*enode.Node, error) { + wrapedNode := make([]*enode.Node, 0) + flag, content, err := p.findContent(n, contentKey) + if err != nil { + return nil, err + } + p.Log.Debug("traceContentLookupWorker reveice response", "ip", n.IP().String(), "flag", flag) + + switch flag { + case portalwire.ContentRawSelector, portalwire.ContentConnIdSelector: + content, ok := content.([]byte) + if !ok { + return wrapedNode, fmt.Errorf("failed to assert to raw content, value is: %v", content) + } + res := &traceContentInfoResp{ + Node: n, + Flag: flag, + Content: content, + UtpTransfer: false, + } + if flag == portalwire.ContentConnIdSelector { + res.UtpTransfer = true + } + if atomic.CompareAndSwapInt32(done, 0, 1) { + p.Log.Debug("contentLookupWorker find content", "ip", n.IP().String(), "port", n.UDP()) + resChan <- res + cancel() + } + return wrapedNode, err + case portalwire.ContentEnrsSelector: + nodes, ok := content.([]*enode.Node) + if !ok { + return wrapedNode, fmt.Errorf("failed to assert to enrs content, value is: %v", content) + } + resChan <- &traceContentInfoResp{ + Node: n, + Flag: flag, + Content: content, + UtpTransfer: false, + } + return nodes, nil + } + return wrapedNode, nil +} + +func (p *PortalProtocol) ToContentId(contentKey []byte) []byte { + return p.toContentId(contentKey) +} + +func (p *PortalProtocol) InRange(contentId []byte) bool { + return inRange(p.Self().ID(), p.Radius(), contentId) +} + +func (p *PortalProtocol) Get(contentKey []byte, contentId []byte) ([]byte, error) { + content, err := p.storage.Get(contentKey, contentId) + p.Log.Trace("get local storage", "contentId", hexutil.Encode(contentId), "content", hexutil.Encode(content), "err", err) + return content, err +} + +func (p *PortalProtocol) Put(contentKey []byte, contentId []byte, content []byte) error { + err := p.storage.Put(contentKey, contentId, content) + p.Log.Trace("put local storage", "contentId", hexutil.Encode(contentId), "content", hexutil.Encode(content), "err", err) + return err +} + +func (p *PortalProtocol) GetContent() chan *ContentElement { + return p.contentQueue +} + +func (p *PortalProtocol) Gossip(srcNodeId *enode.ID, contentKeys [][]byte, content [][]byte) (int, error) { + if len(content) == 0 { + return 0, errors.New("empty content") + } + + contentList := make([]*ContentEntry, 0, portalwire.ContentKeysLimit) + for i := 0; i < len(content); i++ { + contentEntry := &ContentEntry{ + ContentKey: contentKeys[i], + Content: content[i], + } + contentList = append(contentList, contentEntry) + } + + contentId := p.toContentId(contentKeys[0]) + if contentId == nil { + return 0, ErrNilContentKey + } + + maxClosestNodes := 4 + maxFartherNodes := 4 + closestLocalNodes := p.findNodesCloseToContent(contentId, 32) + p.Log.Debug("closest local nodes", "count", len(closestLocalNodes)) + + gossipNodes := make([]*enode.Node, 0) + for _, n := range closestLocalNodes { + radius, found := p.radiusCache.HasGet(nil, []byte(n.ID().String())) + if found { + p.Log.Debug("found closest local nodes", "nodeId", n.ID(), "addr", n.IPAddr().String()) + nodeRadius := new(uint256.Int) + err := nodeRadius.UnmarshalSSZ(radius) + if err != nil { + return 0, err + } + if inRange(n.ID(), nodeRadius, contentId) { + if srcNodeId == nil { + gossipNodes = append(gossipNodes, n) + } else if n.ID() != *srcNodeId { + gossipNodes = append(gossipNodes, n) + } + } + } + } + + if len(gossipNodes) == 0 { + return 0, nil + } + + var finalGossipNodes []*enode.Node + if len(gossipNodes) > maxClosestNodes { + fartherNodes := gossipNodes[maxClosestNodes:] + rand.Shuffle(len(fartherNodes), func(i, j int) { + fartherNodes[i], fartherNodes[j] = fartherNodes[j], fartherNodes[i] + }) + finalGossipNodes = append(gossipNodes[:maxClosestNodes], fartherNodes[:min(maxFartherNodes, len(fartherNodes))]...) + } else { + finalGossipNodes = gossipNodes + } + + for _, n := range finalGossipNodes { + transientOfferRequest := &TransientOfferRequest{ + Contents: contentList, + } + + offerRequest := &OfferRequest{ + Kind: TransientOfferRequestKind, + Request: transientOfferRequest, + } + + offerRequestWithNode := &OfferRequestWithNode{ + Node: n, + Request: offerRequest, + } + p.offerQueue <- offerRequestWithNode + } + + return len(finalGossipNodes), nil +} + +func (p *PortalProtocol) Distance(a, b enode.ID) enode.ID { + res := [32]byte{} + for i := range a { + res[i] = a[i] ^ b[i] + } + return res +} + +func inRange(nodeId enode.ID, nodeRadius *uint256.Int, contentId []byte) bool { + distance := enode.LogDist(nodeId, enode.ID(contentId)) + disBig := new(big.Int).SetInt64(int64(distance)) + return nodeRadius.CmpBig(disBig) > 0 +} + +func encodeContents(contents [][]byte) ([]byte, error) { + contentsBytes := make([]byte, 0) + for _, content := range contents { + contentLen := len(content) + contentLenBytes := leb128.EncodeUint32(uint32(contentLen)) + contentsBytes = append(contentsBytes, contentLenBytes...) + contentsBytes = append(contentsBytes, content...) + } + + return contentsBytes, nil +} + +func decodeContents(payload []byte) ([][]byte, error) { + contents := make([][]byte, 0) + buffer := bytes.NewBuffer(payload) + + for { + contentLen, contentLenLen, err := leb128.DecodeUint32(bytes.NewReader(buffer.Bytes())) + if err != nil { + if errors.Is(err, io.EOF) { + return contents, nil + } + return nil, err + } + + buffer.Next(int(contentLenLen)) + + content := make([]byte, contentLen) + _, err = buffer.Read(content) + if err != nil { + if errors.Is(err, io.EOF) { + return contents, nil + } + return nil, err + } + + contents = append(contents, content) + } +} + +func getContentKeys(request *OfferRequest) [][]byte { + if request.Kind == TransientOfferRequestKind { + contentKeys := make([][]byte, 0) + contents := request.Request.(*TransientOfferRequest).Contents + for _, content := range contents { + contentKeys = append(contentKeys, content.ContentKey) + } + + return contentKeys + } else { + return request.Request.(*PersistOfferRequest).ContentKeys + } +} diff --git a/portalnetwork/portal_protocol_metrics.go b/portalnetwork/portal_protocol_metrics.go new file mode 100644 index 000000000000..343d3f4f00f3 --- /dev/null +++ b/portalnetwork/portal_protocol_metrics.go @@ -0,0 +1,67 @@ +package portalnetwork + +import "github.com/ethereum/go-ethereum/metrics" + +type portalMetrics struct { + messagesReceivedAccept metrics.Meter + messagesReceivedNodes metrics.Meter + messagesReceivedFindNodes metrics.Meter + messagesReceivedFindContent metrics.Meter + messagesReceivedContent metrics.Meter + messagesReceivedOffer metrics.Meter + messagesReceivedPing metrics.Meter + messagesReceivedPong metrics.Meter + + messagesSentAccept metrics.Meter + messagesSentNodes metrics.Meter + messagesSentFindNodes metrics.Meter + messagesSentFindContent metrics.Meter + messagesSentContent metrics.Meter + messagesSentOffer metrics.Meter + messagesSentPing metrics.Meter + messagesSentPong metrics.Meter + + utpInFailConn metrics.Counter + utpInFailRead metrics.Counter + utpInFailDeadline metrics.Counter + utpInSuccess metrics.Counter + + utpOutFailConn metrics.Counter + utpOutFailWrite metrics.Counter + utpOutFailDeadline metrics.Counter + utpOutSuccess metrics.Counter + + contentDecodedTrue metrics.Counter + contentDecodedFalse metrics.Counter +} + +func newPortalMetrics(protocolName string) *portalMetrics { + return &portalMetrics{ + messagesReceivedAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/accept", nil), + messagesReceivedNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/nodes", nil), + messagesReceivedFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_nodes", nil), + messagesReceivedFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_content", nil), + messagesReceivedContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/content", nil), + messagesReceivedOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/offer", nil), + messagesReceivedPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/ping", nil), + messagesReceivedPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/pong", nil), + messagesSentAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/accept", nil), + messagesSentNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/nodes", nil), + messagesSentFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_nodes", nil), + messagesSentFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_content", nil), + messagesSentContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/content", nil), + messagesSentOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/offer", nil), + messagesSentPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/ping", nil), + messagesSentPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/pong", nil), + utpInFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_conn", nil), + utpInFailRead: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_read", nil), + utpInFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_deadline", nil), + utpInSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/success", nil), + utpOutFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_conn", nil), + utpOutFailWrite: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_write", nil), + utpOutFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_deadline", nil), + utpOutSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/success", nil), + contentDecodedTrue: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/true", nil), + contentDecodedFalse: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/false", nil), + } +} diff --git a/portalnetwork/portal_protocol_test.go b/portalnetwork/portal_protocol_test.go new file mode 100644 index 000000000000..a705f9fe27d7 --- /dev/null +++ b/portalnetwork/portal_protocol_test.go @@ -0,0 +1,503 @@ +package portalnetwork + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/optimism-java/utp-go" + "github.com/optimism-java/utp-go/libutp" + "github.com/prysmaticlabs/go-bitfield" + "golang.org/x/exp/slices" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover/portalwire" + "github.com/ethereum/go-ethereum/p2p/enode" + assert "github.com/stretchr/testify/require" +) + +func setupLocalPortalNode(addr string, bootNodes []*enode.Node) (*PortalProtocol, error) { + conf := DefaultPortalProtocolConfig() + conf.NAT = nil + if addr != "" { + conf.ListenAddr = addr + } + if bootNodes != nil { + conf.BootstrapNodes = bootNodes + } + + addr1, err := net.ResolveUDPAddr("udp", conf.ListenAddr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", addr1) + if err != nil { + return nil, err + } + + privKey := newkey() + + discCfg := Config{ + PrivateKey: privKey, + NetRestrict: conf.NetRestrict, + Bootnodes: conf.BootstrapNodes, + } + + nodeDB, err := enode.OpenDB(conf.NodeDBPath) + if err != nil { + return nil, err + } + + localNode := enode.NewLocalNode(nodeDB, privKey) + localNode.SetFallbackIP(net.IP{127, 0, 0, 1}) + localNode.Set(Tag) + + if conf.NAT == nil { + var addrs []net.Addr + addrs, err = net.InterfaceAddrs() + + if err != nil { + return nil, err + } + + for _, address := range addrs { + // check ip addr is loopback addr + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + localNode.SetStaticIP(ipnet.IP) + break + } + } + } + } + + discV5, err := ListenV5(conn, localNode, discCfg) + if err != nil { + return nil, err + } + utpSocket := NewPortalUtp(context.Background(), conf, discV5, conn) + + contentQueue := make(chan *ContentElement, 50) + portalProtocol, err := NewPortalProtocol( + conf, + portalwire.History, + privKey, + conn, + localNode, + discV5, + utpSocket, + &storage.MockStorage{Db: make(map[string][]byte)}, + contentQueue) + if err != nil { + return nil, err + } + + return portalProtocol, nil +} + +func TestPortalWireProtocolUdp(t *testing.T) { + node1, err := setupLocalPortalNode(":8777", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LvlTrace) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":8778", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LvlTrace) + err = node2.Start() + assert.NoError(t, err) + time.Sleep(12 * time.Second) + + node3, err := setupLocalPortalNode(":8779", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LvlTrace) + err = node3.Start() + assert.NoError(t, err) + time.Sleep(12 * time.Second) + + cid1 := libutp.ReceConnId(12) + cid2 := libutp.ReceConnId(116) + cliSendMsgWithCid1 := "there are connection id : 12!" + cliSendMsgWithCid2 := "there are connection id: 116!" + + serverEchoWithCid := "accept connection sends back msg: echo" + + largeTestContent := make([]byte, 1199) + _, err = rand.Read(largeTestContent) + assert.NoError(t, err) + + var workGroup sync.WaitGroup + var acceptGroup sync.WaitGroup + workGroup.Add(4) + acceptGroup.Add(1) + go func() { + var acceptConn *utp.Conn + defer func() { + workGroup.Done() + _ = acceptConn.Close() + }() + acceptConn, err := node1.Utp.AcceptWithCid(context.Background(), node2.localNode.ID(), cid1) + if err != nil { + panic(err) + } + acceptGroup.Done() + buf := make([]byte, 100) + n, err := acceptConn.Read(buf) + if err != nil && err != io.EOF { + panic(err) + } + assert.Equal(t, cliSendMsgWithCid1, string(buf[:n])) + _, err = acceptConn.Write([]byte(serverEchoWithCid)) + if err != nil { + panic(err) + } + }() + go func() { + var connId2Conn net.Conn + defer func() { + workGroup.Done() + _ = connId2Conn.Close() + }() + connId2Conn, err := node1.Utp.AcceptWithCid(context.Background(), node2.localNode.ID(), cid2) + if err != nil { + panic(err) + } + buf := make([]byte, 100) + n, err := connId2Conn.Read(buf) + if err != nil && err != io.EOF { + panic(err) + } + assert.Equal(t, cliSendMsgWithCid2, string(buf[:n])) + + _, err = connId2Conn.Write(largeTestContent) + if err != nil { + panic(err) + } + }() + + go func() { + var connWithConnId net.Conn + defer func() { + workGroup.Done() + if connWithConnId != nil { + _ = connWithConnId.Close() + } + }() + connWithConnId, err = node2.Utp.DialWithCid(context.Background(), node1.localNode.Node(), cid1.SendId()) + if err != nil { + panic(err) + } + _, err = connWithConnId.Write([]byte(cliSendMsgWithCid1)) + if err != nil && err != io.EOF { + panic(err) + } + buf := make([]byte, 100) + n, err := connWithConnId.Read(buf) + if err != nil && err != io.EOF { + panic(err) + } + assert.Equal(t, serverEchoWithCid, string(buf[:n])) + }() + go func() { + var ConnId2Conn net.Conn + defer func() { + workGroup.Done() + if ConnId2Conn != nil { + _ = ConnId2Conn.Close() + } + }() + ConnId2Conn, err = node2.Utp.DialWithCid(context.Background(), node1.localNode.Node(), cid2.SendId()) + if err != nil && err != io.EOF { + panic(err) + } + _, err = ConnId2Conn.Write([]byte(cliSendMsgWithCid2)) + if err != nil { + panic(err) + } + + data := make([]byte, 0) + buf := make([]byte, 1024) + for { + var n int + n, err = ConnId2Conn.Read(buf) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + } + data = append(data, buf[:n]...) + } + assert.Equal(t, largeTestContent, data) + }() + workGroup.Wait() + node1.Stop() + node2.Stop() + node3.Stop() +} + +func TestPortalWireProtocol(t *testing.T) { + node1, err := setupLocalPortalNode(":7777", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LevelDebug) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":7778", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LevelDebug) + err = node2.Start() + assert.NoError(t, err) + + time.Sleep(12 * time.Second) + + node3, err := setupLocalPortalNode(":7779", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LevelDebug) + err = node3.Start() + assert.NoError(t, err) + + time.Sleep(12 * time.Second) + + slices.ContainsFunc(node1.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node2.localNode.Node().ID() + }) + slices.ContainsFunc(node1.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node3.localNode.Node().ID() + }) + + slices.ContainsFunc(node2.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node1.localNode.Node().ID() + }) + slices.ContainsFunc(node2.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node3.localNode.Node().ID() + }) + + slices.ContainsFunc(node3.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node1.localNode.Node().ID() + }) + slices.ContainsFunc(node3.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node2.localNode.Node().ID() + }) + + err = node1.storage.Put(nil, node1.toContentId([]byte("test_key")), []byte("test_value")) + assert.NoError(t, err) + + flag, content, err := node2.findContent(node1.localNode.Node(), []byte("test_key")) + assert.NoError(t, err) + assert.Equal(t, portalwire.ContentRawSelector, flag) + assert.Equal(t, []byte("test_value"), content) + + flag, content, err = node2.findContent(node3.localNode.Node(), []byte("test_key")) + assert.NoError(t, err) + assert.Equal(t, portalwire.ContentEnrsSelector, flag) + assert.Equal(t, 1, len(content.([]*enode.Node))) + assert.Equal(t, node1.localNode.Node().ID(), content.([]*enode.Node)[0].ID()) + + // create a byte slice of length 1199 and fill it with random data + // this will be used as a test content + largeTestContent := make([]byte, 2000) + _, err = rand.Read(largeTestContent) + assert.NoError(t, err) + + err = node1.storage.Put(nil, node1.toContentId([]byte("large_test_key")), largeTestContent) + assert.NoError(t, err) + + flag, content, err = node2.findContent(node1.localNode.Node(), []byte("large_test_key")) + assert.NoError(t, err) + assert.Equal(t, largeTestContent, content) + assert.Equal(t, portalwire.ContentConnIdSelector, flag) + + testEntry1 := &ContentEntry{ + ContentKey: []byte("test_entry1"), + Content: []byte("test_entry1_content"), + } + + testEntry2 := &ContentEntry{ + ContentKey: []byte("test_entry2"), + Content: []byte("test_entry2_content"), + } + + testTransientOfferRequest := &TransientOfferRequest{ + Contents: []*ContentEntry{testEntry1, testEntry2}, + } + + offerRequest := &OfferRequest{ + Kind: TransientOfferRequestKind, + Request: testTransientOfferRequest, + } + + contentKeys, err := node1.offer(node3.localNode.Node(), offerRequest) + assert.Equal(t, uint64(2), bitfield.Bitlist(contentKeys).Count()) + assert.NoError(t, err) + + contentElement := <-node3.contentQueue + assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) + assert.Equal(t, testEntry1.ContentKey, contentElement.ContentKeys[0]) + assert.Equal(t, testEntry1.Content, contentElement.Contents[0]) + assert.Equal(t, testEntry2.ContentKey, contentElement.ContentKeys[1]) + assert.Equal(t, testEntry2.Content, contentElement.Contents[1]) + + testGossipContentKeys := [][]byte{[]byte("test_gossip_content_keys"), []byte("test_gossip_content_keys2")} + testGossipContent := [][]byte{[]byte("test_gossip_content"), []byte("test_gossip_content2")} + id := node1.Self().ID() + gossip, err := node1.Gossip(&id, testGossipContentKeys, testGossipContent) + assert.NoError(t, err) + assert.Equal(t, 2, gossip) + + contentElement = <-node2.contentQueue + assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) + assert.Equal(t, testGossipContentKeys[0], contentElement.ContentKeys[0]) + assert.Equal(t, testGossipContent[0], contentElement.Contents[0]) + assert.Equal(t, testGossipContentKeys[1], contentElement.ContentKeys[1]) + assert.Equal(t, testGossipContent[1], contentElement.Contents[1]) + + contentElement = <-node3.contentQueue + assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) + assert.Equal(t, testGossipContentKeys[0], contentElement.ContentKeys[0]) + assert.Equal(t, testGossipContent[0], contentElement.Contents[0]) + assert.Equal(t, testGossipContentKeys[1], contentElement.ContentKeys[1]) + assert.Equal(t, testGossipContent[1], contentElement.Contents[1]) + + node1.Stop() + node2.Stop() + node3.Stop() +} + +func TestCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + go func(ctx context.Context) { + defer func() { + t.Log("goroutine cancel") + }() + + time.Sleep(time.Second * 5) + }(ctx) + + cancel() + t.Log("after main cancel") + + time.Sleep(time.Second * 3) +} + +func TestContentLookup(t *testing.T) { + node1, err := setupLocalPortalNode(":17777", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LvlTrace) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":17778", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LvlTrace) + err = node2.Start() + assert.NoError(t, err) + fmt.Println(node2.localNode.Node().String()) + + node3, err := setupLocalPortalNode(":17779", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LvlTrace) + err = node3.Start() + assert.NoError(t, err) + + defer func() { + node1.Stop() + node2.Stop() + node3.Stop() + }() + + contentKey := []byte{0x3, 0x4} + content := []byte{0x1, 0x2} + contentId := node1.toContentId(contentKey) + + err = node3.storage.Put(nil, contentId, content) + assert.NoError(t, err) + + res, _, err := node1.ContentLookup(contentKey, contentId) + assert.NoError(t, err) + assert.Equal(t, res, content) + + nonExist := []byte{0x2, 0x4} + res, _, err = node1.ContentLookup(nonExist, node1.toContentId(nonExist)) + assert.Equal(t, ContentNotFound, err) + assert.Nil(t, res) +} + +func TestTraceContentLookup(t *testing.T) { + node1, err := setupLocalPortalNode(":17787", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LvlTrace) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":17788", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LvlTrace) + err = node2.Start() + assert.NoError(t, err) + + node3, err := setupLocalPortalNode(":17789", []*enode.Node{node2.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LvlTrace) + err = node3.Start() + assert.NoError(t, err) + + defer node1.Stop() + defer node2.Stop() + defer node3.Stop() + + contentKey := []byte{0x3, 0x4} + content := []byte{0x1, 0x2} + contentId := node1.toContentId(contentKey) + + err = node1.storage.Put(nil, contentId, content) + assert.NoError(t, err) + + node1Id := hexutil.Encode(node1.Self().ID().Bytes()) + node2Id := hexutil.Encode(node2.Self().ID().Bytes()) + node3Id := hexutil.Encode(node3.Self().ID().Bytes()) + + res, err := node3.TraceContentLookup(contentKey, contentId) + assert.NoError(t, err) + assert.Equal(t, res.Content, hexutil.Encode(content)) + assert.Equal(t, res.UtpTransfer, false) + assert.Equal(t, res.Trace.Origin, node3Id) + assert.Equal(t, res.Trace.TargetId, hexutil.Encode(contentId)) + assert.Equal(t, res.Trace.ReceivedFrom, node1Id) + + // check nodeMeta + node1Meta := res.Trace.Metadata[node1Id] + assert.Equal(t, node1Meta.Enr, node1.Self().String()) + dis := node1.Distance(node1.Self().ID(), enode.ID(contentId)) + assert.Equal(t, node1Meta.Distance, hexutil.Encode(dis[:])) + + node2Meta := res.Trace.Metadata[node2Id] + assert.Equal(t, node2Meta.Enr, node2.Self().String()) + dis = node2.Distance(node2.Self().ID(), enode.ID(contentId)) + assert.Equal(t, node2Meta.Distance, hexutil.Encode(dis[:])) + + node3Meta := res.Trace.Metadata[node3Id] + assert.Equal(t, node3Meta.Enr, node3.Self().String()) + dis = node3.Distance(node3.Self().ID(), enode.ID(contentId)) + assert.Equal(t, node3Meta.Distance, hexutil.Encode(dis[:])) + + // check response + node3Response := res.Trace.Responses[node3Id] + assert.Equal(t, node3Response.RespondedWith, []string{node2Id}) + + node2Response := res.Trace.Responses[node2Id] + assert.Equal(t, node2Response.RespondedWith, []string{node1Id}) + + node1Response := res.Trace.Responses[node1Id] + assert.Equal(t, node1Response.RespondedWith, ([]string)(nil)) +} diff --git a/portalnetwork/portal_utp.go b/portalnetwork/portal_utp.go new file mode 100644 index 000000000000..b1b58a7673ca --- /dev/null +++ b/portalnetwork/portal_utp.go @@ -0,0 +1,139 @@ +package portalnetwork + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/netutil" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" + "github.com/optimism-java/utp-go" + "github.com/optimism-java/utp-go/libutp" + "go.uber.org/zap" +) + +type PortalUtp struct { + ctx context.Context + log log.Logger + discV5 *discover.UDPv5 + conn discover.UDPConn + ListenAddr string + listener *utp.Listener + utpSm *utp.SocketManager + packetRouter *utp.PacketRouter + lAddr *utp.Addr + + startOnce sync.Once +} + +func NewPortalUtp(ctx context.Context, config *PortalProtocolConfig, discV5 *discover.UDPv5, conn discover.UDPConn) *PortalUtp { + return &PortalUtp{ + ctx: ctx, + log: log.New("protocol", "utp", "local", conn.LocalAddr().String()), + discV5: discV5, + conn: conn, + ListenAddr: config.ListenAddr, + } +} + +func (p *PortalUtp) Start() error { + var err error + go p.startOnce.Do(func() { + var logger *zap.Logger + if p.log.Enabled(p.ctx, log.LevelDebug) || p.log.Enabled(p.ctx, log.LevelTrace) { + logger, err = zap.NewDevelopmentConfig().Build() + } else { + logger, err = zap.NewProductionConfig().Build() + } + if err != nil { + return + } + + laddr := p.getLocalAddr() + p.packetRouter = utp.NewPacketRouter(p.packetRouterFunc) + p.utpSm, err = utp.NewSocketManagerWithOptions( + "utp", + laddr, + utp.WithContext(p.ctx), + utp.WithLogger(logger.Named(p.ListenAddr)), + utp.WithPacketRouter(p.packetRouter), + utp.WithMaxPacketSize(1145)) + if err != nil { + return + } + p.listener, err = utp.ListenUTPOptions("utp", (*utp.Addr)(laddr), utp.WithSocketManager(p.utpSm)) + if err != nil { + return + } + p.lAddr = p.listener.Addr().(*utp.Addr) + + // register discv5 listener + p.discV5.RegisterTalkHandler(string(portalwire.Utp), p.handleUtpTalkRequest) + }) + + return err +} + +func (p *PortalUtp) Stop() { + err := p.listener.Close() + if err != nil { + p.log.Error("close utp listener has error", "error", err) + } + p.discV5.Close() +} + +func (p *PortalUtp) DialWithCid(ctx context.Context, dest *enode.Node, connId uint16) (net.Conn, error) { + raddr := &utp.Addr{IP: dest.IP(), Port: dest.UDP()} + p.log.Debug("will connect to: ", "nodeId", dest.ID().String(), "connId", connId) + conn, err := utp.DialUTPOptions("utp", p.lAddr, raddr, utp.WithContext(ctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(connId)) + return conn, err +} + +func (p *PortalUtp) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { + raddr := &utp.Addr{IP: dest.IP(), Port: dest.UDP()} + p.log.Info("will connect to: ", "addr", raddr.String()) + conn, err := utp.DialUTPOptions("utp", p.lAddr, raddr, utp.WithContext(ctx), utp.WithSocketManager(p.utpSm)) + return conn, err +} + +func (p *PortalUtp) AcceptWithCid(ctx context.Context, nodeId enode.ID, cid *libutp.ConnId) (*utp.Conn, error) { + p.log.Debug("will accept from: ", "nodeId", nodeId.String(), "sendId", cid.SendId(), "recvId", cid.RecvId()) + return p.listener.AcceptUTPContext(ctx, nodeId, cid) +} + +func (p *PortalUtp) Accept(ctx context.Context) (*utp.Conn, error) { + return p.listener.AcceptUTPContext(ctx, enode.ID{}, nil) +} + +func (p *PortalUtp) getLocalAddr() *net.UDPAddr { + laddr := p.conn.LocalAddr().(*net.UDPAddr) + p.log.Debug("UDP listener up", "addr", laddr) + return laddr +} + +func (p *PortalUtp) packetRouterFunc(buf []byte, id enode.ID, addr *net.UDPAddr) (int, error) { + p.log.Info("will send to target data", "nodeId", id.String(), "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) + + if n, ok := p.discV5.GetCachedNode(addr.String()); ok { + //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) + req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf} + p.discV5.SendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) + + return len(buf), nil + } else { + p.log.Warn("not found target node info", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) + return 0, fmt.Errorf("not found target node id") + } +} + +func (p *PortalUtp) handleUtpTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { + p.log.Trace("receive utp data", "nodeId", id.String(), "addr", addr, "msg-length", len(msg)) + p.packetRouter.ReceiveMessage(msg, &utp.NodeInfo{Id: id, Addr: addr}) + return []byte("") +} diff --git a/portalnetwork/portalwire/messages.go b/portalnetwork/portalwire/messages.go new file mode 100644 index 000000000000..c7629604d570 --- /dev/null +++ b/portalnetwork/portalwire/messages.go @@ -0,0 +1,336 @@ +package portalwire + +import ( + ssz "github.com/ferranbt/fastssz" +) + +// note: We changed the generated file since fastssz issues which can't be passed by the CI, so we commented the go:generate line +///go:generate sszgen --path messages.go --exclude-objs Content,Enrs,ContentKV + +// Message codes for the portal protocol. +const ( + PING byte = 0x00 + PONG byte = 0x01 + FINDNODES byte = 0x02 + NODES byte = 0x03 + FINDCONTENT byte = 0x04 + CONTENT byte = 0x05 + OFFER byte = 0x06 + ACCEPT byte = 0x07 +) + +// Content selectors for the portal protocol. +const ( + ContentConnIdSelector byte = 0x00 + ContentRawSelector byte = 0x01 + ContentEnrsSelector byte = 0x02 +) + +const ( + ContentKeysLimit = 64 + // OfferMessageOverhead overhead of content message is a result of 1byte for kind enum, and + // 4 bytes for offset in ssz serialization + OfferMessageOverhead = 5 + + // PerContentKeyOverhead each key in ContentKeysList has uint32 offset which results in 4 bytes per + // key overhead when serialized + PerContentKeyOverhead = 4 +) + +// Protocol IDs for the portal protocol. +// var ( +// StateNetwork = []byte{0x50, 0x0a} +// HistoryNetwork = []byte{0x50, 0x0b} +// TxGossipNetwork = []byte{0x50, 0x0c} +// HeaderGossipNetwork = []byte{0x50, 0x0d} +// CanonicalIndicesNetwork = []byte{0x50, 0x0e} +// BeaconLightClientNetwork = []byte{0x50, 0x1a} +// UTPNetwork = []byte{0x75, 0x74, 0x70} +// Rendezvous = []byte{0x72, 0x65, 0x6e} +// ) + +type ProtocolId []byte + +var ( + State ProtocolId = []byte{0x50, 0x0A} + History ProtocolId = []byte{0x50, 0x0B} + Beacon ProtocolId = []byte{0x50, 0x0C} + CanonicalIndices ProtocolId = []byte{0x50, 0x0D} + VerkleState ProtocolId = []byte{0x50, 0x0E} + TransactionGossip ProtocolId = []byte{0x50, 0x0F} + Utp ProtocolId = []byte{0x75, 0x74, 0x70} +) + +var protocalName = map[string]string{ + string(State): "state", + string(History): "history", + string(Beacon): "beacon", + string(CanonicalIndices): "canonical indices", + string(VerkleState): "verkle state", + string(TransactionGossip): "transaction gossip", +} + +func (p ProtocolId) Name() string { + return protocalName[string(p)] +} + +// const ( +// HistoryNetworkName = "history" +// BeaconNetworkName = "beacon" +// StateNetworkName = "state" +// ) + +// var NetworkNameMap = map[string]string{ +// string(StateNetwork): StateNetworkName, +// string(HistoryNetwork): HistoryNetworkName, +// string(BeaconLightClientNetwork): BeaconNetworkName, +// } + +type ContentKV struct { + ContentKey []byte + Content []byte +} + +// Request messages for the portal protocol. +type ( + PingPongCustomData struct { + Radius []byte `ssz-size:"32"` + } + + Ping struct { + EnrSeq uint64 + CustomPayload []byte `ssz-max:"2048"` + } + + FindNodes struct { + Distances [][2]byte `ssz-max:"256,2" ssz-size:"?,2"` + } + + FindContent struct { + ContentKey []byte `ssz-max:"2048"` + } + + Offer struct { + ContentKeys [][]byte `ssz-max:"64,2048"` + } +) + +// Response messages for the portal protocol. +type ( + Pong struct { + EnrSeq uint64 + CustomPayload []byte `ssz-max:"2048"` + } + + Nodes struct { + Total uint8 + Enrs [][]byte `ssz-max:"32,2048"` + } + + ConnectionId struct { + Id []byte `ssz-size:"2"` + } + + Content struct { + Content []byte `ssz-max:"2048"` + } + + Enrs struct { + Enrs [][]byte `ssz-max:"32,2048"` + } + + Accept struct { + ConnectionId []byte `ssz-size:"2"` + ContentKeys []byte `ssz:"bitlist" ssz-max:"64"` + } +) + +// MarshalSSZ ssz marshals the Content object +func (c *Content) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(c) +} + +// MarshalSSZTo ssz marshals the Content object to a target array +func (c *Content) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Content' + if size := len(c.Content); size > 2048 { + err = ssz.ErrBytesLengthFn("Content.Content", size, 2048) + return + } + dst = append(dst, c.Content...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Content object +func (c *Content) UnmarshalSSZ(buf []byte) error { + var err error + tail := buf + + // Field (0) 'Content' + { + buf = tail[:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(c.Content) == 0 { + c.Content = make([]byte, 0, len(buf)) + } + c.Content = append(c.Content, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Content object +func (c *Content) SizeSSZ() (size int) { + // Field (0) 'Content' + return len(c.Content) +} + +// HashTreeRoot ssz hashes the Content object +func (c *Content) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(c) +} + +// HashTreeRootWith ssz hashes the Content object with a hasher +func (c *Content) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Content' + { + elemIndx := hh.Index() + byteLen := uint64(len(c.Content)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(c.Content) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Content object +func (c *Content) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(c) +} + +// MarshalSSZ ssz marshals the Enrs object +func (e *Enrs) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(e) +} + +// MarshalSSZTo ssz marshals the Enrs object to a target array +func (e *Enrs) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(0) + + // Field (0) 'Enrs' + if size := len(e.Enrs); size > 32 { + err = ssz.ErrListTooBigFn("Enrs.Enrs", size, 32) + return + } + { + offset = 4 * len(e.Enrs) + for ii := 0; ii < len(e.Enrs); ii++ { + dst = ssz.WriteOffset(dst, offset) + offset += len(e.Enrs[ii]) + } + } + for ii := 0; ii < len(e.Enrs); ii++ { + if size := len(e.Enrs[ii]); size > 2048 { + err = ssz.ErrBytesLengthFn("Enrs.Enrs[ii]", size, 2048) + return + } + dst = append(dst, e.Enrs[ii]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Enrs object +func (e *Enrs) UnmarshalSSZ(buf []byte) error { + var err error + tail := buf + // Field (0) 'Enrs' + { + buf = tail[:] + num, err := ssz.DecodeDynamicLength(buf, 32) + if err != nil { + return err + } + e.Enrs = make([][]byte, num) + err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(e.Enrs[indx]) == 0 { + e.Enrs[indx] = make([]byte, 0, len(buf)) + } + e.Enrs[indx] = append(e.Enrs[indx], buf...) + return nil + }) + if err != nil { + return err + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Enrs object +func (e *Enrs) SizeSSZ() (size int) { + size = 0 + + // Field (0) 'Enrs' + for ii := 0; ii < len(e.Enrs); ii++ { + size += 4 + size += len(e.Enrs[ii]) + } + + return +} + +// HashTreeRoot ssz hashes the Enrs object +func (e *Enrs) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(e) +} + +// HashTreeRootWith ssz hashes the Enrs object with a hasher +func (e *Enrs) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Enrs' + { + subIndx := hh.Index() + num := uint64(len(e.Enrs)) + if num > 32 { + err = ssz.ErrIncorrectListSize + return + } + for _, elem := range e.Enrs { + { + elemIndx := hh.Index() + byteLen := uint64(len(elem)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.AppendBytes32(elem) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + } + hh.MerkleizeWithMixin(subIndx, num, 32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Enrs object +func (e *Enrs) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(e) +} diff --git a/portalnetwork/portalwire/messages_encoding.go b/portalnetwork/portalwire/messages_encoding.go new file mode 100644 index 000000000000..601150baff1a --- /dev/null +++ b/portalnetwork/portalwire/messages_encoding.go @@ -0,0 +1,957 @@ +// Code generated by fastssz. DO NOT EDIT. +// Hash: 26a61b12807ff78c64a029acdd5bcb580dfe35b7bfbf8bf04ceebae1a3d5cac1 +// Version: 0.1.3 +package portalwire + +import ( + ssz "github.com/ferranbt/fastssz" +) + +// MarshalSSZ ssz marshals the PingPongCustomData object +func (p *PingPongCustomData) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(p) +} + +// MarshalSSZTo ssz marshals the PingPongCustomData object to a target array +func (p *PingPongCustomData) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Radius' + if size := len(p.Radius); size != 32 { + err = ssz.ErrBytesLengthFn("PingPongCustomData.Radius", size, 32) + return + } + dst = append(dst, p.Radius...) + + return +} + +// UnmarshalSSZ ssz unmarshals the PingPongCustomData object +func (p *PingPongCustomData) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 32 { + return ssz.ErrSize + } + + // Field (0) 'Radius' + if cap(p.Radius) == 0 { + p.Radius = make([]byte, 0, len(buf[0:32])) + } + p.Radius = append(p.Radius, buf[0:32]...) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the PingPongCustomData object +func (p *PingPongCustomData) SizeSSZ() (size int) { + size = 32 + return +} + +// HashTreeRoot ssz hashes the PingPongCustomData object +func (p *PingPongCustomData) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(p) +} + +// HashTreeRootWith ssz hashes the PingPongCustomData object with a hasher +func (p *PingPongCustomData) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Radius' + if size := len(p.Radius); size != 32 { + err = ssz.ErrBytesLengthFn("PingPongCustomData.Radius", size, 32) + return + } + hh.PutBytes(p.Radius) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the PingPongCustomData object +func (p *PingPongCustomData) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(p) +} + +// MarshalSSZ ssz marshals the Ping object +func (p *Ping) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(p) +} + +// MarshalSSZTo ssz marshals the Ping object to a target array +func (p *Ping) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(12) + + // Field (0) 'EnrSeq' + dst = ssz.MarshalUint64(dst, p.EnrSeq) + + // Offset (1) 'CustomPayload' + dst = ssz.WriteOffset(dst, offset) + offset += len(p.CustomPayload) + + // Field (1) 'CustomPayload' + if size := len(p.CustomPayload); size > 2048 { + err = ssz.ErrBytesLengthFn("Ping.CustomPayload", size, 2048) + return + } + dst = append(dst, p.CustomPayload...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Ping object +func (p *Ping) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 12 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'EnrSeq' + p.EnrSeq = ssz.UnmarshallUint64(buf[0:8]) + + // Offset (1) 'CustomPayload' + if o1 = ssz.ReadOffset(buf[8:12]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 12 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'CustomPayload' + { + buf = tail[o1:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(p.CustomPayload) == 0 { + p.CustomPayload = make([]byte, 0, len(buf)) + } + p.CustomPayload = append(p.CustomPayload, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Ping object +func (p *Ping) SizeSSZ() (size int) { + size = 12 + + // Field (1) 'CustomPayload' + size += len(p.CustomPayload) + + return +} + +// HashTreeRoot ssz hashes the Ping object +func (p *Ping) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(p) +} + +// HashTreeRootWith ssz hashes the Ping object with a hasher +func (p *Ping) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'EnrSeq' + hh.PutUint64(p.EnrSeq) + + // Field (1) 'CustomPayload' + { + elemIndx := hh.Index() + byteLen := uint64(len(p.CustomPayload)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(p.CustomPayload) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Ping object +func (p *Ping) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(p) +} + +// MarshalSSZ ssz marshals the FindNodes object +func (f *FindNodes) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(f) +} + +// MarshalSSZTo ssz marshals the FindNodes object to a target array +func (f *FindNodes) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(4) + + // Offset (0) 'Distances' + dst = ssz.WriteOffset(dst, offset) + offset += len(f.Distances) * 2 + + // Field (0) 'Distances' + if size := len(f.Distances); size > 256 { + err = ssz.ErrListTooBigFn("FindNodes.Distances", size, 256) + return + } + for ii := 0; ii < len(f.Distances); ii++ { + dst = append(dst, f.Distances[ii][:]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the FindNodes object +func (f *FindNodes) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 4 { + return ssz.ErrSize + } + + tail := buf + var o0 uint64 + + // Offset (0) 'Distances' + if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { + return ssz.ErrOffset + } + + if o0 < 4 { + return ssz.ErrInvalidVariableOffset + } + + // Field (0) 'Distances' + { + buf = tail[o0:] + num, err := ssz.DivideInt2(len(buf), 2, 256) + if err != nil { + return err + } + f.Distances = make([][2]byte, num) + for ii := 0; ii < num; ii++ { + copy(f.Distances[ii][:], buf[ii*2:(ii+1)*2]) + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the FindNodes object +func (f *FindNodes) SizeSSZ() (size int) { + size = 4 + + // Field (0) 'Distances' + size += len(f.Distances) * 2 + + return +} + +// HashTreeRoot ssz hashes the FindNodes object +func (f *FindNodes) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(f) +} + +// HashTreeRootWith ssz hashes the FindNodes object with a hasher +func (f *FindNodes) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Distances' + { + if size := len(f.Distances); size > 256 { + err = ssz.ErrListTooBigFn("FindNodes.Distances", size, 256) + return + } + subIndx := hh.Index() + for _, i := range f.Distances { + hh.PutBytes(i[:]) + } + numItems := uint64(len(f.Distances)) + hh.MerkleizeWithMixin(subIndx, numItems, 256) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the FindNodes object +func (f *FindNodes) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(f) +} + +// MarshalSSZ ssz marshals the FindContent object +func (f *FindContent) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(f) +} + +// MarshalSSZTo ssz marshals the FindContent object to a target array +func (f *FindContent) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(4) + + // Offset (0) 'ContentKey' + dst = ssz.WriteOffset(dst, offset) + offset += len(f.ContentKey) + + // Field (0) 'ContentKey' + if size := len(f.ContentKey); size > 2048 { + err = ssz.ErrBytesLengthFn("FindContent.ContentKey", size, 2048) + return + } + dst = append(dst, f.ContentKey...) + + return +} + +// UnmarshalSSZ ssz unmarshals the FindContent object +func (f *FindContent) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 4 { + return ssz.ErrSize + } + + tail := buf + var o0 uint64 + + // Offset (0) 'ContentKey' + if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { + return ssz.ErrOffset + } + + if o0 < 4 { + return ssz.ErrInvalidVariableOffset + } + + // Field (0) 'ContentKey' + { + buf = tail[o0:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(f.ContentKey) == 0 { + f.ContentKey = make([]byte, 0, len(buf)) + } + f.ContentKey = append(f.ContentKey, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the FindContent object +func (f *FindContent) SizeSSZ() (size int) { + size = 4 + + // Field (0) 'ContentKey' + size += len(f.ContentKey) + + return +} + +// HashTreeRoot ssz hashes the FindContent object +func (f *FindContent) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(f) +} + +// HashTreeRootWith ssz hashes the FindContent object with a hasher +func (f *FindContent) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'ContentKey' + { + elemIndx := hh.Index() + byteLen := uint64(len(f.ContentKey)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(f.ContentKey) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the FindContent object +func (f *FindContent) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(f) +} + +// MarshalSSZ ssz marshals the Offer object +func (o *Offer) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(o) +} + +// MarshalSSZTo ssz marshals the Offer object to a target array +func (o *Offer) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(4) + + // Offset (0) 'ContentKeys' + dst = ssz.WriteOffset(dst, offset) + for ii := 0; ii < len(o.ContentKeys); ii++ { + offset += 4 + offset += len(o.ContentKeys[ii]) + } + + // Field (0) 'ContentKeys' + if size := len(o.ContentKeys); size > 64 { + err = ssz.ErrListTooBigFn("Offer.ContentKeys", size, 64) + return + } + { + offset = 4 * len(o.ContentKeys) + for ii := 0; ii < len(o.ContentKeys); ii++ { + dst = ssz.WriteOffset(dst, offset) + offset += len(o.ContentKeys[ii]) + } + } + for ii := 0; ii < len(o.ContentKeys); ii++ { + if size := len(o.ContentKeys[ii]); size > 2048 { + err = ssz.ErrBytesLengthFn("Offer.ContentKeys[ii]", size, 2048) + return + } + dst = append(dst, o.ContentKeys[ii]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Offer object +func (o *Offer) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 4 { + return ssz.ErrSize + } + + tail := buf + var o0 uint64 + + // Offset (0) 'ContentKeys' + if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { + return ssz.ErrOffset + } + + if o0 < 4 { + return ssz.ErrInvalidVariableOffset + } + + // Field (0) 'ContentKeys' + { + buf = tail[o0:] + num, err := ssz.DecodeDynamicLength(buf, 64) + if err != nil { + return err + } + o.ContentKeys = make([][]byte, num) + err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(o.ContentKeys[indx]) == 0 { + o.ContentKeys[indx] = make([]byte, 0, len(buf)) + } + o.ContentKeys[indx] = append(o.ContentKeys[indx], buf...) + return nil + }) + if err != nil { + return err + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Offer object +func (o *Offer) SizeSSZ() (size int) { + size = 4 + + // Field (0) 'ContentKeys' + for ii := 0; ii < len(o.ContentKeys); ii++ { + size += 4 + size += len(o.ContentKeys[ii]) + } + + return +} + +// HashTreeRoot ssz hashes the Offer object +func (o *Offer) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(o) +} + +// HashTreeRootWith ssz hashes the Offer object with a hasher +func (o *Offer) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'ContentKeys' + { + subIndx := hh.Index() + num := uint64(len(o.ContentKeys)) + if num > 64 { + err = ssz.ErrIncorrectListSize + return + } + for _, elem := range o.ContentKeys { + { + elemIndx := hh.Index() + byteLen := uint64(len(elem)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.AppendBytes32(elem) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + } + hh.MerkleizeWithMixin(subIndx, num, 64) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Offer object +func (o *Offer) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(o) +} + +// MarshalSSZ ssz marshals the Pong object +func (p *Pong) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(p) +} + +// MarshalSSZTo ssz marshals the Pong object to a target array +func (p *Pong) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(12) + + // Field (0) 'EnrSeq' + dst = ssz.MarshalUint64(dst, p.EnrSeq) + + // Offset (1) 'CustomPayload' + dst = ssz.WriteOffset(dst, offset) + offset += len(p.CustomPayload) + + // Field (1) 'CustomPayload' + if size := len(p.CustomPayload); size > 2048 { + err = ssz.ErrBytesLengthFn("Pong.CustomPayload", size, 2048) + return + } + dst = append(dst, p.CustomPayload...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Pong object +func (p *Pong) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 12 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'EnrSeq' + p.EnrSeq = ssz.UnmarshallUint64(buf[0:8]) + + // Offset (1) 'CustomPayload' + if o1 = ssz.ReadOffset(buf[8:12]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 12 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'CustomPayload' + { + buf = tail[o1:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(p.CustomPayload) == 0 { + p.CustomPayload = make([]byte, 0, len(buf)) + } + p.CustomPayload = append(p.CustomPayload, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Pong object +func (p *Pong) SizeSSZ() (size int) { + size = 12 + + // Field (1) 'CustomPayload' + size += len(p.CustomPayload) + + return +} + +// HashTreeRoot ssz hashes the Pong object +func (p *Pong) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(p) +} + +// HashTreeRootWith ssz hashes the Pong object with a hasher +func (p *Pong) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'EnrSeq' + hh.PutUint64(p.EnrSeq) + + // Field (1) 'CustomPayload' + { + elemIndx := hh.Index() + byteLen := uint64(len(p.CustomPayload)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(p.CustomPayload) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Pong object +func (p *Pong) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(p) +} + +// MarshalSSZ ssz marshals the Nodes object +func (n *Nodes) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(n) +} + +// MarshalSSZTo ssz marshals the Nodes object to a target array +func (n *Nodes) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(5) + + // Field (0) 'Total' + dst = ssz.MarshalUint8(dst, n.Total) + + // Offset (1) 'Enrs' + dst = ssz.WriteOffset(dst, offset) + for ii := 0; ii < len(n.Enrs); ii++ { + offset += 4 + offset += len(n.Enrs[ii]) + } + + // Field (1) 'Enrs' + if size := len(n.Enrs); size > 32 { + err = ssz.ErrListTooBigFn("Nodes.Enrs", size, 32) + return + } + { + offset = 4 * len(n.Enrs) + for ii := 0; ii < len(n.Enrs); ii++ { + dst = ssz.WriteOffset(dst, offset) + offset += len(n.Enrs[ii]) + } + } + for ii := 0; ii < len(n.Enrs); ii++ { + if size := len(n.Enrs[ii]); size > 2048 { + err = ssz.ErrBytesLengthFn("Nodes.Enrs[ii]", size, 2048) + return + } + dst = append(dst, n.Enrs[ii]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Nodes object +func (n *Nodes) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 5 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'Total' + n.Total = ssz.UnmarshallUint8(buf[0:1]) + + // Offset (1) 'Enrs' + if o1 = ssz.ReadOffset(buf[1:5]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 5 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'Enrs' + { + buf = tail[o1:] + num, err := ssz.DecodeDynamicLength(buf, 32) + if err != nil { + return err + } + n.Enrs = make([][]byte, num) + err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(n.Enrs[indx]) == 0 { + n.Enrs[indx] = make([]byte, 0, len(buf)) + } + n.Enrs[indx] = append(n.Enrs[indx], buf...) + return nil + }) + if err != nil { + return err + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Nodes object +func (n *Nodes) SizeSSZ() (size int) { + size = 5 + + // Field (1) 'Enrs' + for ii := 0; ii < len(n.Enrs); ii++ { + size += 4 + size += len(n.Enrs[ii]) + } + + return +} + +// HashTreeRoot ssz hashes the Nodes object +func (n *Nodes) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(n) +} + +// HashTreeRootWith ssz hashes the Nodes object with a hasher +func (n *Nodes) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Total' + hh.PutUint8(n.Total) + + // Field (1) 'Enrs' + { + subIndx := hh.Index() + num := uint64(len(n.Enrs)) + if num > 32 { + err = ssz.ErrIncorrectListSize + return + } + for _, elem := range n.Enrs { + { + elemIndx := hh.Index() + byteLen := uint64(len(elem)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.AppendBytes32(elem) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + } + hh.MerkleizeWithMixin(subIndx, num, 32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Nodes object +func (n *Nodes) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(n) +} + +// MarshalSSZ ssz marshals the ConnectionId object +func (c *ConnectionId) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(c) +} + +// MarshalSSZTo ssz marshals the ConnectionId object to a target array +func (c *ConnectionId) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Id' + if size := len(c.Id); size != 2 { + err = ssz.ErrBytesLengthFn("ConnectionId.Id", size, 2) + return + } + dst = append(dst, c.Id...) + + return +} + +// UnmarshalSSZ ssz unmarshals the ConnectionId object +func (c *ConnectionId) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 2 { + return ssz.ErrSize + } + + // Field (0) 'Id' + if cap(c.Id) == 0 { + c.Id = make([]byte, 0, len(buf[0:2])) + } + c.Id = append(c.Id, buf[0:2]...) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the ConnectionId object +func (c *ConnectionId) SizeSSZ() (size int) { + size = 2 + return +} + +// HashTreeRoot ssz hashes the ConnectionId object +func (c *ConnectionId) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(c) +} + +// HashTreeRootWith ssz hashes the ConnectionId object with a hasher +func (c *ConnectionId) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Id' + if size := len(c.Id); size != 2 { + err = ssz.ErrBytesLengthFn("ConnectionId.Id", size, 2) + return + } + hh.PutBytes(c.Id) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the ConnectionId object +func (c *ConnectionId) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(c) +} + +// MarshalSSZ ssz marshals the Accept object +func (a *Accept) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(a) +} + +// MarshalSSZTo ssz marshals the Accept object to a target array +func (a *Accept) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(6) + + // Field (0) 'ConnectionId' + if size := len(a.ConnectionId); size != 2 { + err = ssz.ErrBytesLengthFn("Accept.ConnectionId", size, 2) + return + } + dst = append(dst, a.ConnectionId...) + + // Offset (1) 'ContentKeys' + dst = ssz.WriteOffset(dst, offset) + offset += len(a.ContentKeys) + + // Field (1) 'ContentKeys' + if size := len(a.ContentKeys); size > 64 { + err = ssz.ErrBytesLengthFn("Accept.ContentKeys", size, 64) + return + } + dst = append(dst, a.ContentKeys...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Accept object +func (a *Accept) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 6 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'ConnectionId' + if cap(a.ConnectionId) == 0 { + a.ConnectionId = make([]byte, 0, len(buf[0:2])) + } + a.ConnectionId = append(a.ConnectionId, buf[0:2]...) + + // Offset (1) 'ContentKeys' + if o1 = ssz.ReadOffset(buf[2:6]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 6 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'ContentKeys' + { + buf = tail[o1:] + if err = ssz.ValidateBitlist(buf, 64); err != nil { + return err + } + if cap(a.ContentKeys) == 0 { + a.ContentKeys = make([]byte, 0, len(buf)) + } + a.ContentKeys = append(a.ContentKeys, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Accept object +func (a *Accept) SizeSSZ() (size int) { + size = 6 + + // Field (1) 'ContentKeys' + size += len(a.ContentKeys) + + return +} + +// HashTreeRoot ssz hashes the Accept object +func (a *Accept) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(a) +} + +// HashTreeRootWith ssz hashes the Accept object with a hasher +func (a *Accept) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'ConnectionId' + if size := len(a.ConnectionId); size != 2 { + err = ssz.ErrBytesLengthFn("Accept.ConnectionId", size, 2) + return + } + hh.PutBytes(a.ConnectionId) + + // Field (1) 'ContentKeys' + if len(a.ContentKeys) == 0 { + err = ssz.ErrEmptyBitlist + return + } + hh.PutBitlist(a.ContentKeys, 64) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Accept object +func (a *Accept) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(a) +} diff --git a/portalnetwork/portalwire/messages_test.go b/portalnetwork/portalwire/messages_test.go new file mode 100644 index 000000000000..9e266cf41789 --- /dev/null +++ b/portalnetwork/portalwire/messages_test.go @@ -0,0 +1,212 @@ +package portalwire + +import ( + "fmt" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" + ssz "github.com/ferranbt/fastssz" + "github.com/holiman/uint256" + "github.com/prysmaticlabs/go-bitfield" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var maxUint256 = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + +// https://github.com/ethereum/portal-network-specs/blob/master/portal-wire-test-vectors.md +// we remove the message type here +func TestPingMessage(t *testing.T) { + dataRadius := maxUint256.Sub(maxUint256, uint256.NewInt(1)) + reverseBytes, err := dataRadius.MarshalSSZ() + require.NoError(t, err) + customData := &PingPongCustomData{ + Radius: reverseBytes, + } + dataBytes, err := customData.MarshalSSZ() + assert.NoError(t, err) + ping := &Ping{ + EnrSeq: 1, + CustomPayload: dataBytes, + } + + expected := "0x01000000000000000c000000feffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + + data, err := ping.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} + +func TestPongMessage(t *testing.T) { + dataRadius := maxUint256.Div(maxUint256, uint256.NewInt(2)) + reverseBytes, err := dataRadius.MarshalSSZ() + require.NoError(t, err) + customData := &PingPongCustomData{ + Radius: reverseBytes, + } + + dataBytes, err := customData.MarshalSSZ() + assert.NoError(t, err) + pong := &Pong{ + EnrSeq: 1, + CustomPayload: dataBytes, + } + + expected := "0x01000000000000000c000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" + + data, err := pong.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} + +func TestFindNodesMessage(t *testing.T) { + distances := []uint16{256, 255} + + distancesBytes := make([][2]byte, len(distances)) + for i, distance := range distances { + copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), distance)) + } + + findNode := &FindNodes{ + Distances: distancesBytes, + } + + data, err := findNode.MarshalSSZ() + expected := "0x040000000001ff00" + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} + +func TestNodes(t *testing.T) { + enrs := []string{ + "enr:-HW4QBzimRxkmT18hMKaAL3IcZF1UcfTMPyi3Q1pxwZZbcZVRI8DC5infUAB_UauARLOJtYTxaagKoGmIjzQxO2qUygBgmlkgnY0iXNlY3AyNTZrMaEDymNMrg1JrLQB2KTGtv6MVbcNEVv0AHacwUAPMljNMTg", + "enr:-HW4QNfxw543Ypf4HXKXdYxkyzfcxcO-6p9X986WldfVpnVTQX1xlTnWrktEWUbeTZnmgOuAY_KUhbVV1Ft98WoYUBMBgmlkgnY0iXNlY3AyNTZrMaEDDiy3QkHAxPyOgWbxp5oF1bDdlYE6dLCUUp8xfVw50jU", + } + + enrsBytes := make([][]byte, 0) + for _, enr := range enrs { + n, err := enode.Parse(enode.ValidSchemes, enr) + assert.NoError(t, err) + + enrBytes, err := rlp.EncodeToBytes(n.Record()) + assert.NoError(t, err) + enrsBytes = append(enrsBytes, enrBytes) + } + + testCases := []struct { + name string + input [][]byte + expected string + }{ + { + name: "empty nodes", + input: make([][]byte, 0), + expected: "0x0105000000", + }, + { + name: "two nodes", + input: enrsBytes, + expected: "0x0105000000080000007f000000f875b8401ce2991c64993d7c84c29a00bdc871917551c7d330fca2dd0d69c706596dc655448f030b98a77d4001fd46ae0112ce26d613c5a6a02a81a6223cd0c4edaa53280182696482763489736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138f875b840d7f1c39e376297f81d7297758c64cb37dcc5c3beea9f57f7ce9695d7d5a67553417d719539d6ae4b445946de4d99e680eb8063f29485b555d45b7df16a1850130182696482763489736563703235366b31a1030e2cb74241c0c4fc8e8166f1a79a05d5b0dd95813a74b094529f317d5c39d235", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + nodes := &Nodes{ + Total: 1, + Enrs: test.input, + } + + data, err := nodes.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, test.expected, fmt.Sprintf("0x%x", data)) + }) + } +} + +func TestContent(t *testing.T) { + contentKey := "0x706f7274616c" + + content := &FindContent{ + ContentKey: hexutil.MustDecode(contentKey), + } + expected := "0x04000000706f7274616c" + data, err := content.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + expected = "0x7468652063616b652069732061206c6965" + + contentRes := &Content{ + Content: hexutil.MustDecode("0x7468652063616b652069732061206c6965"), + } + + data, err = contentRes.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + expectData := &Content{} + err = expectData.UnmarshalSSZ(data) + assert.NoError(t, err) + assert.Equal(t, contentRes.Content, expectData.Content) + + enrs := []string{ + "enr:-HW4QBzimRxkmT18hMKaAL3IcZF1UcfTMPyi3Q1pxwZZbcZVRI8DC5infUAB_UauARLOJtYTxaagKoGmIjzQxO2qUygBgmlkgnY0iXNlY3AyNTZrMaEDymNMrg1JrLQB2KTGtv6MVbcNEVv0AHacwUAPMljNMTg", + "enr:-HW4QNfxw543Ypf4HXKXdYxkyzfcxcO-6p9X986WldfVpnVTQX1xlTnWrktEWUbeTZnmgOuAY_KUhbVV1Ft98WoYUBMBgmlkgnY0iXNlY3AyNTZrMaEDDiy3QkHAxPyOgWbxp5oF1bDdlYE6dLCUUp8xfVw50jU", + } + + enrsBytes := make([][]byte, 0) + for _, enr := range enrs { + n, err := enode.Parse(enode.ValidSchemes, enr) + assert.NoError(t, err) + + enrBytes, err := rlp.EncodeToBytes(n.Record()) + assert.NoError(t, err) + enrsBytes = append(enrsBytes, enrBytes) + } + + enrsRes := &Enrs{ + Enrs: enrsBytes, + } + + expected = "0x080000007f000000f875b8401ce2991c64993d7c84c29a00bdc871917551c7d330fca2dd0d69c706596dc655448f030b98a77d4001fd46ae0112ce26d613c5a6a02a81a6223cd0c4edaa53280182696482763489736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138f875b840d7f1c39e376297f81d7297758c64cb37dcc5c3beea9f57f7ce9695d7d5a67553417d719539d6ae4b445946de4d99e680eb8063f29485b555d45b7df16a1850130182696482763489736563703235366b31a1030e2cb74241c0c4fc8e8166f1a79a05d5b0dd95813a74b094529f317d5c39d235" + + data, err = enrsRes.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + expectEnrs := &Enrs{} + err = expectEnrs.UnmarshalSSZ(data) + assert.NoError(t, err) + assert.Equal(t, expectEnrs.Enrs, enrsRes.Enrs) +} + +func TestOfferAndAcceptMessage(t *testing.T) { + contentKey := "0x010203" + contentBytes := hexutil.MustDecode(contentKey) + contentKeys := [][]byte{contentBytes} + offer := &Offer{ + ContentKeys: contentKeys, + } + + expected := "0x0400000004000000010203" + + data, err := offer.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + contentKeyBitlist := bitfield.NewBitlist(8) + contentKeyBitlist.SetBitAt(0, true) + accept := &Accept{ + ConnectionId: []byte{0x01, 0x02}, + ContentKeys: contentKeyBitlist, + } + + expected = "0x0102060000000101" + + data, err = accept.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} From b1ae92d3cf194712171d3983378c393361c88574 Mon Sep 17 00:00:00 2001 From: Chen Kai <281165273grape@gmail.com> Date: Wed, 20 Nov 2024 22:19:37 +0800 Subject: [PATCH 02/13] feat:make portal wire out of p2p package Signed-off-by: Chen Kai <281165273grape@gmail.com> --- cmd/devp2p/discv4cmd.go | 2 +- cmd/devp2p/discv5cmd.go | 2 +- p2p/discover/api.go | 12 +- p2p/discover/lookup.go | 48 +- p2p/discover/node.go | 26 +- p2p/discover/portal_protocol.go | 78 +- p2p/discover/portal_utp.go | 2 +- p2p/discover/table.go | 87 +- p2p/discover/table_reval.go | 2 +- p2p/discover/table_reval_test.go | 2 +- p2p/discover/table_test.go | 70 +- p2p/discover/table_util_test.go | 18 +- p2p/discover/v4_lookup_test.go | 4 +- p2p/discover/v4_udp.go | 52 +- p2p/discover/v4_udp_test.go | 20 +- p2p/discover/v5_udp.go | 88 +- p2p/discover/v5_udp_test.go | 32 +- p2p/discover/v5wire/encoding.go | 4 +- portalnetwork/api.go | 543 +++++ portalnetwork/nat.go | 172 ++ portalnetwork/portal_protocol.go | 1918 +++++++++++++++++ portalnetwork/portal_protocol_metrics.go | 67 + portalnetwork/portal_protocol_test.go | 514 +++++ portalnetwork/portal_utp.go | 139 ++ portalnetwork/portalwire/messages.go | 336 +++ portalnetwork/portalwire/messages_encoding.go | 957 ++++++++ portalnetwork/portalwire/messages_test.go | 212 ++ 27 files changed, 5149 insertions(+), 258 deletions(-) create mode 100644 portalnetwork/api.go create mode 100644 portalnetwork/nat.go create mode 100644 portalnetwork/portal_protocol.go create mode 100644 portalnetwork/portal_protocol_metrics.go create mode 100644 portalnetwork/portal_protocol_test.go create mode 100644 portalnetwork/portal_utp.go create mode 100644 portalnetwork/portalwire/messages.go create mode 100644 portalnetwork/portalwire/messages_encoding.go create mode 100644 portalnetwork/portalwire/messages_test.go diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 8c48b3a557c1..0c832262a67c 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -163,7 +163,7 @@ func discv4Ping(ctx *cli.Context) error { defer disc.Close() start := time.Now() - if err := disc.Ping(n); err != nil { + if err := disc.PingWithoutResp(n); err != nil { return fmt.Errorf("node didn't respond: %v", err) } fmt.Printf("node responded to ping (RTT %v).\n", time.Since(start)) diff --git a/cmd/devp2p/discv5cmd.go b/cmd/devp2p/discv5cmd.go index 2422ef6644c9..b8a02b560acb 100644 --- a/cmd/devp2p/discv5cmd.go +++ b/cmd/devp2p/discv5cmd.go @@ -84,7 +84,7 @@ func discv5Ping(ctx *cli.Context) error { disc, _ := startV5(ctx) defer disc.Close() - fmt.Println(disc.Ping(n)) + fmt.Println(disc.PingWithoutResp(n)) return nil } diff --git a/p2p/discover/api.go b/p2p/discover/api.go index 4915fa688e2a..e7fe5c764ba3 100644 --- a/p2p/discover/api.go +++ b/p2p/discover/api.go @@ -114,7 +114,7 @@ func (d *DiscV5API) GetEnr(nodeId string) (bool, error) { if err != nil { return false, err } - n := d.DiscV5.tab.getNode(id) + n := d.DiscV5.tab.GetNode(id) if n == nil { return false, errors.New("record not in local routing table") } @@ -128,7 +128,7 @@ func (d *DiscV5API) DeleteEnr(nodeId string) (bool, error) { return false, err } - n := d.DiscV5.tab.getNode(id) + n := d.DiscV5.tab.GetNode(id) if n == nil { return false, errors.New("record not in local routing table") } @@ -161,7 +161,7 @@ func (d *DiscV5API) Ping(enr string) (*DiscV5PongResp, error) { return nil, err } - pong, err := d.DiscV5.pingInner(n) + pong, err := d.DiscV5.PingWithResp(n) if err != nil { return nil, err } @@ -178,7 +178,7 @@ func (d *DiscV5API) FindNodes(enr string, distances []uint) ([]string, error) { if err != nil { return nil, err } - findNodes, err := d.DiscV5.findnode(n, distances) + findNodes, err := d.DiscV5.Findnode(n, distances) if err != nil { return nil, err } @@ -283,7 +283,7 @@ func (p *PortalProtocolAPI) GetEnr(nodeId string) (string, error) { return p.portalProtocol.localNode.Node().String(), nil } - n := p.portalProtocol.table.getNode(id) + n := p.portalProtocol.table.GetNode(id) if n == nil { return "", errors.New("record not in local routing table") } @@ -297,7 +297,7 @@ func (p *PortalProtocolAPI) DeleteEnr(nodeId string) (bool, error) { return false, err } - n := p.portalProtocol.table.getNode(id) + n := p.portalProtocol.table.GetNode(id) if n == nil { return false, nil } diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index 09808b71e079..86e606ac5c79 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -24,30 +24,30 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" ) -// lookup performs a network search for nodes close to the given target. It approaches the +// Lookup performs a network search for nodes close to the given target. It approaches the // target by querying nodes that are closer to it on each iteration. The given target does // not need to be an actual node identifier. -type lookup struct { +type Lookup struct { tab *Table queryfunc queryFunc replyCh chan []*enode.Node cancelCh <-chan struct{} asked, seen map[enode.ID]bool - result nodesByDistance + result NodesByDistance replyBuffer []*enode.Node queries int } type queryFunc func(*enode.Node) ([]*enode.Node, error) -func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { - it := &lookup{ +func NewLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *Lookup { + it := &Lookup{ tab: tab, queryfunc: q, asked: make(map[enode.ID]bool), seen: make(map[enode.ID]bool), - result: nodesByDistance{target: target}, - replyCh: make(chan []*enode.Node, alpha), + result: NodesByDistance{Target: target}, + replyCh: make(chan []*enode.Node, Alpha), cancelCh: ctx.Done(), queries: -1, } @@ -57,16 +57,16 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l return it } -// run runs the lookup to completion and returns the closest nodes found. -func (it *lookup) run() []*enode.Node { +// Run runs the lookup to completion and returns the closest nodes found. +func (it *Lookup) Run() []*enode.Node { for it.advance() { } - return it.result.entries + return it.result.Entries } // advance advances the lookup until any new nodes have been found. // It returns false when the lookup has ended. -func (it *lookup) advance() bool { +func (it *Lookup) advance() bool { for it.startQueries() { select { case nodes := <-it.replyCh: @@ -74,7 +74,7 @@ func (it *lookup) advance() bool { for _, n := range nodes { if n != nil && !it.seen[n.ID()] { it.seen[n.ID()] = true - it.result.push(n, bucketSize) + it.result.Push(n, BucketSize) it.replyBuffer = append(it.replyBuffer, n) } } @@ -89,7 +89,7 @@ func (it *lookup) advance() bool { return false } -func (it *lookup) shutdown() { +func (it *Lookup) shutdown() { for it.queries > 0 { <-it.replyCh it.queries-- @@ -98,28 +98,28 @@ func (it *lookup) shutdown() { it.replyBuffer = nil } -func (it *lookup) startQueries() bool { +func (it *Lookup) startQueries() bool { if it.queryfunc == nil { return false } // The first query returns nodes from the local table. if it.queries == -1 { - closest := it.tab.findnodeByID(it.result.target, bucketSize, false) + closest := it.tab.FindnodeByID(it.result.Target, BucketSize, false) // Avoid finishing the lookup too quickly if table is empty. It'd be better to wait // for the table to fill in this case, but there is no good mechanism for that // yet. - if len(closest.entries) == 0 { + if len(closest.Entries) == 0 { it.slowdown() } it.queries = 1 - it.replyCh <- closest.entries + it.replyCh <- closest.Entries return true } // Ask the closest nodes that we haven't asked yet. - for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ { - n := it.result.entries[i] + for i := 0; i < len(it.result.Entries) && it.queries < Alpha; i++ { + n := it.result.Entries[i] if !it.asked[n.ID()] { it.asked[n.ID()] = true it.queries++ @@ -130,7 +130,7 @@ func (it *lookup) startQueries() bool { return it.queries > 0 } -func (it *lookup) slowdown() { +func (it *Lookup) slowdown() { sleep := time.NewTimer(1 * time.Second) defer sleep.Stop() select { @@ -139,9 +139,9 @@ func (it *lookup) slowdown() { } } -func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) { +func (it *Lookup) query(n *enode.Node, reply chan<- []*enode.Node) { r, err := it.queryfunc(n) - if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. + if !errors.Is(err, ErrClosed) { // avoid recording failures on shutdown. success := len(r) > 0 it.tab.trackRequest(n, success, r) if err != nil { @@ -158,10 +158,10 @@ type lookupIterator struct { nextLookup lookupFunc ctx context.Context cancel func() - lookup *lookup + lookup *Lookup } -type lookupFunc func(ctx context.Context) *lookup +type lookupFunc func(ctx context.Context) *Lookup func newLookupIterator(ctx context.Context, next lookupFunc) *lookupIterator { ctx, cancel := context.WithCancel(ctx) diff --git a/p2p/discover/node.go b/p2p/discover/node.go index ac34b7c5b2ea..8b6ec83c0376 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -54,27 +54,27 @@ func (n *tableNode) String() string { return n.Node.String() } -// nodesByDistance is a list of nodes, ordered by distance to target. -type nodesByDistance struct { - entries []*enode.Node - target enode.ID +// NodesByDistance is a list of nodes, ordered by distance to target. +type NodesByDistance struct { + Entries []*enode.Node + Target enode.ID } -// push adds the given node to the list, keeping the total size below maxElems. -func (h *nodesByDistance) push(n *enode.Node, maxElems int) { - ix := sort.Search(len(h.entries), func(i int) bool { - return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 +// Push adds the given node to the list, keeping the total size below maxElems. +func (h *NodesByDistance) Push(n *enode.Node, maxElems int) { + ix := sort.Search(len(h.Entries), func(i int) bool { + return enode.DistCmp(h.Target, h.Entries[i].ID(), n.ID()) > 0 }) - end := len(h.entries) - if len(h.entries) < maxElems { - h.entries = append(h.entries, n) + end := len(h.Entries) + if len(h.Entries) < maxElems { + h.Entries = append(h.Entries, n) } if ix < end { // Slide existing entries down to make room. // This will overwrite the entry we just appended. - copy(h.entries[ix+1:], h.entries[ix:]) - h.entries[ix] = n + copy(h.Entries[ix+1:], h.Entries[ix:]) + h.Entries[ix] = n } } diff --git a/p2p/discover/portal_protocol.go b/p2p/discover/portal_protocol.go index b1be63233c4e..8e2129854e73 100644 --- a/p2p/discover/portal_protocol.go +++ b/p2p/discover/portal_protocol.go @@ -255,7 +255,7 @@ func (p *PortalProtocol) Start() error { return err } - go p.table.loop() + go p.table.Loop() for i := 0; i < concurrentOffers; i++ { go p.offerWorker() @@ -269,7 +269,7 @@ func (p *PortalProtocol) Start() error { func (p *PortalProtocol) Stop() { p.cancelCloseCtx() - p.table.close() + p.table.Close() p.DiscV5.Close() if p.Utp != nil { p.Utp.Stop() @@ -335,7 +335,7 @@ func (p *PortalProtocol) setupDiscV5AndTable() error { Log: p.Log, } - p.table, err = newTable(p, p.localNode.Database(), cfg) + p.table, err = NewTable(p, p.localNode.Database(), cfg) if err != nil { return err } @@ -343,7 +343,7 @@ func (p *PortalProtocol) setupDiscV5AndTable() error { return nil } -func (p *PortalProtocol) ping(node *enode.Node) (uint64, error) { +func (p *PortalProtocol) Ping(node *enode.Node) (uint64, error) { pong, err := p.pingInner(node) if err != nil { return 0, err @@ -515,7 +515,7 @@ func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request * if metrics.Enabled { p.portalMetrics.messagesReceivedAccept.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -651,7 +651,7 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, if metrics.Enabled { p.portalMetrics.messagesReceivedContent.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -669,7 +669,7 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, if metrics.Enabled { p.portalMetrics.messagesReceivedContent.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -729,7 +729,7 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, if metrics.Enabled { p.portalMetrics.messagesReceivedContent.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -757,7 +757,7 @@ func (p *PortalProtocol) processNodes(target *enode.Node, resp []byte, distances return nil, err } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -828,7 +828,7 @@ func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwi if metrics.Enabled { p.portalMetrics.messagesReceivedPong.Mark(1) } - isAdded := p.table.addFoundNode(target, true) + isAdded := p.table.AddFoundNode(target, true) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } else { @@ -840,8 +840,8 @@ func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwi } func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { - if n := p.DiscV5.getNode(id); n != nil { - p.table.addInboundNode(n) + if n := p.DiscV5.GetNode(id); n != nil { + p.table.AddInboundNode(n) } msgCode := msg[0] @@ -1377,7 +1377,7 @@ func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, d return nil, errors.New("not contained in netrestrict list") } if n.UDP() <= 1024 { - return nil, errLowPort + return nil, ErrLowPort } if distances != nil { nd := enode.LogDist(sender.ID(), n.ID()) @@ -1394,24 +1394,24 @@ func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, d // lookupRandom looks up a random target. // This is needed to satisfy the transport interface. -func (p *PortalProtocol) lookupRandom() []*enode.Node { - return p.newRandomLookup(p.closeCtx).run() +func (p *PortalProtocol) LookupRandom() []*enode.Node { + return p.newRandomLookup(p.closeCtx).Run() } // lookupSelf looks up our own node ID. // This is needed to satisfy the transport interface. -func (p *PortalProtocol) lookupSelf() []*enode.Node { - return p.newLookup(p.closeCtx, p.Self().ID()).run() +func (p *PortalProtocol) LookupSelf() []*enode.Node { + return p.newLookup(p.closeCtx, p.Self().ID()).Run() } -func (p *PortalProtocol) newRandomLookup(ctx context.Context) *lookup { +func (p *PortalProtocol) newRandomLookup(ctx context.Context) *Lookup { var target enode.ID _, _ = crand.Read(target[:]) return p.newLookup(ctx, target) } -func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *lookup { - return newLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { +func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *Lookup { + return NewLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { return p.lookupWorker(n, target) }) } @@ -1419,28 +1419,28 @@ func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *lookup // lookupWorker performs FINDNODE calls against a single node during lookup. func (p *PortalProtocol) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { var ( - dists = lookupDistances(target, destNode.ID()) - nodes = nodesByDistance{target: target} + dists = LookupDistances(target, destNode.ID()) + nodes = NodesByDistance{Target: target} err error ) var r []*enode.Node r, err = p.findNodes(destNode, dists) - if errors.Is(err, errClosed) { + if errors.Is(err, ErrClosed) { return nil, err } for _, n := range r { if n.ID() != p.Self().ID() { - isAdded := p.table.addFoundNode(n, false) + isAdded := p.table.AddFoundNode(n, false) if isAdded { log.Debug("Node added to bucket", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) } else { log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) } - nodes.push(n, portalFindnodesResultLimit) + nodes.Push(n, portalFindnodesResultLimit) } } - return nodes.entries, err + return nodes.Entries, err } func (p *PortalProtocol) offerWorker() { @@ -1496,13 +1496,13 @@ func (p *PortalProtocol) findNodesCloseToContent(contentId []byte, limit int) [] // Lookup performs a recursive lookup for the given target. // It returns the closest nodes to target. func (p *PortalProtocol) Lookup(target enode.ID) []*enode.Node { - return p.newLookup(p.closeCtx, target).run() + return p.newLookup(p.closeCtx, target).Run() } // Resolve searches for a specific Node with the given ID and tries to get the most recent // version of the Node record for it. It returns n if the Node could not be resolved. func (p *PortalProtocol) Resolve(n *enode.Node) *enode.Node { - if intable := p.table.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + if intable := p.table.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { n = intable } // Try asking directly. This works if the Node is still responding on the endpoint we have. @@ -1527,7 +1527,7 @@ func (p *PortalProtocol) ResolveNodeId(id enode.ID) *enode.Node { return p.Self() } - n := p.table.getNode(id) + n := p.table.GetNode(id) if n != nil { p.Log.Debug("found Id in table and will request enr from the node", "id", id.String()) // Try asking directly. This works if the Node is still responding on the endpoint we have. @@ -1564,7 +1564,7 @@ func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit i processed[dist] = struct{}{} checkLive := !p.table.cfg.NoFindnodeLivenessCheck - for _, n := range p.table.appendBucketNodes(dist, bn[:0], checkLive) { + for _, n := range p.table.AppendBucketNodes(dist, bn[:0], checkLive) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. if netutil.CheckRelayIP(rip, n.IP()) != nil { @@ -1582,7 +1582,7 @@ func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit i func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bool, error) { lookupContext, cancel := context.WithCancel(context.Background()) - resChan := make(chan *traceContentInfoResp, alpha) + resChan := make(chan *traceContentInfoResp, Alpha) hasResult := int32(0) result := ContentInfoResp{} @@ -1600,9 +1600,9 @@ func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bo } }() - newLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) - }).run() + }).Run() close(resChan) wg.Wait() @@ -1616,7 +1616,7 @@ func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bo func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*TraceContentResult, error) { lookupContext, cancel := context.WithCancel(context.Background()) // resp channel - resChan := make(chan *traceContentInfoResp, alpha) + resChan := make(chan *traceContentInfoResp, Alpha) hasResult := int32(0) @@ -1633,10 +1633,10 @@ func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*Trac Cancelled: make([]string, 0), } - nodes := p.table.findnodeByID(enode.ID(contentId), bucketSize, false) + nodes := p.table.FindnodeByID(enode.ID(contentId), BucketSize, false) - localResponse := make([]string, 0, len(nodes.entries)) - for _, node := range nodes.entries { + localResponse := make([]string, 0, len(nodes.Entries)) + for _, node := range nodes.Entries { id := "0x" + node.ID().String() localResponse = append(localResponse, id) } @@ -1698,10 +1698,10 @@ func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*Trac } }() - lookup := newLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + lookup := NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) }) - lookup.run() + lookup.Run() close(resChan) wg.Wait() diff --git a/p2p/discover/portal_utp.go b/p2p/discover/portal_utp.go index e8c8e8f74ccd..589bd2bd15fe 100644 --- a/p2p/discover/portal_utp.go +++ b/p2p/discover/portal_utp.go @@ -122,7 +122,7 @@ func (p *PortalUtp) packetRouterFunc(buf []byte, id enode.ID, addr *net.UDPAddr) if n, ok := p.discV5.GetCachedNode(addr.String()); ok { //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf} - p.discV5.sendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) + p.discV5.SendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) return len(buf), nil } else { diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 0ad7f1bef496..547bf3440986 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -39,8 +39,8 @@ import ( ) const ( - alpha = 3 // Kademlia concurrency factor - bucketSize = 16 // Kademlia bucket size + Alpha = 3 // Kademlia concurrency factor + BucketSize = 16 // Kademlia bucket size maxReplacements = 10 // Size of per-bucket replacement list // We keep buckets for the upper 1/15 of distances because @@ -92,9 +92,9 @@ type Table struct { type transport interface { Self() *enode.Node RequestENR(*enode.Node) (*enode.Node, error) - lookupRandom() []*enode.Node - lookupSelf() []*enode.Node - ping(*enode.Node) (seq uint64, err error) + LookupRandom() []*enode.Node + LookupSelf() []*enode.Node + Ping(*enode.Node) (seq uint64, err error) } // bucket contains nodes, ordered by their last activity. the entry @@ -118,7 +118,7 @@ type trackRequestOp struct { success bool } -func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) { +func NewTable(t transport, db *enode.DB, cfg Config) (*Table, error) { cfg = cfg.withDefaults() tab := &Table{ net: t, @@ -196,8 +196,8 @@ func (tab *Table) self() *enode.Node { return tab.net.Self() } -// getNode returns the node with the given ID or nil if it isn't in the table. -func (tab *Table) getNode(id enode.ID) *enode.Node { +// GetNode returns the node with the given ID or nil if it isn't in the table. +func (tab *Table) GetNode(id enode.ID) *enode.Node { tab.mutex.Lock() defer tab.mutex.Unlock() @@ -210,8 +210,8 @@ func (tab *Table) getNode(id enode.ID) *enode.Node { return nil } -// close terminates the network listener and flushes the node database. -func (tab *Table) close() { +// Close terminates the network listener and flushes the node database. +func (tab *Table) Close() { close(tab.closeReq) <-tab.closed } @@ -255,40 +255,40 @@ func (tab *Table) refresh() <-chan struct{} { return done } -// findnodeByID returns the n nodes in the table that are closest to the given id. +// FindnodeByID returns the n nodes in the table that are closest to the given id. // This is used by the FINDNODE/v4 handler. // // The preferLive parameter says whether the caller wants liveness-checked results. If // preferLive is true and the table contains any verified nodes, the result will not // contain unverified nodes. However, if there are no verified nodes at all, the result // will contain unverified nodes. -func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance { +func (tab *Table) FindnodeByID(target enode.ID, nresults int, preferLive bool) *NodesByDistance { tab.mutex.Lock() defer tab.mutex.Unlock() // Scan all buckets. There might be a better way to do this, but there aren't that many // buckets, so this solution should be fine. The worst-case complexity of this loop // is O(tab.len() * nresults). - nodes := &nodesByDistance{target: target} - liveNodes := &nodesByDistance{target: target} + nodes := &NodesByDistance{Target: target} + liveNodes := &NodesByDistance{Target: target} for _, b := range &tab.buckets { for _, n := range b.entries { - nodes.push(n.Node, nresults) + nodes.Push(n.Node, nresults) if preferLive && n.isValidatedLive { - liveNodes.push(n.Node, nresults) + liveNodes.Push(n.Node, nresults) } } } - if preferLive && len(liveNodes.entries) > 0 { + if preferLive && len(liveNodes.Entries) > 0 { return liveNodes } return nodes } -// appendBucketNodes adds nodes at the given distance to the result slice. +// AppendBucketNodes adds nodes at the given distance to the result slice. // This is used by the FINDNODE/v5 handler. -func (tab *Table) appendBucketNodes(dist uint, result []*enode.Node, checkLive bool) []*enode.Node { +func (tab *Table) AppendBucketNodes(dist uint, result []*enode.Node, checkLive bool) []*enode.Node { if dist > 256 { return result } @@ -322,12 +322,12 @@ func (tab *Table) len() (n int) { return n } -// addFoundNode adds a node which may not be live. If the bucket has space available, +// AddFoundNode adds a node which may not be live. If the bucket has space available, // adding the node succeeds immediately. Otherwise, the node is added to the replacements // list. // // The caller must not hold tab.mutex. -func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { +func (tab *Table) AddFoundNode(n *enode.Node, forceSetLive bool) bool { op := addNodeOp{node: n, isInbound: false, forceSetLive: forceSetLive} select { case tab.addNodeCh <- op: @@ -337,7 +337,7 @@ func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { } } -// addInboundNode adds a node from an inbound contact. If the bucket has no space, the +// AddInboundNode adds a node from an inbound contact. If the bucket has no space, the // node is added to the replacements list. // // There is an additional safety measure: if the table is still initializing the node is @@ -345,7 +345,7 @@ func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { // repeatedly. // // The caller must not hold tab.mutex. -func (tab *Table) addInboundNode(n *enode.Node) bool { +func (tab *Table) AddInboundNode(n *enode.Node) bool { op := addNodeOp{node: n, isInbound: true} select { case tab.addNodeCh <- op: @@ -363,8 +363,8 @@ func (tab *Table) trackRequest(n *enode.Node, success bool, foundNodes []*enode. } } -// loop is the main loop of Table. -func (tab *Table) loop() { +// Loop is the main loop of Table. +func (tab *Table) Loop() { var ( refresh = time.NewTimer(tab.nextRefreshTime()) refreshDone = make(chan struct{}) // where doRefresh reports completion @@ -447,7 +447,7 @@ func (tab *Table) doRefresh(done chan struct{}) { tab.loadSeedNodes() // Run self lookup to discover new neighbor nodes. - tab.net.lookupSelf() + tab.net.LookupSelf() // The Kademlia paper specifies that the bucket refresh should // perform a lookup in the least recently used bucket. We cannot @@ -456,7 +456,7 @@ func (tab *Table) doRefresh(done chan struct{}) { // sha3 preimage that falls into a chosen bucket. // We perform a few lookups with a random target instead. for i := 0; i < 3; i++ { - tab.net.lookupRandom() + tab.net.LookupRandom() } } @@ -542,7 +542,7 @@ func (tab *Table) handleAddNode(req addNodeOp) bool { tab.log.Debug("the node is already in table", "id", req.node.ID()) return false } - if len(b.entries) >= bucketSize { + if len(b.entries) >= BucketSize { // Bucket full, maybe add as replacement. tab.log.Debug("the bucket is full and will add in replacement", "id", req.node.ID()) tab.addReplacement(b, req.node) @@ -697,7 +697,7 @@ func (tab *Table) handleTrackRequest(op trackRequestOp) { // many times, but only if there are enough other nodes in the bucket. This latter // condition specifically exists to make bootstrapping in smaller test networks more // reliable. - if fails >= maxFindnodeFailures && len(b.entries) >= bucketSize/4 { + if fails >= maxFindnodeFailures && len(b.entries) >= BucketSize/4 { tab.deleteInBucket(b, op.node.ID()) } @@ -717,3 +717,32 @@ func pushNode(list []*tableNode, n *tableNode, max int) ([]*tableNode, *tableNod list[0] = n return list, removed } + +func (tab *Table) WaitInit() { + <-tab.initDone +} + +func (tab *Table) NodeIds() [][]string { + tab.mutex.Lock() + defer tab.mutex.Unlock() + nodes := make([][]string, 0) + for _, b := range &tab.buckets { + bucketNodes := make([]string, 0) + for _, n := range b.entries { + bucketNodes = append(bucketNodes, "0x"+n.ID().String()) + } + nodes = append(nodes, bucketNodes) + } + return nodes +} + +func (tab *Table) Config() Config { + return tab.cfg +} + +func (tab *Table) DeleteNode(n *enode.Node) { + tab.mutex.Lock() + defer tab.mutex.Unlock() + b := tab.bucket(n.ID()) + tab.deleteInBucket(b, n.ID()) +} diff --git a/p2p/discover/table_reval.go b/p2p/discover/table_reval.go index 2465fee9066f..844094cbb80a 100644 --- a/p2p/discover/table_reval.go +++ b/p2p/discover/table_reval.go @@ -111,7 +111,7 @@ func (tr *tableRevalidation) startRequest(tab *Table, n *tableNode) { func (tab *Table) doRevalidate(resp revalidationResponse, node *enode.Node) { // Ping the selected node and wait for a pong response. - remoteSeq, err := tab.net.ping(node) + remoteSeq, err := tab.net.Ping(node) resp.didRespond = err == nil // Also fetch record if the node replied and returned a higher sequence number. diff --git a/p2p/discover/table_reval_test.go b/p2p/discover/table_reval_test.go index 360544393439..16357e42aab0 100644 --- a/p2p/discover/table_reval_test.go +++ b/p2p/discover/table_reval_test.go @@ -63,7 +63,7 @@ func TestRevalidation_nodeRemoved(t *testing.T) { tr.handleResponse(tab, resp) // Ensure the node was not re-added to the table. - if tab.getNode(node.ID()) != nil { + if tab.GetNode(node.ID()) != nil { t.Fatal("node was re-added to Table") } if tr.fast.contains(node.ID()) || tr.slow.contains(node.ID()) { diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 8cc4ae33b2eb..63fa152ffc9d 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -59,7 +59,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding Log: testlog.Logger(t, log.LevelTrace), }) defer db.Close() - defer tab.close() + defer tab.Close() <-tab.initDone @@ -79,7 +79,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding transport.dead[replacementNode.ID()] = !newNodeIsResponding // Add replacement node to table. - tab.addFoundNode(replacementNode, false) + tab.AddFoundNode(replacementNode, false) t.Log("last:", last.ID()) t.Log("replacement:", replacementNode.ID()) @@ -108,7 +108,7 @@ func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding // Check bucket content. tab.mutex.Lock() defer tab.mutex.Unlock() - wantSize := bucketSize + wantSize := BucketSize if !lastInBucketIsResponding && !newNodeIsResponding { wantSize-- } @@ -150,11 +150,11 @@ func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport, Config{}) defer db.Close() - defer tab.close() + defer tab.Close() for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) - tab.addFoundNode(n, false) + tab.AddFoundNode(n, false) } if tab.len() > tableIPLimit { t.Errorf("too many nodes in table") @@ -167,12 +167,12 @@ func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport, Config{}) defer db.Close() - defer tab.close() + defer tab.Close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) - tab.addFoundNode(n, false) + tab.AddFoundNode(n, false) } if tab.len() > bucketIPLimit { t.Errorf("too many nodes in table") @@ -204,11 +204,11 @@ func TestTable_findnodeByID(t *testing.T) { transport := newPingRecorder() tab, db := newTestTable(transport, Config{}) defer db.Close() - defer tab.close() + defer tab.Close() fillTable(tab, test.All, true) // check that closest(Target, N) returns nodes - result := tab.findnodeByID(test.Target, test.N, false).entries + result := tab.FindnodeByID(test.Target, test.N, false).Entries if hasDuplicates(result) { t.Errorf("result contains duplicates") return false @@ -264,7 +264,7 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { t := &closeTest{ Self: gen(enode.ID{}, rand).(enode.ID), Target: gen(enode.ID{}, rand).(enode.ID), - N: rand.Intn(bucketSize), + N: rand.Intn(BucketSize), } for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { r := new(enr.Record) @@ -279,20 +279,20 @@ func TestTable_addInboundNode(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addFoundNode(n1, false) - tab.addFoundNode(n2, false) + tab.AddFoundNode(n1, false) + tab.AddFoundNode(n2, false) checkBucketContent(t, tab, []*enode.Node{n1, n2}) // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) n2v2 := enode.SignNull(newrec, n2.ID()) - tab.addInboundNode(n2v2) + tab.AddInboundNode(n2v2) checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) // Try updating n2 without sequence number change. The update is accepted @@ -301,7 +301,7 @@ func TestTable_addInboundNode(t *testing.T) { newrec.Set(enr.IP{100, 100, 100, 100}) newrec.SetSeq(n2.Seq()) n2v3 := enode.SignNull(newrec, n2.ID()) - tab.addInboundNode(n2v3) + tab.AddInboundNode(n2v3) checkBucketContent(t, tab, []*enode.Node{n1, n2v3}) } @@ -309,20 +309,20 @@ func TestTable_addFoundNode(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addFoundNode(n1, false) - tab.addFoundNode(n2, false) + tab.AddFoundNode(n1, false) + tab.AddFoundNode(n2, false) checkBucketContent(t, tab, []*enode.Node{n1, n2}) // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) n2v2 := enode.SignNull(newrec, n2.ID()) - tab.addFoundNode(n2v2, false) + tab.AddFoundNode(n2v2, false) checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) // Try updating n2 without a sequence number change. @@ -331,7 +331,7 @@ func TestTable_addFoundNode(t *testing.T) { newrec.Set(enr.IP{100, 100, 100, 100}) newrec.SetSeq(n2.Seq()) n2v3 := enode.SignNull(newrec, n2.ID()) - tab.addFoundNode(n2v3, false) + tab.AddFoundNode(n2v3, false) checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) } @@ -340,18 +340,18 @@ func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Add a v4 node. key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) - tab.addInboundNode(n1) + tab.AddInboundNode(n1) checkBucketContent(t, tab, []*enode.Node{n1}) // Add an updated version with changed IP. // The update will be accepted because it is inbound. n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) - tab.addInboundNode(n1v2) + tab.AddInboundNode(n1v2) checkBucketContent(t, tab, []*enode.Node{n1v2}) } @@ -361,18 +361,18 @@ func TestTable_addFoundNodeV4UpdateReject(t *testing.T) { tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Add a v4 node. key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) - tab.addFoundNode(n1, false) + tab.AddFoundNode(n1, false) checkBucketContent(t, tab, []*enode.Node{n1}) // Add an updated version with changed IP. // The update won't be accepted because it isn't inbound. n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) - tab.addFoundNode(n1v2, false) + tab.AddFoundNode(n1v2, false) checkBucketContent(t, tab, []*enode.Node{n1}) } @@ -407,14 +407,14 @@ func TestTable_revalidateSyncRecord(t *testing.T) { }) <-tab.initDone defer db.Close() - defer tab.close() + defer tab.Close() // Insert a node. var r enr.Record r.Set(enr.IP(net.IP{127, 0, 0, 1})) id := enode.ID{1} n1 := enode.SignNull(&r, id) - tab.addFoundNode(n1, false) + tab.AddFoundNode(n1, false) // Update the node record. r.Set(enr.WithEntry("foo", "bar")) @@ -426,7 +426,7 @@ func TestTable_revalidateSyncRecord(t *testing.T) { waitForRevalidationPing(t, transport, tab, n2.ID()) waitForRevalidationPing(t, transport, tab, n2.ID()) - intable := tab.getNode(id) + intable := tab.GetNode(id) if !reflect.DeepEqual(intable, n2) { t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq()) } @@ -448,22 +448,22 @@ func TestNodesPush(t *testing.T) { // Insert all permutations into lists with size limit 3. for _, nodes := range perm { - list := nodesByDistance{target: target} + list := NodesByDistance{Target: target} for _, n := range nodes { - list.push(n, 3) + list.Push(n, 3) } - if !slices.EqualFunc(list.entries, perm[0], nodeIDEqual) { + if !slices.EqualFunc(list.Entries, perm[0], nodeIDEqual) { t.Fatal("not equal") } } // Insert all permutations into lists with size limit 2. for _, nodes := range perm { - list := nodesByDistance{target: target} + list := NodesByDistance{Target: target} for _, n := range nodes { - list.push(n, 2) + list.Push(n, 2) } - if !slices.EqualFunc(list.entries, perm[0][:2], nodeIDEqual) { + if !slices.EqualFunc(list.Entries, perm[0][:2], nodeIDEqual) { t.Fatal("not equal") } } diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index 254471c25a1e..343be71f2f4b 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -45,14 +45,14 @@ func init() { func newTestTable(t transport, cfg Config) (*Table, *enode.DB) { tab, db := newInactiveTestTable(t, cfg) - go tab.loop() + go tab.Loop() return tab, db } // newInactiveTestTable creates a Table without running the main loop. func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) { db, _ := enode.OpenDB("") - tab, _ := newTable(t, db, cfg) + tab, _ := NewTable(t, db, cfg) return tab, db } @@ -110,20 +110,20 @@ func intIP(i int) net.IP { func fillBucket(tab *Table, id enode.ID) (last *tableNode) { ld := enode.LogDist(tab.self().ID(), id) b := tab.bucket(id) - for len(b.entries) < bucketSize { + for len(b.entries) < BucketSize { node := nodeAtDistance(tab.self().ID(), ld, intIP(ld)) - if !tab.addFoundNode(node, false) { + if !tab.AddFoundNode(node, false) { panic("node not added") } } - return b.entries[bucketSize-1] + return b.entries[BucketSize-1] } // fillTable adds nodes the table to the end of their corresponding bucket // if the bucket is not full. The caller must not hold tab.mutex. func fillTable(tab *Table, nodes []*enode.Node, setLive bool) { for _, n := range nodes { - tab.addFoundNode(n, setLive) + tab.AddFoundNode(n, setLive) } } @@ -160,8 +160,8 @@ func (t *pingRecorder) updateRecord(n *enode.Node) { // Stubs to satisfy the transport interface. func (t *pingRecorder) Self() *enode.Node { return nullNode } -func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } -func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } +func (t *pingRecorder) LookupSelf() []*enode.Node { return nil } +func (t *pingRecorder) LookupRandom() []*enode.Node { return nil } func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node { t.mu.Lock() @@ -190,7 +190,7 @@ func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node { } // ping simulates a ping request. -func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { +func (t *pingRecorder) Ping(n *enode.Node) (seq uint64, err error) { t.mu.Lock() defer t.mu.Unlock() diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 29a9dd6645e0..f7515fd3a9fa 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -59,8 +59,8 @@ func TestUDPv4_Lookup(t *testing.T) { for _, e := range results { t.Logf(" ld=%d, %x", enode.LogDist(lookupTestnet.target.ID(), e.ID()), e.ID().Bytes()) } - if len(results) != bucketSize { - t.Errorf("wrong number of results: got %d, want %d", len(results), bucketSize) + if len(results) != BucketSize { + t.Errorf("wrong number of results: got %d, want %d", len(results), BucketSize) } checkLookupResults(t, lookupTestnet, results) } diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index 29ae5f2c084d..f1db0d63f234 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -43,8 +43,8 @@ var ( errUnknownNode = errors.New("unknown node") errTimeout = errors.New("RPC timeout") errClockWarp = errors.New("reply deadline too far in the future") - errClosed = errors.New("socket closed") - errLowPort = errors.New("low port") + ErrClosed = errors.New("socket closed") + ErrLowPort = errors.New("low port") errNoUDPEndpoint = errors.New("node has no UDP endpoint") ) @@ -143,12 +143,12 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { log: cfg.Log, } - tab, err := newTable(t, ln.Database(), cfg) + tab, err := NewTable(t, ln.Database(), cfg) if err != nil { return nil, err } t.tab = tab - go tab.loop() + go tab.Loop() t.wg.Add(2) go t.loop() @@ -167,7 +167,7 @@ func (t *UDPv4) Close() { t.cancelCloseCtx() t.conn.Close() t.wg.Wait() - t.tab.close() + t.tab.Close() }) } @@ -179,7 +179,7 @@ func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { return rn } // Check table for the ID, we might have a newer version there. - if intable := t.tab.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + if intable := t.tab.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { n = intable if rn, err := t.RequestENR(n); err == nil { return rn @@ -210,14 +210,14 @@ func (t *UDPv4) ourEndpoint() v4wire.Endpoint { return v4wire.NewEndpoint(addr, uint16(node.TCP())) } -// Ping sends a ping message to the given node. -func (t *UDPv4) Ping(n *enode.Node) error { - _, err := t.ping(n) +// PingWithoutResp sends a ping message to the given node. +func (t *UDPv4) PingWithoutResp(n *enode.Node) error { + _, err := t.Ping(n) return err } // ping sends a ping message to the given node and waits for a reply. -func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { +func (t *UDPv4) Ping(n *enode.Node) (seq uint64, err error) { addr, ok := n.UDPEndpoint() if !ok { return 0, errNoUDPEndpoint @@ -271,7 +271,7 @@ func (t *UDPv4) LookupPubkey(key *ecdsa.PublicKey) []*enode.Node { // case and run the bootstrapping logic. <-t.tab.refresh() } - return t.newLookup(t.closeCtx, v4wire.EncodePubkey(key)).run() + return t.newLookup(t.closeCtx, v4wire.EncodePubkey(key)).Run() } // RandomNodes is an iterator yielding nodes from a random walk of the DHT. @@ -280,25 +280,25 @@ func (t *UDPv4) RandomNodes() enode.Iterator { } // lookupRandom implements transport. -func (t *UDPv4) lookupRandom() []*enode.Node { - return t.newRandomLookup(t.closeCtx).run() +func (t *UDPv4) LookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).Run() } // lookupSelf implements transport. -func (t *UDPv4) lookupSelf() []*enode.Node { +func (t *UDPv4) LookupSelf() []*enode.Node { pubkey := v4wire.EncodePubkey(&t.priv.PublicKey) - return t.newLookup(t.closeCtx, pubkey).run() + return t.newLookup(t.closeCtx, pubkey).Run() } -func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { +func (t *UDPv4) newRandomLookup(ctx context.Context) *Lookup { var target v4wire.Pubkey crand.Read(target[:]) return t.newLookup(ctx, target) } -func (t *UDPv4) newLookup(ctx context.Context, targetKey v4wire.Pubkey) *lookup { +func (t *UDPv4) newLookup(ctx context.Context, targetKey v4wire.Pubkey) *Lookup { target := enode.ID(crypto.Keccak256Hash(targetKey[:])) - it := newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { + it := NewLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { addr, ok := n.UDPEndpoint() if !ok { return nil, errNoUDPEndpoint @@ -315,7 +315,7 @@ func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is // active until enough nodes have been received. - nodes := make([]*enode.Node, 0, bucketSize) + nodes := make([]*enode.Node, 0, BucketSize) nreceived := 0 rm := t.pending(toid, toAddrPort.Addr(), v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { reply := r.(*v4wire.Neighbors) @@ -328,7 +328,7 @@ func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire } nodes = append(nodes, n) } - return true, nreceived >= bucketSize + return true, nreceived >= BucketSize }) t.send(toAddrPort, toid, &v4wire.Findnode{ Target: target, @@ -400,7 +400,7 @@ func (t *UDPv4) pending(id enode.ID, ip netip.Addr, ptype byte, callback replyMa case t.addReplyMatcher <- p: // loop will handle it case <-t.closeCtx.Done(): - ch <- errClosed + ch <- ErrClosed } return p } @@ -461,7 +461,7 @@ func (t *UDPv4) loop() { select { case <-t.closeCtx.Done(): for el := plist.Front(); el != nil; el = el.Next() { - el.Value.(*replyMatcher).errc <- errClosed + el.Value.(*replyMatcher).errc <- ErrClosed } return @@ -599,7 +599,7 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) { func (t *UDPv4) nodeFromRPC(sender netip.AddrPort, rn v4wire.Node) (*enode.Node, error) { if rn.UDP <= 1024 { - return nil, errLowPort + return nil, ErrLowPort } if err := netutil.CheckRelayIP(sender.Addr().AsSlice(), rn.IP); err != nil { return nil, err @@ -692,10 +692,10 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port())) if time.Since(t.db.LastPongReceived(n.ID(), from.Addr())) > bondExpiration { t.sendPing(fromID, from, func() { - t.tab.addInboundNode(n) + t.tab.AddInboundNode(n) }) } else { - t.tab.addInboundNode(n) + t.tab.AddInboundNode(n) } // Update node database and endpoint predictor. @@ -747,7 +747,7 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID e // Determine closest nodes. target := enode.ID(crypto.Keccak256Hash(req.Target[:])) preferLive := !t.tab.cfg.NoFindnodeLivenessCheck - closest := t.tab.findnodeByID(target, bucketSize, preferLive).entries + closest := t.tab.FindnodeByID(target, BucketSize, preferLive).Entries // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the packet size limit. diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 1af31f4f1b9b..004fe6d7e80a 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -112,7 +112,7 @@ func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() dgram, err := test.pipe.receive() - if err == errClosed { + if err == ErrClosed { return true } else if err != nil { test.t.Error("packet receive error:", err) @@ -151,7 +151,7 @@ func TestUDPv4_pingTimeout(t *testing.T) { key := newkey() toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} node := enode.NewV4(&key.PublicKey, toaddr.IP, 0, toaddr.Port) - if _, err := test.udp.ping(node); err != errTimeout { + if _, err := test.udp.Ping(node); err != errTimeout { t.Error("expected timeout error, got", err) } } @@ -256,9 +256,9 @@ func TestUDPv4_findnode(t *testing.T) { // put a few nodes into the table. their exact // distribution shouldn't matter much, although we need to // take care not to overflow any bucket. - nodes := &nodesByDistance{target: testTarget.ID()} + nodes := &NodesByDistance{Target: testTarget.ID()} live := make(map[enode.ID]bool) - numCandidates := 2 * bucketSize + numCandidates := 2 * BucketSize for i := 0; i < numCandidates; i++ { key := newkey() ip := net.IP{10, 13, 0, byte(i)} @@ -267,8 +267,8 @@ func TestUDPv4_findnode(t *testing.T) { if i > numCandidates/2 { live[n.ID()] = true } - test.table.addFoundNode(n, live[n.ID()]) - nodes.push(n, numCandidates) + test.table.AddFoundNode(n, live[n.ID()]) + nodes.Push(n, numCandidates) } // ensure there's a bond with the test node, @@ -277,7 +277,7 @@ func TestUDPv4_findnode(t *testing.T) { test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr(), time.Now()) // check that closest neighbors are returned. - expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) + expected := test.table.FindnodeByID(testTarget.ID(), BucketSize, true) test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp}) waitNeighbors := func(want []*enode.Node) { test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) { @@ -287,7 +287,7 @@ func TestUDPv4_findnode(t *testing.T) { } for i, n := range p.Nodes { if n.ID.ID() != want[i].ID() { - t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.entries[i]) + t.Errorf("result mismatch at %d:\n got: %v\n want: %v", i, n, expected.Entries[i]) } if !live[n.ID.ID()] { t.Errorf("result includes dead node %v", n.ID.ID()) @@ -296,7 +296,7 @@ func TestUDPv4_findnode(t *testing.T) { }) } // Receive replies. - want := expected.entries + want := expected.Entries if len(want) > v4wire.MaxNeighbors { waitNeighbors(want[:v4wire.MaxNeighbors]) want = want[v4wire.MaxNeighbors:] @@ -644,7 +644,7 @@ func (c *dgramPipe) receive() (dgram, error) { c.cond.Wait() } if c.closed { - return dgram{}, errClosed + return dgram{}, ErrClosed } if timedOut { return dgram{}, errTimeout diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 6383f5e4a731..db66231b02da 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -140,7 +140,7 @@ func ListenV5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { if err != nil { return nil, err } - go t.tab.loop() + go t.tab.Loop() t.wg.Add(2) go t.readLoop() go t.dispatch() @@ -180,7 +180,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { cancelCloseCtx: cancelCloseCtx, } t.talk = newTalkSystem(t) - tab, err := newTable(t, t.db, cfg) + tab, err := NewTable(t, t.db, cfg) if err != nil { return nil, err } @@ -200,20 +200,20 @@ func (t *UDPv5) Close() { t.conn.Close() t.talk.wait() t.wg.Wait() - t.tab.close() + t.tab.Close() }) } -// Ping sends a ping message to the given node. -func (t *UDPv5) Ping(n *enode.Node) error { - _, err := t.ping(n) +// PingWithoutResp sends a ping message to the given node. +func (t *UDPv5) PingWithoutResp(n *enode.Node) error { + _, err := t.Ping(n) return err } // Resolve searches for a specific node with the given ID and tries to get the most recent // version of the node record for it. It returns n if the node could not be resolved. func (t *UDPv5) Resolve(n *enode.Node) *enode.Node { - if intable := t.tab.getNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + if intable := t.tab.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { n = intable } // Try asking directly. This works if the node is still responding on the endpoint we have. @@ -237,7 +237,7 @@ func (t *UDPv5) ResolveNodeId(id enode.ID) *enode.Node { return t.Self() } - n := t.tab.getNode(id) + n := t.tab.GetNode(id) if n != nil { // Try asking directly. This works if the Node is still responding on the endpoint we have. if resp, err := t.RequestENR(n); err == nil { @@ -341,29 +341,29 @@ func (t *UDPv5) RandomNodes() enode.Iterator { // Lookup performs a recursive lookup for the given target. // It returns the closest nodes to target. func (t *UDPv5) Lookup(target enode.ID) []*enode.Node { - return t.newLookup(t.closeCtx, target).run() + return t.newLookup(t.closeCtx, target).Run() } // lookupRandom looks up a random target. // This is needed to satisfy the transport interface. -func (t *UDPv5) lookupRandom() []*enode.Node { - return t.newRandomLookup(t.closeCtx).run() +func (t *UDPv5) LookupRandom() []*enode.Node { + return t.newRandomLookup(t.closeCtx).Run() } // lookupSelf looks up our own node ID. // This is needed to satisfy the transport interface. -func (t *UDPv5) lookupSelf() []*enode.Node { - return t.newLookup(t.closeCtx, t.Self().ID()).run() +func (t *UDPv5) LookupSelf() []*enode.Node { + return t.newLookup(t.closeCtx, t.Self().ID()).Run() } -func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup { +func (t *UDPv5) newRandomLookup(ctx context.Context) *Lookup { var target enode.ID crand.Read(target[:]) return t.newLookup(ctx, target) } -func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { - return newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { +func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *Lookup { + return NewLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { return t.lookupWorker(n, target) }) } @@ -371,27 +371,27 @@ func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { // lookupWorker performs FINDNODE calls against a single node during lookup. func (t *UDPv5) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { var ( - dists = lookupDistances(target, destNode.ID()) - nodes = nodesByDistance{target: target} + dists = LookupDistances(target, destNode.ID()) + nodes = NodesByDistance{Target: target} err error ) var r []*enode.Node - r, err = t.findnode(destNode, dists) - if errors.Is(err, errClosed) { + r, err = t.Findnode(destNode, dists) + if errors.Is(err, ErrClosed) { return nil, err } for _, n := range r { if n.ID() != t.Self().ID() { - nodes.push(n, findnodeResultLimit) + nodes.Push(n, findnodeResultLimit) } } - return nodes.entries, err + return nodes.Entries, err } -// lookupDistances computes the distance parameter for FINDNODE calls to dest. +// LookupDistances computes the distance parameter for FINDNODE calls to dest. // It chooses distances adjacent to logdist(target, dest), e.g. for a target // with logdist(target, dest) = 255 the result is [255, 256, 254]. -func lookupDistances(target, dest enode.ID) (dists []uint) { +func LookupDistances(target, dest enode.ID) (dists []uint) { td := enode.LogDist(target, dest) dists = append(dists, uint(td)) for i := 1; len(dists) < lookupRequestLimit; i++ { @@ -406,8 +406,8 @@ func lookupDistances(target, dest enode.ID) (dists []uint) { } // ping calls PING on a node and waits for a PONG response. -func (t *UDPv5) ping(n *enode.Node) (uint64, error) { - pong, err := t.pingInner(n) +func (t *UDPv5) Ping(n *enode.Node) (uint64, error) { + pong, err := t.PingWithResp(n) if err != nil { return 0, err } @@ -415,8 +415,8 @@ func (t *UDPv5) ping(n *enode.Node) (uint64, error) { return pong.ENRSeq, nil } -// pingInner calls PING on a node and waits for a PONG response. -func (t *UDPv5) pingInner(n *enode.Node) (*v5wire.Pong, error) { +// PingWithResp calls PING on a node and waits for a PONG response. +func (t *UDPv5) PingWithResp(n *enode.Node) (*v5wire.Pong, error) { req := &v5wire.Ping{ENRSeq: t.localNode.Node().Seq()} resp := t.callToNode(n, v5wire.PongMsg, req) defer t.callDone(resp) @@ -431,7 +431,7 @@ func (t *UDPv5) pingInner(n *enode.Node) (*v5wire.Pong, error) { // RequestENR requests n's record. func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { - nodes, err := t.findnode(n, []uint{0}) + nodes, err := t.Findnode(n, []uint{0}) if err != nil { return nil, err } @@ -441,8 +441,8 @@ func (t *UDPv5) RequestENR(n *enode.Node) (*enode.Node, error) { return nodes[0], nil } -// findnode calls FINDNODE on a node and waits for responses. -func (t *UDPv5) findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) { +// Findnode calls FINDNODE on a node and waits for responses. +func (t *UDPv5) Findnode(n *enode.Node, distances []uint) ([]*enode.Node, error) { resp := t.callToNode(n, v5wire.NodesMsg, &v5wire.Findnode{Distances: distances}) return t.waitForNodes(resp, distances) } @@ -493,7 +493,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s return nil, errors.New("not contained in netrestrict list") } if node.UDP() <= 1024 { - return nil, errLowPort + return nil, ErrLowPort } if distances != nil { nd := enode.LogDist(c.id, node.ID()) @@ -537,7 +537,7 @@ func (t *UDPv5) initCall(c *callV5, responseType byte, packet v5wire.Packet) { select { case t.callCh <- c: case <-t.closeCtx.Done(): - c.err <- errClosed + c.err <- ErrClosed } } @@ -630,12 +630,12 @@ func (t *UDPv5) dispatch() { close(t.readNextCh) for id, queue := range t.callQueue { for _, c := range queue { - c.err <- errClosed + c.err <- ErrClosed } delete(t.callQueue, id) } for id, c := range t.activeCallByNode { - c.err <- errClosed + c.err <- ErrClosed delete(t.activeCallByNode, id) delete(t.activeCallByAuth, c.nonce) } @@ -709,7 +709,7 @@ func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr netip.AddrPort, pack } } -func (t *UDPv5) sendFromAnotherThreadWithNode(node *enode.Node, toAddr netip.AddrPort, packet v5wire.Packet) { +func (t *UDPv5) SendFromAnotherThreadWithNode(node *enode.Node, toAddr netip.AddrPort, packet v5wire.Packet) { select { case t.sendCh <- sendRequest{node.ID(), node, toAddr, packet}: case <-t.closeCtx.Done(): @@ -792,7 +792,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error { } if fromNode != nil { // Handshake succeeded, add to table. - t.tab.addInboundNode(fromNode) + t.tab.AddInboundNode(fromNode) t.putCache(fromAddr.String(), fromNode) } if packet.Kind() != v5wire.WhoareyouPacket { @@ -825,9 +825,9 @@ func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr netip.AddrPort, p v return true } -// getNode looks for a node record in table and database. -func (t *UDPv5) getNode(id enode.ID) *enode.Node { - if n := t.tab.getNode(id); n != nil { +// GetNode looks for a node record in table and database. +func (t *UDPv5) GetNode(id enode.ID) *enode.Node { + if n := t.tab.GetNode(id); n != nil { return n } if n := t.localNode.Database().Node(id); n != nil { @@ -865,7 +865,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) { challenge := &v5wire.Whoareyou{Nonce: p.Nonce} crand.Read(challenge.IDNonce[:]) - if n := t.getNode(fromID); n != nil { + if n := t.GetNode(fromID); n != nil { challenge.Node = n challenge.RecordSeq = n.Seq() } @@ -952,7 +952,7 @@ func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) [ processed[dist] = struct{}{} checkLive := !t.tab.cfg.NoFindnodeLivenessCheck - for _, n := range t.tab.appendBucketNodes(dist, bn[:0], checkLive) { + for _, n := range t.tab.AppendBucketNodes(dist, bn[:0], checkLive) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. if netutil.CheckRelayAddr(rip, n.IPAddr()) != nil { @@ -1014,3 +1014,7 @@ func (t *UDPv5) GetCachedNode(addr string) (*enode.Node, bool) { n, ok := t.cachedAddrNode[addr] return n, ok } + +func (t *UDPv5) Table() *Table { + return t.tab +} diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index 2db9824e9708..3abea16884d3 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -143,7 +143,7 @@ func TestUDPv5_unknownPacket(t *testing.T) { // Make Node known. n := test.getNode(test.remotekey, test.remoteaddr).Node() - test.table.addFoundNode(n, false) + test.table.AddFoundNode(n, false) test.packetIn(&v5wire.Unknown{Nonce: nonce}) test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { @@ -237,7 +237,7 @@ func TestUDPv5_pingCall(t *testing.T) { // This ping times out. go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {}) @@ -247,7 +247,7 @@ func TestUDPv5_pingCall(t *testing.T) { // This ping works. go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { @@ -259,7 +259,7 @@ func TestUDPv5_pingCall(t *testing.T) { // This ping gets a reply from the wrong endpoint. go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { @@ -288,7 +288,7 @@ func TestUDPv5_findnodeCall(t *testing.T) { ) go func() { var err error - response, err = test.udp.findnode(remote, distances) + response, err = test.udp.Findnode(remote, distances) done <- err }() @@ -330,11 +330,11 @@ func TestUDPv5_callResend(t *testing.T) { remote := test.getNode(test.remotekey, test.remoteaddr).Node() done := make(chan error, 2) go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() @@ -367,7 +367,7 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) { remote := test.getNode(test.remotekey, test.remoteaddr).Node() done := make(chan error, 1) go func() { - _, err := test.udp.ping(remote) + _, err := test.udp.Ping(remote) done <- err }() @@ -398,7 +398,7 @@ func TestUDPv5_callTimeoutReset(t *testing.T) { done = make(chan error, 1) ) go func() { - _, err := test.udp.findnode(remote, []uint{distance}) + _, err := test.udp.Findnode(remote, []uint{distance}) done <- err }() @@ -535,38 +535,38 @@ func TestUDPv5_talkRequest(t *testing.T) { } } -// This test checks that lookupDistances works. +// This test checks that LookupDistances works. func TestUDPv5_lookupDistances(t *testing.T) { test := newUDPV5Test(t) lnID := test.table.self().ID() t.Run("target distance of 1", func(t *testing.T) { node := nodeAtDistance(lnID, 1, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{1, 2, 3}, dists) }) t.Run("target distance of 2", func(t *testing.T) { node := nodeAtDistance(lnID, 2, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{2, 3, 1}, dists) }) t.Run("target distance of 128", func(t *testing.T) { node := nodeAtDistance(lnID, 128, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{128, 129, 127}, dists) }) t.Run("target distance of 255", func(t *testing.T) { node := nodeAtDistance(lnID, 255, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{255, 256, 254}, dists) }) t.Run("target distance of 256", func(t *testing.T) { node := nodeAtDistance(lnID, 256, intIP(0)) - dists := lookupDistances(lnID, node.ID()) + dists := LookupDistances(lnID, node.ID()) require.Equal(t, []uint{256, 255, 254}, dists) }) } @@ -817,7 +817,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { exptype := fn.Type().In(0) dgram, err := test.pipe.receive() - if err == errClosed { + if err == ErrClosed { return true } if err == errTimeout { diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index a6cc278bba7e..6f33aa831c48 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -94,7 +94,7 @@ const ( // Should reject packets smaller than minPacketSize. minPacketSize = 63 - maxPacketSize = 1280 + MaxPacketSize = 1280 minMessageSize = 48 // this refers to data after static headers randomPacketMsgSize = 20 @@ -169,7 +169,7 @@ func NewCodec(ln *enode.LocalNode, key *ecdsa.PrivateKey, clock mclock.Clock, pr privkey: key, sc: NewSessionCache(1024, clock), protocolID: DefaultProtocolID, - decbuf: make([]byte, maxPacketSize), + decbuf: make([]byte, MaxPacketSize), } if protocolID != nil { c.protocolID = *protocolID diff --git a/portalnetwork/api.go b/portalnetwork/api.go new file mode 100644 index 000000000000..bc7305ef8b57 --- /dev/null +++ b/portalnetwork/api.go @@ -0,0 +1,543 @@ +package portalnetwork + +import ( + "errors" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" + "github.com/holiman/uint256" +) + +// DiscV5API json-rpc spec +// https://playground.open-rpc.org/?schemaUrl=https://raw.githubusercontent.com/ethereum/portal-network-specs/assembled-spec/jsonrpc/openrpc.json&uiSchema%5BappBar%5D%5Bui:splitView%5D=false&uiSchema%5BappBar%5D%5Bui:input%5D=false&uiSchema%5BappBar%5D%5Bui:examplesDropdown%5D=false +type DiscV5API struct { + DiscV5 *discover.UDPv5 +} + +func NewDiscV5API(discV5 *discover.UDPv5) *DiscV5API { + return &DiscV5API{discV5} +} + +type NodeInfo struct { + NodeId string `json:"nodeId"` + Enr string `json:"enr"` + Ip string `json:"ip"` +} + +type RoutingTableInfo struct { + Buckets [][]string `json:"buckets"` + LocalNodeId string `json:"localNodeId"` +} + +type DiscV5PongResp struct { + EnrSeq uint64 `json:"enrSeq"` + RecipientIP string `json:"recipientIP"` + RecipientPort uint16 `json:"recipientPort"` +} + +type PortalPongResp struct { + EnrSeq uint32 `json:"enrSeq"` + DataRadius string `json:"dataRadius"` +} + +type ContentInfo struct { + Content string `json:"content"` + UtpTransfer bool `json:"utpTransfer"` +} + +type TraceContentResult struct { + Content string `json:"content"` + UtpTransfer bool `json:"utpTransfer"` + Trace Trace `json:"trace"` +} + +type Trace struct { + Origin string `json:"origin"` // local node id + TargetId string `json:"targetId"` // target content id + ReceivedFrom string `json:"receivedFrom"` // the node id of which content from + Responses map[string]RespByNode `json:"responses"` // the node id and there response nodeIds + Metadata map[string]*NodeMetadata `json:"metadata"` // node id and there metadata object + StartedAtMs int `json:"startedAtMs"` // timestamp of the beginning of this request in milliseconds + Cancelled []string `json:"cancelled"` // the node ids which are send but cancelled +} + +type NodeMetadata struct { + Enr string `json:"enr"` + Distance string `json:"distance"` +} + +type RespByNode struct { + DurationMs int32 `json:"durationMs"` + RespondedWith []string `json:"respondedWith"` +} + +type Enrs struct { + Enrs []string `json:"enrs"` +} + +func (d *DiscV5API) NodeInfo() *NodeInfo { + n := d.DiscV5.LocalNode().Node() + + return &NodeInfo{ + NodeId: "0x" + n.ID().String(), + Enr: n.String(), + Ip: n.IP().String(), + } +} + +func (d *DiscV5API) RoutingTableInfo() *RoutingTableInfo { + n := d.DiscV5.LocalNode().Node() + bucketNodes := d.DiscV5.RoutingTableInfo() + + return &RoutingTableInfo{ + Buckets: bucketNodes, + LocalNodeId: "0x" + n.ID().String(), + } +} + +func (d *DiscV5API) AddEnr(enr string) (bool, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return false, err + } + + // immediately add the node to the routing table + d.DiscV5.Table().AddInboundNode(n) + return true, nil +} + +func (d *DiscV5API) GetEnr(nodeId string) (bool, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return false, err + } + n := d.DiscV5.Table().GetNode(id) + if n == nil { + return false, errors.New("record not in local routing table") + } + + return true, nil +} + +func (d *DiscV5API) DeleteEnr(nodeId string) (bool, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return false, err + } + + n := d.DiscV5.Table().GetNode(id) + if n == nil { + return false, errors.New("record not in local routing table") + } + + d.DiscV5.Table().DeleteNode(n) + return true, nil +} + +func (d *DiscV5API) LookupEnr(nodeId string) (string, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return "", err + } + + enr := d.DiscV5.ResolveNodeId(id) + + if enr == nil { + return "", errors.New("record not found in DHT lookup") + } + + return enr.String(), nil +} + +func (d *DiscV5API) Ping(enr string) (*DiscV5PongResp, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + + pong, err := d.DiscV5.PingWithResp(n) + if err != nil { + return nil, err + } + + return &DiscV5PongResp{ + EnrSeq: pong.ENRSeq, + RecipientIP: pong.ToIP.String(), + RecipientPort: pong.ToPort, + }, nil +} + +func (d *DiscV5API) FindNodes(enr string, distances []uint) ([]string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + findNodes, err := d.DiscV5.Findnode(n, distances) + if err != nil { + return nil, err + } + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +func (d *DiscV5API) TalkReq(enr string, protocol string, payload string) (string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return "", err + } + + req, err := hexutil.Decode(payload) + if err != nil { + return "", err + } + + talkResp, err := d.DiscV5.TalkRequest(n, protocol, req) + if err != nil { + return "", err + } + return hexutil.Encode(talkResp), nil +} + +func (d *DiscV5API) RecursiveFindNodes(nodeId string) ([]string, error) { + findNodes := d.DiscV5.Lookup(enode.HexID(nodeId)) + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +type PortalProtocolAPI struct { + portalProtocol *PortalProtocol +} + +func NewPortalAPI(portalProtocol *PortalProtocol) *PortalProtocolAPI { + return &PortalProtocolAPI{ + portalProtocol: portalProtocol, + } +} + +func (p *PortalProtocolAPI) NodeInfo() *NodeInfo { + n := p.portalProtocol.localNode.Node() + + return &NodeInfo{ + NodeId: n.ID().String(), + Enr: n.String(), + Ip: n.IP().String(), + } +} + +func (p *PortalProtocolAPI) RoutingTableInfo() *RoutingTableInfo { + n := p.portalProtocol.localNode.Node() + bucketNodes := p.portalProtocol.RoutingTableInfo() + + return &RoutingTableInfo{ + Buckets: bucketNodes, + LocalNodeId: "0x" + n.ID().String(), + } +} + +func (p *PortalProtocolAPI) AddEnr(enr string) (bool, error) { + p.portalProtocol.Log.Debug("serving AddEnr", "enr", enr) + n, err := enode.ParseForAddEnr(enode.ValidSchemes, enr) + if err != nil { + return false, err + } + p.portalProtocol.AddEnr(n) + return true, nil +} + +func (p *PortalProtocolAPI) AddEnrs(enrs []string) bool { + // Note: unspecified RPC, but useful for our local testnet test + for _, enr := range enrs { + n, err := enode.ParseForAddEnr(enode.ValidSchemes, enr) + if err != nil { + continue + } + p.portalProtocol.AddEnr(n) + } + + return true +} + +func (p *PortalProtocolAPI) GetEnr(nodeId string) (string, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return "", err + } + + if id == p.portalProtocol.localNode.Node().ID() { + return p.portalProtocol.localNode.Node().String(), nil + } + + n := p.portalProtocol.table.GetNode(id) + if n == nil { + return "", errors.New("record not in local routing table") + } + + return n.String(), nil +} + +func (p *PortalProtocolAPI) DeleteEnr(nodeId string) (bool, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return false, err + } + + n := p.portalProtocol.table.GetNode(id) + if n == nil { + return false, nil + } + + p.portalProtocol.table.DeleteNode(n) + return true, nil +} + +func (p *PortalProtocolAPI) LookupEnr(nodeId string) (string, error) { + id, err := enode.ParseID(nodeId) + if err != nil { + return "", err + } + + enr := p.portalProtocol.ResolveNodeId(id) + + if enr == nil { + return "", errors.New("record not found in DHT lookup") + } + + return enr.String(), nil +} + +func (p *PortalProtocolAPI) Ping(enr string) (*PortalPongResp, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + + pong, err := p.portalProtocol.pingInner(n) + if err != nil { + return nil, err + } + + customPayload := &portalwire.PingPongCustomData{} + err = customPayload.UnmarshalSSZ(pong.CustomPayload) + if err != nil { + return nil, err + } + + nodeRadius := new(uint256.Int) + err = nodeRadius.UnmarshalSSZ(customPayload.Radius) + if err != nil { + return nil, err + } + + return &PortalPongResp{ + EnrSeq: uint32(pong.EnrSeq), + DataRadius: nodeRadius.Hex(), + }, nil +} + +func (p *PortalProtocolAPI) FindNodes(enr string, distances []uint) ([]string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + findNodes, err := p.portalProtocol.findNodes(n, distances) + if err != nil { + return nil, err + } + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +func (p *PortalProtocolAPI) FindContent(enr string, contentKey string) (interface{}, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return nil, err + } + + contentKeyBytes, err := hexutil.Decode(contentKey) + if err != nil { + return nil, err + } + + flag, findContent, err := p.portalProtocol.findContent(n, contentKeyBytes) + if err != nil { + return nil, err + } + + switch flag { + case portalwire.ContentRawSelector: + contentInfo := &ContentInfo{ + Content: hexutil.Encode(findContent.([]byte)), + UtpTransfer: false, + } + p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) + return contentInfo, nil + case portalwire.ContentConnIdSelector: + contentInfo := &ContentInfo{ + Content: hexutil.Encode(findContent.([]byte)), + UtpTransfer: true, + } + p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) + return contentInfo, nil + default: + enrs := make([]string, 0) + for _, r := range findContent.([]*enode.Node) { + enrs = append(enrs, r.String()) + } + + p.portalProtocol.Log.Trace("FindContent", "enrs", enrs) + return &Enrs{ + Enrs: enrs, + }, nil + } +} + +func (p *PortalProtocolAPI) Offer(enr string, contentItems [][2]string) (string, error) { + n, err := enode.Parse(enode.ValidSchemes, enr) + if err != nil { + return "", err + } + + entries := make([]*ContentEntry, 0, len(contentItems)) + for _, contentItem := range contentItems { + contentKey, err := hexutil.Decode(contentItem[0]) + if err != nil { + return "", err + } + contentValue, err := hexutil.Decode(contentItem[1]) + if err != nil { + return "", err + } + contentEntry := &ContentEntry{ + ContentKey: contentKey, + Content: contentValue, + } + entries = append(entries, contentEntry) + } + + transientOfferRequest := &TransientOfferRequest{ + Contents: entries, + } + + offerReq := &OfferRequest{ + Kind: TransientOfferRequestKind, + Request: transientOfferRequest, + } + accept, err := p.portalProtocol.offer(n, offerReq) + if err != nil { + return "", err + } + + return hexutil.Encode(accept), nil +} + +func (p *PortalProtocolAPI) RecursiveFindNodes(nodeId string) ([]string, error) { + findNodes := p.portalProtocol.Lookup(enode.HexID(nodeId)) + + enrs := make([]string, 0, len(findNodes)) + for _, r := range findNodes { + enrs = append(enrs, r.String()) + } + + return enrs, nil +} + +func (p *PortalProtocolAPI) RecursiveFindContent(contentKeyHex string) (*ContentInfo, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return nil, err + } + contentId := p.portalProtocol.toContentId(contentKey) + + data, err := p.portalProtocol.Get(contentKey, contentId) + if err == nil { + return &ContentInfo{ + Content: hexutil.Encode(data), + UtpTransfer: false, + }, err + } + p.portalProtocol.Log.Warn("find content err", "contextKey", hexutil.Encode(contentKey), "err", err) + + content, utpTransfer, err := p.portalProtocol.ContentLookup(contentKey, contentId) + + if err != nil { + return nil, err + } + + return &ContentInfo{ + Content: hexutil.Encode(content), + UtpTransfer: utpTransfer, + }, err +} + +func (p *PortalProtocolAPI) LocalContent(contentKeyHex string) (string, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return "", err + } + contentId := p.portalProtocol.ToContentId(contentKey) + content, err := p.portalProtocol.Get(contentKey, contentId) + + if err != nil { + return "", err + } + return hexutil.Encode(content), nil +} + +func (p *PortalProtocolAPI) Store(contentKeyHex string, contextHex string) (bool, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return false, err + } + contentId := p.portalProtocol.ToContentId(contentKey) + if !p.portalProtocol.InRange(contentId) { + return false, nil + } + content, err := hexutil.Decode(contextHex) + if err != nil { + return false, err + } + err = p.portalProtocol.Put(contentKey, contentId, content) + if err != nil { + return false, err + } + return true, nil +} + +func (p *PortalProtocolAPI) Gossip(contentKeyHex, contentHex string) (int, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return 0, err + } + content, err := hexutil.Decode(contentHex) + if err != nil { + return 0, err + } + id := p.portalProtocol.Self().ID() + return p.portalProtocol.Gossip(&id, [][]byte{contentKey}, [][]byte{content}) +} + +func (p *PortalProtocolAPI) TraceRecursiveFindContent(contentKeyHex string) (*TraceContentResult, error) { + contentKey, err := hexutil.Decode(contentKeyHex) + if err != nil { + return nil, err + } + contentId := p.portalProtocol.toContentId(contentKey) + return p.portalProtocol.TraceContentLookup(contentKey, contentId) +} diff --git a/portalnetwork/nat.go b/portalnetwork/nat.go new file mode 100644 index 000000000000..ca479d7e457d --- /dev/null +++ b/portalnetwork/nat.go @@ -0,0 +1,172 @@ +package portalnetwork + +import ( + "net" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nat" +) + +const ( + portMapDuration = 10 * time.Minute + portMapRefreshInterval = 8 * time.Minute + portMapRetryInterval = 5 * time.Minute + extipRetryInterval = 2 * time.Minute +) + +type portMapping struct { + protocol string + name string + port int + + // for use by the portMappingLoop goroutine: + extPort int // the mapped port returned by the NAT interface + nextTime mclock.AbsTime +} + +// setupPortMapping starts the port mapping loop if necessary. +// Note: this needs to be called after the LocalNode instance has been set on the server. +func (p *PortalProtocol) setupPortMapping() { + // portMappingRegister will receive up to two values: one for the TCP port if + // listening is enabled, and one more for enabling UDP port mapping if discovery is + // enabled. We make it buffered to avoid blocking setup while a mapping request is in + // progress. + p.portMappingRegister = make(chan *portMapping, 2) + + switch p.NAT.(type) { + case nil: + // No NAT interface configured. + go p.consumePortMappingRequests() + + case nat.ExtIP: + // ExtIP doesn't block, set the IP right away. + ip, _ := p.NAT.ExternalIP() + p.localNode.SetStaticIP(ip) + go p.consumePortMappingRequests() + + case nat.STUN: + // STUN doesn't block, set the IP right away. + ip, _ := p.NAT.ExternalIP() + p.localNode.SetStaticIP(ip) + go p.consumePortMappingRequests() + + default: + go p.portMappingLoop() + } +} + +func (p *PortalProtocol) consumePortMappingRequests() { + for { + select { + case <-p.closeCtx.Done(): + return + case <-p.portMappingRegister: + } + } +} + +// portMappingLoop manages port mappings for UDP and TCP. +func (p *PortalProtocol) portMappingLoop() { + newLogger := func(proto string, e int, i int) log.Logger { + return log.New("proto", proto, "extport", e, "intport", i, "interface", p.NAT) + } + + var ( + mappings = make(map[string]*portMapping, 2) + refresh = mclock.NewAlarm(p.clock) + extip = mclock.NewAlarm(p.clock) + lastExtIP net.IP + ) + extip.Schedule(p.clock.Now()) + defer func() { + refresh.Stop() + extip.Stop() + for _, m := range mappings { + if m.extPort != 0 { + log := newLogger(m.protocol, m.extPort, m.port) + log.Debug("Deleting port mapping") + p.NAT.DeleteMapping(m.protocol, m.extPort, m.port) + } + } + }() + + for { + // Schedule refresh of existing mappings. + for _, m := range mappings { + refresh.Schedule(m.nextTime) + } + + select { + case <-p.closeCtx.Done(): + return + + case <-extip.C(): + extip.Schedule(p.clock.Now().Add(extipRetryInterval)) + ip, err := p.NAT.ExternalIP() + if err != nil { + log.Debug("Couldn't get external IP", "err", err, "interface", p.NAT) + } else if !ip.Equal(lastExtIP) { + log.Debug("External IP changed", "ip", extip, "interface", p.NAT) + } else { + continue + } + // Here, we either failed to get the external IP, or it has changed. + lastExtIP = ip + p.localNode.SetStaticIP(ip) + p.Log.Debug("set static ip in nat", "ip", p.localNode.Node().IP().String()) + // Ensure port mappings are refreshed in case we have moved to a new network. + for _, m := range mappings { + m.nextTime = p.clock.Now() + } + + case m := <-p.portMappingRegister: + if m.protocol != "TCP" && m.protocol != "UDP" { + panic("unknown NAT protocol name: " + m.protocol) + } + mappings[m.protocol] = m + m.nextTime = p.clock.Now() + + case <-refresh.C(): + for _, m := range mappings { + if p.clock.Now() < m.nextTime { + continue + } + + external := m.port + if m.extPort != 0 { + external = m.extPort + } + log := newLogger(m.protocol, external, m.port) + + log.Trace("Attempting port mapping") + port, err := p.NAT.AddMapping(m.protocol, external, m.port, m.name, portMapDuration) + if err != nil { + log.Debug("Couldn't add port mapping", "err", err) + m.extPort = 0 + m.nextTime = p.clock.Now().Add(portMapRetryInterval) + continue + } + // It was mapped! + m.extPort = int(port) + m.nextTime = p.clock.Now().Add(portMapRefreshInterval) + if external != m.extPort { + log = newLogger(m.protocol, m.extPort, m.port) + log.Info("NAT mapped alternative port") + } else { + log.Info("NAT mapped port") + } + + // Update port in local ENR. + switch m.protocol { + case "TCP": + p.localNode.Set(enr.TCP(m.extPort)) + case "UDP": + p.localNode.SetFallbackUDP(m.extPort) + } + } + } + } +} diff --git a/portalnetwork/portal_protocol.go b/portalnetwork/portal_protocol.go new file mode 100644 index 000000000000..126acb82a7ee --- /dev/null +++ b/portalnetwork/portal_protocol.go @@ -0,0 +1,1918 @@ +package portalnetwork + +import ( + "bytes" + "context" + "crypto/ecdsa" + crand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "math/rand" + "net" + "slices" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/VictoriaMetrics/fastcache" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" + "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/ethereum/go-ethereum/rlp" + ssz "github.com/ferranbt/fastssz" + "github.com/holiman/uint256" + "github.com/optimism-java/utp-go" + "github.com/optimism-java/utp-go/libutp" + "github.com/prysmaticlabs/go-bitfield" + "github.com/tetratelabs/wabin/leb128" +) + +const ( + + // TalkResp message is a response message so the session is established and a + // regular discv5 packet is assumed for size calculation. + // Regular message = IV + header + message + // talkResp message = rlp: [request-id, response] + talkRespOverhead = 16 + // IV size + 55 + // header size + 1 + // talkResp msg id + 3 + // rlp encoding outer list, max length will be encoded in 2 bytes + 9 + // request id (max = 8) + 1 byte from rlp encoding byte string + 3 + // rlp encoding response byte string, max length in 2 bytes + 16 // HMAC + + portalFindnodesResultLimit = 32 + + defaultUTPConnectTimeout = 15 * time.Second + + defaultUTPWriteTimeout = 60 * time.Second + + defaultUTPReadTimeout = 60 * time.Second + + // These are the concurrent offers per Portal wire protocol that is running. + // Using the `offerQueue` allows for limiting the amount of offers send and + // thus how many streams can be started. + // TODO: + // More thought needs to go into this as it is currently on a per network + // basis. Keep it simple like that? Or limit it better at the stream transport + // level? In the latter case, this might still need to be checked/blocked at + // the very start of sending the offer, because blocking/waiting too long + // between the received accept message and actually starting the stream and + // sending data could give issues due to timeouts on the other side. + // And then there are still limits to be applied also for FindContent and the + // incoming directions. + concurrentOffers = 50 +) + +const ( + TransientOfferRequestKind byte = 0x01 + PersistOfferRequestKind byte = 0x02 +) + +type ClientTag string + +func (c ClientTag) ENRKey() string { return "c" } + +const Tag ClientTag = "shisui" + +var ErrNilContentKey = errors.New("content key cannot be nil") + +var ContentNotFound = storage.ErrContentNotFound + +var ErrEmptyResp = errors.New("empty resp") + +var MaxDistance = hexutil.MustDecode("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + +type ContentElement struct { + Node enode.ID + ContentKeys [][]byte + Contents [][]byte +} + +type ContentEntry struct { + ContentKey []byte + Content []byte +} + +type TransientOfferRequest struct { + Contents []*ContentEntry +} + +type PersistOfferRequest struct { + ContentKeys [][]byte +} + +type OfferRequest struct { + Kind byte + Request interface{} +} + +type OfferRequestWithNode struct { + Request *OfferRequest + Node *enode.Node +} + +type ContentInfoResp struct { + Content []byte + UtpTransfer bool +} + +type traceContentInfoResp struct { + Node *enode.Node + Flag byte + Content any + UtpTransfer bool +} + +type PortalProtocolOption func(p *PortalProtocol) + +type PortalProtocolConfig struct { + BootstrapNodes []*enode.Node + // NodeIP net.IP + ListenAddr string + NetRestrict *netutil.Netlist + NodeRadius *uint256.Int + RadiusCacheSize int + NodeDBPath string + NAT nat.Interface + clock mclock.Clock +} + +func DefaultPortalProtocolConfig() *PortalProtocolConfig { + return &PortalProtocolConfig{ + BootstrapNodes: make([]*enode.Node, 0), + ListenAddr: ":9009", + NetRestrict: nil, + RadiusCacheSize: 32 * 1024 * 1024, + NodeDBPath: "", + clock: mclock.System{}, + } +} + +type PortalProtocol struct { + table *discover.Table + + protocolId string + protocolName string + + DiscV5 *discover.UDPv5 + localNode *enode.LocalNode + Log log.Logger + PrivateKey *ecdsa.PrivateKey + NetRestrict *netutil.Netlist + BootstrapNodes []*enode.Node + conn discover.UDPConn + + Utp *PortalUtp + connIdGen libutp.ConnIdGenerator + + validSchemes enr.IdentityScheme + radiusCache *fastcache.Cache + closeCtx context.Context + cancelCloseCtx context.CancelFunc + storage storage.ContentStorage + toContentId func(contentKey []byte) []byte + + contentQueue chan *ContentElement + offerQueue chan *OfferRequestWithNode + + portMappingRegister chan *portMapping + clock mclock.Clock + NAT nat.Interface + + portalMetrics *portalMetrics +} + +func defaultContentIdFunc(contentKey []byte) []byte { + digest := sha256.Sum256(contentKey) + return digest[:] +} + +func NewPortalProtocol(config *PortalProtocolConfig, protocolId portalwire.ProtocolId, privateKey *ecdsa.PrivateKey, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *PortalUtp, storage storage.ContentStorage, contentQueue chan *ContentElement, opts ...PortalProtocolOption) (*PortalProtocol, error) { + closeCtx, cancelCloseCtx := context.WithCancel(context.Background()) + + protocol := &PortalProtocol{ + protocolId: string(protocolId), + protocolName: protocolId.Name(), + Log: log.New("protocol", protocolId.Name()), + PrivateKey: privateKey, + NetRestrict: config.NetRestrict, + BootstrapNodes: config.BootstrapNodes, + radiusCache: fastcache.New(config.RadiusCacheSize), + closeCtx: closeCtx, + cancelCloseCtx: cancelCloseCtx, + localNode: localNode, + validSchemes: enode.ValidSchemes, + storage: storage, + toContentId: defaultContentIdFunc, + contentQueue: contentQueue, + offerQueue: make(chan *OfferRequestWithNode, concurrentOffers), + conn: conn, + DiscV5: discV5, + Utp: utp, + NAT: config.NAT, + clock: config.clock, + connIdGen: libutp.NewConnIdGenerator(), + } + + for _, opt := range opts { + opt(protocol) + } + + if metrics.Enabled { + protocol.portalMetrics = newPortalMetrics(protocolId.Name()) + } + + return protocol, nil +} + +func (p *PortalProtocol) Start() error { + p.setupPortMapping() + + err := p.setupDiscV5AndTable() + if err != nil { + return err + } + + p.DiscV5.RegisterTalkHandler(p.protocolId, p.handleTalkRequest) + if p.Utp != nil { + err = p.Utp.Start() + } + if err != nil { + return err + } + + go p.table.Loop() + + for i := 0; i < concurrentOffers; i++ { + go p.offerWorker() + } + + // wait for both initialization processes to complete + p.DiscV5.Table().WaitInit() + p.table.WaitInit() + return nil +} + +func (p *PortalProtocol) Stop() { + p.cancelCloseCtx() + p.table.Close() + p.DiscV5.Close() + if p.Utp != nil { + p.Utp.Stop() + } +} +func (p *PortalProtocol) RoutingTableInfo() [][]string { + return p.table.NodeIds() +} + +func (p *PortalProtocol) AddEnr(n *enode.Node) { + added := p.table.AddInboundNode(n) + if !added { + p.Log.Warn("add node failed", "id", n.ID(), "ip", n.IPAddr()) + return + } + id := n.ID().String() + p.radiusCache.Set([]byte(id), MaxDistance) +} + +func (p *PortalProtocol) Radius() *uint256.Int { + return p.storage.Radius() +} + +func (p *PortalProtocol) setupUDPListening() error { + laddr := p.conn.LocalAddr().(*net.UDPAddr) + p.localNode.SetFallbackUDP(laddr.Port) + p.Log.Debug("UDP listener up", "addr", laddr) + // TODO: NAT + if !laddr.IP.IsLoopback() && !laddr.IP.IsPrivate() { + p.portMappingRegister <- &portMapping{ + protocol: "UDP", + name: "ethereum portal peer discovery", + port: laddr.Port, + } + } + return nil +} + +func (p *PortalProtocol) setupDiscV5AndTable() error { + err := p.setupUDPListening() + if err != nil { + return err + } + + cfg := discover.Config{ + PrivateKey: p.PrivateKey, + NetRestrict: p.NetRestrict, + Bootnodes: p.BootstrapNodes, + Log: p.Log, + } + + p.table, err = discover.NewTable(p, p.localNode.Database(), cfg) + if err != nil { + return err + } + + return nil +} + +func (p *PortalProtocol) Ping(node *enode.Node) (uint64, error) { + pong, err := p.pingInner(node) + if err != nil { + return 0, err + } + + return pong.EnrSeq, nil +} + +func (p *PortalProtocol) pingInner(node *enode.Node) (*portalwire.Pong, error) { + enrSeq := p.Self().Seq() + radiusBytes, err := p.Radius().MarshalSSZ() + if err != nil { + return nil, err + } + customPayload := &portalwire.PingPongCustomData{ + Radius: radiusBytes, + } + + customPayloadBytes, err := customPayload.MarshalSSZ() + if err != nil { + return nil, err + } + + pingRequest := &portalwire.Ping{ + EnrSeq: enrSeq, + CustomPayload: customPayloadBytes, + } + + p.Log.Trace(">> PING/"+p.protocolName, "protocol", p.protocolName, "ip", p.Self().IP().String(), "source", p.Self().ID(), "target", node.ID(), "ping", pingRequest) + if metrics.Enabled { + p.portalMetrics.messagesSentPing.Mark(1) + } + pingRequestBytes, err := pingRequest.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRequestBytes := make([]byte, 0, len(pingRequestBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.PING) + talkRequestBytes = append(talkRequestBytes, pingRequestBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + + if err != nil { + return nil, err + } + + p.Log.Trace("<< PONG/"+p.protocolName, "source", p.Self().ID(), "target", node.ID(), "res", talkResp) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPong.Mark(1) + } + + return p.processPong(node, talkResp) +} + +func (p *PortalProtocol) findNodes(node *enode.Node, distances []uint) ([]*enode.Node, error) { + if p.localNode.ID().String() == node.ID().String() { + return make([]*enode.Node, 0), nil + } + + distancesBytes := make([][2]byte, len(distances)) + for i, distance := range distances { + copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), uint16(distance))) + } + + findNodes := &portalwire.FindNodes{ + Distances: distancesBytes, + } + + p.Log.Trace(">> FIND_NODES/"+p.protocolName, "id", node.ID(), "findNodes", findNodes) + if metrics.Enabled { + p.portalMetrics.messagesSentFindNodes.Mark(1) + } + findNodesBytes, err := findNodes.MarshalSSZ() + if err != nil { + p.Log.Error("failed to marshal find nodes request", "err", err) + return nil, err + } + + talkRequestBytes := make([]byte, 0, len(findNodesBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.FINDNODES) + talkRequestBytes = append(talkRequestBytes, findNodesBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + if err != nil { + p.Log.Error("failed to send find nodes request", "ip", node.IP().String(), "port", node.UDP(), "err", err) + return nil, err + } + + return p.processNodes(node, talkResp, distances) +} + +func (p *PortalProtocol) findContent(node *enode.Node, contentKey []byte) (byte, interface{}, error) { + findContent := &portalwire.FindContent{ + ContentKey: contentKey, + } + + p.Log.Trace(">> FIND_CONTENT/"+p.protocolName, "id", node.ID(), "findContent", findContent) + if metrics.Enabled { + p.portalMetrics.messagesSentFindContent.Mark(1) + } + findContentBytes, err := findContent.MarshalSSZ() + if err != nil { + p.Log.Error("failed to marshal find content request", "err", err) + return 0xff, nil, err + } + + talkRequestBytes := make([]byte, 0, len(findContentBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.FINDCONTENT) + talkRequestBytes = append(talkRequestBytes, findContentBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + if err != nil { + p.Log.Error("failed to send find content request", "ip", node.IP().String(), "port", node.UDP(), "err", err) + return 0xff, nil, err + } + + return p.processContent(node, talkResp) +} + +func (p *PortalProtocol) offer(node *enode.Node, offerRequest *OfferRequest) ([]byte, error) { + contentKeys := getContentKeys(offerRequest) + + offer := &portalwire.Offer{ + ContentKeys: contentKeys, + } + + p.Log.Trace(">> OFFER/"+p.protocolName, "offer", offer) + if metrics.Enabled { + p.portalMetrics.messagesSentOffer.Mark(1) + } + offerBytes, err := offer.MarshalSSZ() + if err != nil { + p.Log.Error("failed to marshal offer request", "err", err) + return nil, err + } + + talkRequestBytes := make([]byte, 0, len(offerBytes)+1) + talkRequestBytes = append(talkRequestBytes, portalwire.OFFER) + talkRequestBytes = append(talkRequestBytes, offerBytes...) + + talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) + if err != nil { + p.Log.Error("failed to send offer request", "err", err) + return nil, err + } + + return p.processOffer(node, talkResp, offerRequest) +} + +func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request *OfferRequest) ([]byte, error) { + var err error + if len(resp) == 0 { + return nil, ErrEmptyResp + } + if resp[0] != portalwire.ACCEPT { + return nil, fmt.Errorf("invalid accept response") + } + + p.Log.Info("will process Offer", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) + + accept := &portalwire.Accept{} + err = accept.UnmarshalSSZ(resp[1:]) + if err != nil { + return nil, err + } + + p.Log.Trace("<< ACCEPT/"+p.protocolName, "id", target.ID(), "accept", accept) + if metrics.Enabled { + p.portalMetrics.messagesReceivedAccept.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + var contentKeyLen int + if request.Kind == TransientOfferRequestKind { + contentKeyLen = len(request.Request.(*TransientOfferRequest).Contents) + } else { + contentKeyLen = len(request.Request.(*PersistOfferRequest).ContentKeys) + } + + contentKeyBitlist := bitfield.Bitlist(accept.ContentKeys) + if contentKeyBitlist.Len() != uint64(contentKeyLen) { + return nil, fmt.Errorf("accepted content key bitlist has invalid size, expected %d, got %d", contentKeyLen, contentKeyBitlist.Len()) + } + + if contentKeyBitlist.Count() == 0 { + return nil, nil + } + + connId := binary.BigEndian.Uint16(accept.ConnectionId[:]) + go func(ctx context.Context) { + var conn net.Conn + defer func() { + if conn == nil { + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + for { + select { + case <-ctx.Done(): + return + default: + contents := make([][]byte, 0, contentKeyBitlist.Count()) + var content []byte + if request.Kind == TransientOfferRequestKind { + for _, index := range contentKeyBitlist.BitIndices() { + content = request.Request.(*TransientOfferRequest).Contents[index].Content + contents = append(contents, content) + } + } else { + for _, index := range contentKeyBitlist.BitIndices() { + contentKey := request.Request.(*PersistOfferRequest).ContentKeys[index] + contentId := p.toContentId(contentKey) + if contentId != nil { + content, err = p.storage.Get(contentKey, contentId) + if err != nil { + p.Log.Error("failed to get content from storage", "err", err) + contents = append(contents, []byte{}) + } else { + contents = append(contents, content) + } + } else { + contents = append(contents, []byte{}) + } + } + } + + var contentsPayload []byte + contentsPayload, err = encodeContents(contents) + if err != nil { + p.Log.Error("failed to encode contents", "err", err) + return + } + + connctx, conncancel := context.WithTimeout(ctx, defaultUTPConnectTimeout) + conn, err = p.Utp.DialWithCid(connctx, target, libutp.ReceConnId(connId).SendId()) + conncancel() + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailConn.Inc(1) + } + p.Log.Error("failed to dial utp connection", "err", err) + return + } + + err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailDeadline.Inc(1) + } + p.Log.Error("failed to set write deadline", "err", err) + return + } + + var written int + written, err = conn.Write(contentsPayload) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailWrite.Inc(1) + } + p.Log.Error("failed to write to utp connection", "err", err) + return + } + p.Log.Trace(">> CONTENT/"+p.protocolName, "id", target.ID(), "contents", contents, "size", written) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + p.portalMetrics.utpOutSuccess.Inc(1) + } + return + } + } + }(p.closeCtx) + + return accept.ContentKeys, nil +} + +func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, interface{}, error) { + if len(resp) == 0 { + return 0x00, nil, ErrEmptyResp + } + + if resp[0] != portalwire.CONTENT { + return 0xff, nil, fmt.Errorf("invalid content response") + } + + p.Log.Info("will process content", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) + + switch resp[1] { + case portalwire.ContentRawSelector: + content := &portalwire.Content{} + err := content.UnmarshalSSZ(resp[2:]) + if err != nil { + return 0xff, nil, err + } + + p.Log.Trace("<< CONTENT/"+p.protocolName, "id", target.ID(), "content", content) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + return resp[1], content.Content, nil + case portalwire.ContentConnIdSelector: + connIdMsg := &portalwire.ConnectionId{} + err := connIdMsg.UnmarshalSSZ(resp[2:]) + if err != nil { + return 0xff, nil, err + } + + p.Log.Trace("<< CONTENT_CONNECTION_ID/"+p.protocolName, "id", target.ID(), "resp", common.Bytes2Hex(resp), "connIdMsg", connIdMsg) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + connctx, conncancel := context.WithTimeout(p.closeCtx, defaultUTPConnectTimeout) + connId := binary.BigEndian.Uint16(connIdMsg.Id[:]) + conn, err := p.Utp.DialWithCid(connctx, target, libutp.ReceConnId(connId).SendId()) + defer func() { + if conn == nil { + if metrics.Enabled { + p.portalMetrics.utpInFailConn.Inc(1) + } + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + conncancel() + if err != nil { + return 0xff, nil, err + } + + err = conn.SetReadDeadline(time.Now().Add(defaultUTPReadTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailDeadline.Inc(1) + } + return 0xff, nil, err + } + // Read ALL the data from the connection until EOF and return it + data, err := io.ReadAll(conn) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailRead.Inc(1) + } + p.Log.Error("failed to read from utp connection", "err", err) + return 0xff, nil, err + } + p.Log.Trace("<< CONTENT/"+p.protocolName, "id", target.ID(), "size", len(data), "data", data) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + p.portalMetrics.utpInSuccess.Inc(1) + } + return resp[1], data, nil + case portalwire.ContentEnrsSelector: + enrs := &portalwire.Enrs{} + err := enrs.UnmarshalSSZ(resp[2:]) + + if err != nil { + return 0xff, nil, err + } + + p.Log.Trace("<< CONTENT_ENRS/"+p.protocolName, "id", target.ID(), "enrs", enrs) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + nodes := p.filterNodes(target, enrs.Enrs, nil) + return resp[1], nodes, nil + default: + return 0xff, nil, fmt.Errorf("invalid content response") + } +} + +func (p *PortalProtocol) processNodes(target *enode.Node, resp []byte, distances []uint) ([]*enode.Node, error) { + if len(resp) == 0 { + return nil, ErrEmptyResp + } + + if resp[0] != portalwire.NODES { + return nil, fmt.Errorf("invalid nodes response") + } + + nodesResp := &portalwire.Nodes{} + err := nodesResp.UnmarshalSSZ(resp[1:]) + if err != nil { + return nil, err + } + + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + nodes := p.filterNodes(target, nodesResp.Enrs, distances) + + return nodes, nil +} + +func (p *PortalProtocol) filterNodes(target *enode.Node, enrs [][]byte, distances []uint) []*enode.Node { + var ( + nodes []*enode.Node + seen = make(map[enode.ID]struct{}) + err error + verified = 0 + n *enode.Node + ) + + for _, b := range enrs { + record := &enr.Record{} + err = rlp.DecodeBytes(b, record) + if err != nil { + p.Log.Error("Invalid record in nodes response", "id", target.ID(), "err", err) + continue + } + n, err = p.verifyResponseNode(target, record, distances, seen) + if err != nil { + p.Log.Error("Invalid record in nodes response", "id", target.ID(), "err", err) + continue + } + verified++ + nodes = append(nodes, n) + } + + p.Log.Trace("<< NODES/"+p.protocolName, "id", target.ID(), "total", len(enrs), "verified", verified, "nodes", nodes) + if metrics.Enabled { + p.portalMetrics.messagesReceivedNodes.Mark(1) + } + return nodes +} + +func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwire.Pong, error) { + if len(resp) == 0 { + return nil, ErrEmptyResp + } + if resp[0] != portalwire.PONG { + return nil, fmt.Errorf("invalid pong response") + } + pong := &portalwire.Pong{} + err := pong.UnmarshalSSZ(resp[1:]) + if err != nil { + return nil, err + } + + p.Log.Trace("<< PONG_RESPONSE/"+p.protocolName, "id", target.ID(), "pong", pong) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPong.Mark(1) + } + + customPayload := &portalwire.PingPongCustomData{} + err = customPayload.UnmarshalSSZ(pong.CustomPayload) + if err != nil { + return nil, err + } + + p.Log.Trace("<< PONG_RESPONSE/"+p.protocolName, "id", target.ID(), "pong", pong, "customPayload", customPayload) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPong.Mark(1) + } + isAdded := p.table.AddFoundNode(target, true) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) + } + + p.radiusCache.Set([]byte(target.ID().String()), customPayload.Radius) + return pong, nil +} + +func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { + if n := p.DiscV5.GetNode(id); n != nil { + p.table.AddInboundNode(n) + } + + msgCode := msg[0] + + switch msgCode { + case portalwire.PING: + pingRequest := &portalwire.Ping{} + err := pingRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal ping request", "err", err) + return nil + } + + p.Log.Trace("<< PING/"+p.protocolName, "protocol", p.protocolName, "source", id, "pingRequest", pingRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedPing.Mark(1) + } + resp, err := p.handlePing(id, pingRequest) + if err != nil { + p.Log.Error("failed to handle ping request", "err", err) + return nil + } + + return resp + case portalwire.FINDNODES: + findNodesRequest := &portalwire.FindNodes{} + err := findNodesRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal find nodes request", "err", err) + return nil + } + + p.Log.Trace("<< FIND_NODES/"+p.protocolName, "protocol", p.protocolName, "source", id, "findNodesRequest", findNodesRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedFindNodes.Mark(1) + } + resp, err := p.handleFindNodes(addr, findNodesRequest) + if err != nil { + p.Log.Error("failed to handle find nodes request", "err", err) + return nil + } + + return resp + case portalwire.FINDCONTENT: + findContentRequest := &portalwire.FindContent{} + err := findContentRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal find content request", "err", err) + return nil + } + + p.Log.Trace("<< FIND_CONTENT/"+p.protocolName, "protocol", p.protocolName, "source", id, "findContentRequest", findContentRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedFindContent.Mark(1) + } + resp, err := p.handleFindContent(id, addr, findContentRequest) + if err != nil { + p.Log.Error("failed to handle find content request", "err", err) + return nil + } + + return resp + case portalwire.OFFER: + offerRequest := &portalwire.Offer{} + err := offerRequest.UnmarshalSSZ(msg[1:]) + if err != nil { + p.Log.Error("failed to unmarshal offer request", "err", err) + return nil + } + + p.Log.Trace("<< OFFER/"+p.protocolName, "protocol", p.protocolName, "source", id, "offerRequest", offerRequest) + if metrics.Enabled { + p.portalMetrics.messagesReceivedOffer.Mark(1) + } + resp, err := p.handleOffer(id, addr, offerRequest) + if err != nil { + p.Log.Error("failed to handle offer request", "err", err) + return nil + } + + return resp + } + + return nil +} + +func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, error) { + pingCustomPayload := &portalwire.PingPongCustomData{} + err := pingCustomPayload.UnmarshalSSZ(ping.CustomPayload) + if err != nil { + return nil, err + } + + p.radiusCache.Set([]byte(id.String()), pingCustomPayload.Radius) + + enrSeq := p.Self().Seq() + radiusBytes, err := p.Radius().MarshalSSZ() + if err != nil { + return nil, err + } + pongCustomPayload := &portalwire.PingPongCustomData{ + Radius: radiusBytes, + } + + pongCustomPayloadBytes, err := pongCustomPayload.MarshalSSZ() + if err != nil { + return nil, err + } + + pong := &portalwire.Pong{ + EnrSeq: enrSeq, + CustomPayload: pongCustomPayloadBytes, + } + + p.Log.Trace(">> PONG/"+p.protocolName, "protocol", p.protocolName, "source", id, "pong", pong) + if metrics.Enabled { + p.portalMetrics.messagesSentPong.Mark(1) + } + pongBytes, err := pong.MarshalSSZ() + + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(pongBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.PONG) + talkRespBytes = append(talkRespBytes, pongBytes...) + + return talkRespBytes, nil +} + +func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *portalwire.FindNodes) ([]byte, error) { + distances := make([]uint, len(request.Distances)) + for i, distance := range request.Distances { + distances[i] = uint(ssz.UnmarshallUint16(distance[:])) + } + + nodes := p.collectTableNodes(fromAddr.IP, distances, portalFindnodesResultLimit) + + nodesOverhead := 1 + 1 + 4 // msg id + total + container offset + maxPayloadSize := v5wire.MaxPacketSize - talkRespOverhead - nodesOverhead + enrOverhead := 4 //per added ENR, 4 bytes offset overhead + + enrs := p.truncateNodes(nodes, maxPayloadSize, enrOverhead) + + nodesMsg := &portalwire.Nodes{ + Total: 1, + Enrs: enrs, + } + + p.Log.Trace(">> NODES/"+p.protocolName, "protocol", p.protocolName, "source", fromAddr, "nodes", nodesMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentNodes.Mark(1) + } + nodesMsgBytes, err := nodesMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(nodesMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.NODES) + talkRespBytes = append(talkRespBytes, nodesMsgBytes...) + + return talkRespBytes, nil +} + +func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, request *portalwire.FindContent) ([]byte, error) { + contentOverhead := 1 + 1 // msg id + SSZ Union selector + maxPayloadSize := v5wire.MaxPacketSize - talkRespOverhead - contentOverhead + enrOverhead := 4 //per added ENR, 4 bytes offset overhead + var err error + contentKey := request.ContentKey + contentId := p.toContentId(contentKey) + if contentId == nil { + return nil, ErrNilContentKey + } + + var content []byte + content, err = p.storage.Get(contentKey, contentId) + if err != nil && !errors.Is(err, ContentNotFound) { + return nil, err + } + + if errors.Is(err, ContentNotFound) { + closestNodes := p.findNodesCloseToContent(contentId, portalFindnodesResultLimit) + for i, n := range closestNodes { + if n.ID() == id { + closestNodes = append(closestNodes[:i], closestNodes[i+1:]...) + break + } + } + + enrs := p.truncateNodes(closestNodes, maxPayloadSize, enrOverhead) + // TODO fix when no content and no enrs found + if len(enrs) == 0 { + enrs = nil + } + + enrsMsg := &portalwire.Enrs{ + Enrs: enrs, + } + + p.Log.Trace(">> CONTENT_ENRS/"+p.protocolName, "protocol", p.protocolName, "source", addr, "enrs", enrsMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + } + var enrsMsgBytes []byte + enrsMsgBytes, err = enrsMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + contentMsgBytes := make([]byte, 0, len(enrsMsgBytes)+1) + contentMsgBytes = append(contentMsgBytes, portalwire.ContentEnrsSelector) + contentMsgBytes = append(contentMsgBytes, enrsMsgBytes...) + + talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, contentMsgBytes...) + + return talkRespBytes, nil + } else if len(content) <= maxPayloadSize { + rawContentMsg := &portalwire.Content{ + Content: content, + } + + p.Log.Trace(">> CONTENT_RAW/"+p.protocolName, "protocol", p.protocolName, "source", addr, "content", rawContentMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + } + + var rawContentMsgBytes []byte + rawContentMsgBytes, err = rawContentMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + contentMsgBytes := make([]byte, 0, len(rawContentMsgBytes)+1) + contentMsgBytes = append(contentMsgBytes, portalwire.ContentRawSelector) + contentMsgBytes = append(contentMsgBytes, rawContentMsgBytes...) + + talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, contentMsgBytes...) + + return talkRespBytes, nil + } else { + connectionId := p.connIdGen.GenCid(id, false) + + go func(bctx context.Context, connId *libutp.ConnId) { + var conn *utp.Conn + var connectCtx context.Context + var cancel context.CancelFunc + defer func() { + p.connIdGen.Remove(connectionId) + if conn == nil { + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + for { + select { + case <-bctx.Done(): + return + default: + p.Log.Debug("will accept find content conn from: ", "nodeId", id.String(), "source", addr, "connId", connId) + connectCtx, cancel = context.WithTimeout(bctx, defaultUTPConnectTimeout) + conn, err = p.Utp.AcceptWithCid(connectCtx, id, connectionId) + cancel() + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailConn.Inc(1) + } + p.Log.Error("failed to accept utp connection for handle find content", "connId", connectionId.SendId(), "err", err) + return + } + + err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailDeadline.Inc(1) + } + p.Log.Error("failed to set write deadline", "err", err) + return + } + + var n int + n, err = conn.Write(content) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpOutFailWrite.Inc(1) + } + p.Log.Error("failed to write content to utp connection", "err", err) + return + } + + if metrics.Enabled { + p.portalMetrics.utpOutSuccess.Inc(1) + } + p.Log.Trace("wrote content size to utp connection", "n", n) + return + } + } + }(p.closeCtx, connectionId) + + idBuffer := make([]byte, 2) + binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) + connIdMsg := &portalwire.ConnectionId{ + Id: idBuffer, + } + + p.Log.Trace(">> CONTENT_CONNECTION_ID/"+p.protocolName, "protocol", p.protocolName, "source", addr, "connId", connIdMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentContent.Mark(1) + } + var connIdMsgBytes []byte + connIdMsgBytes, err = connIdMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + contentMsgBytes := make([]byte, 0, len(connIdMsgBytes)+1) + contentMsgBytes = append(contentMsgBytes, portalwire.ContentConnIdSelector) + contentMsgBytes = append(contentMsgBytes, connIdMsgBytes...) + + talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, contentMsgBytes...) + + return talkRespBytes, nil + } +} + +func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *portalwire.Offer) ([]byte, error) { + var err error + contentKeyBitlist := bitfield.NewBitlist(uint64(len(request.ContentKeys))) + if len(p.contentQueue) >= cap(p.contentQueue) { + acceptMsg := &portalwire.Accept{ + ConnectionId: []byte{0, 0}, + ContentKeys: contentKeyBitlist, + } + + p.Log.Trace(">> ACCEPT/"+p.protocolName, "protocol", p.protocolName, "source", addr, "accept", acceptMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentAccept.Mark(1) + } + var acceptMsgBytes []byte + acceptMsgBytes, err = acceptMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) + talkRespBytes = append(talkRespBytes, acceptMsgBytes...) + + return talkRespBytes, nil + } + + contentKeys := make([][]byte, 0) + for i, contentKey := range request.ContentKeys { + contentId := p.toContentId(contentKey) + if contentId != nil { + if inRange(p.Self().ID(), p.Radius(), contentId) { + if _, err = p.storage.Get(contentKey, contentId); err != nil { + contentKeyBitlist.SetBitAt(uint64(i), true) + contentKeys = append(contentKeys, contentKey) + } + } + } else { + return nil, ErrNilContentKey + } + } + + idBuffer := make([]byte, 2) + if contentKeyBitlist.Count() != 0 { + connectionId := p.connIdGen.GenCid(id, false) + + go func(bctx context.Context, connId *libutp.ConnId) { + var conn *utp.Conn + var connectCtx context.Context + var cancel context.CancelFunc + defer func() { + p.connIdGen.Remove(connectionId) + if conn == nil { + return + } + err := conn.Close() + if err != nil { + p.Log.Error("failed to close connection", "err", err) + } + }() + for { + select { + case <-bctx.Done(): + return + default: + p.Log.Debug("will accept offer conn from: ", "source", addr, "connId", connId) + connectCtx, cancel = context.WithTimeout(bctx, defaultUTPConnectTimeout) + conn, err = p.Utp.AcceptWithCid(connectCtx, id, connectionId) + cancel() + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailConn.Inc(1) + } + p.Log.Error("failed to accept utp connection for handle offer", "connId", connectionId.SendId(), "err", err) + return + } + + err = conn.SetReadDeadline(time.Now().Add(defaultUTPReadTimeout)) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailDeadline.Inc(1) + } + p.Log.Error("failed to set read deadline", "err", err) + return + } + // Read ALL the data from the connection until EOF and return it + var data []byte + data, err = io.ReadAll(conn) + if err != nil { + if metrics.Enabled { + p.portalMetrics.utpInFailRead.Inc(1) + } + p.Log.Error("failed to read from utp connection", "err", err) + return + } + p.Log.Trace("<< OFFER_CONTENT/"+p.protocolName, "id", id, "size", len(data), "data", data) + if metrics.Enabled { + p.portalMetrics.messagesReceivedContent.Mark(1) + } + + err = p.handleOfferedContents(id, contentKeys, data) + if err != nil { + p.Log.Error("failed to handle offered Contents", "err", err) + return + } + + if metrics.Enabled { + p.portalMetrics.utpInSuccess.Inc(1) + } + return + } + } + }(p.closeCtx, connectionId) + + binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) + } else { + binary.BigEndian.PutUint16(idBuffer, uint16(0)) + } + + acceptMsg := &portalwire.Accept{ + ConnectionId: idBuffer, + ContentKeys: []byte(contentKeyBitlist), + } + + p.Log.Trace(">> ACCEPT/"+p.protocolName, "protocol", p.protocolName, "source", addr, "accept", acceptMsg) + if metrics.Enabled { + p.portalMetrics.messagesSentAccept.Mark(1) + } + var acceptMsgBytes []byte + acceptMsgBytes, err = acceptMsg.MarshalSSZ() + if err != nil { + return nil, err + } + + talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) + talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) + talkRespBytes = append(talkRespBytes, acceptMsgBytes...) + + return talkRespBytes, nil +} + +func (p *PortalProtocol) handleOfferedContents(id enode.ID, keys [][]byte, payload []byte) error { + contents, err := decodeContents(payload) + if err != nil { + if metrics.Enabled { + p.portalMetrics.contentDecodedFalse.Inc(1) + } + return err + } + + keyLen := len(keys) + contentLen := len(contents) + if keyLen != contentLen { + if metrics.Enabled { + p.portalMetrics.contentDecodedFalse.Inc(1) + } + return fmt.Errorf("content keys len %d doesn't match content values len %d", keyLen, contentLen) + } + + contentElement := &ContentElement{ + Node: id, + ContentKeys: keys, + Contents: contents, + } + + p.contentQueue <- contentElement + + if metrics.Enabled { + p.portalMetrics.contentDecodedTrue.Inc(1) + } + return nil +} + +func (p *PortalProtocol) Self() *enode.Node { + return p.localNode.Node() +} + +func (p *PortalProtocol) RequestENR(n *enode.Node) (*enode.Node, error) { + nodes, err := p.findNodes(n, []uint{0}) + if err != nil { + return nil, err + } + if len(nodes) != 1 { + return nil, fmt.Errorf("%d nodes in response for distance zero", len(nodes)) + } + return nodes[0], nil +} + +func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, distances []uint, seen map[enode.ID]struct{}) (*enode.Node, error) { + n, err := enode.New(p.validSchemes, r) + if err != nil { + return nil, err + } + if err = netutil.CheckRelayIP(sender.IP(), n.IP()); err != nil { + return nil, err + } + if p.NetRestrict != nil && !p.NetRestrict.Contains(n.IP()) { + return nil, errors.New("not contained in netrestrict list") + } + if n.UDP() <= 1024 { + return nil, discover.ErrLowPort + } + if distances != nil { + nd := enode.LogDist(sender.ID(), n.ID()) + if !slices.Contains(distances, uint(nd)) { + return nil, errors.New("does not match any requested distance") + } + } + if _, ok := seen[n.ID()]; ok { + return nil, fmt.Errorf("duplicate record") + } + seen[n.ID()] = struct{}{} + return n, nil +} + +// LookupRandom looks up a random target. +// This is needed to satisfy the transport interface. +func (p *PortalProtocol) LookupRandom() []*enode.Node { + return p.newRandomLookup(p.closeCtx).Run() +} + +// LookupSelf looks up our own node ID. +// This is needed to satisfy the transport interface. +func (p *PortalProtocol) LookupSelf() []*enode.Node { + return p.newLookup(p.closeCtx, p.Self().ID()).Run() +} + +func (p *PortalProtocol) newRandomLookup(ctx context.Context) *discover.Lookup { + var target enode.ID + _, _ = crand.Read(target[:]) + return p.newLookup(ctx, target) +} + +func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *discover.Lookup { + return discover.NewLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { + return p.lookupWorker(n, target) + }) +} + +// lookupWorker performs FINDNODE calls against a single node during lookup. +func (p *PortalProtocol) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { + var ( + dists = discover.LookupDistances(target, destNode.ID()) + nodes = discover.NodesByDistance{Target: target} + err error + ) + var r []*enode.Node + + r, err = p.findNodes(destNode, dists) + if errors.Is(err, discover.ErrClosed) { + return nil, err + } + for _, n := range r { + if n.ID() != p.Self().ID() { + isAdded := p.table.AddFoundNode(n, false) + if isAdded { + log.Debug("Node added to bucket", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) + } else { + log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) + } + nodes.Push(n, portalFindnodesResultLimit) + } + } + return nodes.Entries, err +} + +func (p *PortalProtocol) offerWorker() { + for { + select { + case <-p.closeCtx.Done(): + return + case offerRequestWithNode := <-p.offerQueue: + p.Log.Trace("offerWorker", "offerRequestWithNode", offerRequestWithNode) + _, err := p.offer(offerRequestWithNode.Node, offerRequestWithNode.Request) + if err != nil { + p.Log.Error("failed to offer", "err", err) + } + } + } +} + +func (p *PortalProtocol) truncateNodes(nodes []*enode.Node, maxSize int, enrOverhead int) [][]byte { + res := make([][]byte, 0) + totalSize := 0 + for _, n := range nodes { + enrBytes, err := rlp.EncodeToBytes(n.Record()) + if err != nil { + p.Log.Error("failed to encode n", "err", err) + continue + } + + if totalSize+len(enrBytes)+enrOverhead > maxSize { + break + } else { + res = append(res, enrBytes) + totalSize += len(enrBytes) + } + } + return res +} + +func (p *PortalProtocol) findNodesCloseToContent(contentId []byte, limit int) []*enode.Node { + allNodes := p.table.NodeList() + sort.Slice(allNodes, func(i, j int) bool { + return enode.LogDist(allNodes[i].ID(), enode.ID(contentId)) < enode.LogDist(allNodes[j].ID(), enode.ID(contentId)) + }) + + if len(allNodes) > limit { + allNodes = allNodes[:limit] + } else { + allNodes = allNodes[:] + } + + return allNodes +} + +// Lookup performs a recursive lookup for the given target. +// It returns the closest nodes to target. +func (p *PortalProtocol) Lookup(target enode.ID) []*enode.Node { + return p.newLookup(p.closeCtx, target).Run() +} + +// Resolve searches for a specific Node with the given ID and tries to get the most recent +// version of the Node record for it. It returns n if the Node could not be resolved. +func (p *PortalProtocol) Resolve(n *enode.Node) *enode.Node { + if intable := p.table.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { + n = intable + } + // Try asking directly. This works if the Node is still responding on the endpoint we have. + if resp, err := p.RequestENR(n); err == nil { + return resp + } + // Otherwise do a network lookup. + result := p.Lookup(n.ID()) + for _, rn := range result { + if rn.ID() == n.ID() && rn.Seq() > n.Seq() { + return rn + } + } + return n +} + +// ResolveNodeId searches for a specific Node with the given ID. +// It returns nil if the nodeId could not be resolved. +func (p *PortalProtocol) ResolveNodeId(id enode.ID) *enode.Node { + if id == p.Self().ID() { + p.Log.Debug("Resolve Self Id", "id", id.String()) + return p.Self() + } + + n := p.table.GetNode(id) + if n != nil { + p.Log.Debug("found Id in table and will request enr from the node", "id", id.String()) + // Try asking directly. This works if the Node is still responding on the endpoint we have. + if resp, err := p.RequestENR(n); err == nil { + return resp + } + } + + // Otherwise do a network lookup. + result := p.Lookup(id) + for _, rn := range result { + if rn.ID() == id { + if n != nil && rn.Seq() <= n.Seq() { + return n + } else { + return rn + } + } + } + + return n +} + +func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { + var bn []*enode.Node + var nodes []*enode.Node + var processed = make(map[uint]struct{}) + for _, dist := range distances { + // Reject duplicate / invalid distances. + _, seen := processed[dist] + if seen || dist > 256 { + continue + } + processed[dist] = struct{}{} + + checkLive := !p.table.Config().NoFindnodeLivenessCheck + for _, n := range p.table.AppendBucketNodes(dist, bn[:0], checkLive) { + // Apply some pre-checks to avoid sending invalid nodes. + // Note liveness is checked by appendLiveNodes. + if netutil.CheckRelayIP(rip, n.IP()) != nil { + continue + } + nodes = append(nodes, n) + if len(nodes) >= limit { + return nodes + } + } + } + return nodes +} + +func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bool, error) { + lookupContext, cancel := context.WithCancel(context.Background()) + + resChan := make(chan *traceContentInfoResp, discover.Alpha) + hasResult := int32(0) + + result := ContentInfoResp{} + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + for res := range resChan { + if res.Flag != portalwire.ContentEnrsSelector { + result.Content = res.Content.([]byte) + result.UtpTransfer = res.UtpTransfer + } + } + }() + + discover.NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) + }).Run() + close(resChan) + + wg.Wait() + if hasResult == 1 { + return result.Content, result.UtpTransfer, nil + } + defer cancel() + return nil, false, ContentNotFound +} + +func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*TraceContentResult, error) { + lookupContext, cancel := context.WithCancel(context.Background()) + // resp channel + resChan := make(chan *traceContentInfoResp, discover.Alpha) + + hasResult := int32(0) + + traceContentRes := &TraceContentResult{} + + selfHexId := "0x" + p.Self().ID().String() + + trace := &Trace{ + Origin: selfHexId, + TargetId: hexutil.Encode(contentId), + StartedAtMs: int(time.Now().UnixMilli()), + Responses: make(map[string]RespByNode), + Metadata: make(map[string]*NodeMetadata), + Cancelled: make([]string, 0), + } + + nodes := p.table.FindnodeByID(enode.ID(contentId), discover.BucketSize, false) + + localResponse := make([]string, 0, len(nodes.Entries)) + for _, node := range nodes.Entries { + id := "0x" + node.ID().String() + localResponse = append(localResponse, id) + } + trace.Responses[selfHexId] = RespByNode{ + DurationMs: 0, + RespondedWith: localResponse, + } + + dis := p.Distance(p.Self().ID(), enode.ID(contentId)) + + trace.Metadata[selfHexId] = &NodeMetadata{ + Enr: p.Self().String(), + Distance: hexutil.Encode(dis[:]), + } + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + for res := range resChan { + node := res.Node + hexId := "0x" + node.ID().String() + dis := p.Distance(node.ID(), enode.ID(contentId)) + p.Log.Debug("reveice res", "id", hexId, "flag", res.Flag) + trace.Metadata[hexId] = &NodeMetadata{ + Enr: node.String(), + Distance: hexutil.Encode(dis[:]), + } + // no content return + if traceContentRes.Content == "" { + if res.Flag == portalwire.ContentRawSelector || res.Flag == portalwire.ContentConnIdSelector { + trace.ReceivedFrom = hexId + content := res.Content.([]byte) + traceContentRes.Content = hexutil.Encode(content) + traceContentRes.UtpTransfer = res.UtpTransfer + trace.Responses[hexId] = RespByNode{} + } else { + nodes := res.Content.([]*enode.Node) + respByNode := RespByNode{ + RespondedWith: make([]string, 0, len(nodes)), + } + for _, node := range nodes { + idInner := "0x" + node.ID().String() + respByNode.RespondedWith = append(respByNode.RespondedWith, idInner) + if _, ok := trace.Metadata[idInner]; !ok { + dis := p.Distance(node.ID(), enode.ID(contentId)) + trace.Metadata[idInner] = &NodeMetadata{ + Enr: node.String(), + Distance: hexutil.Encode(dis[:]), + } + } + trace.Responses[hexId] = respByNode + } + } + } else { + trace.Cancelled = append(trace.Cancelled, hexId) + } + } + }() + + lookup := discover.NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { + return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) + }) + lookup.Run() + close(resChan) + + wg.Wait() + if hasResult == 0 { + cancel() + } + traceContentRes.Trace = *trace + + return traceContentRes, nil +} + +func (p *PortalProtocol) contentLookupWorker(n *enode.Node, contentKey []byte, resChan chan<- *traceContentInfoResp, cancel context.CancelFunc, done *int32) ([]*enode.Node, error) { + wrapedNode := make([]*enode.Node, 0) + flag, content, err := p.findContent(n, contentKey) + if err != nil { + return nil, err + } + p.Log.Debug("traceContentLookupWorker reveice response", "ip", n.IP().String(), "flag", flag) + + switch flag { + case portalwire.ContentRawSelector, portalwire.ContentConnIdSelector: + content, ok := content.([]byte) + if !ok { + return wrapedNode, fmt.Errorf("failed to assert to raw content, value is: %v", content) + } + res := &traceContentInfoResp{ + Node: n, + Flag: flag, + Content: content, + UtpTransfer: false, + } + if flag == portalwire.ContentConnIdSelector { + res.UtpTransfer = true + } + if atomic.CompareAndSwapInt32(done, 0, 1) { + p.Log.Debug("contentLookupWorker find content", "ip", n.IP().String(), "port", n.UDP()) + resChan <- res + cancel() + } + return wrapedNode, err + case portalwire.ContentEnrsSelector: + nodes, ok := content.([]*enode.Node) + if !ok { + return wrapedNode, fmt.Errorf("failed to assert to enrs content, value is: %v", content) + } + resChan <- &traceContentInfoResp{ + Node: n, + Flag: flag, + Content: content, + UtpTransfer: false, + } + return nodes, nil + } + return wrapedNode, nil +} + +func (p *PortalProtocol) ToContentId(contentKey []byte) []byte { + return p.toContentId(contentKey) +} + +func (p *PortalProtocol) InRange(contentId []byte) bool { + return inRange(p.Self().ID(), p.Radius(), contentId) +} + +func (p *PortalProtocol) Get(contentKey []byte, contentId []byte) ([]byte, error) { + content, err := p.storage.Get(contentKey, contentId) + p.Log.Trace("get local storage", "contentId", hexutil.Encode(contentId), "content", hexutil.Encode(content), "err", err) + return content, err +} + +func (p *PortalProtocol) Put(contentKey []byte, contentId []byte, content []byte) error { + err := p.storage.Put(contentKey, contentId, content) + p.Log.Trace("put local storage", "contentId", hexutil.Encode(contentId), "content", hexutil.Encode(content), "err", err) + return err +} + +func (p *PortalProtocol) GetContent() chan *ContentElement { + return p.contentQueue +} + +func (p *PortalProtocol) Gossip(srcNodeId *enode.ID, contentKeys [][]byte, content [][]byte) (int, error) { + if len(content) == 0 { + return 0, errors.New("empty content") + } + + contentList := make([]*ContentEntry, 0, portalwire.ContentKeysLimit) + for i := 0; i < len(content); i++ { + contentEntry := &ContentEntry{ + ContentKey: contentKeys[i], + Content: content[i], + } + contentList = append(contentList, contentEntry) + } + + contentId := p.toContentId(contentKeys[0]) + if contentId == nil { + return 0, ErrNilContentKey + } + + maxClosestNodes := 4 + maxFartherNodes := 4 + closestLocalNodes := p.findNodesCloseToContent(contentId, 32) + p.Log.Debug("closest local nodes", "count", len(closestLocalNodes)) + + gossipNodes := make([]*enode.Node, 0) + for _, n := range closestLocalNodes { + radius, found := p.radiusCache.HasGet(nil, []byte(n.ID().String())) + if found { + p.Log.Debug("found closest local nodes", "nodeId", n.ID(), "addr", n.IPAddr().String()) + nodeRadius := new(uint256.Int) + err := nodeRadius.UnmarshalSSZ(radius) + if err != nil { + return 0, err + } + if inRange(n.ID(), nodeRadius, contentId) { + if srcNodeId == nil { + gossipNodes = append(gossipNodes, n) + } else if n.ID() != *srcNodeId { + gossipNodes = append(gossipNodes, n) + } + } + } + } + + if len(gossipNodes) == 0 { + return 0, nil + } + + var finalGossipNodes []*enode.Node + if len(gossipNodes) > maxClosestNodes { + fartherNodes := gossipNodes[maxClosestNodes:] + rand.Shuffle(len(fartherNodes), func(i, j int) { + fartherNodes[i], fartherNodes[j] = fartherNodes[j], fartherNodes[i] + }) + finalGossipNodes = append(gossipNodes[:maxClosestNodes], fartherNodes[:min(maxFartherNodes, len(fartherNodes))]...) + } else { + finalGossipNodes = gossipNodes + } + + for _, n := range finalGossipNodes { + transientOfferRequest := &TransientOfferRequest{ + Contents: contentList, + } + + offerRequest := &OfferRequest{ + Kind: TransientOfferRequestKind, + Request: transientOfferRequest, + } + + offerRequestWithNode := &OfferRequestWithNode{ + Node: n, + Request: offerRequest, + } + p.offerQueue <- offerRequestWithNode + } + + return len(finalGossipNodes), nil +} + +func (p *PortalProtocol) Distance(a, b enode.ID) enode.ID { + res := [32]byte{} + for i := range a { + res[i] = a[i] ^ b[i] + } + return res +} + +func inRange(nodeId enode.ID, nodeRadius *uint256.Int, contentId []byte) bool { + distance := enode.LogDist(nodeId, enode.ID(contentId)) + disBig := new(big.Int).SetInt64(int64(distance)) + return nodeRadius.CmpBig(disBig) > 0 +} + +func encodeContents(contents [][]byte) ([]byte, error) { + contentsBytes := make([]byte, 0) + for _, content := range contents { + contentLen := len(content) + contentLenBytes := leb128.EncodeUint32(uint32(contentLen)) + contentsBytes = append(contentsBytes, contentLenBytes...) + contentsBytes = append(contentsBytes, content...) + } + + return contentsBytes, nil +} + +func decodeContents(payload []byte) ([][]byte, error) { + contents := make([][]byte, 0) + buffer := bytes.NewBuffer(payload) + + for { + contentLen, contentLenLen, err := leb128.DecodeUint32(bytes.NewReader(buffer.Bytes())) + if err != nil { + if errors.Is(err, io.EOF) { + return contents, nil + } + return nil, err + } + + buffer.Next(int(contentLenLen)) + + content := make([]byte, contentLen) + _, err = buffer.Read(content) + if err != nil { + if errors.Is(err, io.EOF) { + return contents, nil + } + return nil, err + } + + contents = append(contents, content) + } +} + +func getContentKeys(request *OfferRequest) [][]byte { + if request.Kind == TransientOfferRequestKind { + contentKeys := make([][]byte, 0) + contents := request.Request.(*TransientOfferRequest).Contents + for _, content := range contents { + contentKeys = append(contentKeys, content.ContentKey) + } + + return contentKeys + } else { + return request.Request.(*PersistOfferRequest).ContentKeys + } +} diff --git a/portalnetwork/portal_protocol_metrics.go b/portalnetwork/portal_protocol_metrics.go new file mode 100644 index 000000000000..343d3f4f00f3 --- /dev/null +++ b/portalnetwork/portal_protocol_metrics.go @@ -0,0 +1,67 @@ +package portalnetwork + +import "github.com/ethereum/go-ethereum/metrics" + +type portalMetrics struct { + messagesReceivedAccept metrics.Meter + messagesReceivedNodes metrics.Meter + messagesReceivedFindNodes metrics.Meter + messagesReceivedFindContent metrics.Meter + messagesReceivedContent metrics.Meter + messagesReceivedOffer metrics.Meter + messagesReceivedPing metrics.Meter + messagesReceivedPong metrics.Meter + + messagesSentAccept metrics.Meter + messagesSentNodes metrics.Meter + messagesSentFindNodes metrics.Meter + messagesSentFindContent metrics.Meter + messagesSentContent metrics.Meter + messagesSentOffer metrics.Meter + messagesSentPing metrics.Meter + messagesSentPong metrics.Meter + + utpInFailConn metrics.Counter + utpInFailRead metrics.Counter + utpInFailDeadline metrics.Counter + utpInSuccess metrics.Counter + + utpOutFailConn metrics.Counter + utpOutFailWrite metrics.Counter + utpOutFailDeadline metrics.Counter + utpOutSuccess metrics.Counter + + contentDecodedTrue metrics.Counter + contentDecodedFalse metrics.Counter +} + +func newPortalMetrics(protocolName string) *portalMetrics { + return &portalMetrics{ + messagesReceivedAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/accept", nil), + messagesReceivedNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/nodes", nil), + messagesReceivedFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_nodes", nil), + messagesReceivedFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_content", nil), + messagesReceivedContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/content", nil), + messagesReceivedOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/offer", nil), + messagesReceivedPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/ping", nil), + messagesReceivedPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/pong", nil), + messagesSentAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/accept", nil), + messagesSentNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/nodes", nil), + messagesSentFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_nodes", nil), + messagesSentFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_content", nil), + messagesSentContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/content", nil), + messagesSentOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/offer", nil), + messagesSentPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/ping", nil), + messagesSentPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/pong", nil), + utpInFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_conn", nil), + utpInFailRead: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_read", nil), + utpInFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_deadline", nil), + utpInSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/success", nil), + utpOutFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_conn", nil), + utpOutFailWrite: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_write", nil), + utpOutFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_deadline", nil), + utpOutSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/success", nil), + contentDecodedTrue: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/true", nil), + contentDecodedFalse: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/false", nil), + } +} diff --git a/portalnetwork/portal_protocol_test.go b/portalnetwork/portal_protocol_test.go new file mode 100644 index 000000000000..fcc79e9d4f5f --- /dev/null +++ b/portalnetwork/portal_protocol_test.go @@ -0,0 +1,514 @@ +package portalnetwork + +import ( + "context" + "crypto/ecdsa" + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/optimism-java/utp-go" + "github.com/optimism-java/utp-go/libutp" + "github.com/prysmaticlabs/go-bitfield" + "golang.org/x/exp/slices" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + assert "github.com/stretchr/testify/require" +) + +func setupLocalPortalNode(addr string, bootNodes []*enode.Node) (*PortalProtocol, error) { + conf := DefaultPortalProtocolConfig() + conf.NAT = nil + if addr != "" { + conf.ListenAddr = addr + } + if bootNodes != nil { + conf.BootstrapNodes = bootNodes + } + + addr1, err := net.ResolveUDPAddr("udp", conf.ListenAddr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", addr1) + if err != nil { + return nil, err + } + + privKey := newkey() + + discCfg := discover.Config{ + PrivateKey: privKey, + NetRestrict: conf.NetRestrict, + Bootnodes: conf.BootstrapNodes, + } + + nodeDB, err := enode.OpenDB(conf.NodeDBPath) + if err != nil { + return nil, err + } + + localNode := enode.NewLocalNode(nodeDB, privKey) + localNode.SetFallbackIP(net.IP{127, 0, 0, 1}) + localNode.Set(Tag) + + if conf.NAT == nil { + var addrs []net.Addr + addrs, err = net.InterfaceAddrs() + + if err != nil { + return nil, err + } + + for _, address := range addrs { + // check ip addr is loopback addr + if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { + if ipnet.IP.To4() != nil { + localNode.SetStaticIP(ipnet.IP) + break + } + } + } + } + + discV5, err := discover.ListenV5(conn, localNode, discCfg) + if err != nil { + return nil, err + } + utpSocket := NewPortalUtp(context.Background(), conf, discV5, conn) + + contentQueue := make(chan *ContentElement, 50) + portalProtocol, err := NewPortalProtocol( + conf, + portalwire.History, + privKey, + conn, + localNode, + discV5, + utpSocket, + &storage.MockStorage{Db: make(map[string][]byte)}, + contentQueue) + if err != nil { + return nil, err + } + + return portalProtocol, nil +} + +func TestPortalWireProtocolUdp(t *testing.T) { + node1, err := setupLocalPortalNode(":8777", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LvlTrace) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":8778", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LvlTrace) + err = node2.Start() + assert.NoError(t, err) + time.Sleep(12 * time.Second) + + node3, err := setupLocalPortalNode(":8779", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LvlTrace) + err = node3.Start() + assert.NoError(t, err) + time.Sleep(12 * time.Second) + + cid1 := libutp.ReceConnId(12) + cid2 := libutp.ReceConnId(116) + cliSendMsgWithCid1 := "there are connection id : 12!" + cliSendMsgWithCid2 := "there are connection id: 116!" + + serverEchoWithCid := "accept connection sends back msg: echo" + + largeTestContent := make([]byte, 1199) + _, err = rand.Read(largeTestContent) + assert.NoError(t, err) + + var workGroup sync.WaitGroup + var acceptGroup sync.WaitGroup + workGroup.Add(4) + acceptGroup.Add(1) + go func() { + var acceptConn *utp.Conn + defer func() { + workGroup.Done() + _ = acceptConn.Close() + }() + acceptConn, err := node1.Utp.AcceptWithCid(context.Background(), node2.localNode.ID(), cid1) + if err != nil { + panic(err) + } + acceptGroup.Done() + buf := make([]byte, 100) + n, err := acceptConn.Read(buf) + if err != nil && err != io.EOF { + panic(err) + } + assert.Equal(t, cliSendMsgWithCid1, string(buf[:n])) + _, err = acceptConn.Write([]byte(serverEchoWithCid)) + if err != nil { + panic(err) + } + }() + go func() { + var connId2Conn net.Conn + defer func() { + workGroup.Done() + _ = connId2Conn.Close() + }() + connId2Conn, err := node1.Utp.AcceptWithCid(context.Background(), node2.localNode.ID(), cid2) + if err != nil { + panic(err) + } + buf := make([]byte, 100) + n, err := connId2Conn.Read(buf) + if err != nil && err != io.EOF { + panic(err) + } + assert.Equal(t, cliSendMsgWithCid2, string(buf[:n])) + + _, err = connId2Conn.Write(largeTestContent) + if err != nil { + panic(err) + } + }() + + go func() { + var connWithConnId net.Conn + defer func() { + workGroup.Done() + if connWithConnId != nil { + _ = connWithConnId.Close() + } + }() + connWithConnId, err = node2.Utp.DialWithCid(context.Background(), node1.localNode.Node(), cid1.SendId()) + if err != nil { + panic(err) + } + _, err = connWithConnId.Write([]byte(cliSendMsgWithCid1)) + if err != nil && err != io.EOF { + panic(err) + } + buf := make([]byte, 100) + n, err := connWithConnId.Read(buf) + if err != nil && err != io.EOF { + panic(err) + } + assert.Equal(t, serverEchoWithCid, string(buf[:n])) + }() + go func() { + var ConnId2Conn net.Conn + defer func() { + workGroup.Done() + if ConnId2Conn != nil { + _ = ConnId2Conn.Close() + } + }() + ConnId2Conn, err = node2.Utp.DialWithCid(context.Background(), node1.localNode.Node(), cid2.SendId()) + if err != nil && err != io.EOF { + panic(err) + } + _, err = ConnId2Conn.Write([]byte(cliSendMsgWithCid2)) + if err != nil { + panic(err) + } + + data := make([]byte, 0) + buf := make([]byte, 1024) + for { + var n int + n, err = ConnId2Conn.Read(buf) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + } + data = append(data, buf[:n]...) + } + assert.Equal(t, largeTestContent, data) + }() + workGroup.Wait() + node1.Stop() + node2.Stop() + node3.Stop() +} + +func TestPortalWireProtocol(t *testing.T) { + node1, err := setupLocalPortalNode(":7777", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LevelDebug) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":7778", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LevelDebug) + err = node2.Start() + assert.NoError(t, err) + + time.Sleep(12 * time.Second) + + node3, err := setupLocalPortalNode(":7779", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LevelDebug) + err = node3.Start() + assert.NoError(t, err) + + time.Sleep(12 * time.Second) + + slices.ContainsFunc(node1.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node2.localNode.Node().ID() + }) + slices.ContainsFunc(node1.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node3.localNode.Node().ID() + }) + + slices.ContainsFunc(node2.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node1.localNode.Node().ID() + }) + slices.ContainsFunc(node2.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node3.localNode.Node().ID() + }) + + slices.ContainsFunc(node3.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node1.localNode.Node().ID() + }) + slices.ContainsFunc(node3.table.NodeList(), func(n *enode.Node) bool { + return n.ID() == node2.localNode.Node().ID() + }) + + err = node1.storage.Put(nil, node1.toContentId([]byte("test_key")), []byte("test_value")) + assert.NoError(t, err) + + flag, content, err := node2.findContent(node1.localNode.Node(), []byte("test_key")) + assert.NoError(t, err) + assert.Equal(t, portalwire.ContentRawSelector, flag) + assert.Equal(t, []byte("test_value"), content) + + flag, content, err = node2.findContent(node3.localNode.Node(), []byte("test_key")) + assert.NoError(t, err) + assert.Equal(t, portalwire.ContentEnrsSelector, flag) + assert.Equal(t, 1, len(content.([]*enode.Node))) + assert.Equal(t, node1.localNode.Node().ID(), content.([]*enode.Node)[0].ID()) + + // create a byte slice of length 1199 and fill it with random data + // this will be used as a test content + largeTestContent := make([]byte, 2000) + _, err = rand.Read(largeTestContent) + assert.NoError(t, err) + + err = node1.storage.Put(nil, node1.toContentId([]byte("large_test_key")), largeTestContent) + assert.NoError(t, err) + + flag, content, err = node2.findContent(node1.localNode.Node(), []byte("large_test_key")) + assert.NoError(t, err) + assert.Equal(t, largeTestContent, content) + assert.Equal(t, portalwire.ContentConnIdSelector, flag) + + testEntry1 := &ContentEntry{ + ContentKey: []byte("test_entry1"), + Content: []byte("test_entry1_content"), + } + + testEntry2 := &ContentEntry{ + ContentKey: []byte("test_entry2"), + Content: []byte("test_entry2_content"), + } + + testTransientOfferRequest := &TransientOfferRequest{ + Contents: []*ContentEntry{testEntry1, testEntry2}, + } + + offerRequest := &OfferRequest{ + Kind: TransientOfferRequestKind, + Request: testTransientOfferRequest, + } + + contentKeys, err := node1.offer(node3.localNode.Node(), offerRequest) + assert.Equal(t, uint64(2), bitfield.Bitlist(contentKeys).Count()) + assert.NoError(t, err) + + contentElement := <-node3.contentQueue + assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) + assert.Equal(t, testEntry1.ContentKey, contentElement.ContentKeys[0]) + assert.Equal(t, testEntry1.Content, contentElement.Contents[0]) + assert.Equal(t, testEntry2.ContentKey, contentElement.ContentKeys[1]) + assert.Equal(t, testEntry2.Content, contentElement.Contents[1]) + + testGossipContentKeys := [][]byte{[]byte("test_gossip_content_keys"), []byte("test_gossip_content_keys2")} + testGossipContent := [][]byte{[]byte("test_gossip_content"), []byte("test_gossip_content2")} + id := node1.Self().ID() + gossip, err := node1.Gossip(&id, testGossipContentKeys, testGossipContent) + assert.NoError(t, err) + assert.Equal(t, 2, gossip) + + contentElement = <-node2.contentQueue + assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) + assert.Equal(t, testGossipContentKeys[0], contentElement.ContentKeys[0]) + assert.Equal(t, testGossipContent[0], contentElement.Contents[0]) + assert.Equal(t, testGossipContentKeys[1], contentElement.ContentKeys[1]) + assert.Equal(t, testGossipContent[1], contentElement.Contents[1]) + + contentElement = <-node3.contentQueue + assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) + assert.Equal(t, testGossipContentKeys[0], contentElement.ContentKeys[0]) + assert.Equal(t, testGossipContent[0], contentElement.Contents[0]) + assert.Equal(t, testGossipContentKeys[1], contentElement.ContentKeys[1]) + assert.Equal(t, testGossipContent[1], contentElement.Contents[1]) + + node1.Stop() + node2.Stop() + node3.Stop() +} + +func TestCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + go func(ctx context.Context) { + defer func() { + t.Log("goroutine cancel") + }() + + time.Sleep(time.Second * 5) + }(ctx) + + cancel() + t.Log("after main cancel") + + time.Sleep(time.Second * 3) +} + +func TestContentLookup(t *testing.T) { + node1, err := setupLocalPortalNode(":17777", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LvlTrace) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":17778", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LvlTrace) + err = node2.Start() + assert.NoError(t, err) + fmt.Println(node2.localNode.Node().String()) + + node3, err := setupLocalPortalNode(":17779", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LvlTrace) + err = node3.Start() + assert.NoError(t, err) + + defer func() { + node1.Stop() + node2.Stop() + node3.Stop() + }() + + contentKey := []byte{0x3, 0x4} + content := []byte{0x1, 0x2} + contentId := node1.toContentId(contentKey) + + err = node3.storage.Put(nil, contentId, content) + assert.NoError(t, err) + + res, _, err := node1.ContentLookup(contentKey, contentId) + assert.NoError(t, err) + assert.Equal(t, res, content) + + nonExist := []byte{0x2, 0x4} + res, _, err = node1.ContentLookup(nonExist, node1.toContentId(nonExist)) + assert.Equal(t, ContentNotFound, err) + assert.Nil(t, res) +} + +func TestTraceContentLookup(t *testing.T) { + node1, err := setupLocalPortalNode(":17787", nil) + assert.NoError(t, err) + node1.Log = testlog.Logger(t, log.LvlTrace) + err = node1.Start() + assert.NoError(t, err) + + node2, err := setupLocalPortalNode(":17788", []*enode.Node{node1.localNode.Node()}) + assert.NoError(t, err) + node2.Log = testlog.Logger(t, log.LvlTrace) + err = node2.Start() + assert.NoError(t, err) + + node3, err := setupLocalPortalNode(":17789", []*enode.Node{node2.localNode.Node()}) + assert.NoError(t, err) + node3.Log = testlog.Logger(t, log.LvlTrace) + err = node3.Start() + assert.NoError(t, err) + + defer node1.Stop() + defer node2.Stop() + defer node3.Stop() + + contentKey := []byte{0x3, 0x4} + content := []byte{0x1, 0x2} + contentId := node1.toContentId(contentKey) + + err = node1.storage.Put(nil, contentId, content) + assert.NoError(t, err) + + node1Id := hexutil.Encode(node1.Self().ID().Bytes()) + node2Id := hexutil.Encode(node2.Self().ID().Bytes()) + node3Id := hexutil.Encode(node3.Self().ID().Bytes()) + + res, err := node3.TraceContentLookup(contentKey, contentId) + assert.NoError(t, err) + assert.Equal(t, res.Content, hexutil.Encode(content)) + assert.Equal(t, res.UtpTransfer, false) + assert.Equal(t, res.Trace.Origin, node3Id) + assert.Equal(t, res.Trace.TargetId, hexutil.Encode(contentId)) + assert.Equal(t, res.Trace.ReceivedFrom, node1Id) + + // check nodeMeta + node1Meta := res.Trace.Metadata[node1Id] + assert.Equal(t, node1Meta.Enr, node1.Self().String()) + dis := node1.Distance(node1.Self().ID(), enode.ID(contentId)) + assert.Equal(t, node1Meta.Distance, hexutil.Encode(dis[:])) + + node2Meta := res.Trace.Metadata[node2Id] + assert.Equal(t, node2Meta.Enr, node2.Self().String()) + dis = node2.Distance(node2.Self().ID(), enode.ID(contentId)) + assert.Equal(t, node2Meta.Distance, hexutil.Encode(dis[:])) + + node3Meta := res.Trace.Metadata[node3Id] + assert.Equal(t, node3Meta.Enr, node3.Self().String()) + dis = node3.Distance(node3.Self().ID(), enode.ID(contentId)) + assert.Equal(t, node3Meta.Distance, hexutil.Encode(dis[:])) + + // check response + node3Response := res.Trace.Responses[node3Id] + assert.Equal(t, node3Response.RespondedWith, []string{node2Id}) + + node2Response := res.Trace.Responses[node2Id] + assert.Equal(t, node2Response.RespondedWith, []string{node1Id}) + + node1Response := res.Trace.Responses[node1Id] + assert.Equal(t, node1Response.RespondedWith, ([]string)(nil)) +} + +func newkey() *ecdsa.PrivateKey { + key, err := crypto.GenerateKey() + if err != nil { + panic("couldn't generate key: " + err.Error()) + } + return key +} diff --git a/portalnetwork/portal_utp.go b/portalnetwork/portal_utp.go new file mode 100644 index 000000000000..b1b58a7673ca --- /dev/null +++ b/portalnetwork/portal_utp.go @@ -0,0 +1,139 @@ +package portalnetwork + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/discover/v5wire" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/netutil" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" + "github.com/optimism-java/utp-go" + "github.com/optimism-java/utp-go/libutp" + "go.uber.org/zap" +) + +type PortalUtp struct { + ctx context.Context + log log.Logger + discV5 *discover.UDPv5 + conn discover.UDPConn + ListenAddr string + listener *utp.Listener + utpSm *utp.SocketManager + packetRouter *utp.PacketRouter + lAddr *utp.Addr + + startOnce sync.Once +} + +func NewPortalUtp(ctx context.Context, config *PortalProtocolConfig, discV5 *discover.UDPv5, conn discover.UDPConn) *PortalUtp { + return &PortalUtp{ + ctx: ctx, + log: log.New("protocol", "utp", "local", conn.LocalAddr().String()), + discV5: discV5, + conn: conn, + ListenAddr: config.ListenAddr, + } +} + +func (p *PortalUtp) Start() error { + var err error + go p.startOnce.Do(func() { + var logger *zap.Logger + if p.log.Enabled(p.ctx, log.LevelDebug) || p.log.Enabled(p.ctx, log.LevelTrace) { + logger, err = zap.NewDevelopmentConfig().Build() + } else { + logger, err = zap.NewProductionConfig().Build() + } + if err != nil { + return + } + + laddr := p.getLocalAddr() + p.packetRouter = utp.NewPacketRouter(p.packetRouterFunc) + p.utpSm, err = utp.NewSocketManagerWithOptions( + "utp", + laddr, + utp.WithContext(p.ctx), + utp.WithLogger(logger.Named(p.ListenAddr)), + utp.WithPacketRouter(p.packetRouter), + utp.WithMaxPacketSize(1145)) + if err != nil { + return + } + p.listener, err = utp.ListenUTPOptions("utp", (*utp.Addr)(laddr), utp.WithSocketManager(p.utpSm)) + if err != nil { + return + } + p.lAddr = p.listener.Addr().(*utp.Addr) + + // register discv5 listener + p.discV5.RegisterTalkHandler(string(portalwire.Utp), p.handleUtpTalkRequest) + }) + + return err +} + +func (p *PortalUtp) Stop() { + err := p.listener.Close() + if err != nil { + p.log.Error("close utp listener has error", "error", err) + } + p.discV5.Close() +} + +func (p *PortalUtp) DialWithCid(ctx context.Context, dest *enode.Node, connId uint16) (net.Conn, error) { + raddr := &utp.Addr{IP: dest.IP(), Port: dest.UDP()} + p.log.Debug("will connect to: ", "nodeId", dest.ID().String(), "connId", connId) + conn, err := utp.DialUTPOptions("utp", p.lAddr, raddr, utp.WithContext(ctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(connId)) + return conn, err +} + +func (p *PortalUtp) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { + raddr := &utp.Addr{IP: dest.IP(), Port: dest.UDP()} + p.log.Info("will connect to: ", "addr", raddr.String()) + conn, err := utp.DialUTPOptions("utp", p.lAddr, raddr, utp.WithContext(ctx), utp.WithSocketManager(p.utpSm)) + return conn, err +} + +func (p *PortalUtp) AcceptWithCid(ctx context.Context, nodeId enode.ID, cid *libutp.ConnId) (*utp.Conn, error) { + p.log.Debug("will accept from: ", "nodeId", nodeId.String(), "sendId", cid.SendId(), "recvId", cid.RecvId()) + return p.listener.AcceptUTPContext(ctx, nodeId, cid) +} + +func (p *PortalUtp) Accept(ctx context.Context) (*utp.Conn, error) { + return p.listener.AcceptUTPContext(ctx, enode.ID{}, nil) +} + +func (p *PortalUtp) getLocalAddr() *net.UDPAddr { + laddr := p.conn.LocalAddr().(*net.UDPAddr) + p.log.Debug("UDP listener up", "addr", laddr) + return laddr +} + +func (p *PortalUtp) packetRouterFunc(buf []byte, id enode.ID, addr *net.UDPAddr) (int, error) { + p.log.Info("will send to target data", "nodeId", id.String(), "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) + + if n, ok := p.discV5.GetCachedNode(addr.String()); ok { + //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) + req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf} + p.discV5.SendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) + + return len(buf), nil + } else { + p.log.Warn("not found target node info", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) + return 0, fmt.Errorf("not found target node id") + } +} + +func (p *PortalUtp) handleUtpTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { + p.log.Trace("receive utp data", "nodeId", id.String(), "addr", addr, "msg-length", len(msg)) + p.packetRouter.ReceiveMessage(msg, &utp.NodeInfo{Id: id, Addr: addr}) + return []byte("") +} diff --git a/portalnetwork/portalwire/messages.go b/portalnetwork/portalwire/messages.go new file mode 100644 index 000000000000..c7629604d570 --- /dev/null +++ b/portalnetwork/portalwire/messages.go @@ -0,0 +1,336 @@ +package portalwire + +import ( + ssz "github.com/ferranbt/fastssz" +) + +// note: We changed the generated file since fastssz issues which can't be passed by the CI, so we commented the go:generate line +///go:generate sszgen --path messages.go --exclude-objs Content,Enrs,ContentKV + +// Message codes for the portal protocol. +const ( + PING byte = 0x00 + PONG byte = 0x01 + FINDNODES byte = 0x02 + NODES byte = 0x03 + FINDCONTENT byte = 0x04 + CONTENT byte = 0x05 + OFFER byte = 0x06 + ACCEPT byte = 0x07 +) + +// Content selectors for the portal protocol. +const ( + ContentConnIdSelector byte = 0x00 + ContentRawSelector byte = 0x01 + ContentEnrsSelector byte = 0x02 +) + +const ( + ContentKeysLimit = 64 + // OfferMessageOverhead overhead of content message is a result of 1byte for kind enum, and + // 4 bytes for offset in ssz serialization + OfferMessageOverhead = 5 + + // PerContentKeyOverhead each key in ContentKeysList has uint32 offset which results in 4 bytes per + // key overhead when serialized + PerContentKeyOverhead = 4 +) + +// Protocol IDs for the portal protocol. +// var ( +// StateNetwork = []byte{0x50, 0x0a} +// HistoryNetwork = []byte{0x50, 0x0b} +// TxGossipNetwork = []byte{0x50, 0x0c} +// HeaderGossipNetwork = []byte{0x50, 0x0d} +// CanonicalIndicesNetwork = []byte{0x50, 0x0e} +// BeaconLightClientNetwork = []byte{0x50, 0x1a} +// UTPNetwork = []byte{0x75, 0x74, 0x70} +// Rendezvous = []byte{0x72, 0x65, 0x6e} +// ) + +type ProtocolId []byte + +var ( + State ProtocolId = []byte{0x50, 0x0A} + History ProtocolId = []byte{0x50, 0x0B} + Beacon ProtocolId = []byte{0x50, 0x0C} + CanonicalIndices ProtocolId = []byte{0x50, 0x0D} + VerkleState ProtocolId = []byte{0x50, 0x0E} + TransactionGossip ProtocolId = []byte{0x50, 0x0F} + Utp ProtocolId = []byte{0x75, 0x74, 0x70} +) + +var protocalName = map[string]string{ + string(State): "state", + string(History): "history", + string(Beacon): "beacon", + string(CanonicalIndices): "canonical indices", + string(VerkleState): "verkle state", + string(TransactionGossip): "transaction gossip", +} + +func (p ProtocolId) Name() string { + return protocalName[string(p)] +} + +// const ( +// HistoryNetworkName = "history" +// BeaconNetworkName = "beacon" +// StateNetworkName = "state" +// ) + +// var NetworkNameMap = map[string]string{ +// string(StateNetwork): StateNetworkName, +// string(HistoryNetwork): HistoryNetworkName, +// string(BeaconLightClientNetwork): BeaconNetworkName, +// } + +type ContentKV struct { + ContentKey []byte + Content []byte +} + +// Request messages for the portal protocol. +type ( + PingPongCustomData struct { + Radius []byte `ssz-size:"32"` + } + + Ping struct { + EnrSeq uint64 + CustomPayload []byte `ssz-max:"2048"` + } + + FindNodes struct { + Distances [][2]byte `ssz-max:"256,2" ssz-size:"?,2"` + } + + FindContent struct { + ContentKey []byte `ssz-max:"2048"` + } + + Offer struct { + ContentKeys [][]byte `ssz-max:"64,2048"` + } +) + +// Response messages for the portal protocol. +type ( + Pong struct { + EnrSeq uint64 + CustomPayload []byte `ssz-max:"2048"` + } + + Nodes struct { + Total uint8 + Enrs [][]byte `ssz-max:"32,2048"` + } + + ConnectionId struct { + Id []byte `ssz-size:"2"` + } + + Content struct { + Content []byte `ssz-max:"2048"` + } + + Enrs struct { + Enrs [][]byte `ssz-max:"32,2048"` + } + + Accept struct { + ConnectionId []byte `ssz-size:"2"` + ContentKeys []byte `ssz:"bitlist" ssz-max:"64"` + } +) + +// MarshalSSZ ssz marshals the Content object +func (c *Content) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(c) +} + +// MarshalSSZTo ssz marshals the Content object to a target array +func (c *Content) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Content' + if size := len(c.Content); size > 2048 { + err = ssz.ErrBytesLengthFn("Content.Content", size, 2048) + return + } + dst = append(dst, c.Content...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Content object +func (c *Content) UnmarshalSSZ(buf []byte) error { + var err error + tail := buf + + // Field (0) 'Content' + { + buf = tail[:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(c.Content) == 0 { + c.Content = make([]byte, 0, len(buf)) + } + c.Content = append(c.Content, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Content object +func (c *Content) SizeSSZ() (size int) { + // Field (0) 'Content' + return len(c.Content) +} + +// HashTreeRoot ssz hashes the Content object +func (c *Content) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(c) +} + +// HashTreeRootWith ssz hashes the Content object with a hasher +func (c *Content) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Content' + { + elemIndx := hh.Index() + byteLen := uint64(len(c.Content)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(c.Content) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Content object +func (c *Content) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(c) +} + +// MarshalSSZ ssz marshals the Enrs object +func (e *Enrs) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(e) +} + +// MarshalSSZTo ssz marshals the Enrs object to a target array +func (e *Enrs) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(0) + + // Field (0) 'Enrs' + if size := len(e.Enrs); size > 32 { + err = ssz.ErrListTooBigFn("Enrs.Enrs", size, 32) + return + } + { + offset = 4 * len(e.Enrs) + for ii := 0; ii < len(e.Enrs); ii++ { + dst = ssz.WriteOffset(dst, offset) + offset += len(e.Enrs[ii]) + } + } + for ii := 0; ii < len(e.Enrs); ii++ { + if size := len(e.Enrs[ii]); size > 2048 { + err = ssz.ErrBytesLengthFn("Enrs.Enrs[ii]", size, 2048) + return + } + dst = append(dst, e.Enrs[ii]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Enrs object +func (e *Enrs) UnmarshalSSZ(buf []byte) error { + var err error + tail := buf + // Field (0) 'Enrs' + { + buf = tail[:] + num, err := ssz.DecodeDynamicLength(buf, 32) + if err != nil { + return err + } + e.Enrs = make([][]byte, num) + err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(e.Enrs[indx]) == 0 { + e.Enrs[indx] = make([]byte, 0, len(buf)) + } + e.Enrs[indx] = append(e.Enrs[indx], buf...) + return nil + }) + if err != nil { + return err + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Enrs object +func (e *Enrs) SizeSSZ() (size int) { + size = 0 + + // Field (0) 'Enrs' + for ii := 0; ii < len(e.Enrs); ii++ { + size += 4 + size += len(e.Enrs[ii]) + } + + return +} + +// HashTreeRoot ssz hashes the Enrs object +func (e *Enrs) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(e) +} + +// HashTreeRootWith ssz hashes the Enrs object with a hasher +func (e *Enrs) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Enrs' + { + subIndx := hh.Index() + num := uint64(len(e.Enrs)) + if num > 32 { + err = ssz.ErrIncorrectListSize + return + } + for _, elem := range e.Enrs { + { + elemIndx := hh.Index() + byteLen := uint64(len(elem)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.AppendBytes32(elem) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + } + hh.MerkleizeWithMixin(subIndx, num, 32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Enrs object +func (e *Enrs) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(e) +} diff --git a/portalnetwork/portalwire/messages_encoding.go b/portalnetwork/portalwire/messages_encoding.go new file mode 100644 index 000000000000..601150baff1a --- /dev/null +++ b/portalnetwork/portalwire/messages_encoding.go @@ -0,0 +1,957 @@ +// Code generated by fastssz. DO NOT EDIT. +// Hash: 26a61b12807ff78c64a029acdd5bcb580dfe35b7bfbf8bf04ceebae1a3d5cac1 +// Version: 0.1.3 +package portalwire + +import ( + ssz "github.com/ferranbt/fastssz" +) + +// MarshalSSZ ssz marshals the PingPongCustomData object +func (p *PingPongCustomData) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(p) +} + +// MarshalSSZTo ssz marshals the PingPongCustomData object to a target array +func (p *PingPongCustomData) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Radius' + if size := len(p.Radius); size != 32 { + err = ssz.ErrBytesLengthFn("PingPongCustomData.Radius", size, 32) + return + } + dst = append(dst, p.Radius...) + + return +} + +// UnmarshalSSZ ssz unmarshals the PingPongCustomData object +func (p *PingPongCustomData) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 32 { + return ssz.ErrSize + } + + // Field (0) 'Radius' + if cap(p.Radius) == 0 { + p.Radius = make([]byte, 0, len(buf[0:32])) + } + p.Radius = append(p.Radius, buf[0:32]...) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the PingPongCustomData object +func (p *PingPongCustomData) SizeSSZ() (size int) { + size = 32 + return +} + +// HashTreeRoot ssz hashes the PingPongCustomData object +func (p *PingPongCustomData) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(p) +} + +// HashTreeRootWith ssz hashes the PingPongCustomData object with a hasher +func (p *PingPongCustomData) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Radius' + if size := len(p.Radius); size != 32 { + err = ssz.ErrBytesLengthFn("PingPongCustomData.Radius", size, 32) + return + } + hh.PutBytes(p.Radius) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the PingPongCustomData object +func (p *PingPongCustomData) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(p) +} + +// MarshalSSZ ssz marshals the Ping object +func (p *Ping) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(p) +} + +// MarshalSSZTo ssz marshals the Ping object to a target array +func (p *Ping) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(12) + + // Field (0) 'EnrSeq' + dst = ssz.MarshalUint64(dst, p.EnrSeq) + + // Offset (1) 'CustomPayload' + dst = ssz.WriteOffset(dst, offset) + offset += len(p.CustomPayload) + + // Field (1) 'CustomPayload' + if size := len(p.CustomPayload); size > 2048 { + err = ssz.ErrBytesLengthFn("Ping.CustomPayload", size, 2048) + return + } + dst = append(dst, p.CustomPayload...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Ping object +func (p *Ping) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 12 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'EnrSeq' + p.EnrSeq = ssz.UnmarshallUint64(buf[0:8]) + + // Offset (1) 'CustomPayload' + if o1 = ssz.ReadOffset(buf[8:12]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 12 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'CustomPayload' + { + buf = tail[o1:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(p.CustomPayload) == 0 { + p.CustomPayload = make([]byte, 0, len(buf)) + } + p.CustomPayload = append(p.CustomPayload, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Ping object +func (p *Ping) SizeSSZ() (size int) { + size = 12 + + // Field (1) 'CustomPayload' + size += len(p.CustomPayload) + + return +} + +// HashTreeRoot ssz hashes the Ping object +func (p *Ping) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(p) +} + +// HashTreeRootWith ssz hashes the Ping object with a hasher +func (p *Ping) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'EnrSeq' + hh.PutUint64(p.EnrSeq) + + // Field (1) 'CustomPayload' + { + elemIndx := hh.Index() + byteLen := uint64(len(p.CustomPayload)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(p.CustomPayload) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Ping object +func (p *Ping) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(p) +} + +// MarshalSSZ ssz marshals the FindNodes object +func (f *FindNodes) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(f) +} + +// MarshalSSZTo ssz marshals the FindNodes object to a target array +func (f *FindNodes) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(4) + + // Offset (0) 'Distances' + dst = ssz.WriteOffset(dst, offset) + offset += len(f.Distances) * 2 + + // Field (0) 'Distances' + if size := len(f.Distances); size > 256 { + err = ssz.ErrListTooBigFn("FindNodes.Distances", size, 256) + return + } + for ii := 0; ii < len(f.Distances); ii++ { + dst = append(dst, f.Distances[ii][:]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the FindNodes object +func (f *FindNodes) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 4 { + return ssz.ErrSize + } + + tail := buf + var o0 uint64 + + // Offset (0) 'Distances' + if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { + return ssz.ErrOffset + } + + if o0 < 4 { + return ssz.ErrInvalidVariableOffset + } + + // Field (0) 'Distances' + { + buf = tail[o0:] + num, err := ssz.DivideInt2(len(buf), 2, 256) + if err != nil { + return err + } + f.Distances = make([][2]byte, num) + for ii := 0; ii < num; ii++ { + copy(f.Distances[ii][:], buf[ii*2:(ii+1)*2]) + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the FindNodes object +func (f *FindNodes) SizeSSZ() (size int) { + size = 4 + + // Field (0) 'Distances' + size += len(f.Distances) * 2 + + return +} + +// HashTreeRoot ssz hashes the FindNodes object +func (f *FindNodes) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(f) +} + +// HashTreeRootWith ssz hashes the FindNodes object with a hasher +func (f *FindNodes) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Distances' + { + if size := len(f.Distances); size > 256 { + err = ssz.ErrListTooBigFn("FindNodes.Distances", size, 256) + return + } + subIndx := hh.Index() + for _, i := range f.Distances { + hh.PutBytes(i[:]) + } + numItems := uint64(len(f.Distances)) + hh.MerkleizeWithMixin(subIndx, numItems, 256) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the FindNodes object +func (f *FindNodes) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(f) +} + +// MarshalSSZ ssz marshals the FindContent object +func (f *FindContent) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(f) +} + +// MarshalSSZTo ssz marshals the FindContent object to a target array +func (f *FindContent) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(4) + + // Offset (0) 'ContentKey' + dst = ssz.WriteOffset(dst, offset) + offset += len(f.ContentKey) + + // Field (0) 'ContentKey' + if size := len(f.ContentKey); size > 2048 { + err = ssz.ErrBytesLengthFn("FindContent.ContentKey", size, 2048) + return + } + dst = append(dst, f.ContentKey...) + + return +} + +// UnmarshalSSZ ssz unmarshals the FindContent object +func (f *FindContent) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 4 { + return ssz.ErrSize + } + + tail := buf + var o0 uint64 + + // Offset (0) 'ContentKey' + if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { + return ssz.ErrOffset + } + + if o0 < 4 { + return ssz.ErrInvalidVariableOffset + } + + // Field (0) 'ContentKey' + { + buf = tail[o0:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(f.ContentKey) == 0 { + f.ContentKey = make([]byte, 0, len(buf)) + } + f.ContentKey = append(f.ContentKey, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the FindContent object +func (f *FindContent) SizeSSZ() (size int) { + size = 4 + + // Field (0) 'ContentKey' + size += len(f.ContentKey) + + return +} + +// HashTreeRoot ssz hashes the FindContent object +func (f *FindContent) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(f) +} + +// HashTreeRootWith ssz hashes the FindContent object with a hasher +func (f *FindContent) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'ContentKey' + { + elemIndx := hh.Index() + byteLen := uint64(len(f.ContentKey)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(f.ContentKey) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the FindContent object +func (f *FindContent) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(f) +} + +// MarshalSSZ ssz marshals the Offer object +func (o *Offer) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(o) +} + +// MarshalSSZTo ssz marshals the Offer object to a target array +func (o *Offer) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(4) + + // Offset (0) 'ContentKeys' + dst = ssz.WriteOffset(dst, offset) + for ii := 0; ii < len(o.ContentKeys); ii++ { + offset += 4 + offset += len(o.ContentKeys[ii]) + } + + // Field (0) 'ContentKeys' + if size := len(o.ContentKeys); size > 64 { + err = ssz.ErrListTooBigFn("Offer.ContentKeys", size, 64) + return + } + { + offset = 4 * len(o.ContentKeys) + for ii := 0; ii < len(o.ContentKeys); ii++ { + dst = ssz.WriteOffset(dst, offset) + offset += len(o.ContentKeys[ii]) + } + } + for ii := 0; ii < len(o.ContentKeys); ii++ { + if size := len(o.ContentKeys[ii]); size > 2048 { + err = ssz.ErrBytesLengthFn("Offer.ContentKeys[ii]", size, 2048) + return + } + dst = append(dst, o.ContentKeys[ii]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Offer object +func (o *Offer) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 4 { + return ssz.ErrSize + } + + tail := buf + var o0 uint64 + + // Offset (0) 'ContentKeys' + if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { + return ssz.ErrOffset + } + + if o0 < 4 { + return ssz.ErrInvalidVariableOffset + } + + // Field (0) 'ContentKeys' + { + buf = tail[o0:] + num, err := ssz.DecodeDynamicLength(buf, 64) + if err != nil { + return err + } + o.ContentKeys = make([][]byte, num) + err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(o.ContentKeys[indx]) == 0 { + o.ContentKeys[indx] = make([]byte, 0, len(buf)) + } + o.ContentKeys[indx] = append(o.ContentKeys[indx], buf...) + return nil + }) + if err != nil { + return err + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Offer object +func (o *Offer) SizeSSZ() (size int) { + size = 4 + + // Field (0) 'ContentKeys' + for ii := 0; ii < len(o.ContentKeys); ii++ { + size += 4 + size += len(o.ContentKeys[ii]) + } + + return +} + +// HashTreeRoot ssz hashes the Offer object +func (o *Offer) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(o) +} + +// HashTreeRootWith ssz hashes the Offer object with a hasher +func (o *Offer) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'ContentKeys' + { + subIndx := hh.Index() + num := uint64(len(o.ContentKeys)) + if num > 64 { + err = ssz.ErrIncorrectListSize + return + } + for _, elem := range o.ContentKeys { + { + elemIndx := hh.Index() + byteLen := uint64(len(elem)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.AppendBytes32(elem) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + } + hh.MerkleizeWithMixin(subIndx, num, 64) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Offer object +func (o *Offer) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(o) +} + +// MarshalSSZ ssz marshals the Pong object +func (p *Pong) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(p) +} + +// MarshalSSZTo ssz marshals the Pong object to a target array +func (p *Pong) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(12) + + // Field (0) 'EnrSeq' + dst = ssz.MarshalUint64(dst, p.EnrSeq) + + // Offset (1) 'CustomPayload' + dst = ssz.WriteOffset(dst, offset) + offset += len(p.CustomPayload) + + // Field (1) 'CustomPayload' + if size := len(p.CustomPayload); size > 2048 { + err = ssz.ErrBytesLengthFn("Pong.CustomPayload", size, 2048) + return + } + dst = append(dst, p.CustomPayload...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Pong object +func (p *Pong) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 12 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'EnrSeq' + p.EnrSeq = ssz.UnmarshallUint64(buf[0:8]) + + // Offset (1) 'CustomPayload' + if o1 = ssz.ReadOffset(buf[8:12]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 12 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'CustomPayload' + { + buf = tail[o1:] + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(p.CustomPayload) == 0 { + p.CustomPayload = make([]byte, 0, len(buf)) + } + p.CustomPayload = append(p.CustomPayload, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Pong object +func (p *Pong) SizeSSZ() (size int) { + size = 12 + + // Field (1) 'CustomPayload' + size += len(p.CustomPayload) + + return +} + +// HashTreeRoot ssz hashes the Pong object +func (p *Pong) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(p) +} + +// HashTreeRootWith ssz hashes the Pong object with a hasher +func (p *Pong) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'EnrSeq' + hh.PutUint64(p.EnrSeq) + + // Field (1) 'CustomPayload' + { + elemIndx := hh.Index() + byteLen := uint64(len(p.CustomPayload)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.Append(p.CustomPayload) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Pong object +func (p *Pong) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(p) +} + +// MarshalSSZ ssz marshals the Nodes object +func (n *Nodes) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(n) +} + +// MarshalSSZTo ssz marshals the Nodes object to a target array +func (n *Nodes) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(5) + + // Field (0) 'Total' + dst = ssz.MarshalUint8(dst, n.Total) + + // Offset (1) 'Enrs' + dst = ssz.WriteOffset(dst, offset) + for ii := 0; ii < len(n.Enrs); ii++ { + offset += 4 + offset += len(n.Enrs[ii]) + } + + // Field (1) 'Enrs' + if size := len(n.Enrs); size > 32 { + err = ssz.ErrListTooBigFn("Nodes.Enrs", size, 32) + return + } + { + offset = 4 * len(n.Enrs) + for ii := 0; ii < len(n.Enrs); ii++ { + dst = ssz.WriteOffset(dst, offset) + offset += len(n.Enrs[ii]) + } + } + for ii := 0; ii < len(n.Enrs); ii++ { + if size := len(n.Enrs[ii]); size > 2048 { + err = ssz.ErrBytesLengthFn("Nodes.Enrs[ii]", size, 2048) + return + } + dst = append(dst, n.Enrs[ii]...) + } + + return +} + +// UnmarshalSSZ ssz unmarshals the Nodes object +func (n *Nodes) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 5 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'Total' + n.Total = ssz.UnmarshallUint8(buf[0:1]) + + // Offset (1) 'Enrs' + if o1 = ssz.ReadOffset(buf[1:5]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 5 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'Enrs' + { + buf = tail[o1:] + num, err := ssz.DecodeDynamicLength(buf, 32) + if err != nil { + return err + } + n.Enrs = make([][]byte, num) + err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { + if len(buf) > 2048 { + return ssz.ErrBytesLength + } + if cap(n.Enrs[indx]) == 0 { + n.Enrs[indx] = make([]byte, 0, len(buf)) + } + n.Enrs[indx] = append(n.Enrs[indx], buf...) + return nil + }) + if err != nil { + return err + } + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Nodes object +func (n *Nodes) SizeSSZ() (size int) { + size = 5 + + // Field (1) 'Enrs' + for ii := 0; ii < len(n.Enrs); ii++ { + size += 4 + size += len(n.Enrs[ii]) + } + + return +} + +// HashTreeRoot ssz hashes the Nodes object +func (n *Nodes) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(n) +} + +// HashTreeRootWith ssz hashes the Nodes object with a hasher +func (n *Nodes) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Total' + hh.PutUint8(n.Total) + + // Field (1) 'Enrs' + { + subIndx := hh.Index() + num := uint64(len(n.Enrs)) + if num > 32 { + err = ssz.ErrIncorrectListSize + return + } + for _, elem := range n.Enrs { + { + elemIndx := hh.Index() + byteLen := uint64(len(elem)) + if byteLen > 2048 { + err = ssz.ErrIncorrectListSize + return + } + hh.AppendBytes32(elem) + hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) + } + } + hh.MerkleizeWithMixin(subIndx, num, 32) + } + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Nodes object +func (n *Nodes) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(n) +} + +// MarshalSSZ ssz marshals the ConnectionId object +func (c *ConnectionId) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(c) +} + +// MarshalSSZTo ssz marshals the ConnectionId object to a target array +func (c *ConnectionId) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + + // Field (0) 'Id' + if size := len(c.Id); size != 2 { + err = ssz.ErrBytesLengthFn("ConnectionId.Id", size, 2) + return + } + dst = append(dst, c.Id...) + + return +} + +// UnmarshalSSZ ssz unmarshals the ConnectionId object +func (c *ConnectionId) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size != 2 { + return ssz.ErrSize + } + + // Field (0) 'Id' + if cap(c.Id) == 0 { + c.Id = make([]byte, 0, len(buf[0:2])) + } + c.Id = append(c.Id, buf[0:2]...) + + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the ConnectionId object +func (c *ConnectionId) SizeSSZ() (size int) { + size = 2 + return +} + +// HashTreeRoot ssz hashes the ConnectionId object +func (c *ConnectionId) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(c) +} + +// HashTreeRootWith ssz hashes the ConnectionId object with a hasher +func (c *ConnectionId) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'Id' + if size := len(c.Id); size != 2 { + err = ssz.ErrBytesLengthFn("ConnectionId.Id", size, 2) + return + } + hh.PutBytes(c.Id) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the ConnectionId object +func (c *ConnectionId) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(c) +} + +// MarshalSSZ ssz marshals the Accept object +func (a *Accept) MarshalSSZ() ([]byte, error) { + return ssz.MarshalSSZ(a) +} + +// MarshalSSZTo ssz marshals the Accept object to a target array +func (a *Accept) MarshalSSZTo(buf []byte) (dst []byte, err error) { + dst = buf + offset := int(6) + + // Field (0) 'ConnectionId' + if size := len(a.ConnectionId); size != 2 { + err = ssz.ErrBytesLengthFn("Accept.ConnectionId", size, 2) + return + } + dst = append(dst, a.ConnectionId...) + + // Offset (1) 'ContentKeys' + dst = ssz.WriteOffset(dst, offset) + offset += len(a.ContentKeys) + + // Field (1) 'ContentKeys' + if size := len(a.ContentKeys); size > 64 { + err = ssz.ErrBytesLengthFn("Accept.ContentKeys", size, 64) + return + } + dst = append(dst, a.ContentKeys...) + + return +} + +// UnmarshalSSZ ssz unmarshals the Accept object +func (a *Accept) UnmarshalSSZ(buf []byte) error { + var err error + size := uint64(len(buf)) + if size < 6 { + return ssz.ErrSize + } + + tail := buf + var o1 uint64 + + // Field (0) 'ConnectionId' + if cap(a.ConnectionId) == 0 { + a.ConnectionId = make([]byte, 0, len(buf[0:2])) + } + a.ConnectionId = append(a.ConnectionId, buf[0:2]...) + + // Offset (1) 'ContentKeys' + if o1 = ssz.ReadOffset(buf[2:6]); o1 > size { + return ssz.ErrOffset + } + + if o1 < 6 { + return ssz.ErrInvalidVariableOffset + } + + // Field (1) 'ContentKeys' + { + buf = tail[o1:] + if err = ssz.ValidateBitlist(buf, 64); err != nil { + return err + } + if cap(a.ContentKeys) == 0 { + a.ContentKeys = make([]byte, 0, len(buf)) + } + a.ContentKeys = append(a.ContentKeys, buf...) + } + return err +} + +// SizeSSZ returns the ssz encoded size in bytes for the Accept object +func (a *Accept) SizeSSZ() (size int) { + size = 6 + + // Field (1) 'ContentKeys' + size += len(a.ContentKeys) + + return +} + +// HashTreeRoot ssz hashes the Accept object +func (a *Accept) HashTreeRoot() ([32]byte, error) { + return ssz.HashWithDefaultHasher(a) +} + +// HashTreeRootWith ssz hashes the Accept object with a hasher +func (a *Accept) HashTreeRootWith(hh ssz.HashWalker) (err error) { + indx := hh.Index() + + // Field (0) 'ConnectionId' + if size := len(a.ConnectionId); size != 2 { + err = ssz.ErrBytesLengthFn("Accept.ConnectionId", size, 2) + return + } + hh.PutBytes(a.ConnectionId) + + // Field (1) 'ContentKeys' + if len(a.ContentKeys) == 0 { + err = ssz.ErrEmptyBitlist + return + } + hh.PutBitlist(a.ContentKeys, 64) + + hh.Merkleize(indx) + return +} + +// GetTree ssz hashes the Accept object +func (a *Accept) GetTree() (*ssz.Node, error) { + return ssz.ProofTree(a) +} diff --git a/portalnetwork/portalwire/messages_test.go b/portalnetwork/portalwire/messages_test.go new file mode 100644 index 000000000000..9e266cf41789 --- /dev/null +++ b/portalnetwork/portalwire/messages_test.go @@ -0,0 +1,212 @@ +package portalwire + +import ( + "fmt" + "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/rlp" + ssz "github.com/ferranbt/fastssz" + "github.com/holiman/uint256" + "github.com/prysmaticlabs/go-bitfield" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var maxUint256 = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") + +// https://github.com/ethereum/portal-network-specs/blob/master/portal-wire-test-vectors.md +// we remove the message type here +func TestPingMessage(t *testing.T) { + dataRadius := maxUint256.Sub(maxUint256, uint256.NewInt(1)) + reverseBytes, err := dataRadius.MarshalSSZ() + require.NoError(t, err) + customData := &PingPongCustomData{ + Radius: reverseBytes, + } + dataBytes, err := customData.MarshalSSZ() + assert.NoError(t, err) + ping := &Ping{ + EnrSeq: 1, + CustomPayload: dataBytes, + } + + expected := "0x01000000000000000c000000feffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + + data, err := ping.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} + +func TestPongMessage(t *testing.T) { + dataRadius := maxUint256.Div(maxUint256, uint256.NewInt(2)) + reverseBytes, err := dataRadius.MarshalSSZ() + require.NoError(t, err) + customData := &PingPongCustomData{ + Radius: reverseBytes, + } + + dataBytes, err := customData.MarshalSSZ() + assert.NoError(t, err) + pong := &Pong{ + EnrSeq: 1, + CustomPayload: dataBytes, + } + + expected := "0x01000000000000000c000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" + + data, err := pong.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} + +func TestFindNodesMessage(t *testing.T) { + distances := []uint16{256, 255} + + distancesBytes := make([][2]byte, len(distances)) + for i, distance := range distances { + copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), distance)) + } + + findNode := &FindNodes{ + Distances: distancesBytes, + } + + data, err := findNode.MarshalSSZ() + expected := "0x040000000001ff00" + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} + +func TestNodes(t *testing.T) { + enrs := []string{ + "enr:-HW4QBzimRxkmT18hMKaAL3IcZF1UcfTMPyi3Q1pxwZZbcZVRI8DC5infUAB_UauARLOJtYTxaagKoGmIjzQxO2qUygBgmlkgnY0iXNlY3AyNTZrMaEDymNMrg1JrLQB2KTGtv6MVbcNEVv0AHacwUAPMljNMTg", + "enr:-HW4QNfxw543Ypf4HXKXdYxkyzfcxcO-6p9X986WldfVpnVTQX1xlTnWrktEWUbeTZnmgOuAY_KUhbVV1Ft98WoYUBMBgmlkgnY0iXNlY3AyNTZrMaEDDiy3QkHAxPyOgWbxp5oF1bDdlYE6dLCUUp8xfVw50jU", + } + + enrsBytes := make([][]byte, 0) + for _, enr := range enrs { + n, err := enode.Parse(enode.ValidSchemes, enr) + assert.NoError(t, err) + + enrBytes, err := rlp.EncodeToBytes(n.Record()) + assert.NoError(t, err) + enrsBytes = append(enrsBytes, enrBytes) + } + + testCases := []struct { + name string + input [][]byte + expected string + }{ + { + name: "empty nodes", + input: make([][]byte, 0), + expected: "0x0105000000", + }, + { + name: "two nodes", + input: enrsBytes, + expected: "0x0105000000080000007f000000f875b8401ce2991c64993d7c84c29a00bdc871917551c7d330fca2dd0d69c706596dc655448f030b98a77d4001fd46ae0112ce26d613c5a6a02a81a6223cd0c4edaa53280182696482763489736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138f875b840d7f1c39e376297f81d7297758c64cb37dcc5c3beea9f57f7ce9695d7d5a67553417d719539d6ae4b445946de4d99e680eb8063f29485b555d45b7df16a1850130182696482763489736563703235366b31a1030e2cb74241c0c4fc8e8166f1a79a05d5b0dd95813a74b094529f317d5c39d235", + }, + } + + for _, test := range testCases { + t.Run(test.name, func(t *testing.T) { + nodes := &Nodes{ + Total: 1, + Enrs: test.input, + } + + data, err := nodes.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, test.expected, fmt.Sprintf("0x%x", data)) + }) + } +} + +func TestContent(t *testing.T) { + contentKey := "0x706f7274616c" + + content := &FindContent{ + ContentKey: hexutil.MustDecode(contentKey), + } + expected := "0x04000000706f7274616c" + data, err := content.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + expected = "0x7468652063616b652069732061206c6965" + + contentRes := &Content{ + Content: hexutil.MustDecode("0x7468652063616b652069732061206c6965"), + } + + data, err = contentRes.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + expectData := &Content{} + err = expectData.UnmarshalSSZ(data) + assert.NoError(t, err) + assert.Equal(t, contentRes.Content, expectData.Content) + + enrs := []string{ + "enr:-HW4QBzimRxkmT18hMKaAL3IcZF1UcfTMPyi3Q1pxwZZbcZVRI8DC5infUAB_UauARLOJtYTxaagKoGmIjzQxO2qUygBgmlkgnY0iXNlY3AyNTZrMaEDymNMrg1JrLQB2KTGtv6MVbcNEVv0AHacwUAPMljNMTg", + "enr:-HW4QNfxw543Ypf4HXKXdYxkyzfcxcO-6p9X986WldfVpnVTQX1xlTnWrktEWUbeTZnmgOuAY_KUhbVV1Ft98WoYUBMBgmlkgnY0iXNlY3AyNTZrMaEDDiy3QkHAxPyOgWbxp5oF1bDdlYE6dLCUUp8xfVw50jU", + } + + enrsBytes := make([][]byte, 0) + for _, enr := range enrs { + n, err := enode.Parse(enode.ValidSchemes, enr) + assert.NoError(t, err) + + enrBytes, err := rlp.EncodeToBytes(n.Record()) + assert.NoError(t, err) + enrsBytes = append(enrsBytes, enrBytes) + } + + enrsRes := &Enrs{ + Enrs: enrsBytes, + } + + expected = "0x080000007f000000f875b8401ce2991c64993d7c84c29a00bdc871917551c7d330fca2dd0d69c706596dc655448f030b98a77d4001fd46ae0112ce26d613c5a6a02a81a6223cd0c4edaa53280182696482763489736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138f875b840d7f1c39e376297f81d7297758c64cb37dcc5c3beea9f57f7ce9695d7d5a67553417d719539d6ae4b445946de4d99e680eb8063f29485b555d45b7df16a1850130182696482763489736563703235366b31a1030e2cb74241c0c4fc8e8166f1a79a05d5b0dd95813a74b094529f317d5c39d235" + + data, err = enrsRes.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + expectEnrs := &Enrs{} + err = expectEnrs.UnmarshalSSZ(data) + assert.NoError(t, err) + assert.Equal(t, expectEnrs.Enrs, enrsRes.Enrs) +} + +func TestOfferAndAcceptMessage(t *testing.T) { + contentKey := "0x010203" + contentBytes := hexutil.MustDecode(contentKey) + contentKeys := [][]byte{contentBytes} + offer := &Offer{ + ContentKeys: contentKeys, + } + + expected := "0x0400000004000000010203" + + data, err := offer.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) + + contentKeyBitlist := bitfield.NewBitlist(8) + contentKeyBitlist.SetBitAt(0, true) + accept := &Accept{ + ConnectionId: []byte{0x01, 0x02}, + ContentKeys: contentKeyBitlist, + } + + expected = "0x0102060000000101" + + data, err = accept.MarshalSSZ() + assert.NoError(t, err) + assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) +} From cb91f650561b6e6d28f63f1f806317a116b843f5 Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Wed, 20 Nov 2024 16:39:57 +0800 Subject: [PATCH 03/13] feat: use ethdb to store data part1: impl simple get and put method --- portalnetwork/storage/content_storage.go | 3 + portalnetwork/storage/ethpepple/storage.go | 79 +++++++++++++++++++ .../storage/ethpepple/storage_test.go | 74 +++++++++++++++++ 3 files changed, 156 insertions(+) create mode 100644 portalnetwork/storage/ethpepple/storage.go create mode 100644 portalnetwork/storage/ethpepple/storage_test.go diff --git a/portalnetwork/storage/content_storage.go b/portalnetwork/storage/content_storage.go index 3a01df93f6ed..e726d612e54e 100644 --- a/portalnetwork/storage/content_storage.go +++ b/portalnetwork/storage/content_storage.go @@ -12,6 +12,9 @@ var MaxDistance = uint256.MustFromHex("0xfffffffffffffffffffffffffffffffffffffff type ContentType byte +var RadisuKey = []byte("radius") +var SizeKey = []byte("size") + type ContentKey struct { selector ContentType data []byte diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go new file mode 100644 index 000000000000..b074de1e4c2c --- /dev/null +++ b/portalnetwork/storage/ethpepple/storage.go @@ -0,0 +1,79 @@ +package ethpepple + +import ( + "sync/atomic" + + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/ethdb/pebble" + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/holiman/uint256" +) + +var _ storage.ContentStorage = &ContentStorage{} + +type PeppleStorageConfig struct { + StorageCapacityMB uint64 + DB ethdb.KeyValueStore + NodeId enode.ID + NetworkName string +} + +func NewPeppleDB(dataDir string, cache, handles int, namespace string) (ethdb.KeyValueStore, error) { + db, err := pebble.New(dataDir + "/" + namespace, cache, handles, namespace, false) + return db, err +} + +type ContentStorage struct { + nodeId enode.ID + storageCapacityInBytes uint64 + radius atomic.Value + // size uint64 + log log.Logger + db ethdb.KeyValueStore +} + +func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error) { + cs := &ContentStorage{ + nodeId: config.NodeId, + db: config.DB, + storageCapacityInBytes: config.StorageCapacityMB * 1000_000, + log: log.New("storage", config.NetworkName), + } + cs.radius.Store(storage.MaxDistance) + exist, err := cs.db.Has(storage.RadisuKey); + if err != nil { + return nil, err + } + if exist { + radius, err := cs.db.Get(storage.RadisuKey) + if err != nil { + return nil, err + } + dis := uint256.NewInt(0) + err = dis.UnmarshalSSZ(radius) + if err != nil { + return nil, err + } + cs.radius.Store(dis) + } + return cs, nil +} + +// Get implements storage.ContentStorage. +func (c *ContentStorage) Get(contentKey []byte, contentId []byte) ([]byte, error) { + return c.db.Get(contentId) +} + +// Put implements storage.ContentStorage. +func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte) error { + return c.db.Put(contentId, content) +} + +// Radius implements storage.ContentStorage. +func (p *ContentStorage) Radius() *uint256.Int { + radius := p.radius.Load() + val := radius.(*uint256.Int) + return val +} diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go new file mode 100644 index 000000000000..ddf69333e593 --- /dev/null +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -0,0 +1,74 @@ +package ethpepple + +import ( + "os" + "testing" + + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/storage" + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" +) + +const dataDir = "./node1" +var testRadius = uint256.NewInt(100000) + +func clearNodeData() { + _ = os.RemoveAll(dataDir) +} + +func getTestDb() (storage.ContentStorage, error) { + db, err := NewPeppleDB(dataDir, 100, 100, "history") + if err != nil { + return nil, err + } + config := PeppleStorageConfig{ + DB: db, + StorageCapacityMB: 100, + NodeId: enode.ID{}, + NetworkName: "history", + } + return NewPeppleStorage(config) +} + +func TestReadRadius(t *testing.T) { + db, err := getTestDb() + assert.NoError(t, err) + defer clearNodeData() + assert.True(t, db.Radius().Eq(storage.MaxDistance)) + + data, err := testRadius.MarshalSSZ() + assert.NoError(t, err) + db.Put(nil, storage.RadisuKey, data) + + store := db.(*ContentStorage) + err = store.db.Close() + assert.NoError(t, err) + + db, err = getTestDb() + assert.NoError(t, err) + assert.True(t, db.Radius().Eq(testRadius)) +} + +func TestStorage(t *testing.T) { + db, err := getTestDb() + assert.NoError(t, err) + defer clearNodeData() + testcases := map[string][]byte{ + "test1": []byte("test1"), + "test2": []byte("test2"), + "test3": []byte("test3"), + "test4": []byte("test4"), + } + + for key, value := range testcases { + db.Put(nil, []byte(key), value) + } + + for key, value := range testcases { + val, err := db.Get(nil, []byte(key)) + assert.NoError(t, err) + assert.Equal(t, value, val) + } + +} \ No newline at end of file From ad2fd55e7a700e6bae2e128b275d8b282d6fa5f5 Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Wed, 20 Nov 2024 18:08:12 +0800 Subject: [PATCH 04/13] fix: lint error --- portalnetwork/storage/ethpepple/storage.go | 14 +++++++------- portalnetwork/storage/ethpepple/storage_test.go | 12 ++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index b074de1e4c2c..ad4bae90be27 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -21,7 +21,7 @@ type PeppleStorageConfig struct { } func NewPeppleDB(dataDir string, cache, handles int, namespace string) (ethdb.KeyValueStore, error) { - db, err := pebble.New(dataDir + "/" + namespace, cache, handles, namespace, false) + db, err := pebble.New(dataDir+"/"+namespace, cache, handles, namespace, false) return db, err } @@ -30,19 +30,19 @@ type ContentStorage struct { storageCapacityInBytes uint64 radius atomic.Value // size uint64 - log log.Logger - db ethdb.KeyValueStore + log log.Logger + db ethdb.KeyValueStore } func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error) { cs := &ContentStorage{ nodeId: config.NodeId, - db: config.DB, + db: config.DB, storageCapacityInBytes: config.StorageCapacityMB * 1000_000, log: log.New("storage", config.NetworkName), } cs.radius.Store(storage.MaxDistance) - exist, err := cs.db.Has(storage.RadisuKey); + exist, err := cs.db.Has(storage.RadisuKey) if err != nil { return nil, err } @@ -72,8 +72,8 @@ func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte } // Radius implements storage.ContentStorage. -func (p *ContentStorage) Radius() *uint256.Int { - radius := p.radius.Load() +func (c *ContentStorage) Radius() *uint256.Int { + radius := c.radius.Load() val := radius.(*uint256.Int) return val } diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index ddf69333e593..b7eb6db0f0a2 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -11,6 +11,7 @@ import ( ) const dataDir = "./node1" + var testRadius = uint256.NewInt(100000) func clearNodeData() { @@ -23,10 +24,10 @@ func getTestDb() (storage.ContentStorage, error) { return nil, err } config := PeppleStorageConfig{ - DB: db, + DB: db, StorageCapacityMB: 100, - NodeId: enode.ID{}, - NetworkName: "history", + NodeId: enode.ID{}, + NetworkName: "history", } return NewPeppleStorage(config) } @@ -35,7 +36,7 @@ func TestReadRadius(t *testing.T) { db, err := getTestDb() assert.NoError(t, err) defer clearNodeData() - assert.True(t, db.Radius().Eq(storage.MaxDistance)) + assert.True(t, db.Radius().Eq(storage.MaxDistance)) data, err := testRadius.MarshalSSZ() assert.NoError(t, err) @@ -70,5 +71,4 @@ func TestStorage(t *testing.T) { assert.NoError(t, err) assert.Equal(t, value, val) } - -} \ No newline at end of file +} From 6388d8b3dbef8fc851df5db1ee5a1b14425c9a30 Mon Sep 17 00:00:00 2001 From: fearlessfe <505380967@qq.com> Date: Thu, 21 Nov 2024 21:54:02 +0800 Subject: [PATCH 05/13] fix: unit test error --- .../storage/ethpepple/storage_test.go | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index b7eb6db0f0a2..7bf543df9487 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -1,6 +1,8 @@ package ethpepple import ( + "crypto/rand" + "encoding/hex" "os" "testing" @@ -10,16 +12,24 @@ import ( "github.com/stretchr/testify/assert" ) -const dataDir = "./node1" - var testRadius = uint256.NewInt(100000) -func clearNodeData() { - _ = os.RemoveAll(dataDir) +func clearNodeData(path string) { + _ = os.RemoveAll(path) +} + +func getRandomPath() string { + // gen a random hex string + bytes := make([]byte, 32) + _, err := rand.Read(bytes) + if err != nil { + panic(err) + } + return hex.EncodeToString(bytes) } -func getTestDb() (storage.ContentStorage, error) { - db, err := NewPeppleDB(dataDir, 100, 100, "history") +func getTestDb(path string) (storage.ContentStorage, error) { + db, err := NewPeppleDB(path, 100, 100, "history") if err != nil { return nil, err } @@ -33,9 +43,10 @@ func getTestDb() (storage.ContentStorage, error) { } func TestReadRadius(t *testing.T) { - db, err := getTestDb() + path := getRandomPath() + db, err := getTestDb(path) assert.NoError(t, err) - defer clearNodeData() + defer clearNodeData(path) assert.True(t, db.Radius().Eq(storage.MaxDistance)) data, err := testRadius.MarshalSSZ() @@ -46,15 +57,16 @@ func TestReadRadius(t *testing.T) { err = store.db.Close() assert.NoError(t, err) - db, err = getTestDb() + db, err = getTestDb(path) assert.NoError(t, err) assert.True(t, db.Radius().Eq(testRadius)) } func TestStorage(t *testing.T) { - db, err := getTestDb() + path := getRandomPath() + db, err := getTestDb(path) assert.NoError(t, err) - defer clearNodeData() + defer clearNodeData(path) testcases := map[string][]byte{ "test1": []byte("test1"), "test2": []byte("test2"), From ee11806dd7992fbc6806e50566c55c726972b4f2 Mon Sep 17 00:00:00 2001 From: Chen Kai <281165273grape@gmail.com> Date: Mon, 25 Nov 2024 14:18:17 +0800 Subject: [PATCH 06/13] feat:migrate and remove portal wire in p2p package Signed-off-by: Chen Kai <281165273grape@gmail.com> --- cmd/shisui/main.go | 44 +- cmd/utils/flags.go | 2 +- metrics/portal_metrics.go | 153 -- p2p/discover/api.go | 550 ----- p2p/discover/portal_protocol.go | 1930 ----------------- p2p/discover/portal_protocol_metrics.go | 67 - p2p/discover/portal_protocol_test.go | 503 ----- p2p/discover/portal_utp.go | 138 -- portalnetwork/beacon/api.go | 14 +- portalnetwork/beacon/beacon_network.go | 6 +- portalnetwork/beacon/beacon_network_test.go | 5 - portalnetwork/beacon/portal_api.go | 4 +- portalnetwork/beacon/storage.go | 30 +- portalnetwork/beacon/test_utils.go | 16 +- portalnetwork/ethapi/api.go | 2 +- portalnetwork/history/api.go | 14 +- portalnetwork/history/history_network.go | 28 +- portalnetwork/history/history_network_test.go | 16 +- portalnetwork/history/new_storage.go | 443 ++++ portalnetwork/history/storage.go | 5 +- portalnetwork/nat.go | 172 -- portalnetwork/portal_protocol_metrics.go | 67 - portalnetwork/{ => portalwire}/api.go | 13 +- portalnetwork/portalwire/messages.go | 336 --- portalnetwork/portalwire/messages_encoding.go | 957 -------- portalnetwork/portalwire/messages_test.go | 212 -- .../portalwire}/nat.go | 2 +- .../{ => portalwire}/portal_protocol.go | 132 +- .../portalwire/portal_protocol_metrics.go | 217 ++ .../{ => portalwire}/portal_protocol_test.go | 11 +- portalnetwork/{ => portalwire}/portal_utp.go | 7 +- .../portalwire/types.go | 30 +- .../portalwire/types_encoding.go | 0 .../portalwire/types_test.go | 0 portalnetwork/state/api.go | 14 +- portalnetwork/state/network.go | 8 +- portalnetwork/state/network_test.go | 6 +- portalnetwork/state/storage.go | 5 +- 38 files changed, 852 insertions(+), 5307 deletions(-) delete mode 100644 metrics/portal_metrics.go delete mode 100644 p2p/discover/api.go delete mode 100644 p2p/discover/portal_protocol.go delete mode 100644 p2p/discover/portal_protocol_metrics.go delete mode 100644 p2p/discover/portal_protocol_test.go delete mode 100644 p2p/discover/portal_utp.go create mode 100644 portalnetwork/history/new_storage.go delete mode 100644 portalnetwork/nat.go delete mode 100644 portalnetwork/portal_protocol_metrics.go rename portalnetwork/{ => portalwire}/api.go (98%) delete mode 100644 portalnetwork/portalwire/messages.go delete mode 100644 portalnetwork/portalwire/messages_encoding.go delete mode 100644 portalnetwork/portalwire/messages_test.go rename {p2p/discover => portalnetwork/portalwire}/nat.go (99%) rename portalnetwork/{ => portalwire}/portal_protocol.go (93%) create mode 100644 portalnetwork/portalwire/portal_protocol_metrics.go rename portalnetwork/{ => portalwire}/portal_protocol_test.go (98%) rename portalnetwork/{ => portalwire}/portal_utp.go (94%) rename p2p/discover/portalwire/messages.go => portalnetwork/portalwire/types.go (87%) rename p2p/discover/portalwire/messages_encoding.go => portalnetwork/portalwire/types_encoding.go (100%) rename p2p/discover/portalwire/messages_test.go => portalnetwork/portalwire/types_test.go (100%) diff --git a/cmd/shisui/main.go b/cmd/shisui/main.go index 9f369cc4797e..41799c0e4a05 100644 --- a/cmd/shisui/main.go +++ b/cmd/shisui/main.go @@ -27,13 +27,13 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/discover" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/portalnetwork/beacon" "github.com/ethereum/go-ethereum/portalnetwork/ethapi" "github.com/ethereum/go-ethereum/portalnetwork/history" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/state" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/ethereum/go-ethereum/portalnetwork/web3" @@ -53,7 +53,7 @@ const ( ) type Config struct { - Protocol *discover.PortalProtocolConfig + Protocol *portalwire.PortalProtocolConfig PrivateKey *ecdsa.PrivateKey RpcAddr string DataDir string @@ -63,8 +63,8 @@ type Config struct { } type Client struct { - DiscV5API *discover.DiscV5API - HistoryNetwork *history.HistoryNetwork + DiscV5API *portalwire.DiscV5API + HistoryNetwork *history.Network BeaconNetwork *beacon.BeaconNetwork StateNetwork *state.StateNetwork Server *http.Server @@ -129,7 +129,7 @@ func shisui(ctx *cli.Context) error { // Start system runtime metrics collection go metrics.CollectProcessMetrics(3 * time.Second) - go metrics.CollectPortalMetrics(5*time.Second, ctx.StringSlice(utils.PortalNetworksFlag.Name), ctx.String(utils.PortalDataDirFlag.Name)) + go portalwire.CollectPortalMetrics(5*time.Second, ctx.StringSlice(utils.PortalNetworksFlag.Name), ctx.String(utils.PortalDataDirFlag.Name)) if metrics.Enabled { storageCapacity = metrics.NewRegisteredGauge("portal/storage_capacity", nil) @@ -229,7 +229,7 @@ func startPortalRpcServer(config Config, conn discover.UDPConn, addr string, cli } server := rpc.NewServer() - discV5API := discover.NewDiscV5API(discV5) + discV5API := portalwire.NewDiscV5API(discV5) err = server.RegisterName("discv5", discV5API) if err != nil { return err @@ -241,9 +241,9 @@ func startPortalRpcServer(config Config, conn discover.UDPConn, addr string, cli if err != nil { return err } - utp := discover.NewPortalUtp(context.Background(), config.Protocol, discV5, conn) + utp := portalwire.NewPortalUtp(context.Background(), config.Protocol, discV5, conn) - var historyNetwork *history.HistoryNetwork + var historyNetwork *history.Network if slices.Contains(config.Networks, portalwire.History.Name()) { historyNetwork, err = initHistory(config, server, conn, localNode, discV5, utp) if err != nil { @@ -305,7 +305,7 @@ func initDiscV5(config Config, conn discover.UDPConn) (*discover.UDPv5, *enode.L localNode := enode.NewLocalNode(nodeDB, config.PrivateKey) - localNode.Set(discover.Tag) + localNode.Set(portalwire.Tag) listenerAddr := conn.LocalAddr().(*net.UDPAddr) nat := config.Protocol.NAT if nat != nil && !listenerAddr.IP.IsLoopback() { @@ -370,7 +370,7 @@ func doPortMapping(natm nat.Interface, ln *enode.LocalNode, addr *net.UDPAddr) { }() } -func initHistory(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *discover.PortalUtp) (*history.HistoryNetwork, error) { +func initHistory(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *portalwire.PortalUtp) (*history.Network, error) { networkName := portalwire.History.Name() db, err := history.NewDB(config.DataDir, networkName) if err != nil { @@ -385,9 +385,9 @@ func initHistory(config Config, server *rpc.Server, conn discover.UDPConn, local if err != nil { return nil, err } - contentQueue := make(chan *discover.ContentElement, 50) + contentQueue := make(chan *portalwire.ContentElement, 50) - protocol, err := discover.NewPortalProtocol( + protocol, err := portalwire.NewPortalProtocol( config.Protocol, portalwire.History, config.PrivateKey, @@ -401,7 +401,7 @@ func initHistory(config Config, server *rpc.Server, conn discover.UDPConn, local if err != nil { return nil, err } - historyAPI := discover.NewPortalAPI(protocol) + historyAPI := portalwire.NewPortalAPI(protocol) historyNetworkAPI := history.NewHistoryNetworkAPI(historyAPI) err = server.RegisterName("portal", historyNetworkAPI) if err != nil { @@ -415,7 +415,7 @@ func initHistory(config Config, server *rpc.Server, conn discover.UDPConn, local return historyNetwork, historyNetwork.Start() } -func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *discover.PortalUtp) (*beacon.BeaconNetwork, error) { +func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *portalwire.PortalUtp) (*beacon.BeaconNetwork, error) { dbPath := path.Join(config.DataDir, "beacon") err := os.MkdirAll(dbPath, 0755) if err != nil { @@ -436,9 +436,9 @@ func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localN if err != nil { return nil, err } - contentQueue := make(chan *discover.ContentElement, 50) + contentQueue := make(chan *portalwire.ContentElement, 50) - protocol, err := discover.NewPortalProtocol( + protocol, err := portalwire.NewPortalProtocol( config.Protocol, portalwire.Beacon, config.PrivateKey, @@ -452,7 +452,7 @@ func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localN if err != nil { return nil, err } - portalApi := discover.NewPortalAPI(protocol) + portalApi := portalwire.NewPortalAPI(protocol) beaconAPI := beacon.NewBeaconNetworkAPI(portalApi) err = server.RegisterName("portal", beaconAPI) @@ -464,7 +464,7 @@ func initBeacon(config Config, server *rpc.Server, conn discover.UDPConn, localN return beaconNetwork, beaconNetwork.Start() } -func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *discover.PortalUtp) (*state.StateNetwork, error) { +func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *portalwire.PortalUtp) (*state.StateNetwork, error) { networkName := portalwire.State.Name() db, err := history.NewDB(config.DataDir, networkName) if err != nil { @@ -480,9 +480,9 @@ func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNo return nil, err } stateStore := state.NewStateStorage(contentStorage, db) - contentQueue := make(chan *discover.ContentElement, 50) + contentQueue := make(chan *portalwire.ContentElement, 50) - protocol, err := discover.NewPortalProtocol( + protocol, err := portalwire.NewPortalProtocol( config.Protocol, portalwire.State, config.PrivateKey, @@ -496,7 +496,7 @@ func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNo if err != nil { return nil, err } - api := discover.NewPortalAPI(protocol) + api := portalwire.NewPortalAPI(protocol) stateNetworkAPI := state.NewStateNetworkAPI(api) err = server.RegisterName("portal", stateNetworkAPI) if err != nil { @@ -509,7 +509,7 @@ func initState(config Config, server *rpc.Server, conn discover.UDPConn, localNo func getPortalConfig(ctx *cli.Context) (*Config, error) { config := &Config{ - Protocol: discover.DefaultPortalProtocolConfig(), + Protocol: portalwire.DefaultPortalProtocolConfig(), } httpAddr := ctx.String(utils.PortalRPCListenAddrFlag.Name) diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 3547089fe040..17773e4fa905 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -68,11 +68,11 @@ import ( "github.com/ethereum/go-ethereum/miner" "github.com/ethereum/go-ethereum/node" "github.com/ethereum/go-ethereum/p2p" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/triedb" "github.com/ethereum/go-ethereum/triedb/hashdb" diff --git a/metrics/portal_metrics.go b/metrics/portal_metrics.go deleted file mode 100644 index 8d524ffdddf3..000000000000 --- a/metrics/portal_metrics.go +++ /dev/null @@ -1,153 +0,0 @@ -package metrics - -import ( - "database/sql" - "errors" - "os" - "path" - "slices" - "strings" - "time" - - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" -) - -type networkFileMetric struct { - filename string - metric Gauge - file *os.File - network string -} - -type PortalStorageMetrics struct { - RadiusRatio GaugeFloat64 - EntriesCount Gauge - ContentStorageUsage Gauge -} - -const ( - countEntrySql = "SELECT COUNT(1) FROM kvstore;" - contentStorageUsageSql = "SELECT SUM( length(value) ) FROM kvstore;" -) - -// CollectPortalMetrics periodically collects various metrics about system entities. -func CollectPortalMetrics(refresh time.Duration, networks []string, dataDir string) { - // Short circuit if the metrics system is disabled - if !Enabled { - return - } - - // Define the various metrics to collect - var ( - historyTotalStorage = GetOrRegisterGauge("portal/history/total_storage", nil) - beaconTotalStorage = GetOrRegisterGauge("portal/beacon/total_storage", nil) - stateTotalStorage = GetOrRegisterGauge("portal/state/total_storage", nil) - ) - - var metricsArr []*networkFileMetric - if slices.Contains(networks, portalwire.History.Name()) { - dbPath := path.Join(dataDir, portalwire.History.Name()) - metricsArr = append(metricsArr, &networkFileMetric{ - filename: path.Join(dbPath, portalwire.History.Name()+".sqlite"), - metric: historyTotalStorage, - network: portalwire.History.Name(), - }) - } - if slices.Contains(networks, portalwire.Beacon.Name()) { - dbPath := path.Join(dataDir, portalwire.Beacon.Name()) - metricsArr = append(metricsArr, &networkFileMetric{ - filename: path.Join(dbPath, portalwire.Beacon.Name()+".sqlite"), - metric: beaconTotalStorage, - network: portalwire.Beacon.Name(), - }) - } - if slices.Contains(networks, portalwire.State.Name()) { - dbPath := path.Join(dataDir, portalwire.State.Name()) - metricsArr = append(metricsArr, &networkFileMetric{ - filename: path.Join(dbPath, portalwire.State.Name()+".sqlite"), - metric: stateTotalStorage, - network: portalwire.State.Name(), - }) - } - - for { - for _, m := range metricsArr { - var err error = nil - if m.file == nil { - m.file, err = os.OpenFile(m.filename, os.O_RDONLY, 0600) - if err != nil { - log.Debug("Could not open file", "network", m.network, "file", m.filename, "metric", "total_storage", "err", err) - } - } - if m.file != nil && err == nil { - stat, err := m.file.Stat() - if err != nil { - log.Warn("Could not get file stat", "network", m.network, "file", m.filename, "metric", "total_storage", "err", err) - } - if err == nil { - m.metric.Update(stat.Size()) - } - } - } - - time.Sleep(refresh) - } -} - -func NewPortalStorageMetrics(network string, db *sql.DB) (*PortalStorageMetrics, error) { - if !Enabled { - return nil, nil - } - - if network != portalwire.History.Name() && network != portalwire.Beacon.Name() && network != portalwire.State.Name() { - log.Debug("Unknow network for metrics", "network", network) - return nil, errors.New("unknow network for metrics") - } - - var countSql string - var contentSql string - if network == portalwire.Beacon.Name() { - countSql = strings.Replace(countEntrySql, "kvstore", "beacon", 1) - contentSql = strings.Replace(contentStorageUsageSql, "kvstore", "beacon", 1) - contentSql = strings.Replace(contentSql, "value", "content_value", 1) - } else { - countSql = countEntrySql - contentSql = contentStorageUsageSql - } - - storageMetrics := &PortalStorageMetrics{} - - storageMetrics.RadiusRatio = NewRegisteredGaugeFloat64("portal/"+network+"/radius_ratio", nil) - storageMetrics.RadiusRatio.Update(1) - - storageMetrics.EntriesCount = NewRegisteredGauge("portal/"+network+"/entry_count", nil) - log.Debug("Counting entities in " + network + " storage for metrics") - var res *int64 = new(int64) - q := db.QueryRow(countSql) - if q.Err() == sql.ErrNoRows { - storageMetrics.EntriesCount.Update(0) - } else if q.Err() != nil { - log.Error("Querry execution error", "network", network, "metric", "entry_count", "err", q.Err()) - return nil, q.Err() - } else { - q.Scan(res) - storageMetrics.EntriesCount.Update(*res) - } - - storageMetrics.ContentStorageUsage = NewRegisteredGauge("portal/"+network+"/content_storage", nil) - log.Debug("Counting storage usage (bytes) in " + network + " for metrics") - var res2 *int64 = new(int64) - q2 := db.QueryRow(contentSql) - if q2.Err() == sql.ErrNoRows { - storageMetrics.ContentStorageUsage.Update(0) - } else if q2.Err() != nil { - log.Error("Querry execution error", "network", network, "metric", "entry_count", "err", q2.Err()) - return nil, q2.Err() - } else { - q2.Scan(res2) - storageMetrics.ContentStorageUsage.Update(*res2) - } - - return storageMetrics, nil -} diff --git a/p2p/discover/api.go b/p2p/discover/api.go deleted file mode 100644 index e7fe5c764ba3..000000000000 --- a/p2p/discover/api.go +++ /dev/null @@ -1,550 +0,0 @@ -package discover - -import ( - "errors" - - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/holiman/uint256" -) - -// DiscV5API json-rpc spec -// https://playground.open-rpc.org/?schemaUrl=https://raw.githubusercontent.com/ethereum/portal-network-specs/assembled-spec/jsonrpc/openrpc.json&uiSchema%5BappBar%5D%5Bui:splitView%5D=false&uiSchema%5BappBar%5D%5Bui:input%5D=false&uiSchema%5BappBar%5D%5Bui:examplesDropdown%5D=false -type DiscV5API struct { - DiscV5 *UDPv5 -} - -func NewDiscV5API(discV5 *UDPv5) *DiscV5API { - return &DiscV5API{discV5} -} - -type NodeInfo struct { - NodeId string `json:"nodeId"` - Enr string `json:"enr"` - Ip string `json:"ip"` -} - -type RoutingTableInfo struct { - Buckets [][]string `json:"buckets"` - LocalNodeId string `json:"localNodeId"` -} - -type DiscV5PongResp struct { - EnrSeq uint64 `json:"enrSeq"` - RecipientIP string `json:"recipientIP"` - RecipientPort uint16 `json:"recipientPort"` -} - -type PortalPongResp struct { - EnrSeq uint32 `json:"enrSeq"` - DataRadius string `json:"dataRadius"` -} - -type ContentInfo struct { - Content string `json:"content"` - UtpTransfer bool `json:"utpTransfer"` -} - -type TraceContentResult struct { - Content string `json:"content"` - UtpTransfer bool `json:"utpTransfer"` - Trace Trace `json:"trace"` -} - -type Trace struct { - Origin string `json:"origin"` // local node id - TargetId string `json:"targetId"` // target content id - ReceivedFrom string `json:"receivedFrom"` // the node id of which content from - Responses map[string]RespByNode `json:"responses"` // the node id and there response nodeIds - Metadata map[string]*NodeMetadata `json:"metadata"` // node id and there metadata object - StartedAtMs int `json:"startedAtMs"` // timestamp of the beginning of this request in milliseconds - Cancelled []string `json:"cancelled"` // the node ids which are send but cancelled -} - -type NodeMetadata struct { - Enr string `json:"enr"` - Distance string `json:"distance"` -} - -type RespByNode struct { - DurationMs int32 `json:"durationMs"` - RespondedWith []string `json:"respondedWith"` -} - -type Enrs struct { - Enrs []string `json:"enrs"` -} - -func (d *DiscV5API) NodeInfo() *NodeInfo { - n := d.DiscV5.LocalNode().Node() - - return &NodeInfo{ - NodeId: "0x" + n.ID().String(), - Enr: n.String(), - Ip: n.IP().String(), - } -} - -func (d *DiscV5API) RoutingTableInfo() *RoutingTableInfo { - n := d.DiscV5.LocalNode().Node() - bucketNodes := d.DiscV5.RoutingTableInfo() - - return &RoutingTableInfo{ - Buckets: bucketNodes, - LocalNodeId: "0x" + n.ID().String(), - } -} - -func (d *DiscV5API) AddEnr(enr string) (bool, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return false, err - } - - // immediately add the node to the routing table - d.DiscV5.tab.mutex.Lock() - defer d.DiscV5.tab.mutex.Unlock() - d.DiscV5.tab.handleAddNode(addNodeOp{node: n, isInbound: true, forceSetLive: true}) - return true, nil -} - -func (d *DiscV5API) GetEnr(nodeId string) (bool, error) { - id, err := enode.ParseID(nodeId) - if err != nil { - return false, err - } - n := d.DiscV5.tab.GetNode(id) - if n == nil { - return false, errors.New("record not in local routing table") - } - - return true, nil -} - -func (d *DiscV5API) DeleteEnr(nodeId string) (bool, error) { - id, err := enode.ParseID(nodeId) - if err != nil { - return false, err - } - - n := d.DiscV5.tab.GetNode(id) - if n == nil { - return false, errors.New("record not in local routing table") - } - - d.DiscV5.tab.mutex.Lock() - defer d.DiscV5.tab.mutex.Unlock() - b := d.DiscV5.tab.bucket(n.ID()) - d.DiscV5.tab.deleteInBucket(b, n.ID()) - return true, nil -} - -func (d *DiscV5API) LookupEnr(nodeId string) (string, error) { - id, err := enode.ParseID(nodeId) - if err != nil { - return "", err - } - - enr := d.DiscV5.ResolveNodeId(id) - - if enr == nil { - return "", errors.New("record not found in DHT lookup") - } - - return enr.String(), nil -} - -func (d *DiscV5API) Ping(enr string) (*DiscV5PongResp, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return nil, err - } - - pong, err := d.DiscV5.PingWithResp(n) - if err != nil { - return nil, err - } - - return &DiscV5PongResp{ - EnrSeq: pong.ENRSeq, - RecipientIP: pong.ToIP.String(), - RecipientPort: pong.ToPort, - }, nil -} - -func (d *DiscV5API) FindNodes(enr string, distances []uint) ([]string, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return nil, err - } - findNodes, err := d.DiscV5.Findnode(n, distances) - if err != nil { - return nil, err - } - - enrs := make([]string, 0, len(findNodes)) - for _, r := range findNodes { - enrs = append(enrs, r.String()) - } - - return enrs, nil -} - -func (d *DiscV5API) TalkReq(enr string, protocol string, payload string) (string, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return "", err - } - - req, err := hexutil.Decode(payload) - if err != nil { - return "", err - } - - talkResp, err := d.DiscV5.TalkRequest(n, protocol, req) - if err != nil { - return "", err - } - return hexutil.Encode(talkResp), nil -} - -func (d *DiscV5API) RecursiveFindNodes(nodeId string) ([]string, error) { - findNodes := d.DiscV5.Lookup(enode.HexID(nodeId)) - - enrs := make([]string, 0, len(findNodes)) - for _, r := range findNodes { - enrs = append(enrs, r.String()) - } - - return enrs, nil -} - -type PortalProtocolAPI struct { - portalProtocol *PortalProtocol -} - -func NewPortalAPI(portalProtocol *PortalProtocol) *PortalProtocolAPI { - return &PortalProtocolAPI{ - portalProtocol: portalProtocol, - } -} - -func (p *PortalProtocolAPI) NodeInfo() *NodeInfo { - n := p.portalProtocol.localNode.Node() - - return &NodeInfo{ - NodeId: n.ID().String(), - Enr: n.String(), - Ip: n.IP().String(), - } -} - -func (p *PortalProtocolAPI) RoutingTableInfo() *RoutingTableInfo { - n := p.portalProtocol.localNode.Node() - bucketNodes := p.portalProtocol.RoutingTableInfo() - - return &RoutingTableInfo{ - Buckets: bucketNodes, - LocalNodeId: "0x" + n.ID().String(), - } -} - -func (p *PortalProtocolAPI) AddEnr(enr string) (bool, error) { - p.portalProtocol.Log.Debug("serving AddEnr", "enr", enr) - n, err := enode.ParseForAddEnr(enode.ValidSchemes, enr) - if err != nil { - return false, err - } - p.portalProtocol.AddEnr(n) - return true, nil -} - -func (p *PortalProtocolAPI) AddEnrs(enrs []string) bool { - // Note: unspecified RPC, but useful for our local testnet test - for _, enr := range enrs { - n, err := enode.ParseForAddEnr(enode.ValidSchemes, enr) - if err != nil { - continue - } - p.portalProtocol.AddEnr(n) - } - - return true -} - -func (p *PortalProtocolAPI) GetEnr(nodeId string) (string, error) { - id, err := enode.ParseID(nodeId) - if err != nil { - return "", err - } - - if id == p.portalProtocol.localNode.Node().ID() { - return p.portalProtocol.localNode.Node().String(), nil - } - - n := p.portalProtocol.table.GetNode(id) - if n == nil { - return "", errors.New("record not in local routing table") - } - - return n.String(), nil -} - -func (p *PortalProtocolAPI) DeleteEnr(nodeId string) (bool, error) { - id, err := enode.ParseID(nodeId) - if err != nil { - return false, err - } - - n := p.portalProtocol.table.GetNode(id) - if n == nil { - return false, nil - } - - p.portalProtocol.table.mutex.Lock() - defer p.portalProtocol.table.mutex.Unlock() - b := p.portalProtocol.table.bucket(n.ID()) - p.portalProtocol.table.deleteInBucket(b, n.ID()) - return true, nil -} - -func (p *PortalProtocolAPI) LookupEnr(nodeId string) (string, error) { - id, err := enode.ParseID(nodeId) - if err != nil { - return "", err - } - - enr := p.portalProtocol.ResolveNodeId(id) - - if enr == nil { - return "", errors.New("record not found in DHT lookup") - } - - return enr.String(), nil -} - -func (p *PortalProtocolAPI) Ping(enr string) (*PortalPongResp, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return nil, err - } - - pong, err := p.portalProtocol.pingInner(n) - if err != nil { - return nil, err - } - - customPayload := &portalwire.PingPongCustomData{} - err = customPayload.UnmarshalSSZ(pong.CustomPayload) - if err != nil { - return nil, err - } - - nodeRadius := new(uint256.Int) - err = nodeRadius.UnmarshalSSZ(customPayload.Radius) - if err != nil { - return nil, err - } - - return &PortalPongResp{ - EnrSeq: uint32(pong.EnrSeq), - DataRadius: nodeRadius.Hex(), - }, nil -} - -func (p *PortalProtocolAPI) FindNodes(enr string, distances []uint) ([]string, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return nil, err - } - findNodes, err := p.portalProtocol.findNodes(n, distances) - if err != nil { - return nil, err - } - - enrs := make([]string, 0, len(findNodes)) - for _, r := range findNodes { - enrs = append(enrs, r.String()) - } - - return enrs, nil -} - -func (p *PortalProtocolAPI) FindContent(enr string, contentKey string) (interface{}, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return nil, err - } - - contentKeyBytes, err := hexutil.Decode(contentKey) - if err != nil { - return nil, err - } - - flag, findContent, err := p.portalProtocol.findContent(n, contentKeyBytes) - if err != nil { - return nil, err - } - - switch flag { - case portalwire.ContentRawSelector: - contentInfo := &ContentInfo{ - Content: hexutil.Encode(findContent.([]byte)), - UtpTransfer: false, - } - p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) - return contentInfo, nil - case portalwire.ContentConnIdSelector: - contentInfo := &ContentInfo{ - Content: hexutil.Encode(findContent.([]byte)), - UtpTransfer: true, - } - p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) - return contentInfo, nil - default: - enrs := make([]string, 0) - for _, r := range findContent.([]*enode.Node) { - enrs = append(enrs, r.String()) - } - - p.portalProtocol.Log.Trace("FindContent", "enrs", enrs) - return &Enrs{ - Enrs: enrs, - }, nil - } -} - -func (p *PortalProtocolAPI) Offer(enr string, contentItems [][2]string) (string, error) { - n, err := enode.Parse(enode.ValidSchemes, enr) - if err != nil { - return "", err - } - - entries := make([]*ContentEntry, 0, len(contentItems)) - for _, contentItem := range contentItems { - contentKey, err := hexutil.Decode(contentItem[0]) - if err != nil { - return "", err - } - contentValue, err := hexutil.Decode(contentItem[1]) - if err != nil { - return "", err - } - contentEntry := &ContentEntry{ - ContentKey: contentKey, - Content: contentValue, - } - entries = append(entries, contentEntry) - } - - transientOfferRequest := &TransientOfferRequest{ - Contents: entries, - } - - offerReq := &OfferRequest{ - Kind: TransientOfferRequestKind, - Request: transientOfferRequest, - } - accept, err := p.portalProtocol.offer(n, offerReq) - if err != nil { - return "", err - } - - return hexutil.Encode(accept), nil -} - -func (p *PortalProtocolAPI) RecursiveFindNodes(nodeId string) ([]string, error) { - findNodes := p.portalProtocol.Lookup(enode.HexID(nodeId)) - - enrs := make([]string, 0, len(findNodes)) - for _, r := range findNodes { - enrs = append(enrs, r.String()) - } - - return enrs, nil -} - -func (p *PortalProtocolAPI) RecursiveFindContent(contentKeyHex string) (*ContentInfo, error) { - contentKey, err := hexutil.Decode(contentKeyHex) - if err != nil { - return nil, err - } - contentId := p.portalProtocol.toContentId(contentKey) - - data, err := p.portalProtocol.Get(contentKey, contentId) - if err == nil { - return &ContentInfo{ - Content: hexutil.Encode(data), - UtpTransfer: false, - }, err - } - p.portalProtocol.Log.Warn("find content err", "contextKey", hexutil.Encode(contentKey), "err", err) - - content, utpTransfer, err := p.portalProtocol.ContentLookup(contentKey, contentId) - - if err != nil { - return nil, err - } - - return &ContentInfo{ - Content: hexutil.Encode(content), - UtpTransfer: utpTransfer, - }, err -} - -func (p *PortalProtocolAPI) LocalContent(contentKeyHex string) (string, error) { - contentKey, err := hexutil.Decode(contentKeyHex) - if err != nil { - return "", err - } - contentId := p.portalProtocol.ToContentId(contentKey) - content, err := p.portalProtocol.Get(contentKey, contentId) - - if err != nil { - return "", err - } - return hexutil.Encode(content), nil -} - -func (p *PortalProtocolAPI) Store(contentKeyHex string, contextHex string) (bool, error) { - contentKey, err := hexutil.Decode(contentKeyHex) - if err != nil { - return false, err - } - contentId := p.portalProtocol.ToContentId(contentKey) - if !p.portalProtocol.InRange(contentId) { - return false, nil - } - content, err := hexutil.Decode(contextHex) - if err != nil { - return false, err - } - err = p.portalProtocol.Put(contentKey, contentId, content) - if err != nil { - return false, err - } - return true, nil -} - -func (p *PortalProtocolAPI) Gossip(contentKeyHex, contentHex string) (int, error) { - contentKey, err := hexutil.Decode(contentKeyHex) - if err != nil { - return 0, err - } - content, err := hexutil.Decode(contentHex) - if err != nil { - return 0, err - } - id := p.portalProtocol.Self().ID() - return p.portalProtocol.Gossip(&id, [][]byte{contentKey}, [][]byte{content}) -} - -func (p *PortalProtocolAPI) TraceRecursiveFindContent(contentKeyHex string) (*TraceContentResult, error) { - contentKey, err := hexutil.Decode(contentKeyHex) - if err != nil { - return nil, err - } - contentId := p.portalProtocol.toContentId(contentKey) - return p.portalProtocol.TraceContentLookup(contentKey, contentId) -} diff --git a/p2p/discover/portal_protocol.go b/p2p/discover/portal_protocol.go deleted file mode 100644 index 8e2129854e73..000000000000 --- a/p2p/discover/portal_protocol.go +++ /dev/null @@ -1,1930 +0,0 @@ -package discover - -import ( - "bytes" - "context" - "crypto/ecdsa" - crand "crypto/rand" - "crypto/sha256" - "encoding/binary" - "errors" - "fmt" - "io" - "math/big" - "math/rand" - "net" - "slices" - "sort" - "sync" - "sync/atomic" - "time" - - "github.com/VictoriaMetrics/fastcache" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/metrics" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/p2p/nat" - "github.com/ethereum/go-ethereum/p2p/netutil" - "github.com/ethereum/go-ethereum/portalnetwork/storage" - "github.com/ethereum/go-ethereum/rlp" - ssz "github.com/ferranbt/fastssz" - "github.com/holiman/uint256" - "github.com/optimism-java/utp-go" - "github.com/optimism-java/utp-go/libutp" - "github.com/prysmaticlabs/go-bitfield" - "github.com/tetratelabs/wabin/leb128" -) - -const ( - - // TalkResp message is a response message so the session is established and a - // regular discv5 packet is assumed for size calculation. - // Regular message = IV + header + message - // talkResp message = rlp: [request-id, response] - talkRespOverhead = 16 + // IV size - 55 + // header size - 1 + // talkResp msg id - 3 + // rlp encoding outer list, max length will be encoded in 2 bytes - 9 + // request id (max = 8) + 1 byte from rlp encoding byte string - 3 + // rlp encoding response byte string, max length in 2 bytes - 16 // HMAC - - portalFindnodesResultLimit = 32 - - defaultUTPConnectTimeout = 15 * time.Second - - defaultUTPWriteTimeout = 60 * time.Second - - defaultUTPReadTimeout = 60 * time.Second - - // These are the concurrent offers per Portal wire protocol that is running. - // Using the `offerQueue` allows for limiting the amount of offers send and - // thus how many streams can be started. - // TODO: - // More thought needs to go into this as it is currently on a per network - // basis. Keep it simple like that? Or limit it better at the stream transport - // level? In the latter case, this might still need to be checked/blocked at - // the very start of sending the offer, because blocking/waiting too long - // between the received accept message and actually starting the stream and - // sending data could give issues due to timeouts on the other side. - // And then there are still limits to be applied also for FindContent and the - // incoming directions. - concurrentOffers = 50 -) - -const ( - TransientOfferRequestKind byte = 0x01 - PersistOfferRequestKind byte = 0x02 -) - -type ClientTag string - -func (c ClientTag) ENRKey() string { return "c" } - -const Tag ClientTag = "shisui" - -var ErrNilContentKey = errors.New("content key cannot be nil") - -var ContentNotFound = storage.ErrContentNotFound - -var ErrEmptyResp = errors.New("empty resp") - -var MaxDistance = hexutil.MustDecode("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - -type ContentElement struct { - Node enode.ID - ContentKeys [][]byte - Contents [][]byte -} - -type ContentEntry struct { - ContentKey []byte - Content []byte -} - -type TransientOfferRequest struct { - Contents []*ContentEntry -} - -type PersistOfferRequest struct { - ContentKeys [][]byte -} - -type OfferRequest struct { - Kind byte - Request interface{} -} - -type OfferRequestWithNode struct { - Request *OfferRequest - Node *enode.Node -} - -type ContentInfoResp struct { - Content []byte - UtpTransfer bool -} - -type traceContentInfoResp struct { - Node *enode.Node - Flag byte - Content any - UtpTransfer bool -} - -type PortalProtocolOption func(p *PortalProtocol) - -type PortalProtocolConfig struct { - BootstrapNodes []*enode.Node - // NodeIP net.IP - ListenAddr string - NetRestrict *netutil.Netlist - NodeRadius *uint256.Int - RadiusCacheSize int - NodeDBPath string - NAT nat.Interface - clock mclock.Clock -} - -func DefaultPortalProtocolConfig() *PortalProtocolConfig { - return &PortalProtocolConfig{ - BootstrapNodes: make([]*enode.Node, 0), - ListenAddr: ":9009", - NetRestrict: nil, - RadiusCacheSize: 32 * 1024 * 1024, - NodeDBPath: "", - clock: mclock.System{}, - } -} - -type PortalProtocol struct { - table *Table - - protocolId string - protocolName string - - DiscV5 *UDPv5 - localNode *enode.LocalNode - Log log.Logger - PrivateKey *ecdsa.PrivateKey - NetRestrict *netutil.Netlist - BootstrapNodes []*enode.Node - conn UDPConn - - Utp *PortalUtp - connIdGen libutp.ConnIdGenerator - - validSchemes enr.IdentityScheme - radiusCache *fastcache.Cache - closeCtx context.Context - cancelCloseCtx context.CancelFunc - storage storage.ContentStorage - toContentId func(contentKey []byte) []byte - - contentQueue chan *ContentElement - offerQueue chan *OfferRequestWithNode - - portMappingRegister chan *portMapping - clock mclock.Clock - NAT nat.Interface - - portalMetrics *portalMetrics -} - -func defaultContentIdFunc(contentKey []byte) []byte { - digest := sha256.Sum256(contentKey) - return digest[:] -} - -func NewPortalProtocol(config *PortalProtocolConfig, protocolId portalwire.ProtocolId, privateKey *ecdsa.PrivateKey, conn UDPConn, localNode *enode.LocalNode, discV5 *UDPv5, utp *PortalUtp, storage storage.ContentStorage, contentQueue chan *ContentElement, opts ...PortalProtocolOption) (*PortalProtocol, error) { - closeCtx, cancelCloseCtx := context.WithCancel(context.Background()) - - protocol := &PortalProtocol{ - protocolId: string(protocolId), - protocolName: protocolId.Name(), - Log: log.New("protocol", protocolId.Name()), - PrivateKey: privateKey, - NetRestrict: config.NetRestrict, - BootstrapNodes: config.BootstrapNodes, - radiusCache: fastcache.New(config.RadiusCacheSize), - closeCtx: closeCtx, - cancelCloseCtx: cancelCloseCtx, - localNode: localNode, - validSchemes: enode.ValidSchemes, - storage: storage, - toContentId: defaultContentIdFunc, - contentQueue: contentQueue, - offerQueue: make(chan *OfferRequestWithNode, concurrentOffers), - conn: conn, - DiscV5: discV5, - Utp: utp, - NAT: config.NAT, - clock: config.clock, - connIdGen: libutp.NewConnIdGenerator(), - } - - for _, opt := range opts { - opt(protocol) - } - - if metrics.Enabled { - protocol.portalMetrics = newPortalMetrics(protocolId.Name()) - } - - return protocol, nil -} - -func (p *PortalProtocol) Start() error { - p.setupPortMapping() - - err := p.setupDiscV5AndTable() - if err != nil { - return err - } - - p.DiscV5.RegisterTalkHandler(p.protocolId, p.handleTalkRequest) - if p.Utp != nil { - err = p.Utp.Start() - } - if err != nil { - return err - } - - go p.table.Loop() - - for i := 0; i < concurrentOffers; i++ { - go p.offerWorker() - } - - // wait for both initialization processes to complete - <-p.DiscV5.tab.initDone - <-p.table.initDone - return nil -} - -func (p *PortalProtocol) Stop() { - p.cancelCloseCtx() - p.table.Close() - p.DiscV5.Close() - if p.Utp != nil { - p.Utp.Stop() - } -} -func (p *PortalProtocol) RoutingTableInfo() [][]string { - p.table.mutex.Lock() - defer p.table.mutex.Unlock() - nodes := make([][]string, 0) - for _, b := range &p.table.buckets { - bucketNodes := make([]string, 0) - for _, n := range b.entries { - bucketNodes = append(bucketNodes, "0x"+n.ID().String()) - } - nodes = append(nodes, bucketNodes) - } - p.Log.Trace("routingTableInfo resp:", "nodes", nodes) - return nodes -} - -func (p *PortalProtocol) AddEnr(n *enode.Node) { - // immediately add the node to the routing table - p.table.mutex.Lock() - defer p.table.mutex.Unlock() - added := p.table.handleAddNode(addNodeOp{node: n, isInbound: true, forceSetLive: true}) - if !added { - p.Log.Warn("add node failed", "id", n.ID(), "ip", n.IPAddr()) - return - } - id := n.ID().String() - p.radiusCache.Set([]byte(id), MaxDistance) -} - -func (p *PortalProtocol) Radius() *uint256.Int { - return p.storage.Radius() -} - -func (p *PortalProtocol) setupUDPListening() error { - laddr := p.conn.LocalAddr().(*net.UDPAddr) - p.localNode.SetFallbackUDP(laddr.Port) - p.Log.Debug("UDP listener up", "addr", laddr) - // TODO: NAT - if !laddr.IP.IsLoopback() && !laddr.IP.IsPrivate() { - p.portMappingRegister <- &portMapping{ - protocol: "UDP", - name: "ethereum portal peer discovery", - port: laddr.Port, - } - } - return nil -} - -func (p *PortalProtocol) setupDiscV5AndTable() error { - err := p.setupUDPListening() - if err != nil { - return err - } - - cfg := Config{ - PrivateKey: p.PrivateKey, - NetRestrict: p.NetRestrict, - Bootnodes: p.BootstrapNodes, - Log: p.Log, - } - - p.table, err = NewTable(p, p.localNode.Database(), cfg) - if err != nil { - return err - } - - return nil -} - -func (p *PortalProtocol) Ping(node *enode.Node) (uint64, error) { - pong, err := p.pingInner(node) - if err != nil { - return 0, err - } - - return pong.EnrSeq, nil -} - -func (p *PortalProtocol) pingInner(node *enode.Node) (*portalwire.Pong, error) { - enrSeq := p.Self().Seq() - radiusBytes, err := p.Radius().MarshalSSZ() - if err != nil { - return nil, err - } - customPayload := &portalwire.PingPongCustomData{ - Radius: radiusBytes, - } - - customPayloadBytes, err := customPayload.MarshalSSZ() - if err != nil { - return nil, err - } - - pingRequest := &portalwire.Ping{ - EnrSeq: enrSeq, - CustomPayload: customPayloadBytes, - } - - p.Log.Trace(">> PING/"+p.protocolName, "protocol", p.protocolName, "ip", p.Self().IP().String(), "source", p.Self().ID(), "target", node.ID(), "ping", pingRequest) - if metrics.Enabled { - p.portalMetrics.messagesSentPing.Mark(1) - } - pingRequestBytes, err := pingRequest.MarshalSSZ() - if err != nil { - return nil, err - } - - talkRequestBytes := make([]byte, 0, len(pingRequestBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.PING) - talkRequestBytes = append(talkRequestBytes, pingRequestBytes...) - - talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) - - if err != nil { - return nil, err - } - - p.Log.Trace("<< PONG/"+p.protocolName, "source", p.Self().ID(), "target", node.ID(), "res", talkResp) - if metrics.Enabled { - p.portalMetrics.messagesReceivedPong.Mark(1) - } - - return p.processPong(node, talkResp) -} - -func (p *PortalProtocol) findNodes(node *enode.Node, distances []uint) ([]*enode.Node, error) { - if p.localNode.ID().String() == node.ID().String() { - return make([]*enode.Node, 0), nil - } - - distancesBytes := make([][2]byte, len(distances)) - for i, distance := range distances { - copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), uint16(distance))) - } - - findNodes := &portalwire.FindNodes{ - Distances: distancesBytes, - } - - p.Log.Trace(">> FIND_NODES/"+p.protocolName, "id", node.ID(), "findNodes", findNodes) - if metrics.Enabled { - p.portalMetrics.messagesSentFindNodes.Mark(1) - } - findNodesBytes, err := findNodes.MarshalSSZ() - if err != nil { - p.Log.Error("failed to marshal find nodes request", "err", err) - return nil, err - } - - talkRequestBytes := make([]byte, 0, len(findNodesBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.FINDNODES) - talkRequestBytes = append(talkRequestBytes, findNodesBytes...) - - talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) - if err != nil { - p.Log.Error("failed to send find nodes request", "ip", node.IP().String(), "port", node.UDP(), "err", err) - return nil, err - } - - return p.processNodes(node, talkResp, distances) -} - -func (p *PortalProtocol) findContent(node *enode.Node, contentKey []byte) (byte, interface{}, error) { - findContent := &portalwire.FindContent{ - ContentKey: contentKey, - } - - p.Log.Trace(">> FIND_CONTENT/"+p.protocolName, "id", node.ID(), "findContent", findContent) - if metrics.Enabled { - p.portalMetrics.messagesSentFindContent.Mark(1) - } - findContentBytes, err := findContent.MarshalSSZ() - if err != nil { - p.Log.Error("failed to marshal find content request", "err", err) - return 0xff, nil, err - } - - talkRequestBytes := make([]byte, 0, len(findContentBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.FINDCONTENT) - talkRequestBytes = append(talkRequestBytes, findContentBytes...) - - talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) - if err != nil { - p.Log.Error("failed to send find content request", "ip", node.IP().String(), "port", node.UDP(), "err", err) - return 0xff, nil, err - } - - return p.processContent(node, talkResp) -} - -func (p *PortalProtocol) offer(node *enode.Node, offerRequest *OfferRequest) ([]byte, error) { - contentKeys := getContentKeys(offerRequest) - - offer := &portalwire.Offer{ - ContentKeys: contentKeys, - } - - p.Log.Trace(">> OFFER/"+p.protocolName, "offer", offer) - if metrics.Enabled { - p.portalMetrics.messagesSentOffer.Mark(1) - } - offerBytes, err := offer.MarshalSSZ() - if err != nil { - p.Log.Error("failed to marshal offer request", "err", err) - return nil, err - } - - talkRequestBytes := make([]byte, 0, len(offerBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.OFFER) - talkRequestBytes = append(talkRequestBytes, offerBytes...) - - talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) - if err != nil { - p.Log.Error("failed to send offer request", "err", err) - return nil, err - } - - return p.processOffer(node, talkResp, offerRequest) -} - -func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request *OfferRequest) ([]byte, error) { - var err error - if len(resp) == 0 { - return nil, ErrEmptyResp - } - if resp[0] != portalwire.ACCEPT { - return nil, fmt.Errorf("invalid accept response") - } - - p.Log.Info("will process Offer", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) - - accept := &portalwire.Accept{} - err = accept.UnmarshalSSZ(resp[1:]) - if err != nil { - return nil, err - } - - p.Log.Trace("<< ACCEPT/"+p.protocolName, "id", target.ID(), "accept", accept) - if metrics.Enabled { - p.portalMetrics.messagesReceivedAccept.Mark(1) - } - isAdded := p.table.AddFoundNode(target, true) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } - var contentKeyLen int - if request.Kind == TransientOfferRequestKind { - contentKeyLen = len(request.Request.(*TransientOfferRequest).Contents) - } else { - contentKeyLen = len(request.Request.(*PersistOfferRequest).ContentKeys) - } - - contentKeyBitlist := bitfield.Bitlist(accept.ContentKeys) - if contentKeyBitlist.Len() != uint64(contentKeyLen) { - return nil, fmt.Errorf("accepted content key bitlist has invalid size, expected %d, got %d", contentKeyLen, contentKeyBitlist.Len()) - } - - if contentKeyBitlist.Count() == 0 { - return nil, nil - } - - connId := binary.BigEndian.Uint16(accept.ConnectionId[:]) - go func(ctx context.Context) { - var conn net.Conn - defer func() { - if conn == nil { - return - } - err := conn.Close() - if err != nil { - p.Log.Error("failed to close connection", "err", err) - } - }() - for { - select { - case <-ctx.Done(): - return - default: - contents := make([][]byte, 0, contentKeyBitlist.Count()) - var content []byte - if request.Kind == TransientOfferRequestKind { - for _, index := range contentKeyBitlist.BitIndices() { - content = request.Request.(*TransientOfferRequest).Contents[index].Content - contents = append(contents, content) - } - } else { - for _, index := range contentKeyBitlist.BitIndices() { - contentKey := request.Request.(*PersistOfferRequest).ContentKeys[index] - contentId := p.toContentId(contentKey) - if contentId != nil { - content, err = p.storage.Get(contentKey, contentId) - if err != nil { - p.Log.Error("failed to get content from storage", "err", err) - contents = append(contents, []byte{}) - } else { - contents = append(contents, content) - } - } else { - contents = append(contents, []byte{}) - } - } - } - - var contentsPayload []byte - contentsPayload, err = encodeContents(contents) - if err != nil { - p.Log.Error("failed to encode contents", "err", err) - return - } - - connctx, conncancel := context.WithTimeout(ctx, defaultUTPConnectTimeout) - conn, err = p.Utp.DialWithCid(connctx, target, libutp.ReceConnId(connId).SendId()) - conncancel() - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpOutFailConn.Inc(1) - } - p.Log.Error("failed to dial utp connection", "err", err) - return - } - - err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpOutFailDeadline.Inc(1) - } - p.Log.Error("failed to set write deadline", "err", err) - return - } - - var written int - written, err = conn.Write(contentsPayload) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpOutFailWrite.Inc(1) - } - p.Log.Error("failed to write to utp connection", "err", err) - return - } - p.Log.Trace(">> CONTENT/"+p.protocolName, "id", target.ID(), "contents", contents, "size", written) - if metrics.Enabled { - p.portalMetrics.messagesSentContent.Mark(1) - p.portalMetrics.utpOutSuccess.Inc(1) - } - return - } - } - }(p.closeCtx) - - return accept.ContentKeys, nil -} - -func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, interface{}, error) { - if len(resp) == 0 { - return 0x00, nil, ErrEmptyResp - } - - if resp[0] != portalwire.CONTENT { - return 0xff, nil, fmt.Errorf("invalid content response") - } - - p.Log.Info("will process content", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) - - switch resp[1] { - case portalwire.ContentRawSelector: - content := &portalwire.Content{} - err := content.UnmarshalSSZ(resp[2:]) - if err != nil { - return 0xff, nil, err - } - - p.Log.Trace("<< CONTENT/"+p.protocolName, "id", target.ID(), "content", content) - if metrics.Enabled { - p.portalMetrics.messagesReceivedContent.Mark(1) - } - isAdded := p.table.AddFoundNode(target, true) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } - return resp[1], content.Content, nil - case portalwire.ContentConnIdSelector: - connIdMsg := &portalwire.ConnectionId{} - err := connIdMsg.UnmarshalSSZ(resp[2:]) - if err != nil { - return 0xff, nil, err - } - - p.Log.Trace("<< CONTENT_CONNECTION_ID/"+p.protocolName, "id", target.ID(), "resp", common.Bytes2Hex(resp), "connIdMsg", connIdMsg) - if metrics.Enabled { - p.portalMetrics.messagesReceivedContent.Mark(1) - } - isAdded := p.table.AddFoundNode(target, true) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } - connctx, conncancel := context.WithTimeout(p.closeCtx, defaultUTPConnectTimeout) - connId := binary.BigEndian.Uint16(connIdMsg.Id[:]) - conn, err := p.Utp.DialWithCid(connctx, target, libutp.ReceConnId(connId).SendId()) - defer func() { - if conn == nil { - if metrics.Enabled { - p.portalMetrics.utpInFailConn.Inc(1) - } - return - } - err := conn.Close() - if err != nil { - p.Log.Error("failed to close connection", "err", err) - } - }() - conncancel() - if err != nil { - return 0xff, nil, err - } - - err = conn.SetReadDeadline(time.Now().Add(defaultUTPReadTimeout)) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpInFailDeadline.Inc(1) - } - return 0xff, nil, err - } - // Read ALL the data from the connection until EOF and return it - data, err := io.ReadAll(conn) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpInFailRead.Inc(1) - } - p.Log.Error("failed to read from utp connection", "err", err) - return 0xff, nil, err - } - p.Log.Trace("<< CONTENT/"+p.protocolName, "id", target.ID(), "size", len(data), "data", data) - if metrics.Enabled { - p.portalMetrics.messagesReceivedContent.Mark(1) - p.portalMetrics.utpInSuccess.Inc(1) - } - return resp[1], data, nil - case portalwire.ContentEnrsSelector: - enrs := &portalwire.Enrs{} - err := enrs.UnmarshalSSZ(resp[2:]) - - if err != nil { - return 0xff, nil, err - } - - p.Log.Trace("<< CONTENT_ENRS/"+p.protocolName, "id", target.ID(), "enrs", enrs) - if metrics.Enabled { - p.portalMetrics.messagesReceivedContent.Mark(1) - } - isAdded := p.table.AddFoundNode(target, true) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } - nodes := p.filterNodes(target, enrs.Enrs, nil) - return resp[1], nodes, nil - default: - return 0xff, nil, fmt.Errorf("invalid content response") - } -} - -func (p *PortalProtocol) processNodes(target *enode.Node, resp []byte, distances []uint) ([]*enode.Node, error) { - if len(resp) == 0 { - return nil, ErrEmptyResp - } - - if resp[0] != portalwire.NODES { - return nil, fmt.Errorf("invalid nodes response") - } - - nodesResp := &portalwire.Nodes{} - err := nodesResp.UnmarshalSSZ(resp[1:]) - if err != nil { - return nil, err - } - - isAdded := p.table.AddFoundNode(target, true) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } - nodes := p.filterNodes(target, nodesResp.Enrs, distances) - - return nodes, nil -} - -func (p *PortalProtocol) filterNodes(target *enode.Node, enrs [][]byte, distances []uint) []*enode.Node { - var ( - nodes []*enode.Node - seen = make(map[enode.ID]struct{}) - err error - verified = 0 - n *enode.Node - ) - - for _, b := range enrs { - record := &enr.Record{} - err = rlp.DecodeBytes(b, record) - if err != nil { - p.Log.Error("Invalid record in nodes response", "id", target.ID(), "err", err) - continue - } - n, err = p.verifyResponseNode(target, record, distances, seen) - if err != nil { - p.Log.Error("Invalid record in nodes response", "id", target.ID(), "err", err) - continue - } - verified++ - nodes = append(nodes, n) - } - - p.Log.Trace("<< NODES/"+p.protocolName, "id", target.ID(), "total", len(enrs), "verified", verified, "nodes", nodes) - if metrics.Enabled { - p.portalMetrics.messagesReceivedNodes.Mark(1) - } - return nodes -} - -func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwire.Pong, error) { - if len(resp) == 0 { - return nil, ErrEmptyResp - } - if resp[0] != portalwire.PONG { - return nil, fmt.Errorf("invalid pong response") - } - pong := &portalwire.Pong{} - err := pong.UnmarshalSSZ(resp[1:]) - if err != nil { - return nil, err - } - - p.Log.Trace("<< PONG_RESPONSE/"+p.protocolName, "id", target.ID(), "pong", pong) - if metrics.Enabled { - p.portalMetrics.messagesReceivedPong.Mark(1) - } - - customPayload := &portalwire.PingPongCustomData{} - err = customPayload.UnmarshalSSZ(pong.CustomPayload) - if err != nil { - return nil, err - } - - p.Log.Trace("<< PONG_RESPONSE/"+p.protocolName, "id", target.ID(), "pong", pong, "customPayload", customPayload) - if metrics.Enabled { - p.portalMetrics.messagesReceivedPong.Mark(1) - } - isAdded := p.table.AddFoundNode(target, true) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) - } - - p.radiusCache.Set([]byte(target.ID().String()), customPayload.Radius) - return pong, nil -} - -func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { - if n := p.DiscV5.GetNode(id); n != nil { - p.table.AddInboundNode(n) - } - - msgCode := msg[0] - - switch msgCode { - case portalwire.PING: - pingRequest := &portalwire.Ping{} - err := pingRequest.UnmarshalSSZ(msg[1:]) - if err != nil { - p.Log.Error("failed to unmarshal ping request", "err", err) - return nil - } - - p.Log.Trace("<< PING/"+p.protocolName, "protocol", p.protocolName, "source", id, "pingRequest", pingRequest) - if metrics.Enabled { - p.portalMetrics.messagesReceivedPing.Mark(1) - } - resp, err := p.handlePing(id, pingRequest) - if err != nil { - p.Log.Error("failed to handle ping request", "err", err) - return nil - } - - return resp - case portalwire.FINDNODES: - findNodesRequest := &portalwire.FindNodes{} - err := findNodesRequest.UnmarshalSSZ(msg[1:]) - if err != nil { - p.Log.Error("failed to unmarshal find nodes request", "err", err) - return nil - } - - p.Log.Trace("<< FIND_NODES/"+p.protocolName, "protocol", p.protocolName, "source", id, "findNodesRequest", findNodesRequest) - if metrics.Enabled { - p.portalMetrics.messagesReceivedFindNodes.Mark(1) - } - resp, err := p.handleFindNodes(addr, findNodesRequest) - if err != nil { - p.Log.Error("failed to handle find nodes request", "err", err) - return nil - } - - return resp - case portalwire.FINDCONTENT: - findContentRequest := &portalwire.FindContent{} - err := findContentRequest.UnmarshalSSZ(msg[1:]) - if err != nil { - p.Log.Error("failed to unmarshal find content request", "err", err) - return nil - } - - p.Log.Trace("<< FIND_CONTENT/"+p.protocolName, "protocol", p.protocolName, "source", id, "findContentRequest", findContentRequest) - if metrics.Enabled { - p.portalMetrics.messagesReceivedFindContent.Mark(1) - } - resp, err := p.handleFindContent(id, addr, findContentRequest) - if err != nil { - p.Log.Error("failed to handle find content request", "err", err) - return nil - } - - return resp - case portalwire.OFFER: - offerRequest := &portalwire.Offer{} - err := offerRequest.UnmarshalSSZ(msg[1:]) - if err != nil { - p.Log.Error("failed to unmarshal offer request", "err", err) - return nil - } - - p.Log.Trace("<< OFFER/"+p.protocolName, "protocol", p.protocolName, "source", id, "offerRequest", offerRequest) - if metrics.Enabled { - p.portalMetrics.messagesReceivedOffer.Mark(1) - } - resp, err := p.handleOffer(id, addr, offerRequest) - if err != nil { - p.Log.Error("failed to handle offer request", "err", err) - return nil - } - - return resp - } - - return nil -} - -func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, error) { - pingCustomPayload := &portalwire.PingPongCustomData{} - err := pingCustomPayload.UnmarshalSSZ(ping.CustomPayload) - if err != nil { - return nil, err - } - - p.radiusCache.Set([]byte(id.String()), pingCustomPayload.Radius) - - enrSeq := p.Self().Seq() - radiusBytes, err := p.Radius().MarshalSSZ() - if err != nil { - return nil, err - } - pongCustomPayload := &portalwire.PingPongCustomData{ - Radius: radiusBytes, - } - - pongCustomPayloadBytes, err := pongCustomPayload.MarshalSSZ() - if err != nil { - return nil, err - } - - pong := &portalwire.Pong{ - EnrSeq: enrSeq, - CustomPayload: pongCustomPayloadBytes, - } - - p.Log.Trace(">> PONG/"+p.protocolName, "protocol", p.protocolName, "source", id, "pong", pong) - if metrics.Enabled { - p.portalMetrics.messagesSentPong.Mark(1) - } - pongBytes, err := pong.MarshalSSZ() - - if err != nil { - return nil, err - } - - talkRespBytes := make([]byte, 0, len(pongBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.PONG) - talkRespBytes = append(talkRespBytes, pongBytes...) - - return talkRespBytes, nil -} - -func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *portalwire.FindNodes) ([]byte, error) { - distances := make([]uint, len(request.Distances)) - for i, distance := range request.Distances { - distances[i] = uint(ssz.UnmarshallUint16(distance[:])) - } - - nodes := p.collectTableNodes(fromAddr.IP, distances, portalFindnodesResultLimit) - - nodesOverhead := 1 + 1 + 4 // msg id + total + container offset - maxPayloadSize := maxPacketSize - talkRespOverhead - nodesOverhead - enrOverhead := 4 //per added ENR, 4 bytes offset overhead - - enrs := p.truncateNodes(nodes, maxPayloadSize, enrOverhead) - - nodesMsg := &portalwire.Nodes{ - Total: 1, - Enrs: enrs, - } - - p.Log.Trace(">> NODES/"+p.protocolName, "protocol", p.protocolName, "source", fromAddr, "nodes", nodesMsg) - if metrics.Enabled { - p.portalMetrics.messagesSentNodes.Mark(1) - } - nodesMsgBytes, err := nodesMsg.MarshalSSZ() - if err != nil { - return nil, err - } - - talkRespBytes := make([]byte, 0, len(nodesMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.NODES) - talkRespBytes = append(talkRespBytes, nodesMsgBytes...) - - return talkRespBytes, nil -} - -func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, request *portalwire.FindContent) ([]byte, error) { - contentOverhead := 1 + 1 // msg id + SSZ Union selector - maxPayloadSize := maxPacketSize - talkRespOverhead - contentOverhead - enrOverhead := 4 //per added ENR, 4 bytes offset overhead - var err error - contentKey := request.ContentKey - contentId := p.toContentId(contentKey) - if contentId == nil { - return nil, ErrNilContentKey - } - - var content []byte - content, err = p.storage.Get(contentKey, contentId) - if err != nil && !errors.Is(err, ContentNotFound) { - return nil, err - } - - if errors.Is(err, ContentNotFound) { - closestNodes := p.findNodesCloseToContent(contentId, portalFindnodesResultLimit) - for i, n := range closestNodes { - if n.ID() == id { - closestNodes = append(closestNodes[:i], closestNodes[i+1:]...) - break - } - } - - enrs := p.truncateNodes(closestNodes, maxPayloadSize, enrOverhead) - // TODO fix when no content and no enrs found - if len(enrs) == 0 { - enrs = nil - } - - enrsMsg := &portalwire.Enrs{ - Enrs: enrs, - } - - p.Log.Trace(">> CONTENT_ENRS/"+p.protocolName, "protocol", p.protocolName, "source", addr, "enrs", enrsMsg) - if metrics.Enabled { - p.portalMetrics.messagesSentContent.Mark(1) - } - var enrsMsgBytes []byte - enrsMsgBytes, err = enrsMsg.MarshalSSZ() - if err != nil { - return nil, err - } - - contentMsgBytes := make([]byte, 0, len(enrsMsgBytes)+1) - contentMsgBytes = append(contentMsgBytes, portalwire.ContentEnrsSelector) - contentMsgBytes = append(contentMsgBytes, enrsMsgBytes...) - - talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.CONTENT) - talkRespBytes = append(talkRespBytes, contentMsgBytes...) - - return talkRespBytes, nil - } else if len(content) <= maxPayloadSize { - rawContentMsg := &portalwire.Content{ - Content: content, - } - - p.Log.Trace(">> CONTENT_RAW/"+p.protocolName, "protocol", p.protocolName, "source", addr, "content", rawContentMsg) - if metrics.Enabled { - p.portalMetrics.messagesSentContent.Mark(1) - } - - var rawContentMsgBytes []byte - rawContentMsgBytes, err = rawContentMsg.MarshalSSZ() - if err != nil { - return nil, err - } - - contentMsgBytes := make([]byte, 0, len(rawContentMsgBytes)+1) - contentMsgBytes = append(contentMsgBytes, portalwire.ContentRawSelector) - contentMsgBytes = append(contentMsgBytes, rawContentMsgBytes...) - - talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.CONTENT) - talkRespBytes = append(talkRespBytes, contentMsgBytes...) - - return talkRespBytes, nil - } else { - connectionId := p.connIdGen.GenCid(id, false) - - go func(bctx context.Context, connId *libutp.ConnId) { - var conn *utp.Conn - var connectCtx context.Context - var cancel context.CancelFunc - defer func() { - p.connIdGen.Remove(connectionId) - if conn == nil { - return - } - err := conn.Close() - if err != nil { - p.Log.Error("failed to close connection", "err", err) - } - }() - for { - select { - case <-bctx.Done(): - return - default: - p.Log.Debug("will accept find content conn from: ", "nodeId", id.String(), "source", addr, "connId", connId) - connectCtx, cancel = context.WithTimeout(bctx, defaultUTPConnectTimeout) - conn, err = p.Utp.AcceptWithCid(connectCtx, id, connectionId) - cancel() - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpOutFailConn.Inc(1) - } - p.Log.Error("failed to accept utp connection for handle find content", "connId", connectionId.SendId(), "err", err) - return - } - - err = conn.SetWriteDeadline(time.Now().Add(defaultUTPWriteTimeout)) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpOutFailDeadline.Inc(1) - } - p.Log.Error("failed to set write deadline", "err", err) - return - } - - var n int - n, err = conn.Write(content) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpOutFailWrite.Inc(1) - } - p.Log.Error("failed to write content to utp connection", "err", err) - return - } - - if metrics.Enabled { - p.portalMetrics.utpOutSuccess.Inc(1) - } - p.Log.Trace("wrote content size to utp connection", "n", n) - return - } - } - }(p.closeCtx, connectionId) - - idBuffer := make([]byte, 2) - binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) - connIdMsg := &portalwire.ConnectionId{ - Id: idBuffer, - } - - p.Log.Trace(">> CONTENT_CONNECTION_ID/"+p.protocolName, "protocol", p.protocolName, "source", addr, "connId", connIdMsg) - if metrics.Enabled { - p.portalMetrics.messagesSentContent.Mark(1) - } - var connIdMsgBytes []byte - connIdMsgBytes, err = connIdMsg.MarshalSSZ() - if err != nil { - return nil, err - } - - contentMsgBytes := make([]byte, 0, len(connIdMsgBytes)+1) - contentMsgBytes = append(contentMsgBytes, portalwire.ContentConnIdSelector) - contentMsgBytes = append(contentMsgBytes, connIdMsgBytes...) - - talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.CONTENT) - talkRespBytes = append(talkRespBytes, contentMsgBytes...) - - return talkRespBytes, nil - } -} - -func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *portalwire.Offer) ([]byte, error) { - var err error - contentKeyBitlist := bitfield.NewBitlist(uint64(len(request.ContentKeys))) - if len(p.contentQueue) >= cap(p.contentQueue) { - acceptMsg := &portalwire.Accept{ - ConnectionId: []byte{0, 0}, - ContentKeys: contentKeyBitlist, - } - - p.Log.Trace(">> ACCEPT/"+p.protocolName, "protocol", p.protocolName, "source", addr, "accept", acceptMsg) - if metrics.Enabled { - p.portalMetrics.messagesSentAccept.Mark(1) - } - var acceptMsgBytes []byte - acceptMsgBytes, err = acceptMsg.MarshalSSZ() - if err != nil { - return nil, err - } - - talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) - talkRespBytes = append(talkRespBytes, acceptMsgBytes...) - - return talkRespBytes, nil - } - - contentKeys := make([][]byte, 0) - for i, contentKey := range request.ContentKeys { - contentId := p.toContentId(contentKey) - if contentId != nil { - if inRange(p.Self().ID(), p.Radius(), contentId) { - if _, err = p.storage.Get(contentKey, contentId); err != nil { - contentKeyBitlist.SetBitAt(uint64(i), true) - contentKeys = append(contentKeys, contentKey) - } - } - } else { - return nil, ErrNilContentKey - } - } - - idBuffer := make([]byte, 2) - if contentKeyBitlist.Count() != 0 { - connectionId := p.connIdGen.GenCid(id, false) - - go func(bctx context.Context, connId *libutp.ConnId) { - var conn *utp.Conn - var connectCtx context.Context - var cancel context.CancelFunc - defer func() { - p.connIdGen.Remove(connectionId) - if conn == nil { - return - } - err := conn.Close() - if err != nil { - p.Log.Error("failed to close connection", "err", err) - } - }() - for { - select { - case <-bctx.Done(): - return - default: - p.Log.Debug("will accept offer conn from: ", "source", addr, "connId", connId) - connectCtx, cancel = context.WithTimeout(bctx, defaultUTPConnectTimeout) - conn, err = p.Utp.AcceptWithCid(connectCtx, id, connectionId) - cancel() - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpInFailConn.Inc(1) - } - p.Log.Error("failed to accept utp connection for handle offer", "connId", connectionId.SendId(), "err", err) - return - } - - err = conn.SetReadDeadline(time.Now().Add(defaultUTPReadTimeout)) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpInFailDeadline.Inc(1) - } - p.Log.Error("failed to set read deadline", "err", err) - return - } - // Read ALL the data from the connection until EOF and return it - var data []byte - data, err = io.ReadAll(conn) - if err != nil { - if metrics.Enabled { - p.portalMetrics.utpInFailRead.Inc(1) - } - p.Log.Error("failed to read from utp connection", "err", err) - return - } - p.Log.Trace("<< OFFER_CONTENT/"+p.protocolName, "id", id, "size", len(data), "data", data) - if metrics.Enabled { - p.portalMetrics.messagesReceivedContent.Mark(1) - } - - err = p.handleOfferedContents(id, contentKeys, data) - if err != nil { - p.Log.Error("failed to handle offered Contents", "err", err) - return - } - - if metrics.Enabled { - p.portalMetrics.utpInSuccess.Inc(1) - } - return - } - } - }(p.closeCtx, connectionId) - - binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) - } else { - binary.BigEndian.PutUint16(idBuffer, uint16(0)) - } - - acceptMsg := &portalwire.Accept{ - ConnectionId: idBuffer, - ContentKeys: []byte(contentKeyBitlist), - } - - p.Log.Trace(">> ACCEPT/"+p.protocolName, "protocol", p.protocolName, "source", addr, "accept", acceptMsg) - if metrics.Enabled { - p.portalMetrics.messagesSentAccept.Mark(1) - } - var acceptMsgBytes []byte - acceptMsgBytes, err = acceptMsg.MarshalSSZ() - if err != nil { - return nil, err - } - - talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) - talkRespBytes = append(talkRespBytes, acceptMsgBytes...) - - return talkRespBytes, nil -} - -func (p *PortalProtocol) handleOfferedContents(id enode.ID, keys [][]byte, payload []byte) error { - contents, err := decodeContents(payload) - if err != nil { - if metrics.Enabled { - p.portalMetrics.contentDecodedFalse.Inc(1) - } - return err - } - - keyLen := len(keys) - contentLen := len(contents) - if keyLen != contentLen { - if metrics.Enabled { - p.portalMetrics.contentDecodedFalse.Inc(1) - } - return fmt.Errorf("content keys len %d doesn't match content values len %d", keyLen, contentLen) - } - - contentElement := &ContentElement{ - Node: id, - ContentKeys: keys, - Contents: contents, - } - - p.contentQueue <- contentElement - - if metrics.Enabled { - p.portalMetrics.contentDecodedTrue.Inc(1) - } - return nil -} - -func (p *PortalProtocol) Self() *enode.Node { - return p.localNode.Node() -} - -func (p *PortalProtocol) RequestENR(n *enode.Node) (*enode.Node, error) { - nodes, err := p.findNodes(n, []uint{0}) - if err != nil { - return nil, err - } - if len(nodes) != 1 { - return nil, fmt.Errorf("%d nodes in response for distance zero", len(nodes)) - } - return nodes[0], nil -} - -func (p *PortalProtocol) verifyResponseNode(sender *enode.Node, r *enr.Record, distances []uint, seen map[enode.ID]struct{}) (*enode.Node, error) { - n, err := enode.New(p.validSchemes, r) - if err != nil { - return nil, err - } - if err = netutil.CheckRelayIP(sender.IP(), n.IP()); err != nil { - return nil, err - } - if p.NetRestrict != nil && !p.NetRestrict.Contains(n.IP()) { - return nil, errors.New("not contained in netrestrict list") - } - if n.UDP() <= 1024 { - return nil, ErrLowPort - } - if distances != nil { - nd := enode.LogDist(sender.ID(), n.ID()) - if !slices.Contains(distances, uint(nd)) { - return nil, errors.New("does not match any requested distance") - } - } - if _, ok := seen[n.ID()]; ok { - return nil, fmt.Errorf("duplicate record") - } - seen[n.ID()] = struct{}{} - return n, nil -} - -// lookupRandom looks up a random target. -// This is needed to satisfy the transport interface. -func (p *PortalProtocol) LookupRandom() []*enode.Node { - return p.newRandomLookup(p.closeCtx).Run() -} - -// lookupSelf looks up our own node ID. -// This is needed to satisfy the transport interface. -func (p *PortalProtocol) LookupSelf() []*enode.Node { - return p.newLookup(p.closeCtx, p.Self().ID()).Run() -} - -func (p *PortalProtocol) newRandomLookup(ctx context.Context) *Lookup { - var target enode.ID - _, _ = crand.Read(target[:]) - return p.newLookup(ctx, target) -} - -func (p *PortalProtocol) newLookup(ctx context.Context, target enode.ID) *Lookup { - return NewLookup(ctx, p.table, target, func(n *enode.Node) ([]*enode.Node, error) { - return p.lookupWorker(n, target) - }) -} - -// lookupWorker performs FINDNODE calls against a single node during lookup. -func (p *PortalProtocol) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { - var ( - dists = LookupDistances(target, destNode.ID()) - nodes = NodesByDistance{Target: target} - err error - ) - var r []*enode.Node - - r, err = p.findNodes(destNode, dists) - if errors.Is(err, ErrClosed) { - return nil, err - } - for _, n := range r { - if n.ID() != p.Self().ID() { - isAdded := p.table.AddFoundNode(n, false) - if isAdded { - log.Debug("Node added to bucket", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) - } else { - log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", n.IP(), "port", n.UDP()) - } - nodes.Push(n, portalFindnodesResultLimit) - } - } - return nodes.Entries, err -} - -func (p *PortalProtocol) offerWorker() { - for { - select { - case <-p.closeCtx.Done(): - return - case offerRequestWithNode := <-p.offerQueue: - p.Log.Trace("offerWorker", "offerRequestWithNode", offerRequestWithNode) - _, err := p.offer(offerRequestWithNode.Node, offerRequestWithNode.Request) - if err != nil { - p.Log.Error("failed to offer", "err", err) - } - } - } -} - -func (p *PortalProtocol) truncateNodes(nodes []*enode.Node, maxSize int, enrOverhead int) [][]byte { - res := make([][]byte, 0) - totalSize := 0 - for _, n := range nodes { - enrBytes, err := rlp.EncodeToBytes(n.Record()) - if err != nil { - p.Log.Error("failed to encode n", "err", err) - continue - } - - if totalSize+len(enrBytes)+enrOverhead > maxSize { - break - } else { - res = append(res, enrBytes) - totalSize += len(enrBytes) - } - } - return res -} - -func (p *PortalProtocol) findNodesCloseToContent(contentId []byte, limit int) []*enode.Node { - allNodes := p.table.NodeList() - sort.Slice(allNodes, func(i, j int) bool { - return enode.LogDist(allNodes[i].ID(), enode.ID(contentId)) < enode.LogDist(allNodes[j].ID(), enode.ID(contentId)) - }) - - if len(allNodes) > limit { - allNodes = allNodes[:limit] - } else { - allNodes = allNodes[:] - } - - return allNodes -} - -// Lookup performs a recursive lookup for the given target. -// It returns the closest nodes to target. -func (p *PortalProtocol) Lookup(target enode.ID) []*enode.Node { - return p.newLookup(p.closeCtx, target).Run() -} - -// Resolve searches for a specific Node with the given ID and tries to get the most recent -// version of the Node record for it. It returns n if the Node could not be resolved. -func (p *PortalProtocol) Resolve(n *enode.Node) *enode.Node { - if intable := p.table.GetNode(n.ID()); intable != nil && intable.Seq() > n.Seq() { - n = intable - } - // Try asking directly. This works if the Node is still responding on the endpoint we have. - if resp, err := p.RequestENR(n); err == nil { - return resp - } - // Otherwise do a network lookup. - result := p.Lookup(n.ID()) - for _, rn := range result { - if rn.ID() == n.ID() && rn.Seq() > n.Seq() { - return rn - } - } - return n -} - -// ResolveNodeId searches for a specific Node with the given ID. -// It returns nil if the nodeId could not be resolved. -func (p *PortalProtocol) ResolveNodeId(id enode.ID) *enode.Node { - if id == p.Self().ID() { - p.Log.Debug("Resolve Self Id", "id", id.String()) - return p.Self() - } - - n := p.table.GetNode(id) - if n != nil { - p.Log.Debug("found Id in table and will request enr from the node", "id", id.String()) - // Try asking directly. This works if the Node is still responding on the endpoint we have. - if resp, err := p.RequestENR(n); err == nil { - return resp - } - } - - // Otherwise do a network lookup. - result := p.Lookup(id) - for _, rn := range result { - if rn.ID() == id { - if n != nil && rn.Seq() <= n.Seq() { - return n - } else { - return rn - } - } - } - - return n -} - -func (p *PortalProtocol) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { - var bn []*enode.Node - var nodes []*enode.Node - var processed = make(map[uint]struct{}) - for _, dist := range distances { - // Reject duplicate / invalid distances. - _, seen := processed[dist] - if seen || dist > 256 { - continue - } - processed[dist] = struct{}{} - - checkLive := !p.table.cfg.NoFindnodeLivenessCheck - for _, n := range p.table.AppendBucketNodes(dist, bn[:0], checkLive) { - // Apply some pre-checks to avoid sending invalid nodes. - // Note liveness is checked by appendLiveNodes. - if netutil.CheckRelayIP(rip, n.IP()) != nil { - continue - } - nodes = append(nodes, n) - if len(nodes) >= limit { - return nodes - } - } - } - return nodes -} - -func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bool, error) { - lookupContext, cancel := context.WithCancel(context.Background()) - - resChan := make(chan *traceContentInfoResp, Alpha) - hasResult := int32(0) - - result := ContentInfoResp{} - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - for res := range resChan { - if res.Flag != portalwire.ContentEnrsSelector { - result.Content = res.Content.([]byte) - result.UtpTransfer = res.UtpTransfer - } - } - }() - - NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { - return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) - }).Run() - close(resChan) - - wg.Wait() - if hasResult == 1 { - return result.Content, result.UtpTransfer, nil - } - defer cancel() - return nil, false, ContentNotFound -} - -func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*TraceContentResult, error) { - lookupContext, cancel := context.WithCancel(context.Background()) - // resp channel - resChan := make(chan *traceContentInfoResp, Alpha) - - hasResult := int32(0) - - traceContentRes := &TraceContentResult{} - - selfHexId := "0x" + p.Self().ID().String() - - trace := &Trace{ - Origin: selfHexId, - TargetId: hexutil.Encode(contentId), - StartedAtMs: int(time.Now().UnixMilli()), - Responses: make(map[string]RespByNode), - Metadata: make(map[string]*NodeMetadata), - Cancelled: make([]string, 0), - } - - nodes := p.table.FindnodeByID(enode.ID(contentId), BucketSize, false) - - localResponse := make([]string, 0, len(nodes.Entries)) - for _, node := range nodes.Entries { - id := "0x" + node.ID().String() - localResponse = append(localResponse, id) - } - trace.Responses[selfHexId] = RespByNode{ - DurationMs: 0, - RespondedWith: localResponse, - } - - dis := p.Distance(p.Self().ID(), enode.ID(contentId)) - - trace.Metadata[selfHexId] = &NodeMetadata{ - Enr: p.Self().String(), - Distance: hexutil.Encode(dis[:]), - } - - var wg sync.WaitGroup - wg.Add(1) - - go func() { - defer wg.Done() - for res := range resChan { - node := res.Node - hexId := "0x" + node.ID().String() - dis := p.Distance(node.ID(), enode.ID(contentId)) - p.Log.Debug("reveice res", "id", hexId, "flag", res.Flag) - trace.Metadata[hexId] = &NodeMetadata{ - Enr: node.String(), - Distance: hexutil.Encode(dis[:]), - } - // no content return - if traceContentRes.Content == "" { - if res.Flag == portalwire.ContentRawSelector || res.Flag == portalwire.ContentConnIdSelector { - trace.ReceivedFrom = hexId - content := res.Content.([]byte) - traceContentRes.Content = hexutil.Encode(content) - traceContentRes.UtpTransfer = res.UtpTransfer - trace.Responses[hexId] = RespByNode{} - } else { - nodes := res.Content.([]*enode.Node) - respByNode := RespByNode{ - RespondedWith: make([]string, 0, len(nodes)), - } - for _, node := range nodes { - idInner := "0x" + node.ID().String() - respByNode.RespondedWith = append(respByNode.RespondedWith, idInner) - if _, ok := trace.Metadata[idInner]; !ok { - dis := p.Distance(node.ID(), enode.ID(contentId)) - trace.Metadata[idInner] = &NodeMetadata{ - Enr: node.String(), - Distance: hexutil.Encode(dis[:]), - } - } - trace.Responses[hexId] = respByNode - } - } - } else { - trace.Cancelled = append(trace.Cancelled, hexId) - } - } - }() - - lookup := NewLookup(lookupContext, p.table, enode.ID(contentId), func(n *enode.Node) ([]*enode.Node, error) { - return p.contentLookupWorker(n, contentKey, resChan, cancel, &hasResult) - }) - lookup.Run() - close(resChan) - - wg.Wait() - if hasResult == 0 { - cancel() - } - traceContentRes.Trace = *trace - - return traceContentRes, nil -} - -func (p *PortalProtocol) contentLookupWorker(n *enode.Node, contentKey []byte, resChan chan<- *traceContentInfoResp, cancel context.CancelFunc, done *int32) ([]*enode.Node, error) { - wrapedNode := make([]*enode.Node, 0) - flag, content, err := p.findContent(n, contentKey) - if err != nil { - return nil, err - } - p.Log.Debug("traceContentLookupWorker reveice response", "ip", n.IP().String(), "flag", flag) - - switch flag { - case portalwire.ContentRawSelector, portalwire.ContentConnIdSelector: - content, ok := content.([]byte) - if !ok { - return wrapedNode, fmt.Errorf("failed to assert to raw content, value is: %v", content) - } - res := &traceContentInfoResp{ - Node: n, - Flag: flag, - Content: content, - UtpTransfer: false, - } - if flag == portalwire.ContentConnIdSelector { - res.UtpTransfer = true - } - if atomic.CompareAndSwapInt32(done, 0, 1) { - p.Log.Debug("contentLookupWorker find content", "ip", n.IP().String(), "port", n.UDP()) - resChan <- res - cancel() - } - return wrapedNode, err - case portalwire.ContentEnrsSelector: - nodes, ok := content.([]*enode.Node) - if !ok { - return wrapedNode, fmt.Errorf("failed to assert to enrs content, value is: %v", content) - } - resChan <- &traceContentInfoResp{ - Node: n, - Flag: flag, - Content: content, - UtpTransfer: false, - } - return nodes, nil - } - return wrapedNode, nil -} - -func (p *PortalProtocol) ToContentId(contentKey []byte) []byte { - return p.toContentId(contentKey) -} - -func (p *PortalProtocol) InRange(contentId []byte) bool { - return inRange(p.Self().ID(), p.Radius(), contentId) -} - -func (p *PortalProtocol) Get(contentKey []byte, contentId []byte) ([]byte, error) { - content, err := p.storage.Get(contentKey, contentId) - p.Log.Trace("get local storage", "contentId", hexutil.Encode(contentId), "content", hexutil.Encode(content), "err", err) - return content, err -} - -func (p *PortalProtocol) Put(contentKey []byte, contentId []byte, content []byte) error { - err := p.storage.Put(contentKey, contentId, content) - p.Log.Trace("put local storage", "contentId", hexutil.Encode(contentId), "content", hexutil.Encode(content), "err", err) - return err -} - -func (p *PortalProtocol) GetContent() chan *ContentElement { - return p.contentQueue -} - -func (p *PortalProtocol) Gossip(srcNodeId *enode.ID, contentKeys [][]byte, content [][]byte) (int, error) { - if len(content) == 0 { - return 0, errors.New("empty content") - } - - contentList := make([]*ContentEntry, 0, portalwire.ContentKeysLimit) - for i := 0; i < len(content); i++ { - contentEntry := &ContentEntry{ - ContentKey: contentKeys[i], - Content: content[i], - } - contentList = append(contentList, contentEntry) - } - - contentId := p.toContentId(contentKeys[0]) - if contentId == nil { - return 0, ErrNilContentKey - } - - maxClosestNodes := 4 - maxFartherNodes := 4 - closestLocalNodes := p.findNodesCloseToContent(contentId, 32) - p.Log.Debug("closest local nodes", "count", len(closestLocalNodes)) - - gossipNodes := make([]*enode.Node, 0) - for _, n := range closestLocalNodes { - radius, found := p.radiusCache.HasGet(nil, []byte(n.ID().String())) - if found { - p.Log.Debug("found closest local nodes", "nodeId", n.ID(), "addr", n.IPAddr().String()) - nodeRadius := new(uint256.Int) - err := nodeRadius.UnmarshalSSZ(radius) - if err != nil { - return 0, err - } - if inRange(n.ID(), nodeRadius, contentId) { - if srcNodeId == nil { - gossipNodes = append(gossipNodes, n) - } else if n.ID() != *srcNodeId { - gossipNodes = append(gossipNodes, n) - } - } - } - } - - if len(gossipNodes) == 0 { - return 0, nil - } - - var finalGossipNodes []*enode.Node - if len(gossipNodes) > maxClosestNodes { - fartherNodes := gossipNodes[maxClosestNodes:] - rand.Shuffle(len(fartherNodes), func(i, j int) { - fartherNodes[i], fartherNodes[j] = fartherNodes[j], fartherNodes[i] - }) - finalGossipNodes = append(gossipNodes[:maxClosestNodes], fartherNodes[:min(maxFartherNodes, len(fartherNodes))]...) - } else { - finalGossipNodes = gossipNodes - } - - for _, n := range finalGossipNodes { - transientOfferRequest := &TransientOfferRequest{ - Contents: contentList, - } - - offerRequest := &OfferRequest{ - Kind: TransientOfferRequestKind, - Request: transientOfferRequest, - } - - offerRequestWithNode := &OfferRequestWithNode{ - Node: n, - Request: offerRequest, - } - p.offerQueue <- offerRequestWithNode - } - - return len(finalGossipNodes), nil -} - -func (p *PortalProtocol) Distance(a, b enode.ID) enode.ID { - res := [32]byte{} - for i := range a { - res[i] = a[i] ^ b[i] - } - return res -} - -func inRange(nodeId enode.ID, nodeRadius *uint256.Int, contentId []byte) bool { - distance := enode.LogDist(nodeId, enode.ID(contentId)) - disBig := new(big.Int).SetInt64(int64(distance)) - return nodeRadius.CmpBig(disBig) > 0 -} - -func encodeContents(contents [][]byte) ([]byte, error) { - contentsBytes := make([]byte, 0) - for _, content := range contents { - contentLen := len(content) - contentLenBytes := leb128.EncodeUint32(uint32(contentLen)) - contentsBytes = append(contentsBytes, contentLenBytes...) - contentsBytes = append(contentsBytes, content...) - } - - return contentsBytes, nil -} - -func decodeContents(payload []byte) ([][]byte, error) { - contents := make([][]byte, 0) - buffer := bytes.NewBuffer(payload) - - for { - contentLen, contentLenLen, err := leb128.DecodeUint32(bytes.NewReader(buffer.Bytes())) - if err != nil { - if errors.Is(err, io.EOF) { - return contents, nil - } - return nil, err - } - - buffer.Next(int(contentLenLen)) - - content := make([]byte, contentLen) - _, err = buffer.Read(content) - if err != nil { - if errors.Is(err, io.EOF) { - return contents, nil - } - return nil, err - } - - contents = append(contents, content) - } -} - -func getContentKeys(request *OfferRequest) [][]byte { - if request.Kind == TransientOfferRequestKind { - contentKeys := make([][]byte, 0) - contents := request.Request.(*TransientOfferRequest).Contents - for _, content := range contents { - contentKeys = append(contentKeys, content.ContentKey) - } - - return contentKeys - } else { - return request.Request.(*PersistOfferRequest).ContentKeys - } -} diff --git a/p2p/discover/portal_protocol_metrics.go b/p2p/discover/portal_protocol_metrics.go deleted file mode 100644 index 0bff030f5a6b..000000000000 --- a/p2p/discover/portal_protocol_metrics.go +++ /dev/null @@ -1,67 +0,0 @@ -package discover - -import "github.com/ethereum/go-ethereum/metrics" - -type portalMetrics struct { - messagesReceivedAccept metrics.Meter - messagesReceivedNodes metrics.Meter - messagesReceivedFindNodes metrics.Meter - messagesReceivedFindContent metrics.Meter - messagesReceivedContent metrics.Meter - messagesReceivedOffer metrics.Meter - messagesReceivedPing metrics.Meter - messagesReceivedPong metrics.Meter - - messagesSentAccept metrics.Meter - messagesSentNodes metrics.Meter - messagesSentFindNodes metrics.Meter - messagesSentFindContent metrics.Meter - messagesSentContent metrics.Meter - messagesSentOffer metrics.Meter - messagesSentPing metrics.Meter - messagesSentPong metrics.Meter - - utpInFailConn metrics.Counter - utpInFailRead metrics.Counter - utpInFailDeadline metrics.Counter - utpInSuccess metrics.Counter - - utpOutFailConn metrics.Counter - utpOutFailWrite metrics.Counter - utpOutFailDeadline metrics.Counter - utpOutSuccess metrics.Counter - - contentDecodedTrue metrics.Counter - contentDecodedFalse metrics.Counter -} - -func newPortalMetrics(protocolName string) *portalMetrics { - return &portalMetrics{ - messagesReceivedAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/accept", nil), - messagesReceivedNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/nodes", nil), - messagesReceivedFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_nodes", nil), - messagesReceivedFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_content", nil), - messagesReceivedContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/content", nil), - messagesReceivedOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/offer", nil), - messagesReceivedPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/ping", nil), - messagesReceivedPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/pong", nil), - messagesSentAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/accept", nil), - messagesSentNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/nodes", nil), - messagesSentFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_nodes", nil), - messagesSentFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_content", nil), - messagesSentContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/content", nil), - messagesSentOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/offer", nil), - messagesSentPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/ping", nil), - messagesSentPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/pong", nil), - utpInFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_conn", nil), - utpInFailRead: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_read", nil), - utpInFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_deadline", nil), - utpInSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/success", nil), - utpOutFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_conn", nil), - utpOutFailWrite: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_write", nil), - utpOutFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_deadline", nil), - utpOutSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/success", nil), - contentDecodedTrue: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/true", nil), - contentDecodedFalse: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/false", nil), - } -} diff --git a/p2p/discover/portal_protocol_test.go b/p2p/discover/portal_protocol_test.go deleted file mode 100644 index c212677dab75..000000000000 --- a/p2p/discover/portal_protocol_test.go +++ /dev/null @@ -1,503 +0,0 @@ -package discover - -import ( - "context" - "crypto/rand" - "errors" - "fmt" - "io" - "net" - "sync" - "testing" - "time" - - "github.com/ethereum/go-ethereum/portalnetwork/storage" - "github.com/optimism-java/utp-go" - "github.com/optimism-java/utp-go/libutp" - "github.com/prysmaticlabs/go-bitfield" - "golang.org/x/exp/slices" - - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/internal/testlog" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" - "github.com/ethereum/go-ethereum/p2p/enode" - assert "github.com/stretchr/testify/require" -) - -func setupLocalPortalNode(addr string, bootNodes []*enode.Node) (*PortalProtocol, error) { - conf := DefaultPortalProtocolConfig() - conf.NAT = nil - if addr != "" { - conf.ListenAddr = addr - } - if bootNodes != nil { - conf.BootstrapNodes = bootNodes - } - - addr1, err := net.ResolveUDPAddr("udp", conf.ListenAddr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", addr1) - if err != nil { - return nil, err - } - - privKey := newkey() - - discCfg := Config{ - PrivateKey: privKey, - NetRestrict: conf.NetRestrict, - Bootnodes: conf.BootstrapNodes, - } - - nodeDB, err := enode.OpenDB(conf.NodeDBPath) - if err != nil { - return nil, err - } - - localNode := enode.NewLocalNode(nodeDB, privKey) - localNode.SetFallbackIP(net.IP{127, 0, 0, 1}) - localNode.Set(Tag) - - if conf.NAT == nil { - var addrs []net.Addr - addrs, err = net.InterfaceAddrs() - - if err != nil { - return nil, err - } - - for _, address := range addrs { - // check ip addr is loopback addr - if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - localNode.SetStaticIP(ipnet.IP) - break - } - } - } - } - - discV5, err := ListenV5(conn, localNode, discCfg) - if err != nil { - return nil, err - } - utpSocket := NewPortalUtp(context.Background(), conf, discV5, conn) - - contentQueue := make(chan *ContentElement, 50) - portalProtocol, err := NewPortalProtocol( - conf, - portalwire.History, - privKey, - conn, - localNode, - discV5, - utpSocket, - &storage.MockStorage{Db: make(map[string][]byte)}, - contentQueue) - if err != nil { - return nil, err - } - - return portalProtocol, nil -} - -func TestPortalWireProtocolUdp(t *testing.T) { - node1, err := setupLocalPortalNode(":8777", nil) - assert.NoError(t, err) - node1.Log = testlog.Logger(t, log.LvlTrace) - err = node1.Start() - assert.NoError(t, err) - - node2, err := setupLocalPortalNode(":8778", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node2.Log = testlog.Logger(t, log.LvlTrace) - err = node2.Start() - assert.NoError(t, err) - time.Sleep(12 * time.Second) - - node3, err := setupLocalPortalNode(":8779", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node3.Log = testlog.Logger(t, log.LvlTrace) - err = node3.Start() - assert.NoError(t, err) - time.Sleep(12 * time.Second) - - cid1 := libutp.ReceConnId(12) - cid2 := libutp.ReceConnId(116) - cliSendMsgWithCid1 := "there are connection id : 12!" - cliSendMsgWithCid2 := "there are connection id: 116!" - - serverEchoWithCid := "accept connection sends back msg: echo" - - largeTestContent := make([]byte, 1199) - _, err = rand.Read(largeTestContent) - assert.NoError(t, err) - - var workGroup sync.WaitGroup - var acceptGroup sync.WaitGroup - workGroup.Add(4) - acceptGroup.Add(1) - go func() { - var acceptConn *utp.Conn - defer func() { - workGroup.Done() - _ = acceptConn.Close() - }() - acceptConn, err := node1.Utp.AcceptWithCid(context.Background(), node2.localNode.ID(), cid1) - if err != nil { - panic(err) - } - acceptGroup.Done() - buf := make([]byte, 100) - n, err := acceptConn.Read(buf) - if err != nil && err != io.EOF { - panic(err) - } - assert.Equal(t, cliSendMsgWithCid1, string(buf[:n])) - _, err = acceptConn.Write([]byte(serverEchoWithCid)) - if err != nil { - panic(err) - } - }() - go func() { - var connId2Conn net.Conn - defer func() { - workGroup.Done() - _ = connId2Conn.Close() - }() - connId2Conn, err := node1.Utp.AcceptWithCid(context.Background(), node2.localNode.ID(), cid2) - if err != nil { - panic(err) - } - buf := make([]byte, 100) - n, err := connId2Conn.Read(buf) - if err != nil && err != io.EOF { - panic(err) - } - assert.Equal(t, cliSendMsgWithCid2, string(buf[:n])) - - _, err = connId2Conn.Write(largeTestContent) - if err != nil { - panic(err) - } - }() - - go func() { - var connWithConnId net.Conn - defer func() { - workGroup.Done() - if connWithConnId != nil { - _ = connWithConnId.Close() - } - }() - connWithConnId, err = node2.Utp.DialWithCid(context.Background(), node1.localNode.Node(), cid1.SendId()) - if err != nil { - panic(err) - } - _, err = connWithConnId.Write([]byte(cliSendMsgWithCid1)) - if err != nil && err != io.EOF { - panic(err) - } - buf := make([]byte, 100) - n, err := connWithConnId.Read(buf) - if err != nil && err != io.EOF { - panic(err) - } - assert.Equal(t, serverEchoWithCid, string(buf[:n])) - }() - go func() { - var ConnId2Conn net.Conn - defer func() { - workGroup.Done() - if ConnId2Conn != nil { - _ = ConnId2Conn.Close() - } - }() - ConnId2Conn, err = node2.Utp.DialWithCid(context.Background(), node1.localNode.Node(), cid2.SendId()) - if err != nil && err != io.EOF { - panic(err) - } - _, err = ConnId2Conn.Write([]byte(cliSendMsgWithCid2)) - if err != nil { - panic(err) - } - - data := make([]byte, 0) - buf := make([]byte, 1024) - for { - var n int - n, err = ConnId2Conn.Read(buf) - if err != nil { - if errors.Is(err, io.EOF) { - break - } - } - data = append(data, buf[:n]...) - } - assert.Equal(t, largeTestContent, data) - }() - workGroup.Wait() - node1.Stop() - node2.Stop() - node3.Stop() -} - -func TestPortalWireProtocol(t *testing.T) { - node1, err := setupLocalPortalNode(":7777", nil) - assert.NoError(t, err) - node1.Log = testlog.Logger(t, log.LevelDebug) - err = node1.Start() - assert.NoError(t, err) - - node2, err := setupLocalPortalNode(":7778", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node2.Log = testlog.Logger(t, log.LevelDebug) - err = node2.Start() - assert.NoError(t, err) - - time.Sleep(12 * time.Second) - - node3, err := setupLocalPortalNode(":7779", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node3.Log = testlog.Logger(t, log.LevelDebug) - err = node3.Start() - assert.NoError(t, err) - - time.Sleep(12 * time.Second) - - slices.ContainsFunc(node1.table.NodeList(), func(n *enode.Node) bool { - return n.ID() == node2.localNode.Node().ID() - }) - slices.ContainsFunc(node1.table.NodeList(), func(n *enode.Node) bool { - return n.ID() == node3.localNode.Node().ID() - }) - - slices.ContainsFunc(node2.table.NodeList(), func(n *enode.Node) bool { - return n.ID() == node1.localNode.Node().ID() - }) - slices.ContainsFunc(node2.table.NodeList(), func(n *enode.Node) bool { - return n.ID() == node3.localNode.Node().ID() - }) - - slices.ContainsFunc(node3.table.NodeList(), func(n *enode.Node) bool { - return n.ID() == node1.localNode.Node().ID() - }) - slices.ContainsFunc(node3.table.NodeList(), func(n *enode.Node) bool { - return n.ID() == node2.localNode.Node().ID() - }) - - err = node1.storage.Put(nil, node1.toContentId([]byte("test_key")), []byte("test_value")) - assert.NoError(t, err) - - flag, content, err := node2.findContent(node1.localNode.Node(), []byte("test_key")) - assert.NoError(t, err) - assert.Equal(t, portalwire.ContentRawSelector, flag) - assert.Equal(t, []byte("test_value"), content) - - flag, content, err = node2.findContent(node3.localNode.Node(), []byte("test_key")) - assert.NoError(t, err) - assert.Equal(t, portalwire.ContentEnrsSelector, flag) - assert.Equal(t, 1, len(content.([]*enode.Node))) - assert.Equal(t, node1.localNode.Node().ID(), content.([]*enode.Node)[0].ID()) - - // create a byte slice of length 1199 and fill it with random data - // this will be used as a test content - largeTestContent := make([]byte, 2000) - _, err = rand.Read(largeTestContent) - assert.NoError(t, err) - - err = node1.storage.Put(nil, node1.toContentId([]byte("large_test_key")), largeTestContent) - assert.NoError(t, err) - - flag, content, err = node2.findContent(node1.localNode.Node(), []byte("large_test_key")) - assert.NoError(t, err) - assert.Equal(t, largeTestContent, content) - assert.Equal(t, portalwire.ContentConnIdSelector, flag) - - testEntry1 := &ContentEntry{ - ContentKey: []byte("test_entry1"), - Content: []byte("test_entry1_content"), - } - - testEntry2 := &ContentEntry{ - ContentKey: []byte("test_entry2"), - Content: []byte("test_entry2_content"), - } - - testTransientOfferRequest := &TransientOfferRequest{ - Contents: []*ContentEntry{testEntry1, testEntry2}, - } - - offerRequest := &OfferRequest{ - Kind: TransientOfferRequestKind, - Request: testTransientOfferRequest, - } - - contentKeys, err := node1.offer(node3.localNode.Node(), offerRequest) - assert.Equal(t, uint64(2), bitfield.Bitlist(contentKeys).Count()) - assert.NoError(t, err) - - contentElement := <-node3.contentQueue - assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) - assert.Equal(t, testEntry1.ContentKey, contentElement.ContentKeys[0]) - assert.Equal(t, testEntry1.Content, contentElement.Contents[0]) - assert.Equal(t, testEntry2.ContentKey, contentElement.ContentKeys[1]) - assert.Equal(t, testEntry2.Content, contentElement.Contents[1]) - - testGossipContentKeys := [][]byte{[]byte("test_gossip_content_keys"), []byte("test_gossip_content_keys2")} - testGossipContent := [][]byte{[]byte("test_gossip_content"), []byte("test_gossip_content2")} - id := node1.Self().ID() - gossip, err := node1.Gossip(&id, testGossipContentKeys, testGossipContent) - assert.NoError(t, err) - assert.Equal(t, 2, gossip) - - contentElement = <-node2.contentQueue - assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) - assert.Equal(t, testGossipContentKeys[0], contentElement.ContentKeys[0]) - assert.Equal(t, testGossipContent[0], contentElement.Contents[0]) - assert.Equal(t, testGossipContentKeys[1], contentElement.ContentKeys[1]) - assert.Equal(t, testGossipContent[1], contentElement.Contents[1]) - - contentElement = <-node3.contentQueue - assert.Equal(t, node1.localNode.Node().ID(), contentElement.Node) - assert.Equal(t, testGossipContentKeys[0], contentElement.ContentKeys[0]) - assert.Equal(t, testGossipContent[0], contentElement.Contents[0]) - assert.Equal(t, testGossipContentKeys[1], contentElement.ContentKeys[1]) - assert.Equal(t, testGossipContent[1], contentElement.Contents[1]) - - node1.Stop() - node2.Stop() - node3.Stop() -} - -func TestCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - - go func(ctx context.Context) { - defer func() { - t.Log("goroutine cancel") - }() - - time.Sleep(time.Second * 5) - }(ctx) - - cancel() - t.Log("after main cancel") - - time.Sleep(time.Second * 3) -} - -func TestContentLookup(t *testing.T) { - node1, err := setupLocalPortalNode(":17777", nil) - assert.NoError(t, err) - node1.Log = testlog.Logger(t, log.LvlTrace) - err = node1.Start() - assert.NoError(t, err) - - node2, err := setupLocalPortalNode(":17778", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node2.Log = testlog.Logger(t, log.LvlTrace) - err = node2.Start() - assert.NoError(t, err) - fmt.Println(node2.localNode.Node().String()) - - node3, err := setupLocalPortalNode(":17779", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node3.Log = testlog.Logger(t, log.LvlTrace) - err = node3.Start() - assert.NoError(t, err) - - defer func() { - node1.Stop() - node2.Stop() - node3.Stop() - }() - - contentKey := []byte{0x3, 0x4} - content := []byte{0x1, 0x2} - contentId := node1.toContentId(contentKey) - - err = node3.storage.Put(nil, contentId, content) - assert.NoError(t, err) - - res, _, err := node1.ContentLookup(contentKey, contentId) - assert.NoError(t, err) - assert.Equal(t, res, content) - - nonExist := []byte{0x2, 0x4} - res, _, err = node1.ContentLookup(nonExist, node1.toContentId(nonExist)) - assert.Equal(t, ContentNotFound, err) - assert.Nil(t, res) -} - -func TestTraceContentLookup(t *testing.T) { - node1, err := setupLocalPortalNode(":17787", nil) - assert.NoError(t, err) - node1.Log = testlog.Logger(t, log.LvlTrace) - err = node1.Start() - assert.NoError(t, err) - - node2, err := setupLocalPortalNode(":17788", []*enode.Node{node1.localNode.Node()}) - assert.NoError(t, err) - node2.Log = testlog.Logger(t, log.LvlTrace) - err = node2.Start() - assert.NoError(t, err) - - node3, err := setupLocalPortalNode(":17789", []*enode.Node{node2.localNode.Node()}) - assert.NoError(t, err) - node3.Log = testlog.Logger(t, log.LvlTrace) - err = node3.Start() - assert.NoError(t, err) - - defer node1.Stop() - defer node2.Stop() - defer node3.Stop() - - contentKey := []byte{0x3, 0x4} - content := []byte{0x1, 0x2} - contentId := node1.toContentId(contentKey) - - err = node1.storage.Put(nil, contentId, content) - assert.NoError(t, err) - - node1Id := hexutil.Encode(node1.Self().ID().Bytes()) - node2Id := hexutil.Encode(node2.Self().ID().Bytes()) - node3Id := hexutil.Encode(node3.Self().ID().Bytes()) - - res, err := node3.TraceContentLookup(contentKey, contentId) - assert.NoError(t, err) - assert.Equal(t, res.Content, hexutil.Encode(content)) - assert.Equal(t, res.UtpTransfer, false) - assert.Equal(t, res.Trace.Origin, node3Id) - assert.Equal(t, res.Trace.TargetId, hexutil.Encode(contentId)) - assert.Equal(t, res.Trace.ReceivedFrom, node1Id) - - // check nodeMeta - node1Meta := res.Trace.Metadata[node1Id] - assert.Equal(t, node1Meta.Enr, node1.Self().String()) - dis := node1.Distance(node1.Self().ID(), enode.ID(contentId)) - assert.Equal(t, node1Meta.Distance, hexutil.Encode(dis[:])) - - node2Meta := res.Trace.Metadata[node2Id] - assert.Equal(t, node2Meta.Enr, node2.Self().String()) - dis = node2.Distance(node2.Self().ID(), enode.ID(contentId)) - assert.Equal(t, node2Meta.Distance, hexutil.Encode(dis[:])) - - node3Meta := res.Trace.Metadata[node3Id] - assert.Equal(t, node3Meta.Enr, node3.Self().String()) - dis = node3.Distance(node3.Self().ID(), enode.ID(contentId)) - assert.Equal(t, node3Meta.Distance, hexutil.Encode(dis[:])) - - // check response - node3Response := res.Trace.Responses[node3Id] - assert.Equal(t, node3Response.RespondedWith, []string{node2Id}) - - node2Response := res.Trace.Responses[node2Id] - assert.Equal(t, node2Response.RespondedWith, []string{node1Id}) - - node1Response := res.Trace.Responses[node1Id] - assert.Equal(t, node1Response.RespondedWith, ([]string)(nil)) -} diff --git a/p2p/discover/portal_utp.go b/p2p/discover/portal_utp.go deleted file mode 100644 index 589bd2bd15fe..000000000000 --- a/p2p/discover/portal_utp.go +++ /dev/null @@ -1,138 +0,0 @@ -package discover - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" - "github.com/ethereum/go-ethereum/p2p/discover/v5wire" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/netutil" - "github.com/optimism-java/utp-go" - "github.com/optimism-java/utp-go/libutp" - "go.uber.org/zap" -) - -type PortalUtp struct { - ctx context.Context - log log.Logger - discV5 *UDPv5 - conn UDPConn - ListenAddr string - listener *utp.Listener - utpSm *utp.SocketManager - packetRouter *utp.PacketRouter - lAddr *utp.Addr - - startOnce sync.Once -} - -func NewPortalUtp(ctx context.Context, config *PortalProtocolConfig, discV5 *UDPv5, conn UDPConn) *PortalUtp { - return &PortalUtp{ - ctx: ctx, - log: log.New("protocol", "utp", "local", conn.LocalAddr().String()), - discV5: discV5, - conn: conn, - ListenAddr: config.ListenAddr, - } -} - -func (p *PortalUtp) Start() error { - var err error - go p.startOnce.Do(func() { - var logger *zap.Logger - if p.log.Enabled(p.ctx, log.LevelDebug) || p.log.Enabled(p.ctx, log.LevelTrace) { - logger, err = zap.NewDevelopmentConfig().Build() - } else { - logger, err = zap.NewProductionConfig().Build() - } - if err != nil { - return - } - - laddr := p.getLocalAddr() - p.packetRouter = utp.NewPacketRouter(p.packetRouterFunc) - p.utpSm, err = utp.NewSocketManagerWithOptions( - "utp", - laddr, - utp.WithContext(p.ctx), - utp.WithLogger(logger.Named(p.ListenAddr)), - utp.WithPacketRouter(p.packetRouter), - utp.WithMaxPacketSize(1145)) - if err != nil { - return - } - p.listener, err = utp.ListenUTPOptions("utp", (*utp.Addr)(laddr), utp.WithSocketManager(p.utpSm)) - if err != nil { - return - } - p.lAddr = p.listener.Addr().(*utp.Addr) - - // register discv5 listener - p.discV5.RegisterTalkHandler(string(portalwire.Utp), p.handleUtpTalkRequest) - }) - - return err -} - -func (p *PortalUtp) Stop() { - err := p.listener.Close() - if err != nil { - p.log.Error("close utp listener has error", "error", err) - } - p.discV5.Close() -} - -func (p *PortalUtp) DialWithCid(ctx context.Context, dest *enode.Node, connId uint16) (net.Conn, error) { - raddr := &utp.Addr{IP: dest.IP(), Port: dest.UDP()} - p.log.Debug("will connect to: ", "nodeId", dest.ID().String(), "connId", connId) - conn, err := utp.DialUTPOptions("utp", p.lAddr, raddr, utp.WithContext(ctx), utp.WithSocketManager(p.utpSm), utp.WithConnId(connId)) - return conn, err -} - -func (p *PortalUtp) Dial(ctx context.Context, dest *enode.Node) (net.Conn, error) { - raddr := &utp.Addr{IP: dest.IP(), Port: dest.UDP()} - p.log.Info("will connect to: ", "addr", raddr.String()) - conn, err := utp.DialUTPOptions("utp", p.lAddr, raddr, utp.WithContext(ctx), utp.WithSocketManager(p.utpSm)) - return conn, err -} - -func (p *PortalUtp) AcceptWithCid(ctx context.Context, nodeId enode.ID, cid *libutp.ConnId) (*utp.Conn, error) { - p.log.Debug("will accept from: ", "nodeId", nodeId.String(), "sendId", cid.SendId(), "recvId", cid.RecvId()) - return p.listener.AcceptUTPContext(ctx, nodeId, cid) -} - -func (p *PortalUtp) Accept(ctx context.Context) (*utp.Conn, error) { - return p.listener.AcceptUTPContext(ctx, enode.ID{}, nil) -} - -func (p *PortalUtp) getLocalAddr() *net.UDPAddr { - laddr := p.conn.LocalAddr().(*net.UDPAddr) - p.log.Debug("UDP listener up", "addr", laddr) - return laddr -} - -func (p *PortalUtp) packetRouterFunc(buf []byte, id enode.ID, addr *net.UDPAddr) (int, error) { - p.log.Info("will send to target data", "nodeId", id.String(), "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) - - if n, ok := p.discV5.GetCachedNode(addr.String()); ok { - //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) - req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf} - p.discV5.SendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) - - return len(buf), nil - } else { - p.log.Warn("not found target node info", "ip", addr.IP.To4().String(), "port", addr.Port, "bufLength", len(buf)) - return 0, fmt.Errorf("not found target node id") - } -} - -func (p *PortalUtp) handleUtpTalkRequest(id enode.ID, addr *net.UDPAddr, msg []byte) []byte { - p.log.Trace("receive utp data", "nodeId", id.String(), "addr", addr, "msg-length", len(msg)) - p.packetRouter.ReceiveMessage(msg, &utp.NodeInfo{Id: id, Addr: addr}) - return []byte("") -} diff --git a/portalnetwork/beacon/api.go b/portalnetwork/beacon/api.go index d62ed49bf4f6..e45978048e20 100644 --- a/portalnetwork/beacon/api.go +++ b/portalnetwork/beacon/api.go @@ -1,14 +1,14 @@ package beacon import ( - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" ) type API struct { - *discover.PortalProtocolAPI + *portalwire.PortalProtocolAPI } -func (p *API) BeaconRoutingTableInfo() *discover.RoutingTableInfo { +func (p *API) BeaconRoutingTableInfo() *portalwire.RoutingTableInfo { return p.RoutingTableInfo() } @@ -28,7 +28,7 @@ func (p *API) BeaconLookupEnr(nodeId string) (string, error) { return p.LookupEnr(nodeId) } -func (p *API) BeaconPing(enr string) (*discover.PortalPongResp, error) { +func (p *API) BeaconPing(enr string) (*portalwire.PortalPongResp, error) { return p.Ping(enr) } @@ -48,7 +48,7 @@ func (p *API) BeaconRecursiveFindNodes(nodeId string) ([]string, error) { return p.RecursiveFindNodes(nodeId) } -func (p *API) BeaconGetContent(contentKeyHex string) (*discover.ContentInfo, error) { +func (p *API) BeaconGetContent(contentKeyHex string) (*portalwire.ContentInfo, error) { return p.RecursiveFindContent(contentKeyHex) } @@ -64,11 +64,11 @@ func (p *API) BeaconGossip(contentKeyHex, contentHex string) (int, error) { return p.Gossip(contentKeyHex, contentHex) } -func (p *API) BeaconTraceGetContent(contentKeyHex string) (*discover.TraceContentResult, error) { +func (p *API) BeaconTraceGetContent(contentKeyHex string) (*portalwire.TraceContentResult, error) { return p.TraceRecursiveFindContent(contentKeyHex) } -func NewBeaconNetworkAPI(BeaconAPI *discover.PortalProtocolAPI) *API { +func NewBeaconNetworkAPI(BeaconAPI *portalwire.PortalProtocolAPI) *API { return &API{ BeaconAPI, } diff --git a/portalnetwork/beacon/beacon_network.go b/portalnetwork/beacon/beacon_network.go index e5518d98ebff..16587f954459 100644 --- a/portalnetwork/beacon/beacon_network.go +++ b/portalnetwork/beacon/beacon_network.go @@ -9,7 +9,7 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" ssz "github.com/ferranbt/fastssz" "github.com/protolambda/zrnt/eth2/beacon/common" @@ -28,7 +28,7 @@ const ( ) type BeaconNetwork struct { - portalProtocol *discover.PortalProtocol + portalProtocol *portalwire.PortalProtocol spec *common.Spec log log.Logger closeCtx context.Context @@ -36,7 +36,7 @@ type BeaconNetwork struct { lightClient *ConsensusLightClient } -func NewBeaconNetwork(portalProtocol *discover.PortalProtocol) *BeaconNetwork { +func NewBeaconNetwork(portalProtocol *portalwire.PortalProtocol) *BeaconNetwork { ctx, cancel := context.WithCancel(context.Background()) return &BeaconNetwork{ diff --git a/portalnetwork/beacon/beacon_network_test.go b/portalnetwork/beacon/beacon_network_test.go index 9acd391535f5..de45d936bec5 100644 --- a/portalnetwork/beacon/beacon_network_test.go +++ b/portalnetwork/beacon/beacon_network_test.go @@ -8,11 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -type Entry struct { - ContentKey string `yaml:"content_key"` - ContentValue string `yaml:"content_value"` -} - func TestLightClientBootstrapValidation(t *testing.T) { bootstrap, err := GetLightClientBootstrap(0) require.NoError(t, err) diff --git a/portalnetwork/beacon/portal_api.go b/portalnetwork/beacon/portal_api.go index 8f4ddad728ca..322c6a4f8075 100644 --- a/portalnetwork/beacon/portal_api.go +++ b/portalnetwork/beacon/portal_api.go @@ -5,7 +5,7 @@ import ( "errors" "time" - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/protolambda/zrnt/eth2/beacon/common" "github.com/protolambda/ztyp/codec" @@ -17,7 +17,7 @@ const BeaconGenesisTime uint64 = 1606824023 var _ ConsensusAPI = &PortalLightApi{} type PortalLightApi struct { - portalProtocol *discover.PortalProtocol + portalProtocol *portalwire.PortalProtocol spec *common.Spec } diff --git a/portalnetwork/beacon/storage.go b/portalnetwork/beacon/storage.go index b83eb28cc0dc..aca35d67611d 100644 --- a/portalnetwork/beacon/storage.go +++ b/portalnetwork/beacon/storage.go @@ -4,9 +4,11 @@ import ( "bytes" "context" "database/sql" + "errors" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/holiman/uint256" "github.com/protolambda/zrnt/eth2/beacon/common" "github.com/protolambda/ztyp/codec" @@ -16,7 +18,7 @@ import ( const BytesInMB uint64 = 1000 * 1000 -type BeaconStorage struct { +type Storage struct { storageCapacityInBytes uint64 db *sql.DB log log.Logger @@ -24,17 +26,17 @@ type BeaconStorage struct { cache *beaconStorageCache } -var portalStorageMetrics *metrics.PortalStorageMetrics +var portalStorageMetrics *portalwire.PortalStorageMetrics type beaconStorageCache struct { OptimisticUpdate []byte FinalityUpdate []byte } -var _ storage.ContentStorage = &BeaconStorage{} +var _ storage.ContentStorage = &Storage{} func NewBeaconStorage(config storage.PortalStorageConfig) (storage.ContentStorage, error) { - bs := &BeaconStorage{ + bs := &Storage{ storageCapacityInBytes: config.StorageCapacityMB * BytesInMB, db: config.DB, log: log.New("beacon_storage"), @@ -46,7 +48,7 @@ func NewBeaconStorage(config storage.PortalStorageConfig) (storage.ContentStorag } var err error - portalStorageMetrics, err = metrics.NewPortalStorageMetrics(config.NetworkName, config.DB) + portalStorageMetrics, err = portalwire.NewPortalStorageMetrics(config.NetworkName, config.DB) if err != nil { return nil, err } @@ -54,7 +56,7 @@ func NewBeaconStorage(config storage.PortalStorageConfig) (storage.ContentStorag return bs, nil } -func (bs *BeaconStorage) setup() error { +func (bs *Storage) setup() error { if _, err := bs.db.Exec(CreateQueryDBBeacon); err != nil { return err } @@ -64,7 +66,7 @@ func (bs *BeaconStorage) setup() error { return nil } -func (bs *BeaconStorage) Get(contentKey []byte, contentId []byte) ([]byte, error) { +func (bs *Storage) Get(contentKey []byte, contentId []byte) ([]byte, error) { switch storage.ContentType(contentKey[0]) { case LightClientBootstrap: return bs.getContentValue(contentId) @@ -89,7 +91,7 @@ func (bs *BeaconStorage) Get(contentKey []byte, contentId []byte) ([]byte, error return nil, nil } -func (bs *BeaconStorage) Put(contentKey []byte, contentId []byte, content []byte) error { +func (bs *Storage) Put(contentKey []byte, contentId []byte, content []byte) error { switch storage.ContentType(contentKey[0]) { case LightClientBootstrap: return bs.putContentValue(contentId, contentKey, content) @@ -129,20 +131,20 @@ func (bs *BeaconStorage) Put(contentKey []byte, contentId []byte, content []byte return nil } -func (bs *BeaconStorage) Radius() *uint256.Int { +func (bs *Storage) Radius() *uint256.Int { return storage.MaxDistance } -func (bs *BeaconStorage) getContentValue(contentId []byte) ([]byte, error) { +func (bs *Storage) getContentValue(contentId []byte) ([]byte, error) { res := make([]byte, 0) err := bs.db.QueryRowContext(context.Background(), ContentValueLookupQueryBeacon, contentId).Scan(&res) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, storage.ErrContentNotFound } return res, err } -func (bs *BeaconStorage) getLcUpdateValueByRange(start, end uint64) ([]byte, error) { +func (bs *Storage) getLcUpdateValueByRange(start, end uint64) ([]byte, error) { // LightClientUpdateRange := make([]ForkedLightClientUpdate, 0) var lightClientUpdateRange LightClientUpdateRange rows, err := bs.db.QueryContext(context.Background(), LCUpdateLookupQueryByRange, start, end) @@ -182,7 +184,7 @@ func (bs *BeaconStorage) getLcUpdateValueByRange(start, end uint64) ([]byte, err return buf.Bytes(), nil } -func (bs *BeaconStorage) putContentValue(contentId, contentKey, value []byte) error { +func (bs *Storage) putContentValue(contentId, contentKey, value []byte) error { length := 32 + len(contentKey) + len(value) _, err := bs.db.ExecContext(context.Background(), InsertQueryBeacon, contentId, contentKey, value, length) if metrics.Enabled && err == nil { @@ -192,7 +194,7 @@ func (bs *BeaconStorage) putContentValue(contentId, contentKey, value []byte) er return err } -func (bs *BeaconStorage) putLcUpdate(period uint64, value []byte) error { +func (bs *Storage) putLcUpdate(period uint64, value []byte) error { _, err := bs.db.ExecContext(context.Background(), InsertLCUpdateQuery, period, value, 0, len(value)) if metrics.Enabled && err == nil { portalStorageMetrics.EntriesCount.Inc(1) diff --git a/portalnetwork/beacon/test_utils.go b/portalnetwork/beacon/test_utils.go index 5235b79ffdf4..b3ee60bca067 100644 --- a/portalnetwork/beacon/test_utils.go +++ b/portalnetwork/beacon/test_utils.go @@ -9,8 +9,8 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/discover" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" ssz "github.com/ferranbt/fastssz" "github.com/golang/snappy" @@ -22,7 +22,7 @@ import ( ) func SetupBeaconNetwork(addr string, bootNodes []*enode.Node) (*BeaconNetwork, error) { - conf := discover.DefaultPortalProtocolConfig() + conf := portalwire.DefaultPortalProtocolConfig() if addr != "" { conf.ListenAddr = addr } @@ -57,17 +57,17 @@ func SetupBeaconNetwork(addr string, bootNodes []*enode.Node) (*BeaconNetwork, e localNode := enode.NewLocalNode(nodeDB, privKey) localNode.SetFallbackIP(net.IP{127, 0, 0, 1}) - localNode.Set(discover.Tag) + localNode.Set(portalwire.Tag) discV5, err := discover.ListenV5(conn, localNode, discCfg) if err != nil { return nil, err } - contentQueue := make(chan *discover.ContentElement, 50) + contentQueue := make(chan *portalwire.ContentElement, 50) - utpSocket := discover.NewPortalUtp(context.Background(), conf, discV5, conn) - portalProtocol, err := discover.NewPortalProtocol(conf, portalwire.Beacon, privKey, conn, localNode, discV5, utpSocket, &storage.MockStorage{Db: make(map[string][]byte)}, contentQueue) + utpSocket := portalwire.NewPortalUtp(context.Background(), conf, discV5, conn) + portalProtocol, err := portalwire.NewPortalProtocol(conf, portalwire.Beacon, privKey, conn, localNode, discV5, utpSocket, &storage.MockStorage{Db: make(map[string][]byte)}, contentQueue) if err != nil { return nil, err } @@ -232,11 +232,11 @@ func BuildHistoricalSummariesProof(beaconState deneb.BeaconState) ([][]byte, err leavesBytes = append(leavesBytes, dest) } - tree, err := ssz.TreeFromChunks(leavesBytes) + chunks, err := ssz.TreeFromChunks(leavesBytes) if err != nil { return nil, err } - proof, err := tree.Prove(59) + proof, err := chunks.Prove(59) if err != nil { return nil, err } diff --git a/portalnetwork/ethapi/api.go b/portalnetwork/ethapi/api.go index 585d382eeeac..c1a422c26322 100644 --- a/portalnetwork/ethapi/api.go +++ b/portalnetwork/ethapi/api.go @@ -60,7 +60,7 @@ func marshalReceipt(receipt *types.Receipt, blockHash common.Hash, blockNumber u } type API struct { - History *history.HistoryNetwork + History *history.Network ChainID *big.Int } diff --git a/portalnetwork/history/api.go b/portalnetwork/history/api.go index bafdb4ea7a6f..26a184a70736 100644 --- a/portalnetwork/history/api.go +++ b/portalnetwork/history/api.go @@ -1,14 +1,14 @@ package history import ( - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" ) type API struct { - *discover.PortalProtocolAPI + *portalwire.PortalProtocolAPI } -func (p *API) HistoryRoutingTableInfo() *discover.RoutingTableInfo { +func (p *API) HistoryRoutingTableInfo() *portalwire.RoutingTableInfo { return p.RoutingTableInfo() } @@ -28,7 +28,7 @@ func (p *API) HistoryLookupEnr(nodeId string) (string, error) { return p.LookupEnr(nodeId) } -func (p *API) HistoryPing(enr string) (*discover.PortalPongResp, error) { +func (p *API) HistoryPing(enr string) (*portalwire.PortalPongResp, error) { return p.Ping(enr) } @@ -48,7 +48,7 @@ func (p *API) HistoryRecursiveFindNodes(nodeId string) ([]string, error) { return p.RecursiveFindNodes(nodeId) } -func (p *API) HistoryGetContent(contentKeyHex string) (*discover.ContentInfo, error) { +func (p *API) HistoryGetContent(contentKeyHex string) (*portalwire.ContentInfo, error) { return p.RecursiveFindContent(contentKeyHex) } @@ -64,11 +64,11 @@ func (p *API) HistoryGossip(contentKeyHex, contentHex string) (int, error) { return p.Gossip(contentKeyHex, contentHex) } -func (p *API) HistoryTraceGetContent(contentKeyHex string) (*discover.TraceContentResult, error) { +func (p *API) HistoryTraceGetContent(contentKeyHex string) (*portalwire.TraceContentResult, error) { return p.TraceRecursiveFindContent(contentKeyHex) } -func NewHistoryNetworkAPI(historyAPI *discover.PortalProtocolAPI) *API { +func NewHistoryNetworkAPI(historyAPI *portalwire.PortalProtocolAPI) *API { return &API{ historyAPI, } diff --git a/portalnetwork/history/history_network.go b/portalnetwork/history/history_network.go index e102902299e8..b6824d2e9f34 100644 --- a/portalnetwork/history/history_network.go +++ b/portalnetwork/history/history_network.go @@ -13,7 +13,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/trie" @@ -61,18 +61,18 @@ func (c *ContentKey) encode() []byte { return res } -type HistoryNetwork struct { - portalProtocol *discover.PortalProtocol +type Network struct { + portalProtocol *portalwire.PortalProtocol masterAccumulator *MasterAccumulator closeCtx context.Context closeFunc context.CancelFunc log log.Logger } -func NewHistoryNetwork(portalProtocol *discover.PortalProtocol, accu *MasterAccumulator) *HistoryNetwork { +func NewHistoryNetwork(portalProtocol *portalwire.PortalProtocol, accu *MasterAccumulator) *Network { ctx, cancel := context.WithCancel(context.Background()) - return &HistoryNetwork{ + return &Network{ portalProtocol: portalProtocol, masterAccumulator: accu, closeCtx: ctx, @@ -81,7 +81,7 @@ func NewHistoryNetwork(portalProtocol *discover.PortalProtocol, accu *MasterAccu } } -func (h *HistoryNetwork) Start() error { +func (h *Network) Start() error { err := h.portalProtocol.Start() if err != nil { return err @@ -91,7 +91,7 @@ func (h *HistoryNetwork) Start() error { return nil } -func (h *HistoryNetwork) Stop() { +func (h *Network) Stop() { h.closeFunc() h.portalProtocol.Stop() } @@ -99,7 +99,7 @@ func (h *HistoryNetwork) Stop() { // Currently doing 4 retries on lookups but only when the validation fails. const requestRetries = 4 -func (h *HistoryNetwork) GetBlockHeader(blockHash []byte) (*types.Header, error) { +func (h *Network) GetBlockHeader(blockHash []byte) (*types.Header, error) { contentKey := newContentKey(BlockHeaderType, blockHash).encode() contentId := h.portalProtocol.ToContentId(contentKey) h.log.Trace("contentKey convert to contentId", "contentKey", hexutil.Encode(contentKey), "contentId", hexutil.Encode(contentId)) @@ -155,7 +155,7 @@ func (h *HistoryNetwork) GetBlockHeader(blockHash []byte) (*types.Header, error) return nil, storage.ErrContentNotFound } -func (h *HistoryNetwork) GetBlockBody(blockHash []byte) (*types.Body, error) { +func (h *Network) GetBlockBody(blockHash []byte) (*types.Body, error) { header, err := h.GetBlockHeader(blockHash) if err != nil { return nil, err @@ -206,7 +206,7 @@ func (h *HistoryNetwork) GetBlockBody(blockHash []byte) (*types.Body, error) { return nil, storage.ErrContentNotFound } -func (h *HistoryNetwork) GetReceipts(blockHash []byte) ([]*types.Receipt, error) { +func (h *Network) GetReceipts(blockHash []byte) ([]*types.Receipt, error) { header, err := h.GetBlockHeader(blockHash) if err != nil { return nil, err @@ -255,7 +255,7 @@ func (h *HistoryNetwork) GetReceipts(blockHash []byte) ([]*types.Receipt, error) return nil, storage.ErrContentNotFound } -func (h *HistoryNetwork) verifyHeader(header *types.Header, proof BlockHeaderProof) (bool, error) { +func (h *Network) verifyHeader(header *types.Header, proof BlockHeaderProof) (bool, error) { return h.masterAccumulator.VerifyHeader(*header, proof) } @@ -457,7 +457,7 @@ func ToPortalReceipts(receipts []*types.Receipt) (*PortalReceipts, error) { return &PortalReceipts{Receipts: res}, nil } -func (h *HistoryNetwork) processContentLoop(ctx context.Context) { +func (h *Network) processContentLoop(ctx context.Context) { contentChan := h.portalProtocol.GetContent() for { select { @@ -488,7 +488,7 @@ func (h *HistoryNetwork) processContentLoop(ctx context.Context) { } } -func (h *HistoryNetwork) validateContent(contentKey []byte, content []byte) error { +func (h *Network) validateContent(contentKey []byte, content []byte) error { switch ContentType(contentKey[0]) { case BlockHeaderType: headerWithProof, err := DecodeBlockHeaderWithProof(content) @@ -559,7 +559,7 @@ func (h *HistoryNetwork) validateContent(contentKey []byte, content []byte) erro return errors.New("unknown content type") } -func (h *HistoryNetwork) validateContents(contentKeys [][]byte, contents [][]byte) error { +func (h *Network) validateContents(contentKeys [][]byte, contents [][]byte) error { for i, content := range contents { contentKey := contentKeys[i] err := h.validateContent(contentKey, content) diff --git a/portalnetwork/history/history_network_test.go b/portalnetwork/history/history_network_test.go index 38b816e9cb10..72ced80a2013 100644 --- a/portalnetwork/history/history_network_test.go +++ b/portalnetwork/history/history_network_test.go @@ -16,8 +16,8 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/discover" - "github.com/ethereum/go-ethereum/p2p/discover/portalwire" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" @@ -241,7 +241,7 @@ func TestValidateContents(t *testing.T) { func TestValidateContentForCancun(t *testing.T) { master, err := NewMasterAccumulator() require.NoError(t, err) - historyNetwork := &HistoryNetwork{ + historyNetwork := &Network{ masterAccumulator: &master, } @@ -287,12 +287,12 @@ func parseBlockHeaderKeyContent() ([]contentEntry, error) { return res, nil } -func genHistoryNetwork(addr string, bootNodes []*enode.Node) (*HistoryNetwork, error) { +func genHistoryNetwork(addr string, bootNodes []*enode.Node) (*Network, error) { glogger := log.NewGlogHandler(log.NewTerminalHandler(os.Stderr, true)) slogVerbosity := log.FromLegacyLevel(5) glogger.Verbosity(slogVerbosity) log.SetDefault(log.NewLogger(glogger)) - conf := discover.DefaultPortalProtocolConfig() + conf := portalwire.DefaultPortalProtocolConfig() if addr != "" { conf.ListenAddr = addr } @@ -327,16 +327,16 @@ func genHistoryNetwork(addr string, bootNodes []*enode.Node) (*HistoryNetwork, e localNode := enode.NewLocalNode(nodeDB, privKey) localNode.SetFallbackIP(net.IP{127, 0, 0, 1}) - localNode.Set(discover.Tag) + localNode.Set(portalwire.Tag) discV5, err := discover.ListenV5(conn, localNode, discCfg) if err != nil { return nil, err } - contentQueue := make(chan *discover.ContentElement, 50) - utpSocket := discover.NewPortalUtp(context.Background(), conf, discV5, conn) - portalProtocol, err := discover.NewPortalProtocol(conf, portalwire.History, privKey, conn, localNode, discV5, utpSocket, &storage.MockStorage{Db: make(map[string][]byte)}, contentQueue) + contentQueue := make(chan *portalwire.ContentElement, 50) + utpSocket := portalwire.NewPortalUtp(context.Background(), conf, discV5, conn) + portalProtocol, err := portalwire.NewPortalProtocol(conf, portalwire.History, privKey, conn, localNode, discV5, utpSocket, &storage.MockStorage{Db: make(map[string][]byte)}, contentQueue) if err != nil { return nil, err } diff --git a/portalnetwork/history/new_storage.go b/portalnetwork/history/new_storage.go new file mode 100644 index 000000000000..be9ccd38383c --- /dev/null +++ b/portalnetwork/history/new_storage.go @@ -0,0 +1,443 @@ +package history + +// +//import ( +// "encoding/binary" +// "errors" +// "fmt" +// "path" +// "sync/atomic" +// +// "github.com/cockroachdb/pebble" +// "github.com/ethereum/go-ethereum/log" +// "github.com/ethereum/go-ethereum/metrics" +// "github.com/ethereum/go-ethereum/p2p/enode" +// "github.com/ethereum/go-ethereum/portalnetwork/storage" +// "github.com/holiman/uint256" +//) +// +//const ( +// contentDeletionFraction = 0.05 +// prefixContent = byte(0x01) // prefixContent + distance + contentId -> content +// prefixDistanceSize = byte(0x02) // prefixDistanceSize + distance -> total size +//) +// +//type ContentStorage struct { +// nodeId enode.ID +// storageCapacityInBytes uint64 +// radius atomic.Value +// db *pebble.DB +// log log.Logger +//} +// +//func NewHistoryStorage(config storage.PortalStorageConfig) (storage.ContentStorage, error) { +// dbPath := path.Join(config.DataDir, config.NetworkName) +// +// opts := &pebble.Options{ +// MaxOpenFiles: 1000, +// } +// +// db, err := pebble.Open(dbPath, opts) +// if err != nil { +// return nil, err +// } +// +// cs := &ContentStorage{ +// nodeId: config.NodeId, +// db: db, +// storageCapacityInBytes: config.StorageCapacityMB * 1000000, +// log: log.New("storage", config.NetworkName), +// } +// cs.radius.Store(storage.MaxDistance) +// cs.setRadiusToFarthestDistance() +// +// return cs, nil +//} +// +//func makeKey(prefix byte, distance []byte, contentId []byte) []byte { +// if contentId == nil { +// key := make([]byte, 1+len(distance)) +// key[0] = prefix +// copy(key[1:], distance) +// return key +// } +// key := make([]byte, 1+len(distance)+len(contentId)) +// key[0] = prefix +// copy(key[1:], distance) +// copy(key[1+len(distance):], contentId) +// return key +//} +// +//func (p *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte) error { +// distance := xor(contentId, p.nodeId[:]) +// key := makeKey(prefixContent, distance, contentId) +// +// batch := p.db.NewBatch() +// defer batch.Close() +// +// // Update content +// if err := batch.Set(key, content, pebble.Sync); err != nil { +// return err +// } +// +// // Update distance size index +// sizeKey := makeKey(prefixDistanceSize, distance, nil) +// var currentSize uint64 +// if value, closer, err := p.db.Get(sizeKey); err == nil { +// currentSize = binary.BigEndian.Uint64(value) +// closer.Close() +// } +// +// newSize := currentSize + uint64(len(content)) +// sizeBytes := make([]byte, 8) +// binary.BigEndian.PutUint64(sizeBytes, newSize) +// +// if err := batch.Set(sizeKey, sizeBytes, pebble.Sync); err != nil { +// return err +// } +// +// if err := batch.Commit(pebble.Sync); err != nil { +// return err +// } +// +// if size, _ := p.UsedSize(); size > p.storageCapacityInBytes { +// if _, err := p.deleteContentFraction(contentDeletionFraction); err != nil { +// p.log.Warn("failed to delete oversize content", "err", err) +// } +// } +// +// if metrics.Enabled { +// portalStorageMetrics.EntriesCount.Inc(1) +// portalStorageMetrics.ContentStorageUsage.Inc(int64(len(content))) +// } +// +// return nil +//} +// +//func (p *ContentStorage) Get(contentKey []byte, contentId []byte) ([]byte, error) { +// distance := xor(contentId, p.nodeId[:]) +// key := makeKey(prefixContent, distance, contentId) +// +// value, closer, err := p.db.Get(key) +// if err == pebble.ErrNotFound { +// return nil, storage.ErrContentNotFound +// } +// if err != nil { +// return nil, err +// } +// defer closer.Close() +// +// return value, nil +//} +// +//func (p *ContentStorage) deleteContentFraction(fraction float64) (deleteCount int, err error) { +// if fraction <= 0 || fraction >= 1 { +// return 0, errors.New("fraction should be between 0 and 1") +// } +// +// totalSize, err := p.ContentSize() +// if err != nil { +// return 0, err +// } +// +// targetSize := uint64(float64(totalSize) * fraction) +// deletedSize := uint64(0) +// count := 0 +// +// iter := p.db.NewIter(&pebble.IterOptions{ +// LowerBound: []byte{prefixContent}, +// UpperBound: []byte{prefixContent + 1}, +// }) +// defer iter.Close() +// +// batch := p.db.NewBatch() +// defer batch.Close() +// +// for iter.Last(); iter.Valid() && deletedSize < targetSize; iter.Prev() { +// key := iter.Key() +// value := iter.Value() +// distance := key[1:33] +// +// // Delete content +// if err := batch.Delete(key, nil); err != nil { +// return count, err +// } +// +// // Update distance size index +// sizeKey := makeKey(prefixDistanceSize, distance, nil) +// var currentSize uint64 +// sizeValue, closer, err := p.db.Get(sizeKey) +// if err == nil { +// currentSize = binary.BigEndian.Uint64(sizeValue) +// closer.Close() +// } +// +// newSize := currentSize - uint64(len(value)) +// if newSize == 0 { +// if err := batch.Delete(sizeKey, nil); err != nil { +// return count, err +// } +// } else { +// sizeBytes := make([]byte, 8) +// binary.BigEndian.PutUint64(sizeBytes, newSize) +// if err := batch.Set(sizeKey, sizeBytes, nil); err != nil { +// return count, err +// } +// } +// +// deletedSize += uint64(len(value)) +// count++ +// +// if batch.Len() >= 1000 { +// if err := batch.Commit(pebble.Sync); err != nil { +// return count, err +// } +// batch = p.db.NewBatch() +// } +// } +// if batch.Len() > 0 { +// if err := batch.Commit(pebble.Sync); err != nil { +// return count, err +// } +// } +// +// if iter.Valid() { +// key := iter.Key() +// distance := key[1:33] +// dis := uint256.NewInt(0) +// if err := dis.UnmarshalSSZ(distance); err != nil { +// return count, err +// } +// p.radius.Store(dis) +// } +// +// return count, nil +//} +// +//func (p *ContentStorage) UsedSize() (uint64, error) { +// var totalSize uint64 +// iter := p.db.NewIter(&pebble.IterOptions{ +// LowerBound: []byte{prefixDistanceSize}, +// UpperBound: []byte{prefixDistanceSize + 1}, +// }) +// defer iter.Close() +// +// for iter.First(); iter.Valid(); iter.Next() { +// size := binary.BigEndian.Uint64(iter.Value()) +// totalSize += size +// } +// +// return totalSize, nil +//} +// +//func (p *ContentStorage) ContentSize() (uint64, error) { +// return p.UsedSize() +//} +// +//func (p *ContentStorage) ContentCount() (uint64, error) { +// var count uint64 +// iter, _ := p.db.NewIter(&pebble.IterOptions{ +// LowerBound: []byte{prefixContent}, +// UpperBound: []byte{prefixContent + 1}, +// }) +// defer iter.Close() +// +// for iter.First(); iter.Valid(); iter.Next() { +// count++ +// } +// +// return count, nil +//} +// +//func (p *ContentStorage) Radius() *uint256.Int { +// radius := p.radius.Load() +// val := radius.(*uint256.Int) +// return val +//} +//func (p *ContentStorage) GetLargestDistance() (*uint256.Int, error) { +// iter := p.db.NewIter(&pebble.IterOptions{ +// LowerBound: []byte{prefixContent}, +// UpperBound: []byte{prefixContent + 1}, +// }) +// defer iter.Close() +// +// if !iter.Last() { +// return nil, fmt.Errorf("no content found") +// } +// +// key := iter.Key() +// distance := key[1:33] +// +// res := uint256.NewInt(0) +// err := res.UnmarshalSSZ(distance) +// return res, err +//} +// +//func (p *ContentStorage) EstimateNewRadius(currentRadius *uint256.Int) (*uint256.Int, error) { +// currrentSize, err := p.UsedSize() +// if err != nil { +// return nil, err +// } +// +// sizeRatio := currrentSize / p.storageCapacityInBytes +// if sizeRatio > 0 { +// newRadius := new(uint256.Int).Div(currentRadius, uint256.NewInt(sizeRatio)) +// +// if metrics.Enabled { +// ratio := new(uint256.Int).Mul(newRadius, uint256.NewInt(100)) +// ratio.Mod(ratio, storage.MaxDistance) +// portalStorageMetrics.RadiusRatio.Update(ratio.Float64() / 100) +// } +// +// return newRadius, nil +// } +// return currentRadius, nil +//} +// +//func (p *ContentStorage) setRadiusToFarthestDistance() { +// largestDistance, err := p.GetLargestDistance() +// if err != nil { +// p.log.Error("failed to get farthest distance", "err", err) +// return +// } +// p.radius.Store(largestDistance) +//} +//func (p *ContentStorage) ForcePrune(radius *uint256.Int) error { +// batch := p.db.NewBatch() +// defer batch.Close() +// +// iter := p.db.NewIter(&pebble.IterOptions{ +// LowerBound: []byte{prefixContent}, +// UpperBound: []byte{prefixContent + 1}, +// }) +// defer iter.Close() +// +// var deletedSize int64 +// deleteCount := 0 +// +// for iter.First(); iter.Valid(); iter.Next() { +// key := iter.Key() +// value := iter.Value() +// distance := key[1:33] +// +// dis := uint256.NewInt(0) +// if err := dis.UnmarshalSSZ(distance); err != nil { +// return err +// } +// +// if dis.Cmp(radius) > 0 { +// // Delete content +// if err := batch.Delete(key, nil); err != nil { +// return err +// } +// +// // Update distance size index +// sizeKey := makeKey(prefixDistanceSize, distance, nil) +// var currentSize uint64 +// if sizeValue, closer, err := p.db.Get(sizeKey); err == nil { +// currentSize = binary.BigEndian.Uint64(sizeValue) +// closer.Close() +// } +// +// newSize := currentSize - uint64(len(value)) +// if newSize == 0 { +// if err := batch.Delete(sizeKey, nil); err != nil { +// return err +// } +// } else { +// sizeBytes := make([]byte, 8) +// binary.BigEndian.PutUint64(sizeBytes, newSize) +// if err := batch.Set(sizeKey, sizeBytes, nil); err != nil { +// return err +// } +// } +// +// deletedSize += int64(len(value)) +// deleteCount++ +// } +// +// if batch.Len() >= 1000 { +// if err := batch.Commit(pebble.Sync); err != nil { +// return err +// } +// batch = p.db.NewBatch() +// } +// } +// if batch.Len() > 0 { +// if err := batch.Commit(pebble.Sync); err != nil { +// return err +// } +// } +// +// if metrics.Enabled { +// portalStorageMetrics.EntriesCount.Dec(int64(deleteCount)) +// portalStorageMetrics.ContentStorageUsage.Dec(deletedSize) +// } +// +// return nil +//} +// +//func (p *ContentStorage) ReclaimSpace() error { +// return p.db.Compact([]byte{prefixContent}, []byte{prefixContent + 1}, true) +//} +// +//func (p *ContentStorage) Close() error { +// return p.db.Close() +//} +// +//func (p *ContentStorage) SizeByKey(contentId []byte) (uint64, error) { +// distance := xor(contentId, p.nodeId[:]) +// key := makeKey(prefixContent, distance, contentId) +// +// value, closer, err := p.db.Get(key) +// if err == pebble.ErrNotFound { +// return 0, nil +// } +// if err != nil { +// return 0, err +// } +// defer closer.Close() +// +// return uint64(len(value)), nil +//} +// +//func (p *ContentStorage) SizeByKeys(ids [][]byte) (uint64, error) { +// var totalSize uint64 +// +// for _, id := range ids { +// size, err := p.SizeByKey(id) +// if err != nil { +// return 0, err +// } +// totalSize += size +// } +// +// return totalSize, nil +//} +// +//func (p *ContentStorage) SizeOutRadius(radius *uint256.Int) (uint64, error) { +// var totalSize uint64 +// +// iter := p.db.NewIter(&pebble.IterOptions{ +// LowerBound: []byte{prefixDistanceSize}, +// UpperBound: []byte{prefixDistanceSize + 1}, +// }) +// defer iter.Close() +// +// for iter.First(); iter.Valid(); iter.Next() { +// key := iter.Key() +// distance := key[1:33] +// +// dis := uint256.NewInt(0) +// if err := dis.UnmarshalSSZ(distance); err != nil { +// return 0, err +// } +// +// if dis.Cmp(radius) > 0 { +// size := binary.BigEndian.Uint64(iter.Value()) +// totalSize += size +// } +// } +// +// return totalSize, nil +//} diff --git a/portalnetwork/history/storage.go b/portalnetwork/history/storage.go index 28632cdab6b5..8e0260ee5e3f 100644 --- a/portalnetwork/history/storage.go +++ b/portalnetwork/history/storage.go @@ -16,6 +16,7 @@ import ( "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/holiman/uint256" "github.com/mattn/go-sqlite3" @@ -57,7 +58,7 @@ type ContentStorage struct { log log.Logger } -var portalStorageMetrics *metrics.PortalStorageMetrics +var portalStorageMetrics *portalwire.PortalStorageMetrics func xor(contentId, nodeId []byte) []byte { // length of contentId maybe not 32bytes @@ -123,7 +124,7 @@ func NewHistoryStorage(config storage.PortalStorageConfig) (storage.ContentStora // necessary to test NetworkName==history because state also initialize HistoryStorage if strings.ToLower(config.NetworkName) == "history" { - portalStorageMetrics, err = metrics.NewPortalStorageMetrics(config.NetworkName, config.DB) + portalStorageMetrics, err = portalwire.NewPortalStorageMetrics(config.NetworkName, config.DB) if err != nil { return nil, err } diff --git a/portalnetwork/nat.go b/portalnetwork/nat.go deleted file mode 100644 index ca479d7e457d..000000000000 --- a/portalnetwork/nat.go +++ /dev/null @@ -1,172 +0,0 @@ -package portalnetwork - -import ( - "net" - "time" - - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/p2p/nat" -) - -const ( - portMapDuration = 10 * time.Minute - portMapRefreshInterval = 8 * time.Minute - portMapRetryInterval = 5 * time.Minute - extipRetryInterval = 2 * time.Minute -) - -type portMapping struct { - protocol string - name string - port int - - // for use by the portMappingLoop goroutine: - extPort int // the mapped port returned by the NAT interface - nextTime mclock.AbsTime -} - -// setupPortMapping starts the port mapping loop if necessary. -// Note: this needs to be called after the LocalNode instance has been set on the server. -func (p *PortalProtocol) setupPortMapping() { - // portMappingRegister will receive up to two values: one for the TCP port if - // listening is enabled, and one more for enabling UDP port mapping if discovery is - // enabled. We make it buffered to avoid blocking setup while a mapping request is in - // progress. - p.portMappingRegister = make(chan *portMapping, 2) - - switch p.NAT.(type) { - case nil: - // No NAT interface configured. - go p.consumePortMappingRequests() - - case nat.ExtIP: - // ExtIP doesn't block, set the IP right away. - ip, _ := p.NAT.ExternalIP() - p.localNode.SetStaticIP(ip) - go p.consumePortMappingRequests() - - case nat.STUN: - // STUN doesn't block, set the IP right away. - ip, _ := p.NAT.ExternalIP() - p.localNode.SetStaticIP(ip) - go p.consumePortMappingRequests() - - default: - go p.portMappingLoop() - } -} - -func (p *PortalProtocol) consumePortMappingRequests() { - for { - select { - case <-p.closeCtx.Done(): - return - case <-p.portMappingRegister: - } - } -} - -// portMappingLoop manages port mappings for UDP and TCP. -func (p *PortalProtocol) portMappingLoop() { - newLogger := func(proto string, e int, i int) log.Logger { - return log.New("proto", proto, "extport", e, "intport", i, "interface", p.NAT) - } - - var ( - mappings = make(map[string]*portMapping, 2) - refresh = mclock.NewAlarm(p.clock) - extip = mclock.NewAlarm(p.clock) - lastExtIP net.IP - ) - extip.Schedule(p.clock.Now()) - defer func() { - refresh.Stop() - extip.Stop() - for _, m := range mappings { - if m.extPort != 0 { - log := newLogger(m.protocol, m.extPort, m.port) - log.Debug("Deleting port mapping") - p.NAT.DeleteMapping(m.protocol, m.extPort, m.port) - } - } - }() - - for { - // Schedule refresh of existing mappings. - for _, m := range mappings { - refresh.Schedule(m.nextTime) - } - - select { - case <-p.closeCtx.Done(): - return - - case <-extip.C(): - extip.Schedule(p.clock.Now().Add(extipRetryInterval)) - ip, err := p.NAT.ExternalIP() - if err != nil { - log.Debug("Couldn't get external IP", "err", err, "interface", p.NAT) - } else if !ip.Equal(lastExtIP) { - log.Debug("External IP changed", "ip", extip, "interface", p.NAT) - } else { - continue - } - // Here, we either failed to get the external IP, or it has changed. - lastExtIP = ip - p.localNode.SetStaticIP(ip) - p.Log.Debug("set static ip in nat", "ip", p.localNode.Node().IP().String()) - // Ensure port mappings are refreshed in case we have moved to a new network. - for _, m := range mappings { - m.nextTime = p.clock.Now() - } - - case m := <-p.portMappingRegister: - if m.protocol != "TCP" && m.protocol != "UDP" { - panic("unknown NAT protocol name: " + m.protocol) - } - mappings[m.protocol] = m - m.nextTime = p.clock.Now() - - case <-refresh.C(): - for _, m := range mappings { - if p.clock.Now() < m.nextTime { - continue - } - - external := m.port - if m.extPort != 0 { - external = m.extPort - } - log := newLogger(m.protocol, external, m.port) - - log.Trace("Attempting port mapping") - port, err := p.NAT.AddMapping(m.protocol, external, m.port, m.name, portMapDuration) - if err != nil { - log.Debug("Couldn't add port mapping", "err", err) - m.extPort = 0 - m.nextTime = p.clock.Now().Add(portMapRetryInterval) - continue - } - // It was mapped! - m.extPort = int(port) - m.nextTime = p.clock.Now().Add(portMapRefreshInterval) - if external != m.extPort { - log = newLogger(m.protocol, m.extPort, m.port) - log.Info("NAT mapped alternative port") - } else { - log.Info("NAT mapped port") - } - - // Update port in local ENR. - switch m.protocol { - case "TCP": - p.localNode.Set(enr.TCP(m.extPort)) - case "UDP": - p.localNode.SetFallbackUDP(m.extPort) - } - } - } - } -} diff --git a/portalnetwork/portal_protocol_metrics.go b/portalnetwork/portal_protocol_metrics.go deleted file mode 100644 index 343d3f4f00f3..000000000000 --- a/portalnetwork/portal_protocol_metrics.go +++ /dev/null @@ -1,67 +0,0 @@ -package portalnetwork - -import "github.com/ethereum/go-ethereum/metrics" - -type portalMetrics struct { - messagesReceivedAccept metrics.Meter - messagesReceivedNodes metrics.Meter - messagesReceivedFindNodes metrics.Meter - messagesReceivedFindContent metrics.Meter - messagesReceivedContent metrics.Meter - messagesReceivedOffer metrics.Meter - messagesReceivedPing metrics.Meter - messagesReceivedPong metrics.Meter - - messagesSentAccept metrics.Meter - messagesSentNodes metrics.Meter - messagesSentFindNodes metrics.Meter - messagesSentFindContent metrics.Meter - messagesSentContent metrics.Meter - messagesSentOffer metrics.Meter - messagesSentPing metrics.Meter - messagesSentPong metrics.Meter - - utpInFailConn metrics.Counter - utpInFailRead metrics.Counter - utpInFailDeadline metrics.Counter - utpInSuccess metrics.Counter - - utpOutFailConn metrics.Counter - utpOutFailWrite metrics.Counter - utpOutFailDeadline metrics.Counter - utpOutSuccess metrics.Counter - - contentDecodedTrue metrics.Counter - contentDecodedFalse metrics.Counter -} - -func newPortalMetrics(protocolName string) *portalMetrics { - return &portalMetrics{ - messagesReceivedAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/accept", nil), - messagesReceivedNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/nodes", nil), - messagesReceivedFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_nodes", nil), - messagesReceivedFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_content", nil), - messagesReceivedContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/content", nil), - messagesReceivedOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/offer", nil), - messagesReceivedPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/ping", nil), - messagesReceivedPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/pong", nil), - messagesSentAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/accept", nil), - messagesSentNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/nodes", nil), - messagesSentFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_nodes", nil), - messagesSentFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_content", nil), - messagesSentContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/content", nil), - messagesSentOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/offer", nil), - messagesSentPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/ping", nil), - messagesSentPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/pong", nil), - utpInFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_conn", nil), - utpInFailRead: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_read", nil), - utpInFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_deadline", nil), - utpInSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/success", nil), - utpOutFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_conn", nil), - utpOutFailWrite: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_write", nil), - utpOutFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_deadline", nil), - utpOutSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/success", nil), - contentDecodedTrue: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/true", nil), - contentDecodedFalse: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/false", nil), - } -} diff --git a/portalnetwork/api.go b/portalnetwork/portalwire/api.go similarity index 98% rename from portalnetwork/api.go rename to portalnetwork/portalwire/api.go index bc7305ef8b57..68698888c0ee 100644 --- a/portalnetwork/api.go +++ b/portalnetwork/portalwire/api.go @@ -1,4 +1,4 @@ -package portalnetwork +package portalwire import ( "errors" @@ -6,7 +6,6 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/holiman/uint256" ) @@ -73,7 +72,7 @@ type RespByNode struct { RespondedWith []string `json:"respondedWith"` } -type Enrs struct { +type EnrsResp struct { Enrs []string `json:"enrs"` } @@ -328,7 +327,7 @@ func (p *PortalProtocolAPI) Ping(enr string) (*PortalPongResp, error) { return nil, err } - customPayload := &portalwire.PingPongCustomData{} + customPayload := &PingPongCustomData{} err = customPayload.UnmarshalSSZ(pong.CustomPayload) if err != nil { return nil, err @@ -381,14 +380,14 @@ func (p *PortalProtocolAPI) FindContent(enr string, contentKey string) (interfac } switch flag { - case portalwire.ContentRawSelector: + case ContentRawSelector: contentInfo := &ContentInfo{ Content: hexutil.Encode(findContent.([]byte)), UtpTransfer: false, } p.portalProtocol.Log.Trace("FindContent", "contentInfo", contentInfo) return contentInfo, nil - case portalwire.ContentConnIdSelector: + case ContentConnIdSelector: contentInfo := &ContentInfo{ Content: hexutil.Encode(findContent.([]byte)), UtpTransfer: true, @@ -402,7 +401,7 @@ func (p *PortalProtocolAPI) FindContent(enr string, contentKey string) (interfac } p.portalProtocol.Log.Trace("FindContent", "enrs", enrs) - return &Enrs{ + return &EnrsResp{ Enrs: enrs, }, nil } diff --git a/portalnetwork/portalwire/messages.go b/portalnetwork/portalwire/messages.go deleted file mode 100644 index c7629604d570..000000000000 --- a/portalnetwork/portalwire/messages.go +++ /dev/null @@ -1,336 +0,0 @@ -package portalwire - -import ( - ssz "github.com/ferranbt/fastssz" -) - -// note: We changed the generated file since fastssz issues which can't be passed by the CI, so we commented the go:generate line -///go:generate sszgen --path messages.go --exclude-objs Content,Enrs,ContentKV - -// Message codes for the portal protocol. -const ( - PING byte = 0x00 - PONG byte = 0x01 - FINDNODES byte = 0x02 - NODES byte = 0x03 - FINDCONTENT byte = 0x04 - CONTENT byte = 0x05 - OFFER byte = 0x06 - ACCEPT byte = 0x07 -) - -// Content selectors for the portal protocol. -const ( - ContentConnIdSelector byte = 0x00 - ContentRawSelector byte = 0x01 - ContentEnrsSelector byte = 0x02 -) - -const ( - ContentKeysLimit = 64 - // OfferMessageOverhead overhead of content message is a result of 1byte for kind enum, and - // 4 bytes for offset in ssz serialization - OfferMessageOverhead = 5 - - // PerContentKeyOverhead each key in ContentKeysList has uint32 offset which results in 4 bytes per - // key overhead when serialized - PerContentKeyOverhead = 4 -) - -// Protocol IDs for the portal protocol. -// var ( -// StateNetwork = []byte{0x50, 0x0a} -// HistoryNetwork = []byte{0x50, 0x0b} -// TxGossipNetwork = []byte{0x50, 0x0c} -// HeaderGossipNetwork = []byte{0x50, 0x0d} -// CanonicalIndicesNetwork = []byte{0x50, 0x0e} -// BeaconLightClientNetwork = []byte{0x50, 0x1a} -// UTPNetwork = []byte{0x75, 0x74, 0x70} -// Rendezvous = []byte{0x72, 0x65, 0x6e} -// ) - -type ProtocolId []byte - -var ( - State ProtocolId = []byte{0x50, 0x0A} - History ProtocolId = []byte{0x50, 0x0B} - Beacon ProtocolId = []byte{0x50, 0x0C} - CanonicalIndices ProtocolId = []byte{0x50, 0x0D} - VerkleState ProtocolId = []byte{0x50, 0x0E} - TransactionGossip ProtocolId = []byte{0x50, 0x0F} - Utp ProtocolId = []byte{0x75, 0x74, 0x70} -) - -var protocalName = map[string]string{ - string(State): "state", - string(History): "history", - string(Beacon): "beacon", - string(CanonicalIndices): "canonical indices", - string(VerkleState): "verkle state", - string(TransactionGossip): "transaction gossip", -} - -func (p ProtocolId) Name() string { - return protocalName[string(p)] -} - -// const ( -// HistoryNetworkName = "history" -// BeaconNetworkName = "beacon" -// StateNetworkName = "state" -// ) - -// var NetworkNameMap = map[string]string{ -// string(StateNetwork): StateNetworkName, -// string(HistoryNetwork): HistoryNetworkName, -// string(BeaconLightClientNetwork): BeaconNetworkName, -// } - -type ContentKV struct { - ContentKey []byte - Content []byte -} - -// Request messages for the portal protocol. -type ( - PingPongCustomData struct { - Radius []byte `ssz-size:"32"` - } - - Ping struct { - EnrSeq uint64 - CustomPayload []byte `ssz-max:"2048"` - } - - FindNodes struct { - Distances [][2]byte `ssz-max:"256,2" ssz-size:"?,2"` - } - - FindContent struct { - ContentKey []byte `ssz-max:"2048"` - } - - Offer struct { - ContentKeys [][]byte `ssz-max:"64,2048"` - } -) - -// Response messages for the portal protocol. -type ( - Pong struct { - EnrSeq uint64 - CustomPayload []byte `ssz-max:"2048"` - } - - Nodes struct { - Total uint8 - Enrs [][]byte `ssz-max:"32,2048"` - } - - ConnectionId struct { - Id []byte `ssz-size:"2"` - } - - Content struct { - Content []byte `ssz-max:"2048"` - } - - Enrs struct { - Enrs [][]byte `ssz-max:"32,2048"` - } - - Accept struct { - ConnectionId []byte `ssz-size:"2"` - ContentKeys []byte `ssz:"bitlist" ssz-max:"64"` - } -) - -// MarshalSSZ ssz marshals the Content object -func (c *Content) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(c) -} - -// MarshalSSZTo ssz marshals the Content object to a target array -func (c *Content) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - - // Field (0) 'Content' - if size := len(c.Content); size > 2048 { - err = ssz.ErrBytesLengthFn("Content.Content", size, 2048) - return - } - dst = append(dst, c.Content...) - - return -} - -// UnmarshalSSZ ssz unmarshals the Content object -func (c *Content) UnmarshalSSZ(buf []byte) error { - var err error - tail := buf - - // Field (0) 'Content' - { - buf = tail[:] - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(c.Content) == 0 { - c.Content = make([]byte, 0, len(buf)) - } - c.Content = append(c.Content, buf...) - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Content object -func (c *Content) SizeSSZ() (size int) { - // Field (0) 'Content' - return len(c.Content) -} - -// HashTreeRoot ssz hashes the Content object -func (c *Content) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(c) -} - -// HashTreeRootWith ssz hashes the Content object with a hasher -func (c *Content) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'Content' - { - elemIndx := hh.Index() - byteLen := uint64(len(c.Content)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.Append(c.Content) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Content object -func (c *Content) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(c) -} - -// MarshalSSZ ssz marshals the Enrs object -func (e *Enrs) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(e) -} - -// MarshalSSZTo ssz marshals the Enrs object to a target array -func (e *Enrs) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(0) - - // Field (0) 'Enrs' - if size := len(e.Enrs); size > 32 { - err = ssz.ErrListTooBigFn("Enrs.Enrs", size, 32) - return - } - { - offset = 4 * len(e.Enrs) - for ii := 0; ii < len(e.Enrs); ii++ { - dst = ssz.WriteOffset(dst, offset) - offset += len(e.Enrs[ii]) - } - } - for ii := 0; ii < len(e.Enrs); ii++ { - if size := len(e.Enrs[ii]); size > 2048 { - err = ssz.ErrBytesLengthFn("Enrs.Enrs[ii]", size, 2048) - return - } - dst = append(dst, e.Enrs[ii]...) - } - - return -} - -// UnmarshalSSZ ssz unmarshals the Enrs object -func (e *Enrs) UnmarshalSSZ(buf []byte) error { - var err error - tail := buf - // Field (0) 'Enrs' - { - buf = tail[:] - num, err := ssz.DecodeDynamicLength(buf, 32) - if err != nil { - return err - } - e.Enrs = make([][]byte, num) - err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(e.Enrs[indx]) == 0 { - e.Enrs[indx] = make([]byte, 0, len(buf)) - } - e.Enrs[indx] = append(e.Enrs[indx], buf...) - return nil - }) - if err != nil { - return err - } - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Enrs object -func (e *Enrs) SizeSSZ() (size int) { - size = 0 - - // Field (0) 'Enrs' - for ii := 0; ii < len(e.Enrs); ii++ { - size += 4 - size += len(e.Enrs[ii]) - } - - return -} - -// HashTreeRoot ssz hashes the Enrs object -func (e *Enrs) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(e) -} - -// HashTreeRootWith ssz hashes the Enrs object with a hasher -func (e *Enrs) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'Enrs' - { - subIndx := hh.Index() - num := uint64(len(e.Enrs)) - if num > 32 { - err = ssz.ErrIncorrectListSize - return - } - for _, elem := range e.Enrs { - { - elemIndx := hh.Index() - byteLen := uint64(len(elem)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.AppendBytes32(elem) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - } - hh.MerkleizeWithMixin(subIndx, num, 32) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Enrs object -func (e *Enrs) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(e) -} diff --git a/portalnetwork/portalwire/messages_encoding.go b/portalnetwork/portalwire/messages_encoding.go deleted file mode 100644 index 601150baff1a..000000000000 --- a/portalnetwork/portalwire/messages_encoding.go +++ /dev/null @@ -1,957 +0,0 @@ -// Code generated by fastssz. DO NOT EDIT. -// Hash: 26a61b12807ff78c64a029acdd5bcb580dfe35b7bfbf8bf04ceebae1a3d5cac1 -// Version: 0.1.3 -package portalwire - -import ( - ssz "github.com/ferranbt/fastssz" -) - -// MarshalSSZ ssz marshals the PingPongCustomData object -func (p *PingPongCustomData) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(p) -} - -// MarshalSSZTo ssz marshals the PingPongCustomData object to a target array -func (p *PingPongCustomData) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - - // Field (0) 'Radius' - if size := len(p.Radius); size != 32 { - err = ssz.ErrBytesLengthFn("PingPongCustomData.Radius", size, 32) - return - } - dst = append(dst, p.Radius...) - - return -} - -// UnmarshalSSZ ssz unmarshals the PingPongCustomData object -func (p *PingPongCustomData) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size != 32 { - return ssz.ErrSize - } - - // Field (0) 'Radius' - if cap(p.Radius) == 0 { - p.Radius = make([]byte, 0, len(buf[0:32])) - } - p.Radius = append(p.Radius, buf[0:32]...) - - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the PingPongCustomData object -func (p *PingPongCustomData) SizeSSZ() (size int) { - size = 32 - return -} - -// HashTreeRoot ssz hashes the PingPongCustomData object -func (p *PingPongCustomData) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(p) -} - -// HashTreeRootWith ssz hashes the PingPongCustomData object with a hasher -func (p *PingPongCustomData) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'Radius' - if size := len(p.Radius); size != 32 { - err = ssz.ErrBytesLengthFn("PingPongCustomData.Radius", size, 32) - return - } - hh.PutBytes(p.Radius) - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the PingPongCustomData object -func (p *PingPongCustomData) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(p) -} - -// MarshalSSZ ssz marshals the Ping object -func (p *Ping) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(p) -} - -// MarshalSSZTo ssz marshals the Ping object to a target array -func (p *Ping) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(12) - - // Field (0) 'EnrSeq' - dst = ssz.MarshalUint64(dst, p.EnrSeq) - - // Offset (1) 'CustomPayload' - dst = ssz.WriteOffset(dst, offset) - offset += len(p.CustomPayload) - - // Field (1) 'CustomPayload' - if size := len(p.CustomPayload); size > 2048 { - err = ssz.ErrBytesLengthFn("Ping.CustomPayload", size, 2048) - return - } - dst = append(dst, p.CustomPayload...) - - return -} - -// UnmarshalSSZ ssz unmarshals the Ping object -func (p *Ping) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 12 { - return ssz.ErrSize - } - - tail := buf - var o1 uint64 - - // Field (0) 'EnrSeq' - p.EnrSeq = ssz.UnmarshallUint64(buf[0:8]) - - // Offset (1) 'CustomPayload' - if o1 = ssz.ReadOffset(buf[8:12]); o1 > size { - return ssz.ErrOffset - } - - if o1 < 12 { - return ssz.ErrInvalidVariableOffset - } - - // Field (1) 'CustomPayload' - { - buf = tail[o1:] - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(p.CustomPayload) == 0 { - p.CustomPayload = make([]byte, 0, len(buf)) - } - p.CustomPayload = append(p.CustomPayload, buf...) - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Ping object -func (p *Ping) SizeSSZ() (size int) { - size = 12 - - // Field (1) 'CustomPayload' - size += len(p.CustomPayload) - - return -} - -// HashTreeRoot ssz hashes the Ping object -func (p *Ping) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(p) -} - -// HashTreeRootWith ssz hashes the Ping object with a hasher -func (p *Ping) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'EnrSeq' - hh.PutUint64(p.EnrSeq) - - // Field (1) 'CustomPayload' - { - elemIndx := hh.Index() - byteLen := uint64(len(p.CustomPayload)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.Append(p.CustomPayload) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Ping object -func (p *Ping) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(p) -} - -// MarshalSSZ ssz marshals the FindNodes object -func (f *FindNodes) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(f) -} - -// MarshalSSZTo ssz marshals the FindNodes object to a target array -func (f *FindNodes) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(4) - - // Offset (0) 'Distances' - dst = ssz.WriteOffset(dst, offset) - offset += len(f.Distances) * 2 - - // Field (0) 'Distances' - if size := len(f.Distances); size > 256 { - err = ssz.ErrListTooBigFn("FindNodes.Distances", size, 256) - return - } - for ii := 0; ii < len(f.Distances); ii++ { - dst = append(dst, f.Distances[ii][:]...) - } - - return -} - -// UnmarshalSSZ ssz unmarshals the FindNodes object -func (f *FindNodes) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 4 { - return ssz.ErrSize - } - - tail := buf - var o0 uint64 - - // Offset (0) 'Distances' - if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { - return ssz.ErrOffset - } - - if o0 < 4 { - return ssz.ErrInvalidVariableOffset - } - - // Field (0) 'Distances' - { - buf = tail[o0:] - num, err := ssz.DivideInt2(len(buf), 2, 256) - if err != nil { - return err - } - f.Distances = make([][2]byte, num) - for ii := 0; ii < num; ii++ { - copy(f.Distances[ii][:], buf[ii*2:(ii+1)*2]) - } - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the FindNodes object -func (f *FindNodes) SizeSSZ() (size int) { - size = 4 - - // Field (0) 'Distances' - size += len(f.Distances) * 2 - - return -} - -// HashTreeRoot ssz hashes the FindNodes object -func (f *FindNodes) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(f) -} - -// HashTreeRootWith ssz hashes the FindNodes object with a hasher -func (f *FindNodes) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'Distances' - { - if size := len(f.Distances); size > 256 { - err = ssz.ErrListTooBigFn("FindNodes.Distances", size, 256) - return - } - subIndx := hh.Index() - for _, i := range f.Distances { - hh.PutBytes(i[:]) - } - numItems := uint64(len(f.Distances)) - hh.MerkleizeWithMixin(subIndx, numItems, 256) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the FindNodes object -func (f *FindNodes) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(f) -} - -// MarshalSSZ ssz marshals the FindContent object -func (f *FindContent) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(f) -} - -// MarshalSSZTo ssz marshals the FindContent object to a target array -func (f *FindContent) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(4) - - // Offset (0) 'ContentKey' - dst = ssz.WriteOffset(dst, offset) - offset += len(f.ContentKey) - - // Field (0) 'ContentKey' - if size := len(f.ContentKey); size > 2048 { - err = ssz.ErrBytesLengthFn("FindContent.ContentKey", size, 2048) - return - } - dst = append(dst, f.ContentKey...) - - return -} - -// UnmarshalSSZ ssz unmarshals the FindContent object -func (f *FindContent) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 4 { - return ssz.ErrSize - } - - tail := buf - var o0 uint64 - - // Offset (0) 'ContentKey' - if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { - return ssz.ErrOffset - } - - if o0 < 4 { - return ssz.ErrInvalidVariableOffset - } - - // Field (0) 'ContentKey' - { - buf = tail[o0:] - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(f.ContentKey) == 0 { - f.ContentKey = make([]byte, 0, len(buf)) - } - f.ContentKey = append(f.ContentKey, buf...) - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the FindContent object -func (f *FindContent) SizeSSZ() (size int) { - size = 4 - - // Field (0) 'ContentKey' - size += len(f.ContentKey) - - return -} - -// HashTreeRoot ssz hashes the FindContent object -func (f *FindContent) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(f) -} - -// HashTreeRootWith ssz hashes the FindContent object with a hasher -func (f *FindContent) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'ContentKey' - { - elemIndx := hh.Index() - byteLen := uint64(len(f.ContentKey)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.Append(f.ContentKey) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the FindContent object -func (f *FindContent) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(f) -} - -// MarshalSSZ ssz marshals the Offer object -func (o *Offer) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(o) -} - -// MarshalSSZTo ssz marshals the Offer object to a target array -func (o *Offer) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(4) - - // Offset (0) 'ContentKeys' - dst = ssz.WriteOffset(dst, offset) - for ii := 0; ii < len(o.ContentKeys); ii++ { - offset += 4 - offset += len(o.ContentKeys[ii]) - } - - // Field (0) 'ContentKeys' - if size := len(o.ContentKeys); size > 64 { - err = ssz.ErrListTooBigFn("Offer.ContentKeys", size, 64) - return - } - { - offset = 4 * len(o.ContentKeys) - for ii := 0; ii < len(o.ContentKeys); ii++ { - dst = ssz.WriteOffset(dst, offset) - offset += len(o.ContentKeys[ii]) - } - } - for ii := 0; ii < len(o.ContentKeys); ii++ { - if size := len(o.ContentKeys[ii]); size > 2048 { - err = ssz.ErrBytesLengthFn("Offer.ContentKeys[ii]", size, 2048) - return - } - dst = append(dst, o.ContentKeys[ii]...) - } - - return -} - -// UnmarshalSSZ ssz unmarshals the Offer object -func (o *Offer) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 4 { - return ssz.ErrSize - } - - tail := buf - var o0 uint64 - - // Offset (0) 'ContentKeys' - if o0 = ssz.ReadOffset(buf[0:4]); o0 > size { - return ssz.ErrOffset - } - - if o0 < 4 { - return ssz.ErrInvalidVariableOffset - } - - // Field (0) 'ContentKeys' - { - buf = tail[o0:] - num, err := ssz.DecodeDynamicLength(buf, 64) - if err != nil { - return err - } - o.ContentKeys = make([][]byte, num) - err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(o.ContentKeys[indx]) == 0 { - o.ContentKeys[indx] = make([]byte, 0, len(buf)) - } - o.ContentKeys[indx] = append(o.ContentKeys[indx], buf...) - return nil - }) - if err != nil { - return err - } - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Offer object -func (o *Offer) SizeSSZ() (size int) { - size = 4 - - // Field (0) 'ContentKeys' - for ii := 0; ii < len(o.ContentKeys); ii++ { - size += 4 - size += len(o.ContentKeys[ii]) - } - - return -} - -// HashTreeRoot ssz hashes the Offer object -func (o *Offer) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(o) -} - -// HashTreeRootWith ssz hashes the Offer object with a hasher -func (o *Offer) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'ContentKeys' - { - subIndx := hh.Index() - num := uint64(len(o.ContentKeys)) - if num > 64 { - err = ssz.ErrIncorrectListSize - return - } - for _, elem := range o.ContentKeys { - { - elemIndx := hh.Index() - byteLen := uint64(len(elem)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.AppendBytes32(elem) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - } - hh.MerkleizeWithMixin(subIndx, num, 64) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Offer object -func (o *Offer) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(o) -} - -// MarshalSSZ ssz marshals the Pong object -func (p *Pong) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(p) -} - -// MarshalSSZTo ssz marshals the Pong object to a target array -func (p *Pong) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(12) - - // Field (0) 'EnrSeq' - dst = ssz.MarshalUint64(dst, p.EnrSeq) - - // Offset (1) 'CustomPayload' - dst = ssz.WriteOffset(dst, offset) - offset += len(p.CustomPayload) - - // Field (1) 'CustomPayload' - if size := len(p.CustomPayload); size > 2048 { - err = ssz.ErrBytesLengthFn("Pong.CustomPayload", size, 2048) - return - } - dst = append(dst, p.CustomPayload...) - - return -} - -// UnmarshalSSZ ssz unmarshals the Pong object -func (p *Pong) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 12 { - return ssz.ErrSize - } - - tail := buf - var o1 uint64 - - // Field (0) 'EnrSeq' - p.EnrSeq = ssz.UnmarshallUint64(buf[0:8]) - - // Offset (1) 'CustomPayload' - if o1 = ssz.ReadOffset(buf[8:12]); o1 > size { - return ssz.ErrOffset - } - - if o1 < 12 { - return ssz.ErrInvalidVariableOffset - } - - // Field (1) 'CustomPayload' - { - buf = tail[o1:] - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(p.CustomPayload) == 0 { - p.CustomPayload = make([]byte, 0, len(buf)) - } - p.CustomPayload = append(p.CustomPayload, buf...) - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Pong object -func (p *Pong) SizeSSZ() (size int) { - size = 12 - - // Field (1) 'CustomPayload' - size += len(p.CustomPayload) - - return -} - -// HashTreeRoot ssz hashes the Pong object -func (p *Pong) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(p) -} - -// HashTreeRootWith ssz hashes the Pong object with a hasher -func (p *Pong) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'EnrSeq' - hh.PutUint64(p.EnrSeq) - - // Field (1) 'CustomPayload' - { - elemIndx := hh.Index() - byteLen := uint64(len(p.CustomPayload)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.Append(p.CustomPayload) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Pong object -func (p *Pong) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(p) -} - -// MarshalSSZ ssz marshals the Nodes object -func (n *Nodes) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(n) -} - -// MarshalSSZTo ssz marshals the Nodes object to a target array -func (n *Nodes) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(5) - - // Field (0) 'Total' - dst = ssz.MarshalUint8(dst, n.Total) - - // Offset (1) 'Enrs' - dst = ssz.WriteOffset(dst, offset) - for ii := 0; ii < len(n.Enrs); ii++ { - offset += 4 - offset += len(n.Enrs[ii]) - } - - // Field (1) 'Enrs' - if size := len(n.Enrs); size > 32 { - err = ssz.ErrListTooBigFn("Nodes.Enrs", size, 32) - return - } - { - offset = 4 * len(n.Enrs) - for ii := 0; ii < len(n.Enrs); ii++ { - dst = ssz.WriteOffset(dst, offset) - offset += len(n.Enrs[ii]) - } - } - for ii := 0; ii < len(n.Enrs); ii++ { - if size := len(n.Enrs[ii]); size > 2048 { - err = ssz.ErrBytesLengthFn("Nodes.Enrs[ii]", size, 2048) - return - } - dst = append(dst, n.Enrs[ii]...) - } - - return -} - -// UnmarshalSSZ ssz unmarshals the Nodes object -func (n *Nodes) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 5 { - return ssz.ErrSize - } - - tail := buf - var o1 uint64 - - // Field (0) 'Total' - n.Total = ssz.UnmarshallUint8(buf[0:1]) - - // Offset (1) 'Enrs' - if o1 = ssz.ReadOffset(buf[1:5]); o1 > size { - return ssz.ErrOffset - } - - if o1 < 5 { - return ssz.ErrInvalidVariableOffset - } - - // Field (1) 'Enrs' - { - buf = tail[o1:] - num, err := ssz.DecodeDynamicLength(buf, 32) - if err != nil { - return err - } - n.Enrs = make([][]byte, num) - err = ssz.UnmarshalDynamic(buf, num, func(indx int, buf []byte) (err error) { - if len(buf) > 2048 { - return ssz.ErrBytesLength - } - if cap(n.Enrs[indx]) == 0 { - n.Enrs[indx] = make([]byte, 0, len(buf)) - } - n.Enrs[indx] = append(n.Enrs[indx], buf...) - return nil - }) - if err != nil { - return err - } - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Nodes object -func (n *Nodes) SizeSSZ() (size int) { - size = 5 - - // Field (1) 'Enrs' - for ii := 0; ii < len(n.Enrs); ii++ { - size += 4 - size += len(n.Enrs[ii]) - } - - return -} - -// HashTreeRoot ssz hashes the Nodes object -func (n *Nodes) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(n) -} - -// HashTreeRootWith ssz hashes the Nodes object with a hasher -func (n *Nodes) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'Total' - hh.PutUint8(n.Total) - - // Field (1) 'Enrs' - { - subIndx := hh.Index() - num := uint64(len(n.Enrs)) - if num > 32 { - err = ssz.ErrIncorrectListSize - return - } - for _, elem := range n.Enrs { - { - elemIndx := hh.Index() - byteLen := uint64(len(elem)) - if byteLen > 2048 { - err = ssz.ErrIncorrectListSize - return - } - hh.AppendBytes32(elem) - hh.MerkleizeWithMixin(elemIndx, byteLen, (2048+31)/32) - } - } - hh.MerkleizeWithMixin(subIndx, num, 32) - } - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Nodes object -func (n *Nodes) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(n) -} - -// MarshalSSZ ssz marshals the ConnectionId object -func (c *ConnectionId) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(c) -} - -// MarshalSSZTo ssz marshals the ConnectionId object to a target array -func (c *ConnectionId) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - - // Field (0) 'Id' - if size := len(c.Id); size != 2 { - err = ssz.ErrBytesLengthFn("ConnectionId.Id", size, 2) - return - } - dst = append(dst, c.Id...) - - return -} - -// UnmarshalSSZ ssz unmarshals the ConnectionId object -func (c *ConnectionId) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size != 2 { - return ssz.ErrSize - } - - // Field (0) 'Id' - if cap(c.Id) == 0 { - c.Id = make([]byte, 0, len(buf[0:2])) - } - c.Id = append(c.Id, buf[0:2]...) - - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the ConnectionId object -func (c *ConnectionId) SizeSSZ() (size int) { - size = 2 - return -} - -// HashTreeRoot ssz hashes the ConnectionId object -func (c *ConnectionId) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(c) -} - -// HashTreeRootWith ssz hashes the ConnectionId object with a hasher -func (c *ConnectionId) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'Id' - if size := len(c.Id); size != 2 { - err = ssz.ErrBytesLengthFn("ConnectionId.Id", size, 2) - return - } - hh.PutBytes(c.Id) - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the ConnectionId object -func (c *ConnectionId) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(c) -} - -// MarshalSSZ ssz marshals the Accept object -func (a *Accept) MarshalSSZ() ([]byte, error) { - return ssz.MarshalSSZ(a) -} - -// MarshalSSZTo ssz marshals the Accept object to a target array -func (a *Accept) MarshalSSZTo(buf []byte) (dst []byte, err error) { - dst = buf - offset := int(6) - - // Field (0) 'ConnectionId' - if size := len(a.ConnectionId); size != 2 { - err = ssz.ErrBytesLengthFn("Accept.ConnectionId", size, 2) - return - } - dst = append(dst, a.ConnectionId...) - - // Offset (1) 'ContentKeys' - dst = ssz.WriteOffset(dst, offset) - offset += len(a.ContentKeys) - - // Field (1) 'ContentKeys' - if size := len(a.ContentKeys); size > 64 { - err = ssz.ErrBytesLengthFn("Accept.ContentKeys", size, 64) - return - } - dst = append(dst, a.ContentKeys...) - - return -} - -// UnmarshalSSZ ssz unmarshals the Accept object -func (a *Accept) UnmarshalSSZ(buf []byte) error { - var err error - size := uint64(len(buf)) - if size < 6 { - return ssz.ErrSize - } - - tail := buf - var o1 uint64 - - // Field (0) 'ConnectionId' - if cap(a.ConnectionId) == 0 { - a.ConnectionId = make([]byte, 0, len(buf[0:2])) - } - a.ConnectionId = append(a.ConnectionId, buf[0:2]...) - - // Offset (1) 'ContentKeys' - if o1 = ssz.ReadOffset(buf[2:6]); o1 > size { - return ssz.ErrOffset - } - - if o1 < 6 { - return ssz.ErrInvalidVariableOffset - } - - // Field (1) 'ContentKeys' - { - buf = tail[o1:] - if err = ssz.ValidateBitlist(buf, 64); err != nil { - return err - } - if cap(a.ContentKeys) == 0 { - a.ContentKeys = make([]byte, 0, len(buf)) - } - a.ContentKeys = append(a.ContentKeys, buf...) - } - return err -} - -// SizeSSZ returns the ssz encoded size in bytes for the Accept object -func (a *Accept) SizeSSZ() (size int) { - size = 6 - - // Field (1) 'ContentKeys' - size += len(a.ContentKeys) - - return -} - -// HashTreeRoot ssz hashes the Accept object -func (a *Accept) HashTreeRoot() ([32]byte, error) { - return ssz.HashWithDefaultHasher(a) -} - -// HashTreeRootWith ssz hashes the Accept object with a hasher -func (a *Accept) HashTreeRootWith(hh ssz.HashWalker) (err error) { - indx := hh.Index() - - // Field (0) 'ConnectionId' - if size := len(a.ConnectionId); size != 2 { - err = ssz.ErrBytesLengthFn("Accept.ConnectionId", size, 2) - return - } - hh.PutBytes(a.ConnectionId) - - // Field (1) 'ContentKeys' - if len(a.ContentKeys) == 0 { - err = ssz.ErrEmptyBitlist - return - } - hh.PutBitlist(a.ContentKeys, 64) - - hh.Merkleize(indx) - return -} - -// GetTree ssz hashes the Accept object -func (a *Accept) GetTree() (*ssz.Node, error) { - return ssz.ProofTree(a) -} diff --git a/portalnetwork/portalwire/messages_test.go b/portalnetwork/portalwire/messages_test.go deleted file mode 100644 index 9e266cf41789..000000000000 --- a/portalnetwork/portalwire/messages_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package portalwire - -import ( - "fmt" - "testing" - - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/rlp" - ssz "github.com/ferranbt/fastssz" - "github.com/holiman/uint256" - "github.com/prysmaticlabs/go-bitfield" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var maxUint256 = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") - -// https://github.com/ethereum/portal-network-specs/blob/master/portal-wire-test-vectors.md -// we remove the message type here -func TestPingMessage(t *testing.T) { - dataRadius := maxUint256.Sub(maxUint256, uint256.NewInt(1)) - reverseBytes, err := dataRadius.MarshalSSZ() - require.NoError(t, err) - customData := &PingPongCustomData{ - Radius: reverseBytes, - } - dataBytes, err := customData.MarshalSSZ() - assert.NoError(t, err) - ping := &Ping{ - EnrSeq: 1, - CustomPayload: dataBytes, - } - - expected := "0x01000000000000000c000000feffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" - - data, err := ping.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) -} - -func TestPongMessage(t *testing.T) { - dataRadius := maxUint256.Div(maxUint256, uint256.NewInt(2)) - reverseBytes, err := dataRadius.MarshalSSZ() - require.NoError(t, err) - customData := &PingPongCustomData{ - Radius: reverseBytes, - } - - dataBytes, err := customData.MarshalSSZ() - assert.NoError(t, err) - pong := &Pong{ - EnrSeq: 1, - CustomPayload: dataBytes, - } - - expected := "0x01000000000000000c000000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" - - data, err := pong.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) -} - -func TestFindNodesMessage(t *testing.T) { - distances := []uint16{256, 255} - - distancesBytes := make([][2]byte, len(distances)) - for i, distance := range distances { - copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), distance)) - } - - findNode := &FindNodes{ - Distances: distancesBytes, - } - - data, err := findNode.MarshalSSZ() - expected := "0x040000000001ff00" - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) -} - -func TestNodes(t *testing.T) { - enrs := []string{ - "enr:-HW4QBzimRxkmT18hMKaAL3IcZF1UcfTMPyi3Q1pxwZZbcZVRI8DC5infUAB_UauARLOJtYTxaagKoGmIjzQxO2qUygBgmlkgnY0iXNlY3AyNTZrMaEDymNMrg1JrLQB2KTGtv6MVbcNEVv0AHacwUAPMljNMTg", - "enr:-HW4QNfxw543Ypf4HXKXdYxkyzfcxcO-6p9X986WldfVpnVTQX1xlTnWrktEWUbeTZnmgOuAY_KUhbVV1Ft98WoYUBMBgmlkgnY0iXNlY3AyNTZrMaEDDiy3QkHAxPyOgWbxp5oF1bDdlYE6dLCUUp8xfVw50jU", - } - - enrsBytes := make([][]byte, 0) - for _, enr := range enrs { - n, err := enode.Parse(enode.ValidSchemes, enr) - assert.NoError(t, err) - - enrBytes, err := rlp.EncodeToBytes(n.Record()) - assert.NoError(t, err) - enrsBytes = append(enrsBytes, enrBytes) - } - - testCases := []struct { - name string - input [][]byte - expected string - }{ - { - name: "empty nodes", - input: make([][]byte, 0), - expected: "0x0105000000", - }, - { - name: "two nodes", - input: enrsBytes, - expected: "0x0105000000080000007f000000f875b8401ce2991c64993d7c84c29a00bdc871917551c7d330fca2dd0d69c706596dc655448f030b98a77d4001fd46ae0112ce26d613c5a6a02a81a6223cd0c4edaa53280182696482763489736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138f875b840d7f1c39e376297f81d7297758c64cb37dcc5c3beea9f57f7ce9695d7d5a67553417d719539d6ae4b445946de4d99e680eb8063f29485b555d45b7df16a1850130182696482763489736563703235366b31a1030e2cb74241c0c4fc8e8166f1a79a05d5b0dd95813a74b094529f317d5c39d235", - }, - } - - for _, test := range testCases { - t.Run(test.name, func(t *testing.T) { - nodes := &Nodes{ - Total: 1, - Enrs: test.input, - } - - data, err := nodes.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, test.expected, fmt.Sprintf("0x%x", data)) - }) - } -} - -func TestContent(t *testing.T) { - contentKey := "0x706f7274616c" - - content := &FindContent{ - ContentKey: hexutil.MustDecode(contentKey), - } - expected := "0x04000000706f7274616c" - data, err := content.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) - - expected = "0x7468652063616b652069732061206c6965" - - contentRes := &Content{ - Content: hexutil.MustDecode("0x7468652063616b652069732061206c6965"), - } - - data, err = contentRes.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) - - expectData := &Content{} - err = expectData.UnmarshalSSZ(data) - assert.NoError(t, err) - assert.Equal(t, contentRes.Content, expectData.Content) - - enrs := []string{ - "enr:-HW4QBzimRxkmT18hMKaAL3IcZF1UcfTMPyi3Q1pxwZZbcZVRI8DC5infUAB_UauARLOJtYTxaagKoGmIjzQxO2qUygBgmlkgnY0iXNlY3AyNTZrMaEDymNMrg1JrLQB2KTGtv6MVbcNEVv0AHacwUAPMljNMTg", - "enr:-HW4QNfxw543Ypf4HXKXdYxkyzfcxcO-6p9X986WldfVpnVTQX1xlTnWrktEWUbeTZnmgOuAY_KUhbVV1Ft98WoYUBMBgmlkgnY0iXNlY3AyNTZrMaEDDiy3QkHAxPyOgWbxp5oF1bDdlYE6dLCUUp8xfVw50jU", - } - - enrsBytes := make([][]byte, 0) - for _, enr := range enrs { - n, err := enode.Parse(enode.ValidSchemes, enr) - assert.NoError(t, err) - - enrBytes, err := rlp.EncodeToBytes(n.Record()) - assert.NoError(t, err) - enrsBytes = append(enrsBytes, enrBytes) - } - - enrsRes := &Enrs{ - Enrs: enrsBytes, - } - - expected = "0x080000007f000000f875b8401ce2991c64993d7c84c29a00bdc871917551c7d330fca2dd0d69c706596dc655448f030b98a77d4001fd46ae0112ce26d613c5a6a02a81a6223cd0c4edaa53280182696482763489736563703235366b31a103ca634cae0d49acb401d8a4c6b6fe8c55b70d115bf400769cc1400f3258cd3138f875b840d7f1c39e376297f81d7297758c64cb37dcc5c3beea9f57f7ce9695d7d5a67553417d719539d6ae4b445946de4d99e680eb8063f29485b555d45b7df16a1850130182696482763489736563703235366b31a1030e2cb74241c0c4fc8e8166f1a79a05d5b0dd95813a74b094529f317d5c39d235" - - data, err = enrsRes.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) - - expectEnrs := &Enrs{} - err = expectEnrs.UnmarshalSSZ(data) - assert.NoError(t, err) - assert.Equal(t, expectEnrs.Enrs, enrsRes.Enrs) -} - -func TestOfferAndAcceptMessage(t *testing.T) { - contentKey := "0x010203" - contentBytes := hexutil.MustDecode(contentKey) - contentKeys := [][]byte{contentBytes} - offer := &Offer{ - ContentKeys: contentKeys, - } - - expected := "0x0400000004000000010203" - - data, err := offer.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) - - contentKeyBitlist := bitfield.NewBitlist(8) - contentKeyBitlist.SetBitAt(0, true) - accept := &Accept{ - ConnectionId: []byte{0x01, 0x02}, - ContentKeys: contentKeyBitlist, - } - - expected = "0x0102060000000101" - - data, err = accept.MarshalSSZ() - assert.NoError(t, err) - assert.Equal(t, expected, fmt.Sprintf("0x%x", data)) -} diff --git a/p2p/discover/nat.go b/portalnetwork/portalwire/nat.go similarity index 99% rename from p2p/discover/nat.go rename to portalnetwork/portalwire/nat.go index 0202ef3c6637..6e70e72a5686 100644 --- a/p2p/discover/nat.go +++ b/portalnetwork/portalwire/nat.go @@ -1,4 +1,4 @@ -package discover +package portalwire import ( "net" diff --git a/portalnetwork/portal_protocol.go b/portalnetwork/portalwire/portal_protocol.go similarity index 93% rename from portalnetwork/portal_protocol.go rename to portalnetwork/portalwire/portal_protocol.go index 126acb82a7ee..03c89ec81f47 100644 --- a/portalnetwork/portal_protocol.go +++ b/portalnetwork/portalwire/portal_protocol.go @@ -1,4 +1,4 @@ -package portalnetwork +package portalwire import ( "bytes" @@ -31,7 +31,6 @@ import ( "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/nat" "github.com/ethereum/go-ethereum/p2p/netutil" - "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/ethereum/go-ethereum/rlp" ssz "github.com/ferranbt/fastssz" @@ -142,8 +141,7 @@ type traceContentInfoResp struct { type PortalProtocolOption func(p *PortalProtocol) type PortalProtocolConfig struct { - BootstrapNodes []*enode.Node - // NodeIP net.IP + BootstrapNodes []*enode.Node ListenAddr string NetRestrict *netutil.Netlist NodeRadius *uint256.Int @@ -203,7 +201,7 @@ func defaultContentIdFunc(contentKey []byte) []byte { return digest[:] } -func NewPortalProtocol(config *PortalProtocolConfig, protocolId portalwire.ProtocolId, privateKey *ecdsa.PrivateKey, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *PortalUtp, storage storage.ContentStorage, contentQueue chan *ContentElement, opts ...PortalProtocolOption) (*PortalProtocol, error) { +func NewPortalProtocol(config *PortalProtocolConfig, protocolId ProtocolId, privateKey *ecdsa.PrivateKey, conn discover.UDPConn, localNode *enode.LocalNode, discV5 *discover.UDPv5, utp *PortalUtp, storage storage.ContentStorage, contentQueue chan *ContentElement, opts ...PortalProtocolOption) (*PortalProtocol, error) { closeCtx, cancelCloseCtx := context.WithCancel(context.Background()) protocol := &PortalProtocol{ @@ -340,13 +338,13 @@ func (p *PortalProtocol) Ping(node *enode.Node) (uint64, error) { return pong.EnrSeq, nil } -func (p *PortalProtocol) pingInner(node *enode.Node) (*portalwire.Pong, error) { +func (p *PortalProtocol) pingInner(node *enode.Node) (*Pong, error) { enrSeq := p.Self().Seq() radiusBytes, err := p.Radius().MarshalSSZ() if err != nil { return nil, err } - customPayload := &portalwire.PingPongCustomData{ + customPayload := &PingPongCustomData{ Radius: radiusBytes, } @@ -355,7 +353,7 @@ func (p *PortalProtocol) pingInner(node *enode.Node) (*portalwire.Pong, error) { return nil, err } - pingRequest := &portalwire.Ping{ + pingRequest := &Ping{ EnrSeq: enrSeq, CustomPayload: customPayloadBytes, } @@ -370,7 +368,7 @@ func (p *PortalProtocol) pingInner(node *enode.Node) (*portalwire.Pong, error) { } talkRequestBytes := make([]byte, 0, len(pingRequestBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.PING) + talkRequestBytes = append(talkRequestBytes, PING) talkRequestBytes = append(talkRequestBytes, pingRequestBytes...) talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) @@ -397,7 +395,7 @@ func (p *PortalProtocol) findNodes(node *enode.Node, distances []uint) ([]*enode copy(distancesBytes[i][:], ssz.MarshalUint16(make([]byte, 0), uint16(distance))) } - findNodes := &portalwire.FindNodes{ + findNodes := &FindNodes{ Distances: distancesBytes, } @@ -412,7 +410,7 @@ func (p *PortalProtocol) findNodes(node *enode.Node, distances []uint) ([]*enode } talkRequestBytes := make([]byte, 0, len(findNodesBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.FINDNODES) + talkRequestBytes = append(talkRequestBytes, FINDNODES) talkRequestBytes = append(talkRequestBytes, findNodesBytes...) talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) @@ -425,7 +423,7 @@ func (p *PortalProtocol) findNodes(node *enode.Node, distances []uint) ([]*enode } func (p *PortalProtocol) findContent(node *enode.Node, contentKey []byte) (byte, interface{}, error) { - findContent := &portalwire.FindContent{ + findContent := &FindContent{ ContentKey: contentKey, } @@ -440,7 +438,7 @@ func (p *PortalProtocol) findContent(node *enode.Node, contentKey []byte) (byte, } talkRequestBytes := make([]byte, 0, len(findContentBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.FINDCONTENT) + talkRequestBytes = append(talkRequestBytes, FINDCONTENT) talkRequestBytes = append(talkRequestBytes, findContentBytes...) talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) @@ -455,7 +453,7 @@ func (p *PortalProtocol) findContent(node *enode.Node, contentKey []byte) (byte, func (p *PortalProtocol) offer(node *enode.Node, offerRequest *OfferRequest) ([]byte, error) { contentKeys := getContentKeys(offerRequest) - offer := &portalwire.Offer{ + offer := &Offer{ ContentKeys: contentKeys, } @@ -470,7 +468,7 @@ func (p *PortalProtocol) offer(node *enode.Node, offerRequest *OfferRequest) ([] } talkRequestBytes := make([]byte, 0, len(offerBytes)+1) - talkRequestBytes = append(talkRequestBytes, portalwire.OFFER) + talkRequestBytes = append(talkRequestBytes, OFFER) talkRequestBytes = append(talkRequestBytes, offerBytes...) talkResp, err := p.DiscV5.TalkRequest(node, p.protocolId, talkRequestBytes) @@ -487,13 +485,13 @@ func (p *PortalProtocol) processOffer(target *enode.Node, resp []byte, request * if len(resp) == 0 { return nil, ErrEmptyResp } - if resp[0] != portalwire.ACCEPT { + if resp[0] != ACCEPT { return nil, fmt.Errorf("invalid accept response") } p.Log.Info("will process Offer", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) - accept := &portalwire.Accept{} + accept := &Accept{} err = accept.UnmarshalSSZ(resp[1:]) if err != nil { return nil, err @@ -621,15 +619,15 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, return 0x00, nil, ErrEmptyResp } - if resp[0] != portalwire.CONTENT { + if resp[0] != CONTENT { return 0xff, nil, fmt.Errorf("invalid content response") } p.Log.Info("will process content", "id", target.ID(), "ip", target.IP().To4().String(), "port", target.UDP()) switch resp[1] { - case portalwire.ContentRawSelector: - content := &portalwire.Content{} + case ContentRawSelector: + content := &Content{} err := content.UnmarshalSSZ(resp[2:]) if err != nil { return 0xff, nil, err @@ -646,8 +644,8 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, log.Debug("Node added to replacements list", "protocol", p.protocolName, "node", target.IP(), "port", target.UDP()) } return resp[1], content.Content, nil - case portalwire.ContentConnIdSelector: - connIdMsg := &portalwire.ConnectionId{} + case ContentConnIdSelector: + connIdMsg := &ConnectionId{} err := connIdMsg.UnmarshalSSZ(resp[2:]) if err != nil { return 0xff, nil, err @@ -705,8 +703,8 @@ func (p *PortalProtocol) processContent(target *enode.Node, resp []byte) (byte, p.portalMetrics.utpInSuccess.Inc(1) } return resp[1], data, nil - case portalwire.ContentEnrsSelector: - enrs := &portalwire.Enrs{} + case ContentEnrsSelector: + enrs := &Enrs{} err := enrs.UnmarshalSSZ(resp[2:]) if err != nil { @@ -735,11 +733,11 @@ func (p *PortalProtocol) processNodes(target *enode.Node, resp []byte, distances return nil, ErrEmptyResp } - if resp[0] != portalwire.NODES { + if resp[0] != NODES { return nil, fmt.Errorf("invalid nodes response") } - nodesResp := &portalwire.Nodes{} + nodesResp := &Nodes{} err := nodesResp.UnmarshalSSZ(resp[1:]) if err != nil { return nil, err @@ -788,14 +786,14 @@ func (p *PortalProtocol) filterNodes(target *enode.Node, enrs [][]byte, distance return nodes } -func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwire.Pong, error) { +func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*Pong, error) { if len(resp) == 0 { return nil, ErrEmptyResp } - if resp[0] != portalwire.PONG { + if resp[0] != PONG { return nil, fmt.Errorf("invalid pong response") } - pong := &portalwire.Pong{} + pong := &Pong{} err := pong.UnmarshalSSZ(resp[1:]) if err != nil { return nil, err @@ -806,7 +804,7 @@ func (p *PortalProtocol) processPong(target *enode.Node, resp []byte) (*portalwi p.portalMetrics.messagesReceivedPong.Mark(1) } - customPayload := &portalwire.PingPongCustomData{} + customPayload := &PingPongCustomData{} err = customPayload.UnmarshalSSZ(pong.CustomPayload) if err != nil { return nil, err @@ -835,8 +833,8 @@ func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg [ msgCode := msg[0] switch msgCode { - case portalwire.PING: - pingRequest := &portalwire.Ping{} + case PING: + pingRequest := &Ping{} err := pingRequest.UnmarshalSSZ(msg[1:]) if err != nil { p.Log.Error("failed to unmarshal ping request", "err", err) @@ -854,8 +852,8 @@ func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg [ } return resp - case portalwire.FINDNODES: - findNodesRequest := &portalwire.FindNodes{} + case FINDNODES: + findNodesRequest := &FindNodes{} err := findNodesRequest.UnmarshalSSZ(msg[1:]) if err != nil { p.Log.Error("failed to unmarshal find nodes request", "err", err) @@ -873,8 +871,8 @@ func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg [ } return resp - case portalwire.FINDCONTENT: - findContentRequest := &portalwire.FindContent{} + case FINDCONTENT: + findContentRequest := &FindContent{} err := findContentRequest.UnmarshalSSZ(msg[1:]) if err != nil { p.Log.Error("failed to unmarshal find content request", "err", err) @@ -892,8 +890,8 @@ func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg [ } return resp - case portalwire.OFFER: - offerRequest := &portalwire.Offer{} + case OFFER: + offerRequest := &Offer{} err := offerRequest.UnmarshalSSZ(msg[1:]) if err != nil { p.Log.Error("failed to unmarshal offer request", "err", err) @@ -916,8 +914,8 @@ func (p *PortalProtocol) handleTalkRequest(id enode.ID, addr *net.UDPAddr, msg [ return nil } -func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, error) { - pingCustomPayload := &portalwire.PingPongCustomData{} +func (p *PortalProtocol) handlePing(id enode.ID, ping *Ping) ([]byte, error) { + pingCustomPayload := &PingPongCustomData{} err := pingCustomPayload.UnmarshalSSZ(ping.CustomPayload) if err != nil { return nil, err @@ -930,7 +928,7 @@ func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, if err != nil { return nil, err } - pongCustomPayload := &portalwire.PingPongCustomData{ + pongCustomPayload := &PingPongCustomData{ Radius: radiusBytes, } @@ -939,7 +937,7 @@ func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, return nil, err } - pong := &portalwire.Pong{ + pong := &Pong{ EnrSeq: enrSeq, CustomPayload: pongCustomPayloadBytes, } @@ -955,13 +953,13 @@ func (p *PortalProtocol) handlePing(id enode.ID, ping *portalwire.Ping) ([]byte, } talkRespBytes := make([]byte, 0, len(pongBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.PONG) + talkRespBytes = append(talkRespBytes, PONG) talkRespBytes = append(talkRespBytes, pongBytes...) return talkRespBytes, nil } -func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *portalwire.FindNodes) ([]byte, error) { +func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *FindNodes) ([]byte, error) { distances := make([]uint, len(request.Distances)) for i, distance := range request.Distances { distances[i] = uint(ssz.UnmarshallUint16(distance[:])) @@ -975,7 +973,7 @@ func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *portalw enrs := p.truncateNodes(nodes, maxPayloadSize, enrOverhead) - nodesMsg := &portalwire.Nodes{ + nodesMsg := &Nodes{ Total: 1, Enrs: enrs, } @@ -990,13 +988,13 @@ func (p *PortalProtocol) handleFindNodes(fromAddr *net.UDPAddr, request *portalw } talkRespBytes := make([]byte, 0, len(nodesMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.NODES) + talkRespBytes = append(talkRespBytes, NODES) talkRespBytes = append(talkRespBytes, nodesMsgBytes...) return talkRespBytes, nil } -func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, request *portalwire.FindContent) ([]byte, error) { +func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, request *FindContent) ([]byte, error) { contentOverhead := 1 + 1 // msg id + SSZ Union selector maxPayloadSize := v5wire.MaxPacketSize - talkRespOverhead - contentOverhead enrOverhead := 4 //per added ENR, 4 bytes offset overhead @@ -1028,7 +1026,7 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque enrs = nil } - enrsMsg := &portalwire.Enrs{ + enrsMsg := &Enrs{ Enrs: enrs, } @@ -1043,16 +1041,16 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque } contentMsgBytes := make([]byte, 0, len(enrsMsgBytes)+1) - contentMsgBytes = append(contentMsgBytes, portalwire.ContentEnrsSelector) + contentMsgBytes = append(contentMsgBytes, ContentEnrsSelector) contentMsgBytes = append(contentMsgBytes, enrsMsgBytes...) talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, CONTENT) talkRespBytes = append(talkRespBytes, contentMsgBytes...) return talkRespBytes, nil } else if len(content) <= maxPayloadSize { - rawContentMsg := &portalwire.Content{ + rawContentMsg := &Content{ Content: content, } @@ -1068,11 +1066,11 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque } contentMsgBytes := make([]byte, 0, len(rawContentMsgBytes)+1) - contentMsgBytes = append(contentMsgBytes, portalwire.ContentRawSelector) + contentMsgBytes = append(contentMsgBytes, ContentRawSelector) contentMsgBytes = append(contentMsgBytes, rawContentMsgBytes...) talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, CONTENT) talkRespBytes = append(talkRespBytes, contentMsgBytes...) return talkRespBytes, nil @@ -1140,7 +1138,7 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque idBuffer := make([]byte, 2) binary.BigEndian.PutUint16(idBuffer, connectionId.SendId()) - connIdMsg := &portalwire.ConnectionId{ + connIdMsg := &ConnectionId{ Id: idBuffer, } @@ -1155,22 +1153,22 @@ func (p *PortalProtocol) handleFindContent(id enode.ID, addr *net.UDPAddr, reque } contentMsgBytes := make([]byte, 0, len(connIdMsgBytes)+1) - contentMsgBytes = append(contentMsgBytes, portalwire.ContentConnIdSelector) + contentMsgBytes = append(contentMsgBytes, ContentConnIdSelector) contentMsgBytes = append(contentMsgBytes, connIdMsgBytes...) talkRespBytes := make([]byte, 0, len(contentMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.CONTENT) + talkRespBytes = append(talkRespBytes, CONTENT) talkRespBytes = append(talkRespBytes, contentMsgBytes...) return talkRespBytes, nil } } -func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *portalwire.Offer) ([]byte, error) { +func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *Offer) ([]byte, error) { var err error contentKeyBitlist := bitfield.NewBitlist(uint64(len(request.ContentKeys))) if len(p.contentQueue) >= cap(p.contentQueue) { - acceptMsg := &portalwire.Accept{ + acceptMsg := &Accept{ ConnectionId: []byte{0, 0}, ContentKeys: contentKeyBitlist, } @@ -1186,7 +1184,7 @@ func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *po } talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) + talkRespBytes = append(talkRespBytes, ACCEPT) talkRespBytes = append(talkRespBytes, acceptMsgBytes...) return talkRespBytes, nil @@ -1284,7 +1282,7 @@ func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *po binary.BigEndian.PutUint16(idBuffer, uint16(0)) } - acceptMsg := &portalwire.Accept{ + acceptMsg := &Accept{ ConnectionId: idBuffer, ContentKeys: []byte(contentKeyBitlist), } @@ -1300,7 +1298,7 @@ func (p *PortalProtocol) handleOffer(id enode.ID, addr *net.UDPAddr, request *po } talkRespBytes := make([]byte, 0, len(acceptMsgBytes)+1) - talkRespBytes = append(talkRespBytes, portalwire.ACCEPT) + talkRespBytes = append(talkRespBytes, ACCEPT) talkRespBytes = append(talkRespBytes, acceptMsgBytes...) return talkRespBytes, nil @@ -1581,7 +1579,7 @@ func (p *PortalProtocol) ContentLookup(contentKey, contentId []byte) ([]byte, bo go func() { defer wg.Done() for res := range resChan { - if res.Flag != portalwire.ContentEnrsSelector { + if res.Flag != ContentEnrsSelector { result.Content = res.Content.([]byte) result.UtpTransfer = res.UtpTransfer } @@ -1656,7 +1654,7 @@ func (p *PortalProtocol) TraceContentLookup(contentKey, contentId []byte) (*Trac } // no content return if traceContentRes.Content == "" { - if res.Flag == portalwire.ContentRawSelector || res.Flag == portalwire.ContentConnIdSelector { + if res.Flag == ContentRawSelector || res.Flag == ContentConnIdSelector { trace.ReceivedFrom = hexId content := res.Content.([]byte) traceContentRes.Content = hexutil.Encode(content) @@ -1710,7 +1708,7 @@ func (p *PortalProtocol) contentLookupWorker(n *enode.Node, contentKey []byte, r p.Log.Debug("traceContentLookupWorker reveice response", "ip", n.IP().String(), "flag", flag) switch flag { - case portalwire.ContentRawSelector, portalwire.ContentConnIdSelector: + case ContentRawSelector, ContentConnIdSelector: content, ok := content.([]byte) if !ok { return wrapedNode, fmt.Errorf("failed to assert to raw content, value is: %v", content) @@ -1721,7 +1719,7 @@ func (p *PortalProtocol) contentLookupWorker(n *enode.Node, contentKey []byte, r Content: content, UtpTransfer: false, } - if flag == portalwire.ContentConnIdSelector { + if flag == ContentConnIdSelector { res.UtpTransfer = true } if atomic.CompareAndSwapInt32(done, 0, 1) { @@ -1730,7 +1728,7 @@ func (p *PortalProtocol) contentLookupWorker(n *enode.Node, contentKey []byte, r cancel() } return wrapedNode, err - case portalwire.ContentEnrsSelector: + case ContentEnrsSelector: nodes, ok := content.([]*enode.Node) if !ok { return wrapedNode, fmt.Errorf("failed to assert to enrs content, value is: %v", content) @@ -1775,7 +1773,7 @@ func (p *PortalProtocol) Gossip(srcNodeId *enode.ID, contentKeys [][]byte, conte return 0, errors.New("empty content") } - contentList := make([]*ContentEntry, 0, portalwire.ContentKeysLimit) + contentList := make([]*ContentEntry, 0, ContentKeysLimit) for i := 0; i < len(content); i++ { contentEntry := &ContentEntry{ ContentKey: contentKeys[i], diff --git a/portalnetwork/portalwire/portal_protocol_metrics.go b/portalnetwork/portalwire/portal_protocol_metrics.go new file mode 100644 index 000000000000..58faf84586b3 --- /dev/null +++ b/portalnetwork/portalwire/portal_protocol_metrics.go @@ -0,0 +1,217 @@ +package portalwire + +import ( + "database/sql" + "errors" + "os" + "path" + "slices" + "strings" + "time" + + "github.com/ethereum/go-ethereum/log" + "github.com/ethereum/go-ethereum/metrics" +) + +type portalMetrics struct { + messagesReceivedAccept metrics.Meter + messagesReceivedNodes metrics.Meter + messagesReceivedFindNodes metrics.Meter + messagesReceivedFindContent metrics.Meter + messagesReceivedContent metrics.Meter + messagesReceivedOffer metrics.Meter + messagesReceivedPing metrics.Meter + messagesReceivedPong metrics.Meter + + messagesSentAccept metrics.Meter + messagesSentNodes metrics.Meter + messagesSentFindNodes metrics.Meter + messagesSentFindContent metrics.Meter + messagesSentContent metrics.Meter + messagesSentOffer metrics.Meter + messagesSentPing metrics.Meter + messagesSentPong metrics.Meter + + utpInFailConn metrics.Counter + utpInFailRead metrics.Counter + utpInFailDeadline metrics.Counter + utpInSuccess metrics.Counter + + utpOutFailConn metrics.Counter + utpOutFailWrite metrics.Counter + utpOutFailDeadline metrics.Counter + utpOutSuccess metrics.Counter + + contentDecodedTrue metrics.Counter + contentDecodedFalse metrics.Counter +} + +func newPortalMetrics(protocolName string) *portalMetrics { + return &portalMetrics{ + messagesReceivedAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/accept", nil), + messagesReceivedNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/nodes", nil), + messagesReceivedFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_nodes", nil), + messagesReceivedFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/find_content", nil), + messagesReceivedContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/content", nil), + messagesReceivedOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/offer", nil), + messagesReceivedPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/ping", nil), + messagesReceivedPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/received/pong", nil), + messagesSentAccept: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/accept", nil), + messagesSentNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/nodes", nil), + messagesSentFindNodes: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_nodes", nil), + messagesSentFindContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/find_content", nil), + messagesSentContent: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/content", nil), + messagesSentOffer: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/offer", nil), + messagesSentPing: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/ping", nil), + messagesSentPong: metrics.NewRegisteredMeter("portal/"+protocolName+"/sent/pong", nil), + utpInFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_conn", nil), + utpInFailRead: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_read", nil), + utpInFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/fail_deadline", nil), + utpInSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/inbound/success", nil), + utpOutFailConn: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_conn", nil), + utpOutFailWrite: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_write", nil), + utpOutFailDeadline: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/fail_deadline", nil), + utpOutSuccess: metrics.NewRegisteredCounter("portal/"+protocolName+"/utp/outbound/success", nil), + contentDecodedTrue: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/true", nil), + contentDecodedFalse: metrics.NewRegisteredCounter("portal/"+protocolName+"/content/decoded/false", nil), + } +} + +type networkFileMetric struct { + filename string + metric metrics.Gauge + file *os.File + network string +} + +type PortalStorageMetrics struct { + RadiusRatio metrics.GaugeFloat64 + EntriesCount metrics.Gauge + ContentStorageUsage metrics.Gauge +} + +const ( + countEntrySql = "SELECT COUNT(1) FROM kvstore;" + contentStorageUsageSql = "SELECT SUM( length(value) ) FROM kvstore;" +) + +// CollectPortalMetrics periodically collects various metrics about system entities. +func CollectPortalMetrics(refresh time.Duration, networks []string, dataDir string) { + // Short circuit if the metrics system is disabled + if !metrics.Enabled { + return + } + + // Define the various metrics to collect + var ( + historyTotalStorage = metrics.GetOrRegisterGauge("portal/history/total_storage", nil) + beaconTotalStorage = metrics.GetOrRegisterGauge("portal/beacon/total_storage", nil) + stateTotalStorage = metrics.GetOrRegisterGauge("portal/state/total_storage", nil) + ) + + var metricsArr []*networkFileMetric + if slices.Contains(networks, History.Name()) { + dbPath := path.Join(dataDir, History.Name()) + metricsArr = append(metricsArr, &networkFileMetric{ + filename: path.Join(dbPath, History.Name()+".sqlite"), + metric: historyTotalStorage, + network: History.Name(), + }) + } + if slices.Contains(networks, Beacon.Name()) { + dbPath := path.Join(dataDir, Beacon.Name()) + metricsArr = append(metricsArr, &networkFileMetric{ + filename: path.Join(dbPath, Beacon.Name()+".sqlite"), + metric: beaconTotalStorage, + network: Beacon.Name(), + }) + } + if slices.Contains(networks, State.Name()) { + dbPath := path.Join(dataDir, State.Name()) + metricsArr = append(metricsArr, &networkFileMetric{ + filename: path.Join(dbPath, State.Name()+".sqlite"), + metric: stateTotalStorage, + network: State.Name(), + }) + } + + for { + for _, m := range metricsArr { + var err error = nil + if m.file == nil { + m.file, err = os.OpenFile(m.filename, os.O_RDONLY, 0600) + if err != nil { + log.Debug("Could not open file", "network", m.network, "file", m.filename, "metric", "total_storage", "err", err) + } + } + if m.file != nil && err == nil { + stat, err := m.file.Stat() + if err != nil { + log.Warn("Could not get file stat", "network", m.network, "file", m.filename, "metric", "total_storage", "err", err) + } + if err == nil { + m.metric.Update(stat.Size()) + } + } + } + + time.Sleep(refresh) + } +} + +func NewPortalStorageMetrics(network string, db *sql.DB) (*PortalStorageMetrics, error) { + if !metrics.Enabled { + return nil, nil + } + + if network != History.Name() && network != Beacon.Name() && network != State.Name() { + log.Debug("Unknow network for metrics", "network", network) + return nil, errors.New("unknow network for metrics") + } + + var countSql string + var contentSql string + if network == Beacon.Name() { + countSql = strings.Replace(countEntrySql, "kvstore", "beacon", 1) + contentSql = strings.Replace(contentStorageUsageSql, "kvstore", "beacon", 1) + contentSql = strings.Replace(contentSql, "value", "content_value", 1) + } else { + countSql = countEntrySql + contentSql = contentStorageUsageSql + } + + storageMetrics := &PortalStorageMetrics{} + + storageMetrics.RadiusRatio = metrics.NewRegisteredGaugeFloat64("portal/"+network+"/radius_ratio", nil) + storageMetrics.RadiusRatio.Update(1) + + storageMetrics.EntriesCount = metrics.NewRegisteredGauge("portal/"+network+"/entry_count", nil) + log.Debug("Counting entities in " + network + " storage for metrics") + var res = new(int64) + q := db.QueryRow(countSql) + if errors.Is(q.Err(), sql.ErrNoRows) { + storageMetrics.EntriesCount.Update(0) + } else if q.Err() != nil { + log.Error("Querry execution error", "network", network, "metric", "entry_count", "err", q.Err()) + return nil, q.Err() + } else { + q.Scan(res) + storageMetrics.EntriesCount.Update(*res) + } + + storageMetrics.ContentStorageUsage = metrics.NewRegisteredGauge("portal/"+network+"/content_storage", nil) + log.Debug("Counting storage usage (bytes) in " + network + " for metrics") + var res2 = new(int64) + q2 := db.QueryRow(contentSql) + if errors.Is(q2.Err(), sql.ErrNoRows) { + storageMetrics.ContentStorageUsage.Update(0) + } else if q2.Err() != nil { + log.Error("Querry execution error", "network", network, "metric", "entry_count", "err", q2.Err()) + return nil, q2.Err() + } else { + q2.Scan(res2) + storageMetrics.ContentStorageUsage.Update(*res2) + } + + return storageMetrics, nil +} diff --git a/portalnetwork/portal_protocol_test.go b/portalnetwork/portalwire/portal_protocol_test.go similarity index 98% rename from portalnetwork/portal_protocol_test.go rename to portalnetwork/portalwire/portal_protocol_test.go index fcc79e9d4f5f..1c78d3810c4d 100644 --- a/portalnetwork/portal_protocol_test.go +++ b/portalnetwork/portalwire/portal_protocol_test.go @@ -1,4 +1,4 @@ -package portalnetwork +package portalwire import ( "context" @@ -14,7 +14,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/discover" - "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/optimism-java/utp-go" "github.com/optimism-java/utp-go/libutp" @@ -92,7 +91,7 @@ func setupLocalPortalNode(addr string, bootNodes []*enode.Node) (*PortalProtocol contentQueue := make(chan *ContentElement, 50) portalProtocol, err := NewPortalProtocol( conf, - portalwire.History, + History, privKey, conn, localNode, @@ -297,12 +296,12 @@ func TestPortalWireProtocol(t *testing.T) { flag, content, err := node2.findContent(node1.localNode.Node(), []byte("test_key")) assert.NoError(t, err) - assert.Equal(t, portalwire.ContentRawSelector, flag) + assert.Equal(t, ContentRawSelector, flag) assert.Equal(t, []byte("test_value"), content) flag, content, err = node2.findContent(node3.localNode.Node(), []byte("test_key")) assert.NoError(t, err) - assert.Equal(t, portalwire.ContentEnrsSelector, flag) + assert.Equal(t, ContentEnrsSelector, flag) assert.Equal(t, 1, len(content.([]*enode.Node))) assert.Equal(t, node1.localNode.Node().ID(), content.([]*enode.Node)[0].ID()) @@ -318,7 +317,7 @@ func TestPortalWireProtocol(t *testing.T) { flag, content, err = node2.findContent(node1.localNode.Node(), []byte("large_test_key")) assert.NoError(t, err) assert.Equal(t, largeTestContent, content) - assert.Equal(t, portalwire.ContentConnIdSelector, flag) + assert.Equal(t, ContentConnIdSelector, flag) testEntry1 := &ContentEntry{ ContentKey: []byte("test_entry1"), diff --git a/portalnetwork/portal_utp.go b/portalnetwork/portalwire/portal_utp.go similarity index 94% rename from portalnetwork/portal_utp.go rename to portalnetwork/portalwire/portal_utp.go index b1b58a7673ca..487fe6ca455b 100644 --- a/portalnetwork/portal_utp.go +++ b/portalnetwork/portalwire/portal_utp.go @@ -1,4 +1,4 @@ -package portalnetwork +package portalwire import ( "context" @@ -12,7 +12,6 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover/v5wire" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/netutil" - "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/optimism-java/utp-go" "github.com/optimism-java/utp-go/libutp" "go.uber.org/zap" @@ -74,7 +73,7 @@ func (p *PortalUtp) Start() error { p.lAddr = p.listener.Addr().(*utp.Addr) // register discv5 listener - p.discV5.RegisterTalkHandler(string(portalwire.Utp), p.handleUtpTalkRequest) + p.discV5.RegisterTalkHandler(string(Utp), p.handleUtpTalkRequest) }) return err @@ -122,7 +121,7 @@ func (p *PortalUtp) packetRouterFunc(buf []byte, id enode.ID, addr *net.UDPAddr) if n, ok := p.discV5.GetCachedNode(addr.String()); ok { //_, err := p.DiscV5.TalkRequestToID(id, addr, string(portalwire.UTPNetwork), buf) - req := &v5wire.TalkRequest{Protocol: string(portalwire.Utp), Message: buf} + req := &v5wire.TalkRequest{Protocol: string(Utp), Message: buf} p.discV5.SendFromAnotherThreadWithNode(n, netip.AddrPortFrom(netutil.IPToAddr(addr.IP), uint16(addr.Port)), req) return len(buf), nil diff --git a/p2p/discover/portalwire/messages.go b/portalnetwork/portalwire/types.go similarity index 87% rename from p2p/discover/portalwire/messages.go rename to portalnetwork/portalwire/types.go index c7629604d570..2b1f45f8db58 100644 --- a/p2p/discover/portalwire/messages.go +++ b/portalnetwork/portalwire/types.go @@ -5,7 +5,7 @@ import ( ) // note: We changed the generated file since fastssz issues which can't be passed by the CI, so we commented the go:generate line -///go:generate sszgen --path messages.go --exclude-objs Content,Enrs,ContentKV +///go:generate sszgen --path types.go --exclude-objs Content,Enrs,ContentKV // Message codes for the portal protocol. const ( @@ -37,18 +37,6 @@ const ( PerContentKeyOverhead = 4 ) -// Protocol IDs for the portal protocol. -// var ( -// StateNetwork = []byte{0x50, 0x0a} -// HistoryNetwork = []byte{0x50, 0x0b} -// TxGossipNetwork = []byte{0x50, 0x0c} -// HeaderGossipNetwork = []byte{0x50, 0x0d} -// CanonicalIndicesNetwork = []byte{0x50, 0x0e} -// BeaconLightClientNetwork = []byte{0x50, 0x1a} -// UTPNetwork = []byte{0x75, 0x74, 0x70} -// Rendezvous = []byte{0x72, 0x65, 0x6e} -// ) - type ProtocolId []byte var ( @@ -61,7 +49,7 @@ var ( Utp ProtocolId = []byte{0x75, 0x74, 0x70} ) -var protocalName = map[string]string{ +var protocolName = map[string]string{ string(State): "state", string(History): "history", string(Beacon): "beacon", @@ -71,21 +59,9 @@ var protocalName = map[string]string{ } func (p ProtocolId) Name() string { - return protocalName[string(p)] + return protocolName[string(p)] } -// const ( -// HistoryNetworkName = "history" -// BeaconNetworkName = "beacon" -// StateNetworkName = "state" -// ) - -// var NetworkNameMap = map[string]string{ -// string(StateNetwork): StateNetworkName, -// string(HistoryNetwork): HistoryNetworkName, -// string(BeaconLightClientNetwork): BeaconNetworkName, -// } - type ContentKV struct { ContentKey []byte Content []byte diff --git a/p2p/discover/portalwire/messages_encoding.go b/portalnetwork/portalwire/types_encoding.go similarity index 100% rename from p2p/discover/portalwire/messages_encoding.go rename to portalnetwork/portalwire/types_encoding.go diff --git a/p2p/discover/portalwire/messages_test.go b/portalnetwork/portalwire/types_test.go similarity index 100% rename from p2p/discover/portalwire/messages_test.go rename to portalnetwork/portalwire/types_test.go diff --git a/portalnetwork/state/api.go b/portalnetwork/state/api.go index 87afdf73b92f..91de19bdf94c 100644 --- a/portalnetwork/state/api.go +++ b/portalnetwork/state/api.go @@ -1,14 +1,14 @@ package state import ( - "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" ) type API struct { - *discover.PortalProtocolAPI + *portalwire.PortalProtocolAPI } -func (p *API) StateRoutingTableInfo() *discover.RoutingTableInfo { +func (p *API) StateRoutingTableInfo() *portalwire.RoutingTableInfo { return p.RoutingTableInfo() } @@ -28,7 +28,7 @@ func (p *API) StateLookupEnr(nodeId string) (string, error) { return p.LookupEnr(nodeId) } -func (p *API) StatePing(enr string) (*discover.PortalPongResp, error) { +func (p *API) StatePing(enr string) (*portalwire.PortalPongResp, error) { return p.Ping(enr) } @@ -48,7 +48,7 @@ func (p *API) StateRecursiveFindNodes(nodeId string) ([]string, error) { return p.RecursiveFindNodes(nodeId) } -func (p *API) StateGetContent(contentKeyHex string) (*discover.ContentInfo, error) { +func (p *API) StateGetContent(contentKeyHex string) (*portalwire.ContentInfo, error) { return p.RecursiveFindContent(contentKeyHex) } @@ -64,11 +64,11 @@ func (p *API) StateGossip(contentKeyHex, contentHex string) (int, error) { return p.Gossip(contentKeyHex, contentHex) } -func (p *API) StateTraceGetContent(contentKeyHex string) (*discover.TraceContentResult, error) { +func (p *API) StateTraceGetContent(contentKeyHex string) (*portalwire.TraceContentResult, error) { return p.TraceRecursiveFindContent(contentKeyHex) } -func NewStateNetworkAPI(portalProtocolAPI *discover.PortalProtocolAPI) *API { +func NewStateNetworkAPI(portalProtocolAPI *portalwire.PortalProtocolAPI) *API { return &API{ portalProtocolAPI, } diff --git a/portalnetwork/state/network.go b/portalnetwork/state/network.go index 25008d78a1a1..2dcfcf17215d 100644 --- a/portalnetwork/state/network.go +++ b/portalnetwork/state/network.go @@ -10,8 +10,8 @@ import ( "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/portalnetwork/history" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/trie" @@ -21,7 +21,7 @@ import ( ) type StateNetwork struct { - portalProtocol *discover.PortalProtocol + portalProtocol *portalwire.PortalProtocol closeCtx context.Context closeFunc context.CancelFunc log log.Logger @@ -29,7 +29,7 @@ type StateNetwork struct { client *rpc.Client } -func NewStateNetwork(portalProtocol *discover.PortalProtocol, client *rpc.Client) *StateNetwork { +func NewStateNetwork(portalProtocol *portalwire.PortalProtocol, client *rpc.Client) *StateNetwork { ctx, cancel := context.WithCancel(context.Background()) return &StateNetwork{ portalProtocol: portalProtocol, @@ -196,7 +196,7 @@ func (h *StateNetwork) getStateRoot(blockHash common.Bytes32) (common.Bytes32, e contentKey = append(contentKey, blockHash[:]...) arg := hexutil.Encode(contentKey) - res := &discover.ContentInfo{} + res := &portalwire.ContentInfo{} err := h.client.CallContext(ctx, res, "portal_historyGetContent", arg) if err != nil { return common.Bytes32{}, err diff --git a/portalnetwork/state/network_test.go b/portalnetwork/state/network_test.go index 24504e7f3ba6..3a2bc5bfa171 100644 --- a/portalnetwork/state/network_test.go +++ b/portalnetwork/state/network_test.go @@ -6,8 +6,8 @@ import ( "testing" "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/portalnetwork/history" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/rpc" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" @@ -37,7 +37,7 @@ type MockAPI struct { header string } -func (p *MockAPI) HistoryGetContent(contentKeyHex string) (*discover.ContentInfo, error) { +func (p *MockAPI) HistoryGetContent(contentKeyHex string) (*portalwire.ContentInfo, error) { headerWithProof := &history.BlockHeaderWithProof{ Header: hexutil.MustDecode(p.header), Proof: &history.BlockHeaderProof{ @@ -49,7 +49,7 @@ func (p *MockAPI) HistoryGetContent(contentKeyHex string) (*discover.ContentInfo if err != nil { return nil, err } - return &discover.ContentInfo{ + return &portalwire.ContentInfo{ Content: hexutil.Encode(data), UtpTransfer: false, }, nil diff --git a/portalnetwork/state/storage.go b/portalnetwork/state/storage.go index 466c1707d6a2..1d2f12324d80 100644 --- a/portalnetwork/state/storage.go +++ b/portalnetwork/state/storage.go @@ -9,6 +9,7 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" + "github.com/ethereum/go-ethereum/portalnetwork/portalwire" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/holiman/uint256" "github.com/protolambda/ztyp/codec" @@ -27,7 +28,7 @@ type StateStorage struct { log log.Logger } -var portalStorageMetrics *metrics.PortalStorageMetrics +var portalStorageMetrics *portalwire.PortalStorageMetrics func NewStateStorage(store storage.ContentStorage, db *sql.DB) *StateStorage { storage := &StateStorage{ @@ -37,7 +38,7 @@ func NewStateStorage(store storage.ContentStorage, db *sql.DB) *StateStorage { } var err error - portalStorageMetrics, err = metrics.NewPortalStorageMetrics("state", db) + portalStorageMetrics, err = portalwire.NewPortalStorageMetrics("state", db) if err != nil { return nil } From 1daf86ed848ef9826012b6fe0c8d90d939ae1756 Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Thu, 21 Nov 2024 18:30:17 +0800 Subject: [PATCH 07/13] feat: delete content and set radius --- portalnetwork/storage/ethpepple/maxheap.go | 39 +++++ .../storage/ethpepple/maxheap_test.go | 33 ++++ portalnetwork/storage/ethpepple/storage.go | 148 +++++++++++++++++- .../storage/ethpepple/storage_test.go | 118 ++++++++++---- 4 files changed, 304 insertions(+), 34 deletions(-) create mode 100644 portalnetwork/storage/ethpepple/maxheap.go create mode 100644 portalnetwork/storage/ethpepple/maxheap_test.go diff --git a/portalnetwork/storage/ethpepple/maxheap.go b/portalnetwork/storage/ethpepple/maxheap.go new file mode 100644 index 000000000000..a88112bc24fd --- /dev/null +++ b/portalnetwork/storage/ethpepple/maxheap.go @@ -0,0 +1,39 @@ +package ethpepple + +import ( + "bytes" +) + +const maxItem = 250_000 // every item has 40 bytes, so the heap most have 10MB + +type Item struct { + Distance []byte + ValueSize uint64 +} + +type MaxHeap []Item + +func (m MaxHeap) Len() int { + return len(m) +} + +func (m MaxHeap) Less(i, j int) bool { + // Compare Distance as byte slices + return bytes.Compare(m[i].Distance, m[j].Distance) > 0 +} + +func (m MaxHeap) Swap(i, j int) { + m[i], m[j] = m[j], m[i] +} + +func (m *MaxHeap) Pop() interface{} { + old := *m + n := len(old) + item := old[n-1] + *m = old[0 : n-1] + return item +} + +func (m *MaxHeap) Push(x interface{}) { + *m = append(*m, x.(Item)) +} diff --git a/portalnetwork/storage/ethpepple/maxheap_test.go b/portalnetwork/storage/ethpepple/maxheap_test.go new file mode 100644 index 000000000000..65fe82f6e244 --- /dev/null +++ b/portalnetwork/storage/ethpepple/maxheap_test.go @@ -0,0 +1,33 @@ +package ethpepple + +import ( + "container/heap" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMaxHeap(t *testing.T) { + expectValueSize := []uint64{40, 30, 20, 10} + // Create a heap and initialize it with some items + h := &MaxHeap{ + {Distance: []byte{1}, ValueSize: 10}, + {Distance: []byte{2}, ValueSize: 20}, + {Distance: []byte{3}, ValueSize: 30}, + } + heap.Init(h) + + // Push a new item into the heap + heap.Push(h, Item{Distance: []byte{4}, ValueSize: 40}) + heap.Push(h, Item{Distance: []byte{5}, ValueSize: 50}) + + removed := heap.Remove(h, 0) + assert.Equal(t, removed.(Item).ValueSize, uint64(50)) + + len := h.Len() + // Pop and print the largest element + for i := 0; i < len; i++ { + item := heap.Pop(h).(Item) + assert.Equal(t, item.ValueSize, expectValueSize[i]) + } +} diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index ad4bae90be27..31e19d7e7bb9 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -1,7 +1,12 @@ package ethpepple import ( + "bytes" + "container/heap" + "encoding/binary" + "sync" "sync/atomic" + "time" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/ethdb/pebble" @@ -11,6 +16,8 @@ import ( "github.com/holiman/uint256" ) +const contentDeletionFraction = 0.05 // 5% of the content will be deleted when the storage capacity is hit and radius gets adjusted. + var _ storage.ContentStorage = &ContentStorage{} type PeppleStorageConfig struct { @@ -29,9 +36,13 @@ type ContentStorage struct { nodeId enode.ID storageCapacityInBytes uint64 radius atomic.Value - // size uint64 - log log.Logger - db ethdb.KeyValueStore + log log.Logger + db ethdb.KeyValueStore + size uint64 + sizeChan chan uint64 + sizeMutex sync.RWMutex + isPruning bool + pruneDoneChan chan uint64 // finish prune and get the pruned size } func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error) { @@ -40,6 +51,8 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error db: config.DB, storageCapacityInBytes: config.StorageCapacityMB * 1000_000, log: log.New("storage", config.NetworkName), + sizeChan: make(chan uint64, 100), + pruneDoneChan: make(chan uint64, 1), } cs.radius.Store(storage.MaxDistance) exist, err := cs.db.Has(storage.RadisuKey) @@ -58,6 +71,21 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error } cs.radius.Store(dis) } + + exist, err = cs.db.Has(storage.SizeKey) + if err != nil { + return nil, err + } + if exist { + val, err := cs.db.Get(storage.SizeKey) + if err != nil { + return nil, err + } + size := binary.BigEndian.Uint64(val) + // init stage, no need to use lock + cs.size = size + } + go cs.saveCapacity() return cs, nil } @@ -68,6 +96,8 @@ func (c *ContentStorage) Get(contentKey []byte, contentId []byte) ([]byte, error // Put implements storage.ContentStorage. func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte) error { + length := uint64(len(contentId)) + uint64(len(content)) + c.sizeChan <- length return c.db.Put(contentId, content) } @@ -77,3 +107,115 @@ func (c *ContentStorage) Radius() *uint256.Int { val := radius.(*uint256.Int) return val } + +func (c *ContentStorage) saveCapacity() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + sizeChanged := false + buf := make([]byte, 8) // uint64 + + for { + select { + case <-ticker.C: + if sizeChanged { + binary.BigEndian.PutUint64(buf, c.size) + c.db.Put(storage.SizeKey, buf) + sizeChanged = false + } + case size := <-c.sizeChan: + c.log.Debug("reveice size %v", size) + c.sizeMutex.Lock() + c.size += size + c.sizeMutex.Unlock() + sizeChanged = true + if c.size > c.storageCapacityInBytes { + if !c.isPruning { + c.isPruning = true + go c.prune() + } + } + case prunedSize := <-c.pruneDoneChan: + c.isPruning = false + c.size -= prunedSize + sizeChanged = true + } + } +} + +func (c *ContentStorage) prune() { + var distance = []byte{} + + h := &MaxHeap{} + heap.Init(h) + + expectSize := uint64(float64(c.storageCapacityInBytes) * contentDeletionFraction) + + var curentSize uint64 = 0 + + defer func() { + c.pruneDoneChan <- curentSize + }() + // get the keys to be deleted order by distance desc + iterator := c.db.NewIterator(nil, nil) + defer iterator.Release() + for iterator.Next() { + key := iterator.Key() + if bytes.Equal(key, storage.SizeKey) || bytes.Equal(key, storage.RadisuKey) { + continue + } + val := iterator.Value() + size := uint64(len(val)) + + distance := xor(key, c.nodeId[:]) + heap.Push(h, Item{ + Distance: distance, + ValueSize: size, + }) + if h.Len() > maxItem { + heap.Remove(h, h.Len()-1) + } + } + iterator.Release() + // delete the keys + for h.Len() > 0 { + if curentSize > expectSize { + break + } + item := heap.Pop(h) + val := item.(Item) + distance = val.Distance + key := xor(val.Distance, c.nodeId[:]) + if err := c.db.Delete(key); err != nil { + c.log.Error("failed to delete key %v, err: %v", key, err) + continue + } + curentSize += val.ValueSize + } + + dis := uint256.NewInt(0) + err := dis.UnmarshalSSZ(distance) + if err != nil { + c.log.Error("failed to parse the radius key %v, err is %v", distance, err) + } + c.radius.Store(dis) + err = c.db.Put(storage.RadisuKey, distance) + + if err != nil { + c.log.Error("failed to save the radius key %v, err is %v", distance, err) + } +} + +func xor(contentId, nodeId []byte) []byte { + // length of contentId maybe not 32bytes + padding := make([]byte, 32) + if len(contentId) != len(nodeId) { + copy(padding, contentId) + } else { + padding = contentId + } + res := make([]byte, len(padding)) + for i := range padding { + res[i] = padding[i] ^ nodeId[i] + } + return res +} diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index 7bf543df9487..33276d2f5cb8 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -2,11 +2,10 @@ package ethpepple import ( "crypto/rand" - "encoding/hex" - "os" "testing" + "time" - "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/holiman/uint256" "github.com/stretchr/testify/assert" @@ -14,39 +13,28 @@ import ( var testRadius = uint256.NewInt(100000) -func clearNodeData(path string) { - _ = os.RemoveAll(path) -} - -func getRandomPath() string { - // gen a random hex string - bytes := make([]byte, 32) - _, err := rand.Read(bytes) - if err != nil { - panic(err) +func genBytes(length int) []byte { + res := make([]byte, length) + for i := 0; i < length; i++ { + res[i] = byte(i) } - return hex.EncodeToString(bytes) + return res } -func getTestDb(path string) (storage.ContentStorage, error) { - db, err := NewPeppleDB(path, 100, 100, "history") - if err != nil { - return nil, err - } +func getTestDb() (storage.ContentStorage, error) { + db := memorydb.New() config := PeppleStorageConfig{ DB: db, - StorageCapacityMB: 100, - NodeId: enode.ID{}, + StorageCapacityMB: 1, + NodeId: uint256.NewInt(0).Bytes32(), NetworkName: "history", } return NewPeppleStorage(config) } func TestReadRadius(t *testing.T) { - path := getRandomPath() - db, err := getTestDb(path) + db, err := getTestDb() assert.NoError(t, err) - defer clearNodeData(path) assert.True(t, db.Radius().Eq(storage.MaxDistance)) data, err := testRadius.MarshalSSZ() @@ -56,17 +44,11 @@ func TestReadRadius(t *testing.T) { store := db.(*ContentStorage) err = store.db.Close() assert.NoError(t, err) - - db, err = getTestDb(path) - assert.NoError(t, err) - assert.True(t, db.Radius().Eq(testRadius)) } func TestStorage(t *testing.T) { - path := getRandomPath() - db, err := getTestDb(path) + db, err := getTestDb() assert.NoError(t, err) - defer clearNodeData(path) testcases := map[string][]byte{ "test1": []byte("test1"), "test2": []byte("test2"), @@ -84,3 +66,77 @@ func TestStorage(t *testing.T) { assert.Equal(t, value, val) } } + +func TestXor(t *testing.T) { + nodeId := uint256.NewInt(0).Bytes32() + bs := make([]byte, 32) + rand.Read(bs) + dis := xor(bs, nodeId[:]) + assert.Equal(t, bs, dis) + + nodeId2 := uint256.NewInt(2).Bytes32() + dis = xor(bs, nodeId2[:]) + assert.Equal(t, bs, xor(dis, nodeId2[:])) +} + +// the capacity is 1MB, so prune will delete over 50Kb content +func TestPrune(t *testing.T) { + db, err := getTestDb() + assert.NoError(t, err) + // the nodeId is zeros, so contentKey and contentId is the same + testcases := []struct { + contentKey [32]byte + content []byte + shouldPrune bool + }{ + { + contentKey: uint256.NewInt(1).Bytes32(), + content: genBytes(900_000), + shouldPrune: false, + }, + { + contentKey: uint256.NewInt(2).Bytes32(), + content: genBytes(40_000), + shouldPrune: false, + }, + { + contentKey: uint256.NewInt(3).Bytes32(), + content: genBytes(20_000), + shouldPrune: false, + }, + { + contentKey: uint256.NewInt(4).Bytes32(), + content: genBytes(20_000), + shouldPrune: false, + }, + { + contentKey: uint256.NewInt(5).Bytes32(), + content: genBytes(20_000), + shouldPrune: true, + }, + { + contentKey: uint256.NewInt(6).Bytes32(), + content: genBytes(20_000), + shouldPrune: true, + }, + { + contentKey: uint256.NewInt(7).Bytes32(), + content: genBytes(20_000), + shouldPrune: true, + }, + } + + for _, val := range testcases { + db.Put(val.contentKey[:], val.contentKey[:], val.content) + } + // // wait to prune done + time.Sleep(5 * time.Second) + for _, val := range testcases { + content, err := db.Get(val.contentKey[:], val.contentKey[:]) + if !val.shouldPrune { + assert.Equal(t, val.content, content) + } else { + assert.Error(t, err) + } + } +} From 7664f84dd7b2b3f11f3b7c89e55039198d456446 Mon Sep 17 00:00:00 2001 From: fearlessfe <505380967@qq.com> Date: Mon, 25 Nov 2024 08:27:08 +0800 Subject: [PATCH 08/13] feat: use pebble instand of ethdb/pebble --- portalnetwork/storage/ethpepple/maxheap.go | 39 ---- .../storage/ethpepple/maxheap_test.go | 33 --- portalnetwork/storage/ethpepple/storage.go | 213 ++++++++++++------ .../storage/ethpepple/storage_test.go | 146 ++++++++---- 4 files changed, 239 insertions(+), 192 deletions(-) delete mode 100644 portalnetwork/storage/ethpepple/maxheap.go delete mode 100644 portalnetwork/storage/ethpepple/maxheap_test.go diff --git a/portalnetwork/storage/ethpepple/maxheap.go b/portalnetwork/storage/ethpepple/maxheap.go deleted file mode 100644 index a88112bc24fd..000000000000 --- a/portalnetwork/storage/ethpepple/maxheap.go +++ /dev/null @@ -1,39 +0,0 @@ -package ethpepple - -import ( - "bytes" -) - -const maxItem = 250_000 // every item has 40 bytes, so the heap most have 10MB - -type Item struct { - Distance []byte - ValueSize uint64 -} - -type MaxHeap []Item - -func (m MaxHeap) Len() int { - return len(m) -} - -func (m MaxHeap) Less(i, j int) bool { - // Compare Distance as byte slices - return bytes.Compare(m[i].Distance, m[j].Distance) > 0 -} - -func (m MaxHeap) Swap(i, j int) { - m[i], m[j] = m[j], m[i] -} - -func (m *MaxHeap) Pop() interface{} { - old := *m - n := len(old) - item := old[n-1] - *m = old[0 : n-1] - return item -} - -func (m *MaxHeap) Push(x interface{}) { - *m = append(*m, x.(Item)) -} diff --git a/portalnetwork/storage/ethpepple/maxheap_test.go b/portalnetwork/storage/ethpepple/maxheap_test.go deleted file mode 100644 index 65fe82f6e244..000000000000 --- a/portalnetwork/storage/ethpepple/maxheap_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package ethpepple - -import ( - "container/heap" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestMaxHeap(t *testing.T) { - expectValueSize := []uint64{40, 30, 20, 10} - // Create a heap and initialize it with some items - h := &MaxHeap{ - {Distance: []byte{1}, ValueSize: 10}, - {Distance: []byte{2}, ValueSize: 20}, - {Distance: []byte{3}, ValueSize: 30}, - } - heap.Init(h) - - // Push a new item into the heap - heap.Push(h, Item{Distance: []byte{4}, ValueSize: 40}) - heap.Push(h, Item{Distance: []byte{5}, ValueSize: 50}) - - removed := heap.Remove(h, 0) - assert.Equal(t, removed.(Item).ValueSize, uint64(50)) - - len := h.Len() - // Pop and print the largest element - for i := 0; i < len; i++ { - item := heap.Pop(h).(Item) - assert.Equal(t, item.ValueSize, expectValueSize[i]) - } -} diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index 31e19d7e7bb9..b943914690cc 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -1,34 +1,119 @@ package ethpepple import ( - "bytes" - "container/heap" "encoding/binary" + "runtime" "sync" "sync/atomic" "time" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/ethdb/pebble" + "github.com/cockroachdb/pebble" + "github.com/cockroachdb/pebble/bloom" + "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/holiman/uint256" ) -const contentDeletionFraction = 0.05 // 5% of the content will be deleted when the storage capacity is hit and radius gets adjusted. +const ( + // minCache is the minimum amount of memory in megabytes to allocate to pebble + // read and write caching, split half and half. + minCache = 16 + + // minHandles is the minimum number of files handles to allocate to the open + // database files. + minHandles = 16 + + // 5% of the content will be deleted when the storage capacity is hit and radius gets adjusted. + contentDeletionFraction = 0.05 +) var _ storage.ContentStorage = &ContentStorage{} type PeppleStorageConfig struct { StorageCapacityMB uint64 - DB ethdb.KeyValueStore + DB *pebble.DB NodeId enode.ID NetworkName string } -func NewPeppleDB(dataDir string, cache, handles int, namespace string) (ethdb.KeyValueStore, error) { - db, err := pebble.New(dataDir+"/"+namespace, cache, handles, namespace, false) +func NewPeppleDB(dataDir string, cache, handles int, namespace string) (*pebble.DB, error) { + // Ensure we have some minimal caching and file guarantees + if cache < minCache { + cache = minCache + } + if handles < minHandles { + handles = minHandles + } + logger := log.New("database", namespace) + logger.Info("Allocated cache and file handles", "cache", common.StorageSize(cache*1024*1024), "handles", handles) + + // The max memtable size is limited by the uint32 offsets stored in + // internal/arenaskl.node, DeferredBatchOp, and flushableBatchEntry. + // + // - MaxUint32 on 64-bit platforms; + // - MaxInt on 32-bit platforms. + // + // It is used when slices are limited to Uint32 on 64-bit platforms (the + // length limit for slices is naturally MaxInt on 32-bit platforms). + // + // Taken from https://github.com/cockroachdb/pebble/blob/master/internal/constants/constants.go + maxMemTableSize := (1<<31)<<(^uint(0)>>63) - 1 + + // Two memory tables is configured which is identical to leveldb, + // including a frozen memory table and another live one. + memTableLimit := 2 + memTableSize := cache * 1024 * 1024 / 2 / memTableLimit + + // The memory table size is currently capped at maxMemTableSize-1 due to a + // known bug in the pebble where maxMemTableSize is not recognized as a + // valid size. + // + // TODO use the maxMemTableSize as the maximum table size once the issue + // in pebble is fixed. + if memTableSize >= maxMemTableSize { + memTableSize = maxMemTableSize - 1 + } + opt := &pebble.Options{ + // Pebble has a single combined cache area and the write + // buffers are taken from this too. Assign all available + // memory allowance for cache. + Cache: pebble.NewCache(int64(cache * 1024 * 1024)), + MaxOpenFiles: handles, + + // The size of memory table(as well as the write buffer). + // Note, there may have more than two memory tables in the system. + MemTableSize: uint64(memTableSize), + + // MemTableStopWritesThreshold places a hard limit on the size + // of the existent MemTables(including the frozen one). + // Note, this must be the number of tables not the size of all memtables + // according to https://github.com/cockroachdb/pebble/blob/master/options.go#L738-L742 + // and to https://github.com/cockroachdb/pebble/blob/master/db.go#L1892-L1903. + MemTableStopWritesThreshold: memTableLimit, + + // The default compaction concurrency(1 thread), + // Here use all available CPUs for faster compaction. + MaxConcurrentCompactions: runtime.NumCPU, + + // Per-level options. Options for at least one level must be specified. The + // options for the last level are used for all subsequent levels. + Levels: []pebble.LevelOptions{ + {TargetFileSize: 2 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + {TargetFileSize: 4 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + {TargetFileSize: 8 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + {TargetFileSize: 16 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + {TargetFileSize: 32 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + {TargetFileSize: 64 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + {TargetFileSize: 128 * 1024 * 1024, FilterPolicy: bloom.FilterPolicy(10)}, + }, + ReadOnly: false, + } + // Disable seek compaction explicitly. Check https://github.com/ethereum/go-ethereum/pull/20130 + // for more details. + opt.Experimental.ReadSamplingMultiplier = -1 + db, err := pebble.Open(dataDir+"/"+namespace, opt) return db, err } @@ -37,12 +122,13 @@ type ContentStorage struct { storageCapacityInBytes uint64 radius atomic.Value log log.Logger - db ethdb.KeyValueStore + db *pebble.DB size uint64 sizeChan chan uint64 sizeMutex sync.RWMutex isPruning bool pruneDoneChan chan uint64 // finish prune and get the pruned size + writeOptions *pebble.WriteOptions } func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error) { @@ -53,17 +139,14 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error log: log.New("storage", config.NetworkName), sizeChan: make(chan uint64, 100), pruneDoneChan: make(chan uint64, 1), + writeOptions: &pebble.WriteOptions{Sync: false}, } cs.radius.Store(storage.MaxDistance) - exist, err := cs.db.Has(storage.RadisuKey) - if err != nil { + radius, _, err := cs.db.Get(storage.RadisuKey) + if err != nil && err != pebble.ErrNotFound { return nil, err } - if exist { - radius, err := cs.db.Get(storage.RadisuKey) - if err != nil { - return nil, err - } + if err == nil { dis := uint256.NewInt(0) err = dis.UnmarshalSSZ(radius) if err != nil { @@ -72,15 +155,11 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error cs.radius.Store(dis) } - exist, err = cs.db.Has(storage.SizeKey) - if err != nil { + val, _, err := cs.db.Get(storage.SizeKey) + if err != nil && err != pebble.ErrNotFound { return nil, err } - if exist { - val, err := cs.db.Get(storage.SizeKey) - if err != nil { - return nil, err - } + if err == nil { size := binary.BigEndian.Uint64(val) // init stage, no need to use lock cs.size = size @@ -91,14 +170,24 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error // Get implements storage.ContentStorage. func (c *ContentStorage) Get(contentKey []byte, contentId []byte) ([]byte, error) { - return c.db.Get(contentId) + distance := xor(contentId, c.nodeId[:]) + data, closer, err := c.db.Get(distance) + if err != nil && err != pebble.ErrNotFound { + return nil, err + } + if err == pebble.ErrNotFound { + return nil, storage.ErrContentNotFound + } + closer.Close() + return data, nil } // Put implements storage.ContentStorage. func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte) error { length := uint64(len(contentId)) + uint64(len(content)) c.sizeChan <- length - return c.db.Put(contentId, content) + distance := xor(contentId, c.nodeId[:]) + return c.db.Set(distance, content, c.writeOptions) } // Radius implements storage.ContentStorage. @@ -119,7 +208,10 @@ func (c *ContentStorage) saveCapacity() { case <-ticker.C: if sizeChanged { binary.BigEndian.PutUint64(buf, c.size) - c.db.Put(storage.SizeKey, buf) + err := c.db.Set(storage.SizeKey, buf, c.writeOptions) + if err != nil { + c.log.Error("save capacity failed", "error", err) + } sizeChanged = false } case size := <-c.sizeChan: @@ -143,65 +235,40 @@ func (c *ContentStorage) saveCapacity() { } func (c *ContentStorage) prune() { - var distance = []byte{} - - h := &MaxHeap{} - heap.Init(h) - expectSize := uint64(float64(c.storageCapacityInBytes) * contentDeletionFraction) - var curentSize uint64 = 0 defer func() { c.pruneDoneChan <- curentSize }() // get the keys to be deleted order by distance desc - iterator := c.db.NewIterator(nil, nil) - defer iterator.Release() - for iterator.Next() { - key := iterator.Key() - if bytes.Equal(key, storage.SizeKey) || bytes.Equal(key, storage.RadisuKey) { - continue - } - val := iterator.Value() - size := uint64(len(val)) - - distance := xor(key, c.nodeId[:]) - heap.Push(h, Item{ - Distance: distance, - ValueSize: size, - }) - if h.Len() > maxItem { - heap.Remove(h, h.Len()-1) - } + iter, err := c.db.NewIter(nil) + if err != nil { + c.log.Error("get iter failed", "error", err) + return } - iterator.Release() - // delete the keys - for h.Len() > 0 { - if curentSize > expectSize { + + batch := c.db.NewBatch() + for iter.Last(); iter.Valid(); iter.Prev() { + if curentSize < expectSize { + batch.Delete(iter.Key(), nil) + curentSize += uint64(len(iter.Key())) + uint64(len(iter.Value())) + } else { + distance := iter.Key() + c.db.Set(storage.RadisuKey, distance, c.writeOptions) + dis := uint256.NewInt(0) + err = dis.UnmarshalSSZ(distance) + if err != nil { + c.log.Error("unmarshal distance failed", "error", err) + } + c.radius.Store(dis) break } - item := heap.Pop(h) - val := item.(Item) - distance = val.Distance - key := xor(val.Distance, c.nodeId[:]) - if err := c.db.Delete(key); err != nil { - c.log.Error("failed to delete key %v, err: %v", key, err) - continue - } - curentSize += val.ValueSize } - - dis := uint256.NewInt(0) - err := dis.UnmarshalSSZ(distance) - if err != nil { - c.log.Error("failed to parse the radius key %v, err is %v", distance, err) - } - c.radius.Store(dis) - err = c.db.Put(storage.RadisuKey, distance) - + err = batch.Commit(&pebble.WriteOptions{Sync: true}) if err != nil { - c.log.Error("failed to save the radius key %v, err is %v", distance, err) + c.log.Error("prune batch commit failed", "error", err) + return } } diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index 33276d2f5cb8..08fd8bc14a57 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -1,18 +1,14 @@ package ethpepple import ( - "crypto/rand" "testing" "time" - "github.com/ethereum/go-ethereum/ethdb/memorydb" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/holiman/uint256" "github.com/stretchr/testify/assert" ) -var testRadius = uint256.NewInt(100000) - func genBytes(length int) []byte { res := make([]byte, length) for i := 0; i < length; i++ { @@ -21,68 +17,119 @@ func genBytes(length int) []byte { return res } -func getTestDb() (storage.ContentStorage, error) { - db := memorydb.New() - config := PeppleStorageConfig{ - DB: db, - StorageCapacityMB: 1, - NodeId: uint256.NewInt(0).Bytes32(), - NetworkName: "history", - } - return NewPeppleStorage(config) +func TestNewPeppleDB(t *testing.T) { + db, err := NewPeppleDB(t.TempDir(), 16, 16, "test") + assert.NoError(t, err) + defer db.Close() + + assert.NotNil(t, db) } -func TestReadRadius(t *testing.T) { - db, err := getTestDb() +func setupTestStorage(t *testing.T) storage.ContentStorage { + db, err := NewPeppleDB(t.TempDir(), 16, 16, "test") assert.NoError(t, err) - assert.True(t, db.Radius().Eq(storage.MaxDistance)) + t.Cleanup(func() { db.Close() }) - data, err := testRadius.MarshalSSZ() - assert.NoError(t, err) - db.Put(nil, storage.RadisuKey, data) + config := PeppleStorageConfig{ + StorageCapacityMB: 1, + DB: db, + NodeId: uint256.NewInt(0).Bytes32(), + NetworkName: "test", + } - store := db.(*ContentStorage) - err = store.db.Close() + storage, err := NewPeppleStorage(config) assert.NoError(t, err) + return storage } -func TestStorage(t *testing.T) { - db, err := getTestDb() - assert.NoError(t, err) - testcases := map[string][]byte{ - "test1": []byte("test1"), - "test2": []byte("test2"), - "test3": []byte("test3"), - "test4": []byte("test4"), - } +func TestContentStoragePutAndGet(t *testing.T) { + db := setupTestStorage(t) - for key, value := range testcases { - db.Put(nil, []byte(key), value) + testCases := []struct { + contentKey []byte + contentId []byte + content []byte + }{ + {[]byte("key1"), []byte("id1"), []byte("content1")}, + {[]byte("key2"), []byte("id2"), []byte("content2")}, } - for key, value := range testcases { - val, err := db.Get(nil, []byte(key)) + for _, tc := range testCases { + err := db.Put(tc.contentKey, tc.contentId, tc.content) + assert.NoError(t, err) + + got, err := db.Get(tc.contentKey, tc.contentId) assert.NoError(t, err) - assert.Equal(t, value, val) + assert.Equal(t, tc.content, got) } } -func TestXor(t *testing.T) { - nodeId := uint256.NewInt(0).Bytes32() - bs := make([]byte, 32) - rand.Read(bs) - dis := xor(bs, nodeId[:]) - assert.Equal(t, bs, dis) +func TestRadius(t *testing.T) { + db := setupTestStorage(t) + radius := db.Radius() + assert.NotNil(t, radius) + assert.True(t, radius.Eq(storage.MaxDistance)) +} + +// func TestPrune(t *testing.T) { +// db, err := NewPeppleDB(t.TempDir(), 16, 16, "test") +// assert.NoError(t, err) +// defer db.Close() + +// config := PeppleStorageConfig{ +// StorageCapacityMB: 1, // 1MB capacity +// DB: db, +// NodeId: uint256.NewInt(0).Bytes32(), +// NetworkName: "test", +// } + +// storage, err := NewPeppleStorage(config) +// assert.NoError(t, err) + +// // Add content exceeding capacity +// largeContent := make([]byte, 900_000) // 900KB +// err = storage.Put([]byte("key1"), []byte("id1"), largeContent) +// assert.NoError(t, err) + +// smallContent := make([]byte, 200_000) // 200KB +// err = storage.Put([]byte("key2"), []byte("id2"), smallContent) +// assert.NoError(t, err) + +// // Wait for prune to complete +// time.Sleep(6 * time.Second) + +// // Verify content after pruning +// _, err = storage.Get([]byte("key2"), []byte("id2")) +// assert.Error(t, err) // Should be pruned +// } + +func TestXOR(t *testing.T) { + testCases := []struct { + contentId []byte + nodeId []byte + expected []byte + }{ + { + contentId: []byte{0x01}, + nodeId: make([]byte, 32), + expected: append([]byte{0x01}, make([]byte, 31)...), + }, + { + contentId: []byte{0xFF}, + nodeId: []byte{0x0F}, + expected: append([]byte{0xF0}, make([]byte, 31)...), + }, + } - nodeId2 := uint256.NewInt(2).Bytes32() - dis = xor(bs, nodeId2[:]) - assert.Equal(t, bs, xor(dis, nodeId2[:])) + for _, tc := range testCases { + result := xor(tc.contentId, tc.nodeId) + assert.Equal(t, tc.expected, result) + } } // the capacity is 1MB, so prune will delete over 50Kb content func TestPrune(t *testing.T) { - db, err := getTestDb() - assert.NoError(t, err) + db := setupTestStorage(t) // the nodeId is zeros, so contentKey and contentId is the same testcases := []struct { contentKey [32]byte @@ -130,7 +177,7 @@ func TestPrune(t *testing.T) { db.Put(val.contentKey[:], val.contentKey[:], val.content) } // // wait to prune done - time.Sleep(5 * time.Second) + time.Sleep(2 * time.Second) for _, val := range testcases { content, err := db.Get(val.contentKey[:], val.contentKey[:]) if !val.shouldPrune { @@ -139,4 +186,9 @@ func TestPrune(t *testing.T) { assert.Error(t, err) } } + radius := db.Radius() + data, err := radius.MarshalSSZ() + assert.NoError(t, err) + actual := uint256.NewInt(4).Bytes32() + assert.Equal(t, data, actual) } From c1326356f8c093d67987992736cdb6c8dcfed167 Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Mon, 25 Nov 2024 10:32:24 +0800 Subject: [PATCH 09/13] fix: test error --- .../storage/ethpepple/storage_test.go | 36 ++----------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index 08fd8bc14a57..0154f07dfa07 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -71,38 +71,6 @@ func TestRadius(t *testing.T) { assert.True(t, radius.Eq(storage.MaxDistance)) } -// func TestPrune(t *testing.T) { -// db, err := NewPeppleDB(t.TempDir(), 16, 16, "test") -// assert.NoError(t, err) -// defer db.Close() - -// config := PeppleStorageConfig{ -// StorageCapacityMB: 1, // 1MB capacity -// DB: db, -// NodeId: uint256.NewInt(0).Bytes32(), -// NetworkName: "test", -// } - -// storage, err := NewPeppleStorage(config) -// assert.NoError(t, err) - -// // Add content exceeding capacity -// largeContent := make([]byte, 900_000) // 900KB -// err = storage.Put([]byte("key1"), []byte("id1"), largeContent) -// assert.NoError(t, err) - -// smallContent := make([]byte, 200_000) // 200KB -// err = storage.Put([]byte("key2"), []byte("id2"), smallContent) -// assert.NoError(t, err) - -// // Wait for prune to complete -// time.Sleep(6 * time.Second) - -// // Verify content after pruning -// _, err = storage.Get([]byte("key2"), []byte("id2")) -// assert.Error(t, err) // Should be pruned -// } - func TestXOR(t *testing.T) { testCases := []struct { contentId []byte @@ -117,7 +85,7 @@ func TestXOR(t *testing.T) { { contentId: []byte{0xFF}, nodeId: []byte{0x0F}, - expected: append([]byte{0xF0}, make([]byte, 31)...), + expected: []byte{0xF0}, }, } @@ -190,5 +158,5 @@ func TestPrune(t *testing.T) { data, err := radius.MarshalSSZ() assert.NoError(t, err) actual := uint256.NewInt(4).Bytes32() - assert.Equal(t, data, actual) + assert.Equal(t, data, actual[:]) } From 2682bd5526ac36493d23e3d6359f45bde796ae24 Mon Sep 17 00:00:00 2001 From: fearlessfe <505380967@qq.com> Date: Wed, 27 Nov 2024 08:27:20 +0800 Subject: [PATCH 10/13] feat: turn async prune to sync prune --- portalnetwork/storage/ethpepple/storage.go | 65 ++++++++----------- .../storage/ethpepple/storage_test.go | 8 +-- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index b943914690cc..897918bbbb94 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -3,7 +3,6 @@ package ethpepple import ( "encoding/binary" "runtime" - "sync" "sync/atomic" "time" @@ -124,10 +123,8 @@ type ContentStorage struct { log log.Logger db *pebble.DB size uint64 - sizeChan chan uint64 - sizeMutex sync.RWMutex - isPruning bool - pruneDoneChan chan uint64 // finish prune and get the pruned size + sizeChan chan struct{} + capacityChan chan uint64 writeOptions *pebble.WriteOptions } @@ -137,8 +134,8 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error db: config.DB, storageCapacityInBytes: config.StorageCapacityMB * 1000_000, log: log.New("storage", config.NetworkName), - sizeChan: make(chan uint64, 100), - pruneDoneChan: make(chan uint64, 1), + sizeChan: make(chan struct{}, 1), + capacityChan: make(chan uint64, 100), writeOptions: &pebble.WriteOptions{Sync: false}, } cs.radius.Store(storage.MaxDistance) @@ -164,6 +161,7 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error // init stage, no need to use lock cs.size = size } + cs.sizeChan <- struct{}{} go cs.saveCapacity() return cs, nil } @@ -185,7 +183,16 @@ func (c *ContentStorage) Get(contentKey []byte, contentId []byte) ([]byte, error // Put implements storage.ContentStorage. func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte) error { length := uint64(len(contentId)) + uint64(len(content)) - c.sizeChan <- length + <-c.sizeChan + c.size += length + if c.size > c.storageCapacityInBytes { + err := c.prune() + if err != nil { + c.sizeChan <- struct{}{} + return err + } + } + c.sizeChan <- struct{}{} distance := xor(contentId, c.nodeId[:]) return c.db.Set(distance, content, c.writeOptions) } @@ -200,52 +207,36 @@ func (c *ContentStorage) Radius() *uint256.Int { func (c *ContentStorage) saveCapacity() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() - sizeChanged := false + capacityChanged := false + var currentCapacity uint64 = 0 buf := make([]byte, 8) // uint64 for { select { case <-ticker.C: - if sizeChanged { - binary.BigEndian.PutUint64(buf, c.size) + if capacityChanged { + binary.BigEndian.PutUint64(buf, currentCapacity) err := c.db.Set(storage.SizeKey, buf, c.writeOptions) if err != nil { c.log.Error("save capacity failed", "error", err) } - sizeChanged = false - } - case size := <-c.sizeChan: - c.log.Debug("reveice size %v", size) - c.sizeMutex.Lock() - c.size += size - c.sizeMutex.Unlock() - sizeChanged = true - if c.size > c.storageCapacityInBytes { - if !c.isPruning { - c.isPruning = true - go c.prune() - } + capacityChanged = false } - case prunedSize := <-c.pruneDoneChan: - c.isPruning = false - c.size -= prunedSize - sizeChanged = true + case capacity := <-c.capacityChan: + capacityChanged = true + currentCapacity = capacity } } } -func (c *ContentStorage) prune() { +func (c *ContentStorage) prune() error { expectSize := uint64(float64(c.storageCapacityInBytes) * contentDeletionFraction) var curentSize uint64 = 0 - defer func() { - c.pruneDoneChan <- curentSize - }() // get the keys to be deleted order by distance desc iter, err := c.db.NewIter(nil) if err != nil { - c.log.Error("get iter failed", "error", err) - return + return err } batch := c.db.NewBatch() @@ -259,7 +250,7 @@ func (c *ContentStorage) prune() { dis := uint256.NewInt(0) err = dis.UnmarshalSSZ(distance) if err != nil { - c.log.Error("unmarshal distance failed", "error", err) + return err } c.radius.Store(dis) break @@ -267,9 +258,9 @@ func (c *ContentStorage) prune() { } err = batch.Commit(&pebble.WriteOptions{Sync: true}) if err != nil { - c.log.Error("prune batch commit failed", "error", err) - return + return err } + return nil } func xor(contentId, nodeId []byte) []byte { diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index 0154f07dfa07..44e3cd301dac 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -117,12 +117,12 @@ func TestPrune(t *testing.T) { { contentKey: uint256.NewInt(3).Bytes32(), content: genBytes(20_000), - shouldPrune: false, + shouldPrune: true, }, { contentKey: uint256.NewInt(4).Bytes32(), content: genBytes(20_000), - shouldPrune: false, + shouldPrune: true, }, { contentKey: uint256.NewInt(5).Bytes32(), @@ -132,12 +132,12 @@ func TestPrune(t *testing.T) { { contentKey: uint256.NewInt(6).Bytes32(), content: genBytes(20_000), - shouldPrune: true, + shouldPrune: false, }, { contentKey: uint256.NewInt(7).Bytes32(), content: genBytes(20_000), - shouldPrune: true, + shouldPrune: false, }, } From 70e712b462cff90cd6c33b4fd3560e18e8847a45 Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Wed, 27 Nov 2024 13:30:34 +0800 Subject: [PATCH 11/13] feat: add test for prune --- portalnetwork/storage/content_storage.go | 1 + portalnetwork/storage/ethpepple/storage.go | 28 +++++++++- .../storage/ethpepple/storage_test.go | 53 ++++++++++--------- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/portalnetwork/storage/content_storage.go b/portalnetwork/storage/content_storage.go index e726d612e54e..a5d0d2b26d84 100644 --- a/portalnetwork/storage/content_storage.go +++ b/portalnetwork/storage/content_storage.go @@ -7,6 +7,7 @@ import ( ) var ErrContentNotFound = fmt.Errorf("content not found") +var ErrInsufficientRadius = fmt.Errorf("insufficient radius") var MaxDistance = uint256.MustFromHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff") diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index 897918bbbb94..922e1a9e8e83 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -182,6 +182,20 @@ func (c *ContentStorage) Get(contentKey []byte, contentId []byte) ([]byte, error // Put implements storage.ContentStorage. func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte) error { + distance := xor(contentId, c.nodeId[:]) + valid, err := c.inRadius(distance) + if err != nil { + return err + } + if !valid { + return storage.ErrInsufficientRadius + } + + err = c.db.Set(distance, content, c.writeOptions) + if err != nil { + return err + } + length := uint64(len(contentId)) + uint64(len(content)) <-c.sizeChan c.size += length @@ -193,8 +207,7 @@ func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte } } c.sizeChan <- struct{}{} - distance := xor(contentId, c.nodeId[:]) - return c.db.Set(distance, content, c.writeOptions) + return nil } // Radius implements storage.ContentStorage. @@ -263,6 +276,17 @@ func (c *ContentStorage) prune() error { return nil } +func (c *ContentStorage) inRadius(distance []byte) (bool, error) { + dis := uint256.NewInt(0) + err := dis.UnmarshalSSZ(distance) + if err != nil { + return false, err + } + val := c.radius.Load() + radius := val.(*uint256.Int) + return radius.Gt(dis), nil +} + func xor(contentId, nodeId []byte) []byte { // length of contentId maybe not 32bytes padding := make([]byte, 32) diff --git a/portalnetwork/storage/ethpepple/storage_test.go b/portalnetwork/storage/ethpepple/storage_test.go index 44e3cd301dac..9aaca6f6d0df 100644 --- a/portalnetwork/storage/ethpepple/storage_test.go +++ b/portalnetwork/storage/ethpepple/storage_test.go @@ -2,7 +2,6 @@ package ethpepple import ( "testing" - "time" "github.com/ethereum/go-ethereum/portalnetwork/storage" "github.com/holiman/uint256" @@ -102,61 +101,63 @@ func TestPrune(t *testing.T) { testcases := []struct { contentKey [32]byte content []byte - shouldPrune bool + outOfRadius bool + err error }{ { - contentKey: uint256.NewInt(1).Bytes32(), - content: genBytes(900_000), - shouldPrune: false, + contentKey: uint256.NewInt(1).Bytes32(), + content: genBytes(900_000), }, { - contentKey: uint256.NewInt(2).Bytes32(), - content: genBytes(40_000), - shouldPrune: false, + contentKey: uint256.NewInt(2).Bytes32(), + content: genBytes(40_000), }, { - contentKey: uint256.NewInt(3).Bytes32(), - content: genBytes(20_000), - shouldPrune: true, + contentKey: uint256.NewInt(3).Bytes32(), + content: genBytes(20_000), + err: storage.ErrContentNotFound, }, { - contentKey: uint256.NewInt(4).Bytes32(), - content: genBytes(20_000), - shouldPrune: true, + contentKey: uint256.NewInt(4).Bytes32(), + content: genBytes(20_000), + err: storage.ErrContentNotFound, }, { - contentKey: uint256.NewInt(5).Bytes32(), - content: genBytes(20_000), - shouldPrune: true, + contentKey: uint256.NewInt(5).Bytes32(), + content: genBytes(20_000), + err: storage.ErrContentNotFound, }, { contentKey: uint256.NewInt(6).Bytes32(), content: genBytes(20_000), - shouldPrune: false, + err: storage.ErrInsufficientRadius, + outOfRadius: true, }, { contentKey: uint256.NewInt(7).Bytes32(), content: genBytes(20_000), - shouldPrune: false, + err: storage.ErrInsufficientRadius, + outOfRadius: true, }, } for _, val := range testcases { - db.Put(val.contentKey[:], val.contentKey[:], val.content) + err := db.Put(val.contentKey[:], val.contentKey[:], val.content) + if err != nil { + assert.Equal(t, val.err, err) + } } - // // wait to prune done - time.Sleep(2 * time.Second) for _, val := range testcases { content, err := db.Get(val.contentKey[:], val.contentKey[:]) - if !val.shouldPrune { + if err == nil { assert.Equal(t, val.content, content) - } else { - assert.Error(t, err) + } else if !val.outOfRadius { + assert.Equal(t, val.err, err) } } radius := db.Radius() data, err := radius.MarshalSSZ() assert.NoError(t, err) - actual := uint256.NewInt(4).Bytes32() + actual := uint256.NewInt(2).Bytes32() assert.Equal(t, data, actual[:]) } From 6bae115934bc0636fd105580d85fb28db7837f0b Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Wed, 27 Nov 2024 13:50:14 +0800 Subject: [PATCH 12/13] feat: changr size after delete data --- portalnetwork/storage/ethpepple/storage.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index 922e1a9e8e83..ac1f662412b2 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -205,6 +205,8 @@ func (c *ContentStorage) Put(contentKey []byte, contentId []byte, content []byte c.sizeChan <- struct{}{} return err } + } else { + c.capacityChan <- c.size } c.sizeChan <- struct{}{} return nil @@ -273,6 +275,8 @@ func (c *ContentStorage) prune() error { if err != nil { return err } + c.size -= curentSize + c.capacityChan <- c.size - curentSize return nil } From df48088bcf3b577da5a7cd67f2adbca91e5641b7 Mon Sep 17 00:00:00 2001 From: fearlseefe <505380967@qq.com> Date: Wed, 27 Nov 2024 14:21:58 +0800 Subject: [PATCH 13/13] feat: delete radius in storage --- portalnetwork/storage/content_storage.go | 1 - portalnetwork/storage/ethpepple/storage.go | 31 +++++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/portalnetwork/storage/content_storage.go b/portalnetwork/storage/content_storage.go index a5d0d2b26d84..075b4a6bf4b6 100644 --- a/portalnetwork/storage/content_storage.go +++ b/portalnetwork/storage/content_storage.go @@ -13,7 +13,6 @@ var MaxDistance = uint256.MustFromHex("0xfffffffffffffffffffffffffffffffffffffff type ContentType byte -var RadisuKey = []byte("radius") var SizeKey = []byte("size") type ContentKey struct { diff --git a/portalnetwork/storage/ethpepple/storage.go b/portalnetwork/storage/ethpepple/storage.go index ac1f662412b2..ae7d269aba1a 100644 --- a/portalnetwork/storage/ethpepple/storage.go +++ b/portalnetwork/storage/ethpepple/storage.go @@ -1,6 +1,7 @@ package ethpepple import ( + "bytes" "encoding/binary" "runtime" "sync/atomic" @@ -139,28 +140,32 @@ func NewPeppleStorage(config PeppleStorageConfig) (storage.ContentStorage, error writeOptions: &pebble.WriteOptions{Sync: false}, } cs.radius.Store(storage.MaxDistance) - radius, _, err := cs.db.Get(storage.RadisuKey) + + val, _, err := cs.db.Get(storage.SizeKey) if err != nil && err != pebble.ErrNotFound { return nil, err } if err == nil { + size := binary.BigEndian.Uint64(val) + // init stage, no need to use lock + cs.size = size + } + + iter, err := cs.db.NewIter(nil) + if err != nil { + return nil, err + } + defer iter.Close() + if iter.Last() && iter.Valid() { + distance := iter.Key() dis := uint256.NewInt(0) - err = dis.UnmarshalSSZ(radius) + err = dis.UnmarshalSSZ(distance) if err != nil { return nil, err } cs.radius.Store(dis) } - val, _, err := cs.db.Get(storage.SizeKey) - if err != nil && err != pebble.ErrNotFound { - return nil, err - } - if err == nil { - size := binary.BigEndian.Uint64(val) - // init stage, no need to use lock - cs.size = size - } cs.sizeChan <- struct{}{} go cs.saveCapacity() return cs, nil @@ -256,12 +261,14 @@ func (c *ContentStorage) prune() error { batch := c.db.NewBatch() for iter.Last(); iter.Valid(); iter.Prev() { + if bytes.Equal(iter.Key(), storage.SizeKey) { + continue + } if curentSize < expectSize { batch.Delete(iter.Key(), nil) curentSize += uint64(len(iter.Key())) + uint64(len(iter.Value())) } else { distance := iter.Key() - c.db.Set(storage.RadisuKey, distance, c.writeOptions) dis := uint256.NewInt(0) err = dis.UnmarshalSSZ(distance) if err != nil {