From 80f503ab6a5e2e29ff87b176edd67ee35914a277 Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Mon, 13 Jan 2025 17:44:02 -0500 Subject: [PATCH] Optimize vulnerability host counts (#24914) --- changes/22364-vuln-cron | 1 + cmd/fleet/cron.go | 4 +- server/config/config.go | 7 + server/datastore/mysql/vulnerabilities.go | 251 +++++++++++++----- .../datastore/mysql/vulnerabilities_test.go | 20 +- server/fleet/datastore.go | 5 +- server/mock/datastore_mock.go | 6 +- server/service/integration_core_test.go | 4 +- server/service/integration_enterprise_test.go | 4 +- 9 files changed, 216 insertions(+), 86 deletions(-) create mode 100644 changes/22364-vuln-cron diff --git a/changes/22364-vuln-cron b/changes/22364-vuln-cron new file mode 100644 index 000000000000..a63ff4b1e2c9 --- /dev/null +++ b/changes/22364-vuln-cron @@ -0,0 +1 @@ +* fixed issue where the vulnerabilities cron was failing in large environments due to large SQL queries \ No newline at end of file diff --git a/cmd/fleet/cron.go b/cmd/fleet/cron.go index fe5027d511d2..be7a4ab7744e 100644 --- a/cmd/fleet/cron.go +++ b/cmd/fleet/cron.go @@ -102,10 +102,12 @@ func cronVulnerabilities( return fmt.Errorf("scanning vulnerabilities: %w", err) } + start := time.Now() level.Info(logger).Log("msg", "updating vulnerability host counts") - if err := ds.UpdateVulnerabilityHostCounts(ctx); err != nil { + if err := ds.UpdateVulnerabilityHostCounts(ctx, config.MaxConcurrency); err != nil { return fmt.Errorf("updating vulnerability host counts: %w", err) } + level.Info(logger).Log("msg", "vulnerability host counts updated", "took", time.Since(start).Seconds()) } return nil diff --git a/server/config/config.go b/server/config/config.go index 9a523d70d950..42cd7e1faab3 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -518,6 +518,7 @@ type VulnerabilitiesConfig struct { DisableDataSync bool `json:"disable_data_sync" yaml:"disable_data_sync"` RecentVulnerabilityMaxAge time.Duration `json:"recent_vulnerability_max_age" yaml:"recent_vulnerability_max_age"` DisableWinOSVulnerabilities bool `json:"disable_win_os_vulnerabilities" yaml:"disable_win_os_vulnerabilities"` + MaxConcurrency int `json:"max_concurrency" yaml:"max_concurrency"` } // UpgradesConfig defines configs related to fleet server upgrades. @@ -1309,6 +1310,11 @@ func (man Manager) addConfigs() { false, "Don't sync installed Windows updates nor perform Windows OS vulnerability processing.", ) + man.addConfigInt( + "vulnerabilities.max_concurrency", + 5, + "Maximum number of concurrent database queries to use for processing vulnerabilities.", + ) // Upgrades man.addConfigBool("upgrades.allow_missing_migrations", false, @@ -1580,6 +1586,7 @@ func (man Manager) LoadConfig() FleetConfig { DisableDataSync: man.getConfigBool("vulnerabilities.disable_data_sync"), RecentVulnerabilityMaxAge: man.getConfigDuration("vulnerabilities.recent_vulnerability_max_age"), DisableWinOSVulnerabilities: man.getConfigBool("vulnerabilities.disable_win_os_vulnerabilities"), + MaxConcurrency: man.getConfigInt("vulnerabilities.max_concurrency"), }, Upgrades: UpgradesConfig{ AllowMissingMigrations: man.getConfigBool("upgrades.allow_missing_migrations"), diff --git a/server/datastore/mysql/vulnerabilities.go b/server/datastore/mysql/vulnerabilities.go index 96cefc19b571..5bfcb572d673 100644 --- a/server/datastore/mysql/vulnerabilities.go +++ b/server/datastore/mysql/vulnerabilities.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "sync" "time" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" @@ -342,30 +343,193 @@ func (ds *Datastore) CountVulnerabilities(ctx context.Context, opt fleet.VulnLis return count, nil } -func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error { +func (ds *Datastore) distinctCVEs(ctx context.Context) ([]string, error) { + uniqueCVEQuery := ` + SELECT DISTINCT cve FROM ( + SELECT cve FROM software_cve + UNION + SELECT cve FROM operating_system_vulnerabilities + ) AS combined_cves; + ` + + var cves []string + err := sqlx.SelectContext(ctx, ds.reader(ctx), &cves, uniqueCVEQuery) + if err != nil { + return nil, err + } + return cves, nil +} + +type CountScope int + +const ( + GlobalCount CountScope = iota + NoTeamCount + TeamCount +) + +func (ds *Datastore) batchFetchVulnerabilityCounts( + ctx context.Context, + scope CountScope, + maxRoutines int, +) ([]hostCount, error) { + const ( + batchSize = 20 + ) + + // Fetch distinct CVEs + allCVEs, err := ds.distinctCVEs(ctx) + if err != nil { + return nil, err + } + + query := getVulnHostCountQuery(scope) + if query == "" { + return nil, ctxerr.Errorf(ctx, "invalid scope: %d", scope) + } + + var ( + hostCounts []hostCount + mu sync.Mutex + wg sync.WaitGroup + sem = make(chan struct{}, maxRoutines) + errChan = make(chan error, len(allCVEs)/batchSize+1) + ) + + // Process CVEs in batches concurrently + for i := 0; i < len(allCVEs); i += batchSize { + end := i + batchSize + if end > len(allCVEs) { + end = len(allCVEs) + } + + batchCVEs := allCVEs[i:end] + wg.Add(1) + sem <- struct{}{} // Acquire semaphore + + go func(cves []string) { + defer wg.Done() + defer func() { <-sem }() // Release semaphore + + counts, err := ds.fetchBatchCounts(ctx, cves, query) + if err != nil { + errChan <- err + return + } + + mu.Lock() + hostCounts = append(hostCounts, counts...) + mu.Unlock() + }(batchCVEs) + } + + wg.Wait() + close(errChan) + + // Check for errors + for err := range errChan { + if err != nil { + return nil, err + } + } + + return hostCounts, nil +} + +// fetchBatchCounts executes the query for a batch of CVEs. +func (ds *Datastore) fetchBatchCounts( + ctx context.Context, + batchCVEs []string, + scopeConfig string, +) ([]hostCount, error) { + query, args, err := sqlx.In(scopeConfig, batchCVEs, batchCVEs) + if err != nil { + return nil, err + } + + var counts []hostCount + err = sqlx.SelectContext(ctx, ds.reader(ctx), &counts, query, args...) + if err != nil { + return nil, err + } + + return counts, nil +} + +// getScopeConfig returns the query configuration for the given scope. +func getVulnHostCountQuery(scope CountScope) string { + switch scope { + case GlobalCount: + return ` + SELECT 0 as team_id, 1 as global_stats, combined_results.cve, COUNT(*) AS host_count + FROM ( + SELECT sc.cve, hs.host_id + FROM software_cve sc + INNER JOIN host_software hs ON sc.software_id = hs.software_id + WHERE sc.cve IN (?) + + UNION + + SELECT osv.cve, hos.host_id + FROM operating_system_vulnerabilities osv + INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id + WHERE osv.cve IN (?) + ) AS combined_results + GROUP BY cve + ` + case NoTeamCount: + return ` + SELECT 0 as team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count + FROM ( + SELECT sc.cve, hs.host_id + FROM software_cve sc + INNER JOIN host_software hs ON sc.software_id = hs.software_id + WHERE sc.cve IN (?) + + UNION + + SELECT osv.cve, hos.host_id + FROM operating_system_vulnerabilities osv + INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id + WHERE osv.cve IN (?) + ) AS combined_results + INNER JOIN hosts h ON combined_results.host_id = h.id + WHERE h.team_id IS NULL + GROUP BY cve + ` + case TeamCount: + return ` + SELECT h.team_id as team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count + FROM ( + SELECT sc.cve, hs.host_id + FROM software_cve sc + INNER JOIN host_software hs ON sc.software_id = hs.software_id + WHERE sc.cve IN (?) + + UNION + + SELECT osv.cve, hos.host_id + FROM operating_system_vulnerabilities osv + INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id + WHERE osv.cve IN (?) + ) AS combined_results + INNER JOIN hosts h ON combined_results.host_id = h.id + WHERE h.team_id IS NOT NULL + GROUP BY h.team_id, combined_results.cve + ` + default: + return "" + } +} + +func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error { // set all counts to 0 to later identify rows to delete _, err := ds.writer(ctx).ExecContext(ctx, "UPDATE vulnerability_host_counts SET host_count = 0") if err != nil { return ctxerr.Wrap(ctx, err, "initializing vulnerability host counts") } - globalSelectStmt := ` - SELECT 0 as team_id, 1 as global_stats, cve, COUNT(*) AS host_count - FROM ( - SELECT sc.cve, hs.host_id - FROM software_cve sc - INNER JOIN host_software hs ON sc.software_id = hs.software_id - - UNION - - SELECT osv.cve, hos.host_id - FROM operating_system_vulnerabilities osv - INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id - ) AS combined_results - GROUP BY cve; - ` - - globalHostCounts, err := ds.fetchHostCounts(ctx, globalSelectStmt) + globalHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, GlobalCount, maxRoutines) if err != nil { return ctxerr.Wrap(ctx, err, "fetching global vulnerability host counts") } @@ -375,25 +539,7 @@ func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error { return ctxerr.Wrap(ctx, err, "inserting global vulnerability host counts") } - teamSelectStmt := ` - SELECT h.team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count - FROM ( - SELECT hs.host_id, sc.cve - FROM software_cve sc - INNER JOIN host_software hs ON sc.software_id = hs.software_id - - UNION - - SELECT hos.host_id, osv.cve - FROM operating_system_vulnerabilities osv - INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id - ) AS combined_results - INNER JOIN hosts h ON combined_results.host_id = h.id - WHERE h.team_id IS NOT NULL - GROUP BY h.team_id, combined_results.cve - ` - - teamHostCounts, err := ds.fetchHostCounts(ctx, teamSelectStmt) + teamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, TeamCount, maxRoutines) if err != nil { return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts") } @@ -403,27 +549,9 @@ func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error { return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts") } - noTeamSelectStmt := ` - SELECT 0 as team_id, 0 as global_stats, cve, COUNT(*) AS host_count - FROM ( - SELECT hs.host_id, sc.cve - FROM software_cve sc - INNER JOIN host_software hs ON sc.software_id = hs.software_id - - UNION - - SELECT hos.host_id, osv.cve - FROM operating_system_vulnerabilities osv - INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id - ) AS combined_results - INNER JOIN hosts h ON combined_results.host_id = h.id - WHERE h.team_id IS NULL - GROUP BY cve - ` - - noTeamHostCounts, err := ds.fetchHostCounts(ctx, noTeamSelectStmt) + noTeamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, NoTeamCount, maxRoutines) if err != nil { - return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts") + return ctxerr.Wrap(ctx, err, "fetching no team vulnerability host counts") } err = ds.batchInsertHostCounts(ctx, noTeamHostCounts) @@ -455,15 +583,6 @@ func (ds *Datastore) cleanupVulnerabilityHostCounts(ctx context.Context) error { return nil } -func (ds *Datastore) fetchHostCounts(ctx context.Context, query string) ([]hostCount, error) { - var hostCounts []hostCount - err := sqlx.SelectContext(ctx, ds.reader(ctx), &hostCounts, query) - if err != nil { - return nil, err - } - return hostCounts, nil -} - func (ds *Datastore) batchInsertHostCounts(ctx context.Context, counts []hostCount) error { if len(counts) == 0 { return nil diff --git a/server/datastore/mysql/vulnerabilities_test.go b/server/datastore/mysql/vulnerabilities_test.go index 112a4db90fb5..3444059fb3c8 100644 --- a/server/datastore/mysql/vulnerabilities_test.go +++ b/server/datastore/mysql/vulnerabilities_test.go @@ -596,7 +596,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { }, fleet.MSRCSource) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) list, _, err = ds.ListVulnerabilities(context.Background(), fleet.VulnListOptions{}) @@ -611,7 +611,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { err = ds.UpdateHostOperatingSystem(context.Background(), host2.ID, windowsOS) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) list, _, err = ds.ListVulnerabilities(context.Background(), fleet.VulnListOptions{}) @@ -626,7 +626,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { err = ds.UpdateHostOperatingSystem(context.Background(), host3.ID, macOS) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // assert no new vulns @@ -641,7 +641,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { }, fleet.NVDSource) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) list, _, err = ds.ListVulnerabilities(context.Background(), fleet.VulnListOptions{}) @@ -666,7 +666,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { }, fleet.NVDSource) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) list, _, err = ds.ListVulnerabilities(context.Background(), fleet.VulnListOptions{}) @@ -684,7 +684,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { err = ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // global counts should not change @@ -721,7 +721,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { require.NoError(t, err) } - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // global counts should not change @@ -764,7 +764,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { require.NoError(t, err) } - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // no change to team1 counts @@ -796,7 +796,7 @@ func testInsertVulnerabilityCounts(t *testing.T, ds *Datastore) { _, err = ds.UpdateHostSoftware(context.Background(), host1.ID, []fleet.Software{}) require.NoError(t, err) - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // global counts reduced @@ -880,7 +880,7 @@ func testVulnerabilityHostCountBatchInserts(t *testing.T, ds *Datastore) { } // update host counts - err = ds.UpdateVulnerabilityHostCounts(context.Background()) + err = ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // assert host counts diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 58a6bc91b509..085e1cbc6a87 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -1044,8 +1044,9 @@ type Datastore interface { // CountVulnerabilities returns the number of unique vulnerabilities based on the provided // options. CountVulnerabilities(ctx context.Context, opt VulnListOptions) (uint, error) - // UpdateVulnerabilityHostCounts updates hosts counts for all vulnerabilities. - UpdateVulnerabilityHostCounts(ctx context.Context) error + // UpdateVulnerabilityHostCounts updates hosts counts for all vulnerabilities. maxRoutines signifies the number of + // goroutines to use for processing parallel database queries. + UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error // IsCVEKnownToFleet checks if the provided CVE is known to Fleet. IsCVEKnownToFleet(ctx context.Context, cve string) (bool, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 171abbf0d415..5cfc3f5aee17 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -735,7 +735,7 @@ type VulnerabilityFunc func(ctx context.Context, cve string, teamID *uint, inclu type CountVulnerabilitiesFunc func(ctx context.Context, opt fleet.VulnListOptions) (uint, error) -type UpdateVulnerabilityHostCountsFunc func(ctx context.Context) error +type UpdateVulnerabilityHostCountsFunc func(ctx context.Context, maxRoutines int) error type IsCVEKnownToFleetFunc func(ctx context.Context, cve string) (bool, error) @@ -5479,11 +5479,11 @@ func (s *DataStore) CountVulnerabilities(ctx context.Context, opt fleet.VulnList return s.CountVulnerabilitiesFunc(ctx, opt) } -func (s *DataStore) UpdateVulnerabilityHostCounts(ctx context.Context) error { +func (s *DataStore) UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error { s.mu.Lock() s.UpdateVulnerabilityHostCountsFuncInvoked = true s.mu.Unlock() - return s.UpdateVulnerabilityHostCountsFunc(ctx) + return s.UpdateVulnerabilityHostCountsFunc(ctx, maxRoutines) } func (s *DataStore) IsCVEKnownToFleet(ctx context.Context, cve string) (bool, error) { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 97a78e1d404d..c84a0de3d835 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -9113,7 +9113,7 @@ func (s *integrationTestSuite) TestListVulnerabilities() { }) require.NoError(t, err) - err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + err = s.ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) // test list @@ -9220,7 +9220,7 @@ func (s *integrationTestSuite) TestListVulnerabilities() { err = s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{host.ID}) require.NoError(t, err) - err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + err = s.ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "team_id", fmt.Sprintf("%d", team.ID)) diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 21a1599cb459..9000b7e29c35 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -4486,7 +4486,7 @@ func (s *integrationEnterpriseTestSuite) TestListVulnerabilities() { }) require.NoError(t, err) - err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + err = s.ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) @@ -4550,7 +4550,7 @@ func (s *integrationEnterpriseTestSuite) TestListVulnerabilities() { err = s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{host.ID}) require.NoError(t, err) - err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + err = s.ds.UpdateVulnerabilityHostCounts(context.Background(), 5) require.NoError(t, err) s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "team_id", fmt.Sprintf("%d", team.ID))