diff --git a/web/srv/api_handlers.go b/web/srv/api_handlers.go index 69ccbba153fe5..e95f42e8ac371 100644 --- a/web/srv/api_handlers.go +++ b/web/srv/api_handlers.go @@ -45,6 +45,11 @@ var ( websocketUpgrader = websocket.Upgrader{ ReadBufferSize: maxMessageSize, WriteBufferSize: maxMessageSize, + // Only allows requests from the same host, to prevent CSRF attacks. + // This is the default behavior in gorilla/websocket even if the + // CheckOrigin field is not set, but we make it explicit here as a good + // practice, avoiding relying on default behavior. + CheckOrigin: checkSameOrigin, } // Checks whose description matches the following regexp won't be included diff --git a/web/srv/check_same_origin.go b/web/srv/check_same_origin.go new file mode 100644 index 0000000000000..9cf62c62f3aae --- /dev/null +++ b/web/srv/check_same_origin.go @@ -0,0 +1,46 @@ +package srv + +import ( + "net/http" + "net/url" + "unicode/utf8" +) + +// checkSameOrigin returns true if the origin is not set or is equal to the request host. +// Copied from gorilla/websocket. +func checkSameOrigin(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + return equalASCIIFold(u.Host, r.Host) +} + +// equalASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +// Copied from gorilla/websocket. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +}