-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore(dashboard): Make CSWSH protection explicit
In the websocket handshake handler for the tap API we were relying on Gorilla's default behavior for validating that requests came from the same host, to protect against Cross-Site WebSocket Hijacking (CSWSH). This change only makes that validation explicit, instead of relying on default behavior, as a best practice.
- Loading branch information
Showing
2 changed files
with
51 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package srv | ||
|
||
import ( | ||
"net/http" | ||
"net/url" | ||
"unicode/utf8" | ||
) | ||
|
||
// checkSameOrigin returns true if the origin is not set or is equal to the request host. | ||
// Copied from gorilla/websocket. | ||
func checkSameOrigin(r *http.Request) bool { | ||
origin := r.Header["Origin"] | ||
if len(origin) == 0 { | ||
return true | ||
} | ||
u, err := url.Parse(origin[0]) | ||
if err != nil { | ||
return false | ||
} | ||
return equalASCIIFold(u.Host, r.Host) | ||
} | ||
|
||
// equalASCIIFold returns true if s is equal to t with ASCII case folding as | ||
// defined in RFC 4790. | ||
// Copied from gorilla/websocket. | ||
func equalASCIIFold(s, t string) bool { | ||
for s != "" && t != "" { | ||
sr, size := utf8.DecodeRuneInString(s) | ||
s = s[size:] | ||
tr, size := utf8.DecodeRuneInString(t) | ||
t = t[size:] | ||
if sr == tr { | ||
continue | ||
} | ||
if 'A' <= sr && sr <= 'Z' { | ||
sr = sr + 'a' - 'A' | ||
} | ||
if 'A' <= tr && tr <= 'Z' { | ||
tr = tr + 'a' - 'A' | ||
} | ||
if sr != tr { | ||
return false | ||
} | ||
} | ||
return s == t | ||
} |