Skip to content

Commit

Permalink
Optimize vulnerability host counts (#24914)
Browse files Browse the repository at this point in the history
  • Loading branch information
mostlikelee authored Jan 13, 2025
1 parent d15d2e3 commit 80f503a
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 86 deletions.
1 change: 1 addition & 0 deletions changes/22364-vuln-cron
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* fixed issue where the vulnerabilities cron was failing in large environments due to large SQL queries
4 changes: 3 additions & 1 deletion cmd/fleet/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ func cronVulnerabilities(
return fmt.Errorf("scanning vulnerabilities: %w", err)
}

start := time.Now()
level.Info(logger).Log("msg", "updating vulnerability host counts")
if err := ds.UpdateVulnerabilityHostCounts(ctx); err != nil {
if err := ds.UpdateVulnerabilityHostCounts(ctx, config.MaxConcurrency); err != nil {
return fmt.Errorf("updating vulnerability host counts: %w", err)
}
level.Info(logger).Log("msg", "vulnerability host counts updated", "took", time.Since(start).Seconds())
}

return nil
Expand Down
7 changes: 7 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ type VulnerabilitiesConfig struct {
DisableDataSync bool `json:"disable_data_sync" yaml:"disable_data_sync"`
RecentVulnerabilityMaxAge time.Duration `json:"recent_vulnerability_max_age" yaml:"recent_vulnerability_max_age"`
DisableWinOSVulnerabilities bool `json:"disable_win_os_vulnerabilities" yaml:"disable_win_os_vulnerabilities"`
MaxConcurrency int `json:"max_concurrency" yaml:"max_concurrency"`
}

// UpgradesConfig defines configs related to fleet server upgrades.
Expand Down Expand Up @@ -1309,6 +1310,11 @@ func (man Manager) addConfigs() {
false,
"Don't sync installed Windows updates nor perform Windows OS vulnerability processing.",
)
man.addConfigInt(
"vulnerabilities.max_concurrency",
5,
"Maximum number of concurrent database queries to use for processing vulnerabilities.",
)

// Upgrades
man.addConfigBool("upgrades.allow_missing_migrations", false,
Expand Down Expand Up @@ -1580,6 +1586,7 @@ func (man Manager) LoadConfig() FleetConfig {
DisableDataSync: man.getConfigBool("vulnerabilities.disable_data_sync"),
RecentVulnerabilityMaxAge: man.getConfigDuration("vulnerabilities.recent_vulnerability_max_age"),
DisableWinOSVulnerabilities: man.getConfigBool("vulnerabilities.disable_win_os_vulnerabilities"),
MaxConcurrency: man.getConfigInt("vulnerabilities.max_concurrency"),
},
Upgrades: UpgradesConfig{
AllowMissingMigrations: man.getConfigBool("upgrades.allow_missing_migrations"),
Expand Down
251 changes: 185 additions & 66 deletions server/datastore/mysql/vulnerabilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
"time"

"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
Expand Down Expand Up @@ -342,30 +343,193 @@ func (ds *Datastore) CountVulnerabilities(ctx context.Context, opt fleet.VulnLis
return count, nil
}

func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error {
func (ds *Datastore) distinctCVEs(ctx context.Context) ([]string, error) {
uniqueCVEQuery := `
SELECT DISTINCT cve FROM (
SELECT cve FROM software_cve
UNION
SELECT cve FROM operating_system_vulnerabilities
) AS combined_cves;
`

var cves []string
err := sqlx.SelectContext(ctx, ds.reader(ctx), &cves, uniqueCVEQuery)
if err != nil {
return nil, err
}
return cves, nil
}

type CountScope int

const (
GlobalCount CountScope = iota
NoTeamCount
TeamCount
)

func (ds *Datastore) batchFetchVulnerabilityCounts(
ctx context.Context,
scope CountScope,
maxRoutines int,
) ([]hostCount, error) {
const (
batchSize = 20
)

// Fetch distinct CVEs
allCVEs, err := ds.distinctCVEs(ctx)
if err != nil {
return nil, err
}

query := getVulnHostCountQuery(scope)
if query == "" {
return nil, ctxerr.Errorf(ctx, "invalid scope: %d", scope)
}

var (
hostCounts []hostCount
mu sync.Mutex
wg sync.WaitGroup
sem = make(chan struct{}, maxRoutines)
errChan = make(chan error, len(allCVEs)/batchSize+1)
)

// Process CVEs in batches concurrently
for i := 0; i < len(allCVEs); i += batchSize {
end := i + batchSize
if end > len(allCVEs) {
end = len(allCVEs)
}

batchCVEs := allCVEs[i:end]
wg.Add(1)
sem <- struct{}{} // Acquire semaphore

go func(cves []string) {
defer wg.Done()
defer func() { <-sem }() // Release semaphore

counts, err := ds.fetchBatchCounts(ctx, cves, query)
if err != nil {
errChan <- err
return
}

mu.Lock()
hostCounts = append(hostCounts, counts...)
mu.Unlock()
}(batchCVEs)
}

wg.Wait()
close(errChan)

// Check for errors
for err := range errChan {
if err != nil {
return nil, err
}
}

return hostCounts, nil
}

// fetchBatchCounts executes the query for a batch of CVEs.
func (ds *Datastore) fetchBatchCounts(
ctx context.Context,
batchCVEs []string,
scopeConfig string,
) ([]hostCount, error) {
query, args, err := sqlx.In(scopeConfig, batchCVEs, batchCVEs)
if err != nil {
return nil, err
}

var counts []hostCount
err = sqlx.SelectContext(ctx, ds.reader(ctx), &counts, query, args...)
if err != nil {
return nil, err
}

return counts, nil
}

// getScopeConfig returns the query configuration for the given scope.
func getVulnHostCountQuery(scope CountScope) string {
switch scope {
case GlobalCount:
return `
SELECT 0 as team_id, 1 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
WHERE sc.cve IN (?)
UNION
SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
WHERE osv.cve IN (?)
) AS combined_results
GROUP BY cve
`
case NoTeamCount:
return `
SELECT 0 as team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
WHERE sc.cve IN (?)
UNION
SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
WHERE osv.cve IN (?)
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NULL
GROUP BY cve
`
case TeamCount:
return `
SELECT h.team_id as team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
WHERE sc.cve IN (?)
UNION
SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
WHERE osv.cve IN (?)
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NOT NULL
GROUP BY h.team_id, combined_results.cve
`
default:
return ""
}
}

func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context, maxRoutines int) error {
// set all counts to 0 to later identify rows to delete
_, err := ds.writer(ctx).ExecContext(ctx, "UPDATE vulnerability_host_counts SET host_count = 0")
if err != nil {
return ctxerr.Wrap(ctx, err, "initializing vulnerability host counts")
}

globalSelectStmt := `
SELECT 0 as team_id, 1 as global_stats, cve, COUNT(*) AS host_count
FROM (
SELECT sc.cve, hs.host_id
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
UNION
SELECT osv.cve, hos.host_id
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
) AS combined_results
GROUP BY cve;
`

globalHostCounts, err := ds.fetchHostCounts(ctx, globalSelectStmt)
globalHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, GlobalCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching global vulnerability host counts")
}
Expand All @@ -375,25 +539,7 @@ func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error {
return ctxerr.Wrap(ctx, err, "inserting global vulnerability host counts")
}

teamSelectStmt := `
SELECT h.team_id, 0 as global_stats, combined_results.cve, COUNT(*) AS host_count
FROM (
SELECT hs.host_id, sc.cve
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
UNION
SELECT hos.host_id, osv.cve
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NOT NULL
GROUP BY h.team_id, combined_results.cve
`

teamHostCounts, err := ds.fetchHostCounts(ctx, teamSelectStmt)
teamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, TeamCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts")
}
Expand All @@ -403,27 +549,9 @@ func (ds *Datastore) UpdateVulnerabilityHostCounts(ctx context.Context) error {
return ctxerr.Wrap(ctx, err, "inserting team vulnerability host counts")
}

noTeamSelectStmt := `
SELECT 0 as team_id, 0 as global_stats, cve, COUNT(*) AS host_count
FROM (
SELECT hs.host_id, sc.cve
FROM software_cve sc
INNER JOIN host_software hs ON sc.software_id = hs.software_id
UNION
SELECT hos.host_id, osv.cve
FROM operating_system_vulnerabilities osv
INNER JOIN host_operating_system hos ON hos.os_id = osv.operating_system_id
) AS combined_results
INNER JOIN hosts h ON combined_results.host_id = h.id
WHERE h.team_id IS NULL
GROUP BY cve
`

noTeamHostCounts, err := ds.fetchHostCounts(ctx, noTeamSelectStmt)
noTeamHostCounts, err := ds.batchFetchVulnerabilityCounts(ctx, NoTeamCount, maxRoutines)
if err != nil {
return ctxerr.Wrap(ctx, err, "fetching team vulnerability host counts")
return ctxerr.Wrap(ctx, err, "fetching no team vulnerability host counts")
}

err = ds.batchInsertHostCounts(ctx, noTeamHostCounts)
Expand Down Expand Up @@ -455,15 +583,6 @@ func (ds *Datastore) cleanupVulnerabilityHostCounts(ctx context.Context) error {
return nil
}

func (ds *Datastore) fetchHostCounts(ctx context.Context, query string) ([]hostCount, error) {
var hostCounts []hostCount
err := sqlx.SelectContext(ctx, ds.reader(ctx), &hostCounts, query)
if err != nil {
return nil, err
}
return hostCounts, nil
}

func (ds *Datastore) batchInsertHostCounts(ctx context.Context, counts []hostCount) error {
if len(counts) == 0 {
return nil
Expand Down
Loading

0 comments on commit 80f503a

Please sign in to comment.