From 0d07882e22b910ec107915857e7285eb687eaef4 Mon Sep 17 00:00:00 2001 From: Ben Davies Date: Fri, 18 Nov 2016 12:13:16 -0400 Subject: [PATCH] Sockets: Go rewrite WIP Fixes #2943 --- package.json | 9 +- rooms.js | 5 +- sockets.go | 523 +++++++++++++++++++++++++++++++++++++ sockets.js | 666 ++++++++++++++++-------------------------------- sockets_test.go | 249 ++++++++++++++++++ users.js | 16 +- 6 files changed, 1000 insertions(+), 468 deletions(-) create mode 100644 sockets.go create mode 100644 sockets_test.go diff --git a/package.json b/package.json index fe81ad1ff8883..b32f3f559a9cf 100644 --- a/package.json +++ b/package.json @@ -3,17 +3,10 @@ "preferGlobal": true, "description": "The server for the Pokémon Showdown battle simulator", "version": "0.10.2", - "dependencies": { - "sockjs": "0.3.18" - }, "optionalDependencies": { "cloud-env": "0.1.1", "http-proxy": "0.10.0", - "nodemailer": "1.4.0", - "node-static": "0.7.7" - }, - "nonDefaultDependencies": { - "ofe": "0.1.2" + "nodemailer": "1.4.0" }, "engines": { "node": ">=6.0.0" diff --git a/rooms.js b/rooms.js index 7881493a7c6c8..76b3a0482570d 100644 --- a/rooms.js +++ b/rooms.js @@ -1556,7 +1556,10 @@ class ChatRoom extends Room { if (this.users[user.userid]) return user; if (user.named) { - this.reportJoin('j', user.getIdentity(this.id)); + // Prevents a race condition where this message would send before + // Connection#joinRoom has a chance to finish, preventing it from + // reaching users joining empty rooms. + process.nextTick(() => this.reportJoin('j', user.getIdentity(this.id))); } this.users[user.userid] = user; diff --git a/sockets.go b/sockets.go new file mode 100644 index 0000000000000..f21dc46ef3f03 --- /dev/null +++ b/sockets.go @@ -0,0 +1,523 @@ +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "log" + "net" + "net/http" + "os" + "path/filepath" + // "reflect" + "regexp" + // "runtime" + "strings" + "sync" + // "unsafe" + + // TODO: use the stable version of sockjs-go once it includes + // sockjs.Session.Request(). + "github.com/gorilla/mux" + "github.com/igm/sockjs-go/sockjs" +) + +// Silence IPC when unit testing. +var production bool = true + +// IPC delimiter +const EOT byte = 3 + +type SSLOptions struct { + Cert string `json:"cert"` + Key string `json:"key"` +} + +type SSL struct { + Port string `json:"port"` + Options SSLOptions `json:"options"` +} + +type Config struct { + Workers int `json:"Workers"` + Port string `json:"Port"` + BindAddress string `json:"BindAddress"` + SSL SSL `json:"SSL"` +} + +func NewConfig(envVar string) (c Config, err error) { + configEnv := os.Getenv(envVar) + err = json.Unmarshal([]byte(configEnv), &c) + return +} + +type Payload struct { + Command string + Params []string +} + +func NewPayload(command string, params ...string) Payload { + p := Payload{ + Command: command, + Params: params} + return p +} + +func (p Payload) WriteToStdout() { + output, _ := json.Marshal(p) + if production { + os.Stdout.Write(append(output, EOT)) + } +} + +func (p Payload) WriteToSocket() { + switch p.Command { + case "!": + if err := sm.SocketRemove(p.Params[0], true); err != nil { + panic(err) + } + case ">": + if err := sm.SocketSend(p.Params[0], p.Params[1]); err != nil { + panic(err) + } + case "+": + sm.ChannelAdd(p.Params[0], p.Params[1]) + case "-": + if err := sm.ChannelRemove(p.Params[0], p.Params[1]); err != nil { + panic(err) + } + case "#": + if err := sm.ChannelSend(p.Params[0], p.Params[1]); err != nil { + panic(err) + } + case ".": + if err := sm.SubchannelMove(p.Params[0], p.Params[1], p.Params[2]); err != nil { + panic(err) + } + case ":": + if err := sm.SubchannelSend(p.Params[0], p.Params[1]); err != nil { + panic(err) + } + } +} + +type Job struct { + Payload Payload +} + +var JobQueue = make(chan Job) + +type Worker struct { + WorkerPool chan chan Job + JobChannel chan Job + quit chan bool +} + +func NewWorker(workerPool chan chan Job) Worker { + return Worker{ + WorkerPool: workerPool, + JobChannel: make(chan Job), + quit: make(chan bool)} +} + +func (w Worker) Start() { + go func() { + for { + w.WorkerPool <- w.JobChannel + select { + case job := <-w.JobChannel: + job.Payload.WriteToSocket() + case <-w.quit: + return + } + } + }() +} + +func (w Worker) Stop() { + go func() { + w.quit <- true + }() +} + +type Dispatcher struct { + WorkerPool chan chan Job + MaxWorkers int +} + +func NewDispatcher(maxWorkers int) *Dispatcher { + return &Dispatcher{ + WorkerPool: make(chan chan Job, maxWorkers), + MaxWorkers: maxWorkers} +} + +func (d *Dispatcher) Run() { + for i := 0; i < d.MaxWorkers; i++ { + worker := NewWorker(d.WorkerPool) + worker.Start() + } + + go d.dispatch() +} + +func (d *Dispatcher) dispatch() { + for { + select { + case job := <-JobQueue: + go func(job Job) { + jobChannel := <-d.WorkerPool + jobChannel <- job + }(job) + } + } +} + +// socketMultiplexer acts as a wrapper for sockjs.handler, exposing its map of +// SockJS sessions to allow the parent process to be able to send messages to +// all users in a room via channels, or either side of a battle and the +// users in its audience via subchannels. Channels are maps of socket IDs to +// subchannel IDs. +type socketMultiplexer struct { + smux sync.Mutex + sockets map[string]sockjs.Session + cmux sync.Mutex + channels map[string]map[string]string + scre *regexp.Regexp +} + +func newSocketMultiplexer() *socketMultiplexer { + return &socketMultiplexer{ + sockets: make(map[string]sockjs.Session), + channels: make(map[string]map[string]string), + scre: regexp.MustCompile("\n|split\n([^\n]*)\n([^\n]*)\n([^\n]*)\n[^\n]*")} +} + +func (sm *socketMultiplexer) SocketAdd(s sockjs.Session) error { + sm.smux.Lock() + defer sm.smux.Unlock() + + id := s.ID() + if _, ok := sm.sockets[id]; ok { + return fmt.Errorf("sockets: error adding socket: collision at ID %v", id) + } + + sm.sockets[id] = s + + // FIXME: payload params is missing the socket's protocol! Screw with + // reflect and unsafe to get it, since it's a private field. + req := s.Request() + ip, _, _ := net.SplitHostPort(req.RemoteAddr) + ips := req.Header.Get("X-Forwarded-For") + NewPayload("*", id, ip, ips).WriteToStdout() + return nil +} + +func (sm *socketMultiplexer) SocketRemove(sid string, forced bool) error { + sm.smux.Lock() + defer sm.smux.Unlock() + + sm.cmux.Lock() + for cid, c := range sm.channels { + if _, ok := c[sid]; ok { + delete(c, sid) + if (len(c) == 0) { + delete((*sm).channels, cid) + } + } + } + sm.cmux.Unlock() + + s, ok := sm.sockets[sid]; + if !ok { + return fmt.Errorf("sockets: failed to remove socket of ID %v: does not exist", sid) + } + + delete((*sm).sockets, sid) + + if forced { + s.Close(2010, "Normal closure") + } else { + // The parent process doesn't know that the socket was closed. Poke it + // so it can clean up any relevant connections. + NewPayload("!", sid).WriteToStdout() + } + + return nil +} + +func (sm *socketMultiplexer) SocketSend(sid string, msg string) error { + sm.smux.Lock() + defer sm.smux.Unlock() + + s, ok := sm.sockets[sid]; + if !ok { + return fmt.Errorf("sockets: error sending to socket of ID %v: does not exist (%v)", sid, msg) + } + + err := s.Send(msg) + return err +} + +func (sm *socketMultiplexer) ChannelAdd(cid string, sid string) { + sm.cmux.Lock() + defer sm.cmux.Unlock() + + c, ok := sm.channels[cid] + if !ok { + c = make(map[string]string) + sm.channels[cid] = c + } + c[sid] = "0" +} + +func (sm *socketMultiplexer) ChannelRemove(cid string, sid string) error { + sm.cmux.Lock() + defer sm.cmux.Unlock() + + c, ok := sm.channels[cid] + if ok { + if _, ok := c[sid]; !ok { + return fmt.Errorf("sockets: failed to remove socket of ID %v from channel of ID %v: socket does not exist", sid, cid) + } + } else { + return fmt.Errorf("sockets: failed to remove socket of ID %v from channel of ID %v: channel does not exist", sid, cid) + } + + delete(c, sid) + if len(c) == 0 { + delete((*sm).channels, cid) + } + + return nil +} + +func (sm *socketMultiplexer) ChannelSend(cid string, msg string) error { + sm.cmux.Lock() + c, ok := sm.channels[cid] + if !ok { + sm.cmux.Unlock() + return fmt.Errorf("sockets: failed to send to channel of ID %v: does not exist", cid) + } + sm.cmux.Unlock() + + sm.smux.Lock() + defer sm.smux.Unlock() + + for sid, _ := range c { + s, ok := sm.sockets[sid] + if !ok { + delete(c, sid) + continue + } + + if err := s.Send(msg); err != nil { + return fmt.Errorf("sockets: failed to send to channel of ID %v: %v", cid, err) + } + } + + return nil +} + +func (sm *socketMultiplexer) SubchannelMove(cid string, scid string, sid string) error { + sm.cmux.Lock() + defer sm.cmux.Unlock() + + c, ok := sm.channels[cid] + if !ok { + return fmt.Errorf("sockets: failed to move socket of ID %v to subchannel of ID %v in channel of ID %v: channel does not exist", sid, scid, cid) + } + + c[sid] = scid + return nil +} + +func (sm *socketMultiplexer) SubchannelSend(cid string, msg string) error { + sm.cmux.Lock() + defer sm.cmux.Unlock() + + c, ok := sm.channels[cid] + if !ok { + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: channel does not exist", cid) + } + + var scmsgs [3][]string + msgs := sm.scre.FindStringSubmatch(msg) + for i := 0; i < len(msgs); i++ { + switch i % 3 { + case 0: + scmsgs[0] = append(scmsgs[0], "\n" + msgs[i]) + case 1: + scmsgs[1] = append(scmsgs[1], "\n" + msgs[i]) + case 2: + scmsgs[2] = append(scmsgs[2], "\n" + msgs[i]) + } + } + + sm.smux.Lock() + defer sm.smux.Unlock() + + for sid, scid := range c { + s, ok := sm.sockets[sid] + if !ok { + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: socket of ID %v in subchannel of ID %v does not exist", cid, sid, scid) + } + + switch scid { + case "0": + for _, scmsg := range scmsgs[0] { + s.Send(scmsg) + } + case "1": + for _, scmsg := range scmsgs[1] { + s.Send(scmsg) + } + case "2": + for _, scmsg := range scmsgs[2] { + s.Send(scmsg) + } + default: + return fmt.Errorf("sockets: failed to broadcast to subchannels in channel of ID %v: socket of ID %v has unknown subchannel ID: %v", cid, sid, scid) + } + } + + return nil +} + +var sm *socketMultiplexer + +func sockJSHandler(s sockjs.Session) { + if err := sm.SocketAdd(s); err != nil { + log.Fatal(err) + } + + id := s.ID() + for { + if msg, err := s.Recv(); err == nil { + msglen := len(msg) + + // Drop blank messages (DDOS?). + if msglen == 0 { + continue + } + + // Drop messages over 100KB + if msglen > 102400 { + fmt.Printf("sockets: dropping %vKB client message: %v", msglen / 1024, msg[:160]) + continue + } + + // Drop legacy JSON messages (123 is "{" as type byte). + if msg[0] == 123 { + continue + } + + // Drop invalid messages. + if idx := strings.Index(msg, "|"); idx < 0 || idx == msglen - 1 { + continue + } + + NewPayload("<", id, msg).WriteToStdout() + continue + } + + break + } + + sm.SocketRemove(id, false) +} + +func parse(s *bufio.Scanner, pch chan Payload) { + var buf []byte + for s.Scan() { + token := s.Bytes() + buf = append(buf, token...) + + var p Payload + if err := json.Unmarshal(buf, &p); err == nil { + pch <- p + return + } + continue + } + + if err := s.Err(); err != nil { + log.Fatal("sockets: error reading IPC input: %v", err) + } +} + +func main() { + // Create config struct from PS_CONFIG env variable as defined by the + // parent process from relevant settings in config.js. + config, err := NewConfig("PS_CONFIG") + if err != nil { + log.Fatal("sockets: failed to read config from $PS_CONFIG") + } + + // Spawn goroutine workers. + dispatcher := NewDispatcher(config.Workers) + dispatcher.Run() + + // Set up our static servers. + // TODO: serve the proper 404 page for invalid routes. + r := mux.NewRouter() + + staticDir, _ := filepath.Abs("./static") + r.Handle("/", http.FileServer(http.Dir(staticDir))) + + customCSSDir, _ := filepath.Abs("./config") + r.Handle("/custom.css", http.FileServer(http.Dir(customCSSDir))) + + avatarDir, _ := filepath.Abs("./config/avatars") + r.PathPrefix("/avatars/"). + Handler(http.FileServer(http.Dir(avatarDir))) + + // Set up our SockJS server. + sm = newSocketMultiplexer() + opts := sockjs.Options{ + SockJSURL: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + Websocket: true, + HeartbeatDelay: sockjs.DefaultOptions.HeartbeatDelay, + DisconnectDelay: sockjs.DefaultOptions.DisconnectDelay, + JSessionID: sockjs.DefaultOptions.JSessionID} + r.PathPrefix("/showdown"). + Handler(sockjs.NewHandler("/showdown", opts, sockJSHandler)) + + // Begin serving over HTTP... + go func() { + port := config.Port + fmt.Printf("sockets: now serving on http://%v%v/", config.BindAddress, port) + log.Fatal(http.ListenAndServe(port, r)) + }() + + // ...and HTTPS, if so configured. + cert := config.SSL.Options.Cert + key := config.SSL.Options.Key + if cert != "" && key != "" { + go func() { + port := config.SSL.Port + fmt.Printf("sockets: now serving on https://%v%v/", config.BindAddress, port) + log.Fatal(http.ListenAndServeTLS(port, cert, key, r)) + }() + } + + // Finally, listen for any messages sent through IPC by the parent + // process until either process is killed. + scanner := bufio.NewScanner(os.Stdin) + scanner.Split(func (data []byte, atEOF bool) (advance int, token []byte, err error) { + for i := 0; i < len(data); i++ { + if data[i] == EOT { + return i + 1, data[:i], nil + } + } + + return 0, data, bufio.ErrFinalToken + }) + + pch := make(chan Payload) + for { + go parse(scanner, pch) + payload := <-pch + job := Job{Payload: payload} + JobQueue <- job + } +} diff --git a/sockets.js b/sockets.js index 729ce6fd55365..eac7983860573 100644 --- a/sockets.js +++ b/sockets.js @@ -13,488 +13,262 @@ 'use strict'; -const cluster = require('cluster'); -global.Config = require('./config/config'); - -if (cluster.isMaster) { - cluster.setupMaster({ - exec: require('path').resolve(__dirname, 'sockets'), - }); - - let workers = exports.workers = {}; - - let spawnWorker = exports.spawnWorker = function () { - let worker = cluster.fork({PSPORT: Config.port, PSBINDADDR: Config.bindaddress || '', PSNOSSL: Config.ssl ? 0 : 1}); - let id = worker.id; - workers[id] = worker; - worker.on('message', data => { - // console.log('master received: ' + data); - switch (data.charAt(0)) { - case '*': { - // *socketid, ip, protocol - // connect - let nlPos = data.indexOf('\n'); - let nlPos2 = data.indexOf('\n', nlPos + 1); - Users.socketConnect(worker, id, data.slice(1, nlPos), data.slice(nlPos + 1, nlPos2), data.slice(nlPos2 + 1)); - break; - } - - case '!': { - // !socketid - // disconnect - Users.socketDisconnect(worker, id, data.substr(1)); - break; - } - - case '<': { - // { + // Respawn if the child process killed itself. + if (!signal) process.nextTick(() => Worker.spawn()); }); - }; - - cluster.on('disconnect', worker => { - // worker crashed, try our best to clean up - require('./crashlogger')(new Error("Worker " + worker.id + " abruptly died"), "The main process"); - // this could get called during cleanup; prevent it from crashing - worker.send = () => {}; - - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; + this.process.stdin.on('error', () => {}); + this.process.stdin.on('drain', () => { + while (this.queue.length) { + let args = this.queue[0]; + let res = this.send.apply(this, args); + if (res) { + this.queue.shift(); + } else { + // Wait for next drain event before continuing. + break; + } } }); - console.error("" + count + " connections were lost."); - // don't delete the worker, so we can investigate it if necessary. + this.process.stdout.setEncoding('utf8'); + this.process.stdout.on('data', data => { + let payloads = data.split(EOT); + for (let payload of payloads) { + if (!payload) continue; + + if (payload.startsWith('sockets:')) { + // Data received wasn't a proper payload, so it was + // probably intended to be logged to console instead. + console.log(payload); + continue; + } - // attempt to recover - spawnWorker(); - }); + let idx = payload.indexOf('{'); + if (idx > 0) { + // For whatever reason, messages written to stdout from the + // child process are prefixed by the number of messages + // written to it since it was spawned. + payload = payload.substr(idx); + } - exports.listen = function (port, bindAddress, workerCount) { - if (port !== undefined && !isNaN(port)) { - Config.port = port; - Config.ssl = null; - } else { - port = Config.port; - // Autoconfigure the app when running in cloud hosting environments: - try { - let cloudenv = require('cloud-env'); - bindAddress = cloudenv.get('IP', bindAddress); - port = cloudenv.get('PORT', port); - } catch (e) {} - } - if (bindAddress !== undefined) { - Config.bindaddress = bindAddress; - } - if (workerCount === undefined) { - workerCount = (Config.workers !== undefined ? Config.workers : 1); - } - for (let i = 0; i < workerCount; i++) { - spawnWorker(); - } - }; - - exports.killWorker = function (worker) { - let idd = worker.id + '-'; - let count = 0; - Users.connections.forEach((connection, connectionid) => { - if (connectionid.substr(idd.length) === idd) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; + payload = JSON.parse(payload); + let command = payload.Command; + let params = payload.Params; + switch (command) { + case '*': + this.onSocketConnect(...params); + break; + case '!': + this.onSocketDisconnect(...params); + break; + case '<': + this.onSocketReceive(...params); + break; + default: + console.error(`sockets: parent process received job with unknown command type ${command} from child process: ${params}`); + break; + } } }); - try { - worker.kill(); - } catch (e) {} - delete workers[worker.id]; - return count; - }; - - exports.killPid = function (pid) { - pid = '' + pid; - for (let id in workers) { - let worker = workers[id]; - if (pid === '' + worker.process.pid) { - return this.killWorker(worker); - } - } - return false; - }; - - exports.socketSend = function (worker, socketid, message) { - worker.send('>' + socketid + '\n' + message); - }; - exports.socketDisconnect = function (worker, socketid) { - worker.send('!' + socketid); - }; - - exports.channelBroadcast = function (channelid, message) { - for (let workerid in workers) { - workers[workerid].send('#' + channelid + '\n' + message); - } - }; - exports.channelSend = function (worker, channelid, message) { - worker.send('#' + channelid + '\n' + message); - }; - exports.channelAdd = function (worker, channelid, socketid) { - worker.send('+' + channelid + '\n' + socketid); - }; - exports.channelRemove = function (worker, channelid, socketid) { - worker.send('-' + channelid + '\n' + socketid); - }; - - exports.subchannelBroadcast = function (channelid, message) { - for (let workerid in workers) { - workers[workerid].send(':' + channelid + '\n' + message); - } - }; - exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { - worker.send('.' + channelid + '\n' + subchannelid + '\n' + socketid); - }; -} else { - // is worker - - if (process.env.PSPORT) Config.port = +process.env.PSPORT; - if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; - if (+process.env.PSNOSSL) Config.ssl = null; - - // ofe is optional - // if installed, it will heap dump if the process runs out of memory - try { - require('ofe').call(); - } catch (e) {} - - // Static HTTP server - - // This handles the custom CSS and custom avatar features, and also - // redirects yourserver:8001 to yourserver-8001.psim.us - - // It's optional if you don't need these features. - global.Dnsbl = require('./dnsbl'); + this.process.stderr.setEncoding('utf8'); + this.process.stderr.once('data', data => { + require('./crashlogger')(new Error(data), `Worker ${this.id}`); - if (Config.crashguard) { - // graceful crash - process.on('uncaughtException', err => { - require('./crashlogger')(err, 'Socket process ' + cluster.worker.id + ' (' + process.pid + ')', true); - }); - } - - let app = require('http').createServer(); - let appssl; - if (Config.ssl) { - appssl = require('https').createServer(Config.ssl.options); - } - try { - let nodestatic = require('node-static'); - let cssserver = new nodestatic.Server('./config'); - let avatarserver = new nodestatic.Server('./config/avatars'); - let staticserver = new nodestatic.Server('./static'); - let staticRequestHandler = (request, response) => { - // console.log("static rq: " + request.socket.remoteAddress + ":" + request.socket.remotePort + " -> " + request.socket.localAddress + ":" + request.socket.localPort + " - " + request.method + " " + request.url + " " + request.httpVersion + " - " + request.rawHeaders.join('|')); - request.resume(); - request.addListener('end', () => { - if (Config.customhttpresponse && - Config.customhttpresponse(request, response)) { - return; + let {id} = this; + let count = 0; + Users.connections.forEach(connection => { + if (connection.worker === this) { + Users.socketDisconnect(this, id, connection.socketid); + count++; } - let server; - if (request.url === '/custom.css') { - server = cssserver; - } else if (request.url.substr(0, 9) === '/avatars/') { - request.url = request.url.substr(8); - server = avatarserver; - } else { - if (/^\/([A-Za-z0-9][A-Za-z0-9-]*)\/?$/.test(request.url)) { - request.url = '/'; - } - server = staticserver; - } - server.serve(request, response, (e, res) => { - if (e && (e.status === 404)) { - staticserver.serveFile('404.html', 404, {}, request, response); - } - }); }); - }; - app.on('request', staticRequestHandler); - if (appssl) { - appssl.on('request', staticRequestHandler); - } - } catch (e) { - console.log('Could not start node-static - try `npm install` if you want to use it'); - } - - // SockJS server - - // This is the main server that handles users connecting to our server - // and doing things on our server. + console.error(`${count} connections were lost.`); - let sockjs = require('sockjs'); - - let server = sockjs.createServer({ - sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", - log: (severity, message) => { - if (severity === 'error') console.log('ERROR: ' + message); - }, - prefix: '/showdown', - }); + // Leave the worker in the workers map so it can be investigated + // later, and try respawning it once the process finishes exiting. + }); + } - let sockets = {}; - let channels = {}; - let subchannels = {}; - - // Deal with phantom connections. - let sweepClosedSockets = function () { - for (let s in sockets) { - if (sockets[s].protocol === 'xhr-streaming' && - sockets[s]._session && - sockets[s]._session.recv) { - sockets[s]._session.recv.didClose(); - } + onSocketConnect(socketid, remoteAddress, header, protocol) { +// console.log(`sockets: socket connect (${socketid}, ${remoteAddress}, ${header}, ${protocol})`); - // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while - // it is an object for normal users. Under normal circumstances, those properties should only be - // `null` when the timeout has already been called, but somehow it's not happening for some connections. - // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually - // on those connections kills those connections. For a bit of background, this timeout is the timeout - // that sockjs sets to wait for users to reconnect within that time to continue their session. - if (sockets[s]._session && - sockets[s]._session.to_tref && - !sockets[s]._session.to_tref._idlePrev) { - sockets[s]._session.timeout_cb(); - } - } - }; - let interval = setInterval(sweepClosedSockets, 1000 * 60 * 10); // eslint-disable-line no-unused-vars - - process.on('message', data => { - // console.log('worker received: ' + data); - let socket = null, socketid = ''; - let channel = null, channelid = ''; - let subchannel = null, subchannelid = ''; - - switch (data.charAt(0)) { - case '$': // $code - eval(data.substr(1)); - break; - - case '!': // !socketid - // destroy - socketid = data.substr(1); - socket = sockets[socketid]; - if (!socket) return; - socket.end(); - // After sending the FIN packet, we make sure the I/O is totally blocked for this socket - socket.destroy(); - delete sockets[socketid]; - for (channelid in channels) { - delete channels[channelid][socketid]; + let ip = remoteAddress; + if (header && isTrustedProxyIp(remoteAddress)) { + let ips = header.split(','); + for (let i = ips.length; i--;) { + ip = ips[i].trim(); + if (!isTrustedProxyIp(ip)) break; } - break; - - case '>': { - // >socketid, message - // message - let nlLoc = data.indexOf('\n'); - socket = sockets[data.substr(1, nlLoc - 1)]; - if (!socket) return; - socket.write(data.substr(nlLoc + 1)); - break; } - case '#': { - // #channelid, message - // message to channel - let nlLoc = data.indexOf('\n'); - channel = channels[data.substr(1, nlLoc - 1)]; - let message = data.substr(nlLoc + 1); - for (socketid in channel) { - channel[socketid].write(message); - } - break; - } + Users.socketConnect(this, this.id, socketid, ip, protocol); + } - case '+': { - // +channelid, socketid - // add to channel - let nlLoc = data.indexOf('\n'); - socketid = data.substr(nlLoc + 1); - socket = sockets[socketid]; - if (!socket) return; - channelid = data.substr(1, nlLoc - 1); - channel = channels[channelid]; - if (!channel) channel = channels[channelid] = Object.create(null); - channel[socketid] = socket; - break; - } + onSocketDisconnect(socketid) { +// console.log(`sockets: socket disconnect (${socketid})`); + Users.socketDisconnect(this, this.id, socketid); + } - case '-': { - // -channelid, socketid - // remove from channel - let nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels[channelid]; - if (!channel) return; - socketid = data.slice(nlLoc + 1); - delete channel[socketid]; - if (subchannels[channelid]) delete subchannels[channelid][socketid]; - let isEmpty = true; - for (let socketid in channel) { // eslint-disable-line no-unused-vars - isEmpty = false; - break; - } - if (isEmpty) { - delete channels[channelid]; - delete subchannels[channelid]; - } - break; - } + onSocketReceive(socketid, message) { +// console.log(`sockets: socket receive (${socketid}, ${message})`); + Users.socketReceive(this, this.id, socketid, message); + } - case '.': { - // .channelid, subchannelid, socketid - // move subchannel - let nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - let nlLoc2 = data.indexOf('\n', nlLoc + 1); - subchannelid = data.slice(nlLoc + 1, nlLoc2); - socketid = data.slice(nlLoc2 + 1); - - subchannel = subchannels[channelid]; - if (!subchannel) subchannel = subchannels[channelid] = Object.create(null); - if (subchannelid === '0') { - delete subchannel[socketid]; - } else { - subchannel[socketid] = subchannelid; - } - break; - } + isDead() { + return this.process.exitCode !== null || this.process.signalCode !== null; + } - case ':': { - // :channelid, message - // message to subchannel - let nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels[channelid]; - subchannel = subchannels[channelid]; - let message = data.substr(nlLoc + 1); - let messages = [null, null, null]; - for (socketid in channel) { - switch (subchannel ? subchannel[socketid] : '0') { - case '1': - if (!messages[1]) { - messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); - } - channel[socketid].write(messages[1]); - break; - case '2': - if (!messages[2]) { - messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); - } - channel[socketid].write(messages[2]); - break; - default: - if (!messages[0]) { - messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); - } - channel[socketid].write(messages[0]); - break; - } - } - break; - } + kill(signal = 'SIGTERM') { + this.process.kill(signal); + } - default: - } - }); + send(command, ...params) { + if (this.isDead()) return false; - process.on('disconnect', () => { - process.exit(); - }); + let payload = `${JSON.stringify({Command: command, Params: params})}${EOT}`; + let res = this.process.stdin.write(payload); + if (!res) this.queue.push([command, params]); + return res; + } + + static spawn(port = Config.port, bindAddress = Config.bindaddress, workerCount = Config.workers) { + // Don't spawn another child process if one is already alive -- it will + // always crash when it launches its servers otherwise! + for (let [id, worker] of workers) { // eslint-disable-line no-unused-vars + if (!worker.isDead()) return false; + } - // this is global so it can be hotpatched if necessary - let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); - let socketCounter = 0; - server.on('connection', socket => { - if (!socket) { - // For reasons that are not entirely clear, SockJS sometimes triggers - // this event with a null `socket` argument. - return; - } else if (!socket.remoteAddress) { - // This condition occurs several times per day. It may be a SockJS bug. + let {ssl} = Config; + let id = workers.size; + if (port !== undefined && !isNaN(port)) { + port = Config.port; + } else { try { - socket.end(); + let cloudenv = require('cloud-env'); + port = cloudenv.get('PORT', port); + bindAddress = cloudenv.get('IP', bindAddress); } catch (e) {} - return; } - let socketid = socket.id = (++socketCounter); - - sockets[socket.id] = socket; - if (isTrustedProxyIp(socket.remoteAddress)) { - let ips = (socket.headers['x-forwarded-for'] || '').split(','); - let ip; - while ((ip = ips.pop())) { - ip = ip.trim(); - if (!isTrustedProxyIp(ip)) { - socket.remoteAddress = ip; - break; - } + if (ssl && typeof ssl === 'object' && !Array.isArray(ssl)) { + try { + ssl = JSON.stringify(ssl); + } catch (e) { + ssl = null; } - } - - process.send('*' + socketid + '\n' + socket.remoteAddress + '\n' + socket.protocol); - socket.on('data', message => { - // drop empty messages (DDoS?) - if (!message) return; - // drop messages over 100KB - if (message.length > 100000) { - console.log("Dropping client message " + (message.length / 1024) + " KB..."); - console.log(message.slice(0, 160)); - return; + if (ssl !== null && typeof ssl.port === 'number' && !isNaN(ssl.port)) { + ssl.port = `:${ssl.port}` } - // drop legacy JSON messages - if (typeof message !== 'string' || message.charAt(0) === '{') return; - // drop blank messages (DDoS?) - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0 || pipeIndex === message.length - 1) return; + } else { + ssl = null; + } - process.send('<' + socketid + '\n' + message); - }); + if (workerCount === undefined) workerCount = 1; - socket.on('close', () => { - process.send('!' + socketid); - delete sockets[socketid]; - for (let channelid in channels) { - delete channels[channelid][socketid]; - } - }); - }); - server.installHandlers(app, {}); - if (!Config.bindaddress) Config.bindaddress = '0.0.0.0'; - app.listen(Config.port, Config.bindaddress); - console.log('Worker ' + cluster.worker.id + ' now listening on ' + Config.bindaddress + ':' + Config.port); - - if (appssl) { - server.installHandlers(appssl, {}); - appssl.listen(Config.ssl.port, Config.bindaddress); - console.log('Worker ' + cluster.worker.id + ' now listening for SSL on port ' + Config.ssl.port); + let worker = new Worker({id, port, bindAddress, ssl, workerCount}); + return worker; } +} - console.log('Test your server at http://' + (Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress) + ':' + Config.port); +exports.Worker = Worker; + +exports.listen = (...args) => { + let worker = Worker.spawn(...args); + if (worker) workers.set(worker.id, worker); +}; +exports.spawnWorker = () => { + let worker = Worker.spawn(); + if (worker) workers.set(worker.id, worker); +}; +exports.killWorker = worker => { + let {id} = worker; + let count = 0; + Users.connections.forEach(connection => { + if (connection.worker === worker) { + Users.socketDisconnect(worker, id, connection.socketid); + count++; + } + }); - require('./repl').start('sockets-', cluster.worker.id + '-' + process.pid, cmd => eval(cmd)); -} + worker.kill(); + workers.delete(id); + return count; +}; + +exports.socketSend = (worker, socketid, message) => { +// console.log(`sockets: sending to socket of ID ${socketid}: ${message}`); + worker.send('>', socketid, message); +}; +exports.socketDisconnect = (worker, socketid) => { +// console.log(`sockets: disconnecting socket of ID ${socketid}`); + worker.send('!', socketid); +}; + +exports.channelSend = (worker, channelid, message) => { +// console.log(`sockets: sending to channel of ID ${channelid}: ${message}`); + worker.send('#', channelid, message); +}; +exports.channelBroadcast = (channelid, message) => { +// console.log(`sockets: broadcasting to channel of ID ${channelid}: ${message}`); + let worker = workers.get(workers.size - 1); + if (worker) worker.send('#', channelid, message); +}; +exports.channelAdd = (worker, channelid, socketid) => { +// console.log(`sockets: adding socket of ID ${socketid} to channel of ID ${channelid}`); + worker.send('+', channelid, socketid); +}; +exports.channelRemove = (worker, channelid, socketid) => { +// console.log(`sockets: removing socket of ID ${socketid} from channel of ID ${channelid}`); + worker.send('-', channelid, socketid); +}; + +exports.subchannelBroadcast = (channelid, message) => { +// console.log(`sockets: broadcasting to subchannels in channel of ID ${channelid}: ${message}`); + let worker = workers.get(workers.size - 1); + if (worker) worker.send(':', channelid, message); +}; +exports.subchannelMove = (worker, channelid, subchannelid, socketid) => { +// console.log(`sockets: moving socketid of ${socketid} to subchannel ${subchannelid} in channel of ID ${channelid}`); + worker.send('.', channelid, subchannelid, socketid); +}; diff --git a/sockets_test.go b/sockets_test.go new file mode 100644 index 0000000000000..8795b58cf9994 --- /dev/null +++ b/sockets_test.go @@ -0,0 +1,249 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/igm/sockjs-go/sockjs" +) + +// Mock of sockjs.Session for socketsMultiplexer tests. +type MockSession struct { + id string + req *http.Request + state sockjs.SessionState +} + +func NewMockSession(id string) MockSession { + return MockSession{ + id: id, + req: httptest.NewRequest("GET", "http://localhost:8000/showdown", nil), + state: sockjs.SessionActive} +} + +func (ms MockSession) ID() string { + return ms.id +} + +func (ms MockSession) Request() *http.Request { + return ms.req +} + +func (ms MockSession) Recv() (string, error) { + return "", nil +} + +func (ms MockSession) Send(msg string) error { + return nil +} + +func (ms MockSession) Close(status uint32, reason string) error { + ms.state = sockjs.SessionClosed + return nil +} + +func (ms MockSession) GetSessionState() sockjs.SessionState { + return ms.state +} + +func (ms MockSession) ServeHTTP(rw http.ResponseWriter, req *http.Request) {} + +func pipeErr(ech chan error, err error) { + go func() { + ech <- err + }() +} + +func scrubSM() { + for sid, _ := range sm.sockets { + delete((*sm).sockets, sid) + } + for cid, _ := range sm.channels { + delete((*sm).channels, cid) + } +} + +func Test_newSocketMultiplexer(t *testing.T) { + sm = newSocketMultiplexer() + if sm.sockets == nil { + t.Errorf("SM sockets map does not exist") + } + if sm.channels == nil { + t.Errorf("SM channels map does not exist") + } +} + +func Test_socketMultiplexer_SocketAdd(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + if _, ok := sm.sockets[s.ID()]; !ok { + t.Errorf("SM SocketAdd failed to add the session to sockets map") + } + + scrubSM() +} + +func Test_socketMultiplexer_SocketSend(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + if err := sm.SocketSend(s.ID(), ""); err != nil { + t.Errorf("SM SocketSend failed: %v", err) + } + + scrubSM() +} + +func Test_socketMultiplexer_SocketRemove(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + sm.SocketRemove(s.ID(), true) + if _, ok := sm.sockets[s.ID()]; ok { + t.Errorf("SM SocketRemove failed to remove the session from the sockets map") + } + + // Check for race conditions. + ech := make(chan error) + for i := 0; i < 100; i += 1 { + id := fmt.Sprint(i) + digits := 8 - len(id) + id = strings.Repeat("a", digits) + id + ms := NewMockSession(id) + s := sockjs.Session(ms) + go func() { + pipeErr(ech, sm.SocketAdd(s)) + pipeErr(ech, sm.SocketSend(s.ID(), "")) + pipeErr(ech, sm.SocketRemove(s.ID(), true)) + }() + } + + for i := 0; i < 300; i += 1 { + if err := <-ech; err != nil { + t.Errorf("SM sockets race condition in add/remove/send: %v", err) + break + } + } + + scrubSM() +} + +func Test_socketMultiplexer_ChannelAdd(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + sm.ChannelAdd("global", s.ID()) + if _, ok := sm.channels["global"]; !ok { + t.Errorf("SM ChannelAdd failed to add channel to channels map") + } + if _, ok := sm.channels["global"][s.ID()]; !ok { + t.Errorf("SM ChannelAdd failed to add socket to new channel") + } + + scrubSM() +} + +func Test_socketMultiplexer_ChannelSend(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + sm.ChannelAdd("global", s.ID()) + if err := sm.ChannelSend("global", ""); err != nil { + t.Errorf("SM ChannelSend failed to send message: %v", err) + } + + scrubSM() +} + +func Test_socketMultiplexer_ChannelRemove(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + ms2 := NewMockSession("aaaaaaa1") + s := sockjs.Session(ms) + s2 := sockjs.Session(ms2) + sm.SocketAdd(s) + sm.SocketAdd(s2) + sm.ChannelAdd("global", s.ID()) + sm.ChannelAdd("global", s2.ID()) + sm.ChannelRemove("global", s2.ID()) + if _, ok := sm.channels["global"][s2.ID()]; ok { + t.Errorf("SM ChannelRemove failed to remove socket from channel") + } + + sm.ChannelRemove("global", s.ID()) + if _, ok := sm.channels["global"]; ok { + t.Errorf("SM ChannelRemove failed to remove channel when removing its last socket") + } + + // Check for race conditions. + ech := make(chan error) + for i := 0; i < 100; i += 1 { + id := fmt.Sprintf("battle-ou-%v", i) + go func() { + sm.ChannelAdd(id, s.ID()) + pipeErr(ech, sm.ChannelSend(id, "")) + pipeErr(ech, sm.ChannelRemove(id, s.ID())) + }() + } + + for i := 0; i < 200; i += 1 { + err := <-ech + if err != nil { + t.Errorf("SM channel race condition in add/remove/send: %v", err) + break + } + } + + scrubSM() +} + +func Test_socketMultiplexer_SubchannelMove(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + sm.ChannelAdd("battle-ou-1", s.ID()) + sm.SubchannelMove("battle-ou-1", "1", s.ID()) + if scid, _ := sm.channels["battle-ou-1"][s.ID()]; scid != "1" { + t.Errorf("SM SubchannelMove failed to move socket to new subchannel") + } + + scrubSM() +} + +func Test_socketMultiplexer_SubchannelSend(t *testing.T) { + ms := NewMockSession("aaaaaaa0") + s := sockjs.Session(ms) + sm.SocketAdd(s) + sm.ChannelAdd("battle-ou-1", s.ID()) + if err := sm.SubchannelSend("battle-ou-1", ""); err != nil { + t.Errorf("SM SubchannelSend failed to send to subchannel: %v", err) + } + + // Check for race conditions. + ech := make(chan error) + for i := 0; i < 100; i += 1 { + id := "battle-ou-" + fmt.Sprint(i) + go func() { + sm.ChannelAdd(id, s.ID()) + pipeErr(ech, sm.SubchannelSend(id, "\n|split\n|c|@Morfent|sup0\n|c|@Morfent|sup1\n|c|@Morfent|sup2\n")) + pipeErr(ech, sm.ChannelRemove(id, s.ID())) + }() + } + + for i := 0; i < 200; i += 1 { + err := <-ech + if err != nil { + t.Errorf("SM channel race condition in add/remove/send: %v", err) + break + } + } + + scrubSM() +} + +func init() { + production = false +} diff --git a/users.js b/users.js index c768b3d3cfe29..5ca8961fac00f 100644 --- a/users.js +++ b/users.js @@ -1471,7 +1471,8 @@ Users.socketConnect = function (worker, workerid, socketid, ip, protocol) { let banned = Punishments.checkIpBanned(connection); if (banned) { - return connection.destroy(); + setImmediate(() => connection.destroy()); + return; } // Emergency mode connections logging if (Config.emergency) { @@ -1520,22 +1521,11 @@ Users.socketReceive = function (worker, workerid, socketid, message) { let connection = connections.get(id); if (!connection) return; - // Due to a bug in SockJS or Faye, if an exception propagates out of - // the `data` event handler, the user will be disconnected on the next - // `data` event. To prevent this, we log exceptions and prevent them - // from propagating out of this function. - - // drop legacy JSON messages - if (message.charAt(0) === '{') return; - - // drop invalid messages without a pipe character - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0) return; - const user = connection.user; if (!user) return; // The client obviates the room id when sending messages to Lobby by default + const pipeIndex = message.indexOf('|'); const roomId = message.substr(0, pipeIndex) || (Rooms.lobby || Rooms.global).id; message = message.slice(pipeIndex + 1);