From 28c44fb54c3d943eec5ed410d7a3adedebdcc19f Mon Sep 17 00:00:00 2001 From: Steph Samson Date: Sat, 19 Oct 2024 01:49:36 +0900 Subject: [PATCH] feat: persist protocols sets and fix hanging context --- cmd/honeypot/main.go | 34 ++++++++--- db/client_db.go | 18 +++--- ...9_add_protocols_to_requests_table.down.sql | 2 + ...019_add_protocols_to_requests_table.up.sql | 7 +++ ...20_alter_insert_requests_function.down.sql | 5 ++ ...0020_alter_insert_requests_function.up.sql | 59 +++++++++++++++++++ db/models/requests.go | 12 +++- queen.go | 52 +++++++++++----- 8 files changed, 155 insertions(+), 34 deletions(-) create mode 100644 db/migrations/000019_add_protocols_to_requests_table.down.sql create mode 100644 db/migrations/000019_add_protocols_to_requests_table.up.sql create mode 100644 db/migrations/000020_alter_insert_requests_function.down.sql create mode 100644 db/migrations/000020_alter_insert_requests_function.up.sql diff --git a/cmd/honeypot/main.go b/cmd/honeypot/main.go index 644ef7e..6037ff6 100644 --- a/cmd/honeypot/main.go +++ b/cmd/honeypot/main.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "syscall" + "time" logging "github.com/ipfs/go-log/v2" "github.com/probe-lab/ants-watch" @@ -25,6 +26,7 @@ func main() { flag.Parse() ctx, cancel := context.WithCancel(context.Background()) + defer cancel() sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) @@ -40,14 +42,32 @@ func main() { panic(err) } - go queen.Run(ctx) - + errChan := make(chan error, 1) go func() { - sig := <-sigChan - logger.Infof("Received signal: %s, shutting down...", sig) - cancel() + logger.Debugln("Starting Queen.Run") + errChan <- queen.Run(ctx) + logger.Debugln("Queen.Run completed") }() - <-ctx.Done() - logger.Info("Context canceled, queen stopped") + select { + case err := <-errChan: + if err != nil { + logger.Errorf("Queen.Run returned an error: %v", err) + } else { + logger.Debugln("Queen.Run completed successfully") + } + case sig := <-sigChan: + logger.Infof("Received signal: %v, initiating shutdown...", sig) + } + + cancel() + + select { + case <-errChan: + logger.Debugln("Queen.Run stopped after context cancellation") + case <-time.After(30 * time.Second): + logger.Warnln("Timeout waiting for Queen.Run to stop") + } + + logger.Debugln("Work is done") } diff --git a/db/client_db.go b/db/client_db.go index 972dff9..e6e173b 100644 --- a/db/client_db.go +++ b/db/client_db.go @@ -595,12 +595,12 @@ func BulkInsertRequests(ctx context.Context, db *sql.DB, requests []models.Reque i := 1 for _, request := range requests { - valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5, i+6)) - valueArgs = append(valueArgs, request.RequestStartedAt, request.RequestType, request.AntMultihash, request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, request.AgentVersion) - i += 7 + valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5, i+6, i+7)) + valueArgs = append(valueArgs, request.RequestStartedAt, request.RequestType, request.AntMultihash, request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, request.AgentVersion, request.Protocols) + i += 8 } - stmt := fmt.Sprintf("INSERT INTO requests_denormalized (request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version) VALUES %s RETURNING id;", + stmt := fmt.Sprintf("INSERT INTO requests_denormalized (request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version, protocols) VALUES %s RETURNING id;", strings.Join(valueStrings, ", ")) rows, err := queries.Raw(stmt, valueArgs...).QueryContext(ctx, db) @@ -613,7 +613,7 @@ func BulkInsertRequests(ctx context.Context, db *sql.DB, requests []models.Reque } func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) error { - rows, err := db.Query("SELECT id, request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version FROM requests_denormalized WHERE normalized_at IS NULL") + rows, err := db.QueryContext(ctx, "SELECT id, request_started_at, request_type, ant_multihash, peer_multihash, key_multihash, multi_addresses, agent_version, protocols FROM requests_denormalized WHERE normalized_at IS NULL") if err != nil { return err } @@ -621,7 +621,7 @@ func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) erro for rows.Next() { var request models.RequestsDenormalized - if err := rows.Scan(&request.ID, &request.RequestStartedAt, &request.RequestType, &request.AntMultihash, &request.PeerMultihash, &request.KeyMultihash, &request.MultiAddresses, &request.AgentVersion); err != nil { + if err := rows.Scan(&request.ID, &request.RequestStartedAt, &request.RequestType, &request.AntMultihash, &request.PeerMultihash, &request.KeyMultihash, &request.MultiAddresses, &request.AgentVersion, &request.Protocols); err != nil { return err } @@ -633,14 +633,14 @@ func NormalizeRequests(ctx context.Context, db *sql.DB, dbClient *DBClient) erro request.PeerMultihash, request.KeyMultihash, request.MultiAddresses, - request.AgentVersion, // agent versions - nil, // protocol sets + request.AgentVersion, + request.Protocols, ) if err != nil { return fmt.Errorf("failed to normalize request ID %d: %w, timestamp: %v", request.ID, err, request.RequestStartedAt) } - _, err = db.Exec("UPDATE requests_denormalized SET normalized_at = NOW() WHERE id = $1", request.ID) + _, err = db.ExecContext(ctx, "UPDATE requests_denormalized SET normalized_at = NOW() WHERE id = $1", request.ID) if err != nil { return fmt.Errorf("failed to update normalized_at for request ID %d: %w", request.ID, err) } diff --git a/db/migrations/000019_add_protocols_to_requests_table.down.sql b/db/migrations/000019_add_protocols_to_requests_table.down.sql new file mode 100644 index 0000000..708341f --- /dev/null +++ b/db/migrations/000019_add_protocols_to_requests_table.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE requests + DROP COLUMN IF EXISTS protocols_set_id; diff --git a/db/migrations/000019_add_protocols_to_requests_table.up.sql b/db/migrations/000019_add_protocols_to_requests_table.up.sql new file mode 100644 index 0000000..1b9ec24 --- /dev/null +++ b/db/migrations/000019_add_protocols_to_requests_table.up.sql @@ -0,0 +1,7 @@ +BEGIN; + +ALTER TABLE requests + ADD COLUMN protocols_set_id INT, + ADD CONSTRAINT fk_requests_protocols_set_id FOREIGN KEY (protocols_set_id) REFERENCES protocols_sets (id) ON DELETE SET NULL; + +COMMIT; diff --git a/db/migrations/000020_alter_insert_requests_function.down.sql b/db/migrations/000020_alter_insert_requests_function.down.sql new file mode 100644 index 0000000..6bd8cfe --- /dev/null +++ b/db/migrations/000020_alter_insert_requests_function.down.sql @@ -0,0 +1,5 @@ +BEGIN; + +DROP FUNCTION IF EXISTS insert_request; + +COMMIT; diff --git a/db/migrations/000020_alter_insert_requests_function.up.sql b/db/migrations/000020_alter_insert_requests_function.up.sql new file mode 100644 index 0000000..d807cf8 --- /dev/null +++ b/db/migrations/000020_alter_insert_requests_function.up.sql @@ -0,0 +1,59 @@ +BEGIN; + +CREATE OR REPLACE FUNCTION insert_request( + new_timestamp TIMESTAMPTZ, + new_request_type message_type, + new_ant TEXT, + new_multi_hash TEXT, + new_key_multi_hash TEXT, + new_multi_addresses TEXT[], + new_agent_version_id INT, + new_protocols_set_id INT +) RETURNS RECORD AS +$insert_request$ +DECLARE + new_multi_addresses_ids INT[]; + new_request_id INT; + new_peer_id INT; + new_ant_id INT; + new_key_id INT; +BEGIN + SELECT upsert_peer( + new_multi_hash, + new_agent_version_id, + new_protocols_set_id, + new_timestamp + ) INTO new_peer_id; + + SELECT id INTO new_ant_id + FROM peers + WHERE multi_hash = new_ant; + + SELECT insert_key(new_key_multi_hash) INTO new_key_id; + + SELECT array_agg(id) FROM upsert_multi_addresses(new_multi_addresses) INTO new_multi_addresses_ids; + + DELETE + FROM peers_x_multi_addresses pxma + WHERE peer_id = new_peer_id; + + INSERT INTO peers_x_multi_addresses (peer_id, multi_address_id) + SELECT new_peer_id, new_multi_address_id + FROM unnest(new_multi_addresses_ids) new_multi_address_id + ON CONFLICT DO NOTHING; + + INSERT INTO requests (timestamp, request_type, ant_id, peer_id, key_id, multi_address_ids, protocols_set_id) + SELECT new_timestamp, + new_request_type, + new_ant_id, + new_peer_id, + new_key_id, + new_multi_addresses_ids, + new_protocols_set_id + RETURNING id INTO new_request_id; + + RETURN ROW(new_peer_id, new_request_id, new_key_id); +END; +$insert_request$ LANGUAGE plpgsql; + +COMMIT; diff --git a/db/models/requests.go b/db/models/requests.go index af9b797..f5f9128 100644 --- a/db/models/requests.go +++ b/db/models/requests.go @@ -14,6 +14,7 @@ import ( "time" "github.com/friendsofgo/errors" + "github.com/volatiletech/null/v8" "github.com/volatiletech/sqlboiler/v4/boil" "github.com/volatiletech/sqlboiler/v4/queries" "github.com/volatiletech/sqlboiler/v4/queries/qm" @@ -31,6 +32,7 @@ type Request struct { PeerID int64 `boil:"peer_id" json:"peer_id" toml:"peer_id" yaml:"peer_id"` KeyID int `boil:"key_id" json:"key_id" toml:"key_id" yaml:"key_id"` MultiAddressIds types.Int64Array `boil:"multi_address_ids" json:"multi_address_ids,omitempty" toml:"multi_address_ids" yaml:"multi_address_ids,omitempty"` + ProtocolsSetID null.Int `boil:"protocols_set_id" json:"protocols_set_id,omitempty" toml:"protocols_set_id" yaml:"protocols_set_id,omitempty"` R *requestR `boil:"-" json:"-" toml:"-" yaml:"-"` L requestL `boil:"-" json:"-" toml:"-" yaml:"-"` @@ -44,6 +46,7 @@ var RequestColumns = struct { PeerID string KeyID string MultiAddressIds string + ProtocolsSetID string }{ ID: "id", Timestamp: "timestamp", @@ -52,6 +55,7 @@ var RequestColumns = struct { PeerID: "peer_id", KeyID: "key_id", MultiAddressIds: "multi_address_ids", + ProtocolsSetID: "protocols_set_id", } var RequestTableColumns = struct { @@ -62,6 +66,7 @@ var RequestTableColumns = struct { PeerID string KeyID string MultiAddressIds string + ProtocolsSetID string }{ ID: "requests.id", Timestamp: "requests.timestamp", @@ -70,6 +75,7 @@ var RequestTableColumns = struct { PeerID: "requests.peer_id", KeyID: "requests.key_id", MultiAddressIds: "requests.multi_address_ids", + ProtocolsSetID: "requests.protocols_set_id", } // Generated where @@ -85,6 +91,7 @@ var RequestWhere = struct { PeerID whereHelperint64 KeyID whereHelperint MultiAddressIds whereHelpertypes_Int64Array + ProtocolsSetID whereHelpernull_Int }{ ID: whereHelperint{field: "\"requests\".\"id\""}, Timestamp: whereHelpertime_Time{field: "\"requests\".\"timestamp\""}, @@ -93,6 +100,7 @@ var RequestWhere = struct { PeerID: whereHelperint64{field: "\"requests\".\"peer_id\""}, KeyID: whereHelperint{field: "\"requests\".\"key_id\""}, MultiAddressIds: whereHelpertypes_Int64Array{field: "\"requests\".\"multi_address_ids\""}, + ProtocolsSetID: whereHelpernull_Int{field: "\"requests\".\"protocols_set_id\""}, } // RequestRels is where relationship names are stored. @@ -112,9 +120,9 @@ func (*requestR) NewStruct() *requestR { type requestL struct{} var ( - requestAllColumns = []string{"id", "timestamp", "request_type", "ant_id", "peer_id", "key_id", "multi_address_ids"} + requestAllColumns = []string{"id", "timestamp", "request_type", "ant_id", "peer_id", "key_id", "multi_address_ids", "protocols_set_id"} requestColumnsWithoutDefault = []string{"timestamp", "request_type", "ant_id", "peer_id", "key_id"} - requestColumnsWithDefault = []string{"id", "multi_address_ids"} + requestColumnsWithDefault = []string{"id", "multi_address_ids", "protocols_set_id"} requestPrimaryKeyColumns = []string{"id", "timestamp"} requestGeneratedColumns = []string{"id"} ) diff --git a/queen.go b/queen.go index 8330ed7..aef3d9e 100644 --- a/queen.go +++ b/queen.go @@ -16,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" "github.com/probe-lab/go-libdht/kad" "github.com/probe-lab/go-libdht/kad/key" @@ -83,7 +84,7 @@ func NewQueen(ctx context.Context, dbConnString string, keysDbPath string, nPort mmc: mmc, uclient: getUdgerClient(), resolveBatchSize: getBatchSize(), - resolveBatchTime: getBatchSize(), + resolveBatchTime: getBatchTime(), } if nPorts != 0 { @@ -182,7 +183,10 @@ func (q *Queen) freePort(port uint16) { } } -func (q *Queen) Run(ctx context.Context) { +func (q *Queen) Run(ctx context.Context) error { + logger.Debugln("Queen.Run started") + defer logger.Debugln("Queen.Run completing") + go q.consumeAntsLogs(ctx) crawlTime := time.NewTicker(CRAWL_INTERVAL) @@ -195,14 +199,17 @@ func (q *Queen) Run(ctx context.Context) { for { select { + case <-ctx.Done(): + logger.Debugln("Queen.Run done..") + q.persistLiveAntsKeys() + return ctx.Err() case <-crawlTime.C: q.routine(ctx) case <-normalizationTime.C: go q.normalizeRequests(ctx) - // time.Sleep(10 * time.Second) - case <-ctx.Done(): - q.persistLiveAntsKeys() - return + default: + // busy-loop guard + time.Sleep(100 * time.Millisecond) } } } @@ -215,6 +222,18 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) { for { select { + + case <-ctx.Done(): + logger.Debugln("Gracefully shutting down ants...") + logger.Debugln("Number of requests remaining to be inserted:", len(requests)) + if len(requests) > 0 { + err := db.BulkInsertRequests(context.Background(), q.dbc.Handler, requests) + if err != nil { + logger.Fatalf("Error inserting remaining requests: %v", err) + } + } + return + case log := <-q.antsLogs: reqType := kadpb.Message_MessageType(log.Type).String() maddrs := q.peerstore.Addrs(log.Requester) @@ -225,7 +244,9 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) { } else { agent = peerstoreAgent.(string) } - // protocols, _ := q.peerstore.GetProtocols(log.Requester) + + protocols, _ := q.peerstore.GetProtocols(log.Requester) + protocolsAsStr := protocol.ConvertToStrings(protocols) request := models.RequestsDenormalized{ RequestStartedAt: log.Timestamp, @@ -235,12 +256,13 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) { KeyMultihash: log.Target.B58String(), MultiAddresses: db.MaddrsToAddrs(maddrs), AgentVersion: null.StringFrom(agent), + Protocols: protocolsAsStr, } requests = append(requests, request) if len(requests) >= q.resolveBatchSize { err = db.BulkInsertRequests(ctx, q.dbc.Handler, requests) if err != nil { - logger.Fatalf("Error inserting requests: %v", err) + logger.Errorf("Error inserting requests: %v", err) } requests = requests[:0] } @@ -254,16 +276,12 @@ func (q *Queen) consumeAntsLogs(ctx context.Context) { requests = requests[:0] } - case <-ctx.Done(): - if len(requests) > 0 { - err := db.BulkInsertRequests(ctx, q.dbc.Handler, requests) - if err != nil { - logger.Fatalf("Error inserting remaining requests: %v", err) - } - } - return + default: + // against busy-looping since <-q.antsLogs is a busy chan + time.Sleep(10 * time.Millisecond) } } + } func (q *Queen) normalizeRequests(ctx context.Context) { @@ -281,11 +299,13 @@ func (q *Queen) normalizeRequests(ctx context.Context) { } func (q *Queen) persistLiveAntsKeys() { + logger.Debugln("Persisting live ants keys") antsKeys := make([]crypto.PrivKey, 0, len(q.ants)) for _, ant := range q.ants { antsKeys = append(antsKeys, ant.Host.Peerstore().PrivKey(ant.Host.ID())) } q.keysDB.MatchingKeys(nil, antsKeys) + logger.Debugf("Number of antsKeys persisted: %d", len(antsKeys)) } func (q *Queen) routine(ctx context.Context) {