Skip to content

Commit

Permalink
Merge pull request #71 from surajvshankar/tls-support
Browse files Browse the repository at this point in the history
Support TLS when using the Mysql driver
  • Loading branch information
szkiba authored Sep 6, 2024
2 parents a3010c8 + b72a1e0 commit e38fd45
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 7 deletions.
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ Supported RDBMSs: `mysql`, `postgres`, `sqlite3`, `sqlserver`, `azuresql`, `clic
directory for usage. Other RDBMSs are not supported, see
[details below](#support-for-other-rdbmss).

## Table of Contents
- [Build](#build)
- [Development](#development)
- [Example](#example)
- [TLS support](#tls-support)
- [Support for other RDBMSs](#support-for-other-rdbmss)
- [Docker](#docker)

## Build

Expand Down Expand Up @@ -116,10 +123,29 @@ default ✓ [======================================] 1 VUs 00m00.0s/10m0s 1/1
iterations...........: 1 15.292228/s
```
## See also
#### See also
- [Load Testing SQL Databases with k6](https://k6.io/blog/load-testing-sql-databases-with-k6/)
### TLS Support
Presently, TLS support is available only for the MySQL driver.
To enable TLS support, call `sql.loadTLS` from the script, before calling `sql.open`. [mysql_secure_test.js](examples/mysql_secure_test.js) is an example.
`sql.loadTLS` accepts the following options:
```javascript
sql.loadTLS({
enableTLS: true,
insecureSkipTLSverify: true,
minVersion: sql.TLS_1_2,
// Possible values: sql.TLS_1_0, sql.TLS_1_1, sql.TLS_1_2, sql.TLS_1_3
caCertFile: '/filepath/to/ca.pem',
clientCertFile: '/filepath/to/client-cert.pem',
clientKeyFile: '/filepath/to/client-key.pem',
});
```
### Support for other RDBMSs
Note that this project is not accepting support for additional SQL implementations
Expand Down
37 changes: 37 additions & 0 deletions examples/mysql_secure_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import sql from 'k6/x/sql';

sql.loadTLS({
enableTLS: true,
insecureSkipTLSverify: true,
minVersion: sql.TLS_1_2,
// Possible values: sql.TLS_1_0, sql.TLS_1_1, sql.TLS_1_2, sql.TLS_1_3
caCertFile: 'ca.pem',
clientCertFile: 'client-cert.pem',
clientKeyFile: 'client-key.pem',
});

const db = sql.open('mysql', 'root:password@tcp(localhost:3306)/mysql')

export function setup() {
db.exec(`
CREATE TABLE IF NOT EXISTS keyvalues (
id INT(6) UNSIGNED AUTO_INCREMENT PRIMARY KEY,
\`key\` VARCHAR(50) NOT NULL,
value VARCHAR(50) NULL
);
`);
}

export function teardown() {
db.close();
}

export default function () {
db.exec("INSERT INTO keyvalues (`key`, value) VALUES('plugin-name', 'k6-plugin-sql');");

let results = sql.query(db, "SELECT * FROM keyvalues WHERE `key` = ?;", 'plugin-name');
for (const row of results) {
// Convert array of ASCII integers into strings. See https://github.com/grafana/xk6-sql/issues/12
console.log(`key: ${String.fromCharCode(...row.key)}, value: ${String.fromCharCode(...row.value)}`);
}
}
147 changes: 141 additions & 6 deletions sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,36 @@
package sql

import (
"crypto/tls"
"crypto/x509"
dbsql "database/sql"
"encoding/json"
"fmt"
"os"
"strings"

"github.com/go-sql-driver/mysql"
"github.com/grafana/sobek"

// Blank imports required for initialization of drivers
_ "github.com/ClickHouse/clickhouse-go/v2"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/microsoft/go-mssqldb"
_ "github.com/microsoft/go-mssqldb/azuread"
"go.k6.io/k6/js/common"
"go.k6.io/k6/js/modules"
"go.k6.io/k6/lib/netext"
)

// supportedTLSVersions is a map of TLS versions to their numeric values.
var supportedTLSVersions = map[string]uint16{ //nolint: gochecknoglobals
netext.TLS_1_0: tls.VersionTLS10,
netext.TLS_1_1: tls.VersionTLS11,
netext.TLS_1_2: tls.VersionTLS12,
netext.TLS_1_3: tls.VersionTLS13,
}

func init() {
modules.Register("k6/x/sql", new(RootModule))
}
Expand All @@ -25,7 +42,9 @@ type RootModule struct{}

// SQL represents an instance of the SQL module for every VU.
type SQL struct {
vu modules.VU
vu modules.VU
exports *sobek.Object
tlsConfig TLSConfig
}

// Ensure the interfaces are implemented correctly.
Expand All @@ -37,13 +56,34 @@ var (
// NewModuleInstance implements the modules.Module interface to return
// a new instance for each VU.
func (*RootModule) NewModuleInstance(vu modules.VU) modules.Instance {
return &SQL{vu: vu}
runtime := vu.Runtime()

moduleInstance := &SQL{
vu: vu,
exports: runtime.NewObject(),
tlsConfig: TLSConfig{},
}
// Export constants to the JS code.
moduleInstance.defineConstants()

mustExport := func(name string, value interface{}) {
if err := moduleInstance.exports.Set(name, value); err != nil {
common.Throw(runtime, err)
}
}
mustExport("loadTLS", moduleInstance.LoadTLS)
mustExport("open", moduleInstance.Open)
mustExport("query", moduleInstance.Query)

return moduleInstance
}

// Exports implements the modules.Instance interface and returns the exports
// of the JS module.
func (sql *SQL) Exports() modules.Exports {
return modules.Exports{Default: sql}
return modules.Exports{
Default: sql.exports,
}
}

// KeyValue is a simple key-value pair.
Expand All @@ -58,22 +98,117 @@ func contains(array []string, element string) bool {
return false
}

// defineConstants defines the constants that can be used in the JS code.
func (sql *SQL) defineConstants() {
runtime := sql.vu.Runtime()
mustAddProp := func(name string, val interface{}) {
err := sql.exports.DefineDataProperty(
name, runtime.ToValue(val), sobek.FLAG_FALSE, sobek.FLAG_FALSE, sobek.FLAG_TRUE,
)
if err != nil {
common.Throw(runtime, err)
}
}

// TLS versions
mustAddProp("TLS_1_0", netext.TLS_1_0)
mustAddProp("TLS_1_1", netext.TLS_1_1)
mustAddProp("TLS_1_2", netext.TLS_1_2)
mustAddProp("TLS_1_3", netext.TLS_1_3)
}

// TLSConfig contains all the TLS configuration options passed between the JS and Go code.
type TLSConfig struct {
EnableTLS bool `json:"enableTLS"`
InsecureSkipTLSverify bool `json:"insecureSkipTLSverify"`
MinVersion string `json:"minVersion"`
CAcertFile string `json:"caCertFile"`
ClientCertFile string `json:"clientCertFile"`
ClientKeyFile string `json:"clientKeyFile"`
}

// LoadTLS loads the TLS configuration for the SQL module.
func (sql *SQL) LoadTLS(params map[string]interface{}) {
runtime := sql.vu.Runtime()
var tlsConfig *TLSConfig
if b, err := json.Marshal(params); err != nil {
common.Throw(runtime, err)
} else {
if err := json.Unmarshal(b, &tlsConfig); err != nil {
common.Throw(runtime, err)
}
}
if _, ok := supportedTLSVersions[tlsConfig.MinVersion]; !ok {
common.Throw(runtime, fmt.Errorf("unsupported TLS version: %s", tlsConfig.MinVersion))
}
sql.tlsConfig = *tlsConfig
}

// Open establishes a connection to the specified database type using
// the provided connection string.
func (*SQL) Open(database string, connectionString string) (*dbsql.DB, error) {
func (sql *SQL) Open(database string, connectionString string) (*dbsql.DB, error) {
supportedDatabases := []string{"mysql", "postgres", "sqlite3", "sqlserver", "azuresql", "clickhouse"}
if !contains(supportedDatabases, database) {
return nil, fmt.Errorf("database %s is not supported", database)
}

if database == "mysql" && sql.tlsConfig.EnableTLS {
const tlsConfigKey = "custom"
if err := registerTLS(tlsConfigKey, sql.tlsConfig); err != nil {
return nil, err
}
connectionString = prefixConnectionString(connectionString, tlsConfigKey)
}

db, err := dbsql.Open(database, connectionString)
if err != nil {
return nil, err
}

return db, nil
}

// prefixConnectionString prefixes the connection string with the TLS configuration key.
func prefixConnectionString(connectionString string, tlsConfigKey string) string {
tlsParam := fmt.Sprintf("tls=%s", tlsConfigKey)
if strings.Contains(connectionString, tlsParam) {
return connectionString
}
var separator string
if strings.Contains(connectionString, "?") {
separator = "&"
} else {
separator = "?"
}
return fmt.Sprintf("%s%s%s", connectionString, separator, tlsParam)
}

// registerTLS loads the ca-cert and registers the TLS configuration with the MySQL driver.
func registerTLS(tlsConfigKey string, tlsConfig TLSConfig) error {
rootCAs := x509.NewCertPool()
pem, err := os.ReadFile(tlsConfig.CAcertFile) //nolint: forbidigo
if err != nil {
return err
}
if ok := rootCAs.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append PEM")
}

clientCerts := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(tlsConfig.ClientCertFile, tlsConfig.ClientKeyFile)
if err != nil {
return err
}
clientCerts = append(clientCerts, certs)

mysqlTLSConfig := &tls.Config{
RootCAs: rootCAs,
Certificates: clientCerts,
MinVersion: supportedTLSVersions[tlsConfig.MinVersion],
InsecureSkipVerify: tlsConfig.InsecureSkipTLSverify, // #nosec G402
}
return mysql.RegisterTLSConfig(tlsConfigKey, mysqlTLSConfig)
}

// Query executes the provided query string against the database, while
// providing results as a slice of KeyValue instance(s) if available.
func (*SQL) Query(db *dbsql.DB, query string, args ...interface{}) ([]KeyValue, error) {
Expand Down
41 changes: 41 additions & 0 deletions sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/grafana/sobek"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.k6.io/k6/js/common"
"go.k6.io/k6/js/modulestest"
Expand Down Expand Up @@ -94,3 +95,43 @@ func setupTestEnv(t *testing.T) *sobek.Runtime {

return rt
}

func TestPrefixConnectionString(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
connectionString string
want string
}{
{
name: "HappyPath",
connectionString: "root:password@tcp(localhost:3306)/mysql",
want: "root:password@tcp(localhost:3306)/mysql?tls=custom",
},
{
name: "WithExistingParams",
connectionString: "root:password@tcp(localhost:3306)/mysql?param=value",
want: "root:password@tcp(localhost:3306)/mysql?param=value&tls=custom",
},
{
name: "WithExistingTLSparam",
connectionString: "root:password@tcp(localhost:3306)/mysql?tls=custom",
want: "root:password@tcp(localhost:3306)/mysql?tls=custom",
},
{
name: "WithExistingTLSparam",
connectionString: "root:password@tcp(localhost:3306)/mysql?tls=notcustom",
want: "root:password@tcp(localhost:3306)/mysql?tls=notcustom&tls=custom",
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := prefixConnectionString(tc.connectionString, "custom")
assert.Equal(t, tc.want, got)
})
}
}

0 comments on commit e38fd45

Please sign in to comment.