Skip to content

Commit

Permalink
Let GetUserID and GetOrganizationID panic (#878)
Browse files Browse the repository at this point in the history
  • Loading branch information
FoseFx authored Oct 28, 2024
1 parent baffb92 commit 2fcd785
Show file tree
Hide file tree
Showing 23 changed files with 116 additions and 280 deletions.
38 changes: 30 additions & 8 deletions libs/common/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,44 @@ func SessionValidUntil(ctx context.Context) (time.Time, error) {
}
}

func GetUserID(ctx context.Context) (uuid.UUID, error) {
// MaybeGetUserID can be used instead of MustGetUserID
func MaybeGetUserID(ctx context.Context) *uuid.UUID {
res, ok := ctx.Value(userIDKey{}).(uuid.UUID)
if !ok {
return uuid.UUID{}, status.Error(codes.Internal, "userID not in context, set up auth")
} else {
return res, nil
return nil
}
return &res
}

// MustGetUserID panics, if context does not contain the userIDKey,
// which should have been set in the auth middleware
// Also see MaybeGetUserID, if you can not ensure that
func MustGetUserID(ctx context.Context) uuid.UUID {
res := MaybeGetUserID(ctx)
if res == nil {
panic("MustGetUserID called, but userID not in context, set up auth for this handler!")
}
return *res
}

func GetOrganizationID(ctx context.Context) (uuid.UUID, error) {
// MaybeGetOrganizationID can be used instead of MustGetOrganizationID
func MaybeGetOrganizationID(ctx context.Context) *uuid.UUID {
res, ok := ctx.Value(organizationIDKey{}).(uuid.UUID)
if !ok {
return uuid.UUID{}, status.Error(codes.Internal, "organizationID not in context, set up auth")
} else {
return res, nil
return nil
}
return &res
}

// MustGetOrganizationID panics, if context does not contain the organizationIDKey,
// which should have been set in the auth middleware
// Also see MaybeGetOrganizationID, if you can not ensure that
func MustGetOrganizationID(ctx context.Context) uuid.UUID {
res := MaybeGetOrganizationID(ctx)
if res == nil {
panic("MustGetOrganizationID called, but organizationID not in context, set up auth for this handler!")
}
return *res
}

// SetupAuth sets up auth, such that GetIDTokenVerifier and GetOAuthConfig work
Expand Down
2 changes: 1 addition & 1 deletion libs/common/hwgrpc/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func authInterceptor(ctx context.Context) (context.Context, error) {
return nil, status.Errorf(codes.Internal, "invalid userID")
}

// attach userID to the context, so we can get it in a handler using GetUserID()
// attach userID to the context, so we can get it in a handler using MustGetUserID()
ctx = auth.ContextWithUserID(ctx, userID)

// attach userID to the current span (should be the auth interceptor span)
Expand Down
2 changes: 1 addition & 1 deletion libs/common/hwgrpc/organization_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func organizationInterceptor(ctx context.Context) (context.Context, error) {
return nil, status.Errorf(codes.Internal, "invalid organizationID")
}

// attach organizationID to the context, so we can get it in a handler using GetOrganizationID()
// attach organizationID to the context, so we can get it in a handler using MustGetOrganizationID()
ctx = auth.ContextWithOrganizationID(ctx, organizationID)

// attach organizationID to the current span
Expand Down
20 changes: 10 additions & 10 deletions libs/hwes/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,29 +260,29 @@ func (e *Event) GetJsonData(data interface{}) error {
return json.Unmarshal(e.Data, data)
}

// SetCommitterFromCtx injects the UserID from the passed context via common.GetUserID().
// SetCommitterFromCtx injects the UserID from the passed context via auth.MustGetUserID().
func (e *Event) SetCommitterFromCtx(ctx context.Context) error {
userID, err := auth.GetUserID(ctx)
if err != nil {
userID := auth.MaybeGetUserID(ctx)
if userID == nil {
// don't set a user, if no user is available
return nil //nolint:nilerr
return nil
}

e.CommitterUserID = &userID
e.CommitterUserID = userID

telemetry.SetSpanStr(ctx, "committerUserID", e.CommitterUserID.String())
return nil
}

// SetOrganizationFromCtx injects the OrganizationID from the passed context via common.GetOrganizationID().
// SetOrganizationFromCtx injects the OrganizationID from the passed context via common.MustGetOrganizationID().
func (e *Event) SetOrganizationFromCtx(ctx context.Context) error {
organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
organizationID := auth.MaybeGetOrganizationID(ctx)
if organizationID == nil {
// don't set an org, if no org is available
return nil //nolint:nilerr
return nil
}

e.OrganizationID = &organizationID
e.OrganizationID = organizationID

if _, err := uuid.Parse(e.OrganizationID.String()); err != nil {
return fmt.Errorf("SetOrganizationFromCtx: cant parse organization uid: %w", err)
Expand Down
3 changes: 2 additions & 1 deletion libs/telemetry/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package telemetry
import (
"context"
"errors"
"github.com/prometheus/client_golang/prometheus/promauto"
"hwutil"
"net/http"
"os"
"time"

"github.com/prometheus/client_golang/prometheus/promauto"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,8 @@ func NewCreatePropertyCommandHandler(as hwes.AggregateStore, authz hwauthz.AuthZ
setID *string,
fieldTypeData *models.FieldTypeData,
) (version common.ConsistencyToken, err error) {
user, err := perm.UserFromCtx(ctx)
if err != nil {
return 0, err
}

organization, err := perm.OrganizationFromCtx(ctx)
if err != nil {
return 0, err
}
user := perm.UserFromCtx(ctx)
organization := perm.OrganizationFromCtx(ctx)

check := hwauthz.NewPermissionCheck(user, perm.OrganizationCanUserCreateProperty, organization)
if err = authz.Must(ctx, check); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ func NewUpdatePropertyCommandHandler(as hwes.AggregateStore, authz hwauthz.AuthZ
removeOptions []string,
isArchived *bool,
) (common.ConsistencyToken, error) {
user, err := perm.UserFromCtx(ctx)
if err != nil {
return 0, err
}
user := perm.UserFromCtx(ctx)

check := hwauthz.NewPermissionCheck(user, perm.PropertyCanUserUpdate, perm.Property(propertyID))
if err = authz.Must(ctx, check); err != nil {
if err := authz.Must(ctx, check); err != nil {
return 0, err
}

Expand Down
18 changes: 6 additions & 12 deletions services/property-svc/internal/property/perm/permission.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,19 @@ type User uuid.UUID
func (t User) Type() string { return "user" }
func (t User) ID() string { return uuid.UUID(t).String() }

func UserFromCtx(ctx context.Context) (User, error) {
userID, err := auth.GetUserID(ctx)
if err != nil {
return User{}, err
}
return User(userID), nil
func UserFromCtx(ctx context.Context) User {
userID := auth.MustGetUserID(ctx)
return User(userID)
}

type Organization uuid.UUID

func (p Organization) Type() string { return "organization" }
func (p Organization) ID() string { return uuid.UUID(p).String() }

func OrganizationFromCtx(ctx context.Context) (Organization, error) {
organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return Organization{}, err
}
return Organization(organizationID), nil
func OrganizationFromCtx(ctx context.Context) Organization {
organizationID := auth.MustGetOrganizationID(ctx)
return Organization(organizationID)
}

// Direct Relations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ type GetPropertiesQueryHandler func(

func NewGetPropertiesQueryHandler(authz hwauthz.AuthZ) GetPropertiesQueryHandler {
return func(ctx context.Context, subjectType *pb.SubjectType) ([]*models.PropertyWithConsistency, error) {
user, err := perm.UserFromCtx(ctx)
if err != nil {
return nil, err
}
user := perm.UserFromCtx(ctx)

propertyRepo := property_repo.New(hwdb.GetDB())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ type GetPropertyByIDQueryHandler func(

func NewGetPropertyByIDQueryHandler(authz hwauthz.AuthZ) GetPropertyByIDQueryHandler {
return func(ctx context.Context, propertyID uuid.UUID) (*models.Property, common.ConsistencyToken, error) {
user, err := perm.UserFromCtx(ctx)
if err != nil {
return nil, 0, err
}
user := perm.UserFromCtx(ctx)

// Verify user is allowed to see this property
check := hwauthz.NewPermissionCheck(user, perm.PropertyCanUserGet, perm.Property(propertyID))
if err = authz.Must(ctx, check); err != nil {
if err := authz.Must(ctx, check); err != nil {
return nil, 0, err
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ func NewIsPropertyAlwaysIncludedForViewSourceHandler(authz hwauthz.AuthZ) IsProp
subjectType pb.SubjectType,
propertyID uuid.UUID,
) (bool, error) {
user, err := perm.UserFromCtx(ctx)
if err != nil {
return false, err
}
user := perm.UserFromCtx(ctx)

// Is user allowed to see this property?
check := hwauthz.NewPermissionCheck(user, perm.PropertyCanUserGet, perm.Property(propertyID))
if err = authz.Must(ctx, check); err != nil {
if err := authz.Must(ctx, check); err != nil {
return false, err
}

Expand Down
25 changes: 5 additions & 20 deletions services/task-svc/internal/bed/bed.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ func (ServiceServer) CreateBed(ctx context.Context, req *pb.CreateBedRequest) (*
log := zlog.Ctx(ctx)
bedRepo := bed_repo.New(hwdb.GetDB())

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

roomId, err := uuid.Parse(req.GetRoomId())
if err != nil {
Expand Down Expand Up @@ -82,10 +79,7 @@ func (ServiceServer) CreateBed(ctx context.Context, req *pb.CreateBedRequest) (*
func (ServiceServer) GetBed(ctx context.Context, req *pb.GetBedRequest) (*pb.GetBedResponse, error) {
bedRepo := bed_repo.New(hwdb.GetDB())

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

id, err := uuid.Parse(req.GetId())
if err != nil {
Expand Down Expand Up @@ -149,10 +143,7 @@ func (ServiceServer) GetBedByPatient(
func (ServiceServer) GetBeds(ctx context.Context, _ *pb.GetBedsRequest) (*pb.GetBedsResponse, error) {
bedRepo := bed_repo.New(hwdb.GetDB())

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

beds, err := bedRepo.GetBedsForOrganization(ctx, bed_repo.GetBedsForOrganizationParams{
OrganizationID: organizationID,
Expand Down Expand Up @@ -181,10 +172,7 @@ func (ServiceServer) GetBedsByRoom(
return nil, status.Error(codes.InvalidArgument, err.Error())
}

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

bedRepo := bed_repo.New(hwdb.GetDB())

Expand Down Expand Up @@ -244,10 +232,7 @@ func (ServiceServer) DeleteBed(ctx context.Context, req *pb.DeleteBedRequest) (*
log := zlog.Ctx(ctx)
bedRepo := bed_repo.New(hwdb.GetDB())

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

bedID, err := uuid.Parse(req.GetId())
if err != nil {
Expand Down
30 changes: 6 additions & 24 deletions services/task-svc/internal/patient/patient.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ func (ServiceServer) CreatePatient(

// TODO: Auth

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

patient_id, err := patientRepo.CreatePatient(ctx, patient_repo.CreatePatientParams{
OrganizationID: organizationID,
Expand Down Expand Up @@ -147,10 +144,7 @@ func (ServiceServer) GetPatientsByWard(

// TODO: Auth

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

wardID, err := uuid.Parse(req.GetWardId())
if err != nil {
Expand Down Expand Up @@ -245,10 +239,7 @@ func (ServiceServer) GetRecentPatients(
patientRepo := patient_repo.New(hwdb.GetDB())
log := zlog.Ctx(ctx)

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

// TODO: Auth

Expand Down Expand Up @@ -350,10 +341,7 @@ func (ServiceServer) AssignBed(ctx context.Context, req *pb.AssignBedRequest) (*

// TODO: Auth

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

bedID, err := uuid.Parse(req.GetBedId())
if err != nil {
Expand Down Expand Up @@ -460,10 +448,7 @@ func (ServiceServer) GetPatientDetails(

// TODO: Auth

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

patientID, err := uuid.Parse(req.GetId())
if err != nil {
Expand Down Expand Up @@ -562,10 +547,7 @@ func (ServiceServer) GetPatientList(
return nil, status.Error(codes.InvalidArgument, err.Error())
}

organizationID, err := auth.GetOrganizationID(ctx)
if err != nil {
return nil, err
}
organizationID := auth.MustGetOrganizationID(ctx)

rows, err := patientRepo.GetPatientsWithTasksBedAndRoomForOrganization(ctx, organizationID)
err = hwdb.Error(ctx, err)
Expand Down
Loading

0 comments on commit 2fcd785

Please sign in to comment.