diff --git a/lib/auth/auth.go b/lib/auth/auth.go index e490553c6bbaa..6554fdf8b5188 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1512,17 +1512,13 @@ func (a *Server) runPeriodicOperations() { if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { logger.WarnContext(a.closeCtx, "failed to update SSH server hostname", "error", err) } - } - - // If the hostname has been replaced by a sanitized version, revert it back to the original - // if the original is valid under the most recent rules. - if oldHostname, ok := srv.GetLabel(replacedHostnameLabel); ok && validServerHostname(oldHostname) { - switch s := srv.(type) { - case *types.ServerV2: - s.Spec.Hostname = oldHostname - delete(s.Metadata.Labels, replacedHostnameLabel) - default: - return false, trace.BadParameter("invalid server provided") + } else if oldHostname, ok := srv.GetLabel(replacedHostnameLabel); ok && validServerHostname(oldHostname) { + // If the hostname has been replaced by a sanitized version, revert it back to the original + // if the original is valid under the most recent rules. + logger := a.logger.With("server", srv.GetName(), "old_hostname", oldHostname, "sanitized_hostname", srv.GetHostname()) + if err := restoreSanitizedHostname(srv); err != nil { + logger.WarnContext(a.closeCtx, "failed to restore sanitized static SSH server hostname", "error", err) + return false, nil } if _, err := a.Services.UpdateNode(a.closeCtx, srv); err != nil && !trace.IsCompareFailed(err) { log.Warnf("Failed to update node hostname: %v", err) @@ -5688,6 +5684,26 @@ func sanitizeHostname(server types.Server) error { return nil } +// restoreSanitizedHostname restores the original hostname of a server and removes the label. +func restoreSanitizedHostname(server types.Server) error { + oldHostname, ok := server.GetLabels()[replacedHostnameLabel] + // if the label is not present or the hostname is invalid under the most recent rules, do nothing. + if !ok || !validServerHostname(oldHostname) { + return nil + } + + switch s := server.(type) { + case *types.ServerV2: + // restore the original hostname and remove the label. + s.Spec.Hostname = oldHostname + delete(s.Metadata.Labels, replacedHostnameLabel) + default: + return trace.BadParameter("invalid server provided") + } + + return nil +} + // UpsertNode implements [services.Presence] by delegating to [Server.Services] // and potentially emitting a [usagereporter] event. func (a *Server) UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) {