From 8060c86d2015ea9fdb0afcb4efc88fbf3951b78d Mon Sep 17 00:00:00 2001 From: uriyage <78144248+uriyage@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:03:30 +0200 Subject: [PATCH] Offload TLS negotiation to I/O threads (#1338) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## TLS Negotiation Offloading to I/O Threads ### Overview This PR introduces the ability to offload TLS handshake negotiations to I/O threads, significantly improving performance under high TLS connection loads. ### Key Changes - Added infrastructure to offload TLS negotiations to I/O threads - Refactored SSL event handling to allow I/O threads modify conn flags. - Introduced new connection flag to identify client connections ### Performance Impact Testing with 650 clients with SET commands and 160 new TLS connections per second in the background: #### Throughput Impact of new TLS connections - **With Offloading**: Minimal impact (1050K → 990K ops/sec) - **Without Offloading**: Significant drop (1050K → 670K ops/sec) #### New Connection Rate - **With Offloading**: - 1,757 conn/sec - **Without Offloading**: - 477 conn/sec ### Implementation Details 1. **Main Thread**: - Initiates negotiation-offload jobs to I/O threads - Adds connections to pending-read clients list (using existing read offload mechanism) - Post-negotiation handling: - Creates read/write events if needed for incomplete negotiations - Calls accept handler for completed negotiations 2. **I/O Thread**: - Performs TLS negotiation - Updates connection flags based on negotiation result Related issue:https://github.com/valkey-io/valkey/issues/761 --------- Signed-off-by: Uri Yagelnik Signed-off-by: ranshid <88133677+ranshid@users.noreply.github.com> Co-authored-by: ranshid <88133677+ranshid@users.noreply.github.com> Co-authored-by: Madelyn Olson --- .github/workflows/daily.yml | 38 ++++++++++ src/connection.h | 5 +- src/io_threads.c | 52 ++++++++++++++ src/io_threads.h | 1 + src/networking.c | 6 ++ src/server.c | 2 + src/server.h | 1 + src/tls.c | 139 ++++++++++++++++++------------------ 8 files changed, 174 insertions(+), 70 deletions(-) diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index 44386f5ffd..e1d577b51b 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -375,6 +375,44 @@ jobs: if: true && !contains(github.event.inputs.skiptests, 'cluster') run: ./runtest-cluster --io-threads ${{github.event.inputs.cluster_test_args}} + test-ubuntu-tls-io-threads: + runs-on: ubuntu-latest + if: | + (github.event_name == 'workflow_dispatch' || + (github.event_name == 'schedule' && github.repository == 'valkey-io/valkey') || + (github.event_name == 'pull_request' && (contains(github.event.pull_request.labels.*.name, 'run-extra-tests') || github.event.pull_request.base.ref != 'unstable'))) && + !contains(github.event.inputs.skipjobs, 'tls') && !contains(github.event.inputs.skipjobs, 'iothreads') + timeout-minutes: 14400 + steps: + - name: prep + if: github.event_name == 'workflow_dispatch' + run: | + echo "GITHUB_REPOSITORY=${{github.event.inputs.use_repo}}" >> $GITHUB_ENV + echo "GITHUB_HEAD_REF=${{github.event.inputs.use_git_ref}}" >> $GITHUB_ENV + echo "skipjobs: ${{github.event.inputs.skipjobs}}" + echo "skiptests: ${{github.event.inputs.skiptests}}" + echo "test_args: ${{github.event.inputs.test_args}}" + echo "cluster_test_args: ${{github.event.inputs.cluster_test_args}}" + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + with: + repository: ${{ env.GITHUB_REPOSITORY }} + ref: ${{ env.GITHUB_HEAD_REF }} + - name: make + run: | + make BUILD_TLS=yes SERVER_CFLAGS='-Werror' + - name: testprep + run: | + sudo apt-get install tcl8.6 tclx tcl-tls + ./utils/gen-test-certs.sh + - name: test + if: true && !contains(github.event.inputs.skiptests, 'valkey') + run: | + ./runtest --io-threads --tls --accurate --verbose --tags network --dump-logs ${{github.event.inputs.test_args}} + - name: cluster tests + if: true && !contains(github.event.inputs.skiptests, 'cluster') + run: | + ./runtest-cluster --io-threads --tls ${{github.event.inputs.cluster_test_args}} + test-ubuntu-reclaim-cache: runs-on: ubuntu-latest if: | diff --git a/src/connection.h b/src/connection.h index 8a2775ee34..fd7e0910cf 100644 --- a/src/connection.h +++ b/src/connection.h @@ -54,8 +54,9 @@ typedef enum { CONN_STATE_ERROR } ConnectionState; -#define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */ -#define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */ +#define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */ +#define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */ +#define CONN_FLAG_ALLOW_ACCEPT_OFFLOAD (1 << 2) /* Connection accept can be offloaded to IO threads. */ #define CONN_TYPE_SOCKET "tcp" #define CONN_TYPE_UNIX "unix" diff --git a/src/io_threads.c b/src/io_threads.c index 3865eb77c3..90f5b88700 100644 --- a/src/io_threads.c +++ b/src/io_threads.c @@ -561,3 +561,55 @@ void trySendPollJobToIOThreads(void) { aeSetPollProtect(server.el, 1); IOJobQueue_push(jq, IOThreadPoll, server.el); } + +static void ioThreadAccept(void *data) { + client *c = (client *)data; + connAccept(c->conn, NULL); + c->io_read_state = CLIENT_COMPLETED_IO; +} + +/* + * Attempts to offload an Accept operation (currently used for TLS accept) for a client + * connection to I/O threads. + * + * Returns: + * C_OK - If the accept operation was successfully queued for processing + * C_ERR - If the connection is not eligible for offloading + * + * Parameters: + * conn - The connection object to perform the accept operation on + */ +int trySendAcceptToIOThreads(connection *conn) { + if (server.io_threads_num <= 1) { + return C_ERR; + } + + if (!(conn->flags & CONN_FLAG_ALLOW_ACCEPT_OFFLOAD)) { + return C_ERR; + } + + client *c = connGetPrivateData(conn); + if (c->io_read_state != CLIENT_IDLE) { + return C_OK; + } + + if (server.active_io_threads_num <= 1) { + return C_ERR; + } + + size_t thread_id = (c->id % (server.active_io_threads_num - 1)) + 1; + IOJobQueue *job_queue = &io_jobs[thread_id]; + + if (IOJobQueue_isFull(job_queue)) { + return C_ERR; + } + + c->io_read_state = CLIENT_PENDING_IO; + c->flag.pending_read = 1; + listLinkNodeTail(server.clients_pending_io_read, &c->pending_read_list_node); + connSetPostponeUpdateState(c->conn, 1); + server.stat_io_accept_offloaded++; + IOJobQueue_push(job_queue, ioThreadAccept, c); + + return C_OK; +} diff --git a/src/io_threads.h b/src/io_threads.h index 8818f08588..a3ff582a77 100644 --- a/src/io_threads.h +++ b/src/io_threads.h @@ -13,5 +13,6 @@ int tryOffloadFreeArgvToIOThreads(client *c, int argc, robj **argv); void adjustIOThreadsByEventLoad(int numevents, int increase_only); void drainIOThreadsQueue(void); void trySendPollJobToIOThreads(void); +int trySendAcceptToIOThreads(connection *conn); #endif /* IO_THREADS_H */ diff --git a/src/networking.c b/src/networking.c index 16147ff0ba..9f36f24275 100644 --- a/src/networking.c +++ b/src/networking.c @@ -134,6 +134,7 @@ client *createClient(connection *conn) { if (server.tcpkeepalive) connKeepAlive(conn, server.tcpkeepalive); connSetReadHandler(conn, readQueryFromClient); connSetPrivateData(conn, c); + conn->flags |= CONN_FLAG_ALLOW_ACCEPT_OFFLOAD; } c->buf = zmalloc_usable(PROTO_REPLY_CHUNK_BYTES, &c->buf_usable_size); selectDb(c, 0); @@ -4805,9 +4806,14 @@ int processIOThreadsReadDone(void) { processed++; server.stat_io_reads_processed++; + /* Save the current conn state, as connUpdateState may modify it */ + int in_accept_state = (connGetState(c->conn) == CONN_STATE_ACCEPTING); connSetPostponeUpdateState(c->conn, 0); connUpdateState(c->conn); + /* In accept state, no client's data was read - stop here. */ + if (in_accept_state) continue; + /* On read error - stop here. */ if (handleReadResult(c) == C_ERR) { continue; diff --git a/src/server.c b/src/server.c index 5275fed4b9..3cdec9fa9b 100644 --- a/src/server.c +++ b/src/server.c @@ -2645,6 +2645,7 @@ void resetServerStats(void) { server.stat_total_reads_processed = 0; server.stat_io_writes_processed = 0; server.stat_io_freed_objects = 0; + server.stat_io_accept_offloaded = 0; server.stat_poll_processed_by_io_threads = 0; server.stat_total_writes_processed = 0; server.stat_client_qbuf_limit_disconnections = 0; @@ -5915,6 +5916,7 @@ sds genValkeyInfoString(dict *section_dict, int all_sections, int everything) { "io_threaded_reads_processed:%lld\r\n", server.stat_io_reads_processed, "io_threaded_writes_processed:%lld\r\n", server.stat_io_writes_processed, "io_threaded_freed_objects:%lld\r\n", server.stat_io_freed_objects, + "io_threaded_accept_processed:%lld\r\n", server.stat_io_accept_offloaded, "io_threaded_poll_processed:%lld\r\n", server.stat_poll_processed_by_io_threads, "io_threaded_total_prefetch_batches:%lld\r\n", server.stat_total_prefetch_batches, "io_threaded_total_prefetch_entries:%lld\r\n", server.stat_total_prefetch_entries, diff --git a/src/server.h b/src/server.h index 783871b856..841db70614 100644 --- a/src/server.h +++ b/src/server.h @@ -1869,6 +1869,7 @@ struct valkeyServer { long long stat_io_reads_processed; /* Number of read events processed by IO threads */ long long stat_io_writes_processed; /* Number of write events processed by IO threads */ long long stat_io_freed_objects; /* Number of objects freed by IO threads */ + long long stat_io_accept_offloaded; /* Number of offloaded accepts */ long long stat_poll_processed_by_io_threads; /* Total number of poll jobs processed by IO */ long long stat_total_reads_processed; /* Total number of read events processed */ long long stat_total_writes_processed; /* Total number of write events processed */ diff --git a/src/tls.c b/src/tls.c index 48b75553de..11e6143561 100644 --- a/src/tls.c +++ b/src/tls.c @@ -32,6 +32,7 @@ #include "server.h" #include "connhelpers.h" #include "adlist.h" +#include "io_threads.h" #if (USE_OPENSSL == 1 /* BUILD_YES */) || ((USE_OPENSSL == 2 /* BUILD_MODULE */) && (BUILD_TLS_MODULE == 2)) @@ -437,16 +438,13 @@ static ConnectionType CT_TLS; * */ -typedef enum { - WANT_READ = 1, - WANT_WRITE -} WantIOType; - #define TLS_CONN_FLAG_READ_WANT_WRITE (1 << 0) #define TLS_CONN_FLAG_WRITE_WANT_READ (1 << 1) #define TLS_CONN_FLAG_FD_SET (1 << 2) #define TLS_CONN_FLAG_POSTPONE_UPDATE_STATE (1 << 3) #define TLS_CONN_FLAG_HAS_PENDING (1 << 4) +#define TLS_CONN_FLAG_ACCEPT_ERROR (1 << 5) +#define TLS_CONN_FLAG_ACCEPT_SUCCESS (1 << 6) typedef struct tls_connection { connection c; @@ -514,20 +512,26 @@ static connection *connCreateAcceptedTLS(int fd, void *priv) { return (connection *)conn; } +static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler); static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask); static void updateSSLEvent(tls_connection *conn); +static void clearTLSWantFlags(tls_connection *conn) { + conn->flags &= ~(TLS_CONN_FLAG_WRITE_WANT_READ | TLS_CONN_FLAG_READ_WANT_WRITE); +} + /* Process the return code received from OpenSSL> - * Update the want parameter with expected I/O. + * Update the conn flags with the WANT_READ/WANT_WRITE flags. * Update the connection's error state if a real error has occurred. * Returns an SSL error code, or 0 if no further handling is required. */ -static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *want) { +static int handleSSLReturnCode(tls_connection *conn, int ret_value) { + clearTLSWantFlags(conn); if (ret_value <= 0) { int ssl_err = SSL_get_error(conn->ssl, ret_value); switch (ssl_err) { - case SSL_ERROR_WANT_WRITE: *want = WANT_WRITE; return 0; - case SSL_ERROR_WANT_READ: *want = WANT_READ; return 0; + case SSL_ERROR_WANT_WRITE: conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; return 0; + case SSL_ERROR_WANT_READ: conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; return 0; case SSL_ERROR_SYSCALL: conn->c.last_errno = errno; if (conn->ssl_error) zfree(conn->ssl_error); @@ -563,11 +567,8 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update } if (ret_value <= 0) { - WantIOType want = 0; int ssl_err; - if (!(ssl_err = handleSSLReturnCode(conn, ret_value, &want))) { - if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; - if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; + if (!(ssl_err = handleSSLReturnCode(conn, ret_value))) { if (update_event) updateSSLEvent(conn); errno = EAGAIN; return -1; @@ -585,19 +586,17 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update return ret_value; } -static void registerSSLEvent(tls_connection *conn, WantIOType want) { +static void registerSSLEvent(tls_connection *conn) { int mask = aeGetFileEvents(server.el, conn->c.fd); - switch (want) { - case WANT_READ: + if (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ) { if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn); - break; - case WANT_WRITE: + } else if (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE) { if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn); - break; - default: serverAssert(0); break; + } else { + serverAssert(0); } } @@ -650,12 +649,47 @@ static void updateSSLEvent(tls_connection *conn) { if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); } +static int TLSHandleAcceptResult(tls_connection *conn, int call_handler_on_error) { + serverAssert(conn->c.state == CONN_STATE_ACCEPTING); + if (conn->flags & TLS_CONN_FLAG_ACCEPT_SUCCESS) { + conn->c.state = CONN_STATE_CONNECTED; + } else if (conn->flags & TLS_CONN_FLAG_ACCEPT_ERROR) { + conn->c.state = CONN_STATE_ERROR; + if (!call_handler_on_error) return C_ERR; + } else { + /* Still pending accept */ + registerSSLEvent(conn); + return C_OK; + } + + /* call accept handler */ + if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_ERR; + conn->c.conn_handler = NULL; + return C_OK; +} + static void updateSSLState(connection *conn_) { tls_connection *conn = (tls_connection *)conn_; + + if (conn->c.state == CONN_STATE_ACCEPTING) { + if (TLSHandleAcceptResult(conn, 1) == C_ERR || conn->c.state != CONN_STATE_CONNECTED) return; + } + updateSSLEvent(conn); updatePendingData(conn); } +static void TLSAccept(void *_conn) { + tls_connection *conn = (tls_connection *)_conn; + ERR_clear_error(); + int ret = SSL_accept(conn->ssl); + if (ret > 0) { + conn->flags |= TLS_CONN_FLAG_ACCEPT_SUCCESS; + } else if (handleSSLReturnCode(conn, ret)) { + conn->flags |= TLS_CONN_FLAG_ACCEPT_ERROR; + } +} + static void tlsHandleEvent(tls_connection *conn, int mask) { int ret, conn_error; @@ -676,10 +710,8 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { } ret = SSL_connect(conn->ssl); if (ret <= 0) { - WantIOType want = 0; - if (!handleSSLReturnCode(conn, ret, &want)) { - registerSSLEvent(conn, want); - + if (!handleSSLReturnCode(conn, ret)) { + registerSSLEvent(conn); /* Avoid hitting UpdateSSLEvent, which knows nothing * of what SSL_connect() wants and instead looks at our * R/W handlers. @@ -698,27 +730,7 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { conn->c.conn_handler = NULL; break; case CONN_STATE_ACCEPTING: - ERR_clear_error(); - ret = SSL_accept(conn->ssl); - if (ret <= 0) { - WantIOType want = 0; - if (!handleSSLReturnCode(conn, ret, &want)) { - /* Avoid hitting UpdateSSLEvent, which knows nothing - * of what SSL_connect() wants and instead looks at our - * R/W handlers. - */ - registerSSLEvent(conn, want); - return; - } - - /* If not handled, it's an error */ - conn->c.state = CONN_STATE_ERROR; - } else { - conn->c.state = CONN_STATE_CONNECTED; - } - - if (!callHandler((connection *)conn, conn->c.conn_handler)) return; - conn->c.conn_handler = NULL; + if (connTLSAccept((connection *)conn, NULL) == C_ERR || conn->c.state != CONN_STATE_CONNECTED) return; break; case CONN_STATE_CONNECTED: { int call_read = ((mask & AE_READABLE) && conn->c.read_handler) || @@ -740,20 +752,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) { int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER; if (!invert && call_read) { - conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *)conn, conn->c.read_handler)) return; } /* Fire the writable event. */ if (call_write) { - conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; if (!callHandler((connection *)conn, conn->c.write_handler)) return; } /* If we have to invert the call, fire the readable event now * after the writable one. */ if (invert && call_read) { - conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; if (!callHandler((connection *)conn, conn->c.read_handler)) return; } updatePendingData(conn); @@ -845,31 +854,25 @@ static void connTLSClose(connection *conn_) { static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { tls_connection *conn = (tls_connection *)_conn; - int ret; - if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR; - ERR_clear_error(); - + int call_handler_on_error = 1; /* Try to accept */ - conn->c.conn_handler = accept_handler; - ret = SSL_accept(conn->ssl); - - if (ret <= 0) { - WantIOType want = 0; - if (!handleSSLReturnCode(conn, ret, &want)) { - registerSSLEvent(conn, want); /* We'll fire back */ - return C_OK; - } else { - conn->c.state = CONN_STATE_ERROR; - return C_ERR; - } + if (accept_handler) { + conn->c.conn_handler = accept_handler; + call_handler_on_error = 0; } - conn->c.state = CONN_STATE_CONNECTED; - if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_OK; - conn->c.conn_handler = NULL; + /* We're in IO thread - just call accept and return, the main thread will handle the rest */ + if (!inMainThread()) { + TLSAccept(conn); + return C_OK; + } - return C_OK; + /* Try to offload accept to IO threads */ + if (trySendAcceptToIOThreads(_conn) == C_OK) return C_OK; + + TLSAccept(conn); + return TLSHandleAcceptResult(conn, call_handler_on_error); } static int connTLSConnect(connection *conn_,