Skip to content

Commit

Permalink
chore(dashboard): Make CSWSH protection explicit
Browse files Browse the repository at this point in the history
In the websocket handshake handler for the tap API we were relying on Gorilla's default behavior for validating that requests came from the same host, to protect against Cross-Site WebSocket Hijacking (CSWSH). This change only makes that validation explicit, instead of relying on default behavior, as a best practice.
  • Loading branch information
alpeb committed Jan 10, 2025
1 parent d1b2aa5 commit f03280c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
5 changes: 5 additions & 0 deletions web/srv/api_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ var (
websocketUpgrader = websocket.Upgrader{
ReadBufferSize: maxMessageSize,
WriteBufferSize: maxMessageSize,
// Only allows requests from the same host, to prevent CSRF attacks.
// This is the default behavior in gorilla/websocket even if the
// CheckOrigin field is not set, but we make it explicit here as a good
// practice, avoiding relying on default behavior.
CheckOrigin: checkSameOrigin,
}

// Checks whose description matches the following regexp won't be included
Expand Down
46 changes: 46 additions & 0 deletions web/srv/check_same_origin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package srv

import (
"net/http"
"net/url"
"unicode/utf8"
)

// checkSameOrigin returns true if the origin is not set or is equal to the request host.
// Copied from gorilla/websocket.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}

// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
// Copied from gorilla/websocket.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}

0 comments on commit f03280c

Please sign in to comment.