Skip to content

Commit

Permalink
Feat custom user handle (#1978)
Browse files Browse the repository at this point in the history
Add a custom user handle to a webauthn credential

---------

Co-authored-by: bjoern-m <[email protected]>
  • Loading branch information
FreddyDevelop and bjoern-m authored Dec 5, 2024
1 parent e172e05 commit 21fd1d4
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 57 deletions.
7 changes: 3 additions & 4 deletions backend/flow_api/flow/shared/hook_issue_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ func (h IssueSession) Execute(c flowpilot.HookExecutionContext) error {
return errors.New("user_id not found in stash")
}

emails, err := deps.Persister.GetEmailPersisterWithConnection(deps.Tx).FindByUserId(userId)
userModel, err := deps.Persister.GetUserPersisterWithConnection(deps.Tx).Get(userId)
if err != nil {
return fmt.Errorf("failed to fetch emails from db: %w", err)
return fmt.Errorf("failed to fetch user from db: %w", err)
}

var emailDTO *dto.EmailJwt

if email := emails.GetPrimary(); email != nil {
if email := userModel.Emails.GetPrimary(); email != nil {
emailDTO = dto.JwtFromEmailModel(email)
}

Expand Down
86 changes: 51 additions & 35 deletions backend/flow_api/services/webauthn.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package services

import (
"encoding/base64"
"errors"
"fmt"
"github.com/go-webauthn/webauthn/protocol"
Expand Down Expand Up @@ -178,45 +177,38 @@ func (s *webauthnService) VerifyAssertionResponse(p VerifyAssertionResponseParam
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
}

sessionDataModel, err := s.persister.GetWebauthnSessionDataPersister().Get(p.SessionDataID)
sessionDataModel, err := s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Get(p.SessionDataID)
if err != nil {
return nil, fmt.Errorf("failed to get session data from db: %w", err)
}

var userID uuid.UUID
if p.IsMFA {
userID = sessionDataModel.UserId
} else {
userID, err = uuid.FromBytes(credentialAssertionData.Response.UserHandle)
if err != nil {
return nil, fmt.Errorf("failed to parse user id from user handle: %w", err)
}
}

userModel, err := s.persister.GetUserPersister().Get(userID)
credentialModel, err := s.persister.GetWebauthnCredentialPersister().Get(credentialAssertionData.ID)
if err != nil {
return nil, fmt.Errorf("failed to fetch user from db: %w", err)
return nil, fmt.Errorf("failed to get webauthncredential from db: %w", err)
}

if userModel == nil {
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
if credentialModel == nil {
return nil, ErrInvalidWebauthnCredential
}

cred := userModel.GetWebauthnCredentialById(credentialAssertionData.ID)
if cred != nil && (!p.IsMFA && cred.MFAOnly) {
if !p.IsMFA && credentialModel.MFAOnly {
return nil, ErrInvalidWebauthnCredentialMFAOnly
}

webAuthnUser, userModel, err := s.GetWebAuthnUser(p.Tx, *credentialModel)
if err != nil {
return nil, err
}

discoverableUserHandler := func(rawID, userHandle []byte) (webauthn.User, error) {
return userModel, nil
return webAuthnUser, nil
}

sessionData := sessionDataModel.ToSessionData()
var credential *webauthn.Credential
if p.IsMFA {
credential, err = s.cfg.Webauthn.Handler.ValidateLogin(userModel, *sessionData, credentialAssertionData)
_, err = s.cfg.Webauthn.Handler.ValidateLogin(webAuthnUser, *sessionData, credentialAssertionData)
} else {
credential, err = s.cfg.Webauthn.Handler.ValidateDiscoverableLogin(
_, err = s.cfg.Webauthn.Handler.ValidateDiscoverableLogin(
discoverableUserHandler,
*sessionData,
credentialAssertionData,
Expand All @@ -226,19 +218,16 @@ func (s *webauthnService) VerifyAssertionResponse(p VerifyAssertionResponseParam
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
}

encodedCredentialId := base64.RawURLEncoding.EncodeToString(credential.ID)
if credentialModel := userModel.GetWebauthnCredentialById(encodedCredentialId); credentialModel != nil {
now := time.Now().UTC()
flags := credentialAssertionData.Response.AuthenticatorData.Flags
now := time.Now().UTC()
flags := credentialAssertionData.Response.AuthenticatorData.Flags

credentialModel.LastUsedAt = &now
credentialModel.BackupState = flags.HasBackupState()
credentialModel.BackupEligible = flags.HasBackupEligible()
credentialModel.LastUsedAt = &now
credentialModel.BackupState = flags.HasBackupState()
credentialModel.BackupEligible = flags.HasBackupEligible()

err = s.persister.GetWebauthnCredentialPersisterWithConnection(p.Tx).Update(*credentialModel)
if err != nil {
return nil, fmt.Errorf("failed to update webauthn credential: %w", err)
}
err = s.persister.GetWebauthnCredentialPersisterWithConnection(p.Tx).Update(*credentialModel)
if err != nil {
return nil, fmt.Errorf("failed to update webauthn credential: %w", err)
}

err = s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Delete(*sessionDataModel)
Expand Down Expand Up @@ -279,11 +268,10 @@ func (s *webauthnService) generateCreationOptions(p GenerateCreationOptionsParam

err = s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Create(*sessionDataModel)
if err != nil {
return nil, nil, fmt.Errorf("failed to store session data to the db: %W", err)
return nil, nil, fmt.Errorf("failed to store session data to the db: %w", err)
}

return sessionDataModel, options, nil

}

func (s *webauthnService) GenerateCreationOptionsSecurityKey(p GenerateCreationOptionsParams) (*models.WebauthnSessionData, *protocol.CredentialCreation, error) {
Expand Down Expand Up @@ -354,3 +342,31 @@ func (s *webauthnService) VerifyAttestationResponse(p VerifyAttestationResponseP

return credential, nil
}

func (s *webauthnService) GetWebAuthnUser(tx *pop.Connection, credential models.WebauthnCredential) (webauthn.User, *models.User, error) {
user, err := s.persister.GetUserPersisterWithConnection(tx).Get(credential.UserId)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch user from db: %w", err)
}
if user == nil {
return nil, nil, ErrInvalidWebauthnCredential
}

if credential.UserHandle != nil {
return &webauthnUserWithCustomUserHandle{
CustomUserHandle: []byte(credential.UserHandle.Handle),
User: *user,
}, user, nil
}

return user, user, err
}

type webauthnUserWithCustomUserHandle struct {
models.User
CustomUserHandle []byte
}

func (u *webauthnUserWithCustomUserHandle) WebAuthnID() []byte {
return u.CustomUserHandle
}
2 changes: 1 addition & 1 deletion backend/handler/webauthn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ var userId = "ec4ef049-5b88-4321-a173-21b0eff06a04"
type sessionManager struct {
}

func (s sessionManager) GenerateJWT(_ uuid.UUID, _ *dto.EmailJwt) (string, jwt.Token, error) {
func (s sessionManager) GenerateJWT(_ uuid.UUID, _ *dto.EmailJwt, _ ...session.JWTOptions) (string, jwt.Token, error) {
return userId, nil, nil
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
drop_foreign_key("webauthn_credentials", "webauthn_credential_user_handle_fkey", {"if_exists": false})
drop_column("webauthn_credentials", "user_handle_id")
drop_table("webauthn_credential_user_handles")


Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
create_table("webauthn_credential_user_handles") {
t.Column("id", "uuid", {primary: true})
t.Column("user_id", "uuid", {"null": false})
t.Column("handle", "string", {"null": false, "unique": true})
t.Timestamps()
t.Index(["id", "user_id"], {"unique": true})
t.ForeignKey("user_id", {"users": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"})
}

add_column("webauthn_credentials", "user_handle_id", "uuid", { "null": true })
add_foreign_key("webauthn_credentials", "user_handle_id", {"webauthn_credential_user_handles": ["id"]}, {
"on_delete": "set null",
"on_update": "cascade",
})

sql("ALTER TABLE webauthn_credentials ADD CONSTRAINT webauthn_credential_user_handle_fkey FOREIGN KEY (user_handle_id, user_id) REFERENCES webauthn_credential_user_handles(id, user_id) ON DELETE NO ACTION ON UPDATE CASCADE;")
30 changes: 16 additions & 14 deletions backend/persistence/models/webauthn_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@ import (

// WebauthnCredential is used by pop to map your webauthn_credentials database table to your go code.
type WebauthnCredential struct {
ID string `db:"id" json:"id"`
Name *string `db:"name" json:"name"`
UserId uuid.UUID `db:"user_id" json:"user_id"`
PublicKey string `db:"public_key" json:"public_key"`
AttestationType string `db:"attestation_type" json:"attestation_type"`
AAGUID uuid.UUID `db:"aaguid" json:"aaguid"`
SignCount int `db:"sign_count" json:"sign_count"`
LastUsedAt *time.Time `db:"last_used_at" json:"last_used_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Transports Transports `has_many:"webauthn_credential_transports" json:"transports"`
BackupEligible bool `db:"backup_eligible" json:"backup_eligible"`
BackupState bool `db:"backup_state" json:"backup_state"`
MFAOnly bool `db:"mfa_only" json:"mfa_only"`
ID string `db:"id" json:"id"`
Name *string `db:"name" json:"name"`
UserId uuid.UUID `db:"user_id" json:"user_id"`
PublicKey string `db:"public_key" json:"public_key"`
AttestationType string `db:"attestation_type" json:"attestation_type"`
AAGUID uuid.UUID `db:"aaguid" json:"aaguid"`
SignCount int `db:"sign_count" json:"sign_count"`
LastUsedAt *time.Time `db:"last_used_at" json:"last_used_at"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
Transports Transports `has_many:"webauthn_credential_transports" json:"transports"`
BackupEligible bool `db:"backup_eligible" json:"backup_eligible"`
BackupState bool `db:"backup_state" json:"backup_state"`
MFAOnly bool `db:"mfa_only" json:"mfa_only"`
UserHandleID *uuid.UUID `db:"user_handle_id" json:"-"`
UserHandle *WebauthnCredentialUserHandle `belongs_to:"webauthn_credential_user_handle" fk_id:"webauthn_credential_user_handle_fkey" json:"user_handle,omitempty"`
}

type WebauthnCredentials []WebauthnCredential
Expand Down
28 changes: 28 additions & 0 deletions backend/persistence/models/webauthn_credential_user_handle.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package models

import (
"github.com/gobuffalo/pop/v6"
"github.com/gobuffalo/validate/v3"
"github.com/gobuffalo/validate/v3/validators"
"github.com/gofrs/uuid"
"time"
)

type WebauthnCredentialUserHandle struct {
ID uuid.UUID `db:"id" json:"id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
Handle string `db:"handle" json:"handle"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}

// Validate gets run every time you call a "pop.Validate*" (pop.ValidateAndSave, pop.ValidateAndCreate, pop.ValidateAndUpdate) method.
func (userHandle *WebauthnCredentialUserHandle) Validate(tx *pop.Connection) (*validate.Errors, error) {
return validate.Validate(
&validators.UUIDIsPresent{Name: "ID", Field: userHandle.ID},
&validators.UUIDIsPresent{Name: "UserId", Field: userHandle.UserID},
&validators.StringIsPresent{Name: "handle", Field: userHandle.Handle},
&validators.TimeIsPresent{Name: "CreatedAt", Field: userHandle.CreatedAt},
&validators.TimeIsPresent{Name: "UpdatedAt", Field: userHandle.UpdatedAt},
), nil
}
2 changes: 1 addition & 1 deletion backend/persistence/webauthn_credential_persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func NewWebauthnCredentialPersister(db *pop.Connection) WebauthnCredentialPersis

func (p *webauthnCredentialPersister) Get(id string) (*models.WebauthnCredential, error) {
credential := models.WebauthnCredential{}
err := p.db.Find(&credential, id)
err := p.db.Eager().Find(&credential, id)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
Expand Down
16 changes: 14 additions & 2 deletions backend/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

type Manager interface {
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt) (string, jwt.Token, error)
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error)
Verify(string) (jwt.Token, error)
GenerateCookie(token string) (*http.Cookie, error)
DeleteCookie() (*http.Cookie, error)
Expand Down Expand Up @@ -90,7 +90,7 @@ func NewManager(jwkManager hankoJwk.Manager, config config.Config) (Manager, err
}

// GenerateJWT creates a new session JWT for the given user
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt) (string, jwt.Token, error) {
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error) {
sessionID, err := uuid.NewV4()
if err != nil {
return "", nil, err
Expand All @@ -109,6 +109,10 @@ func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt) (string, jw
_ = token.Set("email", &email)
}

for _, opt := range opts {
opt(token)
}

if m.issuer != "" {
_ = token.Set(jwt.IssuerKey, m.issuer)
}
Expand Down Expand Up @@ -158,3 +162,11 @@ func (m *manager) DeleteCookie() (*http.Cookie, error) {
MaxAge: -1,
}, nil
}

type JWTOptions func(token jwt.Token)

func WithValue(key string, value interface{}) JWTOptions {
return func(jwt jwt.Token) {
_ = jwt.Set(key, value)
}
}

0 comments on commit 21fd1d4

Please sign in to comment.