Skip to content

Commit

Permalink
lint code (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablo-ruth authored May 3, 2023
1 parent 86eb726 commit 9d89a80
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 45 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@
# End of https://www.toptal.com/developers/gitignore/api/go
k8s-dashboard-auth-proxy
kubeconfig
.vscode
.vscode
localhost.crt
localhost.key
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func main() {
debug := flag.Bool("debug", false, "Debug mode")
flag.Parse()

// Check that loginURL
// Check login URL
if *loginURL == "" {
fmt.Println("Login URL must be set")
os.Exit(1)
Expand All @@ -42,14 +42,14 @@ func main() {
os.Exit(1)
}
default:
fmt.Println("Auth provider must be 'adfs' or 'tanzu'")
fmt.Println("Auth provider must be 'aws-adfs' or 'tanzu'")
os.Exit(1)
}

// Server requests
err = proxy.Server(*dashboardURL, authProvider, *debug)
if err != nil {
fmt.Printf("failed to start proxy: %s", err)
fmt.Printf("Failed to start proxy: %s", err)
os.Exit(1)
}
}
21 changes: 10 additions & 11 deletions provider/provider_aws_adfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@ import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strings"
"time"

"github.com/Versent/saml2aws/pkg/awsconfig"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/versent/saml2aws"
"github.com/versent/saml2aws/pkg/awsconfig"
"github.com/versent/saml2aws/pkg/cfg"
"github.com/versent/saml2aws/pkg/creds"
"sigs.k8s.io/aws-iam-authenticator/pkg/token"
)

const SessionDuration = 3600

type ProviderAwsAdfs struct {
LoginURL string
ClusterID string
Expand Down Expand Up @@ -68,7 +71,7 @@ func (p *ProviderAwsAdfs) Login(user, password string) (string, map[string]strin
account.Provider = "ADFS"
account.MFA = "Auto"
account.AmazonWebservicesURN = "urn:amazon:webservices"
account.SessionDuration = 36000
account.SessionDuration = SessionDuration

// Create a new SAML client
client, err := saml2aws.NewSAMLClient(account)
Expand Down Expand Up @@ -108,16 +111,15 @@ func (p *ProviderAwsAdfs) Login(user, password string) (string, map[string]strin

// Extract account ID from role ARN
// Example: arn:aws:iam::123456789012:role/role-name
if len(strings.Split(role.RoleARN, ":")) != 6 {
return "", map[string]string{}, fmt.Errorf("invalid role ARN: %s", role.RoleARN)
}
re := regexp.MustCompile(`^arn:aws:iam::(\d+):role\/([\w+=,.@-]{1,64})$`)
match := re.FindSubmatch([]byte(role.RoleARN))

if len(strings.Split(role.RoleARN, "/")) != 2 {
if len(match) != 3 {
return "", map[string]string{}, fmt.Errorf("invalid role ARN: %s", role.RoleARN)
}

awsRole := AWSRole{
Name: fmt.Sprintf("%s:%s", strings.Split(role.RoleARN, ":")[4], strings.Split(role.RoleARN, "/")[1]),
Name: fmt.Sprintf("%s:%s", match[1], match[2]),
ARN: role.RoleARN,
Principal: role.PrincipalARN,
}
Expand All @@ -128,9 +130,6 @@ func (p *ProviderAwsAdfs) Login(user, password string) (string, map[string]strin
}

b64Role := base64.StdEncoding.EncodeToString(jsonRole)
if err != nil {
return "", map[string]string{}, fmt.Errorf("error encoding role: %w", err)
}

roles[awsRole.Name] = b64Role
}
Expand Down Expand Up @@ -166,7 +165,7 @@ func (p *ProviderAwsAdfs) AssumeRole(SAMLAssertion, role string) (AWSCreds, erro
PrincipalArn: aws.String(awsRole.Principal),
RoleArn: aws.String(awsRole.ARN),
SAMLAssertion: aws.String(SAMLAssertion),
DurationSeconds: aws.Int64(int64(36000)),
DurationSeconds: aws.Int64(SessionDuration),
}
resp, err := svc.AssumeRoleWithSAML(assumeInput)
if err != nil {
Expand Down
14 changes: 11 additions & 3 deletions provider/provider_tanzu.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,18 @@ func (p *ProviderTanzu) Login(user, password string) (string, error) {
client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}}

// Create JSON payload
payload := fmt.Sprintf("{\"guest_cluster_name\":\"%s\"}", p.GuestCluster)
payloadStruct := struct {
GuestClusterName string `json:"guest_cluster_name"`
}{
GuestClusterName: p.GuestCluster,
}
payload, err := json.Marshal(payloadStruct)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}

// Create login request
req, err := http.NewRequest("POST", p.LoginURL, strings.NewReader(payload))
req, err := http.NewRequest("POST", p.LoginURL, strings.NewReader(string(payload)))
if err != nil {
return "", fmt.Errorf("failed to create login request: %w", err)
}
Expand All @@ -78,7 +86,7 @@ func (p *ProviderTanzu) Login(user, password string) (string, error) {
defer resp.Body.Close()

// Check HTTP code for login succeeded
if resp.StatusCode != 200 {
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("login failed with HTTP code %d", resp.StatusCode)
}

Expand Down
38 changes: 19 additions & 19 deletions proxy/loginHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ func loginGetHandler(w http.ResponseWriter, r *http.Request) {
// Parse login template
tmpl, err := template.New("login").Parse(loginPageTemplate)
if err != nil {
log.Printf("failed to parse login page template: %s", err)
log.Printf("failed to parse login page template: %v", err)
return
}

// Execute template with login error if provided in URL
err = tmpl.ExecuteTemplate(w, "login", loginErrorMessage)
if err != nil {
log.Printf("failed to execute login page template: %s", err)
log.Printf("failed to execute login page template: %v", err)
return
}
}
Expand All @@ -62,22 +62,22 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter

if username == "" || password == "" {
log.Printf("username or password not provided")
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Username or password not provided")), http.StatusFound)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Username or password not provided")), http.StatusFound)
return
}

// Authenticate user
assertion, roles, err := adfsProvider.Login(username, password)
if err != nil {
log.Printf("failed to authenticate user: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Authentication failed")), http.StatusFound)
log.Printf("failed to authenticate user: %v", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Authentication failed")), http.StatusFound)
return
}

// Parse role template
tmpl, err := template.New("role").Parse(rolePageTemplate)
if err != nil {
log.Printf("failed to parse role page template: %s", err)
log.Printf("failed to parse role page template: %v", err)
return
}

Expand All @@ -90,8 +90,8 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter
Roles: roles,
})
if err != nil {
log.Printf("failed to execute role page template: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to execute role page template")), http.StatusFound)
log.Printf("failed to execute role page template: %v", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Failed to execute role page template")), http.StatusFound)
return
}
case "role":
Expand All @@ -101,31 +101,31 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter
role := r.FormValue("role")
if role == "" {
log.Printf("role not provided")
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Role not provided")), http.StatusFound)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Role not provided")), http.StatusFound)
return
}

// Check if assertion is provided
assertion := r.FormValue("assertion")
if assertion == "" {
log.Printf("assertion not provided")
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Assertion not provided")), http.StatusFound)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Assertion not provided")), http.StatusFound)
return
}

// Assume role with SAML assertion
creds, err := adfsProvider.AssumeRole(assertion, role)
if err != nil {
log.Printf("failed to assume role: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to assume role")), http.StatusFound)
log.Printf("failed to assume role: %v", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Failed to assume role")), http.StatusFound)
return
}

// Marshal credentials to base64 encoded JSON and store them in proxy_aws_creds cookie
jsonCreds, err := json.Marshal(creds)
if err != nil {
log.Printf("failed to marshal credentials: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to marshal credentials")), http.StatusFound)
log.Printf("failed to marshal credentials: %v", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Failed to marshal credentials")), http.StatusFound)
return
}
b64JsonCreds := base64.StdEncoding.EncodeToString(jsonCreds)
Expand All @@ -134,8 +134,8 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter
// Get token from credentials
token, err := adfsProvider.Token(creds)
if err != nil {
log.Printf("failed to get token: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Failed to get token")), http.StatusFound)
log.Printf("failed to get token: %v", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Failed to get token")), http.StatusFound)
return
}

Expand All @@ -155,16 +155,16 @@ func loginPostHandler(authProvider provider.Provider) func(w http.ResponseWriter

if username == "" || password == "" {
log.Printf("username or password not provided")
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Username or password not provided")), http.StatusFound)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Username or password not provided")), http.StatusFound)
return
}

// Authenticate user
tanzuProvider := authProvider
token, err := tanzuProvider.Login(username, password)
if err != nil {
log.Printf("failed to authenticate user: %s", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%s", url.QueryEscape("Authentication failed")), http.StatusFound)
log.Printf("failed to authenticate user: %v", err)
http.Redirect(w, r, fmt.Sprintf("/login?error=%v", url.QueryEscape("Authentication failed")), http.StatusFound)
return
}

Expand Down
2 changes: 1 addition & 1 deletion proxy/logoutHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func logoutGetHandler(w http.ResponseWriter, r *http.Request) {
// Call token cookie deletion helper
err := deleteTokenCookie(w, r)
if err != nil {
fmt.Printf("deleting token cookie: %s\n", err)
fmt.Printf("deleting token cookie: %v\n", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
Expand Down
14 changes: 7 additions & 7 deletions proxy/proxyHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func proxyHandler(target string, authProvider provider.Provider) func(w http.Res
// get token or redirect to login
token, err := getTokenCookie(r)
if err != nil {
log.Printf("failed to get cookie: %s", err)
log.Printf("failed to get cookie: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
Expand All @@ -36,30 +36,30 @@ func proxyHandler(target string, authProvider provider.Provider) func(w http.Res
// Get cookier proxy_aws_creds
b64Creds, err := r.Cookie("proxy_aws_creds")
if err != nil || b64Creds.Value == "" {
log.Printf("failed to get cookie proxy_aws_creds: %s", err)
log.Printf("failed to get cookie proxy_aws_creds: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}

// Extract creds from cookie
decodedCreds, err := base64.StdEncoding.DecodeString(b64Creds.Value)
if err != nil {
log.Printf("failed to decode cookie proxy_aws_creds: %s", err)
log.Printf("failed to decode cookie proxy_aws_creds: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
var creds provider.AWSCreds
err = json.Unmarshal(decodedCreds, &creds)
if err != nil {
log.Printf("failed to unmarshal cookie proxy_aws_creds: %s", err)
log.Printf("failed to unmarshal cookie proxy_aws_creds: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}

// Try to refresh token with existing creds
newToken, err := authProvider.Token(creds)
if err != nil {
log.Printf("failed to refresh token: %s", err)
log.Printf("failed to refresh token: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
Expand All @@ -78,7 +78,7 @@ func proxyHandler(target string, authProvider provider.Provider) func(w http.Res
break
}

log.Printf("failed to check if token is valid: %s", err)
log.Printf("failed to check if token is valid: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return

Expand All @@ -87,7 +87,7 @@ func proxyHandler(target string, authProvider provider.Provider) func(w http.Res
// create the reverse proxy
url, err := url.Parse(target)
if err != nil {
log.Printf("failed to parse target URL: %s", err)
log.Printf("failed to parse target URL: %v", err)
http.Redirect(w, r, "/login", http.StatusFound)
return
}
Expand Down

0 comments on commit 9d89a80

Please sign in to comment.