From d0110b335041f812227c19257ea49b3de28cd4db Mon Sep 17 00:00:00 2001 From: Ben Davies Date: Sun, 14 May 2017 02:05:37 -0300 Subject: [PATCH] Sockets: Go rewrite Work in progress. I haven't written documentation for the Go code or confirmed whether this works as intended on Windows yet. This turned into a pretty substantial refactor of the Node version of sockets-related code to not only to make using Go child processes possible, but to optimize it and make unit testing of it and any code dependent on it possible to write entirely synchronously. sockets.js and sockets-workers.js are now written in Typescript, though they won't be able to be transpiled until after Config, Users, Dnsbl, and Monitor work with it as well. Fixes #2943 --- config/config-example.js | 20 +- dev-tools/sockets.js | 35 + package-lock.json | 55 +- package.json | 6 + pokemon-showdown | 128 +++- sockets-workers.js | 558 ++++++++++++++++ sockets.js | 1089 ++++++++++++++++++------------- sockets/lib/commands.go | 87 +++ sockets/lib/commands_test.go | 26 + sockets/lib/config.go | 38 ++ sockets/lib/config_test.go | 36 + sockets/lib/ipc.go | 139 ++++ sockets/lib/ipc_test.go | 38 ++ sockets/lib/master.go | 82 +++ sockets/lib/master_test.go | 85 +++ sockets/lib/multiplexer.go | 371 +++++++++++ sockets/lib/multiplexer_test.go | 131 ++++ sockets/main.go | 125 ++++ test/application/sockets.js | 278 +++----- tsconfig.json | 6 +- users.js | 83 ++- 21 files changed, 2726 insertions(+), 690 deletions(-) create mode 100644 dev-tools/sockets.js create mode 100644 sockets-workers.js create mode 100644 sockets/lib/commands.go create mode 100644 sockets/lib/commands_test.go create mode 100644 sockets/lib/config.go create mode 100644 sockets/lib/config_test.go create mode 100644 sockets/lib/ipc.go create mode 100644 sockets/lib/ipc_test.go create mode 100644 sockets/lib/master.go create mode 100644 sockets/lib/master_test.go create mode 100644 sockets/lib/multiplexer.go create mode 100644 sockets/lib/multiplexer_test.go create mode 100644 sockets/main.go diff --git a/config/config-example.js b/config/config-example.js index 8a68709144c91..0d205bfe68850 100644 --- a/config/config-example.js +++ b/config/config-example.js @@ -3,6 +3,24 @@ // The server port - the port to run Pokemon Showdown under exports.port = 8000; +// The server bind address - the address to run Pokemon Showdown under +// This should be left set to 0.0.0.0 unless you know what you are doing. +exports.bindaddress = '0.0.0.0'; + +// workers - the number of sockets workers to spawn +// This should not be set any higher than the number of cores available on +// the server's CPU(s). This can be checked from a REPL using +// require('os').cpus().length if you are unsure. +exports.workers = 1; + +// golang - toggle using Go instead of Node for sockets workers +// Node workers are more unstable at handling connections because of bugs in +// sockjs-node, but sending/receiving messages over connections on Go workers +// is slightly slower due to the extra work involved in performing IPC with +// them safely. This should be left set to false unless you know what you are +// doing. +exports.golang = false; + // proxyip - proxy IPs with trusted X-Forwarded-For headers // This can be either false (meaning not to trust any proxies) or an array // of strings. Each string should be either an IP address or a subnet given @@ -10,7 +28,7 @@ exports.port = 8000; // know what you are doing. exports.proxyip = false; -// ofe - write heapdumps if sockets.js workers run out of memory. +// ofe - write heapdumps if Node sockets workers run out of memory // If you wish to enable this, you will need to install ofe, as it is not a // installed by default: // $ npm install --no-save ofe diff --git a/dev-tools/sockets.js b/dev-tools/sockets.js new file mode 100644 index 0000000000000..a4e79d77e12c7 --- /dev/null +++ b/dev-tools/sockets.js @@ -0,0 +1,35 @@ +'use strict'; + +const {Session, SockJSConnection} = require('sockjs/lib/transport'); + +const chars = 'abcdefghijklmnopqrstuvwxyz1234567890-'; +let sessionidCount = 0; + +/** + * @return string + */ +function generateSessionid() { + let ret = ''; + let idx = sessionidCount; + for (let i = 0; i < 8; i++) { + ret = chars[idx % chars.length] + ret; + idx = idx / chars.length | 0; + } + sessionidCount++; + return ret; +} + +/** + * @param {string} sessionid + * @param {{options: {{}}} config + * @return SockJSConnection + */ +exports.createSocket = function (sessionid = generateSessionid(), config = {options: {}}) { + let session = new Session(sessionid, config); + let socket = new SockJSConnection(session); + socket.remoteAddress = '127.0.0.1'; + socket.protocol = 'websocket'; + return socket; +}; + +// TODO: move worker mocks here, use require('../sockets-workers').Multiplexer to stub IPC diff --git a/package-lock.json b/package-lock.json index fa4850361ed51..fb3a278ada8af 100644 --- a/package-lock.json +++ b/package-lock.json @@ -4,12 +4,30 @@ "lockfileVersion": 1, "requires": true, "dependencies": { + "@types/cloud-env": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/@types/cloud-env/-/cloud-env-0.2.0.tgz", + "integrity": "sha512-18AOYo8HyJYEmKcTt/wD4e6HGEInmKLEOWrE6qL8ran0DfyCbmxzR3R1sJKv4XjPHJOyyiXV4bH6tEnuwSgSBQ==", + "dev": true + }, + "@types/mime": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/@types/mime/-/mime-1.3.1.tgz", + "integrity": "sha512-rek8twk9C58gHYqIrUlJsx8NQMhlxqHzln9Z9ODqiNgv3/s+ZwIrfr+djqzsnVM12xe9hL98iJ20lj2RvCBv6A==", + "dev": true + }, "@types/node": { "version": "8.0.1", "resolved": "https://registry.npmjs.org/@types/node/-/node-8.0.1.tgz", "integrity": "sha512-bys2VRs6H7HP8S26aHgPWSiSX7q81TToe5HSSvl5bQjoSElQ2SwbGw2p6/DSDb7Vr0oKhewFao9ZuTn8DSag9Q==", "dev": true }, + "@types/node-static": { + "version": "0.7.0", + "resolved": "https://registry.npmjs.org/@types/node-static/-/node-static-0.7.0.tgz", + "integrity": "sha512-4SImtzapcVt+rQEAKVtbT0eh2D895DKnyrRkDgcSpw+LNnol9zlJPcU6yDvjWrEV/6nBSPQqzY0AP69v5v2iEQ==", + "dev": true + }, "@types/nodemailer": { "version": "1.3.33", "resolved": "https://registry.npmjs.org/@types/nodemailer/-/nodemailer-1.3.33.tgz", @@ -40,6 +58,24 @@ "@types/nodemailer": "1.3.33" } }, + "@types/ofe": { + "version": "0.5.0", + "resolved": "https://registry.npmjs.org/@types/ofe/-/ofe-0.5.0.tgz", + "integrity": "sha512-d/yCVOHDKVLQxzXLjU3cXol59ctzrfKkjzcWrGfy3fnacZUUbs8fTBWjqRcY+dr4v/yVNOy+jLL3q2/OM1LOCQ==", + "dev": true + }, + "@types/sockjs": { + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@types/sockjs/-/sockjs-0.3.31.tgz", + "integrity": "sha512-6d+6cH187jHWoUP07fzDQkC1fATu7TXMNJaCKwMwpcjkRvAK4T7bDw8sukL3rE8A/ZEPmo94YghsgCkI2/V8kw==", + "dev": true + }, + "acorn": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-5.0.3.tgz", + "integrity": "sha1-xGDfCEkUY/AozLguqzcwvwEIez0=", + "dev": true + }, "acorn-jsx": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-3.0.1.tgz", @@ -719,7 +755,7 @@ "globals": { "version": "9.18.0", "resolved": "https://registry.npmjs.org/globals/-/globals-9.18.0.tgz", - "integrity": "sha512-S0nG3CLEQiY/ILxqtztTWH/3iRRdyBLw6KMDxnKMchrtbj2OFmehVh0WUCfW3DUrIgx/qFrJPICrq4Z4sTR9UQ==", + "integrity": "sha1-qjiWs+abSH8X4x7SFD1pqOMMLYo=", "dev": true }, "globby": { @@ -1335,6 +1371,23 @@ "safe-buffer": "5.0.1" } }, + "string-width": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-1.0.2.tgz", + "integrity": "sha1-EYvfW4zcUaKn5w0hHgfisLmxB9M=", + "dev": true, + "requires": { + "code-point-at": "1.1.0", + "is-fullwidth-code-point": "1.0.0", + "strip-ansi": "3.0.1" + } + }, + "string_decoder": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.0.2.tgz", + "integrity": "sha1-sp4fThEl+pehA4K4pTNze3SR4Xk=", + "dev": true + }, "strip-ansi": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-3.0.1.tgz", diff --git a/package.json b/package.json index 5b07145685375..0035ab4a9fd63 100644 --- a/package.json +++ b/package.json @@ -47,6 +47,12 @@ "private": true, "license": "MIT", "devDependencies": { + "@types/cloud-env": "^0.2.0", + "@types/node": "^8.0.1", + "@types/node-static": "^0.7.0", + "@types/nodemailer": "^1.3.33", + "@types/ofe": "^0.5.0", + "@types/sockjs": "^0.3.31", "eslint": "^4.0.0", "mocha": "^3.0.0", "@types/node": "^8.0.1", diff --git a/pokemon-showdown b/pokemon-showdown index 5d4fd0e9b4c83..1c98398de1b53 100755 --- a/pokemon-showdown +++ b/pokemon-showdown @@ -39,27 +39,119 @@ try { ); } -if (!process.argv[2] || /^[0-9]+$/.test(process.argv[2])) { - // Start the server. We manually load app.js so it can be configured to run as - // the main module, rather than this file being considered the main module. - // This ensures any dependencies that were just installed can be found when - // running on Windows and avoids any other potential side effects of the main - // module not being app.js like it is assumed to be. - // - // The port the server should host on can be passed using the second argument - // when launching with this file the same way app.js normally allows, e.g. to - // host on port 9000: - // $ ./pokemon-showdown 9000 - - require('module')._load('./app', module, true); -} else switch (process.argv[2]) { +// ALlow arguments passed to the launch script to be evaluated as commands. +let [, , arg2, arg3, arg4] = process.argv; +if (arg2 && /^[0-9]$/.test(arg2)) { + switch (arg2) { case 'generate-team': const Dex = require('./sim/dex'); global.toId = Dex.getId; - const seed = process.argv[4] ? process.argv[4].split(',').map(Number) : undefined; - console.log(Dex.packTeam(Dex.generateTeam(process.argv[3], seed))); - break; + const seed = arg4 ? arg4.split(',').map(Number) : undefined; + console.log(Dex.packTeam(Dex.generateTeam(arg3, seed))); + process.exit(0); default: - console.error('Unrecognized command: ' + process.argv[2]); + console.error(`Unrecognized command: ${arg2}`); process.exit(1); + } } + +// If evaluating commands wasn't the point of running this script, let's launch +// the server. + +// Check if the server is configured to use Go, and ensure the required +// environment variables and dependencies are available if that is the case + +let config; +try { + config = require('./config/config'); +} catch (e) {} + +if (config && config.golang) { + // GOPATH and GOROOT are optional to a degree, but we need them in order + // to be able to handle Go dependencies. Since Go only cares about the + // first path in the list, so will we. + const GOPATH = child_process.execSync('go env GOPATH', {stdio: null, encoding: 'utf8'}) + .trim() + .split(path.delimiter)[0] + .replace(/^"(.*)"$/, '$1'); + if (!GOPATH) { + // Should never happen, but it does on Bash on Ubuntu on Windows. + console.error('There is no $GOPATH environment variable set. Run:'); + console.error('$ go help GOPATH'); + console.error('For more information on how to configure it.'); + process.exit(0); + } + + const dependencies = ['github.com/gorilla/mux', 'github.com/igm/sockjs-go/sockjs']; + let packages = child_process.execSync('go list all', {stdio: null, encoding: 'utf8'}); + for (let dep of dependencies) { + if (!packages.includes(dep)) { + console.log(`Dependency ${dep} is not installed. Fetching...`); + child_process.execSync(`go get ${dep}`, {stdio: 'inherit'}); + } + } + + let stat; + let needsSrcDir = false; + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + } catch (e) { + needsSrcDir = true; + } finally { + if (stat && !stat.isDirectory()) { + needsSrcDir = true; + } + } + + let srcPath = path.resolve(process.cwd(), 'sockets'); + let tarPath = path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown/sockets'); + if (needsSrcDir) { + try { + fs.mkdirSync(path.resolve(GOPATH, 'src/github.com/Zarel')); + fs.mkdirSync(path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown')); + } catch (e) { + console.error(e); + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${srcPath} to ${tarPath}`); + process.exit(0); + } + } + + try { + stat = fs.lstatSync(path.resolve(GOPATH, 'src/github.com/Zarel/Pokemon-Showdown/sockets')); + } catch (e) {} + + if (!stat || !stat.isSymbolicLink()) { + // Windows requires administrator privileges to make symlinks, so we + // make junctions instead. For our purposes they're compatible enough + // with symlinks on UNIX-like OSes. + let symlinkType = (process.platform === 'win32') ? 'junction' : 'dir'; + try { + fs.symlinkSync(srcPath, tarPath, symlinkType); + } catch (e) { + console.error(`Cannot make go source directory for the sockets library files! Symlink them manually from ${srcPath} to ${tarPath}`); + process.exit(0); + } + } + + console.log('Building Go source libs...'); + try { + child_process.execSync('go install github.com/Zarel/Pokemon-Showdown/sockets', {stdio: 'inherit'}); + } catch (e) { + // Go will show the errors that caused compiling Go's files to fail, so + // there's no reason to bother logging anything of our own. + process.exit(0); + } +} + +// Start the server. We manually load app.js so it can be configured to run as +// the main module, rather than this file being considered the main module. +// This ensures any dependencies that were just installed can be found when +// running on Windows and avoids any other potential side effects of the main +// module not being app.js like it is assumed to be. +// +// The port the server should host on can be passed using the second argument +// when launching with this file the same way app.js normally allows, e.g. to +// host on port 9000: +// $ ./pokemon-showdown 9000 + +require('module')._load('./app', module, true); diff --git a/sockets-workers.js b/sockets-workers.js new file mode 100644 index 0000000000000..678d252ea626f --- /dev/null +++ b/sockets-workers.js @@ -0,0 +1,558 @@ +/** + * Connections + * Pokemon Showdown - http://pokemonshowdown.com/ + * + * Abstraction layer for multi-process SockJS connections. + * + * This file handles all the communications between the users' browsers and + * the main process. + * + * @license MIT license + */ + +'use strict'; + +const cluster = require('cluster'); +const fs = require('fs'); + +// IPC command tokens +const EVAL = '$'; +const SOCKET_CONNECT = '*'; +const SOCKET_DISCONNECT = '!'; +const SOCKET_RECEIVE = '<'; +const SOCKET_SEND = '>'; +const CHANNEL_ADD = '+'; +const CHANNEL_REMOVE = '-'; +const CHANNEL_BROADCAST = '#'; +const SUBCHANNEL_MOVE = '.'; +const SUBCHANNEL_BROADCAST = ':'; + +// Subchannel IDs +const DEFAULT_SUBCHANNEL_ID = '0'; +const P1_SUBCHANNEL_ID = '1'; +const P2_SUBCHANNEL_ID = '2'; + +// Regex for splitting subchannel broadcasts between subchannels. +const SUBCHANNEL_MESSAGE_REGEX = /\|split\n([^\n]*)\n([^\n]*)\n([^\n]*)\n[^\n]*/g; + +/** + * Manages the worker's state for sockets, channels, and + * subchannels. This is responsible for parsing all outgoing and incoming + * messages. + */ +class Multiplexer { + constructor() { + /** @type {number} */ + this.socketCounter = 0; + /** @type {Map} */ + this.sockets = new Map(); + /** @type {Map>} */ + this.channels = new Map(); + /** @type {?NodeJS.Timer} */ + this.cleanupInterval = setInterval(() => this.sweepClosedSockets(), 10 * 60 * 1000); + } + + /** + * Mitigates a potential bug in SockJS or Faye-Websocket where + * sockets fail to emit a 'close' event after having disconnected. + */ + sweepClosedSockets() { + this.sockets.forEach(socket => { + if (socket.protocol === 'xhr-streaming' && + socket._session && + socket._session.recv) { + socket._session.recv.didClose(); + } + + // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while + // it is an object for normal users. Under normal circumstances, those properties should only be + // `null` when the timeout has already been called, but somehow it's not happening for some connections. + // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually + // on those connections kills those connections. For a bit of background, this timeout is the timeout + // that sockjs sets to wait for users to reconnect within that time to continue their session. + if (socket._session && + socket._session.to_tref && + !socket._session.to_tref._idlePrev) { + socket._session.timeout_cb(); + } + }); + + // Don't bother deleting the sockets from our map; their close event + // handler will deal with it. + } + + /** + * Sends an IPC message to the parent process. + * + * @param {string} token + * @param {string[]} params + */ + sendUpstream(token, ...params) { + let message = `${token}${params.join('\n')}`; + if (process.send) process.send(message); + } + + /** + * Parses the params in a downstream message sent as a + * command. + * + * @param {string} params + * @param {number} count + * @return {string[]} + */ + parseParams(params, count) { + let i = 0; + let idx = 0; + let ret = []; + while (i++ < count) { + let newIdx = params.indexOf('\n', idx); + if (newIdx < 0) { + // No remaining newlines; just use the rest of the string as + // the last parametre. + ret.push(params.slice(idx)); + break; + } + + let param = params.slice(idx, newIdx); + if (i === count) { + // We reached the number of parametres needed, but there is + // still some remaining string left. Glue it to the last one. + param += `\n${params.slice(newIdx + 1)}`; + } else { + idx = newIdx + 1; + } + + ret.push(param); + } + + return ret; + } + + /** + * Parses downstream messages. + * + * @param {string} data + * @return {boolean} + */ + receiveDownstream(data) { + // console.log(`worker received: ${data}`); + let token = data.charAt(0); + let params = data.substr(1); + switch (token) { + case EVAL: + return this.onEval(params); + case SOCKET_DISCONNECT: + return this.onSocketDisconnect(params); + case SOCKET_SEND: + // @ts-ignore + return this.onSocketSend(...this.parseParams(params, 2)); + case CHANNEL_ADD: + // @ts-ignore + return this.onChannelAdd(...this.parseParams(params, 2)); + case CHANNEL_REMOVE: + // @ts-ignore + return this.onChannelRemove(...this.parseParams(params, 2)); + case CHANNEL_BROADCAST: + // @ts-ignore + return this.onChannelBroadcast(...this.parseParams(params, 2)); + case SUBCHANNEL_MOVE: + // @ts-ignore + return this.onSubchannelMove(...this.parseParams(params, 3)); + case SUBCHANNEL_BROADCAST: + // @ts-ignore + return this.onSubchannelBroadcast(...this.parseParams(params, 2)); + default: + console.error(`Sockets: attempted to send unknown IPC message with token ${token}: ${params}`); + return false; + } + } + + /** + * Safely tries to destroy a socket's connection. + * + * @param {any} socket + */ + tryDestroySocket(socket) { + try { + socket.end(); + socket.destroy(); + } catch (e) {} + } + + /** + * Eval handler for downstream messages. + * + * @param {string} expr + * @return {boolean} + */ + onEval(expr) { + try { + eval(expr); + return true; + } catch (e) {} + return false; + } + + /** + * Sockets.socketConnect message handler. + * + * @param {any} socket + * @return {boolean} + */ + onSocketConnect(socket) { + if (!socket) return false; + if (!socket.remoteAddress) { + this.tryDestroySocket(socket); + return false; + } + + let socketid = '' + this.socketCounter++; + let ip = socket.remoteAddress; + let ips = socket.headers['x-forwarded-for'] || ''; + this.sockets.set(socketid, socket); + this.sendUpstream(SOCKET_CONNECT, socketid, ip, ips, socket.protocol); + + socket.on('data', /** @param {string} message */ message => { + this.onSocketReceive(socketid, message); + }); + + socket.on('close', () => { + this.sendUpstream(SOCKET_DISCONNECT, socketid); + this.sockets.delete(socketid); + this.channels.forEach((channel, channelid) => { + if (!channel || !channel.has(socketid)) return; + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + }); + }); + + return true; + } + + /** + * Sockets.socketDisconnect message handler. + * @param {string} socketid + * @return {boolean} + */ + onSocketDisconnect(socketid) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + this.tryDestroySocket(socket); + return true; + } + + /** + * Sockets.socketSend message handler. + * + * @param {string} socketid + * @param {string} message + * @return {boolean} + */ + onSocketSend(socketid, message) { + let socket = this.sockets.get(socketid); + if (!socket) return false; + + socket.write(message); + return true; + } + + /** + * onmessage event handler for sockets. Passes the message + * upstream. + * + * @param {string} socketid + * @param {string} message + * @return {boolean} + */ + onSocketReceive(socketid, message) { + // Drop empty messages (DDOS?). + if (!message) return false; + + // Drop >100KB messages. + if (message.length > 100 * 1024) { + console.log(`Dropping client message ${message.length / 1024}KB...`); + console.log(message.slice(0, 160)); + return false; + } + + // Drop legacy JSON messages. + if ((typeof message !== 'string') || message.startsWith('{')) return false; + + // Drop invalid messages (again, DDOS?). + if (message.endsWith('|') || !message.includes('|')) return false; + + this.sendUpstream(SOCKET_RECEIVE, socketid, message); + return true; + } + + /** + * Sockets.channelAdd message handler. + * + * @param {string} channelid + * @param {string} socketid + * @return {boolean} + */ + onChannelAdd(channelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = this.channels.get(channelid); + if (!channel || channel.has(socketid)) return false; + channel.set(socketid, DEFAULT_SUBCHANNEL_ID); + } else { + let channel = new Map([[socketid, DEFAULT_SUBCHANNEL_ID]]); + this.channels.set(channelid, channel); + } + + return true; + } + + /** + * Sockets.channelRemove message handler. + * + * @param {string} channelid + * @param {string} socketid + * @return {boolean} + */ + onChannelRemove(channelid, socketid) { + let channel = this.channels.get(channelid); + if (!channel || !channel.has(socketid)) return false; + + channel.delete(socketid); + if (!channel.size) this.channels.delete(channelid); + return true; + } + + /** + * Sockets.channelSend and Sockets.channelBroadcast message + * handler. + * + * @param {string} channelid + * @param {string} message + * @return {boolean} + */ + onChannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + channel.forEach( + /** + * @param {string} subchannelid + * @param {string} socketid + */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + socket.write(message); + } + ); + + return true; + } + + /** + * Sockets.subchannelMove message handler. + * + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + * @return {boolean} + */ + onSubchannelMove(channelid, subchannelid, socketid) { + if (!this.sockets.has(socketid)) return false; + + if (this.channels.has(channelid)) { + let channel = this.channels.get(channelid); + if (channel) channel.set(socketid, subchannelid); + } else { + let channel = new Map([[socketid, subchannelid]]); + this.channels.set(channelid, channel); + } + + return true; + } + + /** + * Sockets.subchannelBroadcast message handler. + * + * @param {string} channelid + * @param {string} message + * @return {boolean} + */ + onSubchannelBroadcast(channelid, message) { + let channel = this.channels.get(channelid); + if (!channel) return false; + + let msgs = {}; + channel.forEach( + /** + * @param {string} subchannelid + * @param {string} socketid + */ + (subchannelid, socketid) => { + let socket = this.sockets.get(socketid); + if (!socket) return; + + if (!(subchannelid in msgs)) { + switch (subchannelid) { + case DEFAULT_SUBCHANNEL_ID: + msgs[subchannelid] = message.replace(SUBCHANNEL_MESSAGE_REGEX, '$1'); + break; + case P1_SUBCHANNEL_ID: + msgs[subchannelid] = message.replace(SUBCHANNEL_MESSAGE_REGEX, '$2'); + break; + case P2_SUBCHANNEL_ID: + msgs[subchannelid] = message.replace(SUBCHANNEL_MESSAGE_REGEX, '$3'); + break; + } + } + + socket.write(msgs[subchannelid]); + } + ); + + return true; + } + + /** + * Cleans up the properties of the multiplexer once an internal message + * from the parent process dictates that the worker disconnect. We can't + * use the 'disconnect' handler for this because at that point the worker + * is already disconnected. + */ + destroy() { + // @ts-ignore + clearInterval(this.cleanupInterval); + this.cleanupInterval = null; + this.sockets.forEach(socket => this.tryDestroySocket(socket)); + this.sockets.clear(); + this.channels.clear(); + } +} + +if (cluster.isWorker) { + // @ts-ignore + global.Config = require('./config/config'); + + // @ts-ignore + if (process.env.PSPORT) Config.port = +process.env.PSPORT; + // @ts-ignore + if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; + // @ts-ignore + if (+process.env.PSNOSSL) Config.ssl = null; + + if (Config.ofe) { + try { + require.resolve('ofe'); + } catch (e) { + if (e.code !== 'MODULE_NOT_FOUND') throw e; // should never happen + throw new Error( + 'ofe is not installed, but it is a required dependency if Config.ofe is set to true! ' + + 'Run npm install ofe and restart the server.' + ); + } + + // Create a heapdump if the process runs out of memory. + require('ofe').call(); + } + + // Graceful crash. + process.on('uncaughtException', err => { + if (Config.crashguard) require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); + }); + + let app = require('http').createServer(); + let appssl = Config.ssl ? require('https').createServer(Config.ssl.options) : null; + + const StaticServer = require('node-static').Server; + const roomidRegex = /^\/[A-Za-z0-9][A-Za-z0-9-]*\/?$/; + const cssServer = new StaticServer('./config'); + const avatarServer = new StaticServer('./config/avatars'); + const staticServer = new StaticServer('./static'); + /** + * @param {any} req + * @param {any} res + */ + const staticRequestHandler = (req, res) => { + // console.log(`static rq: ${req.socket.remoteAddress}:${req.socket.remotePort} -> ${req.socket.localAddress}:${req.socket.localPort} - ${req.method} ${req.url} ${request.httpVersion} - ${req.rawHeaders.join('|')}`); + req.resume(); + req.addListener('end', () => { + if (Config.customhttpresponse && + Config.customhttpresponse(req, res)) { + return; + } + + let server = staticServer; + if (req.url === '/custom.css') { + server = cssServer; + } else if (req.url.startsWith('/avatars/')) { + req.url = req.url.substr(8); + server = avatarServer; + } else if (roomidRegex.test(req.url)) { + req.url = '/'; + } + + server.serve(req, res, e => { + // @ts-ignore + if (e && e.status === 404) { + staticServer.serveFile('404.html', 404, {}, req, res); + } + }); + }); + }; + + app.on('request', staticRequestHandler); + if (appssl) appssl.on('request', staticRequestHandler); + + // Launch the SockJS server. + const sockjs = require('sockjs'); + const server = sockjs.createServer({ + sockjs_url: '//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js', + log(severity, message) { + if (severity === 'error') console.error(`Sockets worker SockJS error: ${message}`); + }, + prefix: '/showdown', + }); + + // Instantiate SockJS' multiplexer. This takes messages received downstream + // from the parent process and distributes them across the sockets they are + // targeting, as well as handling user disconnects and passing user + // messages upstream. + const multiplexer = new Multiplexer(); + + process.on('message', data => { + multiplexer.receiveDownstream(data); + }); + + // Clean up any remaining connections on disconnect. If this isn't done, + // the process will not exit until any remaining connections have been destroyed. + // Afterwards, the worker process will die on its own. + process.once('disconnect', () => { + multiplexer.destroy(); + app.close(); + if (appssl) appssl.close(); + }); + + server.on('connection', /** @param {any} socket */ socket => { + multiplexer.onSocketConnect(socket); + }); + + server.installHandlers(app, {}); + app.listen(Config.port, Config.bindaddress); + if (appssl) { + // @ts-ignore + server.installHandlers(appssl, {}); + appssl.listen(Config.ssl.port, Config.bindaddress); + } + + require('./repl').start( + `sockets-${cluster.worker.id}-${process.pid}`, + /** @param {string} cmd */ + cmd => eval(cmd) + ); +} + +module.exports = { + SUBCHANNEL_MESSAGE_REGEX, + Multiplexer, +}; diff --git a/sockets.js b/sockets.js index 2964f4befd933..0cd8f94257e56 100644 --- a/sockets.js +++ b/sockets.js @@ -4,517 +4,714 @@ * * Abstraction layer for multi-process SockJS connections. * - * This file handles all the communications between the users' - * browsers, the networking processes, and users.js in the - * main process. + * This file handles all the communications between the networking processes + * and users.js. * * @license MIT license */ 'use strict'; +const child_process = require('child_process'); const cluster = require('cluster'); -global.Config = require('./config/config'); +const EventEmitter = require('events'); +const path = require('path'); if (cluster.isMaster) { cluster.setupMaster({ - exec: require('path').resolve(__dirname, 'sockets'), + exec: path.resolve(process.cwd(), 'sockets-workers'), }); +} - const workers = exports.workers = new Map(); - - const spawnWorker = exports.spawnWorker = function () { - let worker = cluster.fork({PSPORT: Config.port, PSBINDADDR: Config.bindaddress || '0.0.0.0', PSNOSSL: Config.ssl ? 0 : 1}); - let id = worker.id; - workers.set(id, worker); - worker.on('message', data => { - // console.log('master received: ' + data); - switch (data.charAt(0)) { - case '*': { - // *socketid, ip, protocol - // connect - let nlPos = data.indexOf('\n'); - let nlPos2 = data.indexOf('\n', nlPos + 1); - Users.socketConnect(worker, id, data.slice(1, nlPos), data.slice(nlPos + 1, nlPos2), data.slice(nlPos2 + 1)); - break; - } +/** + * IPC delimiter byte. This byte must stringify as a hexadeimal + * escape code when stringified as JSON to prevent messages from being able to + * contain the byte itself. + * + * @type {string} + */ +const DELIM = '\x03'; - case '!': { - // !socketid - // disconnect - Users.socketDisconnect(worker, id, data.substr(1)); - break; - } +/** + * Map of worker IDs to worker wrappers. + * + * @type {Map} + */ +const workers = new Map(); - case '<': { - // this.onListen()); + worker.on('message', /** @param {string} data */ data => this.onMessage(data)); + worker.once('error', /** @param {?Error} err */ err => this.onError(err)); + worker.once('exit', + /** + * @param {any} worker + * @param {?number} code + * @param {?string} status + */ + (worker, code, status) => this.onExit(worker, code, status) + ); + } - default: - // unhandled - } - }); + /** + * Worker process getter + * + * @return {any} + */ + get process() { + return this.worker.process; + } - return worker; - }; - - cluster.on('exit', (worker, code, signal) => { - if (code === null && signal === 'SIGTERM') { - // worker was killed by Sockets.killWorker or Sockets.killPid - } else { - // worker crashed, try our best to clean up - require('./crashlogger')(new Error(`Worker ${worker.id} abruptly died`), "The main process"); - - // this could get called during cleanup; prevent it from crashing - // note: overwriting Worker#send is unnecessary in Node.js v7.0.0 and above - // see https://github.com/nodejs/node/commit/8c53d2fe9f102944cc1889c4efcac7a86224cf0a - worker.send = () => {}; - - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; - } - }); - console.error(`${count} connections were lost.`); - } + /** + * Worker exitedAfterDisconnect getter + * + * @return {boolean | void} + */ + get exitedAfterDisconnect() { + return this.worker.exitedAfterDisconnect; + } - // don't delete the worker, so we can investigate it if necessary. + /** + * Worker suicide getter + * + * @return {boolean | void} + */ + get suicide() { + return this.worker.exitedAfterDisconnect; + } - // attempt to recover - spawnWorker(); - }); + /** + * Worker#disconnect wrapper + * + */ + disconnect() { + return this.worker.disconnect(); + } - exports.listen = function (port, bindAddress, workerCount) { - if (port !== undefined && !isNaN(port)) { - Config.port = port; - Config.ssl = null; - } else { - port = Config.port; - - // Autoconfigure when running in cloud environments. - try { - const cloudenv = require('cloud-env'); - bindAddress = cloudenv.get('IP', bindAddress); - port = cloudenv.get('PORT', port); - } catch (e) {} - } - if (bindAddress !== undefined) { - Config.bindaddress = bindAddress; - } - if (workerCount === undefined) { - workerCount = (Config.workers !== undefined ? Config.workers : 1); - } - for (let i = 0; i < workerCount; i++) { - spawnWorker(); + /** + * Worker#kill wrapper + * + * @param {string=} signal + */ + kill(signal) { + return this.worker.kill(signal); + } + + /** + * Worker#destroy wrapper + * + * @param {string=} signal + */ + destroy(signal) { + return this.worker.kill(signal); + } + + /** + * Worker#send wrapper + * + * @param {string} message + * @return {boolean} + */ + send(message) { + return this.worker.send(message); + } + + /** + * Worker#isConnected wrapper + * + * @return {boolean} + */ + isConnected() { + return this.worker.isConnected(); + } + + /** + * Worker#isDead wrapper + * + * @return {boolean} + */ + isDead() { + return this.worker.isDead(); + } + + /** + * 'listening' event handler for the worker. Logs which + * hostname and worker ID is listening to console. + */ + onListen() { + console.log(`Worker ${this.id} now listening on ${Config.bindaddress}:${Config.port}`); + if (Config.ssl) console.log(`Worker ${this.id} now listening for SSL on port ${Config.ssl.port}`); + console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); + } + + /** + * 'message' event handler for the worker. Parses which type + * of command the incoming IPC message is calling, then passes its + * parametres to the appropriate method to handle. + * + * @param {string} data + */ + onMessage(data) { + // console.log(`master received: ${data}`); + let token = data.charAt(0); + let params = data.substr(1); + switch (token) { + case '*': + this.onSocketConnect(params); + break; + case '!': + this.onSocketDisconnect(params); + break; + case '<': + this.onSocketReceive(params); + break; + default: + console.error(`Sockets: received unknown IPC message with token ${token}: ${params}`); + break; } - }; - - exports.killWorker = function (worker) { - let count = 0; - Users.connections.forEach(connection => { - if (connection.worker === worker) { - Users.socketDisconnect(worker, worker.id, connection.socketid); - count++; + } + + /** + * Socket connection message handler. + * + * @param {string} params + */ + onSocketConnect(params) { + let [socketid, ip, header, protocol] = params.split('\n'); + + if (this.isTrustedProxyIp(ip)) { + let ips = header.split(','); + for (let i = ips.length; i--;) { + let proxy = ips[i].trim(); + if (proxy && !this.isTrustedProxyIp(proxy)) { + ip = proxy; + break; + } } - }); - console.log(`${count} connections were lost.`); + } - try { - worker.kill('SIGTERM'); - } catch (e) {} - workers.delete(worker.id); + Users.socketConnect(this, this.id, socketid, ip, protocol); + } - return count; - }; + /** + * Socket disconnect handler. + * + * @param {string} socketid + */ + onSocketDisconnect(socketid) { + Users.socketDisconnect(this, this.id, socketid); + } - exports.killPid = function (pid) { - pid = '' + pid; - for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars - if (pid === '' + worker.process.pid) { - return this.killWorker(worker); + /** + * Socket message receive handler. + * + * @param {string} params + */ + onSocketReceive(params) { + let idx = params.indexOf('\n'); + let socketid = params.substr(0, idx); + let message = params.substr(idx + 1); + Users.socketReceive(this, this.id, socketid, message); + } + + /** + * Worker 'error' event handler. + * + * @param {?Error} err + */ + onError(err) { + this.error = err; + } + + /** + * Worker 'exit' event handler. + * + * @param {any} worker + * @param {?number} code + * @param {?string} signal + */ + onExit(worker, code, signal) { + if (code === null && signal !== null) { + // Worker was killed by Sockets.killWorker or Sockets.killPid. + console.warn(`Worker ${this.id} was forcibly killed with the signal ${signal}`); + workers.delete(worker.id); + } else if (code === 0 && signal === null) { + console.warn(`Worker ${this.id} died, but returned a successful exit code.`); + workers.delete(worker.id); + } else if (code !== null && code > 0) { + // Worker crashed. + if (this.error) { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died with the following stack trace: ${this.error.stack}`), 'The main process'); + } else { + require('./crashlogger')(new Error(`Worker ${this.id} abruptly died`), 'The main process'); } - } - return false; - }; - - exports.socketSend = function (worker, socketid, message) { - worker.send(`>${socketid}\n${message}`); - }; - exports.socketDisconnect = function (worker, socketid) { - worker.send(`!${socketid}`); - }; - - exports.channelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`#${channelid}\n${message}`); - }); - }; - exports.channelSend = function (worker, channelid, message) { - worker.send(`#${channelid}\n${message}`); - }; - exports.channelAdd = function (worker, channelid, socketid) { - worker.send(`+${channelid}\n${socketid}`); - }; - exports.channelRemove = function (worker, channelid, socketid) { - worker.send(`-${channelid}\n${socketid}`); - }; - - exports.subchannelBroadcast = function (channelid, message) { - workers.forEach(worker => { - worker.send(`:${channelid}\n${message}`); - }); - }; - exports.subchannelMove = function (worker, channelid, subchannelid, socketid) { - worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); - }; -} else { - // is worker - - if (process.env.PSPORT) Config.port = +process.env.PSPORT; - if (process.env.PSBINDADDR) Config.bindaddress = process.env.PSBINDADDR; - if (+process.env.PSNOSSL) Config.ssl = null; - - if (Config.ofe) { - try { - require.resolve('ofe'); - } catch (e) { - if (e.code !== 'MODULE_NOT_FOUND') throw e; // should never happen - throw new Error( - 'ofe is not installed, but it is a required dependency if Config.ofe is set to true! ' + - 'Run npm install ofe and restart the server.' - ); + // Don't delete the worker - keep it for inspection. } - // Create a heapdump if the process runs out of memory. - require('ofe').call(); + if (this.isConnected()) this.disconnect(); + // FIXME: this is a bad hack to get around a race condition in + // Connection#onDiscconnect sending room deinit messages after already + // having removed the sockets from their channels. + // @ts-ignore + this.send = () => {}; + + let count = Users.socketDisconnectAll(this); + console.log(`${count} connections were lost.`); + + spawnWorker(); } +} - // Static HTTP server +/** + * A mock Worker class for Go child processes. Similarly to + * Node.js workers, it uses a TCP net server to perform IPC. After launching + * the server, it will spawn the Go child process and wait for it to make a + * connection to the worker's server before performing IPC with it. + */ +class GoWorker extends EventEmitter { + /** + * @param {number} id + */ + constructor(id) { + super(); + + /** @type {number} */ + this.id = id; + /** @type {boolean | void} */ + this.exitedAfterDisconnect = undefined; + + /** @type {string[]} */ + this.obuf = []; + /** @type {string} */ + this.ibuf = ''; + /** @type {?Error} */ + this.error = null; + + /** @type {any} */ + this.process = null; + /** @type {any} */ + this.connection = null; + /** @type {any} */ + this.server = require('net').createServer(); + this.server.once('connection', /** @param {any} connection */ connection => this.onChildConnect(connection)); + this.server.on('error', () => {}); + this.server.listen(() => process.nextTick(() => this.spawnChild())); + } - // This handles the custom CSS and custom avatar features, and also - // redirects yourserver:8001 to yourserver-8001.psim.us + /** + * Worker#disconnect mock + */ + disconnect() { + if (this.isConnected()) this.connection.destroy(); + } - // It's optional if you don't need these features. + /** + * Worker#kill mock + * + * @param {string} [signal = 'SIGTERM'] + */ + kill(signal = 'SIGTERM') { + if (this.process) this.process.kill(signal); + } - global.Dnsbl = require('./dnsbl'); + /** + * Worker#destroy mock + * + * @param {string=} signal + */ + destroy(signal) { + return this.kill(signal); + } - if (Config.crashguard) { - // graceful crash - process.on('uncaughtException', err => { - require('./crashlogger')(err, `Socket process ${cluster.worker.id} (${process.pid})`, true); - }); + /** + * Worker#send mock + * + * @param {string} message + * @return {boolean} + */ + send(message) { + if (!this.isConnected()) { + this.obuf.push(message); + return false; + } + + if (this.obuf.length) { + this.obuf.splice(0).forEach(msg => { + this.connection.write(JSON.stringify(msg) + DELIM); + }); + } + + return this.connection.write(JSON.stringify(message) + DELIM); } - let app = require('http').createServer(); - let appssl = Config.ssl ? require('https').createServer(Config.ssl.options) : null; - - // Static server - const StaticServer = require('node-static').Server; - const roomidRegex = /^\/(?:[A-Za-z0-9][A-Za-z0-9-]*)\/?$/; - const cssServer = new StaticServer('./config'); - const avatarServer = new StaticServer('./config/avatars'); - const staticServer = new StaticServer('./static'); - const staticRequestHandler = (req, res) => { - // console.log(`static rq: ${req.socket.remoteAddress}:${req.socket.remotePort} -> ${req.socket.localAddress}:${req.socket.localPort} - ${req.method} ${req.url} ${req.httpVersion} - ${req.rawHeaders.join('|')}`); - req.resume(); - req.addListener('end', () => { - if (Config.customhttpresponse && - Config.customhttpresponse(req, res)) { - return; - } + /** + * Worker#isConnected mock + * + * @return {boolean} + */ + isConnected() { + return this.connection && !this.connection.destroyed; + } - let server = staticServer; - if (req.url === '/custom.css') { - server = cssServer; - } else if (req.url.startsWith('/avatars/')) { - req.url = req.url.substr(8); - server = avatarServer; - } else if (roomidRegex.test(req.url)) { - req.url = '/'; - } + /** + * Worker#isDead mock + * + * @return {boolean} + */ + isDead() { + return this.connection && !this.connection.destroyed; + } - server.serve(req, res, e => { - if (e && (e.status === 404)) { - staticServer.serveFile('404.html', 404, {}, req, res); + /** + * Spawns the Go child process. Once the process has started, it will make + * a connection to the worker's TCP server. + */ + spawnChild() { + const GOPATH = child_process.execSync('go env GOPATH', {stdio: null, encoding: 'utf8'}) + .trim() + .split(path.delimiter)[0] + .replace(/^"(.*)"$/, '$1'); + + this.process = child_process.spawn( + path.resolve(GOPATH, 'bin/sockets'), [], { + env: { + PS_IPC_PORT: `:${this.server.address().port}`, + PS_CONFIG: JSON.stringify({ + workers: Config.workers || 1, + port: `:${Config.port || 8000}`, + bindAddress: Config.bindaddress || '0.0.0.0', + ssl: Config.ssl ? Object.assign({}, Config.ssl, {port: `:${Config.ssl.port}`}) : null, + }), + }, + stdio: ['inherit', 'inherit', 'pipe'], + } + ); + + this.process.once('exit', /** @param {any[]} args */ (...args) => { + // Clean up the IPC server. + this.server.close(() => { + // @ts-ignore + if (this.server._eventsCount <= 2) { + // The child process died before ever opening the IPC + // connection and sending any messages over it. Let's avoid + // getting trapped in an endless loop of respawns and crashes + // if it crashed. + if (this.error) throw this.error; } + this.emit('exit', this, ...args); }); }); - }; - app.on('request', staticRequestHandler); - if (appssl) appssl.on('request', staticRequestHandler); - - // SockJS server - - // This is the main server that handles users connecting to our server - // and doing things on our server. + this.process.stderr.setEncoding('utf8'); + this.process.stderr.once('data', /** @param {string} data */ data => { + this.error = new Error(data); + this.emit('error', this.error); + }); + } - const sockjs = require('sockjs'); - const server = sockjs.createServer({ - sockjs_url: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", - log: (severity, message) => { - if (severity === 'error') console.log('ERROR: ' + message); - }, - prefix: '/showdown', - }); + /** + * 'connection' event handler for the TCP server. Begins the parsing of + * incoming IPC messages. + * @param {any} connection + */ + onChildConnect(connection) { + this.connection = connection; + this.connection.setEncoding('utf8'); + this.connection.on('data', /** @param {string} data */ data => { + let idx = data.lastIndexOf(DELIM); + if (idx < 0) { + // Very long message... + this.ibuf += data; + return; + } - const sockets = new Map(); - const channels = new Map(); - const subchannels = new Map(); - - // Deal with phantom connections. - const sweepClosedSockets = () => { - sockets.forEach(socket => { - if (socket.protocol === 'xhr-streaming' && - socket._session && - socket._session.recv) { - socket._session.recv.didClose(); + // Because of how Node handles TCP connections, we can + // receive any number of messages, and they may not + // be guaranteed to be complete. + let messages = this.ibuf; + this.ibuf = ''; + if (idx === data.length - 1) { + messages += data.slice(0, -1); + } else { + messages += data.slice(0, idx); + this.ibuf += data.slice(idx + 1); } - // A ghost connection's `_session.to_tref._idlePrev` (and `_idleNext`) property is `null` while - // it is an object for normal users. Under normal circumstances, those properties should only be - // `null` when the timeout has already been called, but somehow it's not happening for some connections. - // Simply calling `_session.timeout_cb` (the function bound to the aformentioned timeout) manually - // on those connections kills those connections. For a bit of background, this timeout is the timeout - // that sockjs sets to wait for users to reconnect within that time to continue their session. - if (socket._session && - socket._session.to_tref && - !socket._session.to_tref._idlePrev) { - socket._session.timeout_cb(); + for (let message of messages.split(DELIM)) { + this.emit('message', JSON.parse(message)); } }); - }; - const interval = setInterval(sweepClosedSockets, 1000 * 60 * 10); // eslint-disable-line no-unused-vars - - process.on('message', data => { - // console.log('worker received: ' + data); - let socket = null; - let socketid = ''; - let channel = null; - let channelid = ''; - let subchannel = null; - let subchannelid = ''; - let nlLoc = -1; - let message = ''; - - switch (data.charAt(0)) { - case '$': // $code - eval(data.substr(1)); - break; + this.connection.on('error', () => {}); - case '!': // !socketid - // destroy - socketid = data.substr(1); - socket = sockets.get(socketid); - if (!socket) return; - socket.destroy(); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - break; - - case '>': - // >socketid, message - // message - nlLoc = data.indexOf('\n'); - socketid = data.substr(1, nlLoc - 1); - socket = sockets.get(socketid); - if (!socket) return; - message = data.substr(nlLoc + 1); - socket.write(message); - break; - - case '#': - // #channelid, message - // message to channel - nlLoc = data.indexOf('\n'); - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) return; - message = data.substr(nlLoc + 1); - channel.forEach(socket => socket.write(message)); - break; + process.nextTick(() => this.emit('listening')); + } +} - case '+': - // +channelid, socketid - // add to channel - nlLoc = data.indexOf('\n'); - socketid = data.substr(nlLoc + 1); - socket = sockets.get(socketid); - if (!socket) return; - channelid = data.substr(1, nlLoc - 1); - channel = channels.get(channelid); - if (!channel) { - channel = new Map(); - channels.set(channelid, channel); - } - channel.set(socketid, socket); - break; +/** + * Worker ID counter. We don't use cluster's internal counter so + * Config.golang can be freely changed while the server is still running. + * + * @type {number} + */ +let nextWorkerid = 1; - case '-': - // -channelid, socketid - // remove from channel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - socketid = data.slice(nlLoc + 1); - channel.delete(socketid); - subchannel = subchannels.get(channelid); - if (subchannel) subchannel.delete(socketid); - if (!channel.size) { - channels.delete(channelid); - if (subchannel) subchannels.delete(channelid); - } - break; +/** + * Config.golang cache. Checked when spawning new workers to + * ensure that Node and Go workers will not try to run at the same time. + * + * @type {boolean} + */ +let golangCache = !!Config.golang; - case '.': - // .channelid, subchannelid, socketid - // move subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - let nlLoc2 = data.indexOf('\n', nlLoc + 1); - subchannelid = data.slice(nlLoc + 1, nlLoc2); - socketid = data.slice(nlLoc2 + 1); - - subchannel = subchannels.get(channelid); - if (!subchannel) { - subchannel = new Map(); - subchannels.set(channelid, subchannel); +/** + * Spawns a new worker. + * + * @return {WorkerWrapper} + */ +function spawnWorker() { + if (golangCache === !Config.golang) { + // Config settings were changed. Make sure none of the wrong kind of + // worker is already listening. + let workerType = Config.golang ? GoWorker : cluster.Worker; + for (let [workerid, worker] of workers) { + if (worker.isConnected() && !(worker.worker instanceof workerType)) { + let oldType = golangCache ? 'Go' : 'Node'; + let newType = Config.golang ? 'Go' : 'Node'; + throw new Error( + `Sockets: worker of ID ${workerid} is a ${oldType} worker, but config was changed to spawn ${newType} ones! + Set Config.golang back to ${golangCache} or kill all active workers before attempting to spawn more.` + ); } - if (subchannelid === '0') { - subchannel.delete(socketid); - } else { - subchannel.set(socketid, subchannelid); + } + golangCache = !!Config.golang; + } else if (golangCache) { + // Prevent spawning multiple Go child processes by accident. + for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars + if (worker.isConnected() && worker.worker instanceof GoWorker) { + throw new Error('Sockets: multiple Go child processes cannot be spawned!'); } - break; - - case ':': - // :channelid, message - // message to subchannel - nlLoc = data.indexOf('\n'); - channelid = data.slice(1, nlLoc); - channel = channels.get(channelid); - if (!channel) return; - - let messages = [null, null, null]; - message = data.substr(nlLoc + 1); - subchannel = subchannels.get(channelid); - channel.forEach((socket, socketid) => { - switch (subchannel ? subchannel.get(socketid) : '0') { - case '1': - if (!messages[1]) { - messages[1] = message.replace(/\n\|split\n[^\n]*\n([^\n]*)\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[1]); - break; - case '2': - if (!messages[2]) { - messages[2] = message.replace(/\n\|split\n[^\n]*\n[^\n]*\n([^\n]*)\n[^\n]*/g, '\n$1'); - } - socket.write(messages[2]); - break; - default: - if (!messages[0]) { - messages[0] = message.replace(/\n\|split\n([^\n]*)\n[^\n]*\n[^\n]*\n[^\n]*/g, '\n$1'); - } - socket.write(messages[0]); - break; - } - }); - break; } - }); + } - // Clean up any remaining connections on disconnect. If this isn't done, - // the process will not exit until any remaining connections have been destroyed. - // Afterwards, the worker process will die on its own. - process.once('disconnect', () => { - sockets.forEach(socket => { - try { - socket.destroy(); - } catch (e) {} + let worker; + if (golangCache) { + worker = new GoWorker(nextWorkerid); + } else { + worker = cluster.fork({ + PSPORT: Config.port, + PSBINDADDR: Config.bindaddress || '0.0.0.0', + PSNOSSL: Config.ssl ? 0 : 1, }); - sockets.clear(); - channels.clear(); - subchannels.clear(); - app.close(); - if (appssl) appssl.close(); - }); + } - // this is global so it can be hotpatched if necessary - let isTrustedProxyIp = Dnsbl.checker(Config.proxyip); - let socketCounter = 0; - server.on('connection', socket => { - if (!socket) { - // For reasons that are not entirely clear, SockJS sometimes triggers - // this event with a null `socket` argument. - return; - } else if (!socket.remoteAddress) { - // This condition occurs several times per day. It may be a SockJS bug. - try { - socket.destroy(); - } catch (e) {} - return; - } + let wrapper = new WorkerWrapper(worker, nextWorkerid++); + workers.set(wrapper.id, wrapper); + return wrapper; +} - let socketid = socket.id = '' + (++socketCounter); - sockets.set(socketid, socket); +/** + * Initializes the configured number of worker processes. + * + * @param {any} port + * @param {any} bindAddress + * @param {any} workerCount + */ +function listen(port, bindAddress, workerCount) { + if (port !== undefined && !isNaN(port)) { + Config.port = port; + Config.ssl = null; + } else { + port = Config.port; + // Autoconfigure the app when running in cloud hosting environments: + try { + const cloudenv = require('cloud-env'); + bindAddress = cloudenv.get('IP', bindAddress); + port = cloudenv.get('PORT', port); + } catch (e) {} + } + if (bindAddress !== undefined) { + Config.bindaddress = bindAddress; + } - if (isTrustedProxyIp(socket.remoteAddress)) { - let ips = (socket.headers['x-forwarded-for'] || '').split(','); - let ip; - while ((ip = ips.pop())) { - ip = ip.trim(); - if (!isTrustedProxyIp(ip)) { - socket.remoteAddress = ip; - break; - } - } - } + // Go only uses one child process since it does not share FD handles for + // serving like Node.js workers do. Workers are instead used to limit the + // number of concurrent requests that can be handled at once in the child + // process. + if (golangCache) { + spawnWorker(); + return; + } - process.send(`*${socketid}\n${socket.remoteAddress}\n${socket.protocol}`); + if (workerCount === undefined) { + workerCount = (Config.workers !== undefined ? Config.workers : 1); + } + for (let i = 0; i < workerCount; i++) { + spawnWorker(); + } +} - socket.on('data', message => { - // drop empty messages (DDoS?) - if (!message) return; - // drop messages over 100KB - if (message.length > 100000) { - console.log(`Dropping client message ${message.length / 1024} KB...`); - console.log(message.slice(0, 160)); - return; - } - // drop legacy JSON messages - if (typeof message !== 'string' || message.startsWith('{')) return; - // drop blank messages (DDoS?) - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0 || pipeIndex === message.length - 1) return; +/** + * Kills a worker process using the given worker object. + * + * @param {WorkerWrapper} worker + * @return {number} + */ +function killWorker(worker) { + let count = Users.socketDisconnectAll(worker); + console.log(`${count} connections were lost.`); + try { + worker.kill('SIGTERM'); + } catch (e) {} + workers.delete(worker.id); + return count; +} - process.send(`<${socketid}\n${message}`); - }); +/** + * Kills a worker process using the given worker PID. + * + * @param {number} pid + * @return {number | false} + */ +function killPid(pid) { + for (let [workerid, worker] of workers) { // eslint-disable-line no-unused-vars + if (pid === worker.process.pid) { + return killWorker(worker); + } + } + return false; +} - socket.on('close', () => { - process.send(`!${socketid}`); - sockets.delete(socketid); - channels.forEach(channel => channel.delete(socketid)); - }); +/** + * Sends a message to a socket in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} socketid + * @param {string} message + */ +function socketSend(worker, socketid, message) { + worker.send(`>${socketid}\n${message}`); +} + +/** + * Forcefully disconnects a socket in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} socketid + */ +function socketDisconnect(worker, socketid) { + worker.send(`!${socketid}`); +} + +/** + * Broadcasts a message to all sockets in a given channel across + * all workers. + * + * @param {string} channelid + * @param {string} message + */ +function channelBroadcast(channelid, message) { + workers.forEach(worker => { + worker.send(`#${channelid}\n${message}`); }); - server.installHandlers(app, {}); - app.listen(Config.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening on ${Config.bindaddress}:${Config.port}`); +} - if (appssl) { - server.installHandlers(appssl, {}); - appssl.listen(Config.ssl.port, Config.bindaddress); - console.log(`Worker ${cluster.worker.id} now listening for SSL on port ${Config.ssl.port}`); - } +/** + * Broadcasts a message to all sockets in a given channel and a + * given worker. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} message + */ +function channelSend(worker, channelid, message) { + worker.send(`#${channelid}\n${message}`); +} - console.log(`Test your server at http://${Config.bindaddress === '0.0.0.0' ? 'localhost' : Config.bindaddress}:${Config.port}`); +/** + * Adds a socket to a given channel in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} socketid + */ +function channelAdd(worker, channelid, socketid) { + worker.send(`+${channelid}\n${socketid}`); +} - require('./repl').start(`sockets-${cluster.worker.id}-${process.pid}`, cmd => eval(cmd)); +/** + * Removes a socket from a given channel in a given worker by ID. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} socketid + */ +function channelRemove(worker, channelid, socketid) { + worker.send(`-${channelid}\n${socketid}`); } + +/** + * Broadcasts a message to be demuxed into three separate messages + * across three subchannels in a given channel across all workers. + * + * @param {string} channelid + * @param {string} message + */ +function subchannelBroadcast(channelid, message) { + workers.forEach(worker => { + worker.send(`:${channelid}\n${message}`); + }); +} + +/** + * Moves a given socket to a different subchannel in a channel by + * ID in the given worker. + * + * @param {WorkerWrapper} worker + * @param {string} channelid + * @param {string} subchannelid + * @param {string} socketid + */ +function subchannelMove(worker, channelid, subchannelid, socketid) { + worker.send(`.${channelid}\n${subchannelid}\n${socketid}`); +} + +module.exports = { + WorkerWrapper, + GoWorker, + + workers, + spawnWorker, + listen, + killWorker, + killPid, + + socketSend, + socketDisconnect, + channelBroadcast, + channelSend, + channelAdd, + channelRemove, + subchannelBroadcast, + subchannelMove, +}; diff --git a/sockets/lib/commands.go b/sockets/lib/commands.go new file mode 100644 index 0000000000000..fc48c24c1dbb7 --- /dev/null +++ b/sockets/lib/commands.go @@ -0,0 +1,87 @@ +/** + * Commands + * https://pokemonshowdown.com/ + * + * Commands are an abstraction over IPC messages sent to and received from the + * parent process. Each message follows a specific syntax: a one character + * token, followed by any number of parametres separated by newlines. Commands + * give the multiplexer and IPC connection a simple way to determine which + * struct it's meant to be handled by, before enqueueing it to be distributed + * to workers to finally process their payload concurrently. + */ + +package sockets + +import "strings" + +// IPC message types +const ( + SOCKET_CONNECT byte = '*' + SOCKET_DISCONNECT byte = '!' + SOCKET_RECEIVE byte = '<' + SOCKET_SEND byte = '>' + CHANNEL_ADD byte = '+' + CHANNEL_REMOVE byte = '-' + CHANNEL_BROADCAST byte = '#' + SUBCHANNEL_MOVE byte = '.' + SUBCHANNEL_BROADCAST byte = ':' +) + +var PARAM_COUNTS = map[byte]int{ + SOCKET_CONNECT: 4, + SOCKET_DISCONNECT: 1, + SOCKET_RECEIVE: 2, + SOCKET_SEND: 2, + CHANNEL_ADD: 2, + CHANNEL_REMOVE: 2, + CHANNEL_BROADCAST: 2, + SUBCHANNEL_MOVE: 3, + SUBCHANNEL_BROADCAST: 2, +} + +type Command struct { + token byte // Token designating the type of command. + params []string // The command parametre list, parsed. + target CommandIO // The target to process this command. +} + +// The multiplexer and the IPC connection both implement this interface. Its +// purpose is solely to allow the two structs to be used in Command. +type CommandIO interface { + Process(*Command) error // Invokes one of its methods using the command's token and parametres. +} + +func NewCommand(msg string, target CommandIO) *Command { + token := msg[0] + count := PARAM_COUNTS[token] + params := strings.SplitN(msg[1:], "\n", count) + return &Command{ + token: token, + params: params, + target: target, + } +} + +func BuildCommand(target CommandIO, token byte, params ...string) *Command { + return &Command{ + token: token, + params: params, + target: target, + } +} + +func (c *Command) Token() byte { + return c.token +} + +func (c *Command) Params() []string { + return c.params +} + +func (c *Command) Message() string { + return string(c.token) + strings.Join(c.params, "\n") +} + +func (c *Command) Process() error { + return c.target.Process(c) +} diff --git a/sockets/lib/commands_test.go b/sockets/lib/commands_test.go new file mode 100644 index 0000000000000..cb85b37fb9e33 --- /dev/null +++ b/sockets/lib/commands_test.go @@ -0,0 +1,26 @@ +package sockets + +import "testing" + +type testTarget struct { + CommandIO +} + +func TestCommands(t *testing.T) { + tokens := []byte{ + SOCKET_CONNECT, + SOCKET_DISCONNECT, + SOCKET_RECEIVE, + SOCKET_SEND, + CHANNEL_ADD, + CHANNEL_REMOVE, + CHANNEL_BROADCAST, + SUBCHANNEL_MOVE, + SUBCHANNEL_BROADCAST, + } + + cmds := make([]*Command, len(tokens)) + for i, token := range tokens { + cmds[i] = NewCommand(string(token)+"1\n2\n3\n4", testTarget{}) + } +} diff --git a/sockets/lib/config.go b/sockets/lib/config.go new file mode 100644 index 0000000000000..98bb907507abb --- /dev/null +++ b/sockets/lib/config.go @@ -0,0 +1,38 @@ +/** + * Config + * https://pokemonshowdown.com/ + * + * Config is a struct representing the config settings the parent process + * passed to us by stringifying pertinent settings as JSON and assigning it to + * the $PS_CONFIG environment variable. + */ + +package sockets + +import ( + "encoding/json" + "os" +) + +type sslcert struct { + Cert string `json:"cert"` // Path to the SSL certificate. + Key string `json:"key"` // Path to the SSL key. +} + +type sslconf struct { + Port string `json:"port"` // HTTPS server port. + Options sslcert `json:"options,omitempty"` // SSL config settings. +} + +type config struct { + Workers int `json:"workers"` // Number of workers for the master to spawn. + Port string `json:"port"` // HTTP server port. + BindAddress string `json:"bindAddress"` // HTTP/HTTPS server(s) hostname. + SSL sslconf `json:"ssl,omitempty"` // HTTPS config settings. +} + +func NewConfig(envVar string) (c config, err error) { + configEnv := os.Getenv(envVar) + err = json.Unmarshal([]byte(configEnv), &c) + return +} diff --git a/sockets/lib/config_test.go b/sockets/lib/config_test.go new file mode 100644 index 0000000000000..34495d055f39b --- /dev/null +++ b/sockets/lib/config_test.go @@ -0,0 +1,36 @@ +package sockets + +import ( + "encoding/json" + "testing" +) + +func TestConfig(t *testing.T) { + var c config + cj := []byte(`{"workers": 1, "port": ":8000", "bindAddress": "0.0.0.0", "ssl": null}`) + err := json.Unmarshal(cj, &c) + if err != nil { + t.Errorf("Sockets: failed to parse config JSON with SSL being null: %v", err) + } + if c.SSL.Port != "" || c.SSL.Options.Cert != "" || c.SSL.Options.Key != "" { + t.Errorf("Sockets: config failed to omit null SSL config") + } + + c.SSL = sslconf{ + Port: ":443", + Options: sslcert{ + Cert: "", + Key: "", + }, + } + + cj, _ = json.Marshal(c) + if err != nil { + t.Errorf("Sockets: failed to stringify config JSON: %v", err) + } + + err = json.Unmarshal(cj, &c) + if err != nil { + t.Errorf("Sockets: failed to parse config JSON containing SSL config") + } +} diff --git a/sockets/lib/ipc.go b/sockets/lib/ipc.go new file mode 100644 index 0000000000000..4347245b07b81 --- /dev/null +++ b/sockets/lib/ipc.go @@ -0,0 +1,139 @@ +/** + * IPC - Inter-Process Communication + * https://pokemonshowdown.com/ + * + * This handles all communication between us and the parent process. The parent + * process creates a local TCP server using a random port. The port is passed + * down to us through the $PS_IPC_PORT environment variable. A TCP connection + * to the parent's server is created, allowing us to send messages back and + * forth. + */ + +package sockets + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" + "time" +) + +// This must be a byte that stringifies to either a hexadecimal escape code. +// Otherwise, it would be possible for someone to send a message with the +// delimiter and break up messages. +const DELIM byte = '\x03' + +type Connection struct { + addr *net.TCPAddr // Parent process' TCP server address. + conn *net.TCPConn // Connection to the parent process' TCP server. + mux *Multiplexer // Target for commands originating from here. + listening bool // Whether or not this is connected and listening for IPC messages. +} + +func NewConnection(envVar string) (*Connection, error) { + port := os.Getenv(envVar) + addr, err := net.ResolveTCPAddr("tcp", "localhost"+port) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to parse TCP address to connect to the parent process with: %v", err) + } + + conn, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, fmt.Errorf("Sockets: failed to connect to TCP server: %v", err) + } + + c := &Connection{ + addr: addr, + conn: conn, + listening: false, + } + + return c, nil +} + +func (c *Connection) Listening() bool { + return c.listening +} + +func (c *Connection) Listen(mux *Multiplexer) { + if c.listening { + return + } + + c.mux = mux + c.listening = true + + go func() { + reader := bufio.NewReader(c.conn) + for { + token, err := reader.ReadBytes(DELIM) + if len(token) == 0 || err != nil { + continue + } + + var msg string + err = json.Unmarshal(token[:len(token)-1], &msg) + if err != nil { + continue + } + + go func() { + cmd := NewCommand(msg, c.mux) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + }() +} + +// Final step in evaluating commands targeted at the IPC connection. +func (c *Connection) Process(cmd *Command) error { + // fmt.Printf("Sockets => IPC: %v\n", cmd.Message()) + if !c.listening { + return fmt.Errorf("Sockets: can't process connection commands when the connection isn't listening yet") + } + + msg := cmd.Message() + _, err := c.write(msg) + return err +} + +func (c *Connection) Close() error { + if !c.listening { + return nil + } + + return c.conn.Close() +} + +func (c *Connection) write(msg string) (int, error) { + if !c.listening { + return 0, fmt.Errorf("Sockets: can't write messages over a connection that isn't listening yet...") + } + + data, err := json.Marshal(msg) + if err != nil { + return 0, fmt.Errorf("Sockets: failed to parse upstream IPC message: %v", err) + } + + // The max allowed length for a message that Multiplexer.socketReceive will + // not drop is short enough for us not to need to buffer here. + return c.conn.Write(append(data, DELIM)) +} + +func (c *Connection) SendCommand(token byte, args ...string) error { + if !c.listening { + return fmt.Errorf("Sockets: cannot send commands to the parent process before connecting") + } + + go func() { + cmd := BuildCommand(c.mux, token, args...) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + return nil +} diff --git a/sockets/lib/ipc_test.go b/sockets/lib/ipc_test.go new file mode 100644 index 0000000000000..4981aaddbe9a4 --- /dev/null +++ b/sockets/lib/ipc_test.go @@ -0,0 +1,38 @@ +package sockets + +import ( + "net" + "os" + "testing" +) + +func TestConnection(t *testing.T) { + port := ":3000" + ln, err := net.Listen("tcp", "localhost"+port) + defer ln.Close() + if err != nil { + t.Errorf("Sockets: failed to launch TCP server on port %v: %v", port, err) + } + + envVar := "PS_IPC_PORT" + err = os.Setenv(envVar, port) + if err != nil { + t.Errorf("Sockets: failed to set %v environment variable: %v", envVar, port) + } + + conn, err := NewConnection(envVar) + defer conn.Close() + if err != nil { + t.Errorf("%v", err) + } + + mux := NewMultiplexer() + mux.Listen(conn) + conn.Listen(mux) + + cmd := BuildCommand(mux, SOCKET_SEND, "0", "|ayy lmao") + err = conn.Process(cmd) + if err != nil { + t.Errorf("%v", err) + } +} diff --git a/sockets/lib/master.go b/sockets/lib/master.go new file mode 100644 index 0000000000000..4b5217067f5cc --- /dev/null +++ b/sockets/lib/master.go @@ -0,0 +1,82 @@ +/** + * Master - Master/Worker pattern implementation + * https://pokemonshowdown.com/ + * + * This makes it possible to parse messages from sockets and the parent process + * concurrently. A command queue stores commands created by the multiplexer and + * IPC connection. The master contains a pool of command channels belonging to + * workers. Once a command is available in the queue, the master takes a + * worker's command channel from its pool and enqueues it. The worker takes the + * command and processes it before enqueueing its command channel back into the + * master's pool. The workers are distributed round-robin, much like Node's + * cluster module (when not using Windows). + */ + +package sockets + +// A global command channel for the multiplexer and IPC connection to enqueue +// their new commands to be processed by the workers. +var CmdQueue = make(chan *Command) + +type master struct { + wpool chan chan *Command // Pool of worker command queues. + count int // Number of workers. +} + +func NewMaster(count int) *master { + wpool := make(chan chan *Command, count) + return &master{ + wpool: wpool, + count: count, + } +} + +// Create the initial set of workers and make them listen before the master. +func (m *master) Spawn() { + for i := 0; i < m.count; i++ { + w := newWorker(m.wpool) + w.listen() + } +} + +// Listen for new commands to remove from the command queue and pass to the +// first available worker. +func (m *master) Listen() { + for { + cmd := <-CmdQueue + cmdch := <-m.wpool + cmdch <- cmd + } +} + +type worker struct { + wpool chan chan *Command // The master's pool of worker command queues. + cmdch chan *Command // Queue for incoming commands from CmdQueue. + quit chan bool // Channel used to kill the worker when needed. +} + +func newWorker(wpool chan chan *Command) *worker { + cmdch := make(chan *Command) + quit := make(chan bool) + return &worker{ + wpool: wpool, + cmdch: cmdch, + quit: quit, + } +} + +func (w *worker) listen() { + go func() { + for { + w.wpool <- w.cmdch + select { + case cmd := <-w.cmdch: + // Invokes *Multiplexer.Process or *Connection.Process, where + // the command is finally handled and used to update state. + cmd.target.Process(cmd) + case <-w.quit: + return + } + } + }() +} diff --git a/sockets/lib/master_test.go b/sockets/lib/master_test.go new file mode 100644 index 0000000000000..7fdcb163ef5b5 --- /dev/null +++ b/sockets/lib/master_test.go @@ -0,0 +1,85 @@ +package sockets + +import ( + "net" + "net/http" + "strconv" + "testing" + + "github.com/igm/sockjs-go/sockjs" +) + +type testSocket struct { + sockjs.Session +} + +func (ts testSocket) Send(msg string) error { + return nil +} + +func (ts testSocket) Close(code uint32, signal string) error { + return nil +} + +func (ts testSocket) Request() *http.Request { + return &http.Request{} +} + +func TestMasterListen(t *testing.T) { + t.Parallel() + ln, _ := net.Listen("tcp", "localhost:3000") + defer ln.Close() + + conn, _ := NewConnection("PS_IPC_PORT") + defer conn.Close() + + mux := NewMultiplexer() + mux.Listen(conn) + conn.Listen(mux) + + m := NewMaster(4) + m.Spawn() + go m.Listen() + + for i := 0; i < m.count*250; i++ { + id := strconv.Itoa(i) + t.Run("Worker/Multiplexer command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := strconv.FormatUint(mux.nsid, 10) + mux.sockets[sid] = testSocket{} + mux.nsid++ + mux.smux.Unlock() + + cmd := BuildCommand(mux, SOCKET_DISCONNECT, sid) + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to multiplexer") + } + + mux.socketRemove(sid, true) + }(id, mux, conn) + }) + t.Run("Worker/Connection command #"+id, func(t *testing.T) { + go func(id string, mux *Multiplexer, conn *Connection) { + mux.smux.Lock() + sid := strconv.FormatUint(mux.nsid, 10) + mux.sockets[sid] = testSocket{} + mux.nsid++ + mux.smux.Unlock() + + cmd := BuildCommand(conn, SOCKET_CONNECT, sid, "0.0.0.0", "", "websocket") + cmd.Process() + if len(CmdQueue) != 0 { + t.Error("Sockets: master failed to pass command struct from worker to connection") + } + + mux.socketRemove(sid, true) + }(id, mux, conn) + }) + } + + for len(m.wpool) > 0 { + <-m.wpool + } +} diff --git a/sockets/lib/multiplexer.go b/sockets/lib/multiplexer.go new file mode 100644 index 0000000000000..265ca6b5821fe --- /dev/null +++ b/sockets/lib/multiplexer.go @@ -0,0 +1,371 @@ +/** + * Multiplexer - Socket/Channel/Subchannel state machine + * https://pokemonshowdown.com/ + * + * This keeps track of the sockets that connect to the SockJS server. Sockets + * are stored in the multiplexer to allow the parent process to manipulate them + * as it pleases. Channels represent rooms in the parent process; subchannels + * split battle rooms into three groups: side 1, side 2, and spectators. + * Certain messages will display differently depending on which subchannel the + * user's socket is in. + */ + +package sockets + +import ( + "fmt" + "net" + "path" + "regexp" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/igm/sockjs-go/sockjs" +) + +// Subchannel IDs +const ( + DEFAULT_SUBCHANNEL_ID byte = '0' + P1_SUBCHANNEL_ID byte = '1' + P2_SUBCHANNEL_ID byte = '2' +) + +// Map of socket IDs to subchannel IDs. +type Channel map[string]byte + +type Multiplexer struct { + nsid uint64 // Socket ID counter. + sockets map[string]sockjs.Session // Map of socket IDs to sockets. + smux sync.RWMutex // nsid and sockets mutex. + channels map[string]Channel // Map of channel (i.e. room) IDs to channels. + cmux sync.RWMutex // channels mutex. + scre *regexp.Regexp // Regex for splitting subchannel broadcasts into their three messages. + conn *Connection // Target for commands originating from here. +} + +func NewMultiplexer() *Multiplexer { + sockets := make(map[string]sockjs.Session) + channels := make(map[string]Channel) + scre := regexp.MustCompile(`\|split\n([^\n]*)\n([^\n]*)\n([^\n]*)\n[^\n]*`) + return &Multiplexer{ + sockets: sockets, + channels: channels, + scre: scre, + } +} + +func (m *Multiplexer) Listen(conn *Connection) { + m.conn = conn +} + +func (m *Multiplexer) Process(cmd *Command) (err error) { + // fmt.Printf("IPC => Sockets: %v\n", cmd.Message()) + params := cmd.Params() + + // Parse the command's params and call the appropriate method. + switch token := cmd.Token(); token { + case SOCKET_DISCONNECT: + sid := params[0] + err = m.socketRemove(sid, true) + case SOCKET_SEND: + sid := params[0] + msg := params[1] + err = m.socketSend(sid, msg) + case SOCKET_RECEIVE: + sid := params[0] + msg := params[1] + err = m.socketReceive(sid, msg) + case CHANNEL_ADD: + cid := params[0] + sid := params[1] + err = m.channelAdd(cid, sid) + case CHANNEL_REMOVE: + cid := params[0] + sid := params[1] + err = m.channelRemove(cid, sid) + case CHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.channelBroadcast(cid, msg) + case SUBCHANNEL_MOVE: + cid := params[0] + scid := params[1][0] + sid := params[2] + err = m.subchannelMove(cid, scid, sid) + case SUBCHANNEL_BROADCAST: + cid := params[0] + msg := params[1] + err = m.subchannelBroadcast(cid, msg) + default: + err = fmt.Errorf("Sockets: received unknown message of type %v: %v", cmd.Token(), cmd.Message()) + } + + if err != nil { + // Something went wrong somewhere, but it's likely a timing issue from + // the parent process. Let's just log the error instead of crashing. + fmt.Printf("%v\n", err) + } + + return +} + +func (m *Multiplexer) socketAdd(s sockjs.Session) (sid string) { + nsid := atomic.LoadUint64(&m.nsid) + sid = strconv.FormatUint(nsid, 10) + atomic.AddUint64(&m.nsid, 1) + + m.smux.Lock() + m.sockets[sid] = s + m.smux.Unlock() + + if m.conn.Listening() { + req := s.Request() + ip, _, _ := net.SplitHostPort(req.RemoteAddr) + ips := req.Header.Get("X-Forwarded-For") + protocol := path.Base(req.URL.Path) + + go func() { + cmd := BuildCommand(m.conn, SOCKET_CONNECT, sid, ip, ips, protocol) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + + return +} + +func (m *Multiplexer) socketRemove(sid string, forced bool) error { + m.cmux.Lock() + for cid, c := range m.channels { + if _, ok := c[sid]; ok { + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + } + } + m.cmux.Unlock() + + m.smux.Lock() + defer m.smux.Unlock() + + s, ok := m.sockets[sid] + if ok { + delete((*m).sockets, sid) + } else { + return fmt.Errorf("Sockets: attempted to remove non-existent socket of ID %v", sid) + } + + if forced { + s.Close(1000, "Normal closure") + } else { + // User-initiated disconnect. Poke the parent process to clean up. + if m.conn.Listening() { + go func() { + cmd := BuildCommand(m.conn, SOCKET_DISCONNECT, sid) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + } + + return nil +} + +func (m *Multiplexer) socketReceive(sid string, msg string) error { + m.smux.RLock() + defer m.smux.RUnlock() + + if _, ok := m.sockets[sid]; ok { + // Drop empty messages (DDOS?). + if len(msg) == 0 { + return nil + } + + // Drop >100KB messages. + if len(msg) > 100*1024 { + fmt.Printf("Dropping client message %vKB...\n%v\n", len(msg)/1024, msg[:160]) + return nil + } + + // Drop legacy JSON messages. + if strings.HasPrefix(msg, "{") { + return nil + } + + // Drop invalid messages (again, DDOS?). + if strings.HasSuffix(msg, "|") || !strings.Contains(msg, "|") { + return nil + } + + if m.conn.Listening() { + go func() { + cmd := BuildCommand(m.conn, SOCKET_RECEIVE, sid, msg) + CmdQueue <- cmd + }() + + time.Sleep(1 * time.Nanosecond) + } + + return nil + } + + // This should never happen. If it does, it's likely a SockJS bug. + return fmt.Errorf("Sockets: received message for a non-existent socket of ID %v: %v", sid, msg) +} + +func (m *Multiplexer) socketSend(sid string, msg string) error { + m.smux.RLock() + defer m.smux.RUnlock() + + if s, ok := m.sockets[sid]; ok { + s.Send(msg) + return nil + } + + return fmt.Errorf("Sockets: attempted to send to non-existent socket of ID %v: %v", sid, msg) +} + +func (m *Multiplexer) channelAdd(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + c = make(Channel) + m.channels[cid] = c + } + + c[sid] = DEFAULT_SUBCHANNEL_ID + + return nil +} + +func (m *Multiplexer) channelRemove(cid string, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if ok { + if _, ok = c[sid]; !ok { + return fmt.Errorf("Sockets: failed to remove non-existent socket of ID %v from channel %v", sid, cid) + } + } else { + // This happens on user-initiated disconnect. Mitigate until this race + // condition is fixed. + return nil + } + + delete(c, sid) + if len(c) == 0 { + delete((*m).channels, cid) + } + + return nil +} + +func (m *Multiplexer) channelBroadcast(cid string, msg string) error { + m.cmux.RLock() + defer m.cmux.RUnlock() + + c, ok := m.channels[cid] + if !ok { + // This happens occasionally when the sole user in a room leaves. + // Mitigate until this race condition is fixed. + return nil + } + + m.smux.RLock() + defer m.smux.RUnlock() + + for sid := range c { + var s sockjs.Session + if s, ok = m.sockets[sid]; ok { + s.Send(msg) + } else { + return fmt.Errorf("Sockets: attempted to broadcast to non-existent socket of ID %v in channel %v: %v", sid, cid, msg) + } + } + + return nil +} + +func (m *Multiplexer) subchannelMove(cid string, scid byte, sid string) error { + m.cmux.Lock() + defer m.cmux.Unlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to move socket of ID %v in non-existent channel %v to subchannel %v", sid, cid, scid) + } + + c[sid] = scid + return nil +} + +func (m *Multiplexer) subchannelBroadcast(cid string, msg string) error { + m.cmux.RLock() + defer m.cmux.RUnlock() + + c, ok := m.channels[cid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, which doesn't exist: %v", cid, msg) + } + + m.smux.RLock() + defer m.smux.RUnlock() + + msgs := make(map[byte]string) + for sid, scid := range c { + s, ok := m.sockets[sid] + if !ok { + return fmt.Errorf("Sockets: attempted to broadcast to subchannels in channel %v, but socket of ID %v doesn't exist: %v", cid, sid, msg) + } + + if _, ok := msgs[scid]; !ok { + switch scid { + case DEFAULT_SUBCHANNEL_ID: + msgs[scid] = m.scre.ReplaceAllString(msg, "$1") + case P1_SUBCHANNEL_ID: + msgs[scid] = m.scre.ReplaceAllString(msg, "$2") + case P2_SUBCHANNEL_ID: + msgs[scid] = m.scre.ReplaceAllString(msg, "$3") + } + } + + s.Send(msgs[scid]) + } + + return nil +} + +// This is the HTTP handler for the SockJS server. This is where new sockets +// arrive for us to use. +func (m *Multiplexer) Handler(s sockjs.Session) { + sid := m.socketAdd(s) + for { + msg, err := s.Recv() + if err != nil { + if err == sockjs.ErrSessionNotOpen { + // User disconnected. + } else { + fmt.Printf("Sockets: SockJS error on message receive for socket of ID %v: %v\n", sid, err) + } + break + } + + if err = m.socketReceive(sid, msg); err != nil { + fmt.Printf("%v\n", err) + break + } + } + + if err := m.socketRemove(sid, false); err != nil { + fmt.Printf("%v\n", err) + } +} diff --git a/sockets/lib/multiplexer_test.go b/sockets/lib/multiplexer_test.go new file mode 100644 index 0000000000000..66ba1558cdb68 --- /dev/null +++ b/sockets/lib/multiplexer_test.go @@ -0,0 +1,131 @@ +package sockets + +import ( + "fmt" + "net" + "testing" +) + +func TestMultiplexer(t *testing.T) { + port := ":3000" + ln, _ := net.Listen("tcp", "localhost"+port) + defer ln.Close() + + conn, _ := NewConnection("PS_IPC_PORT") + defer conn.Close() + mux := NewMultiplexer() + mux.Listen(conn) + // Do not make the connection listen. + + ts := testSocket{} + t.Run("socketAdd", func(t *testing.T) { + sid := mux.socketAdd(ts) + if len(mux.sockets) != 1 { + t.Errorf("Sockets: expected sockets length to be %v, but is actually %v", 1, len(mux.sockets)) + } + delete((*mux).sockets, sid) + }) + t.Run("socketRemove", func(t *testing.T) { + sid := mux.socketAdd(ts) + if err := mux.socketRemove(sid, true); err != nil { + t.Errorf("%v", err) + } + if len(mux.sockets) != 0 { + t.Errorf("Sockets: expected sockets length to be %v, but is actually %v", 0, len(mux.sockets)) + } + if err := mux.socketRemove(sid, true); err == nil { + t.Errorf("Sockets: did not remove socket of ID %v on socket remove", sid) + } + + sid = mux.socketAdd(ts) + if err := mux.socketRemove(sid, false); err != nil { + t.Errorf("%v", err) + } + if err := mux.socketRemove(sid, false); err == nil { + t.Errorf("Sockets: did not remove socket of ID %v on socket remove", sid) + } + }) + t.Run("socketSend", func(t *testing.T) { + sid := mux.socketAdd(ts) + if err := mux.socketSend(sid, ">global\n|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("channelAdd", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + if err := mux.channelAdd(cid, sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 1 { + t.Errorf("Sockets: expected channels length to be %v, but is actually %v", 1, len(mux.channels)) + } + if err := mux.channelAdd(cid, sid); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove(cid, sid) + mux.socketRemove(sid, true) + }) + t.Run("channelRemove", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + if err := mux.channelRemove(cid, sid); err != nil { + t.Errorf("%v", err) + } + if len(mux.channels) != 0 { + t.Errorf("Sockets: expected channels length to be %v, but is actually %v", 0, len(mux.channels)) + } + if err := mux.channelRemove(cid, sid); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("channelBroadcast", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + if err := mux.channelBroadcast(cid, "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove(cid, sid) + if err := mux.channelBroadcast(cid, "|raw|ayy lmao"); err != nil { + t.Errorf("%v", err) + } + mux.socketRemove(sid, true) + }) + t.Run("subchannelMove", func(t *testing.T) { + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + if err := mux.subchannelMove(cid, P1_SUBCHANNEL_ID, sid); err != nil { + t.Errorf("%v", err) + } + if scid := mux.channels[cid][sid]; scid != P1_SUBCHANNEL_ID { + t.Errorf("Sockets: expected subchannel for socket of ID %v in channel %v to be %v, but is actually %v", sid, cid, P1_SUBCHANNEL_ID, scid) + } + mux.channelRemove(cid, sid) + mux.socketRemove(sid, true) + }) + t.Run("subchannelBroadcast", func(t *testing.T) { + msg := "|split\n0\n1\n2\n|\n|split\n3\n4\n5\n|" + scids := []byte{DEFAULT_SUBCHANNEL_ID, P1_SUBCHANNEL_ID, P2_SUBCHANNEL_ID} + for idx, scid := range scids { + amsg := mux.scre.ReplaceAllString(msg, fmt.Sprintf("$%v", idx+1)) + if emsg := fmt.Sprintf("%v\n%v", idx, idx+3); emsg != amsg { + t.Errorf("Sockets: expected broadcast to subchannel of ID %v to be %v, but is actually %v", string(scid), emsg, amsg) + } + } + + sid := mux.socketAdd(ts) + cid := "global" + mux.channelAdd(cid, sid) + mux.subchannelMove(cid, P1_SUBCHANNEL_ID, sid) + if err := mux.subchannelBroadcast(cid, msg); err != nil { + t.Errorf("%v", err) + } + mux.channelRemove(cid, sid) + mux.socketRemove(sid, true) + }) +} diff --git a/sockets/main.go b/sockets/main.go new file mode 100644 index 0000000000000..63a64de1f8f98 --- /dev/null +++ b/sockets/main.go @@ -0,0 +1,125 @@ +package main + +import ( + "crypto/tls" + "io/ioutil" + "log" + "net" + "net/http" + "path/filepath" + + "github.com/Zarel/Pokemon-Showdown/sockets/lib" + + "github.com/gorilla/mux" + "github.com/igm/sockjs-go/sockjs" +) + +func main() { + // Parse our config settings passed through the $PS_CONFIG environment + // variable by the parent process. + config, err := sockets.NewConfig("PS_CONFIG") + if err != nil { + log.Fatalf("Sockets: failed to read parent's config settings from environment: %v") + } + + // Instantiate the socket multiplexer and IPC struct. + smux := sockets.NewMultiplexer() + conn, err := sockets.NewConnection("PS_IPC_PORT") + if err != nil { + log.Fatalf("%v", err) + } + defer conn.Close() + + // Begin listening for incoming messages from sockets and the TCP + // connection to the parent process. For now, they'll just get enqueued + // for workers to manage later. + smux.Listen(conn) + conn.Listen(smux) + + // Set up routing. + r := mux.NewRouter() + + opts := sockjs.Options{ + SockJSURL: "//play.pokemonshowdown.com/js/lib/sockjs-1.1.1-nwjsfix.min.js", + Websocket: true, + HeartbeatDelay: sockjs.DefaultOptions.HeartbeatDelay, + DisconnectDelay: sockjs.DefaultOptions.DisconnectDelay, + JSessionID: sockjs.DefaultOptions.JSessionID, + } + + r.PathPrefix("/showdown"). + Handler(sockjs.NewHandler("/showdown", opts, smux.Handler)) + + customCssDir, _ := filepath.Abs("./config") + r.Handle("/custom.css", http.FileServer(http.Dir(customCssDir))) + + avatarDir, _ := filepath.Abs("./config/avatars") + r.PathPrefix("/avatars/"). + Handler(http.StripPrefix("/avatars/", http.FileServer(http.Dir(avatarDir)))) + + indexPath, _ := filepath.Abs("./static/index.html") + r.PathPrefix("/{roomid:[A-Za-z0-9][A-Za-z0-9-]*}"). + HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeFile(w, r, indexPath) + }) + + notFoundPath, _ := filepath.Abs("./static/404.html") + notFoundPage, _ := ioutil.ReadFile(notFoundPath) + r.NotFoundHandler = + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write(notFoundPage) + }) + + staticDir, _ := filepath.Abs("./static") + r.Handle("/", http.FileServer(http.Dir(staticDir))) + + // Begin serving over HTTP. + go func(ba string, port string) { + addr, err := net.ResolveTCPAddr("tcp4", ba+port) + if err != nil { + log.Fatalf("Sockets: failed to resolve the TCP address of the parent's server: %v", err) + } + + ln, err := net.ListenTCP("tcp4", addr) + defer ln.Close() + if err != nil { + log.Fatalf("Sockets: failed to listen over HTTP: %v", err) + } + + err = http.Serve(ln, r) + log.Fatalf("Sockets: HTTP server failed: %v", err) + }(config.BindAddress, config.Port) + + // Begin serving over HTTPS if configured to do so. + if config.SSL.Options.Cert != "" && config.SSL.Options.Key != "" { + go func(ba string, port string, cert string, key string) { + certs, err := tls.LoadX509KeyPair(cert, key) + if err != nil { + log.Fatalf("Sockets: failed to load certificate and key files for TLS: %v", err) + } + + srv := &http.Server{ + Handler: r, + Addr: ba + port, + TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, + } + + var ln net.Listener + ln, err = tls.Listen("tcp4", srv.Addr, srv.TLSConfig) + if err != nil { + log.Fatalf("Sockets: failed to listen over HTTPS: %v", err) + } + + defer ln.Close() + err = http.Serve(ln, r) + log.Fatalf("Sockets: HTTPS server failed: %v", err) + }(config.BindAddress, config.SSL.Port, config.SSL.Options.Cert, config.SSL.Options.Key) + } + + // Finally, spawn workers.to pipe messages received at the multiplexer or + // IPC connection to each other concurrently. + master := sockets.NewMaster(config.Workers) + master.Spawn() + master.Listen() +} diff --git a/test/application/sockets.js b/test/application/sockets.js index 7fc963ebe9bf2..7c3182e55cbdf 100644 --- a/test/application/sockets.js +++ b/test/application/sockets.js @@ -1,212 +1,96 @@ 'use strict'; const assert = require('assert'); -const cluster = require('cluster'); -describe.skip('Sockets', function () { - const spawnWorker = () => ( - new Promise(resolve => { - let worker = Sockets.spawnWorker(); - worker.removeAllListeners('message'); - resolve(worker); - }) - ); +let sockets; +describe('Sockets workers', function () { before(function () { - cluster.settings.silent = true; - cluster.removeAllListeners('disconnect'); + sockets = require('../../sockets-workers'); + + this.mux = new sockets.Multiplexer(); + clearInterval(this.mux.cleanupInterval); + this.mux.cleanupInterval = null; + + this.socket = require('../../dev-tools/sockets').createSocket(); }); afterEach(function () { - Sockets.workers.forEach((worker, workerid) => { - worker.kill(); - Sockets.workers.delete(workerid); - }); + this.mux.socketCounter = 0; + this.mux.sockets.clear(); + this.mux.channels.clear(); + }); + + after(function () { + this.mux.tryDestroySocket(this.socket); + this.socket = null; + this.mux = null; + }); + + it('should parse more than two params', function () { + let params = '1\n1\n0\n'; + let ret = this.mux.parseParams(params, 4); + assert.deepStrictEqual(ret, ['1', '1', '0', '']); + }); + + it('should parse params with multiple newlines', function () { + let params = '0\n|1\n|2'; + let ret = this.mux.parseParams(params, 2); + assert.deepStrictEqual(ret, ['0', '|1\n|2']); + }); + + it('should add sockets on connect', function () { + let res = this.mux.onSocketConnect(this.socket); + assert.ok(res); + }); + + it('should remove sockets on disconnect', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onSocketDisconnect('0', this.socket); + assert.ok(res); + }); + + it('should add sockets to channels', function () { + this.mux.onSocketConnect(this.socket); + let res = this.mux.onChannelAdd('global', '0'); + assert.ok(res); + res = this.mux.onChannelAdd('global', '0'); + assert.ok(!res); + this.mux.channels.set('lobby', new Map()); + res = this.mux.onChannelAdd('lobby', '0'); + assert.ok(res); + }); + + it('should remove sockets from channels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onChannelRemove('global', '0'); + assert.ok(res); + res = this.mux.onChannelRemove('global', '0'); + assert.ok(!res); }); - describe('master', function () { - it('should be able to spawn workers', function () { - Sockets.spawnWorker(); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to spawn workers on listen', function () { - Sockets.listen(0, '127.0.0.1', 1); - assert.strictEqual(Sockets.workers.size, 1); - }); - - it('should be able to kill workers', function () { - return spawnWorker().then(worker => { - Sockets.killWorker(worker); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); - - it('should be able to kill workers by PID', function () { - return spawnWorker().then(worker => { - Sockets.killPid(worker.process.pid); - assert.strictEqual(Sockets.workers.size, 0); - }); - }); + it('should move sockets to subchannels', function () { + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + let res = this.mux.onSubchannelMove('global', '1', '0'); + assert.ok(res); }); - describe('workers', function () { - // This composes a sequence of HOFs that send a message to a worker, - // wait for its response, then return the worker for the next function - // to use. - const chain = (eventHandler, msg) => worker => { - worker.once('message', eventHandler(worker)); - msg = msg || `$ - const {Session} = require('sockjs/lib/transport'); - const socket = new Session('aaaaaaaa', server); - socket.remoteAddress = '127.0.0.1'; - if (!('headers' in socket)) socket.headers = {}; - socket.headers['x-forwarded-for'] = ''; - socket.protocol = 'websocket'; - socket.write = msg => process.send(msg); - server.emit('connection', socket);`; - worker.send(msg); - return worker; - }; - - const spawnSocket = eventHandler => spawnWorker().then(chain(eventHandler)); - - it('should allow sockets to connect', function () { - return spawnSocket(worker => data => { - let cmd = data.charAt(0); - let [sid, ip, protocol] = data.substr(1).split('\n'); - assert.strictEqual(cmd, '*'); - assert.strictEqual(sid, '1'); - assert.strictEqual(ip, '127.0.0.1'); - assert.strictEqual(protocol, 'websocket'); - }); - }); - - it('should allow sockets to disconnect', function () { - let querySocket; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - querySocket = `$ - let socket = sockets.get(${sid}); - process.send(!socket);`; - Sockets.socketDisconnect(worker, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySocket)); - }); - - it('should allow sockets to send messages', function () { - let msg = 'ayy lmao'; - let socketSend; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - socketSend = `>${sid}\n${msg}`; - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, socketSend)); - }); - - it('should allow sockets to receive messages', function () { - let sid; - let msg; - let mockReceive; - return spawnSocket(worker => data => { - sid = data.substr(1, data.indexOf('\n')); - msg = '|/cmd rooms'; - mockReceive = `$ - let socket = sockets.get(${sid}); - socket.emit('data', ${msg});`; - }).then(chain(worker => data => { - let cmd = data.charAt(0); - let params = data.substr(1).split('\n'); - assert.strictEqual(cmd, '<'); - assert.strictEqual(sid, params[0]); - assert.strictEqual(msg, params[1]); - }, mockReceive)); - }); - - it('should create a channel for the first socket to get added to it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - let channel = channels.get(${cid}); - process.send(channel && channel.has(${sid}));`; - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); - - it('should remove a channel if the last socket gets removed from it', function () { - let queryChannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'global'; - queryChannel = `$ - process.send(!sockets.has(${sid}) && !channels.has(${cid}));`; - Sockets.channelAdd(worker, cid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, queryChannel)); - }); - - it('should send to all sockets in a channel', function () { - let msg = 'ayy lmao'; - let cid = 'global'; - let channelSend = `#${cid}\n${msg}`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - Sockets.channelAdd(worker, cid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, channelSend)); - }); - - it('should create a subchannel when moving a socket to it', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels[${cid}]; - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); - - it('should remove a subchannel when removing its last socket', function () { - let querySubchannel; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let cid = 'battle-ou-1'; - let scid = '1'; - querySubchannel = `$ - let subchannel = subchannels.get(${cid}); - process.send(!!subchannel && (subchannel.get(${sid}) === ${scid}));`; - Sockets.subchannelMove(worker, cid, scid, sid); - Sockets.channelRemove(worker, cid, sid); - }).then(chain(worker => data => { - assert.ok(data); - }, querySubchannel)); - }); - - it('should send to sockets in a subchannel', function () { - let cid = 'battle-ou-1'; - let msg = 'ayy lmao'; - let subchannelSend = `.${cid}\n\n|split\n\n${msg}\n\n`; - return spawnSocket(worker => data => { - let sid = data.substr(1, data.indexOf('\n')); - let scid = '1'; - Sockets.subchannelMove(worker, cid, scid, sid); - }).then(chain(worker => data => { - assert.strictEqual(data, msg); - }, subchannelSend)); - }); + it('should broadcast to subchannels', function () { + let messages = '|split\n0\n1\n2\n|\n|split\n3\n4\n5\n|'; + for (let i = 0; i < 3; i++) { + let message = messages.replace(sockets.SUBCHANNEL_MESSAGE_REGEX, `$${i + 1}`); + assert.strictEqual(message, `${i}\n${i + 3}`); + } + + this.mux.onSocketConnect(this.socket); + this.mux.onChannelAdd('global', '0'); + this.mux.onSubchannelMove('global', '1', '0'); + let res = this.mux.onSubchannelBroadcast('global', messages); + assert.ok(res); + this.mux.onChannelRemove('global', '0'); + res = this.mux.onSubchannelBroadcast('global', messages); + assert.ok(!res); }); }); diff --git a/tsconfig.json b/tsconfig.json index c18e73cfacafd..885538cca7c2b 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -15,11 +15,13 @@ "./sim/prng.js", "./crashlogger.js", "./dnsbl.js", + "./fs.js", "./ladders-matchmaker.js", "./monitor.js", - "./repl.js", - "./fs.js", "./process-manager.js", + "./repl.js", + "./sockets.js", + "./sockets-workers.js", "./verifier.js" ] } diff --git a/users.js b/users.js index 61a886c57cb0b..5114cc4c02aea 100644 --- a/users.js +++ b/users.js @@ -1545,23 +1545,32 @@ Users.pruneInactiveTimer = setInterval(() => { * Routing *********************************************************/ +/** + * Creates a user and connection object for a new socket and sends a challenge + * string to the user for authentication. + * @param {WorkerWrapper} worker + * @param {number} workerid + * @param {string} socketid + * @param {string} ip + * @param {string} protocol + */ Users.socketConnect = function (worker, workerid, socketid, ip, protocol) { - let id = '' + workerid + '-' + socketid; + let id = `${workerid}-${socketid}`; let connection = new Connection(id, worker, socketid, null, ip, protocol); connections.set(id, connection); let banned = Punishments.checkIpBanned(connection); - if (banned) { - return connection.destroy(); - } + if (banned) return connection.destroy(); + // Emergency mode connections logging if (Config.emergency) { - FS('logs/cons.emergency.log').append('[' + ip + ']\n'); + FS('logs/cons.emergency.log').append(`[${ip}]\n`); } let user = new User(connection); connection.user = user; Punishments.checkIp(user, connection); + // Generate 1024-bit challenge string. require('crypto').randomBytes(128, (err, buffer) => { if (err) { @@ -1582,17 +1591,29 @@ Users.socketConnect = function (worker, workerid, socketid, ip, protocol) { user.joinRoom('global', connection); }; +/** + * Forcefully disconnects a socket. + * @param {WorkerWrapper} worker + * @param {number} workerid + * @param {string} socketid + */ Users.socketDisconnect = function (worker, workerid, socketid) { - let id = '' + workerid + '-' + socketid; - + let id = `${workerid}-${socketid}`; let connection = connections.get(id); if (!connection) return; + connection.onDisconnect(); }; -Users.socketReceive = function (worker, workerid, socketid, message) { - let id = '' + workerid + '-' + socketid; - +/** + * Parses a chat message received by a socket. + * @param {WorkerWrapper} worker + * @param {number} workerid + * @param {string} socketid + * @param {string} data + */ +Users.socketReceive = function (worker, workerid, socketid, data) { + let id = `${workerid}-${socketid}`; let connection = connections.get(id); if (!connection) return; @@ -1601,36 +1622,31 @@ Users.socketReceive = function (worker, workerid, socketid, message) { // `data` event. To prevent this, we log exceptions and prevent them // from propagating out of this function. - // drop legacy JSON messages - if (message.charAt(0) === '{') return; - - // drop invalid messages without a pipe character - let pipeIndex = message.indexOf('|'); - if (pipeIndex < 0) return; - - const user = connection.user; + let {user} = connection; if (!user) return; // The client obviates the room id when sending messages to Lobby by default - const roomId = message.substr(0, pipeIndex) || (Rooms.lobby || Rooms.global).id; - message = message.slice(pipeIndex + 1); - - const room = Rooms(roomId); + let pipeIndex = data.indexOf('|'); + let roomid = data.substr(0, pipeIndex) || (Rooms.lobby || Rooms.global).id; + let message = data.slice(pipeIndex + 1); + let room = Rooms(roomid); if (!room) return; + if (Chat.multiLinePattern.test(message)) { user.chat(message, room, connection); return; } - const lines = message.split('\n'); + let lines = message.split('\n'); if (!lines[lines.length - 1]) lines.pop(); if (lines.length > (user.isStaff ? THROTTLE_MULTILINE_WARN_STAFF : THROTTLE_MULTILINE_WARN)) { connection.popup(`You're sending too many lines at once. Try using a paste service like [[Pastebin]].`); return; } + // Emergency logging if (Config.emergency) { - FS('logs/emergency.log').append(`[${user} (${connection.ip})] ${roomId}|${message}\n`); + FS('logs/emergency.log').append(`[${user} (${connection.ip})] ${data}\n`); } let startTime = Date.now(); @@ -1639,6 +1655,23 @@ Users.socketReceive = function (worker, workerid, socketid, message) { } let deltaTime = Date.now() - startTime; if (deltaTime > 1000) { - Monitor.warn(`[slow] ${deltaTime}ms - ${user.name} <${connection.ip}>: ${roomId}|${message}`); + Monitor.warn(`[slow] ${deltaTime}ms - ${user.name} <${connection.ip}>: ${data}`); } }; + +/** + * Clears all connections whose sockets were contained by a + * worker. Called after a worker's process crashes or gets killed. + * @param {WorkerWrapper} worker + * @return {number} + */ +Users.socketDisconnectAll = function (worker) { + let count = 0; + connections.forEach(connection => { + if (connection.worker === worker) { + Users.socketDisconnect(worker, worker.id, connection.socketid); + count++; + } + }); + return count; +};