diff --git a/src/unix/sqUnixOpenSSL.c b/src/unix/sqUnixOpenSSL.c index 4021095..aed4b13 100644 --- a/src/unix/sqUnixOpenSSL.c +++ b/src/unix/sqUnixOpenSSL.c @@ -11,6 +11,7 @@ typedef struct sqSSL { char *certName; char *peerName; + char *serverName; SSL_METHOD *method; SSL_CTX *ctx; @@ -34,12 +35,12 @@ static sqSSL *sslFromHandle(sqInt handle) { /* sqCopyBioSSL: Copies data from a BIO into an out buffer */ sqInt sqCopyBioSSL(sqSSL *ssl, BIO *bio, char *dstBuf, sqInt dstLen) { - int nbytes = BIO_ctrl_pending(bio); + int nbytes = BIO_ctrl_pending(bio); - if(ssl->loglevel) printf("sqCopyBioSSL: %d bytes pending; buffer size %ld\n", - nbytes, (long)dstLen); - if(nbytes > dstLen) return -1; - return BIO_read(bio, dstBuf, dstLen); + if(ssl->loglevel) printf("sqCopyBioSSL: %d bytes pending; buffer size %ld\n", + nbytes, (long)dstLen); + if(nbytes > dstLen) return -1; + return BIO_read(bio, dstBuf, dstLen); } /* sqSetupSSL: Common SSL setup tasks */ @@ -60,16 +61,16 @@ sqInt sqSetupSSL(sqSSL *ssl, int server) { if(ssl->certName) { if(ssl->loglevel) printf("sqSetupSSL: Using cert file %s\n", ssl->certName); if(SSL_CTX_use_certificate_file(ssl->ctx, ssl->certName, SSL_FILETYPE_PEM)<=0) - ERR_print_errors_fp(stderr); + ERR_print_errors_fp(stderr); if(SSL_CTX_use_PrivateKey_file(ssl->ctx, ssl->certName, SSL_FILETYPE_PEM)<=0) - ERR_print_errors_fp(stderr); + ERR_print_errors_fp(stderr); } /* Set up trusted CA */ if(ssl->loglevel) printf("sqSetupSSL: No root CA given; using default verify paths\n"); if(SSL_CTX_set_default_verify_paths(ssl->ctx) <=0) - ERR_print_errors_fp(stderr); + ERR_print_errors_fp(stderr); if(ssl->loglevel) printf("sqSetupSSL: Creating SSL\n"); ssl->ssl = SSL_new(ssl->ctx); @@ -184,6 +185,13 @@ sqInt sqConnectSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt if(ssl->loglevel) printf("sqConnectSSL: BIO_write failed\n"); return SQSSL_GENERIC_ERROR; } + + /* if a server name is provided, use it */ + if(ssl->serverName) { + if(ssl->loglevel) printf("sqSetupSSL: Using server name %s\n", ssl->serverName); + SSL_set_tlsext_host_name(ssl->ssl, ssl->serverName); + } + if(ssl->loglevel) printf("sqConnectSSL: SSL_connect\n"); result = SSL_connect(ssl->ssl); if(result <= 0) { @@ -206,8 +214,8 @@ sqInt sqConnectSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt /* Fail if no cert received. */ if(cert) { X509_NAME_get_text_by_NID(X509_get_subject_name(cert), - NID_commonName, peerName, - sizeof(peerName)); + NID_commonName, peerName, + sizeof(peerName)); if(ssl->loglevel) printf("sqConnectSSL: peerName = %s\n", peerName); ssl->peerName = strdup(peerName); X509_free(cert); @@ -289,18 +297,18 @@ sqInt sqAcceptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt if(ssl->loglevel) printf("sqAcceptSSL: cert = %lx\n", (long)cert); if(cert) { - X509_NAME_get_text_by_NID(X509_get_subject_name(cert), - NID_commonName, peerName, - sizeof(peerName)); - if(ssl->loglevel) printf("sqAcceptSSL: peerName = %s\n", peerName); - ssl->peerName = strdup(peerName); - X509_free(cert); - - /* Check the result of verification */ - result = SSL_get_verify_result(ssl->ssl); - if(ssl->loglevel) printf("sqAcceptSSL: SSL_get_verify_result = %d\n", result); - /* FIXME: Figure out the actual failure reason */ - ssl->certFlags = result ? SQSSL_OTHER_ISSUE : SQSSL_OK; + X509_NAME_get_text_by_NID(X509_get_subject_name(cert), + NID_commonName, peerName, + sizeof(peerName)); + if(ssl->loglevel) printf("sqAcceptSSL: peerName = %s\n", peerName); + ssl->peerName = strdup(peerName); + X509_free(cert); + + /* Check the result of verification */ + result = SSL_get_verify_result(ssl->ssl); + if(ssl->loglevel) printf("sqAcceptSSL: SSL_get_verify_result = %d\n", result); + /* FIXME: Figure out the actual failure reason */ + ssl->certFlags = result ? SQSSL_OTHER_ISSUE : SQSSL_OK; } else { ssl->certFlags = SQSSL_NO_CERTIFICATE; } @@ -348,11 +356,11 @@ sqInt sqDecryptSSL(sqInt handle, char* srcBuf, sqInt srcLen, char *dstBuf, sqInt if(nbytes != srcLen) return SQSSL_GENERIC_ERROR; nbytes = SSL_read(ssl->ssl, dstBuf, dstLen); if(nbytes <= 0) { - int error = SSL_get_error(ssl->ssl, nbytes); - if(error != SSL_ERROR_WANT_READ && error != SSL_ERROR_ZERO_RETURN) { - return SQSSL_GENERIC_ERROR; - } - nbytes = 0; + int error = SSL_get_error(ssl->ssl, nbytes); + if(error != SSL_ERROR_WANT_READ && error != SSL_ERROR_ZERO_RETURN) { + return SQSSL_GENERIC_ERROR; + } + nbytes = 0; } return nbytes; } @@ -368,8 +376,9 @@ char* sqGetStringPropertySSL(sqInt handle, int propID) { if(ssl == NULL) return NULL; switch(propID) { - case SQSSL_PROP_PEERNAME: return ssl->peerName; - case SQSSL_PROP_CERTNAME: return ssl->certName; + case SQSSL_PROP_PEERNAME: return ssl->peerName; + case SQSSL_PROP_CERTNAME: return ssl->certName; + case SQSSL_PROP_SERVERNAME: return ssl->serverName; default: if(ssl->loglevel) printf("sqGetStringPropertySSL: Unknown property ID %d\n", propID); return NULL; @@ -382,7 +391,7 @@ char* sqGetStringPropertySSL(sqInt handle, int propID) { handle - the ssl handle propID - the property id to retrieve propName - the property string - propLen - the length of the property string + propLen - the length of the property string Returns: Non-zero if successful. */ sqInt sqSetStringPropertySSL(sqInt handle, int propID, char *propName, sqInt propLen) { @@ -392,15 +401,23 @@ sqInt sqSetStringPropertySSL(sqInt handle, int propID, char *propName, sqInt pro if(ssl == NULL) return 0; if(propLen) { - property = calloc(1, propLen+1); - memcpy(property, propName, propLen); + property = malloc(propLen + 1); + memcpy(property, propName, propLen); + property[propLen] = '\0'; }; if(ssl->loglevel) printf("sqSetStringPropertySSL(%d): %s\n", propID, property); switch(propID) { - case SQSSL_PROP_CERTNAME: ssl->certName = property; break; - default: + case SQSSL_PROP_CERTNAME: + if (ssl->certName) free(ssl->certName); + ssl->certName = property; + break; + case SQSSL_PROP_SERVERNAME: + if (ssl->serverName) free(ssl->serverName); + ssl->serverName = property; + break; + default: if(property) free(property); if(ssl->loglevel) printf("sqSetStringPropertySSL: Unknown property ID %d\n", propID); return 0;