Skip to content

Commit

Permalink
Convert lib/reversetunnel to use slog
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Dec 17, 2024
1 parent 7384627 commit 1708df7
Show file tree
Hide file tree
Showing 15 changed files with 356 additions and 311 deletions.
67 changes: 36 additions & 31 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/api/constants"
Expand All @@ -42,6 +42,7 @@ import (
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnel/track"
"github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

type AgentState string
Expand Down Expand Up @@ -113,8 +114,8 @@ type agentConfig struct {
// clock is use to get the current time. Mock clocks can be used for
// testing.
clock clockwork.Clock
// log is an optional logger.
log logrus.FieldLogger
// logger is an optional logger.
logger *slog.Logger
// localAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
localAuthAddresses []string
Expand Down Expand Up @@ -145,12 +146,13 @@ func (c *agentConfig) checkAndSetDefaults() error {
if c.clock == nil {
c.clock = clockwork.NewRealClock()
}
if c.log == nil {
c.log = logrus.New()
if c.logger == nil {
c.logger = slog.Default()
}
c.log = c.log.
WithField("leaseID", c.lease.ID()).
WithField("target", c.addr.String())
c.logger = c.logger.With(
"lease_id", c.lease.ID(),
"target", c.addr.String(),
)

return nil
}
Expand Down Expand Up @@ -284,7 +286,10 @@ func (a *agent) updateState(state AgentState) (AgentState, error) {

prevState := a.state
a.state = state
a.log.Debugf("Changing state %s -> %s.", prevState, state)
a.logger.DebugContext(a.ctx, "Agent state updated",
"previous_state", prevState,
"current_state", state,
)

if a.agentConfig.stateCallback != nil {
go a.agentConfig.stateCallback(a.state)
Expand All @@ -296,7 +301,7 @@ func (a *agent) updateState(state AgentState) (AgentState, error) {
// Start starts an agent returning after successfully connecting and sending
// the first heartbeat.
func (a *agent) Start(ctx context.Context) error {
a.log.Debugf("Starting agent %v", a.addr)
a.logger.DebugContext(ctx, "Starting agent", "addr", a.addr.FullAddress())

var err error
defer func() {
Expand Down Expand Up @@ -325,7 +330,7 @@ func (a *agent) Start(ctx context.Context) error {
a.wg.Add(1)
go func() {
if err := a.handleGlobalRequests(a.ctx, a.client.GlobalRequests()); err != nil {
a.log.WithError(err).Debug("Failed to handle global requests.")
a.logger.DebugContext(a.ctx, "Failed to handle global requests", "error", err)
}
a.wg.Done()
a.Stop()
Expand All @@ -336,7 +341,7 @@ func (a *agent) Start(ctx context.Context) error {
a.wg.Add(1)
go func() {
if err := a.handleDrainChannels(); err != nil {
a.log.WithError(err).Debug("Failed to handle drainable channels.")
a.logger.DebugContext(a.ctx, "Failed to handle drainable channels", "error", err)
}
a.wg.Done()
a.Stop()
Expand All @@ -345,7 +350,7 @@ func (a *agent) Start(ctx context.Context) error {
a.wg.Add(1)
go func() {
if err := a.handleChannels(); err != nil {
a.log.WithError(err).Debug("Failed to handle channels.")
a.logger.DebugContext(a.ctx, "Failed to handle channels", "error", err)
}
a.wg.Done()
a.Stop()
Expand Down Expand Up @@ -460,23 +465,23 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
case versionRequest:
version, err := a.versionGetter.getVersion(ctx)
if err != nil {
a.log.WithError(err).Warnf("Failed to retrieve auth version in response to %v request.", r.Type)
a.logger.WarnContext(ctx, "Failed to retrieve auth version in response to x-teleport-version request", "error", err)
if err := a.client.Reply(r, false, []byte("Failed to retrieve auth version")); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to x-teleport-version request", "error", err)
continue
}
}

if err := a.client.Reply(r, true, []byte(version)); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to x-teleport-version request", "error", err)
continue
}
case reconnectRequest:
a.log.Debugf("Received reconnect advisory request from proxy.")
a.logger.DebugContext(ctx, "Received reconnect advisory request from proxy")
if r.WantReply {
err := a.client.Reply(r, true, nil)
if err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to [email protected] request", "error", err)
}
}

Expand All @@ -487,7 +492,7 @@ func (a *agent) handleGlobalRequests(ctx context.Context, requests <-chan *ssh.R
// This handles keep-alive messages and matches the behavior of OpenSSH.
err := a.client.Reply(r, false, nil)
if err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.logger.DebugContext(ctx, "Failed to reply to global request", "request_type", r.Type, "error", err)
continue
}
}
Expand Down Expand Up @@ -555,10 +560,10 @@ func (a *agent) handleDrainChannels() error {
bytes, _ := a.clock.Now().UTC().MarshalText()
_, err := a.hbChannel.SendRequest(a.ctx, "ping", false, bytes)
if err != nil {
a.log.Error(err)
a.logger.ErrorContext(a.ctx, "failed to sing ping request", "error", err)
return trace.Wrap(err)
}
a.log.Debugf("Ping -> %v.", a.client.RemoteAddr())
a.logger.DebugContext(a.ctx, "Sent ping request", "target_addr", logutils.StringerAttr(a.client.RemoteAddr()))
// Handle transport requests.
case nch := <-a.transportC:
if nch == nil {
Expand All @@ -567,15 +572,15 @@ func (a *agent) handleDrainChannels() error {
if a.isDraining() {
err := nch.Reject(ssh.ConnectionFailed, "agent connection is draining")
if err != nil {
a.log.WithError(err).Warningf("Failed to reject transport channel.")
a.logger.WarnContext(a.ctx, "Failed to reject transport channel", "error", err)
}
continue
}

a.log.Debugf("Transport request: %v.", nch.ChannelType())
a.logger.DebugContext(a.ctx, "Received trransport request", "channel_type", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Warningf("Failed to accept transport request: %v.", err)
a.logger.WarnContext(a.ctx, "Failed to accept transport request", "error", err)
continue
}

Expand All @@ -601,10 +606,10 @@ func (a *agent) handleChannels() error {
if nch == nil {
continue
}
a.log.Debugf("Discovery request channel opened: %v.", nch.ChannelType())
a.logger.DebugContext(a.ctx, "Discovery request channel opened", "channel_type", nch.ChannelType())
ch, req, err := nch.Accept()
if err != nil {
a.log.Warningf("Failed to accept discovery channel request: %v.", err)
a.logger.WarnContext(a.ctx, "Failed to accept discovery channel request", "error", err)
continue
}

Expand All @@ -624,11 +629,11 @@ func (a *agent) handleChannels() error {
// ch : SSH channel which received "teleport-transport" out-of-band request
// reqC : request payload
func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
a.log.Debugf("handleDiscovery requests channel.")
a.logger.DebugContext(a.ctx, "handleDiscovery requests channel")
sshutils.DiscardChannelData(ch)
defer func() {
if err := ch.Close(); err != nil {
a.log.Warnf("Failed to close discovery channel: %v", err)
a.logger.WarnContext(a.ctx, "Failed to close discovery channel", "error", err)
}
}()

Expand All @@ -639,17 +644,17 @@ func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
return
case req = <-reqC:
if req == nil {
a.log.Infof("Connection closed, returning")
a.logger.InfoContext(a.ctx, "Connection closed, returning")
return
}

var r discoveryRequest
if err := json.Unmarshal(req.Payload, &r); err != nil {
a.log.WithError(err).Warn("Bad payload")
a.logger.WarnContext(a.ctx, "Received discovery request with bad payload", "error", err)
return
}

a.log.Debugf("Received discovery request: %s", &r)
a.logger.DebugContext(a.ctx, "Received discovery request", "discovery_request", logutils.StringerAttr(&r))
a.tracker.TrackExpected(r.TrackProxies()...)
}
}
Expand Down
10 changes: 5 additions & 5 deletions lib/reversetunnel/agent_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ package reversetunnel

import (
"context"
"log/slog"
"strings"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

apidefaults "github.com/gravitational/teleport/api/defaults"
Expand Down Expand Up @@ -55,7 +55,7 @@ type agentDialer struct {
authMethods []ssh.AuthMethod
fips bool
options []proxy.DialerOptionFunc
log logrus.FieldLogger
logger *slog.Logger
isClaimed func(principals ...string) bool
}

Expand All @@ -65,7 +65,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
dialer := proxy.DialerFromEnvironment(addr.Addr, d.options...)
pconn, err := dialer.DialTimeout(ctx, addr.AddrNetwork, addr.Addr, apidefaults.DefaultIOTimeout)
if err != nil {
d.log.WithError(err).Debugf("Failed to dial %s.", addr.Addr)
d.logger.DebugContext(ctx, "Failed to dial", "error", err, "target_addr", addr.Addr)
return nil, trace.Wrap(err)
}

Expand All @@ -75,7 +75,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
GetHostCheckers: d.hostCheckerFunc(ctx),
OnCheckCert: func(c *ssh.Certificate) error {
if d.isClaimed != nil && d.isClaimed(c.ValidPrincipals...) {
d.log.Debugf("Aborting SSH handshake because the proxy %q is already claimed by some other agent.", c.ValidPrincipals[0])
d.logger.DebugContext(ctx, "Aborting SSH handshake because the proxy is already claimed by some other agent.", "proxy_id", c.ValidPrincipals[0])
// the error message must end with
// [proxyAlreadyClaimedError] to be recognized by
// [isProxyAlreadyClaimed]
Expand All @@ -88,7 +88,7 @@ func (d *agentDialer) DialContext(ctx context.Context, addr utils.NetAddr) (SSHC
FIPS: d.fips,
})
if err != nil {
d.log.Debugf("Failed to create host key callback for %v: %v.", addr.Addr, err)
d.logger.DebugContext(ctx, "Failed to create host key callback", "target_addr", addr.Addr, "error", err)
return nil, trace.Wrap(err)
}

Expand Down
3 changes: 1 addition & 2 deletions lib/reversetunnel/agent_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"testing"

"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"

Expand Down Expand Up @@ -90,7 +89,7 @@ func TestAgentCertChecker(t *testing.T) {
dialer := agentDialer{
client: &fakeClient{caKey: ca.PublicKey()},
authMethods: []ssh.AuthMethod{ssh.PublicKeys(signer)},
log: logrus.New(),
logger: utils.NewSlogLoggerForTests(),
}

_, err = dialer.DialContext(context.Background(), *utils.MustParseAddr(sshServer.Addr()))
Expand Down
Loading

0 comments on commit 1708df7

Please sign in to comment.