Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(dashboard): Make CSWSH protection explicit #13548

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions web/srv/api_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ var (
websocketUpgrader = websocket.Upgrader{
ReadBufferSize: maxMessageSize,
WriteBufferSize: maxMessageSize,
// Only allows requests from the same host, to prevent CSRF attacks.
// This is the default behavior in gorilla/websocket even if the
// CheckOrigin field is not set, but we make it explicit here as a good
// practice, avoiding relying on default behavior.
CheckOrigin: checkSameOrigin,
}

// Checks whose description matches the following regexp won't be included
Expand Down
46 changes: 46 additions & 0 deletions web/srv/check_same_origin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package srv

import (
"net/http"
"net/url"
"unicode/utf8"
)

// checkSameOrigin returns true if the origin is not set or is equal to the request host.
// Copied from gorilla/websocket.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}

// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
// Copied from gorilla/websocket.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
Loading