Skip to content

Commit

Permalink
Allow HTTP clients to be cancelled by passing along context (#565)
Browse files Browse the repository at this point in the history
This replaces all calls of "http.NewRequest" with
"http.NewRequestWithContext", to ensure that the HTTP client stops what
its doing immediately when the context gets canceled.

The main motivation here is to better handle network connectivity problems
to the pganalyze snapshot API, which otherwise make it seem like the
collector hangs when trying to cancel a test run with CTRL+C (SIGINT).
  • Loading branch information
lfittl authored Jul 12, 2024
1 parent 6403715 commit 44bb410
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 46 deletions.
5 changes: 3 additions & 2 deletions grant/default.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package grant

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand All @@ -10,8 +11,8 @@ import (
"github.com/pganalyze/collector/util"
)

func GetDefaultGrant(server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.Grant, error) {
req, err := http.NewRequest("GET", server.Config.APIBaseURL+"/v2/snapshots/grant", nil)
func GetDefaultGrant(ctx context.Context, server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.Grant, error) {
req, err := http.NewRequestWithContext(ctx, "GET", server.Config.APIBaseURL+"/v2/snapshots/grant", nil)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectApiConnection, err.Error())
return state.Grant{}, err
Expand Down
5 changes: 3 additions & 2 deletions grant/logs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package grant

import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
Expand All @@ -10,8 +11,8 @@ import (
"github.com/pganalyze/collector/util"
)

func GetLogsGrant(server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.GrantLogs, error) {
req, err := http.NewRequest("GET", server.Config.APIBaseURL+"/v2/snapshots/grant_logs", nil)
func GetLogsGrant(ctx context.Context, server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.GrantLogs, error) {
req, err := http.NewRequestWithContext(ctx, "GET", server.Config.APIBaseURL+"/v2/snapshots/grant_logs", nil)
if err != nil {
return state.GrantLogs{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion input/full.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func CollectFull(ctx context.Context, server *state.Server, connection *sql.DB,
}

if globalCollectionOpts.CollectSystemInformation {
ps.System = system.GetSystemState(server, logger, globalCollectionOpts)
ps.System = system.GetSystemState(ctx, server, logger, globalCollectionOpts)
}

server.SetLogTimezone(ts.Settings)
Expand Down
33 changes: 17 additions & 16 deletions input/system/crunchy_bridge/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crunchy_bridge

import (
"context"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -74,8 +75,8 @@ type DiskUsageMetrics struct {
WalSize uint64
}

func (c *Client) NewRequest(method string, path string) (*http.Request, error) {
req, err := http.NewRequest(method, c.BaseURL+path, nil)
func (c *Client) NewRequest(ctx context.Context, method string, path string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, c.BaseURL+path, nil)
if err != nil {
return nil, err
}
Expand All @@ -86,8 +87,8 @@ func (c *Client) NewRequest(method string, path string) (*http.Request, error) {
return req, nil
}

func (c *Client) GetClusterInfo() (*ClusterInfo, error) {
req, err := c.NewRequest("GET", "/clusters/"+c.ClusterID)
func (c *Client) GetClusterInfo(ctx context.Context) (*ClusterInfo, error) {
req, err := c.NewRequest(ctx, "GET", "/clusters/"+c.ClusterID)
if err != nil {
return nil, err
}
Expand All @@ -114,8 +115,8 @@ func (c *Client) GetClusterInfo() (*ClusterInfo, error) {
return &clusterInfo, err
}

func (c *Client) getMetrics(name string) (*MetricViews, error) {
req, err := c.NewRequest("GET", fmt.Sprintf("/metric-views/%s?cluster_id=%s&period=15m", name, c.ClusterID))
func (c *Client) getMetrics(ctx context.Context, name string) (*MetricViews, error) {
req, err := c.NewRequest(ctx, "GET", fmt.Sprintf("/metric-views/%s?cluster_id=%s&period=15m", name, c.ClusterID))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -143,8 +144,8 @@ func (c *Client) getMetrics(name string) (*MetricViews, error) {
return &metricViews, nil
}

func (c *Client) GetCPUMetrics() (*CPUMetrics, error) {
metricViews, err := c.getMetrics("cpu")
func (c *Client) GetCPUMetrics(ctx context.Context) (*CPUMetrics, error) {
metricViews, err := c.getMetrics(ctx, "cpu")
if err != nil {
return nil, err
}
Expand All @@ -166,8 +167,8 @@ func (c *Client) GetCPUMetrics() (*CPUMetrics, error) {
return &metrics, err
}

func (c *Client) GetMemoryMetrics() (*MemoryMetrics, error) {
metricViews, err := c.getMetrics("memory")
func (c *Client) GetMemoryMetrics(ctx context.Context) (*MemoryMetrics, error) {
metricViews, err := c.getMetrics(ctx, "memory")
if err != nil {
return nil, err
}
Expand All @@ -185,8 +186,8 @@ func (c *Client) GetMemoryMetrics() (*MemoryMetrics, error) {
return &metrics, err
}

func (c *Client) GetIOPSMetrics() (*IOPSMetrics, error) {
metricViews, err := c.getMetrics("iops")
func (c *Client) GetIOPSMetrics(ctx context.Context) (*IOPSMetrics, error) {
metricViews, err := c.getMetrics(ctx, "iops")
if err != nil {
return nil, err
}
Expand All @@ -204,8 +205,8 @@ func (c *Client) GetIOPSMetrics() (*IOPSMetrics, error) {
return &metrics, err
}

func (c *Client) GetLoadAverageMetrics() (*LoadAverageMetrics, error) {
metricViews, err := c.getMetrics("load-average")
func (c *Client) GetLoadAverageMetrics(ctx context.Context) (*LoadAverageMetrics, error) {
metricViews, err := c.getMetrics(ctx, "load-average")
if err != nil {
return nil, err
}
Expand All @@ -221,8 +222,8 @@ func (c *Client) GetLoadAverageMetrics() (*LoadAverageMetrics, error) {
return &metrics, err
}

func (c *Client) GetDiskUsageMetrics() (*DiskUsageMetrics, error) {
metricViews, err := c.getMetrics("disk-usage")
func (c *Client) GetDiskUsageMetrics(ctx context.Context) (*DiskUsageMetrics, error) {
metricViews, err := c.getMetrics(ctx, "disk-usage")
if err != nil {
return nil, err
}
Expand Down
7 changes: 4 additions & 3 deletions input/system/crunchy_bridge/system.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crunchy_bridge

import (
"context"
"time"

"github.com/pganalyze/collector/input/system/selfhosted"
Expand All @@ -9,7 +10,7 @@ import (
)

// GetSystemState - Gets system information about a Crunchy Bridge instance
func GetSystemState(server *state.Server, logger *util.Logger) (system state.SystemState) {
func GetSystemState(ctx context.Context, server *state.Server, logger *util.Logger) (system state.SystemState) {
config := server.Config
// With Crunchy Bridge, we are assuming that the collector is deployed on Container Apps,
// which run directly on the database server. Most of the metrics can be obtained
Expand All @@ -25,7 +26,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys
}
client := Client{Client: *config.HTTPClientWithRetry, BaseURL: apiBaseURL, BearerToken: config.CrunchyBridgeAPIKey, ClusterID: config.CrunchyBridgeClusterID}

clusterInfo, err := client.GetClusterInfo()
clusterInfo, err := client.GetClusterInfo(ctx)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting cluster info: %s", err)
logger.PrintError("CrunchyBridge/System: Encountered error when getting cluster info %v\n", err)
Expand All @@ -44,7 +45,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys
system.Info.CrunchyBridge.CreatedAt = parsedCreatedAt
}

diskUsageMetrics, err := client.GetDiskUsageMetrics()
diskUsageMetrics, err := client.GetDiskUsageMetrics(ctx)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting cluster disk usage metrics: %s", err)
logger.PrintError("CrunchyBridge/System: Encountered error when getting cluster disk usage metrics %v\n", err)
Expand Down
2 changes: 1 addition & 1 deletion input/system/heroku/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func processSystemMetrics(ctx context.Context, timestamp time.Time, content []by

prefixedLogger := logger.WithPrefix(server.Config.SectionName)

grant, err := grant.GetDefaultGrant(server, globalCollectionOpts, prefixedLogger)
grant, err := grant.GetDefaultGrant(ctx, server, globalCollectionOpts, prefixedLogger)
if err != nil {
prefixedLogger.PrintError("Could not get default grant for system snapshot: %s", err)
return
Expand Down
6 changes: 3 additions & 3 deletions input/system/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func DownloadLogFiles(ctx context.Context, server *state.Server, globalCollectio
}

// GetSystemState - Retrieves a system snapshot for this system and returns it
func GetSystemState(server *state.Server, logger *util.Logger, globalCollectionOpts state.CollectionOpts) (system state.SystemState) {
func GetSystemState(ctx context.Context, server *state.Server, logger *util.Logger, globalCollectionOpts state.CollectionOpts) (system state.SystemState) {
config := server.Config
dbHost := config.GetDbHost()
if config.SystemType == "amazon_rds" {
Expand All @@ -48,12 +48,12 @@ func GetSystemState(server *state.Server, logger *util.Logger, globalCollectionO
system.Info.Type = state.HerokuSystem
server.SelfTest.MarkCollectionAspectNotAvailable(state.CollectionAspectSystemStats, "not available on this platform")
} else if config.SystemType == "crunchy_bridge" {
system = crunchy_bridge.GetSystemState(server, logger)
system = crunchy_bridge.GetSystemState(ctx, server, logger)
} else if config.SystemType == "aiven" {
system.Info.Type = state.AivenSystem
server.SelfTest.MarkCollectionAspectNotAvailable(state.CollectionAspectSystemStats, "not available on this platform")
} else if config.SystemType == "tembo" {
system = tembo.GetSystemState(server, logger)
system = tembo.GetSystemState(ctx, server, logger)
} else if dbHost == "" || dbHost == "localhost" || dbHost == "127.0.0.1" || config.AlwaysCollectSystemData {
system = selfhosted.GetSystemState(server, logger)
} else {
Expand Down
25 changes: 13 additions & 12 deletions input/system/tembo/system.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tembo

import (
"context"
"encoding/json"
"net/http"
"net/url"
Expand Down Expand Up @@ -38,7 +39,7 @@ type Metric struct {
}

// GetSystemState - Gets system information for a Tembo Cloud instance
func GetSystemState(server *state.Server, logger *util.Logger) (system state.SystemState) {
func GetSystemState(ctx context.Context, server *state.Server, logger *util.Logger) (system state.SystemState) {
system.Info.Type = state.TemboSystem
config := server.Config
headers := map[string]string{
Expand All @@ -51,7 +52,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys

// Get CPU usage percentage
query := "sum(node_namespace_pod_container:container_cpu_usage_seconds_total:sum_irate{ namespace=\"" + config.TemboNamespace + "\"}) / sum(kube_pod_container_resource_requests{job=\"kube-state-metrics\", namespace=\"" + config.TemboNamespace + "\", resource=\"cpu\"})"
cpuUsage, err := getFloat64(query, metricsUrl, client, headers)
cpuUsage, err := getFloat64(ctx, query, metricsUrl, client, headers)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting CPU info: %v", err)
logger.PrintError("Tembo/System: Encountered error when getting CPU info %v\n", err)
Expand All @@ -68,7 +69,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys

// Get total memory
query = "sum(max by(pod) (kube_pod_container_resource_requests{job=\"kube-state-metrics\", namespace=\"" + config.TemboNamespace + "\", resource=\"memory\"}))"
memoryTotalBytes, err := getUint64(query, metricsUrl, client, headers)
memoryTotalBytes, err := getUint64(ctx, query, metricsUrl, client, headers)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting memory info: %v", err)
logger.PrintError("Tembo/System: Encountered error when getting memory info %v\n", err)
Expand All @@ -77,7 +78,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys

// Get available memory
query = "sum(max by(pod) (kube_pod_container_resource_requests{job=\"kube-state-metrics\", namespace=\"" + config.TemboNamespace + "\", resource=\"memory\"})) - sum(container_memory_working_set_bytes{job=\"kubelet\", metrics_path=\"/metrics/cadvisor\", namespace=\"" + config.TemboNamespace + "\",container!=\"\", image!=\"\"})"
memoryAvailableBytes, err := getUint64(query, metricsUrl, client, headers)
memoryAvailableBytes, err := getUint64(ctx, query, metricsUrl, client, headers)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting memory info: %v", err)
logger.PrintError("Tembo/System: Encountered error when getting memory info %v\n", err)
Expand All @@ -91,7 +92,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys
// Get disk capacity
// Note this does not yet handle multiple volume claims in cases like HA
query = "kubelet_volume_stats_capacity_bytes{namespace=\"" + config.TemboNamespace + "\", persistentvolumeclaim=~\"" + config.TemboNamespace + "-1" + "\"}"
diskCapacity, err := getUint64(query, metricsUrl, client, headers)
diskCapacity, err := getUint64(ctx, query, metricsUrl, client, headers)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting disk info: %v", err)
logger.PrintError("Tembo/System: Encountered error when getting disk info %v\n", err)
Expand All @@ -101,7 +102,7 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys
// Get disk available
// Note this does not yet handle multiple volume claims in cases like HA
query = "kubelet_volume_stats_available_bytes{namespace=\"" + config.TemboNamespace + "\", persistentvolumeclaim=~\"" + config.TemboNamespace + "-1" + "\"}"
diskAvailable, err := getUint64(query, metricsUrl, client, headers)
diskAvailable, err := getUint64(ctx, query, metricsUrl, client, headers)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectSystemStats, "error getting disk info: %v", err)
logger.PrintError("Tembo/System: Encountered error when getting disk info %v\n", err)
Expand All @@ -121,8 +122,8 @@ func GetSystemState(server *state.Server, logger *util.Logger) (system state.Sys
return
}

func getFloat64(query string, metricsUrl string, client http.Client, headers map[string]string) (float64, error) {
res, err := getSystemInfo(metricsUrl, query, client, headers)
func getFloat64(ctx context.Context, query string, metricsUrl string, client http.Client, headers map[string]string) (float64, error) {
res, err := getSystemInfo(ctx, metricsUrl, query, client, headers)
if err != nil {
return 0, err
}
Expand All @@ -142,8 +143,8 @@ func getFloat64(query string, metricsUrl string, client http.Client, headers map
return value, nil
}

func getUint64(query string, metricsUrl string, client http.Client, headers map[string]string) (uint64, error) {
res, err := getSystemInfo(metricsUrl, query, client, headers)
func getUint64(ctx context.Context, query string, metricsUrl string, client http.Client, headers map[string]string) (uint64, error) {
res, err := getSystemInfo(ctx, metricsUrl, query, client, headers)
if err != nil {
return 0, err
}
Expand All @@ -163,12 +164,12 @@ func getUint64(query string, metricsUrl string, client http.Client, headers map[
return value, nil
}

func getSystemInfo(metricsUrl string, query string, client http.Client, headers map[string]string) (Response, error) {
func getSystemInfo(ctx context.Context, metricsUrl string, query string, client http.Client, headers map[string]string) (Response, error) {
encodedQuery := url.QueryEscape(query)

metricsUrl = metricsUrl + encodedQuery

req, err := http.NewRequest("GET", metricsUrl, nil)
req, err := http.NewRequestWithContext(ctx, "GET", metricsUrl, nil)
if err != nil {
return Response{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion runner/activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func processActivityForServer(ctx context.Context, server *state.Server, globalC
}

if !globalCollectionOpts.ForceEmptyGrant {
newGrant, err = grant.GetDefaultGrant(server, globalCollectionOpts, logger)
newGrant, err = grant.GetDefaultGrant(ctx, server, globalCollectionOpts, logger)
if err != nil {
return newState, false, errors.Wrap(err, "could not get default grant for activity snapshot")
}
Expand Down
2 changes: 1 addition & 1 deletion runner/full.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func processServer(ctx context.Context, server *state.Server, globalCollectionOp

if !globalCollectionOpts.ForceEmptyGrant {
// Note: In case of server errors, we should reuse the old grant if its still recent (i.e. less than 50 minutes ago)
newGrant, err = grant.GetDefaultGrant(server, globalCollectionOpts, logger)
newGrant, err = grant.GetDefaultGrant(ctx, server, globalCollectionOpts, logger)
if err != nil {
if server.Grant.Valid {
logger.PrintVerbose("Could not acquire snapshot grant, reusing previous grant: %s", err)
Expand Down
8 changes: 4 additions & 4 deletions runner/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func downloadLogsForServerWithLocksAndCallbacks(ctx context.Context, wg *sync.Wa
}

func downloadLogsForServer(ctx context.Context, server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (state.PersistedLogState, bool, error) {
grant, err := getLogsGrant(server, globalCollectionOpts, logger)
grant, err := getLogsGrant(ctx, server, globalCollectionOpts, logger)
if err != nil || !grant.Valid {
return server.LogPrevState, false, err
}
Expand Down Expand Up @@ -254,7 +254,7 @@ func processLogStream(ctx context.Context, server *state.Server, logLines []stat
return tooFreshLogLines
}

grant, err := getLogsGrant(server, globalCollectionOpts, logger)
grant, err := getLogsGrant(ctx, server, globalCollectionOpts, logger)
if err != nil {
// Note we intentionally discard log lines here (and in the other
// error case below), because the HTTP client already retries to work
Expand All @@ -276,8 +276,8 @@ func processLogStream(ctx context.Context, server *state.Server, logLines []stat
return tooFreshLogLines
}

func getLogsGrant(server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (logGrant state.GrantLogs, err error) {
logGrant, err = grant.GetLogsGrant(server, globalCollectionOpts, logger)
func getLogsGrant(ctx context.Context, server *state.Server, globalCollectionOpts state.CollectionOpts, logger *util.Logger) (logGrant state.GrantLogs, err error) {
logGrant, err = grant.GetLogsGrant(ctx, server, globalCollectionOpts, logger)
if err != nil {
server.SelfTest.MarkCollectionAspectError(state.CollectionAspectLogs, "error getting log grant: %s", err)
return state.GrantLogs{Valid: false}, errors.Wrap(err, "could not get log grant")
Expand Down

0 comments on commit 44bb410

Please sign in to comment.