Skip to content

Commit

Permalink
Merge pull request #1862 from tkleczek/fix-rfc-errors
Browse files Browse the repository at this point in the history
Improve auth flow error handling
  • Loading branch information
nabokihms authored Aug 2, 2021
2 parents 766fc7a + 4ffaa60 commit 3fac2ab
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 86 deletions.
25 changes: 10 additions & 15 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return
}
}
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
s.renderError(r, w, http.StatusBadRequest, "Connector ID does not match a valid Connector")
return
}

Expand Down Expand Up @@ -187,21 +187,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
authReq, err := s.parseAuthorizationRequest(r)
if err != nil {
s.logger.Errorf("Failed to parse authorization request: %v", err)
status := http.StatusInternalServerError

// If this is an authErr, let's let it handle the error, or update the HTTP
// status code
if err, ok := err.(*authErr); ok {
if handler, ok := err.Handle(); ok {
// client_id and redirect_uri checked out and we can redirect back to
// the client with the error.
handler.ServeHTTP(w, r)
return
}
status = err.Status()

switch authErr := err.(type) {
case *redirectedAuthErr:
authErr.Handler().ServeHTTP(w, r)
case *displayedAuthErr:
s.renderError(r, w, authErr.Status, err.Error())
default:
panic("unsupported error type")
}

s.renderError(r, w, status, err.Error())
return
}

Expand Down Expand Up @@ -770,7 +765,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
case grantTypePassword:
s.withClientFromStorage(w, r, s.handlePasswordGrant)
default:
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
}
}

Expand Down
99 changes: 50 additions & 49 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,35 @@ import (

// TODO(ericchiang): clean this file up and figure out more idiomatic error handling.

// authErr is an error response to an authorization request.
// See: https://tools.ietf.org/html/rfc6749#section-4.1.2.1
type authErr struct {

// displayedAuthErr is an error that should be displayed to the user as a web page
type displayedAuthErr struct {
Status int
Description string
}

func (err *displayedAuthErr) Error() string {
return err.Description
}

func newDisplayedErr(status int, format string, a ...interface{}) *displayedAuthErr {
return &displayedAuthErr{status, fmt.Sprintf(format, a...)}
}

// redirectedAuthErr is an error that should be reported back to the client by 302 redirect
type redirectedAuthErr struct {
State string
RedirectURI string
Type string
Description string
}

func (err *authErr) Status() int {
if err.State == errServerError {
return http.StatusInternalServerError
}
return http.StatusBadRequest
}

func (err *authErr) Error() string {
func (err *redirectedAuthErr) Error() string {
return err.Description
}

func (err *authErr) Handle() (http.Handler, bool) {
// Didn't get a valid redirect URI.
if err.RedirectURI == "" {
return nil, false
}

func (err *redirectedAuthErr) Handler() http.Handler {
hf := func(w http.ResponseWriter, r *http.Request) {
v := url.Values{}
v.Add("state", err.State)
Expand All @@ -70,7 +73,7 @@ func (err *authErr) Handle() (http.Handler, bool) {
}
http.Redirect(w, r, redirectURI, http.StatusSeeOther)
}
return http.HandlerFunc(hf), true
return http.HandlerFunc(hf)
}

func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) error {
Expand Down Expand Up @@ -102,7 +105,6 @@ const (
errUnsupportedGrantType = "unsupported_grant_type"
errInvalidGrant = "invalid_grant"
errInvalidClient = "invalid_client"
errInvalidConnectorID = "invalid_connector_id"
)

const (
Expand Down Expand Up @@ -408,12 +410,12 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
// parse the initial request from the OAuth2 client.
func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) {
if err := r.ParseForm(); err != nil {
return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."}
return nil, newDisplayedErr(http.StatusBadRequest, "Failed to parse request.")
}
q := r.Form
redirectURI, err := url.QueryUnescape(q.Get("redirect_uri"))
if err != nil {
return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
return nil, newDisplayedErr(http.StatusBadRequest, "No redirect_uri provided.")
}

clientID := q.Get("client_id")
Expand All @@ -434,45 +436,44 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
client, err := s.storage.GetClient(clientID)
if err != nil {
if err == storage.ErrNotFound {
description := fmt.Sprintf("Invalid client_id (%q).", clientID)
return nil, &authErr{"", "", errUnauthorizedClient, description}
return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID)
}
s.logger.Errorf("Failed to get client: %v", err)
return nil, &authErr{"", "", errServerError, ""}
}

if connectorID != "" {
connectors, err := s.storage.ListConnectors()
if err != nil {
return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
}
if !validateConnectorID(connectors, connectorID) {
return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
}
return nil, newDisplayedErr(http.StatusInternalServerError, "Database error.")
}

if !validateRedirectURI(client, redirectURI) {
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return nil, &authErr{"", "", errInvalidRequest, description}
return nil, newDisplayedErr(http.StatusBadRequest, "Unregistered redirect_uri (%q).", redirectURI)
}
if redirectURI == deviceCallbackURI && client.Public {
redirectURI = s.issuerURL.Path + deviceCallbackURI
}

// From here on out, we want to redirect back to the client with an error.
newErr := func(typ, format string, a ...interface{}) *authErr {
return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
newRedirectedErr := func(typ, format string, a ...interface{}) *redirectedAuthErr {
return &redirectedAuthErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
}

if connectorID != "" {
connectors, err := s.storage.ListConnectors()
if err != nil {
s.logger.Errorf("Failed to list connectors: %v", err)
return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors")
}
if !validateConnectorID(connectors, connectorID) {
return nil, newRedirectedErr(errInvalidRequest, "Invalid ConnectorID")
}
}

// dex doesn't support request parameter and must return request_not_supported error
// https://openid.net/specs/openid-connect-core-1_0.html#6.1
if q.Get("request") != "" {
return nil, newErr(errRequestNotSupported, "Server does not support request parameter.")
return nil, newRedirectedErr(errRequestNotSupported, "Server does not support request parameter.")
}

if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
return nil, newErr(errInvalidRequest, description)
return nil, newRedirectedErr(errInvalidRequest, description)
}

var (
Expand All @@ -494,21 +495,21 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques

isTrusted, err := s.validateCrossClientTrust(clientID, peerID)
if err != nil {
return nil, newErr(errServerError, "Internal server error.")
return nil, newRedirectedErr(errServerError, "Internal server error.")
}
if !isTrusted {
invalidScopes = append(invalidScopes, scope)
}
}
}
if !hasOpenIDScope {
return nil, newErr(errInvalidScope, `Missing required scope(s) ["openid"].`)
return nil, newRedirectedErr(errInvalidScope, `Missing required scope(s) ["openid"].`)
}
if len(unrecognized) > 0 {
return nil, newErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized)
return nil, newRedirectedErr(errInvalidScope, "Unrecognized scope(s) %q", unrecognized)
}
if len(invalidScopes) > 0 {
return nil, newErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes)
return nil, newRedirectedErr(errInvalidScope, "Client can't request scope(s) %q", invalidScopes)
}

var rt struct {
Expand All @@ -526,37 +527,37 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
case responseTypeToken:
rt.token = true
default:
return nil, newErr(errInvalidRequest, "Invalid response type %q", responseType)
return nil, newRedirectedErr(errInvalidRequest, "Invalid response type %q", responseType)
}

if !s.supportedResponseTypes[responseType] {
return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
return nil, newRedirectedErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
}
}

if len(responseTypes) == 0 {
return nil, newErr(errInvalidRequest, "No response_type provided")
return nil, newRedirectedErr(errInvalidRequest, "No response_type provided")
}

if rt.token && !rt.code && !rt.idToken {
// "token" can't be provided by its own.
//
// https://openid.net/specs/openid-connect-core-1_0.html#Authentication
return nil, newErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'")
return nil, newRedirectedErr(errInvalidRequest, "Response type 'token' must be provided with type 'id_token' and/or 'code'")
}
if !rt.code {
// Either "id_token token" or "id_token" has been provided which implies the
// implicit flow. Implicit flow requires a nonce value.
//
// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
if nonce == "" {
return nil, newErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.")
return nil, newRedirectedErr(errInvalidRequest, "Response type 'token' requires a 'nonce' value.")
}
}
if rt.token {
if redirectURI == redirectURIOOB {
err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
return nil, newErr(errInvalidRequest, err)
return nil, newRedirectedErr(errInvalidRequest, err)
}
}

Expand Down
54 changes: 34 additions & 20 deletions server/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"

"github.com/dexidp/dex/storage"
Expand All @@ -27,8 +26,7 @@ func TestParseAuthorizationRequest(t *testing.T) {

queryParams map[string]string

wantErr bool
exactError *authErr
expectedError error
}{
{
name: "normal request",
Expand Down Expand Up @@ -78,7 +76,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"response_type": "code",
"scope": "openid email profile",
},
wantErr: true,
expectedError: &displayedAuthErr{Status: http.StatusNotFound},
},
{
name: "invalid redirect uri",
Expand All @@ -95,7 +93,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"response_type": "code",
"scope": "openid email profile",
},
wantErr: true,
expectedError: &displayedAuthErr{Status: http.StatusBadRequest},
},
{
name: "implicit flow",
Expand Down Expand Up @@ -128,7 +126,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"response_type": "code id_token",
"scope": "openid email profile",
},
wantErr: true,
expectedError: &redirectedAuthErr{Type: errUnsupportedResponseType},
},
{
name: "only token response type",
Expand All @@ -145,7 +143,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"response_type": "token",
"scope": "openid email profile",
},
wantErr: true,
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
},
{
name: "choose connector_id",
Expand Down Expand Up @@ -197,7 +195,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"response_type": "code id_token",
"scope": "openid email profile",
},
wantErr: true,
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
},
{
name: "PKCE code_challenge_method plain",
Expand Down Expand Up @@ -269,7 +267,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"code_challenge_method": "invalid_method",
"scope": "openid email profile",
},
wantErr: true,
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
},
{
name: "No response type",
Expand All @@ -287,12 +285,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
"code_challenge_method": "plain",
"scope": "openid email profile",
},
wantErr: true,
exactError: &authErr{
RedirectURI: "https://example.com/bar",
Type: "invalid_request",
Description: "No response_type provided",
},
expectedError: &redirectedAuthErr{Type: errInvalidRequest},
},
}

Expand Down Expand Up @@ -321,13 +314,34 @@ func TestParseAuthorizationRequest(t *testing.T) {
}

_, err := server.parseAuthorizationRequest(req)
if tc.wantErr {
require.Error(t, err)
if tc.exactError != nil {
require.Equal(t, tc.exactError, err)
if tc.expectedError == nil {
if err != nil {
t.Errorf("%s: expected no error", tc.name)
}
} else {
require.NoError(t, err)
switch expectedErr := tc.expectedError.(type) {
case *redirectedAuthErr:
e, ok := err.(*redirectedAuthErr)
if !ok {
t.Fatalf("%s: expected redirectedAuthErr error", tc.name)
}
if e.Type != expectedErr.Type {
t.Errorf("%s: expected error type %v, got %v", tc.name, expectedErr.Type, e.Type)
}
if e.RedirectURI != tc.queryParams["redirect_uri"] {
t.Errorf("%s: expected error to be returned in redirect to %v", tc.name, tc.queryParams["redirect_uri"])
}
case *displayedAuthErr:
e, ok := err.(*displayedAuthErr)
if !ok {
t.Fatalf("%s: expected displayedAuthErr error", tc.name)
}
if e.Status != expectedErr.Status {
t.Errorf("%s: expected http status %v, got %v", tc.name, expectedErr.Status, e.Status)
}
default:
t.Fatalf("%s: unsupported error type", tc.name)
}
}
}()
}
Expand Down
Loading

0 comments on commit 3fac2ab

Please sign in to comment.