diff --git a/README.md b/README.md index 33ce2c9..5001ab5 100644 --- a/README.md +++ b/README.md @@ -3,12 +3,14 @@ This is a Redis client library for [k6](https://github.com/grafana/k6), implemented as an extension using the [xk6](https://github.com/grafana/xk6) system. -| :exclamation: This extension is going under heavy changes and is about to make it to k6's core. USE AT YOUR OWN RISK! | -| --------------------------------------------------------------------------------------------------------------------- | +| :exclamation: This extension is an experimental project, and breaking changes are possible. USE AT YOUR OWN RISK! | +|-------------------------------------------------------------------------------------------------------------------| ## Build -To build a `k6` binary with this extension, first ensure you have the prerequisites: +This extension is available as an [experimental k6 module](https://k6.io/docs/javascript-api/k6-experimental/redis/) since k6 v0.40.0, so you don't need to build it with xk6 yourself, and can use it with the main k6 binary. Note that your script must import `k6/experimental/redis` instead of `k6/x/redis` if you're using the module bundled in k6. + +However, if you prefer to build it from source using xk6, first ensure you have the prerequisites: - [Go toolchain](https://go101.org/article/go-toolchain.html) - Git @@ -39,11 +41,9 @@ with Redis in a seemingly synchronous manner. For instance, if you were to depend on values stored in Redis to perform HTTP calls, those HTTP calls should be made in the context of the Redis promise chain: ```javascript -// Instantiate a new redis client -const redisClient = new redis.Client({ - addrs: __ENV.REDIS_ADDRS.split(",") || new Array("localhost:6379"), // in the form of "host:port", separated by commas - password: __ENV.REDIS_PASSWORD || "", -}) +// Instantiate a new Redis client using a URL. +// The connection will be established on the first command call. +const client = new redis.Client('redis://localhost:6379'); export default function() { // Once the SRANDMEMBER operation is succesfull, @@ -51,9 +51,9 @@ export default function() { // set member value to the caller of the resolve callback. // // The next promise performs the synchronous HTTP call, and - // returns a promise to the next operation, which uses the + // returns a promise to the next operation, which uses the // passed URL value to store some data in redis. - redisClient.srandmember('client_ids') + client.srandmember('client_ids') .then((randomID) => { const url = `https://my.url/${randomID}` const res = http.get(url) @@ -63,134 +63,138 @@ export default function() { // return a promise resolving to the URL return url }) - .then((url) => redisClient.hincrby('k6_crocodile_fetched', url, 1)) + .then((url) => client.hincrby('k6_crocodile_fetched', url, 1)); } ``` -## Example test scripts +You can see more complete examples in the [/examples](/examples) directory. + + +### Single-node client -In this example we demonstrate two scenarios: one load testing a redis instance, another using redis as an external data store used throughout the test itself. +As shown in the above example, the simplest way to create a new `Client` instance that connects to a single Redis server is by passing a URL string. It must be in the format: + +``` +redis[s]://[[username][:password]@][host][:port][/db-number] +``` +A client can also be instantiated using an object, for more flexibility: ```javascript -import { check } from "k6"; -import http from "k6/http"; -import redis from "k6/x/redis"; -import exec from "k6/execution"; -import { textSummary } from "https://jslib.k6.io/k6-summary/0.0.1/index.js"; - -export const options = { - scenarios: { - redisPerformance: { - executor: "shared-iterations", - vus: 10, - iterations: 200, - exec: "measureRedisPerformance", - }, - usingRedisData: { - executor: "shared-iterations", - vus: 10, - iterations: 200, - exec: "measureUsingRedisData", - }, +const client = new redis.Client({ + socket: { + host: 'localhost', + port: 6379, }, -}; + username: 'someusername', + password: 'somepassword', +}); +``` + + +### Cluster client -// Instantiate a new redis client -const redisClient = new redis.Client({ - addrs: __ENV.REDIS_ADDRS.split(",") || new Array("localhost:6379"), // in the form of "host:port", separated by commas - password: __ENV.REDIS_PASSWORD || "", +You can connect to a cluster of Redis servers by using the `cluster` property, and passing 2 or more node URLs: +```javascript +const client = new redis.Client({ + cluster: { + // Cluster options + maxRedirects: 3, + readOnly: true, + routeByLatency: true, + routeRandomly: true, + nodes: ['redis://host1:6379', 'redis://host2:6379'] + } }); +``` -// Prepare an array of crocodile ids for later use -// in the context of the measureUsingRedisData function. -const crocodileIDs = new Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); - -export function measureRedisPerformance() { - // VUs are executed in a parallel fashion, - // thus, to ensure that parallel VUs are not - // modifying the same key at the same time, - // we use keys indexed by the VU id. - const key = `foo-${exec.vu.idInTest}`; - - redisClient - .set(`foo-${exec.vu.idInTest}`, 1) - .then(() => redisClient.get(`foo-${exec.vu.idInTest}`)) - .then((value) => redisClient.incrBy(`foo-${exec.vu.idInTest}`, value)) - .then((_) => redisClient.del(`foo-${exec.vu.idInTest}`)) - .then((_) => redisClient.exists(`foo-${exec.vu.idInTest}`)) - .then((exists) => { - if (exists !== 0) { - throw new Error("foo should have been deleted"); +Or the same as above, but using node objects: +```javascript +const client = new redis.Client({ + cluster: { + nodes: [ + { + socket: { + host: 'host1', + port: 6379, + } + }, + { + socket: { + host: 'host2', + port: 6379, + } } - }); -} + ] + } +}); +``` -export function setup() { - redisClient.sadd("crocodile_ids", ...crocodileIDs); -} -export function measureUsingRedisData() { - // Pick a random crocodile id from the dedicated redis set, - // we have filled in setup(). - redisClient - .srandmember("crocodile_ids") - .then((randomID) => { - const url = `https://test-api.k6.io/public/crocodiles/${randomID}`; - const res = http.get(url); - - check(res, { - "status is 200": (r) => r.status === 200, - "content-type is application/json": (r) => - r.headers["content-type"] === "application/json", - }); - - return url; - }) - .then((url) => redisClient.hincrby("k6_crocodile_fetched", url, 1)); -} +### Sentinel (failover) client -export function teardown() { - redisClient.del("crocodile_ids"); -} +A [Redis Sentinel](https://redis.io/docs/management/sentinel/) provides high availability features, as an alternative to a Redis cluster. -export function handleSummary(data) { - redisClient - .hgetall("k6_crocodile_fetched") - .then((fetched) => Object.assign(data, { k6_crocodile_fetched: fetched })) - .then((data) => - redisClient.set(`k6_report_${Date.now()}`, JSON.stringify(data)) - ) - .then(() => redisClient.del("k6_crocodile_fetched")); - - return { - stdout: textSummary(data, { indent: " ", enableColors: true }), - }; -} +You can connect to a sentinel instance by setting additional options in the object passed to the `Client` constructor: +```javascript +const client = new redis.Client({ + username: 'someusername', + password: 'somepassword', + socket: { + host: 'localhost', + port: 6379, + }, + // Sentinel options + masterName: 'masterhost', + sentinelUsername: 'sentineluser', + sentinelPassword: 'sentinelpass', +}); ``` -Result output: -```shell -$ ./k6 run test.js +### TLS + +A TLS connection can be established in a couple of ways. + +If the server has a certificate signed by a public Certificate Authority, you can use the `rediss` URL scheme: +```javascript +const client = new redis.Client('rediss://example.com'); +``` + +Otherwise, you can supply your own self-signed certificate in PEM format using the `socket.tls` object: +```javascript +const client = new redis.Client({ + socket: { + host: 'localhost', + port: 6379, + tls: { + ca: [open('ca.crt')], + } + }, +}); +``` - /\ |‾‾| /‾‾/ /‾‾/ - /\ / \ | |/ / / / - / \/ \ | ( / ‾‾\ - / \ | |\ \ | (‾) | - / __________ \ |__| \__\ \_____/ .io +Note that for self-signed certificates, [k6's `insecureSkipTLSVerify` option](https://k6.io/docs/using-k6/k6-options/reference/#insecure-skip-tls-verify) must be enabled (set to `true`). - execution: local - script: test.js - output: - - scenarios: (100.00%) 1 scenario, 10 max VUs, 1m30s max duration (incl. graceful stop): - * default: 10 looping VUs for 1m0s (gracefulStop: 30s) +### TLS client authentication (mTLS) +You can also enable mTLS by setting two additional properties in the `socket.tls` object: -running (1m00.1s), 00/10 VUs, 4954 complete and 0 interrupted iterations -default ✓ [======================================] 10 VUs 1m0s +```javascript +const client = new redis.Client({ + socket: { + host: 'localhost', + port: 6379, + tls: { + ca: [open('ca.crt')], + cert: open('client.crt'), // client certificate + key: open('client.key'), // client private key + } + }, +}); ``` + ## API xk6-redis exposes a subset of Redis' [commands](https://redis.io/commands) the core team judged relevant in the context of k6 scripts. @@ -228,7 +232,7 @@ xk6-redis exposes a subset of Redis' [commands](https://redis.io/commands) the c | **LSET** | `lset(key: string, index: number, element: string)` | Sets the list element at `index` to `element`. | On **success**, the promise **resolves** with `"OK"`. If the list does not exist, or the index is out of bounds, the promise is **rejected** with an error. | | **LREM** | `lrem(key: string, count: number, value: string) => Promise` | Removes the first `count` occurrences of `value` from the list stored at `key`. If `count` is positive, elements are removed from the beginning of the list. If `count` is negative, elements are removed from the end of the list. If `count` is zero, all elements matching `value` are removed. | On **success**, the promise **resolves** with the number of removed elements. If the list does not exist, the promise is **rejected** with an error. | | **LLEN** | `llen(key: string) => Promise` | Returns the length of the list stored at `key`. If `key` does not exist, it is interpreted as an empty list and 0 is returned. | On **success**, the promise **resolves** with the length of the list at `key`. If the list does not exist, the promise is **rejected** with an error. | - + ### Hash field operations | Redis Command | Module function signature | Description | Returns | diff --git a/examples/loadtest.js b/examples/loadtest.js index b0c7b22..33562d2 100644 --- a/examples/loadtest.js +++ b/examples/loadtest.js @@ -21,11 +21,11 @@ export const options = { }, }; -// Instantiate a new redis client -const redisClient = new redis.Client({ - addrs: __ENV.REDIS_ADDRS.split(",") || new Array("localhost:6379"), // in the form of "host:port", separated by commas - password: __ENV.REDIS_PASSWORD || "", -}); +// Instantiate a new Redis client using a URL +const redisClient = new redis.Client( + // URL in the form of redis[s]://[[username][:password]@][host][:port][/db-number + __ENV.REDIS_URL || "redis://localhost:6379", +); // Prepare an array of crocodile ids for later use // in the context of the measureUsingRedisData function. diff --git a/examples/tls/README.md b/examples/tls/README.md new file mode 100644 index 0000000..839de32 --- /dev/null +++ b/examples/tls/README.md @@ -0,0 +1,8 @@ +# How to run a k6 test against a Redis test server with TLS + +1. Move in the docker folder `cd docker` +2. Run `sh gen-test-certs.sh` to generate custom TLS certificates that the docker container will use. +3. Run `docker-compose up` to start the Redis server with TLS enabled. +4. Connect to it with `redis-cli --tls --cert ./tests/tls/redis.crt --key ./tests/tls/redis.key --cacert ./tests/tls/ca.crt` and run `AUTH tjkbZ8jrwz3pGiku` to authenticate, and verify that the redis server is properly set up. +5. Build the k6 binary with `xk6 build --with github.com/k6io/xk6-redis=.` +5. Run `./k6 run loadtest-tls.js` to run the k6 load test with TLS enabled. \ No newline at end of file diff --git a/examples/tls/docker/docker-compose.yml b/examples/tls/docker/docker-compose.yml new file mode 100644 index 0000000..6f32ebf --- /dev/null +++ b/examples/tls/docker/docker-compose.yml @@ -0,0 +1,26 @@ +version: "3.3" + +services: + redis: + image: docker.io/bitnami/redis:7.0.8 + user: root + restart: always + environment: + - ALLOW_EMPTY_PASSWORD=false + - REDIS_PASSWORD=tjkbZ8jrwz3pGiku + - REDIS_DISABLE_COMMANDS=FLUSHDB,FLUSHALL + - REDIS_EXTRA_FLAGS=--loglevel verbose --tls-auth-clients optional + - REDIS_TLS_ENABLED=yes + - REDIS_TLS_PORT=6379 + - REDIS_TLS_CERT_FILE=/tls/redis.crt + - REDIS_TLS_KEY_FILE=/tls/redis.key + - REDIS_TLS_CA_FILE=/tls/ca.crt + ports: + - "6379:6379" + volumes: + - redis_data:/bitnami/redis/data + - ./tests/tls:/tls + +volumes: + redis_data: + driver: local diff --git a/examples/tls/docker/gen-test-certs.sh b/examples/tls/docker/gen-test-certs.sh new file mode 100755 index 0000000..b6c7eb8 --- /dev/null +++ b/examples/tls/docker/gen-test-certs.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# Generate some test certificates which are used by the regression test suite: +# +# tests/tls/ca.{crt,key} Self signed CA certificate. +# tests/tls/redis.{crt,key} A certificate with no key usage/policy restrictions. +# tests/tls/client.{crt,key} A certificate restricted for SSL client usage. +# tests/tls/server.{crt,key} A certificate restricted for SSL server usage. +# tests/tls/redis.dh DH Params file. + +generate_cert() { + local name=$1 + local cn="$2" + local opts="$3" + + local keyfile=tests/tls/${name}.key + local certfile=tests/tls/${name}.crt + + [ -f $keyfile ] || openssl genrsa -out $keyfile 2048 + openssl req \ + -new -sha256 \ + -subj "/O=Redis Test/CN=$cn" \ + -key $keyfile | + openssl x509 \ + -req -sha256 \ + -CA tests/tls/ca.crt \ + -CAkey tests/tls/ca.key \ + -CAserial tests/tls/ca.txt \ + -CAcreateserial \ + -days 365 \ + $opts \ + -out $certfile +} + +mkdir -p tests/tls +[ -f tests/tls/ca.key ] || openssl genrsa -out tests/tls/ca.key 4096 +openssl req \ + -x509 -new -nodes -sha256 \ + -key tests/tls/ca.key \ + -days 3650 \ + -subj '/O=Redis Test/CN=Certificate Authority' \ + -out tests/tls/ca.crt + +cat >tests/tls/openssl.cnf <<_END_ +[ server_cert ] +keyUsage = digitalSignature, keyEncipherment +nsCertType = server + +[ client_cert ] +keyUsage = digitalSignature, keyEncipherment +nsCertType = client +_END_ + +generate_cert server "Server-only" "-extfile tests/tls/openssl.cnf -extensions server_cert" +generate_cert client "Client-only" "-extfile tests/tls/openssl.cnf -extensions client_cert" +generate_cert redis "Generic-cert" + +[ -f tests/tls/redis.dh ] || openssl dhparam -out tests/tls/redis.dh 2048 diff --git a/examples/tls/loadtest-tls.js b/examples/tls/loadtest-tls.js new file mode 100644 index 0000000..0221f15 --- /dev/null +++ b/examples/tls/loadtest-tls.js @@ -0,0 +1,42 @@ +import redis from "k6/x/redis"; +import exec from "k6/execution"; + +export const options = { + vus: 10, + iterations: 200, + insecureSkipTLSVerify: true, +}; + +// Instantiate a new Redis client using a URL +// const client = new redis.Client('rediss://localhost:6379') +const client = new redis.Client({ + password: "tjkbZ8jrwz3pGiku", + socket:{ + host: "localhost", + port: 6379, + tls: { + ca: [open('docker/tests/tls/ca.crt')], + cert: open('docker/tests/tls/client.crt'), // client cert + key: open('docker/tests/tls/client.key'), // client private key + } + } +}); + +export default async function () { + // VUs are executed in a parallel fashion, + // thus, to ensure that parallel VUs are not + // modifying the same key at the same time, + // we use keys indexed by the VU id. + const key = `foo-${exec.vu.idInTest}`; + + await client.set(`foo-${exec.vu.idInTest}`, 1) + + let value = await client.get(`foo-${exec.vu.idInTest}`) + value = await client.incrBy(`foo-${exec.vu.idInTest}`, value) + + await client.del(`foo-${exec.vu.idInTest}`) + const exists = await client.exists(`foo-${exec.vu.idInTest}`) + if (exists !== 0) { + throw new Error("foo should have been deleted"); + } +} diff --git a/redis/client.go b/redis/client.go index cb9dad4..ebc1d68 100644 --- a/redis/client.go +++ b/redis/client.go @@ -1,13 +1,17 @@ package redis import ( + "context" + "crypto/tls" "fmt" + "net" "time" "github.com/dop251/goja" "github.com/redis/go-redis/v9" "go.k6.io/k6/js/common" "go.k6.io/k6/js/modules" + "go.k6.io/k6/lib" ) // Client represents the Client constructor (i.e. `new redis.Client()`) and @@ -1077,16 +1081,39 @@ func (c *Client) connect() error { return nil } - // If k6 has a TLSConfig set in its state, use - // it has redis' client TLSConfig too. - if vuState.TLSConfig != nil { - c.redisOptions.TLSConfig = vuState.TLSConfig + tlsCfg := c.redisOptions.TLSConfig + if tlsCfg != nil && vuState.TLSConfig != nil { + // Merge k6 TLS configuration with the one we received from the + // Client constructor. This will need adjusting depending on which + // options we want to expose in the Redis module, and how we want + // the override to work. + tlsCfg.InsecureSkipVerify = vuState.TLSConfig.InsecureSkipVerify + tlsCfg.CipherSuites = vuState.TLSConfig.CipherSuites + tlsCfg.MinVersion = vuState.TLSConfig.MinVersion + tlsCfg.MaxVersion = vuState.TLSConfig.MaxVersion + tlsCfg.Renegotiation = vuState.TLSConfig.Renegotiation + tlsCfg.KeyLogWriter = vuState.TLSConfig.KeyLogWriter + tlsCfg.Certificates = append(tlsCfg.Certificates, vuState.TLSConfig.Certificates...) + + // TODO: Merge vuState.TLSConfig.RootCAs with + // c.redisOptions.TLSConfig. k6 currently doesn't allow setting + // this, so it doesn't matter right now, but these should be merged. + // I couldn't find a way to do this with the x509.CertPool API + // though... + + // In order to preserve the underlying effects of the [netext.Dialer], such + // as handling blocked hostnames, or handling hostname resolution, we override + // the redis client's dialer with our own function which uses the VU's [netext.Dialer] + // and manually upgrades the connection to TLS. + // + // See Pull Request's #17 [discussion] for more details. + // + // [discussion]: https://github.com/grafana/xk6-redis/pull/17#discussion_r1369707388 + c.redisOptions.Dialer = c.upgradeDialerToTLS(vuState.Dialer, tlsCfg) + } else { + c.redisOptions.Dialer = vuState.Dialer.DialContext } - // use k6's lib.DialerContexter function has redis' - // client Dialer - c.redisOptions.Dialer = vuState.Dialer.DialContext - // Replace the internal redis client instance with a new // one using our custom options. c.redisClient = redis.NewUniversalClient(c.redisOptions) @@ -1126,3 +1153,39 @@ func (c *Client) isSupportedType(offset int, args ...interface{}) error { return nil } + +// DialContextFunc is a function that can be used to dial a connection to a redis server. +type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error) + +// upgradeDialerToTLS returns a DialContextFunc that uses the provided dialer to +// establish a connection, and then upgrades it to TLS using the provided config. +// +// We use this function to make sure the k6 [netext.Dialer], our redis module uses to establish +// the connection and handle network-related options such as blocked hostnames, +// or hostname resolution, but we also want to use the TLS configuration provided +// by the user. +func (c *Client) upgradeDialerToTLS(dialer lib.DialContexter, config *tls.Config) DialContextFunc { + return func(ctx context.Context, network string, addr string) (net.Conn, error) { + // Use netext.Dialer to establish the connection + rawConn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + // Upgrade the connection to TLS if needed + tlsConn := tls.Client(rawConn, config) + err = tlsConn.Handshake() + if err != nil { + if closeErr := rawConn.Close(); closeErr != nil { + return nil, fmt.Errorf("failed to close connection after TLS handshake error: %w", closeErr) + } + + return nil, err + } + + // Overwrite rawConn with the TLS connection + rawConn = tlsConn + + return rawConn, nil + } +} diff --git a/redis/client_test.go b/redis/client_test.go index d675c8e..de63fc2 100644 --- a/redis/client_test.go +++ b/redis/client_test.go @@ -19,6 +19,199 @@ import ( "gopkg.in/guregu/null.v3" ) +func TestClientConstructor(t *testing.T) { + t.Parallel() + + testCases := []struct { + name, arg, expErr string + }{ + { + name: "ok/url/tcp", + arg: "'redis://user:pass@localhost:6379/0'", + }, + { + name: "ok/url/tls", + arg: "'rediss://somesecurehost'", + }, + { + name: "ok/object/single", + arg: `{ + clientName: 'myclient', + username: 'user', + password: 'pass', + socket: { + host: 'localhost', + port: 6379, + } + }`, + }, + { + name: "ok/object/single_tls", + arg: `{ + socket: { + host: 'localhost', + port: 6379, + tls: { + ca: ['...'], + } + } + }`, + }, + { + name: "ok/object/cluster_urls", + arg: `{ + cluster: { + maxRedirects: 3, + readOnly: true, + routeByLatency: true, + routeRandomly: true, + nodes: ['redis://host1:6379', 'redis://host2:6379'] + } + }`, + }, + { + name: "ok/object/cluster_objects", + arg: `{ + cluster: { + nodes: [ + { + username: 'user', + password: 'pass', + socket: { + host: 'host1', + port: 6379, + }, + }, + { + username: 'user', + password: 'pass', + socket: { + host: 'host2', + port: 6379, + }, + } + ] + } + }`, + }, + { + name: "ok/object/sentinel", + arg: `{ + username: 'user', + password: 'pass', + socket: { + host: 'localhost', + port: 6379, + }, + masterName: 'masterhost', + sentinelUsername: 'sentineluser', + sentinelPassword: 'sentinelpass', + }`, + }, + { + name: "err/empty", + arg: "", + expErr: "must specify one argument", + }, + { + name: "err/url/missing_scheme", + arg: "'localhost:6379'", + expErr: "invalid URL scheme", + }, + { + name: "err/url/invalid_scheme", + arg: "'https://localhost:6379'", + expErr: "invalid options; reason: redis: invalid URL scheme: https", + }, + { + name: "err/object/unknown_field", + arg: "{addrs: ['localhost:6379']}", + expErr: `invalid options; reason: json: unknown field "addrs"`, + }, + { + name: "err/object/empty_socket", + arg: `{ + username: 'user', + password: 'pass', + }`, + expErr: "invalid options; reason: empty socket options", + }, + { + name: "err/object/cluster_wrong_type", + arg: `{ + cluster: { + nodes: 1, + } + }`, + expErr: `invalid options; reason: cluster nodes property must be an array; got int64`, + }, + { + name: "err/object/cluster_wrong_type_internal", + arg: `{ + cluster: { + nodes: [1, 2], + } + }`, + expErr: `invalid options; reason: cluster nodes array must contain string or object elements; got int64`, + }, + { + name: "err/object/cluster_empty", + arg: `{ + cluster: { + nodes: [] + } + }`, + expErr: `invalid options; reason: cluster nodes property cannot be empty`, + }, + { + name: "err/object/cluster_inconsistent_option", + arg: `{ + cluster: { + nodes: [ + { + username: 'user1', + password: 'pass', + socket: { + host: 'host1', + port: 6379, + }, + }, + { + username: 'user2', + password: 'pass', + socket: { + host: 'host2', + port: 6379, + }, + } + ] + } + }`, + expErr: `invalid options; reason: inconsistent username option: user1 != user2`, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ts := newTestSetup(t) + script := fmt.Sprintf("new Client(%s);", tc.arg) + gotScriptErr := ts.ev.Start(func() error { + _, err := ts.rt.RunString(script) + return err + }) + if tc.expErr != "" { + require.Error(t, gotScriptErr) + assert.Contains(t, gotScriptErr.Error(), tc.expErr) + } else { + assert.NoError(t, gotScriptErr) + } + }) + } +} + func TestClientSet(t *testing.T) { t.Parallel() @@ -43,9 +236,7 @@ func TestClientSet(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.set("existing_key", "new_value") .then(res => { if (res !== "OK") { throw 'unexpected value for set result: ' + res } }) @@ -94,9 +285,7 @@ func TestClientGet(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.get("existing_key") .then(res => { if (res !== "old_value") { throw 'unexpected value for get result: ' + res } }) @@ -140,9 +329,7 @@ func TestClientGetSet(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.getSet("existing_key", "new_value") .then(res => { if (res !== "old_value") { throw 'unexpected value for getSet result: ' + res } }) @@ -183,9 +370,7 @@ func TestClientDel(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.del("key1", "key2", "nonexisting_key") .then(res => { if (res !== 2) { throw 'unexpected value for del result: ' + res } }) @@ -223,9 +408,7 @@ func TestClientGetDel(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.getDel("existing_key") .then(res => { if (res !== "old_value") { throw 'unexpected value for getDel result: ' + res } }) @@ -264,9 +447,7 @@ func TestClientExists(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.exists("existing_key", "nonexisting_key") .then(res => { if (res !== 1) { throw 'unexpected value for exists result: ' + res } }) @@ -306,9 +487,7 @@ func TestClientIncr(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.incr("existing_key") .then(res => { if (res !== 11) { throw 'unexpected value for existing key incr result: ' + res } }) @@ -357,9 +536,7 @@ func TestClientIncrBy(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.incrBy("existing_key", 10) .then(res => { if (res !== 20) { throw 'unexpected value for incrBy result: ' + res } }) @@ -402,9 +579,7 @@ func TestClientDecr(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.decr("existing_key") .then(res => { if (res !== 9) { throw 'unexpected value for decr result: ' + res } }) @@ -453,9 +628,7 @@ func TestClientDecrBy(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.decrBy("existing_key", 2) .then(res => { if (res !== 8) { throw 'unexpected value for decrBy result: ' + res } }) @@ -499,9 +672,7 @@ func TestClientRandomKey(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.randomKey() .then( @@ -540,9 +711,7 @@ func TestClientMget(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.mget("existing_key", "non_existing_key") .then( @@ -586,9 +755,7 @@ func TestClientExpire(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.expire("expires_key", 10) .then(res => { if (res !== true) { throw 'unexpected value for expire result: ' + res } }) @@ -629,9 +796,7 @@ func TestClientTTL(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.ttl("expires_key") .then(res => { if (res !== 10) { throw 'unexpected value for expire result: ' + res } }) @@ -672,9 +837,7 @@ func TestClientPersist(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.persist("expires_key") .then(res => { if (res !== true) { throw 'unexpected value for expire result: ' + res } }) @@ -718,9 +881,7 @@ func TestClientLPush(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.lpush("existing_list", "second", "first") .then(res => { if (res !== 3) { throw 'unexpected value for lpush result: ' + res } }) @@ -764,9 +925,7 @@ func TestClientRPush(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.rpush("existing_list", "second", "third") .then(res => { if (res !== 3) { throw 'unexpected value for rpush result: ' + res } }) @@ -809,9 +968,7 @@ func TestClientLPop(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.lpop("existing_list") .then(res => { if (res !== "first") { throw 'unexpected value for lpop first result: ' + res } }) @@ -862,9 +1019,7 @@ func TestClientRPop(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.rpop("existing_list") .then(res => { if (res !== "second") { throw 'unexpected value for rpop result: ' + res }}) @@ -927,9 +1082,7 @@ func TestClientLRange(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.lrange("existing_list", 0, 0) .then(res => { if (res.length !== 1 || res[0] !== "first") { throw 'unexpected value for lrange result: ' + res }}) @@ -993,9 +1146,7 @@ func TestClientLIndex(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.lindex("existing_list", 0) .then(res => { if (res !== "first") { throw 'unexpected value for lindex result: ' + res } }) @@ -1053,9 +1204,7 @@ func TestClientClientLSet(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.lset("existing_list", 0, "new_first") .then(res => { if (res !== "OK") { throw 'unexpected value for lset result: ' + res }}) @@ -1101,9 +1250,7 @@ func TestClientLrem(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.lrem("existing_list", 1, "first") .then(() => redis.lrem("existing_list", 0, "second")) @@ -1150,9 +1297,7 @@ func TestClientLlen(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.llen("existing_list") .then(res => { if (res !== 3) { throw 'unexpected value for llen result: ' + res } }) @@ -1198,9 +1343,7 @@ func TestClientHSet(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hset("existing_hash", "key", "value") .then(res => { if (res !== 1) { throw 'unexpected value for hset result: ' + res } }) @@ -1254,9 +1397,7 @@ func TestClientHsetnx(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hsetnx("existing_hash", "key", "value") .then(res => { if (res !== true) { throw 'unexpected value for hsetnx result: ' + res } }) @@ -1300,9 +1441,7 @@ func TestClientHget(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hget("existing_hash", "foo") .then(res => { if (res !== "bar") { throw 'unexpected value for hget result: ' + res } }) @@ -1346,9 +1485,7 @@ func TestClientHdel(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hdel("existing_hash", "foo") .then(res => { if (res !== 1) { throw 'unexpected value for hdel result: ' + res } }) @@ -1392,9 +1529,7 @@ func TestClientHgetall(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hgetall("existing_hash") .then(res => { if (typeof res !== "object" || res['foo'] !== 'bar') { throw 'unexpected value for hgetall result: ' + res } }) @@ -1438,9 +1573,7 @@ func TestClientHkeys(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hkeys("existing_hash") .then(res => { if (res.length !== 1 || res[0] !== 'foo') { throw 'unexpected value for hkeys result: ' + res } }) @@ -1484,9 +1617,7 @@ func TestClientHvals(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hvals("existing_hash") .then(res => { if (res.length !== 1 || res[0] !== 'bar') { throw 'unexpected value for hvals result: ' + res } }) @@ -1530,9 +1661,7 @@ func TestClientHlen(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hlen("existing_hash") .then(res => { if (res !== 1) { throw 'unexpected value for hlen result: ' + res } }) @@ -1585,9 +1714,7 @@ func TestClientHincrby(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.hincrby("existing_hash", "foo", 1) .then(res => { if (res !== 2) { throw 'unexpected value for hincrby result: ' + res } }) @@ -1638,9 +1765,7 @@ func TestClientSadd(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.sadd("existing_set", "bar") .then(res => { if (res !== 1) { throw 'unexpected value for sadd result: ' + res } }) @@ -1691,9 +1816,7 @@ func TestClientSrem(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.srem("existing_set", "foo") .then(res => { if (res !== 1) { throw 'unexpected value for srem result: ' + res } }) @@ -1745,9 +1868,7 @@ func TestClientSismember(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.sismember("existing_set", "foo") .then(res => { if (res !== true) { throw 'unexpected value for sismember result: ' + res } }) @@ -1791,9 +1912,7 @@ func TestClientSmembers(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.smembers("existing_set") .then(res => { if (res.length !== 2 || 'foo' in res || 'bar' in res) { throw 'unexpected value for smembers result: ' + res } }) @@ -1834,9 +1953,7 @@ func TestClientSrandmember(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.srandmember("existing_set") .then(res => { if (res !== 'foo' && res !== 'bar') { throw 'unexpected value for srandmember result: ' + res} }) @@ -1880,9 +1997,7 @@ func TestClientSpop(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.spop("existing_set") .then(res => { if (res !== 'foo' && res !== 'bar') { throw 'unexpected value for spop result: ' + res} }) @@ -1928,9 +2043,7 @@ func TestClientSendCommand(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("%s"), - }); + const redis = new Client('redis://%s'); redis.sendCommand("sadd", "existing_set", "foo") .then(res => { if (res !== 1) { throw 'unexpected value for sadd result: ' + res } }) @@ -2134,9 +2247,7 @@ func TestClientCommandsInInitContext(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("unreachable:42424"), - }); + const redis = new Client('redis://unreachable:42424'); %s.then(res => { throw 'expected to fail when called in the init context' }) `, tc.statement)) @@ -2332,9 +2443,7 @@ func TestClientCommandsAgainstUnreachableServer(t *testing.T) { gotScriptErr := ts.ev.Start(func() error { _, err := ts.rt.RunString(fmt.Sprintf(` - const redis = new Client({ - addrs: new Array("unreachable:42424"), - }); + const redis = new Client('redis://unreachable:42424'); %s.then(res => { throw 'expected to fail when server is unreachable' }) `, tc.statement)) @@ -2451,6 +2560,7 @@ type testSetup struct { state *lib.State samples chan metrics.SampleContainer ev *eventloop.EventLoop + tb *httpmultibin.HTTPMultiBin } // newTestSetup initializes a new test setup. @@ -2468,6 +2578,12 @@ func newTestSetup(t testing.TB) testSetup { samples := make(chan metrics.SampleContainer, 1000) + // We use self-signed TLS certificates for some tests, and need to disable + // strict verification. Since we don't use the k6 js.Runner, we can't set + // the k6 option InsecureSkipTLSVerify for this, and must override it in the + // TLS config we use from HTTPMultiBin. + tb.TLSClientConfig.InsecureSkipVerify = true + state := &lib.State{ Group: root, Dialer: tb.Dialer, @@ -2503,6 +2619,7 @@ func newTestSetup(t testing.TB) testSetup { state: state, samples: samples, ev: ev, + tb: tb, } } @@ -2538,3 +2655,131 @@ func newInitContextTestSetup(t testing.TB) testSetup { ev: ev, } } + +func TestClientTLS(t *testing.T) { + t.Parallel() + + ts := newTestSetup(t) + rs := RunTSecure(t, nil) + + err := ts.rt.Set("caCert", string(rs.TLSCertificate())) + require.NoError(t, err) + + gotScriptErr := ts.ev.Start(func() error { + _, err := ts.rt.RunString(fmt.Sprintf(` + const redis = new Client({ + socket: { + host: '%s', + port: %d, + tls: { + ca: [caCert], + } + } + }); + + redis.sendCommand("PING"); + `, rs.Addr().IP.String(), rs.Addr().Port)) + + return err + }) + + require.NoError(t, gotScriptErr) + assert.Equal(t, 1, rs.HandledCommandsCount()) + assert.Equal(t, [][]string{ + {"HELLO", "2"}, + {"PING"}, + }, rs.GotCommands()) +} + +func TestClientTLSAuth(t *testing.T) { + t.Parallel() + + clientCert, clientPKey, err := generateTLSCert() + require.NoError(t, err) + + ts := newTestSetup(t) + rs := RunTSecure(t, clientCert) + + err = ts.rt.Set("caCert", string(rs.TLSCertificate())) + require.NoError(t, err) + err = ts.rt.Set("clientCert", string(clientCert)) + require.NoError(t, err) + err = ts.rt.Set("clientPKey", string(clientPKey)) + require.NoError(t, err) + + gotScriptErr := ts.ev.Start(func() error { + _, err := ts.rt.RunString(fmt.Sprintf(` + const redis = new Client({ + socket: { + host: '%s', + port: %d, + tls: { + ca: [caCert], + cert: clientCert, + key: clientPKey + } + } + }); + + redis.sendCommand("PING"); + `, rs.Addr().IP.String(), rs.Addr().Port)) + + return err + }) + + require.NoError(t, gotScriptErr) + assert.Equal(t, 1, rs.HandledCommandsCount()) + assert.Equal(t, [][]string{ + {"HELLO", "2"}, + {"PING"}, + }, rs.GotCommands()) +} + +func TestClientTLSRespectsNetworkOPtions(t *testing.T) { + t.Parallel() + + clientCert, clientPKey, err := generateTLSCert() + require.NoError(t, err) + + ts := newTestSetup(t) + rs := RunTSecure(t, clientCert) + + err = ts.rt.Set("caCert", string(rs.TLSCertificate())) + require.NoError(t, err) + err = ts.rt.Set("clientCert", string(clientCert)) + require.NoError(t, err) + err = ts.rt.Set("clientPKey", string(clientPKey)) + require.NoError(t, err) + + // Set the redis server's IP to be blacklisted. + net, err := lib.ParseCIDR(rs.Addr().IP.String() + "/32") + require.NoError(t, err) + ts.tb.Dialer.Blacklist = []*lib.IPNet{net} + + gotScriptErr := ts.ev.Start(func() error { + _, err := ts.rt.RunString(fmt.Sprintf(` + const redis = new Client({ + socket: { + host: '%s', + port: %d, + tls: { + ca: [caCert], + cert: clientCert, + key: clientPKey + } + } + }); + + // This operation triggers a connection to the redis + // server under the hood, and should therefore fail, since + // the server's IP is blacklisted by k6. + redis.sendCommand("PING") + `, rs.Addr().IP.String(), rs.Addr().Port)) + + return err + }) + + assert.Error(t, gotScriptErr) + assert.ErrorContains(t, gotScriptErr, "IP ("+rs.Addr().IP.String()+") is in a blacklisted range") + assert.Equal(t, 0, rs.HandledCommandsCount()) +} diff --git a/redis/module.go b/redis/module.go index 7e210fb..7c44ffb 100644 --- a/redis/module.go +++ b/redis/module.go @@ -2,14 +2,9 @@ package redis import ( - "bytes" - "encoding/json" "errors" - "fmt" - "time" "github.com/dop251/goja" - "github.com/redis/go-redis/v9" "go.k6.io/k6/js/common" "go.k6.io/k6/js/modules" ) @@ -59,128 +54,35 @@ func (mi *ModuleInstance) Exports() modules.Exports { // the internal universal client instance will be one of those. // // The type of the underlying client depends on the following conditions: -// 1. If the MasterName option is specified, a sentinel-backed FailoverClient is used. -// 2. if the number of Addrs is two or more, a ClusterClient is used. +// If the first argument is a string, it's parsed as a Redis URL, and a +// single-node Client is used. +// Otherwise, an object is expected, and depending on its properties: +// 1. If the masterName property is defined, a sentinel-backed FailoverClient is used. +// 2. If the cluster property is defined, a ClusterClient is used. // 3. Otherwise, a single-node Client is used. // // To support being instantiated in the init context, while not // producing any IO, as it is the convention in k6, the produced -// Client is initially configured, but in a disconnected state. In -// order to connect to the configured target instance(s), the `.Connect` -// should be called. +// Client is initially configured, but in a disconnected state. +// The connection is automatically established when using any of the Redis +// commands exposed by the Client. func (mi *ModuleInstance) NewClient(call goja.ConstructorCall) *goja.Object { rt := mi.vu.Runtime() - var optionsArg map[string]interface{} - err := rt.ExportTo(call.Arguments[0], &optionsArg) - if err != nil { - common.Throw(rt, errors.New("unable to parse options object")) + if len(call.Arguments) != 1 { + common.Throw(rt, errors.New("must specify one argument")) } - opts, err := newOptionsFrom(optionsArg) + opts, err := readOptions(call.Arguments[0].Export()) if err != nil { - common.Throw(rt, fmt.Errorf("invalid options; reason: %w", err)) - } - - redisOptions := &redis.UniversalOptions{ - Protocol: 2, - Addrs: opts.Addrs, - DB: opts.DB, - Username: opts.Username, - Password: opts.Password, - SentinelUsername: opts.SentinelUsername, - SentinelPassword: opts.SentinelPassword, - MasterName: opts.MasterName, - MaxRetries: opts.MaxRetries, - MinRetryBackoff: time.Duration(opts.MinRetryBackoff) * time.Millisecond, - MaxRetryBackoff: time.Duration(opts.MaxRetryBackoff) * time.Millisecond, - DialTimeout: time.Duration(opts.DialTimeout) * time.Millisecond, - ReadTimeout: time.Duration(opts.ReadTimeout) * time.Millisecond, - WriteTimeout: time.Duration(opts.WriteTimeout) * time.Millisecond, - PoolSize: opts.PoolSize, - MinIdleConns: opts.MinIdleConns, - ConnMaxLifetime: time.Duration(opts.MaxConnAge) * time.Millisecond, - PoolTimeout: time.Duration(opts.PoolTimeout) * time.Millisecond, - ConnMaxIdleTime: time.Duration(opts.IdleTimeout) * time.Millisecond, - MaxRedirects: opts.MaxRedirects, - ReadOnly: opts.ReadOnly, - RouteByLatency: opts.RouteByLatency, - RouteRandomly: opts.RouteRandomly, + common.Throw(rt, err) } client := &Client{ vu: mi.vu, - redisOptions: redisOptions, + redisOptions: opts, redisClient: nil, } return rt.ToValue(client).ToObject(rt) } - -type options struct { - // Either a single address or a seed list of host:port addresses - // of cluster/sentinel nodes. - Addrs []string `json:"addrs,omitempty"` - - // Database to be selected after connecting to the server. - // Only used in single-node and failover modes. - DB int `json:"db,omitempty"` - - // Use the specified Username to authenticate the current connection - // with one of the connections defined in the ACL list when connecting - // to a Redis 6.0 instance, or greater, that is using the Redis ACL system. - Username string `json:"username,omitempty"` - - // Optional password. Must match the password specified in the - // requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), - // or the User Password when connecting to a Redis 6.0 instance, or greater, - // that is using the Redis ACL system. - Password string `json:"password,omitempty"` - - SentinelUsername string `json:"sentinelUsername,omitempty"` - SentinelPassword string `json:"sentinelPassword,omitempty"` - - MasterName string `json:"masterName,omitempty"` - - MaxRetries int `json:"maxRetries,omitempty"` - MinRetryBackoff int64 `json:"minRetryBackoff,omitempty"` - MaxRetryBackoff int64 `json:"maxRetryBackoff,omitempty"` - - DialTimeout int64 `json:"dialTimeout,omitempty"` - ReadTimeout int64 `json:"readTimeout,omitempty"` - WriteTimeout int64 `json:"writeTimeout,omitempty"` - - PoolSize int `json:"poolSize,omitempty"` - MinIdleConns int `json:"minIdleConns,omitempty"` - MaxConnAge int64 `json:"maxConnAge,omitempty"` - PoolTimeout int64 `json:"poolTimeout,omitempty"` - IdleTimeout int64 `json:"idleTimeout,omitempty"` - - MaxRedirects int `json:"maxRedirects,omitempty"` - ReadOnly bool `json:"readOnly,omitempty"` - RouteByLatency bool `json:"routeByLatency,omitempty"` - RouteRandomly bool `json:"routeRandomly,omitempty"` -} - -// newOptionsFrom validates and instantiates an options struct from its map representation -// as obtained by calling a Goja's Runtime.ExportTo. -func newOptionsFrom(argument map[string]interface{}) (*options, error) { - jsonStr, err := json.Marshal(argument) - if err != nil { - return nil, fmt.Errorf("unable to serialize options to JSON %w", err) - } - - // Instantiate a JSON decoder which will error on unknown - // fields. As a result, if the input map contains an unknown - // option, this function will produce an error. - decoder := json.NewDecoder(bytes.NewReader(jsonStr)) - decoder.DisallowUnknownFields() - - var opts options - err = decoder.Decode(&opts) - if err != nil { - return nil, fmt.Errorf("unable to decode options %w", err) - } - - return &opts, nil -} diff --git a/redis/options.go b/redis/options.go new file mode 100644 index 0000000..983b5f7 --- /dev/null +++ b/redis/options.go @@ -0,0 +1,362 @@ +package redis + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +type singleNodeOptions struct { + Socket *socketOptions `json:"socket,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + ClientName string `json:"clientName,omitempty"` + Database int `json:"database,omitempty"` + MaxRetries int `json:"maxRetries,omitempty"` + MinRetryBackoff int64 `json:"minRetryBackoff,omitempty"` + MaxRetryBackoff int64 `json:"maxRetryBackoff,omitempty"` +} + +func (opts singleNodeOptions) toRedisOptions() (*redis.Options, error) { + ropts := &redis.Options{} + sopts := opts.Socket + if err := setSocketOptions(ropts, sopts); err != nil { + return nil, err + } + + ropts.DB = opts.Database + ropts.Username = opts.Username + ropts.Password = opts.Password + ropts.MaxRetries = opts.MaxRetries + ropts.MinRetryBackoff = time.Duration(opts.MinRetryBackoff) + ropts.MaxRetryBackoff = time.Duration(opts.MaxRetryBackoff) + + return ropts, nil +} + +type socketOptions struct { + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + TLS *tlsOptions `json:"tls,omitempty"` + DialTimeout int64 `json:"dialTimeout,omitempty"` + ReadTimeout int64 `json:"readTimeout,omitempty"` + WriteTimeout int64 `json:"writeTimeout,omitempty"` + PoolSize int `json:"poolSize,omitempty"` + MinIdleConns int `json:"minIdleConns,omitempty"` + MaxConnAge int64 `json:"maxConnAge,omitempty"` + PoolTimeout int64 `json:"poolTimeout,omitempty"` + IdleTimeout int64 `json:"idleTimeout,omitempty"` + IdleCheckFrequency int64 `json:"idleCheckFrequency,omitempty"` +} + +type tlsOptions struct { + // TODO: Handle binary data (ArrayBuffer) for all these as well. + CA []string `json:"ca,omitempty"` + Cert string `json:"cert,omitempty"` + Key string `json:"key,omitempty"` +} + +type commonClusterOptions struct { + MaxRedirects int `json:"maxRedirects,omitempty"` + ReadOnly bool `json:"readOnly,omitempty"` + RouteByLatency bool `json:"routeByLatency,omitempty"` + RouteRandomly bool `json:"routeRandomly,omitempty"` +} + +type clusterNodesMapOptions struct { + commonClusterOptions + Nodes []*singleNodeOptions `json:"nodes,omitempty"` +} + +type clusterNodesStringOptions struct { + commonClusterOptions + Nodes []string `json:"nodes,omitempty"` +} + +type sentinelOptions struct { + singleNodeOptions + MasterName string `json:"masterName,omitempty"` + SentinelUsername string `json:"sentinelUsername,omitempty"` + SentinelPassword string `json:"sentinelPassword,omitempty"` +} + +// newOptionsFromObject validates and instantiates an options struct from its +// map representation as exported from goja.Runtime. +func newOptionsFromObject(obj map[string]interface{}) (*redis.UniversalOptions, error) { + var options interface{} + if cluster, ok := obj["cluster"].(map[string]interface{}); ok { + obj = cluster + nodes, ok := cluster["nodes"].([]interface{}) + if !ok { + return nil, fmt.Errorf("cluster nodes property must be an array; got %T", cluster["nodes"]) + } + if len(nodes) == 0 { + return nil, errors.New("cluster nodes property cannot be empty") + } + switch nodes[0].(type) { + case map[string]interface{}: + options = &clusterNodesMapOptions{} + case string: + options = &clusterNodesStringOptions{} + default: + return nil, fmt.Errorf("cluster nodes array must contain string or object elements; got %T", nodes[0]) + } + } else if _, ok := obj["masterName"]; ok { + options = &sentinelOptions{} + } else { + options = &singleNodeOptions{} + } + + jsonStr, err := json.Marshal(obj) + if err != nil { + return nil, fmt.Errorf("unable to serialize options to JSON %w", err) + } + + // Instantiate a JSON decoder which will error on unknown + // fields. As a result, if the input map contains an unknown + // option, this function will produce an error. + decoder := json.NewDecoder(bytes.NewReader(jsonStr)) + decoder.DisallowUnknownFields() + + err = decoder.Decode(&options) + if err != nil { + return nil, err + } + + return toUniversalOptions(options) +} + +// newOptionsFromString parses the expected URL into redis.UniversalOptions. +func newOptionsFromString(url string) (*redis.UniversalOptions, error) { + opts, err := redis.ParseURL(url) + if err != nil { + return nil, err + } + + return toUniversalOptions(opts) +} + +func readOptions(options interface{}) (*redis.UniversalOptions, error) { + var ( + opts *redis.UniversalOptions + err error + ) + switch val := options.(type) { + case string: + opts, err = newOptionsFromString(val) + case map[string]interface{}: + opts, err = newOptionsFromObject(val) + default: + return nil, fmt.Errorf("invalid options type: %T; expected string or object", val) + } + + if err != nil { + return nil, fmt.Errorf("invalid options; reason: %w", err) + } + + return opts, nil +} + +func toUniversalOptions(options interface{}) (*redis.UniversalOptions, error) { + uopts := &redis.UniversalOptions{Protocol: 2} + + switch o := options.(type) { + case *clusterNodesMapOptions: + setClusterOptions(uopts, &o.commonClusterOptions) + + for _, n := range o.Nodes { + ropts, err := n.toRedisOptions() + if err != nil { + return nil, err + } + if err = setConsistentOptions(uopts, ropts); err != nil { + return nil, err + } + } + case *clusterNodesStringOptions: + setClusterOptions(uopts, &o.commonClusterOptions) + + for _, n := range o.Nodes { + ropts, err := redis.ParseURL(n) + if err != nil { + return nil, err + } + if err = setConsistentOptions(uopts, ropts); err != nil { + return nil, err + } + } + case *sentinelOptions: + uopts.MasterName = o.MasterName + uopts.SentinelUsername = o.SentinelUsername + uopts.SentinelPassword = o.SentinelPassword + + ropts, err := o.toRedisOptions() + if err != nil { + return nil, err + } + if err = setConsistentOptions(uopts, ropts); err != nil { + return nil, err + } + case *singleNodeOptions: + ropts, err := o.toRedisOptions() + if err != nil { + return nil, err + } + if err = setConsistentOptions(uopts, ropts); err != nil { + return nil, err + } + case *redis.Options: + if err := setConsistentOptions(uopts, o); err != nil { + return nil, err + } + default: + panic(fmt.Sprintf("unexpected options type %T", options)) + } + + return uopts, nil +} + +// Set UniversalOptions values from single-node options, ensuring that any +// previously set values are consistent with the new values. This validates that +// multiple node options set when using cluster mode are consistent with each other. +// TODO: Break apart, simplify? +//nolint: gocognit,cyclop,funlen,gofmt,gofumpt,goimports +func setConsistentOptions(uopts *redis.UniversalOptions, opts *redis.Options) error { + uopts.Addrs = append(uopts.Addrs, opts.Addr) + + // Only set the TLS config once. Note that this assumes the same config is + // used in other single-node options, since doing the consistency check we + // use for the other options would be tedious. + if uopts.TLSConfig == nil && opts.TLSConfig != nil { + uopts.TLSConfig = opts.TLSConfig + } + + if uopts.DB != 0 && opts.DB != 0 && uopts.DB != opts.DB { + return fmt.Errorf("inconsistent db option: %d != %d", uopts.DB, opts.DB) + } + uopts.DB = opts.DB + + if uopts.Username != "" && opts.Username != "" && uopts.Username != opts.Username { + return fmt.Errorf("inconsistent username option: %s != %s", uopts.Username, opts.Username) + } + uopts.Username = opts.Username + + if uopts.Password != "" && opts.Password != "" && uopts.Password != opts.Password { + return fmt.Errorf("inconsistent password option") + } + uopts.Password = opts.Password + + if uopts.ClientName != "" && opts.ClientName != "" && uopts.ClientName != opts.ClientName { + return fmt.Errorf("inconsistent clientName option: %s != %s", uopts.ClientName, opts.ClientName) + } + uopts.ClientName = opts.ClientName + + if uopts.MaxRetries != 0 && opts.MaxRetries != 0 && uopts.MaxRetries != opts.MaxRetries { + return fmt.Errorf("inconsistent maxRetries option: %d != %d", uopts.MaxRetries, opts.MaxRetries) + } + uopts.MaxRetries = opts.MaxRetries + + if uopts.MinRetryBackoff != 0 && opts.MinRetryBackoff != 0 && uopts.MinRetryBackoff != opts.MinRetryBackoff { + return fmt.Errorf("inconsistent minRetryBackoff option: %d != %d", uopts.MinRetryBackoff, opts.MinRetryBackoff) + } + uopts.MinRetryBackoff = opts.MinRetryBackoff + + if uopts.MaxRetryBackoff != 0 && opts.MaxRetryBackoff != 0 && uopts.MaxRetryBackoff != opts.MaxRetryBackoff { + return fmt.Errorf("inconsistent maxRetryBackoff option: %d != %d", uopts.MaxRetryBackoff, opts.MaxRetryBackoff) + } + uopts.MaxRetryBackoff = opts.MaxRetryBackoff + + if uopts.DialTimeout != 0 && opts.DialTimeout != 0 && uopts.DialTimeout != opts.DialTimeout { + return fmt.Errorf("inconsistent dialTimeout option: %d != %d", uopts.DialTimeout, opts.DialTimeout) + } + uopts.DialTimeout = opts.DialTimeout + + if uopts.ReadTimeout != 0 && opts.ReadTimeout != 0 && uopts.ReadTimeout != opts.ReadTimeout { + return fmt.Errorf("inconsistent readTimeout option: %d != %d", uopts.ReadTimeout, opts.ReadTimeout) + } + uopts.ReadTimeout = opts.ReadTimeout + + if uopts.WriteTimeout != 0 && opts.WriteTimeout != 0 && uopts.WriteTimeout != opts.WriteTimeout { + return fmt.Errorf("inconsistent writeTimeout option: %d != %d", uopts.WriteTimeout, opts.WriteTimeout) + } + uopts.WriteTimeout = opts.WriteTimeout + + if uopts.PoolSize != 0 && opts.PoolSize != 0 && uopts.PoolSize != opts.PoolSize { + return fmt.Errorf("inconsistent poolSize option: %d != %d", uopts.PoolSize, opts.PoolSize) + } + uopts.PoolSize = opts.PoolSize + + if uopts.MinIdleConns != 0 && opts.MinIdleConns != 0 && uopts.MinIdleConns != opts.MinIdleConns { + return fmt.Errorf("inconsistent minIdleConns option: %d != %d", uopts.MinIdleConns, opts.MinIdleConns) + } + uopts.MinIdleConns = opts.MinIdleConns + + if uopts.ConnMaxLifetime != 0 && opts.ConnMaxLifetime != 0 && uopts.ConnMaxLifetime != opts.ConnMaxLifetime { + return fmt.Errorf("inconsistent maxConnAge option: %d != %d", uopts.ConnMaxLifetime, opts.ConnMaxLifetime) + } + uopts.ConnMaxLifetime = opts.ConnMaxLifetime + + if uopts.PoolTimeout != 0 && opts.PoolTimeout != 0 && uopts.PoolTimeout != opts.PoolTimeout { + return fmt.Errorf("inconsistent poolTimeout option: %d != %d", uopts.PoolTimeout, opts.PoolTimeout) + } + uopts.PoolTimeout = opts.PoolTimeout + + if uopts.ConnMaxIdleTime != 0 && opts.ConnMaxIdleTime != 0 && uopts.ConnMaxIdleTime != opts.ConnMaxIdleTime { + return fmt.Errorf("inconsistent idleTimeout option: %d != %d", uopts.ConnMaxIdleTime, opts.ConnMaxIdleTime) + } + uopts.ConnMaxIdleTime = opts.ConnMaxIdleTime + + return nil +} + +func setClusterOptions(uopts *redis.UniversalOptions, opts *commonClusterOptions) { + uopts.MaxRedirects = opts.MaxRedirects + uopts.ReadOnly = opts.ReadOnly + uopts.RouteByLatency = opts.RouteByLatency + uopts.RouteRandomly = opts.RouteRandomly +} + +func setSocketOptions(opts *redis.Options, sopts *socketOptions) error { + if sopts == nil { + return fmt.Errorf("empty socket options") + } + opts.Addr = fmt.Sprintf("%s:%d", sopts.Host, sopts.Port) + opts.DialTimeout = time.Duration(sopts.DialTimeout) * time.Millisecond + opts.ReadTimeout = time.Duration(sopts.ReadTimeout) * time.Millisecond + opts.WriteTimeout = time.Duration(sopts.WriteTimeout) * time.Millisecond + opts.PoolSize = sopts.PoolSize + opts.MinIdleConns = sopts.MinIdleConns + opts.ConnMaxLifetime = time.Duration(sopts.MaxConnAge) * time.Millisecond + opts.PoolTimeout = time.Duration(sopts.PoolTimeout) * time.Millisecond + opts.ConnMaxIdleTime = time.Duration(sopts.IdleTimeout) * time.Millisecond + + if sopts.TLS != nil { + //nolint: gosec // ignore G402: TLS MinVersion too low + tlsCfg := &tls.Config{} + if len(sopts.TLS.CA) > 0 { + caCertPool := x509.NewCertPool() + for _, cert := range sopts.TLS.CA { + caCertPool.AppendCertsFromPEM([]byte(cert)) + } + tlsCfg.RootCAs = caCertPool + } + + if sopts.TLS.Cert != "" && sopts.TLS.Key != "" { + clientCertPair, err := tls.X509KeyPair([]byte(sopts.TLS.Cert), []byte(sopts.TLS.Key)) + if err != nil { + return err + } + tlsCfg.Certificates = []tls.Certificate{clientCertPair} + } + + opts.TLSConfig = tlsCfg + } + + return nil +} diff --git a/redis/stub_test.go b/redis/stub_test.go index 8fa787b..1788be3 100644 --- a/redis/stub_test.go +++ b/redis/stub_test.go @@ -2,6 +2,8 @@ package redis import ( "bufio" + "crypto/tls" + "crypto/x509" "errors" "fmt" "net" @@ -12,11 +14,26 @@ import ( "unicode" ) -// RunT starts a new redis stub for a given test context. +// RunT starts a new redis stub TCP server for a given test context. // It registers the test cleanup after your test is done. func RunT(t testing.TB) *StubServer { s := NewStubServer() - if err := s.Start(); err != nil { + if err := s.Start(false, nil); err != nil { + t.Fatalf("could not start RedisStub; reason: %s", err) + } + + t.Cleanup(s.Close) + + return s +} + +// RunTSecure starts a new redis stub TLS server for a given test context. +// It registers the test cleanup after your test is done. +// Optionally, a client certificate in PEM format can be passed to enable TLS +// client authentication (mTLS). +func RunTSecure(t testing.TB, clientCert []byte) *StubServer { + s := NewStubServer() + if err := s.Start(true, clientCert); err != nil { t.Fatalf("could not start RedisStub; reason: %s", err) } @@ -50,6 +67,8 @@ type StubServer struct { handlers map[string]func(*Connection, []string) processedCommands int commandsHistory [][]string + + tlsCert []byte } // NewStubServer instantiates a new RedisStub server. @@ -62,19 +81,54 @@ func NewStubServer() *StubServer { } } -// Start starts the RedisStub server. -func (rs *StubServer) Start() error { - listener, err := net.Listen("tcp", net.JoinHostPort("localhost", "0")) - if err != nil { - return err +// Start the RedisStub server. If secure is true, a TLS server with a +// self-signed certificate will be started. Otherwise, an unencrypted TCP +// server will start. +// Optionally, a client certificate in PEM format can be passed to enable TLS +// client authentication (mTLS). +func (rs *StubServer) Start(secure bool, clientCert []byte) error { + var ( + addr = net.JoinHostPort("localhost", "0") + listener net.Listener + ) + if secure { //nolint: nestif + // TODO: Generate the cert only once per test run and reuse it, instead + // of once per StubServer start? + cert, pkey, err := generateTLSCert() + if err != nil { + return err + } + rs.tlsCert = cert + certPair, err := tls.X509KeyPair(cert, pkey) + if err != nil { + return err + } + config := &tls.Config{ + MinVersion: tls.VersionTLS13, + PreferServerCipherSuites: true, + Certificates: []tls.Certificate{certPair}, + } + if clientCert != nil { + clientCertPool := x509.NewCertPool() + clientCertPool.AppendCertsFromPEM(clientCert) + config.ClientCAs = clientCertPool + config.ClientAuth = tls.RequireAndVerifyClientCert + } + if listener, err = tls.Listen("tcp", addr, config); err != nil { + return err + } + } else { + var err error + if listener, err = net.Listen("tcp", addr); err != nil { + return err + } } rs.listener = listener // the provided addr string binds to port zero, // which leads to automatic port selection by the OS. - // To catter for this, we parse the listener address - // to get the actual port, the OS bound us to. + // We need to get the actual port the OS bound us to. boundAddr, ok := listener.Addr().(*net.TCPAddr) if !ok { return errors.New("could not get TCP address") @@ -140,8 +194,8 @@ func (rs *StubServer) RegisterCommandHandler(command string, handler func(*Conne } // Addr returns the address of the RedisStub server. -func (rs *StubServer) Addr() string { - return rs.boundAddr.String() +func (rs *StubServer) Addr() *net.TCPAddr { + return rs.boundAddr } // HandledCommandsCount returns the total number of commands @@ -168,6 +222,11 @@ func (rs *StubServer) GotCommands() [][]string { return rs.commandsHistory } +// TLSCertificate returns the TLS certificate used by the server. +func (rs *StubServer) TLSCertificate() []byte { + return rs.tlsCert +} + // listenAndServe listens on the redis server's listener, // and handles client connections. func (rs *StubServer) listenAndServe(l net.Listener) { diff --git a/redis/tls_test.go b/redis/tls_test.go new file mode 100644 index 0000000..d3d1a2c --- /dev/null +++ b/redis/tls_test.go @@ -0,0 +1,64 @@ +package redis + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math/big" + "time" +) + +// Generate self-signed TLS certificate and private key for testing purposes, +// and return them as PEM encoded data. +// Source: https://eli.thegreenplace.net/2021/go-https-servers-with-tls/ +func generateTLSCert() (certPEM, privateKeyPEM []byte, err error) { + pkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Grafana Labs"}, + }, + DNSNames: []string{"localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(3 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &pkey.PublicKey, pkey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + if certPEM == nil { + return nil, nil, errors.New("failed to encode certificate to PEM") + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(pkey) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal private key: %w", err) + } + privateKeyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) + if privateKeyPEM == nil { + return nil, nil, errors.New("failed to encode private key to PEM") + } + + return certPEM, privateKeyPEM, nil +}