diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 000000000..0ea16dee3 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,74 @@ +name: ci +on: + push: + branches: + - main + pull_request: + release: + types: [published] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: 'stable' + + - name: Install deps + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + go get -t ./... + go install honnef.co/go/tools/cmd/staticcheck@latest + go install github.com/client9/misspell/cmd/misspell@latest + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + + - name: Run linters + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + $(exit $(go fmt -modfile=go_test.mod ./... | wc -l)) + go vet -modfile=go_test.mod ./... + GOFLAGS="-mod=mod -modfile=go_test.mod" staticcheck ./... + find . -type f -name "*.go" | xargs misspell -error -locale US + golangci-lint run --timeout 5m0s ./jetstream/... + + test: + runs-on: ubuntu-latest + + strategy: + matrix: + go: [ "1.22", "1.23" ] + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go }} + + - name: Install deps + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + go install github.com/mattn/goveralls@latest + go install github.com/wadey/gocovmerge@latest + + - name: Test and coverage + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast -vet=off + if [ "${{ matrix.go }}" = "1.23" ]; then + ./scripts/cov.sh CI + else + go test -modfile=go_test.mod -race -v -p=1 ./... --failfast -vet=off -tags=internal_testing + fi + + - name: Coveralls + if: matrix.go == '1.23' + uses: coverallsapp/github-action@v2 + with: + file: acc.out \ No newline at end of file diff --git a/.github/workflows/dependencies.yaml b/.github/workflows/dependencies.yaml new file mode 100644 index 000000000..5d9a7dcb2 --- /dev/null +++ b/.github/workflows/dependencies.yaml @@ -0,0 +1,61 @@ +name: License Check + +on: + push: + paths: + - 'go.mod' + branches: + - main + +jobs: + license-check: + runs-on: ubuntu-latest + + env: + BRANCH_NAME: update-report-branch-${{ github.run_id }} + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 # Fetch all history for all branches and tags + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.22' + + - name: Install go-licenses + run: go install github.com/google/go-licenses@latest + # We need this step because of test dependencies and how they are handled in nats.go + - name: Run go mod tidy + run: go mod tidy + - name: Run license check + run: go-licenses report ./... --template dependencies.tpl > dependencies.md + + - name: Configure git + run: | + git config user.name 'github-actions[bot]' + git config user.email 'github-actions[bot]@users.noreply.github.com' + + - name: Check for changes + id: git_diff + run: | + git fetch + git diff --exit-code dependencies.md || echo "has_changes=true" >> $GITHUB_ENV + + - name: Commit changes + if: env.has_changes == 'true' + run: | + git checkout -b "$BRANCH_NAME" + git add dependencies.md + git commit -m "Update dependencies.md" + git push -u origin "$BRANCH_NAME" + + - name: Create Pull Request + if: env.has_changes == 'true' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh pr create --title "Update dependencies.md" --body "This PR updates the dependencies report" --head "$BRANCH_NAME" --base main + diff --git a/.github/workflows/latest-server.yaml b/.github/workflows/latest-server.yaml new file mode 100644 index 000000000..c44523b1a --- /dev/null +++ b/.github/workflows/latest-server.yaml @@ -0,0 +1,27 @@ +name: Test nats-server@main +on: + schedule: + - cron: "30 8 * * *" + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: 'stable' + + - name: Get latest server + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + go get -modfile go_test.mod github.com/nats-io/nats-server/v2@main + + - name: Test + shell: bash --noprofile --norc -x -eo pipefail {0} + run: | + go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast -vet=off + go test -modfile=go_test.mod -race -v -p=1 ./... --failfast -vet=off -tags=internal_testing \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 9a6b4a89c..000000000 --- a/.travis.yml +++ /dev/null @@ -1,36 +0,0 @@ -language: go -go: -- "1.22.x" -- "1.21.x" -go_import_path: github.com/nats-io/nats.go -install: -- go get -t ./... -- curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin -- if [[ "$TRAVIS_GO_VERSION" =~ 1.22 ]]; then - go install github.com/mattn/goveralls@latest; - go install github.com/wadey/gocovmerge@latest; - go install honnef.co/go/tools/cmd/staticcheck@latest; - go install github.com/client9/misspell/cmd/misspell@latest; - fi -before_script: -- $(exit $(go fmt ./... | wc -l)) -- go vet -modfile=go_test.mod ./... -- if [[ "$TRAVIS_GO_VERSION" =~ 1.22 ]]; then - find . -type f -name "*.go" | xargs misspell -error -locale US; - GOFLAGS="-mod=mod -modfile=go_test.mod" staticcheck ./...; - fi -- golangci-lint run ./jetstream/... -script: -- go test -modfile=go_test.mod -v -run=TestNoRace -p=1 ./... --failfast -vet=off -- if [[ "$TRAVIS_GO_VERSION" =~ 1.22 ]]; then ./scripts/cov.sh TRAVIS; else go test -modfile=go_test.mod -race -v -p=1 ./... --failfast -vet=off -tags=internal_testing; fi -after_success: -- if [[ "$TRAVIS_GO_VERSION" =~ 1.22 ]]; then $HOME/gopath/bin/goveralls -coverprofile=acc.out -service travis-ci; fi - -jobs: - include: - - name: "Go: 1.22.x (nats-server@main)" - go: "1.22.x" - before_script: - - go get -modfile go_test.mod github.com/nats-io/nats-server/v2@main - allow_failures: - - name: "Go: 1.22.x (nats-server@main)" diff --git a/README.md b/README.md index 237c03b59..7980eccbb 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ A [Go](http://golang.org) client for the [NATS messaging system](https://nats.io [License-Image]: https://img.shields.io/badge/License-Apache2-blue.svg [ReportCard-Url]: https://goreportcard.com/report/github.com/nats-io/nats.go [ReportCard-Image]: https://goreportcard.com/badge/github.com/nats-io/nats.go -[Build-Status-Url]: https://travis-ci.com/github/nats-io/nats.go -[Build-Status-Image]: https://travis-ci.com/nats-io/nats.go.svg?branch=main +[Build-Status-Url]: https://github.com/nats-io/nats.go/actions +[Build-Status-Image]: https://github.com/nats-io/nats.go/actions/workflows/ci.yaml/badge.svg?branch=main [GoDoc-Url]: https://pkg.go.dev/github.com/nats-io/nats.go [GoDoc-Image]: https://img.shields.io/badge/GoDoc-reference-007d9c [Coverage-Url]: https://coveralls.io/r/nats-io/nats.go?branch=main @@ -19,25 +19,14 @@ A [Go](http://golang.org) client for the [NATS messaging system](https://nats.io ## Installation ```bash -# Go client -go get github.com/nats-io/nats.go/ +# To get the latest released Go client: +go get github.com/nats-io/nats.go@latest -# Server -go get github.com/nats-io/nats-server -``` - -When using or transitioning to Go modules support: - -```bash -# Go client latest or explicit version -go get github.com/nats-io/nats.go/@latest -go get github.com/nats-io/nats.go/@v1.33.0 +# To get a specific version: +go get github.com/nats-io/nats.go@v1.37.0 -# For latest NATS Server, add /v2 at the end -go get github.com/nats-io/nats-server/v2 - -# NATS Server v1 is installed otherwise -# go get github.com/nats-io/nats-server +# Note that the latest major version for NATS Server is v2: +go get github.com/nats-io/nats-server/v2@latest ``` ## Basic Usage @@ -93,11 +82,13 @@ nc.Close() ``` ## JetStream +[![JetStream API Reference](https://pkg.go.dev/badge/github.com/nats-io/nats.go/jetstream.svg)](https://pkg.go.dev/github.com/nats-io/nats.go/jetstream) JetStream is the built-in NATS persistence system. `nats.go` provides a built-in API enabling both managing JetStream assets as well as publishing/consuming persistent messages. + ### Basic usage ```go @@ -134,60 +125,6 @@ To find more information on `nats.go` JetStream API, visit The service API (`micro`) allows you to [easily build NATS services](micro/README.md) The services API is currently in beta release. -## Encoded Connections - -```go - -nc, _ := nats.Connect(nats.DefaultURL) -c, _ := nats.NewEncodedConn(nc, nats.JSON_ENCODER) -defer c.Close() - -// Simple Publisher -c.Publish("foo", "Hello World") - -// Simple Async Subscriber -c.Subscribe("foo", func(s string) { - fmt.Printf("Received a message: %s\n", s) -}) - -// EncodedConn can Publish any raw Go type using the registered Encoder -type person struct { - Name string - Address string - Age int -} - -// Go type Subscriber -c.Subscribe("hello", func(p *person) { - fmt.Printf("Received a person: %+v\n", p) -}) - -me := &person{Name: "derek", Age: 22, Address: "140 New Montgomery Street, San Francisco, CA"} - -// Go type Publisher -c.Publish("hello", me) - -// Unsubscribe -sub, err := c.Subscribe("foo", nil) -// ... -sub.Unsubscribe() - -// Requests -var response string -err = c.Request("help", "help me", &response, 10*time.Millisecond) -if err != nil { - fmt.Printf("Request failed: %v\n", err) -} - -// Replying -c.Subscribe("help", func(subj, reply string, msg string) { - c.Publish(reply, "I can help!") -}) - -// Close connection -c.Close(); -``` - ## New Authentication (Nkeys and User Credentials) This requires server with version >= 2.0.0 @@ -267,34 +204,6 @@ if err != nil { ``` -## Using Go Channels (netchan) - -```go -nc, _ := nats.Connect(nats.DefaultURL) -ec, _ := nats.NewEncodedConn(nc, nats.JSON_ENCODER) -defer ec.Close() - -type person struct { - Name string - Address string - Age int -} - -recvCh := make(chan *person) -ec.BindRecvChan("hello", recvCh) - -sendCh := make(chan *person) -ec.BindSendChan("hello", sendCh) - -me := &person{Name: "derek", Age: 22, Address: "140 New Montgomery Street"} - -// Send via Go channels -sendCh <- me - -// Receive via Go channels -who := <- recvCh -``` - ## Wildcard Subscriptions ```go @@ -461,19 +370,21 @@ msg, err := nc.RequestWithContext(ctx, "foo", []byte("bar")) sub, err := nc.SubscribeSync("foo") msg, err := sub.NextMsgWithContext(ctx) -// Encoded Request with context -c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) -type request struct { - Message string `json:"message"` -} -type response struct { - Code int `json:"code"` -} -req := &request{Message: "Hello"} -resp := &response{} -err := c.RequestWithContext(ctx, "foo", req, resp) ``` +## Backwards compatibility + +In the development of nats.go, we are committed to maintaining backward compatibility and ensuring a stable and reliable experience for all users. In general, we follow the standard go compatibility guidelines. +However, it's important to clarify our stance on certain types of changes: + +- **Expanding structures:** +Adding new fields to structs is not considered a breaking change. + +- **Adding methods to exported interfaces:** +Extending public interfaces with new methods is also not viewed as a breaking change within the context of this project. It is important to note that no unexported methods will be added to interfaces allowing users to implement them. + +Additionally, this library always supports at least 2 latest minor Go versions. For example, if the latest Go version is 1.22, the library will support Go 1.21 and 1.22. + ## License Unless otherwise noted, the NATS source files are distributed diff --git a/bench/bench.go b/bench/bench.go index cb724737f..4ea4b3056 100644 --- a/bench/bench.go +++ b/bench/bench.go @@ -36,7 +36,7 @@ type Sample struct { End time.Time } -// SampleGroup for a number of samples, the group is a Sample itself agregating the values the Samples +// SampleGroup for a number of samples, the group is a Sample itself aggregating the values the Samples type SampleGroup struct { Sample Samples []*Sample @@ -156,7 +156,7 @@ func (s *Sample) Throughput() float64 { return float64(s.MsgBytes) / s.Duration().Seconds() } -// Rate of meessages in the job per second +// Rate of messages in the job per second func (s *Sample) Rate() int64 { return int64(float64(s.JobMsgCnt) / s.Duration().Seconds()) } diff --git a/bench/benchlib_test.go b/bench/benchlib_test.go index d3b515c0f..020bae45c 100644 --- a/bench/benchlib_test.go +++ b/bench/benchlib_test.go @@ -116,7 +116,7 @@ func TestGroupThoughput(t *testing.T) { sg.AddSample(millionMessagesSecondSample(2)) sg.AddSample(millionMessagesSecondSample(3)) if sg.Throughput() != 2*Million*MsgSize { - t.Fatalf("Expected througput at %d million bytes/sec", 2*MsgSize) + t.Fatalf("Expected throughput at %d million bytes/sec", 2*MsgSize) } } diff --git a/context.go b/context.go index 20f1782ac..c19673c18 100644 --- a/context.go +++ b/context.go @@ -217,6 +217,8 @@ func (nc *Conn) FlushWithContext(ctx context.Context) error { // RequestWithContext will create an Inbox and perform a Request // using the provided cancellation context with the Inbox reply // for the data v. A response will be decoded into the vPtr last parameter. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) RequestWithContext(ctx context.Context, subject string, v any, vPtr any) error { if ctx == nil { return ErrInvalidContext diff --git a/dependencies.tpl b/dependencies.tpl new file mode 100644 index 000000000..5d01380f9 --- /dev/null +++ b/dependencies.tpl @@ -0,0 +1,8 @@ +# External Dependencies + +This file lists the dependencies used in this repository. + +| Dependency | License | +|--------------------------------------------------|-----------------------------------------| +{{ range . }}| {{.Name}} | {{.LicenseName}} | +{{ end }} diff --git a/enc.go b/enc.go index 4550f618d..78bcc219f 100644 --- a/enc.go +++ b/enc.go @@ -24,7 +24,11 @@ import ( "github.com/nats-io/nats.go/encoders/builtin" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + // Encoder interface is for all register encoders +// +// Deprecated: Encoded connections are no longer supported. type Encoder interface { Encode(subject string, v any) ([]byte, error) Decode(subject string, data []byte, vPtr any) error @@ -51,6 +55,8 @@ func init() { // EncodedConn are the preferred way to interface with NATS. They wrap a bare connection to // a nats server and have an extendable encoder system that will encode and decode messages // from raw Go types. +// +// Deprecated: Encoded connections are no longer supported. type EncodedConn struct { Conn *Conn Enc Encoder @@ -58,6 +64,8 @@ type EncodedConn struct { // NewEncodedConn will wrap an existing Connection and utilize the appropriate registered // encoder. +// +// Deprecated: Encoded connections are no longer supported. func NewEncodedConn(c *Conn, encType string) (*EncodedConn, error) { if c == nil { return nil, errors.New("nats: Nil Connection") @@ -73,6 +81,8 @@ func NewEncodedConn(c *Conn, encType string) (*EncodedConn, error) { } // RegisterEncoder will register the encType with the given Encoder. Useful for customization. +// +// Deprecated: Encoded connections are no longer supported. func RegisterEncoder(encType string, enc Encoder) { encLock.Lock() defer encLock.Unlock() @@ -80,6 +90,8 @@ func RegisterEncoder(encType string, enc Encoder) { } // EncoderForType will return the registered Encoder for the encType. +// +// Deprecated: Encoded connections are no longer supported. func EncoderForType(encType string) Encoder { encLock.Lock() defer encLock.Unlock() @@ -88,6 +100,8 @@ func EncoderForType(encType string) Encoder { // Publish publishes the data argument to the given subject. The data argument // will be encoded using the associated encoder. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) Publish(subject string, v any) error { b, err := c.Enc.Encode(subject, v) if err != nil { @@ -99,6 +113,8 @@ func (c *EncodedConn) Publish(subject string, v any) error { // PublishRequest will perform a Publish() expecting a response on the // reply subject. Use Request() for automatically waiting for a response // inline. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) PublishRequest(subject, reply string, v any) error { b, err := c.Enc.Encode(subject, v) if err != nil { @@ -110,6 +126,8 @@ func (c *EncodedConn) PublishRequest(subject, reply string, v any) error { // Request will create an Inbox and perform a Request() call // with the Inbox reply for the data v. A response will be // decoded into the vPtr Response. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) Request(subject string, v any, vPtr any, timeout time.Duration) error { b, err := c.Enc.Encode(subject, v) if err != nil { @@ -150,6 +168,8 @@ func (c *EncodedConn) Request(subject string, v any, vPtr any, timeout time.Dura // and demarshal it into the given struct, e.g. person. // There are also variants where the callback wants either the subject, or the // subject and the reply subject. +// +// Deprecated: Encoded connections are no longer supported. type Handler any // Dissect the cb Handler's signature @@ -170,6 +190,8 @@ var emptyMsgType = reflect.TypeOf(&Msg{}) // Subscribe will create a subscription on the given subject and process incoming // messages using the specified Handler. The Handler should be a func that matches // a signature from the description of Handler from above. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) Subscribe(subject string, cb Handler) (*Subscription, error) { return c.subscribe(subject, _EMPTY_, cb) } @@ -177,6 +199,8 @@ func (c *EncodedConn) Subscribe(subject string, cb Handler) (*Subscription, erro // QueueSubscribe will create a queue subscription on the given subject and process // incoming messages using the specified Handler. The Handler should be a func that // matches a signature from the description of Handler from above. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) QueueSubscribe(subject, queue string, cb Handler) (*Subscription, error) { return c.subscribe(subject, queue, cb) } @@ -238,18 +262,24 @@ func (c *EncodedConn) subscribe(subject, queue string, cb Handler) (*Subscriptio } // FlushTimeout allows a Flush operation to have an associated timeout. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) FlushTimeout(timeout time.Duration) (err error) { return c.Conn.FlushTimeout(timeout) } // Flush will perform a round trip to the server and return when it // receives the internal reply. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) Flush() error { return c.Conn.Flush() } // Close will close the connection to the server. This call will release // all blocking calls, such as Flush(), etc. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) Close() { c.Conn.Close() } @@ -259,11 +289,15 @@ func (c *EncodedConn) Close() { // will be drained and can not publish any additional messages. Upon draining // of the publishers, the connection will be closed. Use the ClosedCB() // option to know when the connection has moved from draining to closed. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) Drain() error { return c.Conn.Drain() } // LastError reports the last error encountered via the Connection. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) LastError() error { return c.Conn.LastError() } diff --git a/encoders/builtin/default_enc.go b/encoders/builtin/default_enc.go index c1d0f6f0b..e73113da8 100644 --- a/encoders/builtin/default_enc.go +++ b/encoders/builtin/default_enc.go @@ -26,6 +26,8 @@ import ( // turn numbers into appropriate strings that can be decoded. It will also // properly encoded and decode bools. If will encode a struct, but if you want // to properly handle structures you should use JsonEncoder. +// +// Deprecated: Encoded connections are no longer supported. type DefaultEncoder struct { // Empty } @@ -35,6 +37,8 @@ var falseB = []byte("false") var nilB = []byte("") // Encode +// +// Deprecated: Encoded connections are no longer supported. func (je *DefaultEncoder) Encode(subject string, v any) ([]byte, error) { switch arg := v.(type) { case string: @@ -58,6 +62,8 @@ func (je *DefaultEncoder) Encode(subject string, v any) ([]byte, error) { } // Decode +// +// Deprecated: Encoded connections are no longer supported. func (je *DefaultEncoder) Decode(subject string, data []byte, vPtr any) error { // Figure out what it's pointing to... sData := *(*string)(unsafe.Pointer(&data)) diff --git a/encoders/builtin/gob_enc.go b/encoders/builtin/gob_enc.go index 7ecf85e4d..e2e8c3202 100644 --- a/encoders/builtin/gob_enc.go +++ b/encoders/builtin/gob_enc.go @@ -21,6 +21,8 @@ import ( // GobEncoder is a Go specific GOB Encoder implementation for EncodedConn. // This encoder will use the builtin encoding/gob to Marshal // and Unmarshal most types, including structs. +// +// Deprecated: Encoded connections are no longer supported. type GobEncoder struct { // Empty } @@ -28,6 +30,8 @@ type GobEncoder struct { // FIXME(dlc) - This could probably be more efficient. // Encode +// +// Deprecated: Encoded connections are no longer supported. func (ge *GobEncoder) Encode(subject string, v any) ([]byte, error) { b := new(bytes.Buffer) enc := gob.NewEncoder(b) @@ -38,6 +42,8 @@ func (ge *GobEncoder) Encode(subject string, v any) ([]byte, error) { } // Decode +// +// Deprecated: Encoded connections are no longer supported. func (ge *GobEncoder) Decode(subject string, data []byte, vPtr any) (err error) { dec := gob.NewDecoder(bytes.NewBuffer(data)) err = dec.Decode(vPtr) diff --git a/encoders/builtin/json_enc.go b/encoders/builtin/json_enc.go index 0540d9850..8e4c852a4 100644 --- a/encoders/builtin/json_enc.go +++ b/encoders/builtin/json_enc.go @@ -21,11 +21,15 @@ import ( // JsonEncoder is a JSON Encoder implementation for EncodedConn. // This encoder will use the builtin encoding/json to Marshal // and Unmarshal most types, including structs. +// +// Deprecated: Encoded connections are no longer supported. type JsonEncoder struct { // Empty } // Encode +// +// Deprecated: Encoded connections are no longer supported. func (je *JsonEncoder) Encode(subject string, v any) ([]byte, error) { b, err := json.Marshal(v) if err != nil { @@ -35,6 +39,8 @@ func (je *JsonEncoder) Encode(subject string, v any) ([]byte, error) { } // Decode +// +// Deprecated: Encoded connections are no longer supported. func (je *JsonEncoder) Decode(subject string, data []byte, vPtr any) (err error) { switch arg := vPtr.(type) { case *string: diff --git a/encoders/protobuf/protobuf_enc.go b/encoders/protobuf/protobuf_enc.go index 017ffc035..805657767 100644 --- a/encoders/protobuf/protobuf_enc.go +++ b/encoders/protobuf/protobuf_enc.go @@ -20,6 +20,8 @@ import ( "google.golang.org/protobuf/proto" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + // Additional index for registered Encoders. const ( PROTOBUF_ENCODER = "protobuf" @@ -33,6 +35,8 @@ func init() { // ProtobufEncoder is a protobuf implementation for EncodedConn // This encoder will use the builtin protobuf lib to Marshal // and Unmarshal structs. +// +// Deprecated: Encoded connections are no longer supported. type ProtobufEncoder struct { // Empty } @@ -43,6 +47,8 @@ var ( ) // Encode +// +// Deprecated: Encoded connections are no longer supported. func (pb *ProtobufEncoder) Encode(subject string, v any) ([]byte, error) { if v == nil { return nil, nil @@ -60,6 +66,8 @@ func (pb *ProtobufEncoder) Encode(subject string, v any) ([]byte, error) { } // Decode +// +// Deprecated: Encoded connections are no longer supported. func (pb *ProtobufEncoder) Decode(subject string, data []byte, vPtr any) error { if _, ok := vPtr.(*any); ok { return nil diff --git a/example_test.go b/example_test.go index 6aa93636c..782adc414 100644 --- a/example_test.go +++ b/example_test.go @@ -89,6 +89,19 @@ func ExampleConn_Subscribe() { }) } +func ExampleConn_ForceReconnect() { + nc, _ := nats.Connect(nats.DefaultURL) + defer nc.Close() + + nc.Subscribe("foo", func(m *nats.Msg) { + fmt.Printf("Received a message: %s\n", string(m.Data)) + }) + + // Reconnect to the server. + // the subscription will be recreated after the reconnect. + nc.ForceReconnect() +} + // This Example shows a synchronous subscriber. func ExampleConn_SubscribeSync() { nc, _ := nats.Connect(nats.DefaultURL) @@ -214,106 +227,6 @@ func ExampleConn_Close() { nc.Close() } -// Shows how to wrap a Conn into an EncodedConn -func ExampleNewEncodedConn() { - nc, _ := nats.Connect(nats.DefaultURL) - c, _ := nats.NewEncodedConn(nc, "json") - c.Close() -} - -// EncodedConn can publish virtually anything just -// by passing it in. The encoder will be used to properly -// encode the raw Go type -func ExampleEncodedConn_Publish() { - nc, _ := nats.Connect(nats.DefaultURL) - c, _ := nats.NewEncodedConn(nc, "json") - defer c.Close() - - type person struct { - Name string - Address string - Age int - } - - me := &person{Name: "derek", Age: 22, Address: "85 Second St"} - c.Publish("hello", me) -} - -// EncodedConn's subscribers will automatically decode the -// wire data into the requested Go type using the Decode() -// method of the registered Encoder. The callback signature -// can also vary to include additional data, such as subject -// and reply subjects. -func ExampleEncodedConn_Subscribe() { - nc, _ := nats.Connect(nats.DefaultURL) - c, _ := nats.NewEncodedConn(nc, "json") - defer c.Close() - - type person struct { - Name string - Address string - Age int - } - - c.Subscribe("hello", func(p *person) { - fmt.Printf("Received a person! %+v\n", p) - }) - - c.Subscribe("hello", func(subj, reply string, p *person) { - fmt.Printf("Received a person on subject %s! %+v\n", subj, p) - }) - - me := &person{Name: "derek", Age: 22, Address: "85 Second St"} - c.Publish("hello", me) -} - -// BindSendChan() allows binding of a Go channel to a nats -// subject for publish operations. The Encoder attached to the -// EncodedConn will be used for marshaling. -func ExampleEncodedConn_BindSendChan() { - nc, _ := nats.Connect(nats.DefaultURL) - c, _ := nats.NewEncodedConn(nc, "json") - defer c.Close() - - type person struct { - Name string - Address string - Age int - } - - ch := make(chan *person) - c.BindSendChan("hello", ch) - - me := &person{Name: "derek", Age: 22, Address: "85 Second St"} - ch <- me -} - -// BindRecvChan() allows binding of a Go channel to a nats -// subject for subscribe operations. The Encoder attached to the -// EncodedConn will be used for un-marshaling. -func ExampleEncodedConn_BindRecvChan() { - nc, _ := nats.Connect(nats.DefaultURL) - c, _ := nats.NewEncodedConn(nc, "json") - defer c.Close() - - type person struct { - Name string - Address string - Age int - } - - ch := make(chan *person) - c.BindRecvChan("hello", ch) - - me := &person{Name: "derek", Age: 22, Address: "85 Second St"} - c.Publish("hello", me) - - // Receive the publish directly on a channel - who := <-ch - - fmt.Printf("%v says hello!\n", who) -} - func ExampleJetStream() { nc, err := nats.Connect("localhost") if err != nil { diff --git a/go_test.mod b/go_test.mod index 319a78434..ad05e0fa9 100644 --- a/go_test.mod +++ b/go_test.mod @@ -1,23 +1,26 @@ module github.com/nats-io/nats.go -go 1.20 +go 1.21 + +toolchain go1.22.5 require ( github.com/golang/protobuf v1.4.2 - github.com/klauspost/compress v1.17.7 - github.com/nats-io/nats-server/v2 v2.11.0-preview.1 + github.com/klauspost/compress v1.17.9 + github.com/nats-io/jwt v1.2.2 + github.com/nats-io/nats-server/v2 v2.11.0-preview.2 github.com/nats-io/nkeys v0.4.7 github.com/nats-io/nuid v1.0.1 go.uber.org/goleak v1.3.0 - golang.org/x/text v0.14.0 + golang.org/x/text v0.16.0 google.golang.org/protobuf v1.23.0 ) require ( + github.com/google/go-tpm v0.9.0 // indirect github.com/minio/highwayhash v1.0.2 // indirect - github.com/nats-io/jwt/v2 v2.5.5 // indirect - go.uber.org/automaxprocs v1.5.3 // indirect - golang.org/x/crypto v0.20.0 // indirect - golang.org/x/sys v0.17.0 // indirect + github.com/nats-io/jwt/v2 v2.5.7 // indirect + golang.org/x/crypto v0.24.0 // indirect + golang.org/x/sys v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect ) diff --git a/go_test.sum b/go_test.sum index ead7d3e28..5839875ea 100644 --- a/go_test.sum +++ b/go_test.sum @@ -1,4 +1,5 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= @@ -8,40 +9,47 @@ github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0 github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-tpm v0.9.0 h1:sQF6YqWMi+SCXpsmS3fd21oPy/vSddwZry4JnmltHVk= +github.com/google/go-tpm v0.9.0/go.mod h1:FkNVkc6C+IsvDI9Jw1OveJmxGZUUaKxtrpOS47QWKfU= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/nats-io/jwt/v2 v2.5.5 h1:ROfXb50elFq5c9+1ztaUbdlrArNFl2+fQWP6B8HGEq4= -github.com/nats-io/jwt/v2 v2.5.5/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= -github.com/nats-io/nats-server/v2 v2.11.0-dev.0.20240227164423-1d1d982f0538 h1:z0+WWP9+JS43uYJXUJQApG1HyJLOXzFdXks1eZMDlE0= -github.com/nats-io/nats-server/v2 v2.11.0-dev.0.20240227164423-1d1d982f0538/go.mod h1:J0sPAPoyG5tzqLha88PgAnG4dib7rxHVT/Fka8H6JBQ= -github.com/nats-io/nats-server/v2 v2.11.0-preview.1 h1:pvN3VGtehpjipubmLkWESb5MQASG6qbK7HGD4eslfPM= -github.com/nats-io/nats-server/v2 v2.11.0-preview.1/go.mod h1:J0sPAPoyG5tzqLha88PgAnG4dib7rxHVT/Fka8H6JBQ= +github.com/nats-io/jwt v1.2.2 h1:w3GMTO969dFg+UOKTmmyuu7IGdusK+7Ytlt//OYH/uU= +github.com/nats-io/jwt v1.2.2/go.mod h1:/xX356yQA6LuXI9xWW7mZNpxgF2mBmGecH+Fj34sP5Q= +github.com/nats-io/jwt/v2 v2.5.7 h1:j5lH1fUXCnJnY8SsQeB/a/z9Azgu2bYIDvtPVNdxe2c= +github.com/nats-io/jwt/v2 v2.5.7/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= +github.com/nats-io/nats-server/v2 v2.11.0-preview.2 h1:tT/UeBbFzHRzwy77T/+/Rbw58XP9F3CY3VmtcDltZ68= +github.com/nats-io/nats-server/v2 v2.11.0-preview.2/go.mod h1:ILDVzrTqMco4rQMOgEZimBjJHb1oZDlz1J+qhJtZlRM= +github.com/nats-io/nkeys v0.2.0/go.mod h1:XdZpAbhgyyODYqjTawOnIOI7VlbKSarI9Gfy1tqEu/s= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= -go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.20.0 h1:jmAMJJZXr5KiCw05dfYK9QnqaqKLYXijU23lsEdcQqg= -golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -51,3 +59,4 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/syncx/map.go b/internal/syncx/map.go new file mode 100644 index 000000000..d2278e62a --- /dev/null +++ b/internal/syncx/map.go @@ -0,0 +1,73 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import "sync" + +// Map is a type-safe wrapper around sync.Map. +// It is safe for concurrent use. +// The zero value of Map is an empty map ready to use. +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Load(key K) (V, bool) { + v, ok := m.m.Load(key) + if !ok { + var empty V + return empty, false + } + return v.(V), true +} + +func (m *Map[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { + return f(key.(K), value.(V)) + }) +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (V, bool) { + v, loaded := m.m.LoadOrStore(key, value) + return v.(V), loaded +} + +func (m *Map[K, V]) LoadAndDelete(key K) (V, bool) { + v, ok := m.m.LoadAndDelete(key) + if !ok { + var empty V + return empty, false + } + return v.(V), true +} + +func (m *Map[K, V]) CompareAndSwap(key K, old, new V) bool { + return m.m.CompareAndSwap(key, old, new) +} + +func (m *Map[K, V]) CompareAndDelete(key K, value V) bool { + return m.m.CompareAndDelete(key, value) +} + +func (m *Map[K, V]) Swap(key K, value V) (V, bool) { + previous, loaded := m.m.Swap(key, value) + return previous.(V), loaded +} diff --git a/internal/syncx/map_test.go b/internal/syncx/map_test.go new file mode 100644 index 000000000..df34b2f2f --- /dev/null +++ b/internal/syncx/map_test.go @@ -0,0 +1,152 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncx + +import ( + "testing" +) + +func TestMapLoad(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, ok := m.Load(1) + if !ok || v != "one" { + t.Errorf("Load(1) = %v, %v; want 'one', true", v, ok) + } + + v, ok = m.Load(2) + if ok || v != "" { + t.Errorf("Load(2) = %v, %v; want '', false", v, ok) + } +} + +func TestMapStore(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, ok := m.Load(1) + if !ok || v != "one" { + t.Errorf("Load(1) after Store(1, 'one') = %v, %v; want 'one', true", v, ok) + } +} + +func TestMapDelete(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + m.Delete(1) + + v, ok := m.Load(1) + if ok || v != "" { + t.Errorf("Load(1) after Delete(1) = %v, %v; want '', false", v, ok) + } +} + +func TestMapRange(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + m.Store(2, "two") + + var keys []int + var values []string + m.Range(func(key int, value string) bool { + keys = append(keys, key) + values = append(values, value) + return true + }) + + if len(keys) != 2 || len(values) != 2 { + t.Errorf("Range() keys = %v, values = %v; want 2 keys and 2 values", keys, values) + } +} + +func TestMapLoadOrStore(t *testing.T) { + var m Map[int, string] + + v, loaded := m.LoadOrStore(1, "one") + if loaded || v != "one" { + t.Errorf("LoadOrStore(1, 'one') = %v, %v; want 'one', false", v, loaded) + } + + v, loaded = m.LoadOrStore(1, "uno") + if !loaded || v != "one" { + t.Errorf("LoadOrStore(1, 'uno') = %v, %v; want 'one', true", v, loaded) + } +} + +func TestMapLoadAndDelete(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, ok := m.LoadAndDelete(1) + if !ok || v != "one" { + t.Errorf("LoadAndDelete(1) = %v, %v; want 'one', true", v, ok) + } + + v, ok = m.Load(1) + if ok || v != "" { + t.Errorf("Load(1) after LoadAndDelete(1) = %v, %v; want '', false", v, ok) + } + + // Test that LoadAndDelete on a missing key returns the zero value. + v, ok = m.LoadAndDelete(2) + if ok || v != "" { + t.Errorf("LoadAndDelete(2) = %v, %v; want '', false", v, ok) + } +} + +func TestMapCompareAndSwap(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + ok := m.CompareAndSwap(1, "one", "uno") + if !ok { + t.Errorf("CompareAndSwap(1, 'one', 'uno') = false; want true") + } + + v, _ := m.Load(1) + if v != "uno" { + t.Errorf("Load(1) after CompareAndSwap = %v; want 'uno'", v) + } +} + +func TestMapCompareAndDelete(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + ok := m.CompareAndDelete(1, "one") + if !ok { + t.Errorf("CompareAndDelete(1, 'one') = false; want true") + } + + v, _ := m.Load(1) + if v != "" { + t.Errorf("Load(1) after CompareAndDelete = %v; want ''", v) + } +} + +func TestMapSwap(t *testing.T) { + var m Map[int, string] + m.Store(1, "one") + + v, loaded := m.Swap(1, "uno") + if !loaded || v != "one" { + t.Errorf("Swap(1, 'uno') = %v, %v; want 'one', true", v, loaded) + } + + v, _ = m.Load(1) + if v != "uno" { + t.Errorf("Load(1) after Swap = %v; want 'uno'", v) + } +} diff --git a/jetstream/README.md b/jetstream/README.md index e2ee01ed4..4889a1b48 100644 --- a/jetstream/README.md +++ b/jetstream/README.md @@ -1,4 +1,5 @@ -# JetStream Simplified Client + +# JetStream Simplified Client [![JetStream API Reference](https://pkg.go.dev/badge/github.com/nats-io/nats.go/jetstream.svg)](https://pkg.go.dev/github.com/nats-io/nats.go/jetstream) This doc covers the basic usage of the `jetstream` package in `nats.go` client. @@ -29,6 +30,10 @@ This doc covers the basic usage of the `jetstream` package in `nats.go` client. - [Basic usage of KV bucket](#basic-usage-of-kv-bucket) - [Watching for changes on a bucket](#watching-for-changes-on-a-bucket) - [Additional operations on a bucket](#additional-operations-on-a-bucket) + - [Object Store](#object-store) + - [Basic usage of Object Store](#basic-usage-of-object-store) + - [Watching for changes on a store](#watching-for-changes-on-a-store) + - [Additional operations on a store](#additional-operations-on-a-store) - [Examples](#examples) ## Overview @@ -109,13 +114,19 @@ func main() { // Get 10 messages from the consumer messageCounter := 0 - msgs, _ := c.Fetch(10) + msgs, err := c.Fetch(10) + if err != nil { + // handle error + } + for msg := range msgs.Messages() { msg.Ack() fmt.Printf("Received a JetStream message via fetch: %s\n", string(msg.Data())) messageCounter++ } + fmt.Printf("received %d messages\n", messageCounter) + if msgs.Error() != nil { fmt.Println("Error during Fetch(): ", msgs.Error()) } @@ -387,19 +398,29 @@ of messages/bytes. By default, `Fetch()` will wait 30 seconds before timing out ```go // receive up to 10 messages from the stream -msgs, _ := c.Fetch(10) +msgs, err := c.Fetch(10) +if err != nil { + // handle error +} + for msg := range msgs.Messages() { fmt.Printf("Received a JetStream message: %s\n", string(msg.Data())) } + if msgs.Error() != nil { // handle error } // receive up to 1024 B of data -msgs, _ := c.FetchBytes(1024) +msgs, err := c.FetchBytes(1024) +if err != nil { +// handle error +} + for msg := range msgs.Messages() { fmt.Printf("Received a JetStream message: %s\n", string(msg.Data())) } + if msgs.Error() != nil { // handle error } @@ -410,10 +431,15 @@ stream available at the time of sending request: ```go // FetchNoWait will not wait for new messages if the whole batch is not available at the time of sending request. -msgs, _ := c.FetchNoWait(10) +msgs, err := c.FetchNoWait(10) +if err != nil { +// handle error +} + for msg := range msgs.Messages() { fmt.Printf("Received a JetStream message: %s\n", string(msg.Data())) } + if msgs.Error() != nil { // handle error } diff --git a/jetstream/consumer.go b/jetstream/consumer.go index a9ab8247b..d7a8e7739 100644 --- a/jetstream/consumer.go +++ b/jetstream/consumer.go @@ -21,6 +21,7 @@ import ( "strings" "time" + "github.com/nats-io/nats.go/internal/syncx" "github.com/nats-io/nuid" ) @@ -209,6 +210,9 @@ func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg Consu var ccSubj string if cfg.FilterSubject != "" && len(cfg.FilterSubjects) == 0 { + if err := validateSubject(cfg.FilterSubject); err != nil { + return nil, err + } ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateWithFilterSubjectT, stream, consumerName, cfg.FilterSubject)) } else { ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateT, stream, consumerName)) @@ -231,12 +235,12 @@ func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg Consu } return &pullConsumer{ - jetStream: js, - stream: stream, - name: resp.Name, - durable: cfg.Durable != "", - info: resp.ConsumerInfo, - subscriptions: make(map[string]*pullSubscription), + jetStream: js, + stream: stream, + name: resp.Name, + durable: cfg.Durable != "", + info: resp.ConsumerInfo, + subs: syncx.Map[string, *pullSubscription]{}, }, nil } @@ -283,12 +287,12 @@ func getConsumer(ctx context.Context, js *jetStream, stream, name string) (Consu } cons := &pullConsumer{ - jetStream: js, - stream: stream, - name: name, - durable: resp.Config.Durable != "", - info: resp.ConsumerInfo, - subscriptions: make(map[string]*pullSubscription, 0), + jetStream: js, + stream: stream, + name: name, + durable: resp.Config.Durable != "", + info: resp.ConsumerInfo, + subs: syncx.Map[string, *pullSubscription]{}, } return cons, nil @@ -356,8 +360,11 @@ func resumeConsumer(ctx context.Context, js *jetStream, stream, consumer string) } func validateConsumerName(dur string) error { - if strings.Contains(dur, ".") { - return fmt.Errorf("%w: %q", ErrInvalidConsumerName, dur) + if dur == "" { + return fmt.Errorf("%w: '%s'", ErrInvalidConsumerName, "name is required") + } + if strings.ContainsAny(dur, ">*. /\\") { + return fmt.Errorf("%w: '%s'", ErrInvalidConsumerName, dur) } return nil } diff --git a/jetstream/consumer_config.go b/jetstream/consumer_config.go index 495942fa3..5d419cdfa 100644 --- a/jetstream/consumer_config.go +++ b/jetstream/consumer_config.go @@ -172,8 +172,8 @@ type ( // MaxAckPending is a maximum number of outstanding unacknowledged // messages. Once this limit is reached, the server will suspend sending - // messages to the consumer. If not set, server default is 1000 - // seconds. Set to -1 for unlimited. + // messages to the consumer. If not set, server default is 1000. + // Set to -1 for unlimited. MaxAckPending int `json:"max_ack_pending,omitempty"` // HeadersOnly indicates whether only headers of messages should be sent diff --git a/jetstream/errors.go b/jetstream/errors.go index 5ed5176a1..8d2fec642 100644 --- a/jetstream/errors.go +++ b/jetstream/errors.go @@ -71,6 +71,10 @@ var ( // ErrJetStreamNotEnabled is an error returned when JetStream is not // enabled. + // + // Note: This error will not be returned in clustered mode, even if each + // server in the cluster does not have JetStream enabled. In clustered mode, + // requests will time out instead. ErrJetStreamNotEnabled JetStreamError = &jsError{apiErr: &APIError{ErrorCode: JSErrCodeJetStreamNotEnabled, Description: "jetstream not enabled", Code: 503}} // ErrJetStreamNotEnabledForAccount is an error returned when JetStream is @@ -266,6 +270,9 @@ var ( // of an ordered consumer which was not yet created. ErrOrderedConsumerNotCreated JetStreamError = &jsError{message: "consumer instance not yet created"} + // ErrJetStreamPublisherClosed is returned for each unfinished ack future when JetStream.Cleanup is called. + ErrJetStreamPublisherClosed JetStreamError = &jsError{message: "jetstream context closed"} + // KeyValue Errors // ErrKeyExists is returned when attempting to create a key that already diff --git a/jetstream/jetstream.go b/jetstream/jetstream.go index 369d5b833..6863a9b73 100644 --- a/jetstream/jetstream.go +++ b/jetstream/jetstream.go @@ -101,6 +101,19 @@ type ( // outstanding asynchronously published messages are acknowledged by the // server. PublishAsyncComplete() <-chan struct{} + + // CleanupPublisher will cleanup the publishing side of JetStreamContext. + // + // This will unsubscribe from the internal reply subject if needed. + // All pending async publishes will fail with ErrJetStreamContextClosed. + // + // If an error handler was provided, it will be called for each pending async + // publish and PublishAsyncComplete will be closed. + // + // After completing JetStreamContext is still usable - internal subscription + // will be recreated on next publish, but the acks from previous publishes will + // be lost. + CleanupPublisher() } // StreamManager provides CRUD API for managing streams. It is available as @@ -109,7 +122,8 @@ type ( // to operate on a stream. StreamManager interface { // CreateStream creates a new stream with given config and returns an - // interface to operate on it. If stream with given name already exists, + // interface to operate on it. If stream with given name already exists + // and its configuration differs from the provided one, // ErrStreamNameAlreadyInUse is returned. CreateStream(ctx context.Context, cfg StreamConfig) (Stream, error) @@ -428,7 +442,7 @@ func NewWithAPIPrefix(nc *nats.Conn, apiPrefix string, opts ...JetStreamOpt) (Je } } if apiPrefix == "" { - return nil, fmt.Errorf("API prefix cannot be empty") + return nil, errors.New("API prefix cannot be empty") } if !strings.HasSuffix(apiPrefix, ".") { jsOpts.apiPrefix = fmt.Sprintf("%s.", apiPrefix) @@ -745,13 +759,12 @@ func (js *jetStream) OrderedConsumer(ctx context.Context, stream string, cfg Ord namePrefix: nuid.Next(), doReset: make(chan struct{}, 1), } - if cfg.OptStartSeq != 0 { - oc.cursor.streamSeq = cfg.OptStartSeq - 1 - } - err := oc.reset() + consCfg := oc.getConsumerConfig() + cons, err := js.CreateOrUpdateConsumer(ctx, stream, *consCfg) if err != nil { return nil, err } + oc.currentConsumer = cons.(*pullConsumer) return oc, nil } @@ -793,7 +806,7 @@ func validateStreamName(stream string) error { if stream == "" { return ErrStreamNameRequired } - if strings.Contains(stream, ".") { + if strings.ContainsAny(stream, ">*. /\\") { return fmt.Errorf("%w: '%s'", ErrInvalidStreamName, stream) } return nil @@ -803,7 +816,7 @@ func validateSubject(subject string) error { if subject == "" { return fmt.Errorf("%w: %s", ErrInvalidSubject, "subject cannot be empty") } - if !subjectRegexp.MatchString(subject) { + if subject[0] == '.' || subject[len(subject)-1] == '.' || !subjectRegexp.MatchString(subject) { return fmt.Errorf("%w: %s", ErrInvalidSubject, subject) } return nil @@ -811,9 +824,11 @@ func validateSubject(subject string) error { // AccountInfo fetches account information from the server, containing details // about the account associated with this JetStream connection. If account is -// not enabled for JetStream, ErrJetStreamNotEnabledForAccount is returned. If -// the server does not have JetStream enabled, ErrJetStreamNotEnabled is -// returned. +// not enabled for JetStream, ErrJetStreamNotEnabledForAccount is returned. +// +// If the server does not have JetStream enabled, ErrJetStreamNotEnabled is +// returned (for a single server setup). For clustered topologies, AccountInfo +// will time out. func (js *jetStream) AccountInfo(ctx context.Context) (*AccountInfo, error) { ctx, cancel := wrapContextWithoutDeadline(ctx) if cancel != nil { @@ -1051,6 +1066,39 @@ func wrapContextWithoutDeadline(ctx context.Context) (context.Context, context.C return context.WithTimeout(ctx, defaultAPITimeout) } +// CleanupPublisher will cleanup the publishing side of JetStreamContext. +// +// This will unsubscribe from the internal reply subject if needed. +// All pending async publishes will fail with ErrJetStreamContextClosed. +// +// If an error handler was provided, it will be called for each pending async +// publish and PublishAsyncComplete will be closed. +// +// After completing JetStreamContext is still usable - internal subscription +// will be recreated on next publish, but the acks from previous publishes will +// be lost. +func (js *jetStream) CleanupPublisher() { + js.cleanupReplySub() + js.publisher.Lock() + errCb := js.publisher.aecb + for id, paf := range js.publisher.acks { + paf.err = ErrJetStreamPublisherClosed + if paf.errCh != nil { + paf.errCh <- paf.err + } + if errCb != nil { + // call error handler after releasing the mutex to avoid contention + defer errCb(js, paf.msg, ErrJetStreamPublisherClosed) + } + delete(js.publisher.acks, id) + } + if js.publisher.doneCh != nil { + close(js.publisher.doneCh) + js.publisher.doneCh = nil + } + js.publisher.Unlock() +} + func (js *jetStream) cleanupReplySub() { if js.publisher == nil { return diff --git a/jetstream/jetstream_test.go b/jetstream/jetstream_test.go index 62af4ef02..9d429ad33 100644 --- a/jetstream/jetstream_test.go +++ b/jetstream/jetstream_test.go @@ -276,9 +276,11 @@ func TestRetryWithBackoff(t *testing.T) { } func TestPullConsumer_checkPending(t *testing.T) { + tests := []struct { name string givenSub *pullSubscription + fetchInProgress bool shouldSend bool expectedPullRequest *pullRequest }{ @@ -292,7 +294,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdMessages: 5, MaxMessages: 10, }, - fetchInProgress: 0, }, shouldSend: false, }, @@ -307,7 +308,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdMessages: 5, MaxMessages: 10, }, - fetchInProgress: 0, }, shouldSend: true, expectedPullRequest: &pullRequest{ @@ -325,9 +325,9 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdMessages: 5, MaxMessages: 10, }, - fetchInProgress: 1, }, - shouldSend: false, + fetchInProgress: true, + shouldSend: false, }, { name: "pending bytes below threshold, send pull request", @@ -341,7 +341,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdBytes: 500, MaxBytes: 1000, }, - fetchInProgress: 0, }, shouldSend: true, expectedPullRequest: &pullRequest{ @@ -359,7 +358,6 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdBytes: 500, MaxBytes: 1000, }, - fetchInProgress: 0, }, shouldSend: false, }, @@ -373,9 +371,9 @@ func TestPullConsumer_checkPending(t *testing.T) { ThresholdBytes: 500, MaxBytes: 1000, }, - fetchInProgress: 1, }, - shouldSend: false, + fetchInProgress: true, + shouldSend: false, }, { name: "StopAfter set, pending msgs below StopAfter, send pull request", @@ -388,8 +386,7 @@ func TestPullConsumer_checkPending(t *testing.T) { MaxMessages: 10, StopAfter: 8, }, - fetchInProgress: 0, - delivered: 2, + delivered: 2, }, shouldSend: true, expectedPullRequest: &pullRequest{ @@ -408,8 +405,7 @@ func TestPullConsumer_checkPending(t *testing.T) { MaxMessages: 10, StopAfter: 6, }, - fetchInProgress: 0, - delivered: 0, + delivered: 0, }, shouldSend: false, }, @@ -419,6 +415,9 @@ func TestPullConsumer_checkPending(t *testing.T) { t.Run(test.name, func(t *testing.T) { prChan := make(chan *pullRequest, 1) test.givenSub.fetchNext = prChan + if test.fetchInProgress { + test.givenSub.fetchInProgress.Store(1) + } errs := make(chan error, 1) ok := make(chan struct{}, 1) go func() { @@ -431,13 +430,13 @@ func TestPullConsumer_checkPending(t *testing.T) { } ok <- struct{}{} case <-time.After(1 * time.Second): - errs <- fmt.Errorf("Timeout") + errs <- errors.New("Timeout") return } } else { select { case <-prChan: - errs <- fmt.Errorf("Unexpected pull request") + errs <- errors.New("Unexpected pull request") case <-time.After(100 * time.Millisecond): ok <- struct{}{} return @@ -456,3 +455,96 @@ func TestPullConsumer_checkPending(t *testing.T) { }) } } + +func TestKV_keyValid(t *testing.T) { + tests := []struct { + key string + ok bool + }{ + {key: "foo123", ok: true}, + {key: "foo.bar", ok: true}, + {key: "Foo.123=bar_baz-abc", ok: true}, + {key: "foo.*.bar", ok: false}, + {key: "foo.>", ok: false}, + {key: ">", ok: false}, + {key: "*", ok: false}, + {key: "foo!", ok: false}, + {key: "foo bar", ok: false}, + {key: "", ok: false}, + {key: " ", ok: false}, + {key: ".", ok: false}, + {key: ".foo", ok: false}, + {key: "foo.", ok: false}, + } + + for _, test := range tests { + t.Run(test.key, func(t *testing.T) { + res := keyValid(test.key) + if res != test.ok { + t.Fatalf("Invalid result; want: %v; got: %v", test.ok, res) + } + }) + } +} + +func TestKV_searchKeyValid(t *testing.T) { + tests := []struct { + key string + ok bool + }{ + {key: "foo123", ok: true}, + {key: "foo.bar", ok: true}, + {key: "Foo.123=bar_baz-abc", ok: true}, + {key: "foo.*.bar", ok: true}, + {key: "foo.>", ok: true}, + {key: ">", ok: true}, + {key: "*", ok: true}, + {key: "foo!", ok: false}, + {key: "foo bar", ok: false}, + {key: "", ok: false}, + {key: " ", ok: false}, + {key: ".", ok: false}, + {key: ".foo", ok: false}, + {key: "foo.", ok: false}, + } + + for _, test := range tests { + t.Run(test.key, func(t *testing.T) { + res := searchKeyValid(test.key) + if res != test.ok { + t.Fatalf("Invalid result; want: %v; got: %v", test.ok, res) + } + }) + } +} + +func TestKV_bucketValid(t *testing.T) { + tests := []struct { + key string + ok bool + }{ + {key: "foo123", ok: true}, + {key: "Foo123-bar_baz", ok: true}, + {key: "foo.bar", ok: false}, + {key: "foo.*.bar", ok: false}, + {key: "foo.>", ok: false}, + {key: ">", ok: false}, + {key: "*", ok: false}, + {key: "foo!", ok: false}, + {key: "foo bar", ok: false}, + {key: "", ok: false}, + {key: " ", ok: false}, + {key: ".", ok: false}, + {key: ".foo", ok: false}, + {key: "foo.", ok: false}, + } + + for _, test := range tests { + t.Run(test.key, func(t *testing.T) { + res := bucketValid(test.key) + if res != test.ok { + t.Fatalf("Invalid result; want: %v; got: %v", test.ok, res) + } + }) + } +} diff --git a/jetstream/kv.go b/jetstream/kv.go index 42a86c51e..38acbdc61 100644 --- a/jetstream/kv.go +++ b/jetstream/kv.go @@ -17,6 +17,7 @@ import ( "context" "errors" "fmt" + "reflect" "regexp" "strconv" "strings" @@ -122,6 +123,7 @@ type ( // Update will update the value if the latest revision matches. // If the provided revision is not the latest, Update will return an error. + // Update also resets the TTL associated with the key (if any). Update(ctx context.Context, key string, value []byte, revision uint64) (uint64, error) // Delete will place a delete marker and leave all revisions. A history @@ -164,8 +166,12 @@ type ( // with the same options as Watch. WatchAll(ctx context.Context, opts ...WatchOpt) (KeyWatcher, error) - // Keys will return all keys. DEPRECATED: Use ListKeys instead to avoid - // memory issues. + // WatchFiltered will watch for any updates to keys that match the keys + // argument. It can be configured with the same options as Watch. + WatchFiltered(ctx context.Context, keys []string, opts ...WatchOpt) (KeyWatcher, error) + + // Keys will return all keys. + // Deprecated: Use ListKeys instead to avoid memory issues. Keys(ctx context.Context, opts ...WatchOpt) ([]string, error) // ListKeys will return KeyLister, allowing to retrieve all keys from @@ -196,52 +202,52 @@ type ( // Bucket is the name of the KeyValue store. Bucket name has to be // unique and can only contain alphanumeric characters, dashes, and // underscores. - Bucket string + Bucket string `json:"bucket"` // Description is an optional description for the KeyValue store. - Description string + Description string `json:"description,omitempty"` // MaxValueSize is the maximum size of a value in bytes. If not // specified, the default is -1 (unlimited). - MaxValueSize int32 + MaxValueSize int32 `json:"max_value_size,omitempty"` // History is the number of historical values to keep per key. If not // specified, the default is 1. Max is 64. - History uint8 + History uint8 `json:"history,omitempty"` // TTL is the expiry time for keys. By default, keys do not expire. - TTL time.Duration + TTL time.Duration `json:"ttl,omitempty"` // MaxBytes is the maximum size in bytes of the KeyValue store. If not // specified, the default is -1 (unlimited). - MaxBytes int64 + MaxBytes int64 `json:"max_bytes,omitempty"` // Storage is the type of storage to use for the KeyValue store. If not // specified, the default is FileStorage. - Storage StorageType + Storage StorageType `json:"storage,omitempty"` // Replicas is the number of replicas to keep for the KeyValue store in // clustered jetstream. Defaults to 1, maximum is 5. - Replicas int + Replicas int `json:"num_replicas,omitempty"` // Placement is used to declare where the stream should be placed via // tags and/or an explicit cluster name. - Placement *Placement + Placement *Placement `json:"placement,omitempty"` // RePublish allows immediate republishing a message to the configured // subject after it's stored. - RePublish *RePublish + RePublish *RePublish `json:"republish,omitempty"` // Mirror defines the consiguration for mirroring another KeyValue // store. - Mirror *StreamSource + Mirror *StreamSource `json:"mirror,omitempty"` // Sources defines the configuration for sources of a KeyValue store. - Sources []*StreamSource + Sources []*StreamSource `json:"sources,omitempty"` // Compression sets the underlying stream compression. // NOTE: Compression is supported for nats-server 2.10.0+ - Compression bool + Compression bool `json:"compression,omitempty"` } // KeyLister is used to retrieve a list of key value store keys. It returns @@ -447,12 +453,13 @@ const ( // Regex for valid keys and buckets. var ( - validBucketRe = regexp.MustCompile(`\A[a-zA-Z0-9_-]+\z`) - validKeyRe = regexp.MustCompile(`\A[-/_=\.a-zA-Z0-9]+\z`) + validBucketRe = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + validKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9]+$`) + validSearchKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9*]*[>]?$`) ) func (js *jetStream) KeyValue(ctx context.Context, bucket string) (KeyValue, error) { - if !validBucketRe.MatchString(bucket) { + if !bucketValid(bucket) { return nil, ErrInvalidBucketName } streamName := fmt.Sprintf(kvBucketNameTmpl, bucket) @@ -488,8 +495,26 @@ func (js *jetStream) CreateKeyValue(ctx context.Context, cfg KeyValueConfig) (Ke // errors are joined so that backwards compatibility is retained // and previous checks for ErrStreamNameAlreadyInUse will still work. err = errors.Join(fmt.Errorf("%w: %s", ErrBucketExists, cfg.Bucket), err) + + // If we have a failure to add, it could be because we have + // a config change if the KV was created against before a bug fix + // that changed the value of discard policy. + // We will check if the stream exists and if the only difference + // is the discard policy, we will update the stream. + // The same logic applies for KVs created pre 2.9.x and + // the AllowDirect setting. + if stream, _ = js.Stream(ctx, scfg.Name); stream != nil { + cfg := stream.CachedInfo().Config + cfg.Discard = scfg.Discard + cfg.AllowDirect = scfg.AllowDirect + if reflect.DeepEqual(cfg, scfg) { + stream, err = js.UpdateStream(ctx, scfg) + } + } + } + if err != nil { + return nil, err } - return nil, err } pushJS, err := js.legacyJetStream() if err != nil { @@ -539,7 +564,7 @@ func (js *jetStream) CreateOrUpdateKeyValue(ctx context.Context, cfg KeyValueCon } func (js *jetStream) prepareKeyValueConfig(ctx context.Context, cfg KeyValueConfig) (StreamConfig, error) { - if !validBucketRe.MatchString(cfg.Bucket) { + if !bucketValid(cfg.Bucket) { return StreamConfig{}, ErrInvalidBucketName } if _, err := js.AccountInfo(ctx); err != nil { @@ -601,6 +626,7 @@ func (js *jetStream) prepareKeyValueConfig(ctx context.Context, cfg KeyValueConf AllowDirect: true, RePublish: cfg.RePublish, Compression: compression, + Discard: DiscardNew, } if cfg.Mirror != nil { // Copy in case we need to make changes so we do not change caller's version. @@ -636,7 +662,7 @@ func (js *jetStream) prepareKeyValueConfig(ctx context.Context, cfg KeyValueConf // DeleteKeyValue will delete this KeyValue store (JetStream stream). func (js *jetStream) DeleteKeyValue(ctx context.Context, bucket string) error { - if !validBucketRe.MatchString(bucket) { + if !bucketValid(bucket) { return ErrInvalidBucketName } stream := fmt.Sprintf(kvBucketNameTmpl, bucket) @@ -773,6 +799,13 @@ func (js *jetStream) legacyJetStream() (nats.JetStreamContext, error) { return js.conn.JetStream(opts...) } +func bucketValid(bucket string) bool { + if len(bucket) == 0 { + return false + } + return validBucketRe.MatchString(bucket) +} + func keyValid(key string) bool { if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' { return false @@ -780,6 +813,13 @@ func keyValid(key string) bool { return validKeyRe.MatchString(key) } +func searchKeyValid(key string) bool { + if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' { + return false + } + return validSearchKeyRe.MatchString(key) +} + func (kv *kvs) get(ctx context.Context, key string, revision uint64) (KeyValueEntry, error) { if !keyValid(key) { return nil, ErrInvalidKey @@ -1033,9 +1073,12 @@ func (w *watcher) Stop() error { return w.sub.Unsubscribe() } -// Watch for any updates to keys that match the keys argument which could include wildcards. -// Watch will send a nil entry when it has received all initial values. -func (kv *kvs) Watch(ctx context.Context, keys string, opts ...WatchOpt) (KeyWatcher, error) { +func (kv *kvs) WatchFiltered(ctx context.Context, keys []string, opts ...WatchOpt) (KeyWatcher, error) { + for _, key := range keys { + if !searchKeyValid(key) { + return nil, fmt.Errorf("%w: %s", ErrInvalidKey, "key cannot be empty and must be a valid NATS subject") + } + } var o watchOpts for _, opt := range opts { if opt != nil { @@ -1046,10 +1089,20 @@ func (kv *kvs) Watch(ctx context.Context, keys string, opts ...WatchOpt) (KeyWat } // Could be a pattern so don't check for validity as we normally do. - var b strings.Builder - b.WriteString(kv.pre) - b.WriteString(keys) - keys = b.String() + for i, key := range keys { + var b strings.Builder + b.WriteString(kv.pre) + b.WriteString(key) + keys[i] = b.String() + } + + // if no keys are provided, watch all keys + if len(keys) == 0 { + var b strings.Builder + b.WriteString(kv.pre) + b.WriteString(AllKeys) + keys = []string{b.String()} + } // We will block below on placing items on the chan. That is by design. w := &watcher{updates: make(chan KeyValueEntry, 256)} @@ -1122,7 +1175,14 @@ func (kv *kvs) Watch(ctx context.Context, keys string, opts ...WatchOpt) (KeyWat // update() callback. w.mu.Lock() defer w.mu.Unlock() - sub, err := kv.pushJS.Subscribe(keys, update, subOpts...) + var sub *nats.Subscription + var err error + if len(keys) == 1 { + sub, err = kv.pushJS.Subscribe(keys[0], update, subOpts...) + } else { + subOpts = append(subOpts, nats.ConsumerFilterSubjects(keys...)) + sub, err = kv.pushJS.Subscribe("", update, subOpts...) + } if err != nil { return nil, err } @@ -1146,6 +1206,12 @@ func (kv *kvs) Watch(ctx context.Context, keys string, opts ...WatchOpt) (KeyWat return w, nil } +// Watch for any updates to keys that match the keys argument which could include wildcards. +// Watch will send a nil entry when it has received all initial values. +func (kv *kvs) Watch(ctx context.Context, keys string, opts ...WatchOpt) (KeyWatcher, error) { + return kv.WatchFiltered(ctx, []string{keys}, opts...) +} + // WatchAll will invoke the callback for all updates. func (kv *kvs) WatchAll(ctx context.Context, opts ...WatchOpt) (KeyWatcher, error) { return kv.Watch(ctx, AllKeys, opts...) diff --git a/jetstream/message.go b/jetstream/message.go index 81e151268..095f13968 100644 --- a/jetstream/message.go +++ b/jetstream/message.go @@ -16,6 +16,7 @@ package jetstream import ( "bytes" "context" + "errors" "fmt" "strconv" "strings" @@ -434,7 +435,7 @@ func parsePending(msg *nats.Msg) (int, int, error) { if msgsLeftStr != "" { msgsLeft, err = strconv.Atoi(msgsLeftStr) if err != nil { - return 0, 0, fmt.Errorf("nats: invalid format of Nats-Pending-Messages") + return 0, 0, errors.New("nats: invalid format of Nats-Pending-Messages") } } bytesLeftStr := msg.Header.Get("Nats-Pending-Bytes") @@ -442,7 +443,7 @@ func parsePending(msg *nats.Msg) (int, int, error) { if bytesLeftStr != "" { bytesLeft, err = strconv.Atoi(bytesLeftStr) if err != nil { - return 0, 0, fmt.Errorf("nats: invalid format of Nats-Pending-Bytes") + return 0, 0, errors.New("nats: invalid format of Nats-Pending-Bytes") } } return msgsLeft, bytesLeft, nil diff --git a/jetstream/object.go b/jetstream/object.go index 271cc2235..a0eecff33 100644 --- a/jetstream/object.go +++ b/jetstream/object.go @@ -918,7 +918,13 @@ func (obs *obs) Get(ctx context.Context, name string, opts ...GetObjectOpt) (Obj } chunkSubj := fmt.Sprintf(objChunksPreTmpl, obs.name, info.NUID) - _, err = obs.pushJS.Subscribe(chunkSubj, processChunk, nats.OrderedConsumer(), nats.Context(ctx)) + streamName := fmt.Sprintf(objNameTmpl, obs.name) + subscribeOpts := []nats.SubOpt{ + nats.OrderedConsumer(), + nats.Context(ctx), + nats.BindStream(streamName), + } + _, err = obs.pushJS.Subscribe(chunkSubj, processChunk, subscribeOpts...) if err != nil { return nil, err } @@ -1302,7 +1308,8 @@ func (obs *obs) Watch(ctx context.Context, opts ...WatchOpt) (ObjectWatcher, err } // Used ordered consumer to deliver results. - subOpts := []nats.SubOpt{nats.OrderedConsumer()} + streamName := fmt.Sprintf(objNameTmpl, obs.name) + subOpts := []nats.SubOpt{nats.OrderedConsumer(), nats.BindStream(streamName)} if !o.includeHistory { subOpts = append(subOpts, nats.DeliverLastPerSubject()) } diff --git a/jetstream/ordered.go b/jetstream/ordered.go index fd7fe2f50..5fe656e9b 100644 --- a/jetstream/ordered.go +++ b/jetstream/ordered.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -31,17 +32,19 @@ type ( cfg *OrderedConsumerConfig stream string currentConsumer *pullConsumer + currentSub *pullSubscription cursor cursor namePrefix string serial int consumerType consumerType doReset chan struct{} - resetInProgress uint32 + resetInProgress atomic.Uint32 userErrHandler ConsumeErrHandlerFunc stopAfter int stopAfterMsgsLeft chan int withStopAfter bool runningFetch *fetchResult + subscription *orderedSubscription sync.Mutex } @@ -49,7 +52,7 @@ type ( consumer *orderedConsumer opts []PullMessagesOpt done chan struct{} - closed uint32 + closed atomic.Uint32 } cursor struct { @@ -66,7 +69,10 @@ const ( consumerTypeFetch ) -var errOrderedSequenceMismatch = errors.New("sequence mismatch") +var ( + errOrderedSequenceMismatch = errors.New("sequence mismatch") + errOrderedConsumerClosed = errors.New("ordered consumer closed") +) // Consume can be used to continuously receive messages and handle them // with the provided callback function. Consume cannot be used concurrently @@ -91,7 +97,8 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) } c.userErrHandler = consumeOpts.ErrHandler - opts = append(opts, ConsumeErrHandler(c.errHandler(c.serial))) + opts = append(opts, consumeReconnectNotify(), + ConsumeErrHandler(c.errHandler(c.serial))) if consumeOpts.StopAfter > 0 { c.withStopAfter = true c.stopAfter = consumeOpts.StopAfter @@ -104,6 +111,7 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt consumer: c, done: make(chan struct{}, 1), } + c.subscription = sub internalHandler := func(serial int) func(msg Msg) { return func(msg Msg) { // handler is a noop if message was delivered for a consumer with different serial @@ -112,19 +120,11 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt } meta, err := msg.Metadata() if err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } - c.errHandler(serial)(sub, err) + c.errHandler(serial)(c.currentSub, err) return } dseq := meta.Sequence.Consumer if dseq != c.cursor.deliverSeq+1 { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } c.errHandler(serial)(sub, errOrderedSequenceMismatch) return } @@ -134,21 +134,21 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt } } - _, err = c.currentConsumer.Consume(internalHandler(c.serial), opts...) + cc, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...) if err != nil { return nil, err } + c.currentSub = cc.(*pullSubscription) go func() { for { select { case <-c.doReset: if err := c.reset(); err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return + if errors.Is(err, errOrderedConsumerClosed) { + continue } - c.errHandler(c.serial)(sub, err) + c.errHandler(c.serial)(c.currentSub, err) } if c.withStopAfter { select { @@ -171,14 +171,20 @@ func (c *orderedConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt if c.withStopAfter { opts = append(opts, consumeStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft)) } - if _, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil { - sub, ok := c.currentConsumer.getSubscription("") - if !ok { - return - } - c.errHandler(c.serial)(sub, err) + if cc, err := c.currentConsumer.Consume(internalHandler(c.serial), opts...); err != nil { + c.errHandler(c.serial)(cc, err) + } else { + c.Lock() + c.currentSub = cc.(*pullSubscription) + c.Unlock() } case <-sub.done: + s := sub.consumer.currentSub + if s != nil { + sub.consumer.Lock() + s.Stop() + sub.consumer.Unlock() + } return case msgsLeft, ok := <-c.stopAfterMsgsLeft: if !ok { @@ -196,16 +202,16 @@ func (c *orderedConsumer) errHandler(serial int) func(cc ConsumeContext, err err return func(cc ConsumeContext, err error) { c.Lock() defer c.Unlock() - if c.userErrHandler != nil && !errors.Is(err, errOrderedSequenceMismatch) { + if c.userErrHandler != nil && !errors.Is(err, errOrderedSequenceMismatch) && !errors.Is(err, errConnected) { c.userErrHandler(cc, err) } if errors.Is(err, ErrNoHeartbeat) || errors.Is(err, errOrderedSequenceMismatch) || errors.Is(err, ErrConsumerDeleted) || - errors.Is(err, ErrConsumerNotFound) { + errors.Is(err, errConnected) { // only reset if serial matches the current consumer serial and there is no reset in progress - if serial == c.serial && atomic.LoadUint32(&c.resetInProgress) == 0 { - atomic.StoreUint32(&c.resetInProgress, 1) + if serial == c.serial && c.resetInProgress.Load() == 0 { + c.resetInProgress.Store(1) c.doReset <- struct{}{} } } @@ -234,7 +240,9 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er if err != nil { return nil, fmt.Errorf("%w: %s", ErrInvalidOption, err) } - opts = append(opts, WithMessagesErrOnMissingHeartbeat(true)) + opts = append(opts, + WithMessagesErrOnMissingHeartbeat(true), + messagesReconnectNotify()) c.stopAfterMsgsLeft = make(chan int, 1) if consumeOpts.StopAfter > 0 { c.withStopAfter = true @@ -244,28 +252,25 @@ func (c *orderedConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, er if c.stopAfter > 0 { opts = append(opts, messagesStopAfterNotify(c.stopAfter, c.stopAfterMsgsLeft)) } - _, err = c.currentConsumer.Messages(opts...) + cc, err := c.currentConsumer.Messages(opts...) if err != nil { return nil, err } + c.currentSub = cc.(*pullSubscription) sub := &orderedSubscription{ consumer: c, opts: opts, done: make(chan struct{}, 1), } + c.subscription = sub return sub, nil } func (s *orderedSubscription) Next() (Msg, error) { for { - currentConsumer := s.consumer.currentConsumer - sub, ok := currentConsumer.getSubscription("") - if !ok { - return nil, ErrMsgIteratorClosed - } - msg, err := sub.Next() + msg, err := s.consumer.currentSub.Next() if err != nil { if errors.Is(err, ErrMsgIteratorClosed) { s.Stop() @@ -283,23 +288,40 @@ func (s *orderedSubscription) Next() (Msg, error) { s.opts[len(s.opts)-1] = StopAfter(s.consumer.stopAfter) } if err := s.consumer.reset(); err != nil { + if errors.Is(err, errOrderedConsumerClosed) { + return nil, ErrMsgIteratorClosed + } return nil, err } - _, err := s.consumer.currentConsumer.Messages(s.opts...) + cc, err := s.consumer.currentConsumer.Messages(s.opts...) if err != nil { return nil, err } + s.consumer.currentSub = cc.(*pullSubscription) continue } + meta, err := msg.Metadata() if err != nil { - s.consumer.errHandler(s.consumer.serial)(sub, err) - continue + return nil, err } serial := serialNumberFromConsumer(meta.Consumer) + if serial != s.consumer.serial { + continue + } dseq := meta.Sequence.Consumer if dseq != s.consumer.cursor.deliverSeq+1 { - s.consumer.errHandler(serial)(sub, errOrderedSequenceMismatch) + if err := s.consumer.reset(); err != nil { + if errors.Is(err, errOrderedConsumerClosed) { + return nil, ErrMsgIteratorClosed + } + return nil, err + } + cc, err := s.consumer.currentConsumer.Messages(s.opts...) + if err != nil { + return nil, err + } + s.consumer.currentSub = cc.(*pullSubscription) continue } s.consumer.cursor.deliverSeq = dseq @@ -309,33 +331,60 @@ func (s *orderedSubscription) Next() (Msg, error) { } func (s *orderedSubscription) Stop() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } - sub, ok := s.consumer.currentConsumer.getSubscription("") - if !ok { - return + s.consumer.Lock() + defer s.consumer.Unlock() + if s.consumer.currentSub != nil { + s.consumer.currentSub.Stop() } - s.consumer.currentConsumer.Lock() - defer s.consumer.currentConsumer.Unlock() - sub.Stop() close(s.done) } func (s *orderedSubscription) Drain() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } - sub, ok := s.consumer.currentConsumer.getSubscription("") - if !ok { - return + if s.consumer.currentSub != nil { + s.consumer.currentConsumer.Lock() + s.consumer.currentSub.Drain() + s.consumer.currentConsumer.Unlock() } - s.consumer.currentConsumer.Lock() - defer s.consumer.currentConsumer.Unlock() - sub.Drain() close(s.done) } +// Closed returns a channel that is closed when the consuming is +// fully stopped/drained. When the channel is closed, no more messages +// will be received and processing is complete. +func (s *orderedSubscription) Closed() <-chan struct{} { + s.consumer.Lock() + defer s.consumer.Unlock() + closedCh := make(chan struct{}) + + go func() { + for { + s.consumer.Lock() + if s.consumer.currentSub == nil { + return + } + + closed := s.consumer.currentSub.Closed() + s.consumer.Unlock() + + // wait until the underlying pull consumer is closed + <-closed + // if the subscription is closed and ordered consumer is closed as well, + // send a signal that the Consume() is fully stopped + if s.closed.Load() == 1 { + close(closedCh) + return + } + } + }() + return closedCh +} + // Fetch is used to retrieve up to a provided number of messages from a // stream. This method will always send a single request and wait until // either all messages are retrieved or request times out. @@ -353,10 +402,17 @@ func (c *orderedConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, erro c.currentConsumer.Unlock() return nil, ErrOrderedConsumerConcurrentRequests } - c.cursor.streamSeq = c.runningFetch.sseq + if c.runningFetch.sseq != 0 { + c.cursor.streamSeq = c.runningFetch.sseq + } } c.currentConsumer.Unlock() c.consumerType = consumerTypeFetch + sub := orderedSubscription{ + consumer: c, + done: make(chan struct{}), + } + c.subscription = &sub err := c.reset() if err != nil { return nil, err @@ -384,9 +440,16 @@ func (c *orderedConsumer) FetchBytes(maxBytes int, opts ...FetchOpt) (MessageBat if !c.runningFetch.done { return nil, ErrOrderedConsumerConcurrentRequests } - c.cursor.streamSeq = c.runningFetch.sseq + if c.runningFetch.sseq != 0 { + c.cursor.streamSeq = c.runningFetch.sseq + } } c.consumerType = consumerTypeFetch + sub := orderedSubscription{ + consumer: c, + done: make(chan struct{}), + } + c.subscription = &sub err := c.reset() if err != nil { return nil, err @@ -415,6 +478,11 @@ func (c *orderedConsumer) FetchNoWait(batch int) (MessageBatch, error) { return nil, ErrOrderedConsumerConcurrentRequests } c.consumerType = consumerTypeFetch + sub := orderedSubscription{ + consumer: c, + done: make(chan struct{}), + } + c.subscription = &sub err := c.reset() if err != nil { return nil, err @@ -448,7 +516,11 @@ func serialNumberFromConsumer(name string) int { if len(name) == 0 { return 0 } - serial, err := strconv.Atoi(name[len(name)-1:]) + parts := strings.Split(name, "_") + if len(parts) < 2 { + return 0 + } + serial, err := strconv.Atoi(parts[len(parts)-1]) if err != nil { return 0 } @@ -458,73 +530,79 @@ func serialNumberFromConsumer(name string) int { func (c *orderedConsumer) reset() error { c.Lock() defer c.Unlock() - defer atomic.StoreUint32(&c.resetInProgress, 0) + defer c.resetInProgress.Store(0) if c.currentConsumer != nil { - sub, ok := c.currentConsumer.getSubscription("") c.currentConsumer.Lock() - if ok { - sub.Stop() + if c.currentSub != nil { + c.currentSub.Stop() } consName := c.currentConsumer.CachedInfo().Name c.currentConsumer.Unlock() - var err error - for i := 0; ; i++ { - if c.cfg.MaxResetAttempts > 0 && i == c.cfg.MaxResetAttempts { - return fmt.Errorf("%w: maximum number of delete attempts reached: %s", ErrOrderedConsumerReset, err) - } + go func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - err = c.jetStream.DeleteConsumer(ctx, c.stream, consName) + _ = c.jetStream.DeleteConsumer(ctx, c.stream, consName) cancel() - if err != nil { - if errors.Is(err, ErrConsumerNotFound) { - break - } - if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - continue - } - return err - } - break - } + }() } - seq := c.cursor.streamSeq + 1 + c.cursor.deliverSeq = 0 - consumerConfig := c.getConsumerConfigForSeq(seq) + consumerConfig := c.getConsumerConfig() var err error var cons Consumer - for i := 0; ; i++ { - if c.cfg.MaxResetAttempts > 0 && i == c.cfg.MaxResetAttempts { - return fmt.Errorf("%w: maximum number of create consumer attempts reached: %s", ErrOrderedConsumerReset, err) + + backoffOpts := backoffOpts{ + attempts: c.cfg.MaxResetAttempts, + initialInterval: time.Second, + factor: 2, + maxInterval: 10 * time.Second, + cancel: c.subscription.done, + } + err = retryWithBackoff(func(attempt int) (bool, error) { + isClosed := c.subscription.closed.Load() == 1 + if isClosed { + return false, errOrderedConsumerClosed } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() cons, err = c.jetStream.CreateOrUpdateConsumer(ctx, c.stream, *consumerConfig) if err != nil { - if errors.Is(err, ErrConsumerNotFound) { - cancel() - break - } - if errors.Is(err, nats.ErrTimeout) || errors.Is(err, context.DeadlineExceeded) { - cancel() - continue - } - cancel() - return err + return true, err } - cancel() - break + return false, nil + }, backoffOpts) + if err != nil { + return err } c.currentConsumer = cons.(*pullConsumer) return nil } -func (c *orderedConsumer) getConsumerConfigForSeq(seq uint64) *ConsumerConfig { +func (c *orderedConsumer) getConsumerConfig() *ConsumerConfig { c.serial++ + var nextSeq uint64 + + // if stream sequence is not initialized, no message was consumed yet + // therefore, start from the beginning (either from 1 or from the provided sequence) + if c.cursor.streamSeq == 0 { + if c.cfg.OptStartSeq != 0 { + nextSeq = c.cfg.OptStartSeq + } else { + nextSeq = 1 + } + } else { + // otherwise, start from the next sequence + nextSeq = c.cursor.streamSeq + 1 + } + + if c.cfg.MaxResetAttempts == 0 { + c.cfg.MaxResetAttempts = -1 + } name := fmt.Sprintf("%s_%d", c.namePrefix, c.serial) cfg := &ConsumerConfig{ Name: name, DeliverPolicy: DeliverByStartSequencePolicy, - OptStartSeq: seq, + OptStartSeq: nextSeq, AckPolicy: AckNonePolicy, InactiveThreshold: 5 * time.Minute, Replicas: 1, @@ -536,8 +614,12 @@ func (c *orderedConsumer) getConsumerConfigForSeq(seq uint64) *ConsumerConfig { } else { cfg.FilterSubjects = c.cfg.FilterSubjects } + if c.cfg.InactiveThreshold != 0 { + cfg.InactiveThreshold = c.cfg.InactiveThreshold + } - if seq != c.cfg.OptStartSeq+1 { + // if the cursor is not yet set, use the provided deliver policy + if c.cursor.streamSeq != 0 { return cfg } @@ -549,19 +631,16 @@ func (c *orderedConsumer) getConsumerConfigForSeq(seq uint64) *ConsumerConfig { c.cfg.DeliverPolicy == DeliverAllPolicy { cfg.OptStartSeq = 0 + } else if c.cfg.DeliverPolicy == DeliverByStartTimePolicy { + cfg.OptStartSeq = 0 + cfg.OptStartTime = c.cfg.OptStartTime + } else { + cfg.OptStartSeq = c.cfg.OptStartSeq } if cfg.DeliverPolicy == DeliverLastPerSubjectPolicy && len(c.cfg.FilterSubjects) == 0 { cfg.FilterSubjects = []string{">"} } - if c.cfg.OptStartTime != nil { - cfg.OptStartSeq = 0 - cfg.DeliverPolicy = DeliverByStartTimePolicy - cfg.OptStartTime = c.cfg.OptStartTime - } - if c.cfg.InactiveThreshold != 0 { - cfg.InactiveThreshold = c.cfg.InactiveThreshold - } return cfg } @@ -582,6 +661,20 @@ func messagesStopAfterNotify(numMsgs int, msgsLeftAfterStop chan int) PullMessag }) } +func consumeReconnectNotify() PullConsumeOpt { + return pullOptFunc(func(opts *consumeOpts) error { + opts.notifyOnReconnect = true + return nil + }) +} + +func messagesReconnectNotify() PullMessagesOpt { + return pullOptFunc(func(opts *consumeOpts) error { + opts.notifyOnReconnect = true + return nil + }) +} + // Info returns information about the ordered consumer. // Note that this method will fetch the latest instance of the // consumer from the server, which can be deleted by the library at any time. @@ -622,3 +715,91 @@ func (c *orderedConsumer) CachedInfo() *ConsumerInfo { } return c.currentConsumer.info } + +type backoffOpts struct { + // total retry attempts + // -1 for unlimited + attempts int + // initial interval after which first retry will be performed + // defaults to 1s + initialInterval time.Duration + // determines whether first function execution should be performed immediately + disableInitialExecution bool + // multiplier on each attempt + // defaults to 2 + factor float64 + // max interval between retries + // after reaching this value, all subsequent + // retries will be performed with this interval + // defaults to 1 minute + maxInterval time.Duration + // custom backoff intervals + // if set, overrides all other options except attempts + // if attempts are set, then the last interval will be used + // for all subsequent retries after reaching the limit + customBackoff []time.Duration + // cancel channel + // if set, retry will be canceled when this channel is closed + cancel <-chan struct{} +} + +func retryWithBackoff(f func(int) (bool, error), opts backoffOpts) error { + var err error + var shouldContinue bool + // if custom backoff is set, use it instead of other options + if len(opts.customBackoff) > 0 { + if opts.attempts != 0 { + return errors.New("cannot use custom backoff intervals when attempts are set") + } + for i, interval := range opts.customBackoff { + select { + case <-opts.cancel: + return nil + case <-time.After(interval): + } + shouldContinue, err = f(i) + if !shouldContinue { + return err + } + } + return err + } + + // set default options + if opts.initialInterval == 0 { + opts.initialInterval = 1 * time.Second + } + if opts.factor == 0 { + opts.factor = 2 + } + if opts.maxInterval == 0 { + opts.maxInterval = 1 * time.Minute + } + if opts.attempts == 0 { + return errors.New("retry attempts have to be set when not using custom backoff intervals") + } + interval := opts.initialInterval + for i := 0; ; i++ { + if i == 0 && opts.disableInitialExecution { + time.Sleep(interval) + continue + } + shouldContinue, err = f(i) + if !shouldContinue { + return err + } + if opts.attempts > 0 && i >= opts.attempts-1 { + break + } + select { + case <-opts.cancel: + return nil + case <-time.After(interval): + } + interval = time.Duration(float64(interval) * opts.factor) + if interval >= opts.maxInterval { + interval = opts.maxInterval + } + } + return err +} diff --git a/jetstream/publish.go b/jetstream/publish.go index f41b06fd1..70e219ac4 100644 --- a/jetstream/publish.go +++ b/jetstream/publish.go @@ -1,4 +1,4 @@ -// Copyright 2022-2023 The NATS Authors +// Copyright 2022-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -81,6 +81,7 @@ type ( err error errCh chan error doneCh chan *PubAck + reply string } jetStreamClient struct { @@ -280,17 +281,17 @@ func (js *jetStream) PublishMsgAsync(m *nats.Msg, opts ...PublishOpt) (PubAckFut } var id string + var reply string // register new paf if not retrying if paf == nil { var err error - m.Reply, err = js.newAsyncReply() - defer func() { m.Reply = "" }() + reply, err = js.newAsyncReply() if err != nil { return nil, fmt.Errorf("nats: error creating async reply handler: %s", err) } - id = m.Reply[js.replyPrefixLen:] - paf = &pubAckFuture{msg: m, jsClient: js.publisher, maxRetries: o.retryAttempts, retryWait: o.retryWait} + id = reply[js.replyPrefixLen:] + paf = &pubAckFuture{msg: m, jsClient: js.publisher, maxRetries: o.retryAttempts, retryWait: o.retryWait, reply: reply} numPending, maxPending := js.registerPAF(id, paf) if maxPending > 0 && numPending > maxPending { @@ -303,10 +304,17 @@ func (js *jetStream) PublishMsgAsync(m *nats.Msg, opts ...PublishOpt) (PubAckFut } } else { // when retrying, get the ID from existing reply subject - id = m.Reply[js.replyPrefixLen:] + reply = paf.reply + id = reply[js.replyPrefixLen:] } - if err := js.conn.PublishMsg(m); err != nil { + pubMsg := &nats.Msg{ + Subject: m.Subject, + Reply: reply, + Data: m.Data, + Header: m.Header, + } + if err := js.conn.PublishMsg(pubMsg); err != nil { js.clearPAF(id) return nil, err } @@ -370,6 +378,31 @@ func (js *jetStream) handleAsyncReply(m *nats.Msg) { return } + closeStc := func() { + // Check on anyone stalled and waiting. + if js.publisher.stallCh != nil && len(js.publisher.acks) < js.publisher.maxpa { + close(js.publisher.stallCh) + js.publisher.stallCh = nil + } + } + + closeDchFn := func() func() { + var dch chan struct{} + // Check on anyone one waiting on done status. + if js.publisher.doneCh != nil && len(js.publisher.acks) == 0 { + dch = js.publisher.doneCh + js.publisher.doneCh = nil + } + // Return function to close done channel which + // should be deferred so that error is processed and + // can be checked. + return func() { + if dch != nil { + close(dch) + } + } + } + doErr := func(err error) { paf.err = err if paf.errCh != nil { @@ -386,8 +419,13 @@ func (js *jetStream) handleAsyncReply(m *nats.Msg) { if len(m.Data) == 0 && m.Header.Get(statusHdr) == noResponders { if paf.retries < paf.maxRetries { paf.retries++ - paf.msg.Reply = m.Subject time.AfterFunc(paf.retryWait, func() { + js.publisher.Lock() + paf := js.getPAF(id) + js.publisher.Unlock() + if paf == nil { + return + } _, err := js.PublishMsgAsync(paf.msg, func(po *pubOpts) error { po.pafRetry = paf return nil @@ -401,25 +439,16 @@ func (js *jetStream) handleAsyncReply(m *nats.Msg) { return } delete(js.publisher.acks, id) + closeStc() + defer closeDchFn()() doErr(ErrNoStreamResponse) return } // Remove delete(js.publisher.acks, id) - - // Check on anyone stalled and waiting. - if js.publisher.stallCh != nil && len(js.publisher.acks) < js.publisher.asyncPublisherOpts.maxpa { - close(js.publisher.stallCh) - js.publisher.stallCh = nil - } - // Check on anyone waiting on done status. - if js.publisher.doneCh != nil && len(js.publisher.acks) == 0 { - dch := js.publisher.doneCh - js.publisher.doneCh = nil - // Defer here so error is processed and can be checked. - defer close(dch) - } + closeStc() + defer closeDchFn()() var pa pubAckResponse if err := json.Unmarshal(m.Data, &pa); err != nil { @@ -453,10 +482,17 @@ func (js *jetStream) resetPendingAcksOnReconnect() { return } js.publisher.Lock() - for _, paf := range js.publisher.acks { + errCb := js.publisher.asyncPublisherOpts.aecb + for id, paf := range js.publisher.acks { paf.err = nats.ErrDisconnected + if paf.errCh != nil { + paf.errCh <- paf.err + } + if errCb != nil { + defer errCb(js, paf.msg, nats.ErrDisconnected) + } + delete(js.publisher.acks, id) } - js.publisher.acks = nil if js.publisher.doneCh != nil { close(js.publisher.doneCh) js.publisher.doneCh = nil diff --git a/jetstream/pull.go b/jetstream/pull.go index bb5479aa0..386968108 100644 --- a/jetstream/pull.go +++ b/jetstream/pull.go @@ -14,7 +14,6 @@ package jetstream import ( - "context" "encoding/json" "errors" "fmt" @@ -24,6 +23,7 @@ import ( "time" "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/internal/syncx" "github.com/nats-io/nuid" ) @@ -32,7 +32,7 @@ type ( // It is returned by [Consumer.Messages] method. MessagesContext interface { // Next retrieves next message on a stream. It will block until the next - // message is available. If the context is cancelled, Next will return + // message is available. If the context is canceled, Next will return // ErrMsgIteratorClosed error. Next() (Msg, error) @@ -59,6 +59,11 @@ type ( // Drain unsubscribes from the stream and cancels subscription. // All messages that are already in the buffer will be processed in callback function. Drain() + + // Closed returns a channel that is closed when the consuming is + // fully stopped/drained. When the channel is closed, no more messages + // will be received and processing is complete. + Closed() <-chan struct{} } // MessageHandler is a handler function used as callback in [Consume]. @@ -76,12 +81,12 @@ type ( pullConsumer struct { sync.Mutex - jetStream *jetStream - stream string - durable bool - name string - info *ConsumerInfo - subscriptions map[string]*pullSubscription + jetStream *jetStream + stream string + durable bool + name string + info *ConsumerInfo + subs syncx.Map[string, *pullSubscription] } pullRequest struct { @@ -103,6 +108,7 @@ type ( ThresholdBytes int StopAfter int stopAfterMsgsLeft chan int + notifyOnReconnect bool } ConsumeErrHandlerFunc func(consumeCtx ConsumeContext, err error) @@ -116,14 +122,15 @@ type ( errs chan error pending pendingMsgs hbMonitor *hbMonitor - fetchInProgress uint32 - closed uint32 - draining uint32 + fetchInProgress atomic.Uint32 + closed atomic.Uint32 + draining atomic.Uint32 done chan struct{} connStatusChanged chan nats.Status fetchNext chan *pullRequest consumeOpts *consumeOpts delivered int + closedCh chan struct{} } pendingMsgs struct { @@ -181,12 +188,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( subject := apiSubj(p.jetStream.apiPrefix, fmt.Sprintf(apiRequestNextT, p.stream, p.name)) - // for single consume, use empty string as id - // this is useful for ordered consumer, where only a single subscription is valid - var consumeID string - if len(p.subscriptions) > 0 { - consumeID = nuid.Next() - } + consumeID := nuid.Next() sub := &pullSubscription{ id: consumeID, consumer: p, @@ -199,7 +201,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub.hbMonitor = sub.scheduleHeartbeatCheck(consumeOpts.Heartbeat) - p.subscriptions[sub.id] = sub + p.subs.Store(sub.id, sub) p.Unlock() internalHandler := func(msg *nats.Msg) { @@ -232,7 +234,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub.Unlock() if err != nil { - if atomic.LoadUint32(&sub.closed) == 1 { + if sub.closed.Load() == 1 { return } if sub.consumeOpts.ErrHandler != nil { @@ -259,10 +261,14 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( } sub.subscription.SetClosedHandler(func(sid string) func(string) { return func(subject string) { - p.Lock() - defer p.Unlock() - delete(p.subscriptions, sid) - atomic.CompareAndSwapUint32(&sub.draining, 1, 0) + p.subs.Delete(sid) + sub.draining.CompareAndSwap(1, 0) + sub.Lock() + if sub.closedCh != nil { + close(sub.closedCh) + sub.closedCh = nil + } + sub.Unlock() } }(sub.id)) @@ -286,7 +292,7 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( go func() { isConnected := true for { - if atomic.LoadUint32(&sub.closed) == 1 { + if sub.closed.Load() == 1 { return } select { @@ -304,42 +310,8 @@ func (p *pullConsumer) Consume(handler MessageHandler, opts ...PullConsumeOpt) ( sub.Lock() if !isConnected { isConnected = true - // try fetching consumer info several times to make sure consumer is available after reconnect - backoffOpts := backoffOpts{ - attempts: 10, - initialInterval: 1 * time.Second, - disableInitialExecution: true, - factor: 2, - maxInterval: 10 * time.Second, - cancel: sub.done, - } - err = retryWithBackoff(func(attempt int) (bool, error) { - isClosed := atomic.LoadUint32(&sub.closed) == 1 - if isClosed { - return false, nil - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _, err := p.Info(ctx) - if err != nil { - if sub.consumeOpts.ErrHandler != nil { - err = fmt.Errorf("[%d] attempting to fetch consumer info after reconnect: %w", attempt, err) - if attempt == backoffOpts.attempts-1 { - err = errors.Join(err, fmt.Errorf("maximum retry attempts reached")) - } - sub.consumeOpts.ErrHandler(sub, err) - } - return true, err - } - return false, nil - }, backoffOpts) - if err != nil { - if sub.consumeOpts.ErrHandler != nil { - sub.consumeOpts.ErrHandler(sub, err) - } - sub.Unlock() - sub.cleanup() - return + if sub.consumeOpts.notifyOnReconnect { + sub.errs <- errConnected } sub.fetchNext <- &pullRequest{ @@ -417,7 +389,7 @@ func (s *pullSubscription) incrementDeliveredMsgs() { func (s *pullSubscription) checkPending() { if (s.pending.msgCount < s.consumeOpts.ThresholdMessages || (s.pending.byteCount < s.consumeOpts.ThresholdBytes && s.consumeOpts.MaxBytes != 0)) && - atomic.LoadUint32(&s.fetchInProgress) == 0 { + s.fetchInProgress.Load() == 0 { var batchSize, maxBytes int if s.consumeOpts.MaxBytes == 0 { @@ -461,12 +433,7 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error msgs := make(chan *nats.Msg, consumeOpts.MaxMessages) - // for single consume, use empty string as id - // this is useful for ordered consumer, where only a single subscription is valid - var consumeID string - if len(p.subscriptions) > 0 { - consumeID = nuid.Next() - } + consumeID := nuid.Next() sub := &pullSubscription{ id: consumeID, consumer: p, @@ -485,20 +452,18 @@ func (p *pullConsumer) Messages(opts ...PullMessagesOpt) (MessagesContext, error } sub.subscription.SetClosedHandler(func(sid string) func(string) { return func(subject string) { - p.Lock() - defer p.Unlock() - if atomic.LoadUint32(&sub.draining) != 1 { + if sub.draining.Load() != 1 { // if we're not draining, subscription can be closed as soon // as closed handler is called // otherwise, we need to wait until all messages are drained // in Next - delete(p.subscriptions, sid) + p.subs.Delete(sid) } close(msgs) } }(sub.id)) - p.subscriptions[sub.id] = sub + p.subs.Store(sub.id, sub) p.Unlock() go sub.pullMessages(subject) @@ -531,13 +496,13 @@ var ( ) // Next retrieves next message on a stream. It will block until the next -// message is available. If the context is cancelled, Next will return +// message is available. If the context is canceled, Next will return // ErrMsgIteratorClosed error. func (s *pullSubscription) Next() (Msg, error) { s.Lock() defer s.Unlock() - drainMode := atomic.LoadUint32(&s.draining) == 1 - closed := atomic.LoadUint32(&s.closed) == 1 + drainMode := s.draining.Load() == 1 + closed := s.closed.Load() == 1 if closed && !drainMode { return nil, ErrMsgIteratorClosed } @@ -560,8 +525,8 @@ func (s *pullSubscription) Next() (Msg, error) { case msg, ok := <-s.msgs: if !ok { // if msgs channel is closed, it means that subscription was either drained or stopped - delete(s.consumer.subscriptions, s.id) - atomic.CompareAndSwapUint32(&s.draining, 1, 0) + s.consumer.subs.Delete(s.id) + s.draining.CompareAndSwap(1, 0) return nil, ErrMsgIteratorClosed } if hbMonitor != nil { @@ -596,39 +561,10 @@ func (s *pullSubscription) Next() (Msg, error) { if errors.Is(err, errConnected) { if !isConnected { isConnected = true - // try fetching consumer info several times to make sure consumer is available after reconnect - backoffOpts := backoffOpts{ - attempts: 10, - initialInterval: 1 * time.Second, - disableInitialExecution: true, - factor: 2, - maxInterval: 10 * time.Second, - cancel: s.done, - } - err = retryWithBackoff(func(attempt int) (bool, error) { - isClosed := atomic.LoadUint32(&s.closed) == 1 - if isClosed { - return false, nil - } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _, err := s.consumer.Info(ctx) - if err != nil { - if errors.Is(err, ErrConsumerNotFound) { - return false, err - } - if attempt == backoffOpts.attempts-1 { - return true, fmt.Errorf("could not get consumer info after server reconnect: %w", err) - } - return true, err - } - return false, nil - }, backoffOpts) - if err != nil { - s.Stop() - return nil, err - } + if s.consumeOpts.notifyOnReconnect { + return nil, errConnected + } s.pending.msgCount = 0 s.pending.byteCount = 0 if hbMonitor != nil { @@ -638,7 +574,7 @@ func (s *pullSubscription) Next() (Msg, error) { } if errors.Is(err, errDisconnected) { if hbMonitor != nil { - hbMonitor.Reset(2 * s.consumeOpts.Heartbeat) + hbMonitor.Stop() } isConnected = false } @@ -693,7 +629,7 @@ func (hb *hbMonitor) Reset(dur time.Duration) { // Next after calling Stop will return ErrMsgIteratorClosed error. // All messages that are already in the buffer are discarded. func (s *pullSubscription) Stop() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } close(s.done) @@ -711,10 +647,10 @@ func (s *pullSubscription) Stop() { // subsequent calls to Next. After the buffer is drained, Next will // return ErrMsgIteratorClosed error. func (s *pullSubscription) Drain() { - if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { + if !s.closed.CompareAndSwap(0, 1) { return } - atomic.StoreUint32(&s.draining, 1) + s.draining.Store(1) close(s.done) if s.consumeOpts.stopAfterMsgsLeft != nil { if s.delivered >= s.consumeOpts.StopAfter { @@ -725,6 +661,24 @@ func (s *pullSubscription) Drain() { } } +// Closed returns a channel that is closed when consuming is +// fully stopped/drained. When the channel is closed, no more messages +// will be received and processing is complete. +func (s *pullSubscription) Closed() <-chan struct{} { + s.Lock() + defer s.Unlock() + closedCh := s.closedCh + if closedCh == nil { + closedCh = make(chan struct{}) + s.closedCh = closedCh + } + if !s.subscription.IsValid() { + close(s.closedCh) + s.closedCh = nil + } + return closedCh +} + // Fetch sends a single request to retrieve given number of messages. // It will wait up to provided expiry time if not all messages are available. func (p *pullConsumer) Fetch(batch int, opts ...FetchOpt) (MessageBatch, error) { @@ -903,7 +857,7 @@ func (s *pullSubscription) pullMessages(subject string) { for { select { case req := <-s.fetchNext: - atomic.StoreUint32(&s.fetchInProgress, 1) + s.fetchInProgress.Store(1) if err := s.pull(req, subject); err != nil { if errors.Is(err, ErrMsgIteratorClosed) { @@ -912,7 +866,7 @@ func (s *pullSubscription) pullMessages(subject string) { } s.errs <- err } - atomic.StoreUint32(&s.fetchInProgress, 0) + s.fetchInProgress.Store(0) case <-s.done: s.cleanup() return @@ -943,13 +897,13 @@ func (s *pullSubscription) cleanup() { if s.hbMonitor != nil { s.hbMonitor.Stop() } - drainMode := atomic.LoadUint32(&s.draining) == 1 + drainMode := s.draining.Load() == 1 if drainMode { s.subscription.Drain() } else { s.subscription.Unsubscribe() } - atomic.StoreUint32(&s.closed, 1) + s.closed.Store(1) } // pull sends a pull request to the server and waits for messages using a subscription from [pullSubscription]. @@ -957,7 +911,7 @@ func (s *pullSubscription) cleanup() { func (s *pullSubscription) pull(req *pullRequest, subject string) error { s.consumer.Lock() defer s.consumer.Unlock() - if atomic.LoadUint32(&s.closed) == 1 { + if s.closed.Load() == 1 { return ErrMsgIteratorClosed } if req.Batch < 1 { @@ -1017,7 +971,7 @@ func parseMessagesOpts(ordered bool, opts ...PullMessagesOpt) (*consumeOpts, err func (consumeOpts *consumeOpts) setDefaults(ordered bool) error { if consumeOpts.MaxBytes != unset && consumeOpts.MaxMessages != unset { - return fmt.Errorf("only one of MaxMessages and MaxBytes can be specified") + return errors.New("only one of MaxMessages and MaxBytes can be specified") } if consumeOpts.MaxBytes != unset { // when max_bytes is used, set batch size to a very large number @@ -1053,102 +1007,7 @@ func (consumeOpts *consumeOpts) setDefaults(ordered bool) error { } } if consumeOpts.Heartbeat > consumeOpts.Expires/2 { - return fmt.Errorf("the value of Heartbeat must be less than 50%% of expiry") + return errors.New("the value of Heartbeat must be less than 50%% of expiry") } return nil } - -type backoffOpts struct { - // total retry attempts - // -1 for unlimited - attempts int - // initial interval after which first retry will be performed - // defaults to 1s - initialInterval time.Duration - // determines whether first function execution should be performed immediately - disableInitialExecution bool - // multiplier on each attempt - // defaults to 2 - factor float64 - // max interval between retries - // after reaching this value, all subsequent - // retries will be performed with this interval - // defaults to 1 minute - maxInterval time.Duration - // custom backoff intervals - // if set, overrides all other options except attempts - // if attempts are set, then the last interval will be used - // for all subsequent retries after reaching the limit - customBackoff []time.Duration - // cancel channel - // if set, retry will be cancelled when this channel is closed - cancel <-chan struct{} -} - -func retryWithBackoff(f func(int) (bool, error), opts backoffOpts) error { - var err error - var shouldContinue bool - // if custom backoff is set, use it instead of other options - if len(opts.customBackoff) > 0 { - if opts.attempts != 0 { - return fmt.Errorf("cannot use custom backoff intervals when attempts are set") - } - for i, interval := range opts.customBackoff { - select { - case <-opts.cancel: - return nil - case <-time.After(interval): - } - shouldContinue, err = f(i) - if !shouldContinue { - return err - } - } - return err - } - - // set default options - if opts.initialInterval == 0 { - opts.initialInterval = 1 * time.Second - } - if opts.factor == 0 { - opts.factor = 2 - } - if opts.maxInterval == 0 { - opts.maxInterval = 1 * time.Minute - } - if opts.attempts == 0 { - return fmt.Errorf("retry attempts have to be set when not using custom backoff intervals") - } - interval := opts.initialInterval - for i := 0; ; i++ { - if i == 0 && opts.disableInitialExecution { - time.Sleep(interval) - continue - } - shouldContinue, err = f(i) - if !shouldContinue { - return err - } - if opts.attempts > 0 && i >= opts.attempts-1 { - break - } - select { - case <-opts.cancel: - return nil - case <-time.After(interval): - } - interval = time.Duration(float64(interval) * opts.factor) - if interval >= opts.maxInterval { - interval = opts.maxInterval - } - } - return err -} - -func (c *pullConsumer) getSubscription(id string) (*pullSubscription, bool) { - c.Lock() - defer c.Unlock() - sub, ok := c.subscriptions[id] - return sub, ok -} diff --git a/jetstream/stream.go b/jetstream/stream.go index 397c618c6..0a7beb5d8 100644 --- a/jetstream/stream.go +++ b/jetstream/stream.go @@ -297,13 +297,12 @@ func (s *stream) OrderedConsumer(ctx context.Context, cfg OrderedConsumerConfig) namePrefix: nuid.Next(), doReset: make(chan struct{}, 1), } - if cfg.OptStartSeq != 0 { - oc.cursor.streamSeq = cfg.OptStartSeq - 1 - } - err := oc.reset() + consCfg := oc.getConsumerConfig() + cons, err := s.CreateOrUpdateConsumer(ctx, *consCfg) if err != nil { return nil, err } + oc.currentConsumer = cons.(*pullConsumer) return oc, nil } @@ -539,16 +538,16 @@ func convertDirectGetMsgResponseToMsg(name string, r *nats.Msg) (*RawStreamMsg, // Check for headers that give us the required information to // reconstruct the message. if len(r.Header) == 0 { - return nil, fmt.Errorf("nats: response should have headers") + return nil, errors.New("nats: response should have headers") } stream := r.Header.Get(StreamHeader) if stream == "" { - return nil, fmt.Errorf("nats: missing stream header") + return nil, errors.New("nats: missing stream header") } seqStr := r.Header.Get(SequenceHeader) if seqStr == "" { - return nil, fmt.Errorf("nats: missing sequence header") + return nil, errors.New("nats: missing sequence header") } seq, err := strconv.ParseUint(seqStr, 10, 64) if err != nil { @@ -556,7 +555,7 @@ func convertDirectGetMsgResponseToMsg(name string, r *nats.Msg) (*RawStreamMsg, } timeStr := r.Header.Get(TimeStampHeaer) if timeStr == "" { - return nil, fmt.Errorf("nats: missing timestamp header") + return nil, errors.New("nats: missing timestamp header") } tm, err := time.Parse(time.RFC3339Nano, timeStr) @@ -565,7 +564,7 @@ func convertDirectGetMsgResponseToMsg(name string, r *nats.Msg) (*RawStreamMsg, } subj := r.Header.Get(SubjectHeader) if subj == "" { - return nil, fmt.Errorf("nats: missing subject header") + return nil, errors.New("nats: missing subject header") } return &RawStreamMsg{ Subject: subj, diff --git a/jetstream/stream_config.go b/jetstream/stream_config.go index dd1f9d941..304203bc5 100644 --- a/jetstream/stream_config.go +++ b/jetstream/stream_config.go @@ -15,6 +15,7 @@ package jetstream import ( "encoding/json" + "errors" "fmt" "strings" "time" @@ -192,8 +193,8 @@ type ( // v2.10.0 or later. Metadata map[string]string `json:"metadata,omitempty"` - // Template identifies the template that manages the Stream. DEPRECATED: - // This feature is no longer supported. + // Template identifies the template that manages the Stream. + // Deprecated: This feature is no longer supported. Template string `json:"template_owner,omitempty"` } @@ -584,7 +585,7 @@ func (alg StoreCompression) MarshalJSON() ([]byte, error) { case NoCompression: str = "none" default: - return nil, fmt.Errorf("unknown compression algorithm") + return nil, errors.New("unknown compression algorithm") } return json.Marshal(str) } @@ -600,7 +601,7 @@ func (alg *StoreCompression) UnmarshalJSON(b []byte) error { case "none": *alg = NoCompression default: - return fmt.Errorf("unknown compression algorithm") + return errors.New("unknown compression algorithm") } return nil } diff --git a/jetstream/test/helper_test.go b/jetstream/test/helper_test.go index a9dbae222..9b7c1b76a 100644 --- a/jetstream/test/helper_test.go +++ b/jetstream/test/helper_test.go @@ -82,15 +82,6 @@ func NewConnection(t *testing.T, port int) *nats.Conn { return nc } -// NewEConn -func NewEConn(t *testing.T) *nats.EncodedConn { - ec, err := nats.NewEncodedConn(NewDefaultConnection(t), nats.DEFAULT_ENCODER) - if err != nil { - t.Fatalf("Failed to create an encoded connection: %v\n", err) - } - return ec -} - //////////////////////////////////////////////////////////////////////////////// // Running nats server in separate Go routines //////////////////////////////////////////////////////////////////////////////// diff --git a/jetstream/test/jetstream_test.go b/jetstream/test/jetstream_test.go index f5c9c8ee1..3b2530940 100644 --- a/jetstream/test/jetstream_test.go +++ b/jetstream/test/jetstream_test.go @@ -1,4 +1,4 @@ -// Copyright 2022-2023 The NATS Authors +// Copyright 2022-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -1963,3 +1963,125 @@ func TestConsumerConfigMatches(t *testing.T) { t.Fatalf("ConsumerConfig doesn't match") } } + +func TestJetStreamCleanupPublisher(t *testing.T) { + + t.Run("cleanup js publisher", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + // Create a stream + if _, err := js.CreateStream(context.Background(), jetstream.StreamConfig{Name: "TEST", Subjects: []string{"FOO"}}); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + numSubs := nc.NumSubscriptions() + if _, err := js.PublishAsync("FOO", []byte("hello")); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + select { + case <-js.PublishAsyncComplete(): + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + if numSubs+1 != nc.NumSubscriptions() { + t.Fatalf("Expected an additional subscription after publish, got %d", nc.NumSubscriptions()) + } + + js.CleanupPublisher() + + if numSubs != nc.NumSubscriptions() { + t.Fatalf("Expected subscriptions to be back to original count") + } + }) + + t.Run("cleanup js publisher, cancel pending acks", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + cbErr := make(chan error, 10) + js, err := jetstream.New(nc, jetstream.WithPublishAsyncErrHandler(func(js jetstream.JetStream, m *nats.Msg, err error) { + cbErr <- err + })) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Create a stream with NoAck so that we can test that we cancel ack futures. + if _, err := js.CreateStream(context.Background(), jetstream.StreamConfig{Name: "TEST", Subjects: []string{"FOO"}, NoAck: true}); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + numSubs := nc.NumSubscriptions() + + var acks []jetstream.PubAckFuture + for i := 0; i < 10; i++ { + ack, err := js.PublishAsync("FOO", []byte("hello")) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + acks = append(acks, ack) + } + + asyncComplete := js.PublishAsyncComplete() + select { + case <-asyncComplete: + t.Fatalf("Should not complete, NoAck is set") + case <-time.After(200 * time.Millisecond): + } + + if numSubs+1 != nc.NumSubscriptions() { + t.Fatalf("Expected an additional subscription after publish, got %d", nc.NumSubscriptions()) + } + + js.CleanupPublisher() + + if numSubs != nc.NumSubscriptions() { + t.Fatalf("Expected subscriptions to be back to original count") + } + + // check that PublishAsyncComplete channel is closed + select { + case <-asyncComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + // check that all ack futures are canceled + for _, ack := range acks { + select { + case err := <-ack.Err(): + if !errors.Is(err, jetstream.ErrJetStreamPublisherClosed) { + t.Fatalf("Expected JetStreamContextClosed error, got %v", err) + } + case <-ack.Ok(): + t.Fatalf("Expected error on the ack future") + case <-time.After(200 * time.Millisecond): + t.Fatalf("Expected an error on the ack future") + } + } + + // check that async error handler is called for each pending ack + for i := 0; i < 10; i++ { + select { + case err := <-cbErr: + if !errors.Is(err, jetstream.ErrJetStreamPublisherClosed) { + t.Fatalf("Expected JetStreamContextClosed error, got %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("Expected errors to be passed from the async handler") + } + } + }) + +} diff --git a/jetstream/test/kv_test.go b/jetstream/test/kv_test.go index d85231663..42c0d28c3 100644 --- a/jetstream/test/kv_test.go +++ b/jetstream/test/kv_test.go @@ -246,6 +246,22 @@ func TestKeyValueWatch(t *testing.T) { } } } + expectPurgeF := func(t *testing.T, watcher jetstream.KeyWatcher) func(key string, revision uint64) { + return func(key string, revision uint64) { + t.Helper() + select { + case v := <-watcher.Updates(): + if v.Operation() != jetstream.KeyValuePurge { + t.Fatalf("Expected a delete operation but got %+v", v) + } + if v.Revision() != revision { + t.Fatalf("Did not get expected revision: %d vs %d", revision, v.Revision()) + } + case <-time.After(time.Second): + t.Fatalf("Did not receive an update like expected") + } + } + } expectInitDoneF := func(t *testing.T, watcher jetstream.KeyWatcher) func() { return func() { t.Helper() @@ -315,13 +331,27 @@ func TestKeyValueWatch(t *testing.T) { watcher, err = kv.Watch(ctx, "t.*") expectOk(t, err) - defer watcher.Stop() expectInitDone = expectInitDoneF(t, watcher) expectUpdate = expectUpdateF(t, watcher) expectUpdate("t.name", "ik", 8) expectUpdate("t.age", "44", 10) expectInitDone() + watcher.Stop() + + // test watcher with multiple filters + watcher, err = kv.WatchFiltered(ctx, []string{"t.name", "name"}) + expectOk(t, err) + expectInitDone = expectInitDoneF(t, watcher) + expectUpdate = expectUpdateF(t, watcher) + expectPurge := expectPurgeF(t, watcher) + expectUpdate("name", "ik", 3) + expectUpdate("t.name", "ik", 8) + expectInitDone() + err = kv.Purge(ctx, "name") + expectOk(t, err) + expectPurge("name", 11) + defer watcher.Stop() }) t.Run("watcher with history included", func(t *testing.T) { @@ -514,6 +544,76 @@ func TestKeyValueWatch(t *testing.T) { expectUpdate("age", "22", 3) expectUpdate("name2", "ik", 4) }) + + t.Run("invalid watchers", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kv, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{Bucket: "WATCH"}) + expectOk(t, err) + + // empty keys + _, err = kv.Watch(ctx, "") + expectErr(t, err, jetstream.ErrInvalidKey) + + // invalid key + _, err = kv.Watch(ctx, "a.>.b") + expectErr(t, err, jetstream.ErrInvalidKey) + + _, err = kv.Watch(ctx, "foo.") + expectErr(t, err, jetstream.ErrInvalidKey) + + // conflicting options + _, err = kv.Watch(ctx, "foo", jetstream.IncludeHistory(), jetstream.UpdatesOnly()) + expectErr(t, err, jetstream.ErrInvalidOption) + }) + + t.Run("filtered watch with no filters", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kv, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{Bucket: "WATCH"}) + expectOk(t, err) + + // this should behave like WatchAll + watcher, err := kv.WatchFiltered(ctx, []string{}) + expectOk(t, err) + defer watcher.Stop() + + expectInitDone := expectInitDoneF(t, watcher) + expectUpdate := expectUpdateF(t, watcher) + expectDelete := expectDeleteF(t, watcher) + // Make sure we already got an initial value marker. + expectInitDone() + + _, err = kv.Create(ctx, "name", []byte("derek")) + expectOk(t, err) + expectUpdate("name", "derek", 1) + _, err = kv.Put(ctx, "name", []byte("rip")) + expectOk(t, err) + expectUpdate("name", "rip", 2) + _, err = kv.Put(ctx, "name", []byte("ik")) + expectOk(t, err) + expectUpdate("name", "ik", 3) + _, err = kv.Put(ctx, "age", []byte("22")) + expectOk(t, err) + expectUpdate("age", "22", 4) + _, err = kv.Put(ctx, "age", []byte("33")) + expectOk(t, err) + expectUpdate("age", "33", 5) + expectOk(t, kv.Delete(ctx, "age")) + expectDelete("age", 6) + }) } func TestKeyValueWatchContext(t *testing.T) { @@ -939,6 +1039,7 @@ func TestKeyValueListKeys(t *testing.T) { func TestKeyValueCrossAccounts(t *testing.T) { conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 jetstream: enabled accounts: { A: { @@ -1234,7 +1335,7 @@ func TestKeyValueMirrorCrossDomains(t *testing.T) { checkFor(t, 10*time.Second, 10*time.Millisecond, func() error { _, err := kv.Get(context.Background(), key) if err == nil { - return fmt.Errorf("Expected key to be gone") + return errors.New("Expected key to be gone") } if !errors.Is(err, jetstream.ErrKeyNotFound) { return err @@ -1481,11 +1582,46 @@ func TestKeyValueCreate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - kv, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{Bucket: "TEST"}) + kv, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "TEST", + Description: "Test KV", + MaxValueSize: 128, + History: 10, + TTL: 1 * time.Hour, + MaxBytes: 1024, + Storage: jetstream.FileStorage, + }) if err != nil { t.Fatalf("Error creating kv: %v", err) } + expectedStreamConfig := jetstream.StreamConfig{ + Name: "KV_TEST", + Description: "Test KV", + Subjects: []string{"$KV.TEST.>"}, + MaxMsgs: -1, + MaxBytes: 1024, + Discard: jetstream.DiscardNew, + MaxAge: 1 * time.Hour, + MaxMsgsPerSubject: 10, + MaxMsgSize: 128, + Storage: jetstream.FileStorage, + DenyDelete: true, + AllowRollup: true, + AllowDirect: true, + MaxConsumers: -1, + Replicas: 1, + Duplicates: 2 * time.Minute, + } + + stream, err := js.Stream(ctx, "KV_TEST") + if err != nil { + t.Fatalf("Error getting stream: %v", err) + } + if !reflect.DeepEqual(stream.CachedInfo().Config, expectedStreamConfig) { + t.Fatalf("Expected stream config to be %+v, got %+v", expectedStreamConfig, stream.CachedInfo().Config) + } + _, err = kv.Create(ctx, "key", []byte("1")) if err != nil { t.Fatalf("Error creating key: %v", err) @@ -1599,3 +1735,58 @@ func TestKeyValueCompression(t *testing.T) { t.Fatalf("Expected stream to be compressed with S2") } } + +func TestKeyValueCreateRepairOldKV(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + ctx := context.Background() + + // create a standard kv + _, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "A", + }) + if err != nil { + t.Fatalf("Error creating kv: %v", err) + } + + // get stream config and set discard policy to old and AllowDirect to false + stream, err := js.Stream(ctx, "KV_A") + if err != nil { + t.Fatalf("Error getting stream info: %v", err) + } + streamCfg := stream.CachedInfo().Config + streamCfg.Discard = jetstream.DiscardOld + streamCfg.AllowDirect = false + + // create a new kv with the same name - client should fix the config + _, err = js.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "A", + }) + if err != nil { + t.Fatalf("Error creating kv: %v", err) + } + + // get stream config again and check if the discard policy is set to new + stream, err = js.Stream(ctx, "KV_A") + if err != nil { + t.Fatalf("Error getting stream info: %v", err) + } + if stream.CachedInfo().Config.Discard != jetstream.DiscardNew { + t.Fatalf("Expected stream to have discard policy set to new") + } + if !stream.CachedInfo().Config.AllowDirect { + t.Fatalf("Expected stream to have AllowDirect set to true") + } + + // attempting to create a new kv with the same name and different settings should fail + _, err = js.CreateKeyValue(ctx, jetstream.KeyValueConfig{ + Bucket: "A", + Description: "New KV", + }) + if !errors.Is(err, jetstream.ErrBucketExists) { + t.Fatalf("Expected error to be ErrBucketExists, got: %v", err) + } +} diff --git a/jetstream/test/object_test.go b/jetstream/test/object_test.go index 70e2b7096..8f421c51f 100644 --- a/jetstream/test/object_test.go +++ b/jetstream/test/object_test.go @@ -1197,3 +1197,72 @@ func TestObjectStoreCompression(t *testing.T) { t.Fatalf("Expected stream to be compressed with S2") } } + +func TestObjectStoreMirror(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + bucketName := "test-bucket" + + ctx := context.Background() + obs, err := js.CreateObjectStore(ctx, jetstream.ObjectStoreConfig{Bucket: bucketName, Description: "testing"}) + expectOk(t, err) + + mirrorBucketName := "mirror-test-bucket" + + _, err = js.CreateStream(ctx, jetstream.StreamConfig{ + Name: fmt.Sprintf("OBJ_%s", mirrorBucketName), + Mirror: &jetstream.StreamSource{ + Name: fmt.Sprintf("OBJ_%s", bucketName), + SubjectTransforms: []jetstream.SubjectTransformConfig{ + { + Source: fmt.Sprintf("$O.%s.>", bucketName), + Destination: fmt.Sprintf("$O.%s.>", mirrorBucketName), + }, + }, + }, + AllowRollup: true, // meta messages are always rollups + }) + if err != nil { + t.Fatalf("Error creating object store bucket mirror: %v", err) + } + + _, err = obs.PutString(ctx, "A", "abc") + expectOk(t, err) + + mirrorObs, err := js.ObjectStore(ctx, mirrorBucketName) + expectOk(t, err) + + // Make sure we sync. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + mirrorValue, err := mirrorObs.GetString(ctx, "A") + if err != nil { + return err + } + if mirrorValue != "abc" { + t.Fatalf("Expected mirrored object store value to be the same as original") + } + return nil + }) + + watcher, err := mirrorObs.Watch(ctx) + if err != nil { + t.Fatalf("Error creating watcher: %v", err) + } + defer watcher.Stop() + + // expect to get one value and nil + for { + select { + case info := <-watcher.Updates(): + if info == nil { + return + } + case <-time.After(2 * time.Second): + t.Fatalf("Expected to receive an update") + } + } +} diff --git a/jetstream/test/ordered_test.go b/jetstream/test/ordered_test.go index c8b529f16..5a6231b2d 100644 --- a/jetstream/test/ordered_test.go +++ b/jetstream/test/ordered_test.go @@ -17,6 +17,7 @@ import ( "context" "errors" "fmt" + "reflect" "sync" "testing" "time" @@ -28,9 +29,9 @@ import ( func TestOrderedConsumerConsume(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -71,7 +72,7 @@ func TestOrderedConsumerConsume(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() name := c.CachedInfo().Name @@ -79,12 +80,110 @@ func TestOrderedConsumerConsume(t *testing.T) { t.Fatal(err) } wg.Add(len(testMsgs)) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() l.Stop() }) + t.Run("reset consumer before receiving any messages", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + wg := &sync.WaitGroup{} + l, err := c.Consume(func(msg jetstream.Msg) { + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(500 * time.Millisecond) + + name := c.CachedInfo().Name + if err := s.DeleteConsumer(ctx, name); err != nil { + t.Fatal(err) + } + wg.Add(len(testMsgs)) + publishTestMsgs(t, js) + wg.Wait() + + l.Stop() + }) + + t.Run("with custom start seq", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, js) + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{DeliverPolicy: jetstream.DeliverByStartSequencePolicy, OptStartSeq: 3}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + wg := &sync.WaitGroup{} + wg.Add(len(testMsgs) - 2) + l, err := c.Consume(func(msg jetstream.Msg) { + wg.Done() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer l.Stop() + + wg.Wait() + + time.Sleep(500 * time.Millisecond) + // now delete consumer again and publish some more messages, all should be received normally + info, err := c.Info(ctx) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if err := s.DeleteConsumer(ctx, info.Config.Name); err != nil { + t.Fatal(err) + } + wg.Add(len(testMsgs)) + publishTestMsgs(t, js) + wg.Wait() + }) + t.Run("base usage, server shutdown", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) @@ -126,21 +225,13 @@ func TestOrderedConsumerConsume(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() srv = restartBasicJSServer(t, srv) defer shutdownJSServerAndRemoveStorage(t, srv) - select { - case err := <-errs: - if !errors.Is(err, jetstream.ErrConsumerNotFound) { - t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerNotFound, err) - } - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for error") - } wg.Add(len(testMsgs)) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() l.Stop() @@ -190,7 +281,7 @@ func TestOrderedConsumerConsume(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) select { case err := <-errs: if !errors.Is(err, jetstream.ErrNoHeartbeat) { @@ -202,7 +293,7 @@ func TestOrderedConsumerConsume(t *testing.T) { wg.Wait() wg.Add(len(testMsgs)) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() l.Stop() }) @@ -232,7 +323,7 @@ func TestOrderedConsumerConsume(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs, err := c.Fetch(5) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -435,7 +526,7 @@ func TestOrderedConsumerConsume(t *testing.T) { } wg := &sync.WaitGroup{} wg.Add(5) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) cc, err := c.Consume(func(msg jetstream.Msg) { time.Sleep(50 * time.Millisecond) msg.Ack() @@ -448,14 +539,182 @@ func TestOrderedConsumerConsume(t *testing.T) { cc.Drain() wg.Wait() }) + + t.Run("stop consume during reset", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + for i := 0; i < 10; i++ { + c, err := s.OrderedConsumer(context.Background(), jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cc, err := c.Consume(func(msg jetstream.Msg) { + msg.Ack() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + cc.Stop() + time.Sleep(50 * time.Millisecond) + } + }) + + t.Run("wait for closed after drain", func(t *testing.T) { + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + closed := cc.Closed() + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, js) + + // wait for the consumer to be recreated before calling drain + for i := 0; i < 5; i++ { + _, err = c.Info(ctx) + if err != nil { + if errors.Is(err, jetstream.ErrConsumerNotFound) { + time.Sleep(100 * time.Millisecond) + continue + } + t.Fatalf("Unexpected error: %v", err) + } + break + } + + cc.Drain() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + + if len(msgs) != 2*len(testMsgs) { + t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", 2*len(testMsgs), len(msgs)) + } + }) + } + }) + + t.Run("wait for closed on already closed consume", func(t *testing.T) { + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + if err := s.DeleteConsumer(context.Background(), c.CachedInfo().Name); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + cc.Stop() + + time.Sleep(100 * time.Millisecond) + + select { + case <-cc.Closed(): + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + }) + } + }) } func TestOrderedConsumerMessages(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -492,7 +751,7 @@ func TestOrderedConsumerMessages(t *testing.T) { } defer it.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < 5; i++ { msg, err := it.Next() if err != nil { @@ -504,7 +763,7 @@ func TestOrderedConsumerMessages(t *testing.T) { if err := s.DeleteConsumer(ctx, name); err != nil { t.Fatal(err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < 5; i++ { msg, err := it.Next() if err != nil { @@ -549,7 +808,7 @@ func TestOrderedConsumerMessages(t *testing.T) { } defer it.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < 5; i++ { msg, err := it.Next() if err != nil { @@ -559,7 +818,7 @@ func TestOrderedConsumerMessages(t *testing.T) { } srv = restartBasicJSServer(t, srv) defer shutdownJSServerAndRemoveStorage(t, srv) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < 5; i++ { msg, err := it.Next() if err != nil { @@ -608,7 +867,7 @@ func TestOrderedConsumerMessages(t *testing.T) { } defer it.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < 5; i++ { msg, err := it.Next() if err != nil { @@ -616,7 +875,7 @@ func TestOrderedConsumerMessages(t *testing.T) { } msgs = append(msgs, msg) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < 5; i++ { msg, err := it.Next() if err != nil { @@ -816,7 +1075,7 @@ func TestOrderedConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs, err := c.Fetch(5) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -894,7 +1153,7 @@ func TestOrderedConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) go func() { time.Sleep(100 * time.Millisecond) it.Drain() @@ -922,9 +1181,9 @@ func TestOrderedConsumerMessages(t *testing.T) { func TestOrderedConsumerFetch(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -956,7 +1215,7 @@ func TestOrderedConsumerFetch(t *testing.T) { msgs := make([]jetstream.Msg, 0) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err := c.Fetch(5) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -972,7 +1231,7 @@ func TestOrderedConsumerFetch(t *testing.T) { if err := s.DeleteConsumer(ctx, name); err != nil { t.Fatal(err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err = c.Fetch(5) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -989,6 +1248,71 @@ func TestOrderedConsumerFetch(t *testing.T) { } }) + t.Run("with custom deliver policy", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + + for i := 0; i < 5; i++ { + if _, err := js.Publish(context.Background(), "FOO.A", []byte("msg")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + for i := 0; i < 5; i++ { + if _, err := js.Publish(context.Background(), "FOO.B", []byte("msg")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + } + + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{ + DeliverPolicy: jetstream.DeliverLastPerSubjectPolicy, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + res, err := c.Fetch(int(c.CachedInfo().NumPending), jetstream.FetchMaxWait(1*time.Second)) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + for msg := range res.Messages() { + msgs = append(msgs, msg) + } + + if res.Error() != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if len(msgs) != 2 { + t.Fatalf("Expected %d messages; got: %d", 2, len(msgs)) + } + expectedSubjects := []string{"FOO.A", "FOO.B"} + + for i := range msgs { + if msgs[i].Subject() != expectedSubjects[i] { + t.Fatalf("Expected subject: %s; got: %s", expectedSubjects[i], msgs[i].Subject()) + } + } + }) + t.Run("consumer used as consume", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) @@ -1050,7 +1374,7 @@ func TestOrderedConsumerFetch(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err := c.Fetch(1, jetstream.FetchMaxWait(100*time.Millisecond)) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1068,9 +1392,9 @@ func TestOrderedConsumerFetch(t *testing.T) { func TestOrderedConsumerFetchBytes(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -1102,7 +1426,7 @@ func TestOrderedConsumerFetchBytes(t *testing.T) { msgs := make([]jetstream.Msg, 0) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err := c.FetchBytes(500, jetstream.FetchMaxWait(100*time.Millisecond)) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1118,7 +1442,7 @@ func TestOrderedConsumerFetchBytes(t *testing.T) { if err := s.DeleteConsumer(ctx, name); err != nil { t.Fatal(err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err = c.Fetch(500, jetstream.FetchMaxWait(100*time.Millisecond)) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1196,7 +1520,7 @@ func TestOrderedConsumerFetchBytes(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err := c.FetchBytes(500, jetstream.FetchMaxWait(100*time.Millisecond)) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1214,9 +1538,9 @@ func TestOrderedConsumerFetchBytes(t *testing.T) { func TestOrderedConsumerNext(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -1246,22 +1570,20 @@ func TestOrderedConsumerNext(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) - msg, err := c.Next() + publishTestMsgs(t, js) + _, err = c.Next() if err != nil { t.Fatalf("Unexpected error: %s", err) } - msg.Ack() name := c.CachedInfo().Name if err := s.DeleteConsumer(ctx, name); err != nil { t.Fatal(err) } - msg, err = c.Next() + _, err = c.Next() if err != nil { t.Fatalf("Unexpected error: %s", err) } - msg.Ack() }) t.Run("consumer used as consume", func(t *testing.T) { @@ -1299,14 +1621,78 @@ func TestOrderedConsumerNext(t *testing.T) { t.Fatalf("Expected error: %s; got: %s", jetstream.ErrOrderConsumerUsedAsConsume, err) } }) + + t.Run("preserve sequence after fetch error", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.OrderedConsumer(ctx, jetstream.OrderedConsumerConfig{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + msg, err := c.Next() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + meta, err := msg.Metadata() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if meta.Sequence.Stream != 1 { + t.Fatalf("Expected sequence: %d; got: %d", 1, meta.Sequence.Stream) + } + + // get next message, it should time out (no more messages on stream) + _, err = c.Next(jetstream.FetchMaxWait(100 * time.Millisecond)) + if !errors.Is(err, nats.ErrTimeout) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrTimeout, err) + } + + if _, err := js.Publish(ctx, "FOO.A", []byte("msg")); err != nil { + t.Fatalf("Unexpected error during publish: %s", err) + } + + // get next message, it should have stream sequence 2 + msg, err = c.Next() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + meta, err = msg.Metadata() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if meta.Sequence.Stream != 2 { + t.Fatalf("Expected sequence: %d; got: %d", 2, meta.Sequence.Stream) + } + }) } func TestOrderedConsumerFetchNoWait(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -1338,7 +1724,7 @@ func TestOrderedConsumerFetchNoWait(t *testing.T) { msgs := make([]jetstream.Msg, 0) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err := c.FetchNoWait(5) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1354,7 +1740,7 @@ func TestOrderedConsumerFetchNoWait(t *testing.T) { if err := s.DeleteConsumer(ctx, name); err != nil { t.Fatal(err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) res, err = c.FetchNoWait(5) if err != nil { t.Fatalf("Unexpected error: %s", err) @@ -1548,3 +1934,188 @@ func TestOrderedConsumerNextOrder(t *testing.T) { } } } + +func TestOrderedConsumerConfig(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + s, err := js.CreateStream(context.Background(), jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + tests := []struct { + name string + config jetstream.OrderedConsumerConfig + expected jetstream.ConsumerConfig + }{ + { + name: "default config", + config: jetstream.OrderedConsumerConfig{}, + expected: jetstream.ConsumerConfig{ + DeliverPolicy: jetstream.DeliverAllPolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 5 * time.Minute, + Replicas: 1, + MemoryStorage: true, + }, + }, + { + name: "custom inactive threshold", + config: jetstream.OrderedConsumerConfig{ + InactiveThreshold: 10 * time.Second, + }, + expected: jetstream.ConsumerConfig{ + DeliverPolicy: jetstream.DeliverAllPolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 10 * time.Second, + Replicas: 1, + MemoryStorage: true, + }, + }, + { + name: "custom opt start seq and inactive threshold", + config: jetstream.OrderedConsumerConfig{ + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + OptStartSeq: 10, + InactiveThreshold: 10 * time.Second, + }, + expected: jetstream.ConsumerConfig{ + OptStartSeq: 10, + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 10 * time.Second, + Replicas: 1, + MemoryStorage: true, + }, + }, + { + name: "all fields customized, start with custom seq", + config: jetstream.OrderedConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + OptStartSeq: 10, + ReplayPolicy: jetstream.ReplayOriginalPolicy, + InactiveThreshold: 10 * time.Second, + HeadersOnly: true, + }, + expected: jetstream.ConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + OptStartSeq: 10, + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 10 * time.Second, + Replicas: 1, + MemoryStorage: true, + HeadersOnly: true, + }, + }, + { + name: "all fields customized, start with custom time", + config: jetstream.OrderedConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + DeliverPolicy: jetstream.DeliverByStartTimePolicy, + OptStartTime: &time.Time{}, + ReplayPolicy: jetstream.ReplayOriginalPolicy, + InactiveThreshold: 10 * time.Second, + HeadersOnly: true, + }, + expected: jetstream.ConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + OptStartTime: &time.Time{}, + DeliverPolicy: jetstream.DeliverByStartTimePolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 10 * time.Second, + Replicas: 1, + MemoryStorage: true, + HeadersOnly: true, + }, + }, + { + name: "both start seq and time set, deliver policy start seq", + config: jetstream.OrderedConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + OptStartSeq: 10, + OptStartTime: &time.Time{}, + ReplayPolicy: jetstream.ReplayOriginalPolicy, + InactiveThreshold: 10 * time.Second, + HeadersOnly: true, + }, + expected: jetstream.ConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + OptStartSeq: 10, + OptStartTime: nil, + DeliverPolicy: jetstream.DeliverByStartSequencePolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 10 * time.Second, + Replicas: 1, + MemoryStorage: true, + HeadersOnly: true, + }, + }, + { + name: "both start seq and time set, deliver policy start time", + config: jetstream.OrderedConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + DeliverPolicy: jetstream.DeliverByStartTimePolicy, + OptStartSeq: 10, + OptStartTime: &time.Time{}, + ReplayPolicy: jetstream.ReplayOriginalPolicy, + InactiveThreshold: 10 * time.Second, + HeadersOnly: true, + }, + expected: jetstream.ConsumerConfig{ + FilterSubjects: []string{"foo.a", "foo.b"}, + OptStartSeq: 0, + OptStartTime: &time.Time{}, + DeliverPolicy: jetstream.DeliverByStartTimePolicy, + AckPolicy: jetstream.AckNonePolicy, + MaxDeliver: -1, + MaxWaiting: 512, + InactiveThreshold: 10 * time.Second, + Replicas: 1, + MemoryStorage: true, + HeadersOnly: true, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c, err := s.OrderedConsumer(context.Background(), test.config) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + cfg := c.CachedInfo().Config + test.expected.Name = cfg.Name + + if !reflect.DeepEqual(test.expected, cfg) { + t.Fatalf("Expected config %+v, got %+v", test.expected, cfg) + } + }) + } +} diff --git a/jetstream/test/publish_test.go b/jetstream/test/publish_test.go index a7ba8c31a..f79ad19fb 100644 --- a/jetstream/test/publish_test.go +++ b/jetstream/test/publish_test.go @@ -16,8 +16,10 @@ package test import ( "context" "errors" + "fmt" "os" "reflect" + "sync" "testing" "time" @@ -1330,7 +1332,6 @@ func TestPublishMsgAsyncWithPendingMsgs(t *testing.T) { func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { s := RunBasicJetStreamServer() - defer shutdownJSServerAndRemoveStorage(t, s) nc, err := nats.Connect(s.ClientURL()) if err != nil { @@ -1352,6 +1353,7 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { errs := make(chan error, 1) done := make(chan struct{}, 1) acks := make(chan jetstream.PubAckFuture, 100) + wg := sync.WaitGroup{} go func() { for i := 0; i < 100; i++ { if ack, err := js.PublishAsync("FOO.A", []byte("hello")); err != nil { @@ -1360,6 +1362,7 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { } else { acks <- ack } + wg.Add(1) } close(acks) done <- struct{}{} @@ -1371,28 +1374,32 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("Did not receive completion signal") } - s.Shutdown() - time.Sleep(100 * time.Millisecond) - if pending := js.PublishAsyncPending(); pending != 0 { - t.Fatalf("Expected no pending messages after server shutdown; got: %d", pending) + for ack := range acks { + go func(paf jetstream.PubAckFuture) { + select { + case <-paf.Ok(): + case err := <-paf.Err(): + if !errors.Is(err, nats.ErrDisconnected) && !errors.Is(err, nats.ErrNoResponders) { + errs <- fmt.Errorf("Expected error: %v or %v; got: %v", nats.ErrDisconnected, nats.ErrNoResponders, err) + } + case <-time.After(5 * time.Second): + errs <- errors.New("Did not receive completion signal") + } + wg.Done() + }(ack) } - s = RunBasicJetStreamServer() + s = restartBasicJSServer(t, s) defer shutdownJSServerAndRemoveStorage(t, s) - for ack := range acks { - select { - case <-ack.Ok(): - case err := <-ack.Err(): - if !errors.Is(err, nats.ErrDisconnected) && !errors.Is(err, nats.ErrNoResponders) { - t.Fatalf("Expected error: %v or %v; got: %v", nats.ErrDisconnected, nats.ErrNoResponders, err) - } - case <-time.After(5 * time.Second): - t.Fatalf("Did not receive completion signal") - } + wg.Wait() + select { + case err := <-errs: + t.Fatalf("Unexpected error: %v", err) + default: } } -func TestAsyncPublishRetry(t *testing.T) { +func TestPublishAsyncRetry(t *testing.T) { tests := []struct { name string pubOpts []jetstream.PublishOpt @@ -1446,6 +1453,7 @@ func TestAsyncPublishRetry(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } + publishComplete := js.PublishAsyncComplete() errs := make(chan error, 1) go func() { // create stream with delay so that publish will receive no responders @@ -1469,6 +1477,78 @@ func TestAsyncPublishRetry(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("Timeout waiting for ack") } + + select { + case <-publishComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } }) } } + +func TestPublishAsyncRetryInErrHandler(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + streamCreated := make(chan struct{}) + errCB := func(js jetstream.JetStream, m *nats.Msg, e error) { + <-streamCreated + _, err := js.PublishMsgAsync(m) + if err != nil { + t.Fatalf("Unexpected error when republishing: %v", err) + } + } + + js, err := jetstream.New(nc, jetstream.WithPublishAsyncErrHandler(errCB)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + errs := make(chan error, 1) + done := make(chan struct{}, 1) + go func() { + for i := 0; i < 10; i++ { + if _, err := js.PublishAsync("FOO.A", []byte("hello"), jetstream.WithRetryAttempts(0)); err != nil { + errs <- err + return + } + } + done <- struct{}{} + }() + select { + case <-done: + case err := <-errs: + t.Fatalf("Unexpected error during publish: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + stream, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + close(streamCreated) + select { + case <-js.PublishAsyncComplete(): + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + info, err := stream.Info(context.Background()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if info.State.Msgs != 10 { + t.Fatalf("Expected 10 messages in the stream; got: %d", info.State.Msgs) + } +} diff --git a/jetstream/test/pull_test.go b/jetstream/test/pull_test.go index b46697e7a..4042e52f5 100644 --- a/jetstream/test/pull_test.go +++ b/jetstream/test/pull_test.go @@ -28,9 +28,9 @@ import ( func TestPullConsumerFetch(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -61,7 +61,7 @@ func TestPullConsumerFetch(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs, err := c.Fetch(5) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -107,7 +107,7 @@ func TestPullConsumerFetch(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs, err := c.Fetch(10) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -184,7 +184,7 @@ func TestPullConsumerFetch(t *testing.T) { }() time.Sleep(10 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) select { case err := <-errs: t.Fatalf("Unexpected error: %v", err) @@ -230,7 +230,7 @@ func TestPullConsumerFetch(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } time.Sleep(100 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msg := <-msgs.Messages() if msg != nil { @@ -263,14 +263,14 @@ func TestPullConsumerFetch(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) time.Sleep(50 * time.Millisecond) msgs, err := c.FetchNoWait(10) if err != nil { t.Fatalf("Unexpected error: %v", err) } time.Sleep(100 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) var msgsNum int for range msgs.Messages() { @@ -376,7 +376,7 @@ func TestPullConsumerFetch(t *testing.T) { if err != nil { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) // fetch 5 messages, should return normally msgs, err := c.Fetch(5, jetstream.FetchHeartbeat(50*time.Millisecond)) if err != nil { @@ -480,14 +480,13 @@ func TestPullConsumerFetch(t *testing.T) { func TestPullConsumerFetchBytes(t *testing.T) { testSubject := "FOO.123" msg := [10]byte{} - publishTestMsgs := func(t *testing.T, nc *nats.Conn, count int) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream, count int) { for i := 0; i < count; i++ { - if err := nc.Publish(testSubject, msg[:]); err != nil { + if _, err := js.Publish(context.Background(), testSubject, msg[:]); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } } - t.Run("no options, exact byte count received", func(t *testing.T) { srv := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, srv) @@ -513,7 +512,7 @@ func TestPullConsumerFetchBytes(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc, 5) + publishTestMsgs(t, js, 5) // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) msgs, err := c.FetchBytes(300) if err != nil { @@ -558,7 +557,7 @@ func TestPullConsumerFetchBytes(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc, 5) + publishTestMsgs(t, js, 5) // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) msgs, err := c.FetchBytes(250) if err != nil { @@ -602,7 +601,7 @@ func TestPullConsumerFetchBytes(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc, 5) + publishTestMsgs(t, js, 5) // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) msgs, err := c.FetchBytes(30) if err != nil { @@ -647,7 +646,7 @@ func TestPullConsumerFetchBytes(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc, 5) + publishTestMsgs(t, js, 5) // actual received msg size will be 60 (payload=10 + Subject=7 + Reply=43) msgs, err := c.FetchBytes(1000, jetstream.FetchMaxWait(50*time.Millisecond)) if err != nil { @@ -779,9 +778,9 @@ func TestPullConsumerFetchBytes(t *testing.T) { func TestPullConsumerFetch_WithCluster(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -819,7 +818,7 @@ func TestPullConsumerFetch_WithCluster(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs, err := c.Fetch(5) if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -867,7 +866,7 @@ func TestPullConsumerFetch_WithCluster(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } time.Sleep(100 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msg := <-msgs.Messages() if msg != nil { @@ -880,9 +879,9 @@ func TestPullConsumerFetch_WithCluster(t *testing.T) { func TestPullConsumerMessages(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -919,7 +918,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -981,7 +980,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -1042,7 +1041,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -1110,7 +1109,7 @@ func TestPullConsumerMessages(t *testing.T) { } defer it.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -1133,7 +1132,7 @@ func TestPullConsumerMessages(t *testing.T) { if !errors.Is(err, jetstream.ErrConsumerDeleted) { t.Fatalf("Expected error: %v; got: %v", jetstream.ErrConsumerDeleted, err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) time.Sleep(50 * time.Millisecond) _, err = it.Next() if !errors.Is(err, jetstream.ErrMsgIteratorClosed) { @@ -1180,7 +1179,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -1249,7 +1248,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -1490,7 +1489,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) for i := 0; i < len(testMsgs); i++ { msg, err := it.Next() if err != nil { @@ -1506,7 +1505,7 @@ func TestPullConsumerMessages(t *testing.T) { it.Stop() time.Sleep(10 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) it, err = c.Messages() if err != nil { t.Fatalf("Unexpected error: %v", err) @@ -1599,7 +1598,7 @@ func TestPullConsumerMessages(t *testing.T) { done := make(chan struct{}) errs := make(chan error) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) go func() { for i := 0; i < 2*len(testMsgs); i++ { msg, err := it.Next() @@ -1617,7 +1616,7 @@ func TestPullConsumerMessages(t *testing.T) { srv = restartBasicJSServer(t, srv) defer shutdownJSServerAndRemoveStorage(t, srv) time.Sleep(10 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) select { case <-done: @@ -1667,7 +1666,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) errs := make(chan error) msgs := make([]jetstream.Msg, 0) @@ -1737,7 +1736,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) go func() { time.Sleep(100 * time.Millisecond) it.Stop() @@ -1792,7 +1791,7 @@ func TestPullConsumerMessages(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) go func() { time.Sleep(100 * time.Millisecond) it.Drain() @@ -1820,9 +1819,9 @@ func TestPullConsumerMessages(t *testing.T) { func TestPullConsumerConsume(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -1865,7 +1864,7 @@ func TestPullConsumerConsume(t *testing.T) { } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) @@ -1924,7 +1923,7 @@ func TestPullConsumerConsume(t *testing.T) { defer l2.Stop() wg.Add(len(testMsgs)) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs1)+len(msgs2) != len(testMsgs) { @@ -1974,7 +1973,7 @@ func TestPullConsumerConsume(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() l.Stop() @@ -1991,7 +1990,7 @@ func TestPullConsumerConsume(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != 2*len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) @@ -2041,7 +2040,7 @@ func TestPullConsumerConsume(t *testing.T) { } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != len(testMsgs) { @@ -2091,7 +2090,7 @@ func TestPullConsumerConsume(t *testing.T) { } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != len(testMsgs) { @@ -2144,7 +2143,7 @@ func TestPullConsumerConsume(t *testing.T) { } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) @@ -2160,7 +2159,7 @@ func TestPullConsumerConsume(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("Timeout waiting for %v", jetstream.ErrConsumerDeleted) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) time.Sleep(50 * time.Millisecond) if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) @@ -2197,7 +2196,7 @@ func TestPullConsumerConsume(t *testing.T) { t.Fatalf("Error on subscribe: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs := make([]jetstream.Msg, 0) wg := &sync.WaitGroup{} wg.Add(len(testMsgs)) @@ -2359,7 +2358,7 @@ func TestPullConsumerConsume(t *testing.T) { } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != len(testMsgs) { @@ -2441,7 +2440,7 @@ func TestPullConsumerConsume(t *testing.T) { } defer l.Stop() - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() if len(msgs) != len(testMsgs) { t.Fatalf("Unexpected received message count; want %d; got %d", len(testMsgs), len(msgs)) @@ -2480,7 +2479,7 @@ func TestPullConsumerConsume(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2 * len(testMsgs)) msgs := make([]jetstream.Msg, 0) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) l, err := c.Consume(func(msg jetstream.Msg) { msgs = append(msgs, msg) wg.Done() @@ -2494,7 +2493,7 @@ func TestPullConsumerConsume(t *testing.T) { srv = restartBasicJSServer(t, srv) defer shutdownJSServerAndRemoveStorage(t, srv) time.Sleep(10 * time.Millisecond) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) wg.Wait() }) @@ -2524,7 +2523,7 @@ func TestPullConsumerConsume(t *testing.T) { } wg := &sync.WaitGroup{} wg.Add(2) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs := make([]jetstream.Msg, 0) cc, err := c.Consume(func(msg jetstream.Msg) { time.Sleep(80 * time.Millisecond) @@ -2572,7 +2571,7 @@ func TestPullConsumerConsume(t *testing.T) { } wg := &sync.WaitGroup{} wg.Add(5) - publishTestMsgs(t, nc) + publishTestMsgs(t, js) cc, err := c.Consume(func(msg jetstream.Msg) { time.Sleep(50 * time.Millisecond) msg.Ack() @@ -2585,6 +2584,154 @@ func TestPullConsumerConsume(t *testing.T) { cc.Drain() wg.Wait() }) + + t.Run("wait for closed after drain", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + closed := cc.Closed() + time.Sleep(100 * time.Millisecond) + + cc.Drain() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + + if len(msgs) != len(testMsgs) { + t.Fatalf("Unexpected received message count after consume closed; want %d; got %d", len(testMsgs), len(msgs)) + } + }) + + t.Run("wait for closed after stop", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + msgs := make([]jetstream.Msg, 0) + lock := sync.Mutex{} + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + lock.Lock() + msgs = append(msgs, msg) + lock.Unlock() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + closed := cc.Closed() + + cc.Stop() + + select { + case <-closed: + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + + if len(msgs) < 1 || len(msgs) > 3 { + t.Fatalf("Unexpected received message count after consume closed; want 1-3; got %d", len(msgs)) + } + }) + + t.Run("wait for closed on already closed consume", func(t *testing.T) { + srv := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, srv) + nc, err := nats.Connect(srv.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + js, err := jetstream.New(nc) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + s, err := js.CreateStream(ctx, jetstream.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + c, err := s.CreateOrUpdateConsumer(ctx, jetstream.ConsumerConfig{AckPolicy: jetstream.AckExplicitPolicy}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + publishTestMsgs(t, js) + cc, err := c.Consume(func(msg jetstream.Msg) { + time.Sleep(50 * time.Millisecond) + msg.Ack() + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + time.Sleep(100 * time.Millisecond) + + cc.Stop() + + time.Sleep(100 * time.Millisecond) + + select { + case <-cc.Closed(): + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for consume to be closed") + } + }) } func TestPullConsumerConsume_WithCluster(t *testing.T) { @@ -2809,9 +2956,9 @@ func TestPullConsumerConsume_WithCluster(t *testing.T) { func TestPullConsumerNext(t *testing.T) { testSubject := "FOO.123" testMsgs := []string{"m1", "m2", "m3", "m4", "m5"} - publishTestMsgs := func(t *testing.T, nc *nats.Conn) { + publishTestMsgs := func(t *testing.T, js jetstream.JetStream) { for _, msg := range testMsgs { - if err := nc.Publish(testSubject, []byte(msg)); err != nil { + if _, err := js.Publish(context.Background(), testSubject, []byte(msg)); err != nil { t.Fatalf("Unexpected error during publish: %s", err) } } @@ -2842,7 +2989,7 @@ func TestPullConsumerNext(t *testing.T) { t.Fatalf("Unexpected error: %v", err) } - publishTestMsgs(t, nc) + publishTestMsgs(t, js) msgs := make([]jetstream.Msg, 0) var i int diff --git a/jetstream/test/stream_test.go b/jetstream/test/stream_test.go index bac8ac277..632e9b3d2 100644 --- a/jetstream/test/stream_test.go +++ b/jetstream/test/stream_test.go @@ -193,6 +193,16 @@ func TestCreateConsumer(t *testing.T) { consumerConfig: jetstream.ConsumerConfig{FilterSubjects: []string{"FOO.A", ""}}, withError: jetstream.ErrEmptyFilter, }, + { + name: "with invalid filter subject, leading dot", + consumerConfig: jetstream.ConsumerConfig{FilterSubject: ".foo"}, + withError: jetstream.ErrInvalidSubject, + }, + { + name: "with invalid filter subject, trailing dot", + consumerConfig: jetstream.ConsumerConfig{FilterSubject: "foo."}, + withError: jetstream.ErrInvalidSubject, + }, { name: "consumer already exists, error", consumerConfig: jetstream.ConsumerConfig{Durable: "dur", Description: "test consumer"}, diff --git a/js.go b/js.go index 462fea17e..e024fae0a 100644 --- a/js.go +++ b/js.go @@ -1,4 +1,4 @@ -// Copyright 2020-2023 The NATS Authors +// Copyright 2020-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -58,6 +58,19 @@ type JetStream interface { // PublishAsyncComplete returns a channel that will be closed when all outstanding messages are ack'd. PublishAsyncComplete() <-chan struct{} + // CleanupPublisher will cleanup the publishing side of JetStreamContext. + // + // This will unsubscribe from the internal reply subject if needed. + // All pending async publishes will fail with ErrJetStreamPublisherClosed. + // + // If an error handler was provided, it will be called for each pending async + // publish and PublishAsyncComplete will be closed. + // + // After completing JetStreamContext is still usable - internal subscription + // will be recreated on next publish, but the acks from previous publishes will + // be lost. + CleanupPublisher() + // Subscribe creates an async Subscription for JetStream. // The stream and consumer names can be provided with the nats.Bind() option. // For creating an ephemeral (where the consumer name is picked by the server), @@ -466,6 +479,9 @@ type pubOpts struct { // stallWait is the max wait of a async pub ack. stallWait time.Duration + + // internal option to re-use existing paf in case of retry. + pafRetry *pubAckFuture } // pubAckResponse is the ack response from the JetStream API when publishing a message. @@ -531,7 +547,7 @@ func (js *js) PublishMsg(m *Msg, opts ...PubOpt) (*PubAck, error) { o.ttl = js.opts.wait } if o.stallWait > 0 { - return nil, fmt.Errorf("nats: stall wait cannot be set to sync publish") + return nil, errors.New("nats: stall wait cannot be set to sync publish") } if o.id != _EMPTY_ { @@ -620,13 +636,17 @@ type PubAckFuture interface { } type pubAckFuture struct { - js *js - msg *Msg - pa *PubAck - st time.Time - err error - errCh chan error - doneCh chan *PubAck + js *js + msg *Msg + pa *PubAck + st time.Time + err error + errCh chan error + doneCh chan *PubAck + retries int + maxRetries int + retryWait time.Duration + reply string } func (paf *pubAckFuture) Ok() <-chan *PubAck { @@ -712,10 +732,17 @@ func (js *js) resetPendingAcksOnReconnect() { return } js.mu.Lock() - for _, paf := range js.pafs { + errCb := js.opts.aecb + for id, paf := range js.pafs { paf.err = ErrDisconnected + if paf.errCh != nil { + paf.errCh <- paf.err + } + if errCb != nil { + defer errCb(js, paf.msg, ErrDisconnected) + } + delete(js.pafs, id) } - js.pafs = nil if js.dch != nil { close(js.dch) js.dch = nil @@ -724,6 +751,38 @@ func (js *js) resetPendingAcksOnReconnect() { } } +// CleanupPublisher will cleanup the publishing side of JetStreamContext. +// +// This will unsubscribe from the internal reply subject if needed. +// All pending async publishes will fail with ErrJetStreamContextClosed. +// +// If an error handler was provided, it will be called for each pending async +// publish and PublishAsyncComplete will be closed. +// +// After completing JetStreamContext is still usable - internal subscription +// will be recreated on next publish, but the acks from previous publishes will +// be lost. +func (js *js) CleanupPublisher() { + js.cleanupReplySub() + js.mu.Lock() + errCb := js.opts.aecb + for id, paf := range js.pafs { + paf.err = ErrJetStreamPublisherClosed + if paf.errCh != nil { + paf.errCh <- paf.err + } + if errCb != nil { + defer errCb(js, paf.msg, ErrJetStreamPublisherClosed) + } + delete(js.pafs, id) + } + if js.dch != nil { + close(js.dch) + js.dch = nil + } + js.mu.Unlock() +} + func (js *js) cleanupReplySub() { js.mu.Lock() if js.rsub != nil { @@ -796,20 +855,30 @@ func (js *js) handleAsyncReply(m *Msg) { js.mu.Unlock() return } - // Remove - delete(js.pafs, id) - // Check on anyone stalled and waiting. - if js.stc != nil && len(js.pafs) < js.opts.maxpa { - close(js.stc) - js.stc = nil + closeStc := func() { + // Check on anyone stalled and waiting. + if js.stc != nil && len(js.pafs) < js.opts.maxpa { + close(js.stc) + js.stc = nil + } } - // Check on anyone one waiting on done status. - if js.dch != nil && len(js.pafs) == 0 { - dch := js.dch - js.dch = nil - // Defer here so error is processed and can be checked. - defer close(dch) + + closeDchFn := func() func() { + var dch chan struct{} + // Check on anyone one waiting on done status. + if js.dch != nil && len(js.pafs) == 0 { + dch = js.dch + js.dch = nil + } + // Return function to close done channel which + // should be deferred so that error is processed and + // can be checked. + return func() { + if dch != nil { + close(dch) + } + } } doErr := func(err error) { @@ -826,10 +895,39 @@ func (js *js) handleAsyncReply(m *Msg) { // Process no responders etc. if len(m.Data) == 0 && m.Header.Get(statusHdr) == noResponders { + if paf.retries < paf.maxRetries { + paf.retries++ + time.AfterFunc(paf.retryWait, func() { + js.mu.Lock() + paf := js.getPAF(id) + js.mu.Unlock() + if paf == nil { + return + } + _, err := js.PublishMsgAsync(paf.msg, pubOptFn(func(po *pubOpts) error { + po.pafRetry = paf + return nil + })) + if err != nil { + js.mu.Lock() + doErr(err) + } + }) + js.mu.Unlock() + return + } + delete(js.pafs, id) + closeStc() + defer closeDchFn()() doErr(ErrNoResponders) return } + //remove + delete(js.pafs, id) + closeStc() + defer closeDchFn()() + var pa pubAckResponse if err := json.Unmarshal(m.Data, &pa); err != nil { doErr(ErrInvalidJSAck) @@ -896,6 +994,10 @@ func (js *js) PublishMsgAsync(m *Msg, opts ...PubOpt) (PubAckFuture, error) { } } + if o.rnum < 0 { + return nil, fmt.Errorf("%w: retry attempts cannot be negative", ErrInvalidArg) + } + // Timeouts and contexts do not make sense for these. if o.ttl != 0 || o.ctx != nil { return nil, ErrContextAndTimeout @@ -923,30 +1025,42 @@ func (js *js) PublishMsgAsync(m *Msg, opts ...PubOpt) (PubAckFuture, error) { } // Reply - if m.Reply != _EMPTY_ { + paf := o.pafRetry + if paf == nil && m.Reply != _EMPTY_ { return nil, errors.New("nats: reply subject should be empty") } - reply := m.Reply - m.Reply = js.newAsyncReply() - defer func() { m.Reply = reply }() + var id string + var reply string - if m.Reply == _EMPTY_ { - return nil, errors.New("nats: error creating async reply handler") - } + // register new paf if not retrying + if paf == nil { + reply = js.newAsyncReply() - id := m.Reply[js.replyPrefixLen:] - paf := &pubAckFuture{msg: m, st: time.Now()} - numPending, maxPending := js.registerPAF(id, paf) + if reply == _EMPTY_ { + return nil, errors.New("nats: error creating async reply handler") + } + + id = reply[js.replyPrefixLen:] + paf = &pubAckFuture{msg: m, st: time.Now(), maxRetries: o.rnum, retryWait: o.rwait, reply: reply} + numPending, maxPending := js.registerPAF(id, paf) - if maxPending > 0 && numPending >= maxPending { - select { - case <-js.asyncStall(): - case <-time.After(stallWait): - js.clearPAF(id) - return nil, errors.New("nats: stalled with too many outstanding async published messages") + if maxPending > 0 && numPending > maxPending { + select { + case <-js.asyncStall(): + case <-time.After(stallWait): + js.clearPAF(id) + return nil, errors.New("nats: stalled with too many outstanding async published messages") + } } + } else { + reply = paf.reply + id = reply[js.replyPrefixLen:] + } + hdr, err := m.headerBytes() + if err != nil { + return nil, err } - if err := js.nc.PublishMsg(m); err != nil { + if err := js.nc.publish(m.Subject, reply, hdr, m.Data); err != nil { js.clearPAF(id) return nil, err } @@ -1029,7 +1143,7 @@ func RetryAttempts(num int) PubOpt { func StallWait(ttl time.Duration) PubOpt { return pubOptFn(func(opts *pubOpts) error { if ttl <= 0 { - return fmt.Errorf("nats: stall wait should be more than 0") + return errors.New("nats: stall wait should be more than 0") } opts.stallWait = ttl return nil @@ -1387,11 +1501,11 @@ func processConsInfo(info *ConsumerInfo, userCfg *ConsumerConfig, isPullMode boo // Prevent an user from attempting to create a queue subscription on // a JS consumer that was not created with a deliver group. if queue != _EMPTY_ { - return _EMPTY_, fmt.Errorf("cannot create a queue subscription for a consumer without a deliver group") + return _EMPTY_, errors.New("cannot create a queue subscription for a consumer without a deliver group") } else if info.PushBound { // Need to reject a non queue subscription to a non queue consumer // if the consumer is already bound. - return _EMPTY_, fmt.Errorf("consumer is already bound to a subscription") + return _EMPTY_, errors.New("consumer is already bound to a subscription") } } else { // If the JS consumer has a deliver group, we need to fail a non queue @@ -1493,7 +1607,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, // If no stream name is specified, the subject cannot be empty. if subj == _EMPTY_ && o.stream == _EMPTY_ { - return nil, fmt.Errorf("nats: subject required") + return nil, errors.New("nats: subject required") } // Note that these may change based on the consumer info response we may get. @@ -1515,7 +1629,7 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, // would subscribe to and server would send on. if o.cfg.Heartbeat > 0 || o.cfg.FlowControl { // Not making this a public ErrXXX in case we allow in the future. - return nil, fmt.Errorf("nats: queue subscription doesn't support idle heartbeat nor flow control") + return nil, errors.New("nats: queue subscription doesn't support idle heartbeat nor flow control") } // If this is a queue subscription and no consumer nor durable name was specified, @@ -1553,31 +1667,31 @@ func (js *js) subscribe(subj, queue string, cb MsgHandler, ch chan *Msg, isSync, if o.ordered { // Make sure we are not durable. if isDurable { - return nil, fmt.Errorf("nats: durable can not be set for an ordered consumer") + return nil, errors.New("nats: durable can not be set for an ordered consumer") } // Check ack policy. if o.cfg.AckPolicy != ackPolicyNotSet { - return nil, fmt.Errorf("nats: ack policy can not be set for an ordered consumer") + return nil, errors.New("nats: ack policy can not be set for an ordered consumer") } // Check max deliver. if o.cfg.MaxDeliver != 1 && o.cfg.MaxDeliver != 0 { - return nil, fmt.Errorf("nats: max deliver can not be set for an ordered consumer") + return nil, errors.New("nats: max deliver can not be set for an ordered consumer") } // No deliver subject, we pick our own. if o.cfg.DeliverSubject != _EMPTY_ { - return nil, fmt.Errorf("nats: deliver subject can not be set for an ordered consumer") + return nil, errors.New("nats: deliver subject can not be set for an ordered consumer") } // Queue groups not allowed. if queue != _EMPTY_ { - return nil, fmt.Errorf("nats: queues not be set for an ordered consumer") + return nil, errors.New("nats: queues not be set for an ordered consumer") } // Check for bound consumers. if consumer != _EMPTY_ { - return nil, fmt.Errorf("nats: can not bind existing consumer for an ordered consumer") + return nil, errors.New("nats: can not bind existing consumer for an ordered consumer") } // Check for pull mode. if isPullMode { - return nil, fmt.Errorf("nats: can not use pull mode for an ordered consumer") + return nil, errors.New("nats: can not use pull mode for an ordered consumer") } // Setup how we need it to be here. o.cfg.FlowControl = true @@ -2311,7 +2425,7 @@ func Description(description string) SubOpt { func Durable(consumer string) SubOpt { return subOptFn(func(opts *subOpts) error { if opts.cfg.Durable != _EMPTY_ { - return fmt.Errorf("nats: option Durable set more than once") + return errors.New("nats: option Durable set more than once") } if opts.consumer != _EMPTY_ && opts.consumer != consumer { return fmt.Errorf("nats: duplicate consumer names (%s and %s)", opts.consumer, consumer) @@ -2861,7 +2975,14 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } var hbTimer *time.Timer var hbErr error - if err == nil && len(msgs) < batch { + sub.mu.Lock() + subClosed := sub.closed || sub.draining + sub.mu.Unlock() + if subClosed { + err = errors.Join(ErrBadSubscription, ErrSubscriptionClosed) + } + hbLock := sync.Mutex{} + if err == nil && len(msgs) < batch && !subClosed { // For batch real size of 1, it does not make sense to set no_wait in // the request. noWait := batch-len(msgs) > 1 @@ -2882,10 +3003,11 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } // Make our request expiration a bit shorter than the current timeout. - expires := ttl - if ttl >= 20*time.Millisecond { - expires = ttl - 10*time.Millisecond + expiresDiff := time.Duration(float64(ttl) * 0.1) + if expiresDiff > 5*time.Second { + expiresDiff = 5 * time.Second } + expires := ttl - expiresDiff nr.Batch = batch - len(msgs) nr.Expires = expires @@ -2903,7 +3025,9 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { if o.hb > 0 { if hbTimer == nil { hbTimer = time.AfterFunc(2*o.hb, func() { + hbLock.Lock() hbErr = ErrNoHeartbeat + hbLock.Unlock() cancel() }) } else { @@ -2945,6 +3069,8 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) { } // If there is at least a message added to msgs, then need to return OK and no error if err != nil && len(msgs) == 0 { + hbLock.Lock() + defer hbLock.Unlock() if hbErr != nil { return nil, hbErr } @@ -3129,8 +3255,14 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e result.msgs <- msg } } - if len(result.msgs) == batch || result.err != nil { + sub.mu.Lock() + subClosed := sub.closed || sub.draining + sub.mu.Unlock() + if len(result.msgs) == batch || result.err != nil || subClosed { close(result.msgs) + if subClosed && len(result.msgs) == 0 { + return nil, errors.Join(ErrBadSubscription, ErrSubscriptionClosed) + } result.done <- struct{}{} return result, nil } @@ -3139,10 +3271,11 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e ttl = time.Until(deadline) // Make our request expiration a bit shorter than the current timeout. - expires := ttl - if ttl >= 20*time.Millisecond { - expires = ttl - 10*time.Millisecond + expiresDiff := time.Duration(float64(ttl) * 0.1) + if expiresDiff > 5*time.Second { + expiresDiff = 5 * time.Second } + expires := ttl - expiresDiff requestBatch := batch - len(result.msgs) req := nextRequest{ @@ -3169,9 +3302,12 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e } var hbTimer *time.Timer var hbErr error + hbLock := sync.Mutex{} if o.hb > 0 { hbTimer = time.AfterFunc(2*o.hb, func() { + hbLock.Lock() hbErr = ErrNoHeartbeat + hbLock.Unlock() cancel() }) } @@ -3207,11 +3343,13 @@ func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, e } } if err != nil { + hbLock.Lock() if hbErr != nil { result.err = hbErr } else { result.err = o.checkCtxErr(err) } + hbLock.Unlock() } close(result.msgs) result.done <- struct{}{} @@ -3812,7 +3950,7 @@ func (alg StoreCompression) MarshalJSON() ([]byte, error) { case NoCompression: str = "none" default: - return nil, fmt.Errorf("unknown compression algorithm") + return nil, errors.New("unknown compression algorithm") } return json.Marshal(str) } @@ -3828,7 +3966,7 @@ func (alg *StoreCompression) UnmarshalJSON(b []byte) error { case "none": *alg = NoCompression default: - return fmt.Errorf("unknown compression algorithm") + return errors.New("unknown compression algorithm") } return nil } diff --git a/jserrors.go b/jserrors.go index b5c968465..1c22d812b 100644 --- a/jserrors.go +++ b/jserrors.go @@ -22,6 +22,10 @@ var ( // API errors // ErrJetStreamNotEnabled is an error returned when JetStream is not enabled for an account. + // + // Note: This error will not be returned in clustered mode, even if each + // server in the cluster does not have JetStream enabled. In clustered mode, + // requests will time out instead. ErrJetStreamNotEnabled JetStreamError = &jsError{apiErr: &APIError{ErrorCode: JSErrCodeJetStreamNotEnabled, Description: "jetstream not enabled", Code: 503}} // ErrJetStreamNotEnabledForAccount is an error returned when JetStream is not enabled for an account. @@ -120,6 +124,9 @@ var ( // ErrInvalidConsumerName is returned when the provided consumer name is invalid (contains '.' or ' '). ErrInvalidConsumerName JetStreamError = &jsError{message: "invalid consumer name"} + // ErrInvalidFilterSubject is returned when the provided filter subject is invalid. + ErrInvalidFilterSubject JetStreamError = &jsError{message: "invalid filter subject"} + // ErrNoMatchingStream is returned when stream lookup by subject is unsuccessful. ErrNoMatchingStream JetStreamError = &jsError{message: "no stream matches subject"} @@ -141,7 +148,13 @@ var ( // ErrNoHeartbeat is returned when no heartbeat is received from server when sending requests with pull consumer. ErrNoHeartbeat JetStreamError = &jsError{message: "no heartbeat received"} - // DEPRECATED: ErrInvalidDurableName is no longer returned and will be removed in future releases. + // ErrSubscriptionClosed is returned when attempting to send pull request to a closed subscription + ErrSubscriptionClosed JetStreamError = &jsError{message: "subscription closed"} + + // ErrJetStreamPublisherClosed is returned for each unfinished ack future when JetStream.Cleanup is called. + ErrJetStreamPublisherClosed JetStreamError = &jsError{message: "jetstream context closed"} + + // Deprecated: ErrInvalidDurableName is no longer returned and will be removed in future releases. // Use ErrInvalidConsumerName instead. ErrInvalidDurableName = errors.New("nats: invalid durable name") ) diff --git a/jsm.go b/jsm.go index 94fa86c32..2ae19c7a3 100644 --- a/jsm.go +++ b/jsm.go @@ -41,7 +41,7 @@ type JetStreamManager interface { PurgeStream(name string, opts ...JSOpt) error // StreamsInfo can be used to retrieve a list of StreamInfo objects. - // DEPRECATED: Use Streams() instead. + // Deprecated: Use Streams() instead. StreamsInfo(opts ...JSOpt) <-chan *StreamInfo // Streams can be used to retrieve a list of StreamInfo objects. @@ -86,7 +86,7 @@ type JetStreamManager interface { ConsumerInfo(stream, name string, opts ...JSOpt) (*ConsumerInfo, error) // ConsumersInfo is used to retrieve a list of ConsumerInfo objects. - // DEPRECATED: Use Consumers() instead. + // Deprecated: Use Consumers() instead. ConsumersInfo(stream string, opts ...JSOpt) <-chan *ConsumerInfo // Consumers is used to retrieve a list of ConsumerInfo objects. @@ -106,51 +106,143 @@ type JetStreamManager interface { // There are sensible defaults for most. If no subjects are // given the name will be used as the only subject. type StreamConfig struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Subjects []string `json:"subjects,omitempty"` - Retention RetentionPolicy `json:"retention"` - MaxConsumers int `json:"max_consumers"` - MaxMsgs int64 `json:"max_msgs"` - MaxBytes int64 `json:"max_bytes"` - Discard DiscardPolicy `json:"discard"` - DiscardNewPerSubject bool `json:"discard_new_per_subject,omitempty"` - MaxAge time.Duration `json:"max_age"` - MaxMsgsPerSubject int64 `json:"max_msgs_per_subject"` - MaxMsgSize int32 `json:"max_msg_size,omitempty"` - Storage StorageType `json:"storage"` - Replicas int `json:"num_replicas"` - NoAck bool `json:"no_ack,omitempty"` - Template string `json:"template_owner,omitempty"` - Duplicates time.Duration `json:"duplicate_window,omitempty"` - Placement *Placement `json:"placement,omitempty"` - Mirror *StreamSource `json:"mirror,omitempty"` - Sources []*StreamSource `json:"sources,omitempty"` - Sealed bool `json:"sealed,omitempty"` - DenyDelete bool `json:"deny_delete,omitempty"` - DenyPurge bool `json:"deny_purge,omitempty"` - AllowRollup bool `json:"allow_rollup_hdrs,omitempty"` - Compression StoreCompression `json:"compression"` - FirstSeq uint64 `json:"first_seq,omitempty"` - - // Allow applying a subject transform to incoming messages before doing anything else. + // Name is the name of the stream. It is required and must be unique + // across the JetStream account. + // + // Name Names cannot contain whitespace, ., *, >, path separators + // (forward or backwards slash), and non-printable characters. + Name string `json:"name"` + + // Description is an optional description of the stream. + Description string `json:"description,omitempty"` + + // Subjects is a list of subjects that the stream is listening on. + // Wildcards are supported. Subjects cannot be set if the stream is + // created as a mirror. + Subjects []string `json:"subjects,omitempty"` + + // Retention defines the message retention policy for the stream. + // Defaults to LimitsPolicy. + Retention RetentionPolicy `json:"retention"` + + // MaxConsumers specifies the maximum number of consumers allowed for + // the stream. + MaxConsumers int `json:"max_consumers"` + + // MaxMsgs is the maximum number of messages the stream will store. + // After reaching the limit, stream adheres to the discard policy. + // If not set, server default is -1 (unlimited). + MaxMsgs int64 `json:"max_msgs"` + + // MaxBytes is the maximum total size of messages the stream will store. + // After reaching the limit, stream adheres to the discard policy. + // If not set, server default is -1 (unlimited). + MaxBytes int64 `json:"max_bytes"` + + // Discard defines the policy for handling messages when the stream + // reaches its limits in terms of number of messages or total bytes. + Discard DiscardPolicy `json:"discard"` + + // DiscardNewPerSubject is a flag to enable discarding new messages per + // subject when limits are reached. Requires DiscardPolicy to be + // DiscardNew and the MaxMsgsPerSubject to be set. + DiscardNewPerSubject bool `json:"discard_new_per_subject,omitempty"` + + // MaxAge is the maximum age of messages that the stream will retain. + MaxAge time.Duration `json:"max_age"` + + // MaxMsgsPerSubject is the maximum number of messages per subject that + // the stream will retain. + MaxMsgsPerSubject int64 `json:"max_msgs_per_subject"` + + // MaxMsgSize is the maximum size of any single message in the stream. + MaxMsgSize int32 `json:"max_msg_size,omitempty"` + + // Storage specifies the type of storage backend used for the stream + // (file or memory). + Storage StorageType `json:"storage"` + + // Replicas is the number of stream replicas in clustered JetStream. + // Defaults to 1, maximum is 5. + Replicas int `json:"num_replicas"` + + // NoAck is a flag to disable acknowledging messages received by this + // stream. + // + // If set to true, publish methods from the JetStream client will not + // work as expected, since they rely on acknowledgements. Core NATS + // publish methods should be used instead. Note that this will make + // message delivery less reliable. + NoAck bool `json:"no_ack,omitempty"` + + // Duplicates is the window within which to track duplicate messages. + // If not set, server default is 2 minutes. + Duplicates time.Duration `json:"duplicate_window,omitempty"` + + // Placement is used to declare where the stream should be placed via + // tags and/or an explicit cluster name. + Placement *Placement `json:"placement,omitempty"` + + // Mirror defines the configuration for mirroring another stream. + Mirror *StreamSource `json:"mirror,omitempty"` + + // Sources is a list of other streams this stream sources messages from. + Sources []*StreamSource `json:"sources,omitempty"` + + // Sealed streams do not allow messages to be published or deleted via limits or API, + // sealed streams can not be unsealed via configuration update. Can only + // be set on already created streams via the Update API. + Sealed bool `json:"sealed,omitempty"` + + // DenyDelete restricts the ability to delete messages from a stream via + // the API. Defaults to false. + DenyDelete bool `json:"deny_delete,omitempty"` + + // DenyPurge restricts the ability to purge messages from a stream via + // the API. Defaults to false. + DenyPurge bool `json:"deny_purge,omitempty"` + + // AllowRollup allows the use of the Nats-Rollup header to replace all + // contents of a stream, or subject in a stream, with a single new + // message. + AllowRollup bool `json:"allow_rollup_hdrs,omitempty"` + + // Compression specifies the message storage compression algorithm. + // Defaults to NoCompression. + Compression StoreCompression `json:"compression"` + + // FirstSeq is the initial sequence number of the first message in the + // stream. + FirstSeq uint64 `json:"first_seq,omitempty"` + + // SubjectTransform allows applying a transformation to matching + // messages' subjects. SubjectTransform *SubjectTransformConfig `json:"subject_transform,omitempty"` - // Allow republish of the message after being sequenced and stored. + // RePublish allows immediate republishing a message to the configured + // subject after it's stored. RePublish *RePublish `json:"republish,omitempty"` - // Allow higher performance, direct access to get individual messages. E.g. KeyValue + // AllowDirect enables direct access to individual messages using direct + // get API. Defaults to false. AllowDirect bool `json:"allow_direct"` - // Allow higher performance and unified direct access for mirrors as well. + + // MirrorDirect enables direct access to individual messages from the + // origin stream using direct get API. Defaults to false. MirrorDirect bool `json:"mirror_direct"` - // Limits for consumers on this stream. + // ConsumerLimits defines limits of certain values that consumers can + // set, defaults for those who don't set these settings ConsumerLimits StreamConsumerLimits `json:"consumer_limits,omitempty"` - // Metadata is additional metadata for the Stream. - // Keys starting with `_nats` are reserved. - // NOTE: Metadata requires nats-server v2.10.0+ + // Metadata is a set of application-defined key-value pairs for + // associating metadata on the stream. This feature requires nats-server + // v2.10.0 or later. Metadata map[string]string `json:"metadata,omitempty"` + + // Template identifies the template that manages the Stream. Deprecated: + // This feature is no longer supported. + Template string `json:"template_owner,omitempty"` } // SubjectTransformConfig is for applying a subject transform (to matching messages) before doing anything else when a new message is received. @@ -288,9 +380,13 @@ type accountInfoResponse struct { AccountInfo } -// AccountInfo retrieves info about the JetStream usage from the current account. -// If JetStream is not enabled, this will return ErrJetStreamNotEnabled -// Other errors can happen but are generally considered retryable +// AccountInfo fetches account information from the server, containing details +// about the account associated with this JetStream connection. If account is +// not enabled for JetStream, ErrJetStreamNotEnabledForAccount is returned. +// +// If the server does not have JetStream enabled, ErrJetStreamNotEnabled is +// returned (for a single server setup). For clustered topologies, AccountInfo +// will time out. func (js *js) AccountInfo(opts ...JSOpt) (*AccountInfo, error) { o, cancel, err := getJSContextOpts(js.opts, opts...) if err != nil { @@ -410,6 +506,10 @@ func (js *js) upsertConsumer(stream, consumerName string, cfg *ConsumerConfig, o // if filter subject is empty or ">", use the endpoint without filter subject ccSubj = fmt.Sprintf(apiConsumerCreateT, stream, consumerName) } else { + // safeguard against passing invalid filter subject in request subject + if cfg.FilterSubject[0] == '.' || cfg.FilterSubject[len(cfg.FilterSubject)-1] == '.' { + return nil, fmt.Errorf("%w: %q", ErrInvalidFilterSubject, cfg.FilterSubject) + } // if filter subject is not empty, use the endpoint with filter subject ccSubj = fmt.Sprintf(apiConsumerCreateWithFilterSubjectT, stream, consumerName, cfg.FilterSubject) } @@ -647,7 +747,7 @@ func (jsc *js) Consumers(stream string, opts ...JSOpt) <-chan *ConsumerInfo { } // ConsumersInfo is used to retrieve a list of ConsumerInfo objects. -// DEPRECATED: Use Consumers() instead. +// Deprecated: Use Consumers() instead. func (jsc *js) ConsumersInfo(stream string, opts ...JSOpt) <-chan *ConsumerInfo { return jsc.Consumers(stream, opts...) } @@ -1230,11 +1330,11 @@ func convertDirectGetMsgResponseToMsg(name string, r *Msg) (*RawStreamMsg, error // Check for headers that give us the required information to // reconstruct the message. if len(r.Header) == 0 { - return nil, fmt.Errorf("nats: response should have headers") + return nil, errors.New("nats: response should have headers") } stream := r.Header.Get(JSStream) if stream == _EMPTY_ { - return nil, fmt.Errorf("nats: missing stream header") + return nil, errors.New("nats: missing stream header") } // Mirrors can now answer direct gets, so removing check for name equality. @@ -1242,7 +1342,7 @@ func convertDirectGetMsgResponseToMsg(name string, r *Msg) (*RawStreamMsg, error seqStr := r.Header.Get(JSSequence) if seqStr == _EMPTY_ { - return nil, fmt.Errorf("nats: missing sequence header") + return nil, errors.New("nats: missing sequence header") } seq, err := strconv.ParseUint(seqStr, 10, 64) if err != nil { @@ -1250,7 +1350,7 @@ func convertDirectGetMsgResponseToMsg(name string, r *Msg) (*RawStreamMsg, error } timeStr := r.Header.Get(JSTimeStamp) if timeStr == _EMPTY_ { - return nil, fmt.Errorf("nats: missing timestamp header") + return nil, errors.New("nats: missing timestamp header") } // Temporary code: the server in main branch is sending with format // "2006-01-02 15:04:05.999999999 +0000 UTC", but will be changed @@ -1265,7 +1365,7 @@ func convertDirectGetMsgResponseToMsg(name string, r *Msg) (*RawStreamMsg, error } subj := r.Header.Get(JSSubject) if subj == _EMPTY_ { - return nil, fmt.Errorf("nats: missing subject header") + return nil, errors.New("nats: missing subject header") } return &RawStreamMsg{ Subject: subj, @@ -1517,7 +1617,7 @@ func (jsc *js) Streams(opts ...JSOpt) <-chan *StreamInfo { } // StreamsInfo can be used to retrieve a list of StreamInfo objects. -// DEPRECATED: Use Streams() instead. +// Deprecated: Use Streams() instead. func (jsc *js) StreamsInfo(opts ...JSOpt) <-chan *StreamInfo { return jsc.Streams(opts...) } diff --git a/kv.go b/kv.go index 0864f30cc..bcb283ff8 100644 --- a/kv.go +++ b/kv.go @@ -54,6 +54,7 @@ type KeyValue interface { // Create will add the key/value pair iff it does not exist. Create(key string, value []byte) (revision uint64, err error) // Update will update the value iff the latest revision matches. + // Update also resets the TTL associated with the key (if any). Update(key string, value []byte, last uint64) (revision uint64, err error) // Delete will place a delete marker and leave all revisions. Delete(key string, opts ...DeleteOpt) error @@ -64,8 +65,11 @@ type KeyValue interface { Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) // WatchAll will invoke the callback for all updates. WatchAll(opts ...WatchOpt) (KeyWatcher, error) + // WatchFiltered will watch for any updates to keys that match the keys + // argument. It can be configured with the same options as Watch. + WatchFiltered(keys []string, opts ...WatchOpt) (KeyWatcher, error) // Keys will return all keys. - // DEPRECATED: Use ListKeys instead to avoid memory issues. + // Deprecated: Use ListKeys instead to avoid memory issues. Keys(opts ...WatchOpt) ([]string, error) // ListKeys will return all keys in a channel. ListKeys(opts ...WatchOpt) (KeyLister, error) @@ -249,22 +253,22 @@ func purge() DeleteOpt { // KeyValueConfig is for configuring a KeyValue store. type KeyValueConfig struct { - Bucket string - Description string - MaxValueSize int32 - History uint8 - TTL time.Duration - MaxBytes int64 - Storage StorageType - Replicas int - Placement *Placement - RePublish *RePublish - Mirror *StreamSource - Sources []*StreamSource + Bucket string `json:"bucket"` + Description string `json:"description,omitempty"` + MaxValueSize int32 `json:"max_value_size,omitempty"` + History uint8 `json:"history,omitempty"` + TTL time.Duration `json:"ttl,omitempty"` + MaxBytes int64 `json:"max_bytes,omitempty"` + Storage StorageType `json:"storage,omitempty"` + Replicas int `json:"num_replicas,omitempty"` + Placement *Placement `json:"placement,omitempty"` + RePublish *RePublish `json:"republish,omitempty"` + Mirror *StreamSource `json:"mirror,omitempty"` + Sources []*StreamSource `json:"sources,omitempty"` // Enable underlying stream compression. // NOTE: Compression is supported for nats-server 2.10.0+ - Compression bool + Compression bool `json:"compression,omitempty"` } // Used to watch all keys. @@ -344,8 +348,9 @@ const ( // Regex for valid keys and buckets. var ( - validBucketRe = regexp.MustCompile(`\A[a-zA-Z0-9_-]+\z`) - validKeyRe = regexp.MustCompile(`\A[-/_=\.a-zA-Z0-9]+\z`) + validBucketRe = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + validKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9]+$`) + validSearchKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9*]*[>]?$`) ) // KeyValue will lookup and bind to an existing KeyValue store. @@ -353,7 +358,7 @@ func (js *js) KeyValue(bucket string) (KeyValue, error) { if !js.nc.serverMinVersion(2, 6, 2) { return nil, errors.New("nats: key-value requires at least server version 2.6.2") } - if !validBucketRe.MatchString(bucket) { + if !bucketValid(bucket) { return nil, ErrInvalidBucketName } stream := fmt.Sprintf(kvBucketNameTmpl, bucket) @@ -381,7 +386,7 @@ func (js *js) CreateKeyValue(cfg *KeyValueConfig) (KeyValue, error) { if cfg == nil { return nil, ErrKeyValueConfigRequired } - if !validBucketRe.MatchString(cfg.Bucket) { + if !bucketValid(cfg.Bucket) { return nil, ErrInvalidBucketName } if _, err := js.AccountInfo(); err != nil { @@ -507,7 +512,7 @@ func (js *js) CreateKeyValue(cfg *KeyValueConfig) (KeyValue, error) { // DeleteKeyValue will delete this KeyValue store (JetStream stream). func (js *js) DeleteKeyValue(bucket string) error { - if !validBucketRe.MatchString(bucket) { + if !bucketValid(bucket) { return ErrInvalidBucketName } stream := fmt.Sprintf(kvBucketNameTmpl, bucket) @@ -547,6 +552,13 @@ func (e *kve) Created() time.Time { return e.created } func (e *kve) Delta() uint64 { return e.delta } func (e *kve) Operation() KeyValueOp { return e.op } +func bucketValid(bucket string) bool { + if len(bucket) == 0 { + return false + } + return validBucketRe.MatchString(bucket) +} + func keyValid(key string) bool { if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' { return false @@ -554,6 +566,13 @@ func keyValid(key string) bool { return validKeyRe.MatchString(key) } +func searchKeyValid(key string) bool { + if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' { + return false + } + return validSearchKeyRe.MatchString(key) +} + // Get returns the latest value for the key. func (kv *kvs) Get(key string) (KeyValueEntry, error) { e, err := kv.get(key, kvLatestRevision) @@ -948,9 +967,12 @@ func (kv *kvs) WatchAll(opts ...WatchOpt) (KeyWatcher, error) { return kv.Watch(AllKeys, opts...) } -// Watch will fire the callback when a key that matches the keys pattern is updated. -// keys needs to be a valid NATS subject. -func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { +func (kv *kvs) WatchFiltered(keys []string, opts ...WatchOpt) (KeyWatcher, error) { + for _, key := range keys { + if !searchKeyValid(key) { + return nil, fmt.Errorf("%w: %s", ErrInvalidKey, "key cannot be empty and must be a valid NATS subject") + } + } var o watchOpts for _, opt := range opts { if opt != nil { @@ -961,10 +983,20 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { } // Could be a pattern so don't check for validity as we normally do. - var b strings.Builder - b.WriteString(kv.pre) - b.WriteString(keys) - keys = b.String() + for i, key := range keys { + var b strings.Builder + b.WriteString(kv.pre) + b.WriteString(key) + keys[i] = b.String() + } + + // if no keys are provided, watch all keys + if len(keys) == 0 { + var b strings.Builder + b.WriteString(kv.pre) + b.WriteString(AllKeys) + keys = []string{b.String()} + } // We will block below on placing items on the chan. That is by design. w := &watcher{updates: make(chan KeyValueEntry, 256), ctx: o.ctx} @@ -1037,7 +1069,14 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { // update() callback. w.mu.Lock() defer w.mu.Unlock() - sub, err := kv.js.Subscribe(keys, update, subOpts...) + var sub *Subscription + var err error + if len(keys) == 1 { + sub, err = kv.js.Subscribe(keys[0], update, subOpts...) + } else { + subOpts = append(subOpts, ConsumerFilterSubjects(keys...)) + sub, err = kv.js.Subscribe("", update, subOpts...) + } if err != nil { return nil, err } @@ -1064,6 +1103,12 @@ func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { return w, nil } +// Watch will fire the callback when a key that matches the keys pattern is updated. +// keys needs to be a valid NATS subject. +func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) { + return kv.WatchFiltered([]string{keys}, opts...) +} + // Bucket returns the current bucket name (JetStream stream). func (kv *kvs) Bucket() string { return kv.name diff --git a/micro/README.md b/micro/README.md index e4ffe0977..99949eb8b 100644 --- a/micro/README.md +++ b/micro/README.md @@ -206,35 +206,69 @@ Service IDs can be discovered by: ```sh nats req '$SRV.PING.EchoService' '' --replies=3 -8:59:41 Sending request on "$SRV.PING.EchoService" -18:59:41 Received with rtt 688.042µs -{"name":"EchoService","id":"tNoopzL5Sp1M4qJZdhdxqC","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.ping_response"} +13:03:04 Sending request on "$SRV.PING.EchoService" +13:03:04 Received with rtt 1.302208ms +{"name":"EchoService","id":"x3Yuiq7g7MoxhXdxk7i4K7","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.ping_response"} -18:59:41 Received with rtt 704.167µs -{"name":"EchoService","id":"tNoopzL5Sp1M4qJZdhdxvO","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.ping_response"} +13:03:04 Received with rtt 1.317ms +{"name":"EchoService","id":"x3Yuiq7g7MoxhXdxk7i4Kt","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.ping_response"} -18:59:41 Received with rtt 707.875µs -{"name":"EchoService","id":"tNoopzL5Sp1M4qJZdhdy0a","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.ping_response"} +13:03:04 Received with rtt 1.320291ms +{"name":"EchoService","id":"x3Yuiq7g7MoxhXdxk7i4Lf","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.ping_response"} ``` A specific service instance info can be retrieved: ```sh -nats req '$SRV.INFO.EchoService.tNoopzL5Sp1M4qJZdhdxqC' '' - -19:40:06 Sending request on "$SRV.INFO.EchoService.tNoopzL5Sp1M4qJZdhdxqC" -19:40:06 Received with rtt 282.375µs -{"name":"EchoService","id":"tNoopzL5Sp1M4qJZdhdxqC","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.info_response","description":"","subjects":["svc.echo"]} +nats req '$SRV.INFO.EchoService.x3Yuiq7g7MoxhXdxk7i4K7' '' | jq + +13:04:19 Sending request on "$SRV.INFO.EchoService.x3Yuiq7g7MoxhXdxk7i4K7" +13:04:19 Received with rtt 318.875µs +{ + "name": "EchoService", + "id": "x3Yuiq7g7MoxhXdxk7i4K7", + "version": "1.0.0", + "metadata": {}, + "type": "io.nats.micro.v1.info_response", + "description": "", + "endpoints": [ + { + "name": "default", + "subject": "svc.echo", + "queue_group": "q", + "metadata": null + } + ] +} ``` To get statistics for this service: ```sh -nats req '$SRV.STATS.EchoService.tNoopzL5Sp1M4qJZdhdxqC' '' - -19:40:47 Sending request on "$SRV.STATS.EchoService.tNoopzL5Sp1M4qJZdhdxqC" -19:40:47 Received with rtt 421.666µs -{"name":"EchoService","id":"tNoopzL5Sp1M4qJZdhdxqC","version":"1.0.0","metadata":{},"type":"io.nats.micro.v1.stats_response","started":"2023-05-22T16:59:39.938514Z","endpoints":[{"name":"default","subject":"svc.echo","metadata":null,"num_requests":0,"num_errors":0,"last_error":"","processing_time":0,"average_processing_time":0}]} +nats req '$SRV.STATS.EchoService.x3Yuiq7g7MoxhXdxk7i4K7' '' | jq + +13:04:46 Sending request on "$SRV.STATS.EchoService.x3Yuiq7g7MoxhXdxk7i4K7" +13:04:46 Received with rtt 678.25µs +{ + "name": "EchoService", + "id": "x3Yuiq7g7MoxhXdxk7i4K7", + "version": "1.0.0", + "metadata": {}, + "type": "io.nats.micro.v1.stats_response", + "started": "2024-09-24T11:02:55.564771Z", + "endpoints": [ + { + "name": "default", + "subject": "svc.echo", + "queue_group": "q", + "num_requests": 0, + "num_errors": 0, + "last_error": "", + "processing_time": 0, + "average_processing_time": 0 + } + ] +} ``` ## Examples diff --git a/micro/request.go b/micro/request.go index 380f4945c..e282d4576 100644 --- a/micro/request.go +++ b/micro/request.go @@ -58,6 +58,9 @@ type ( // Subject returns underlying NATS message subject. Subject() string + + // Reply returns underlying NATS message reply subject. + Reply() string } // Headers is a wrapper around [*nats.Header] @@ -186,6 +189,11 @@ func (r *request) Subject() string { return r.msg.Subject } +// Reply returns underlying NATS message reply subject. +func (r *request) Reply() string { + return r.msg.Reply +} + // Get gets the first value associated with the given key. // It is case-sensitive. func (h Headers) Get(key string) string { diff --git a/micro/service.go b/micro/service.go index d6a8c82f4..0f98e0edd 100644 --- a/micro/service.go +++ b/micro/service.go @@ -632,7 +632,7 @@ func (s *service) addInternalHandler(nc *nats.Conn, verb Verb, kind, id, name st handler(&request{msg: msg}) }) if err != nil { - if stopErr := s.Stop(); err != nil { + if stopErr := s.Stop(); stopErr != nil { return errors.Join(err, fmt.Errorf("stopping service: %w", stopErr)) } return err diff --git a/micro/test/service_test.go b/micro/test/service_test.go index b9e004946..1dfafa072 100644 --- a/micro/test/service_test.go +++ b/micro/test/service_test.go @@ -504,7 +504,7 @@ func TestAddService(t *testing.T) { } if test.givenConfig.ErrorHandler != nil { - go nc.Opts.AsyncErrorCB(nc, &nats.Subscription{Subject: test.asyncErrorSubject}, fmt.Errorf("oops")) + go nc.Opts.AsyncErrorCB(nc, &nats.Subscription{Subject: test.asyncErrorSubject}, errors.New("oops")) select { case <-errService: case <-time.After(1 * time.Second): @@ -536,7 +536,7 @@ func TestAddService(t *testing.T) { } } if test.natsErrorHandler != nil { - go nc.Opts.AsyncErrorCB(nc, &nats.Subscription{Subject: test.asyncErrorSubject}, fmt.Errorf("oops")) + go nc.Opts.AsyncErrorCB(nc, &nats.Subscription{Subject: test.asyncErrorSubject}, errors.New("oops")) select { case <-errService: t.Fatalf("Expected to restore nats error handler") @@ -634,7 +634,7 @@ func TestErrHandlerSubjectMatch(t *testing.T) { } defer svc.Stop() - go nc.Opts.AsyncErrorCB(nc, &nats.Subscription{Subject: test.errSubject}, fmt.Errorf("oops")) + go nc.Opts.AsyncErrorCB(nc, &nats.Subscription{Subject: test.errSubject}, errors.New("oops")) if test.expectServiceErr { select { case <-errChan: diff --git a/nats.go b/nats.go index 0be428932..3f12d61e2 100644 --- a/nats.go +++ b/nats.go @@ -1,4 +1,4 @@ -// Copyright 2012-2023 The NATS Authors +// Copyright 2012-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -47,7 +47,7 @@ import ( // Default Constants const ( - Version = "1.33.0" + Version = "1.37.0" DefaultURL = "nats://127.0.0.1:4222" DefaultPort = 4222 DefaultMaxReconnect = 60 @@ -86,6 +86,9 @@ const ( // MAX_CONNECTIONS_ERR is for when nats server denies the connection due to server max_connections limit MAX_CONNECTIONS_ERR = "maximum connections exceeded" + + // MAX_SUBSCRIPTIONS_ERR is for when nats server denies the connection due to server subscriptions limit + MAX_SUBSCRIPTIONS_ERR = "maximum subscriptions exceeded" ) // Errors @@ -131,6 +134,7 @@ var ( ErrNkeysNotSupported = errors.New("nats: nkeys not supported by the server") ErrStaleConnection = errors.New("nats: " + STALE_CONNECTION) ErrTokenAlreadySet = errors.New("nats: token and token handler both set") + ErrUserInfoAlreadySet = errors.New("nats: cannot set user info callback and user/pass") ErrMsgNotBound = errors.New("nats: message is not bound to subscription/connection") ErrMsgNoReply = errors.New("nats: message does not have a reply") ErrClientIPNotSupported = errors.New("nats: client IP not supported by this server") @@ -140,6 +144,7 @@ var ( ErrNoResponders = errors.New("nats: no responders available for request") ErrMaxConnectionsExceeded = errors.New("nats: server maximum connections exceeded") ErrConnectionNotTLS = errors.New("nats: connection is not tls") + ErrMaxSubscriptionsExceeded = errors.New("nats: server maximum subscriptions exceeded") ) // GetDefaultOptions returns default configuration options for the client. @@ -160,7 +165,7 @@ func GetDefaultOptions() Options { } } -// DEPRECATED: Use GetDefaultOptions() instead. +// Deprecated: Use GetDefaultOptions() instead. // DefaultOptions is not safe for use by multiple clients. // For details see #308. var DefaultOptions = GetDefaultOptions() @@ -230,6 +235,9 @@ type SignatureHandler func([]byte) ([]byte, error) // AuthTokenHandler is used to generate a new token. type AuthTokenHandler func() string +// UserInfoCB is used to pass the username and password when establishing connection. +type UserInfoCB func() (string, string) + // ReconnectDelayHandler is used to get from the user the desired // delay the library should pause before attempting to reconnect // again. Note that this is invoked after the library tried the @@ -386,7 +394,7 @@ type Options struct { // DisconnectedCB sets the disconnected handler that is called // whenever the connection is disconnected. // Will not be called if DisconnectedErrCB is set - // DEPRECATED. Use DisconnectedErrCB which passes error that caused + // Deprecated. Use DisconnectedErrCB which passes error that caused // the disconnect event. DisconnectedCB ConnHandler @@ -443,6 +451,9 @@ type Options struct { // Password sets the password to be used when connecting to a server. Password string + // UserInfo sets the callback handler that will fetch the username and password. + UserInfo UserInfoCB + // Token sets the token to be used when connecting to a server. Token string @@ -450,7 +461,7 @@ type Options struct { TokenHandler AuthTokenHandler // Dialer allows a custom net.Dialer when forming connections. - // DEPRECATED: should use CustomDialer instead. + // Deprecated: should use CustomDialer instead. Dialer *net.Dialer // CustomDialer allows to specify a custom dialer (not necessarily @@ -566,7 +577,6 @@ type Conn struct { respSub string // The wildcard subject respSubPrefix string // the wildcard prefix including trailing . respSubLen int // the length of the wildcard prefix excluding trailing . - respScanf string // The scanf template to extract mux token respMux *Subscription // A single response subscription respMap map[string]chan *Msg // Request map for the response msg channels respRand *rand.Rand // Used for generating suffix @@ -608,14 +618,17 @@ type Subscription struct { // For holding information about a JetStream consumer. jsi *jsSub - delivered uint64 - max uint64 - conn *Conn - mcb MsgHandler - mch chan *Msg - closed bool - sc bool - connClosed bool + delivered uint64 + max uint64 + conn *Conn + mcb MsgHandler + mch chan *Msg + closed bool + sc bool + connClosed bool + draining bool + status SubStatus + statListeners map[chan SubStatus][]SubStatus // Type of Subscription typ SubscriptionType @@ -636,6 +649,30 @@ type Subscription struct { dropped int } +// Status represents the state of the connection. +type SubStatus int + +const ( + SubscriptionActive = SubStatus(iota) + SubscriptionDraining + SubscriptionClosed + SubscriptionSlowConsumer +) + +func (s SubStatus) String() string { + switch s { + case SubscriptionActive: + return "Active" + case SubscriptionDraining: + return "Draining" + case SubscriptionClosed: + return "Closed" + case SubscriptionSlowConsumer: + return "SlowConsumer" + } + return "unknown status" +} + // Msg represents a message delivered by NATS. This structure is used // by Subscribers and PublishMsg(). // @@ -1082,7 +1119,7 @@ func DisconnectErrHandler(cb ConnErrHandler) Option { } // DisconnectHandler is an Option to set the disconnected handler. -// DEPRECATED: Use DisconnectErrHandler. +// Deprecated: Use DisconnectErrHandler. func DisconnectHandler(cb ConnHandler) Option { return func(o *Options) error { o.DisconnectedCB = cb @@ -1140,6 +1177,13 @@ func UserInfo(user, password string) Option { } } +func UserInfoHandler(cb UserInfoCB) Option { + return func(o *Options) error { + o.UserInfo = cb + return nil + } +} + // Token is an Option to set the token to use // when a token is not included directly in the URLs // and when a token handler is not provided. @@ -1254,7 +1298,7 @@ func SyncQueueLen(max int) Option { // Dialer is an Option to set the dialer which will be used when // attempting to establish a connection. -// DEPRECATED: Should use CustomDialer instead. +// Deprecated: Should use CustomDialer instead. func Dialer(dialer *net.Dialer) Option { return func(o *Options) error { o.Dialer = dialer @@ -1333,7 +1377,7 @@ func ProxyPath(path string) Option { func CustomInboxPrefix(p string) Option { return func(o *Options) error { if p == "" || strings.Contains(p, ">") || strings.Contains(p, "*") || strings.HasSuffix(p, ".") { - return fmt.Errorf("nats: invalid custom prefix") + return errors.New("nats: invalid custom prefix") } o.InboxPrefix = p return nil @@ -1371,7 +1415,7 @@ func TLSHandshakeFirst() Option { // Handler processing // SetDisconnectHandler will set the disconnect event handler. -// DEPRECATED: Use SetDisconnectErrHandler +// Deprecated: Use SetDisconnectErrHandler func (nc *Conn) SetDisconnectHandler(dcb ConnHandler) { if nc == nil { return @@ -1487,7 +1531,7 @@ func processUrlString(url string) []string { urls := strings.Split(url, ",") var j int for _, s := range urls { - u := strings.TrimSpace(s) + u := strings.TrimSuffix(strings.TrimSpace(s), "/") if len(u) > 0 { urls[j] = u j++ @@ -1788,7 +1832,7 @@ func (nc *Conn) addURLToPool(sURL string, implicit, saveTLSName bool) error { if len(nc.srvPool) == 0 { nc.ws = isWS } else if isWS && !nc.ws || !isWS && nc.ws { - return fmt.Errorf("mixing of websocket and non websocket URLs is not allowed") + return errors.New("mixing of websocket and non websocket URLs is not allowed") } var tlsName string @@ -2135,6 +2179,47 @@ func (nc *Conn) waitForExits() { nc.wg.Wait() } +// ForceReconnect forces a reconnect attempt to the server. +// This is a non-blocking call and will start the reconnect +// process without waiting for it to complete. +// +// If the connection is already in the process of reconnecting, +// this call will force an immediate reconnect attempt (bypassing +// the current reconnect delay). +func (nc *Conn) ForceReconnect() error { + nc.mu.Lock() + defer nc.mu.Unlock() + + if nc.isClosed() { + return ErrConnectionClosed + } + if nc.isReconnecting() { + // if we're already reconnecting, force a reconnect attempt + // even if we're in the middle of a backoff + if nc.rqch != nil { + close(nc.rqch) + } + return nil + } + + // Clear any queued pongs + nc.clearPendingFlushCalls() + + // Clear any queued and blocking requests. + nc.clearPendingRequestCalls() + + // Stop ping timer if set. + nc.stopPingTimer() + + // Go ahead and make sure we have flushed the outbound + nc.bw.flush() + nc.conn.Close() + + nc.changeConnStatus(RECONNECTING) + go nc.doReconnect(nil, true) + return nil +} + // ConnectedUrl reports the connected server's URL func (nc *Conn) ConnectedUrl() string { if nc == nil { @@ -2394,7 +2479,7 @@ func (nc *Conn) connect() (bool, error) { nc.setup() nc.changeConnStatus(RECONNECTING) nc.bw.switchToPending() - go nc.doReconnect(ErrNoServers) + go nc.doReconnect(ErrNoServers, false) err = nil } else { nc.current = nil @@ -2496,6 +2581,13 @@ func (nc *Conn) connectProto() (string, error) { pass = o.Password token = o.Token nkey = o.Nkey + + if nc.Opts.UserInfo != nil { + if user != _EMPTY_ || pass != _EMPTY_ { + return _EMPTY_, ErrUserInfoAlreadySet + } + user, pass = nc.Opts.UserInfo() + } } // Look for user jwt. @@ -2694,7 +2786,7 @@ func (nc *Conn) stopPingTimer() { // Try to reconnect using the option parameters. // This function assumes we are allowed to reconnect. -func (nc *Conn) doReconnect(err error) { +func (nc *Conn) doReconnect(err error, forceReconnect bool) { // We want to make sure we have the other watchers shutdown properly // here before we proceed past this point. nc.waitForExits() @@ -2750,7 +2842,8 @@ func (nc *Conn) doReconnect(err error) { break } - doSleep := i+1 >= len(nc.srvPool) + doSleep := i+1 >= len(nc.srvPool) && !forceReconnect + forceReconnect = false nc.mu.Unlock() if !doSleep { @@ -2777,6 +2870,12 @@ func (nc *Conn) doReconnect(err error) { select { case <-rqch: rt.Stop() + + // we need to reset the rqch channel to avoid + // closing a closed channel in the next iteration + nc.mu.Lock() + nc.rqch = make(chan struct{}) + nc.mu.Unlock() case <-rt.C: } } @@ -2846,18 +2945,19 @@ func (nc *Conn) doReconnect(err error) { // Done with the pending buffer nc.bw.doneWithPending() - // This is where we are truly connected. - nc.status = CONNECTED + // Queue up the correct callback. If we are in initial connect state + // (using retry on failed connect), we will call the ConnectedCB, + // otherwise the ReconnectedCB. + if nc.Opts.ReconnectedCB != nil && !nc.initc { + nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) }) + } else if nc.Opts.ConnectedCB != nil && nc.initc { + nc.ach.push(func() { nc.Opts.ConnectedCB(nc) }) + } // If we are here with a retry on failed connect, indicate that the // initial connect is now complete. nc.initc = false - // Queue up the reconnect callback. - if nc.Opts.ReconnectedCB != nil { - nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) }) - } - // Release lock here, we will return below. nc.mu.Unlock() @@ -2877,11 +2977,11 @@ func (nc *Conn) doReconnect(err error) { // processOpErr handles errors from reading or parsing the protocol. // The lock should not be held entering this function. -func (nc *Conn) processOpErr(err error) { +func (nc *Conn) processOpErr(err error) bool { nc.mu.Lock() + defer nc.mu.Unlock() if nc.isConnecting() || nc.isClosed() || nc.isReconnecting() { - nc.mu.Unlock() - return + return false } if nc.Opts.AllowReconnect && nc.status == CONNECTED { @@ -2900,15 +3000,13 @@ func (nc *Conn) processOpErr(err error) { // Clear any queued pongs, e.g. pending flush calls. nc.clearPendingFlushCalls() - go nc.doReconnect(err) - nc.mu.Unlock() - return + go nc.doReconnect(err, false) + return false } nc.changeConnStatus(DISCONNECTED) nc.err = err - nc.mu.Unlock() - nc.close(CLOSED, true, nil) + return true } // dispatch is responsible for calling any async callbacks @@ -3005,7 +3103,9 @@ func (nc *Conn) readLoop() { err = nc.parse(buf) } if err != nil { - nc.processOpErr(err) + if shouldClose := nc.processOpErr(err); shouldClose { + nc.close(CLOSED, true, nil) + } break } } @@ -3292,6 +3392,9 @@ func (nc *Conn) processMsg(data []byte) { } // Clear any SlowConsumer status. + if sub.sc { + sub.changeSubStatus(SubscriptionActive) + } sub.sc = false sub.mu.Unlock() @@ -3315,8 +3418,9 @@ slowConsumer: sub.pMsgs-- sub.pBytes -= len(m.Data) } - sub.mu.Unlock() if sc { + sub.changeSubStatus(SubscriptionSlowConsumer) + sub.mu.Unlock() // Now we need connection's lock and we may end-up in the situation // that we were trying to avoid, except that in this case, the client // is already experiencing client-side slow consumer situation. @@ -3326,18 +3430,22 @@ slowConsumer: nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) }) } nc.mu.Unlock() + } else { + sub.mu.Unlock() } } -// processPermissionsViolation is called when the server signals a subject -// permissions violation on either publish or subscribe. -func (nc *Conn) processPermissionsViolation(err string) { +// processTransientError is called when the server signals a non terminal error +// which does not close the connection or trigger a reconnect. +// This will trigger the async error callback if set. +// These errors include the following: +// - permissions violation on publish or subscribe +// - maximum subscriptions exceeded +func (nc *Conn) processTransientError(err error) { nc.mu.Lock() - // create error here so we can pass it as a closure to the async cb dispatcher. - e := errors.New("nats: " + err) - nc.err = e + nc.err = err if nc.Opts.AsyncErrorCB != nil { - nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, e) }) + nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, err) }) } nc.mu.Unlock() } @@ -3569,15 +3677,17 @@ func (nc *Conn) processErr(ie string) { // convert to lower case. e := strings.ToLower(ne) - close := false + var close bool // FIXME(dlc) - process Slow Consumer signals special. if e == STALE_CONNECTION { - nc.processOpErr(ErrStaleConnection) + close = nc.processOpErr(ErrStaleConnection) } else if e == MAX_CONNECTIONS_ERR { - nc.processOpErr(ErrMaxConnectionsExceeded) + close = nc.processOpErr(ErrMaxConnectionsExceeded) } else if strings.HasPrefix(e, PERMISSIONS_ERR) { - nc.processPermissionsViolation(ne) + nc.processTransientError(fmt.Errorf("nats: %s", ne)) + } else if strings.HasPrefix(e, MAX_SUBSCRIPTIONS_ERR) { + nc.processTransientError(ErrMaxSubscriptionsExceeded) } else if authErr := checkAuthError(e); authErr != nil { nc.mu.Lock() close = nc.processAuthError(authErr) @@ -3721,7 +3831,7 @@ func readMIMEHeader(tp *textproto.Reader) (textproto.MIMEHeader, error) { } // Process key fetching original case. - i := bytes.IndexByte([]byte(kv), ':') + i := strings.IndexByte(kv, ':') if i < 0 { return nil, ErrBadHeaderMsg } @@ -3734,8 +3844,7 @@ func readMIMEHeader(tp *textproto.Reader) (textproto.MIMEHeader, error) { for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { i++ } - value := string(kv[i:]) - m[key] = append(m[key], value) + m[key] = append(m[key], kv[i:]) if err != nil { return m, err } @@ -3938,7 +4047,6 @@ func (nc *Conn) createNewRequestAndSend(subj string, hdr, data []byte) (chan *Ms nc.mu.Unlock() return nil, token, err } - nc.respScanf = strings.Replace(nc.respSub, "*", "%s", -1) nc.respMux = s } nc.mu.Unlock() @@ -4119,16 +4227,14 @@ func (nc *Conn) NewRespInbox() string { } // respToken will return the last token of a literal response inbox -// which we use for the message channel lookup. This needs to do a -// scan to protect itself against the server changing the subject. +// which we use for the message channel lookup. This needs to verify the subject +// prefix matches to protect itself against the server changing the subject. // Lock should be held. func (nc *Conn) respToken(respInbox string) string { - var token string - n, err := fmt.Sscanf(respInbox, nc.respScanf, &token) - if err != nil || n != 1 { - return "" + if token, found := strings.CutPrefix(respInbox, nc.respSubPrefix); found { + return token } - return token + return "" } // Subscribe will express interest in the given subject. The subject @@ -4298,6 +4404,7 @@ func (nc *Conn) subscribeLocked(subj, queue string, cb MsgHandler, ch chan *Msg, nc.kickFlusher() } + sub.changeSubStatus(SubscriptionActive) return sub, nil } @@ -4341,6 +4448,7 @@ func (nc *Conn) removeSub(s *Subscription) { } // Mark as invalid s.closed = true + s.changeSubStatus(SubscriptionClosed) if s.pCond != nil { s.pCond.Broadcast() } @@ -4410,6 +4518,91 @@ func (s *Subscription) Drain() error { return conn.unsubscribe(s, 0, true) } +// IsDraining returns a boolean indicating whether the subscription +// is being drained. +// This will return false if the subscription has already been closed. +func (s *Subscription) IsDraining() bool { + if s == nil { + return false + } + s.mu.Lock() + defer s.mu.Unlock() + return s.draining +} + +// StatusChanged returns a channel on which given list of subscription status +// changes will be sent. If no status is provided, all status changes will be sent. +// Available statuses are SubscriptionActive, SubscriptionDraining, SubscriptionClosed, +// and SubscriptionSlowConsumer. +// The returned channel will be closed when the subscription is closed. +func (s *Subscription) StatusChanged(statuses ...SubStatus) <-chan SubStatus { + if len(statuses) == 0 { + statuses = []SubStatus{SubscriptionActive, SubscriptionDraining, SubscriptionClosed, SubscriptionSlowConsumer} + } + ch := make(chan SubStatus, 10) + for _, status := range statuses { + s.registerStatusChangeListener(status, ch) + // initial status + if status == s.status { + ch <- status + } + } + return ch +} + +// registerStatusChangeListener registers a channel waiting for a specific status change event. +// Status change events are non-blocking - if no receiver is waiting for the status change, +// it will not be sent on the channel. Closed channels are ignored. +func (s *Subscription) registerStatusChangeListener(status SubStatus, ch chan SubStatus) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statListeners == nil { + s.statListeners = make(map[chan SubStatus][]SubStatus) + } + if _, ok := s.statListeners[ch]; !ok { + s.statListeners[ch] = make([]SubStatus, 0) + } + s.statListeners[ch] = append(s.statListeners[ch], status) +} + +// sendStatusEvent sends subscription status event to all channels. +// If there is no listener, sendStatusEvent +// will not block. Lock should be held entering. +func (s *Subscription) sendStatusEvent(status SubStatus) { + for ch, statuses := range s.statListeners { + if !containsStatus(statuses, status) { + continue + } + // only send event if someone's listening + select { + case ch <- status: + default: + } + if status == SubscriptionClosed { + close(ch) + } + } +} + +func containsStatus(statuses []SubStatus, status SubStatus) bool { + for _, s := range statuses { + if s == status { + return true + } + } + return false +} + +// changeSubStatus changes subscription status and sends events +// to all listeners. Lock should be held entering. +func (s *Subscription) changeSubStatus(status SubStatus) { + if s == nil { + return + } + s.sendStatusEvent(status) + s.status = status +} + // Unsubscribe will remove interest in the given subject. // // For a JetStream subscription, if the library has created the JetStream @@ -4448,6 +4641,11 @@ func (s *Subscription) Unsubscribe() error { // checkDrained will watch for a subscription to be fully drained // and then remove it. func (nc *Conn) checkDrained(sub *Subscription) { + defer func() { + sub.mu.Lock() + defer sub.mu.Unlock() + sub.draining = false + }() if nc == nil || sub == nil { return } @@ -4557,6 +4755,10 @@ func (nc *Conn) unsubscribe(sub *Subscription, max int, drainMode bool) error { } if drainMode { + s.mu.Lock() + s.draining = true + sub.changeSubStatus(SubscriptionDraining) + s.mu.Unlock() go nc.checkDrained(sub) } @@ -4659,6 +4861,7 @@ func (s *Subscription) validateNextMsgState(pullSubInternal bool) error { return ErrSyncSubRequired } if s.sc { + s.changeSubStatus(SubscriptionActive) s.sc = false return ErrSlowConsumer } @@ -4728,7 +4931,8 @@ func (s *Subscription) processNextMsgDelivered(msg *Msg) error { } // Queued returns the number of queued messages in the client for this subscription. -// DEPRECATED: Use Pending() +// +// Deprecated: Use Pending() func (s *Subscription) QueuedMsgs() (int, error) { m, _, err := s.Pending() return int(m), err @@ -4932,7 +5136,9 @@ func (nc *Conn) processPingTimer() { nc.pout++ if nc.pout > nc.Opts.MaxPingsOut { nc.mu.Unlock() - nc.processOpErr(ErrStaleConnection) + if shouldClose := nc.processOpErr(ErrStaleConnection); shouldClose { + nc.close(CLOSED, true, nil) + } return } @@ -5309,7 +5515,7 @@ func (nc *Conn) drainConnection() { // Drain will put a connection into a drain state. All subscriptions will // immediately be put into a drain state. Upon completion, the publishers // will be drained and can not publish any additional messages. Upon draining -// of the publishers, the connection will be closed. Use the ClosedCB() +// of the publishers, the connection will be closed. Use the ClosedCB // option to know when the connection has moved from draining to closed. // // See note in Subscription.Drain for JetStream subscriptions. @@ -5617,7 +5823,7 @@ func NkeyOptionFromSeed(seedFile string) (Option, error) { return nil, err } if !nkeys.IsValidPublicUserKey(pub) { - return nil, fmt.Errorf("nats: Not a valid nkey user seed") + return nil, errors.New("nats: Not a valid nkey user seed") } sigCB := func(nonce []byte) ([]byte, error) { return sigHandler(nonce, seedFile) diff --git a/nats_test.go b/nats_test.go index 7dcdce2cf..8f4546a77 100644 --- a/nats_test.go +++ b/nats_test.go @@ -94,10 +94,11 @@ func checkErrChannel(t *testing.T, errCh chan error) { } func TestVersionMatchesTag(t *testing.T) { - tag := os.Getenv("TRAVIS_TAG") - if tag == "" { + refType := os.Getenv("GITHUB_REF_TYPE") + if refType != "tag" { t.SkipNow() } + tag := os.Getenv("GITHUB_REF_NAME") // We expect a tag of the form vX.Y.Z. If that's not the case, // we need someone to have a look. So fail if first letter is not // a `v` @@ -229,6 +230,7 @@ func TestSimplifiedURLs(t *testing.T) { { "nats", []string{ + "nats://host1:1234/", "nats://host1:1234", "nats://host2:", "nats://host3", @@ -242,6 +244,7 @@ func TestSimplifiedURLs(t *testing.T) { "[17:18:19:20]:1234", }, []string{ + "nats://host1:1234/", "nats://host1:1234", "nats://host2:4222", "nats://host3:4222", @@ -434,6 +437,7 @@ func TestUrlArgument(t *testing.T) { check("nats://localhost:1222 ", oneExpected) check(" nats://localhost:1222", oneExpected) check(" nats://localhost:1222 ", oneExpected) + check("nats://localhost:1222/", oneExpected) var multiExpected = []string{ "nats://localhost:1222", @@ -445,6 +449,7 @@ func TestUrlArgument(t *testing.T) { check("nats://localhost:1222, nats://localhost:1223, nats://localhost:1224", multiExpected) check(" nats://localhost:1222, nats://localhost:1223, nats://localhost:1224 ", multiExpected) check("nats://localhost:1222, nats://localhost:1223 ,nats://localhost:1224", multiExpected) + check("nats://localhost:1222/,nats://localhost:1223/,nats://localhost:1224/", multiExpected) } func TestParserPing(t *testing.T) { diff --git a/netchan.go b/netchan.go index 6b13690b4..3722d9f1b 100644 --- a/netchan.go +++ b/netchan.go @@ -23,6 +23,8 @@ import ( // Data will be encoded and decoded via the EncodedConn and its associated encoders. // BindSendChan binds a channel for send operations to NATS. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) BindSendChan(subject string, channel any) error { chVal := reflect.ValueOf(channel) if chVal.Kind() != reflect.Chan { @@ -61,11 +63,15 @@ func chPublish(c *EncodedConn, chVal reflect.Value, subject string) { } // BindRecvChan binds a channel for receive operations from NATS. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) BindRecvChan(subject string, channel any) (*Subscription, error) { return c.bindRecvChan(subject, _EMPTY_, channel) } // BindRecvQueueChan binds a channel for queue-based receive operations from NATS. +// +// Deprecated: Encoded connections are no longer supported. func (c *EncodedConn) BindRecvQueueChan(subject, queue string, channel any) (*Subscription, error) { return c.bindRecvChan(subject, queue, channel) } diff --git a/object.go b/object.go index 2b818ac86..75ceaa8e9 100644 --- a/object.go +++ b/object.go @@ -694,7 +694,12 @@ func (obs *obs) Get(name string, opts ...GetObjectOpt) (ObjectResult, error) { } chunkSubj := fmt.Sprintf(objChunksPreTmpl, obs.name, info.NUID) - _, err = obs.js.Subscribe(chunkSubj, processChunk, OrderedConsumer()) + streamName := fmt.Sprintf(objNameTmpl, obs.name) + subscribeOpts := []SubOpt{ + OrderedConsumer(), + BindStream(streamName), + } + _, err = obs.js.Subscribe(chunkSubj, processChunk, subscribeOpts...) if err != nil { return nil, err } @@ -1110,7 +1115,8 @@ func (obs *obs) Watch(opts ...WatchOpt) (ObjectWatcher, error) { } // Used ordered consumer to deliver results. - subOpts := []SubOpt{OrderedConsumer()} + streamName := fmt.Sprintf(objNameTmpl, obs.name) + subOpts := []SubOpt{OrderedConsumer(), BindStream(streamName)} if !o.includeHistory { subOpts = append(subOpts, DeliverLastPerSubject()) } diff --git a/scripts/cov.sh b/scripts/cov.sh index 80828cb16..fa0fdc19d 100755 --- a/scripts/cov.sh +++ b/scripts/cov.sh @@ -5,14 +5,15 @@ rm -rf ./cov mkdir cov go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/nats.out . -tags=skip_no_race_tests go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/test.out -coverpkg=github.com/nats-io/nats.go ./test -tags=skip_no_race_tests,internal_testing -go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/jetstream.out -coverpkg=github.com/nats-io/nats.go/jetstream ./jetstream/test -tags=skip_no_race_tests +go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/jetstream.out -coverpkg=github.com/nats-io/nats.go/jetstream ./jetstream/... +go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/service.out -coverpkg=github.com/nats-io/nats.go/micro ./micro/... go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/builtin.out -coverpkg=github.com/nats-io/nats.go/encoders/builtin ./test -run EncBuiltin -tags=skip_no_race_tests go test -modfile=go_test.mod --failfast -vet=off -v -covermode=atomic -coverprofile=./cov/protobuf.out -coverpkg=github.com/nats-io/nats.go/encoders/protobuf ./test -run EncProto -tags=skip_no_race_tests gocovmerge ./cov/*.out > acc.out rm -rf ./cov # Without argument, launch browser results. We are going to push to coveralls only -# from Travis.yml and after success of the build (and result of pushing will not affect +# from ci.yml and after success of the build (and result of pushing will not affect # build result). if [[ $1 == "" ]]; then go tool cover -html=acc.out diff --git a/test/auth_test.go b/test/auth_test.go index a55b51217..8fd0982c2 100644 --- a/test/auth_test.go +++ b/test/auth_test.go @@ -17,6 +17,8 @@ import ( "errors" "fmt" "io/fs" + "net" + "os" "strings" "sync/atomic" "testing" @@ -377,3 +379,75 @@ func TestConnectMissingCreds(t *testing.T) { t.Fatalf("Expected not exists error, got: %v", err) } } + +func TestUserInfoHandler(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + accounts: { + A { + users: [{ user: "pp", password: "foo" }] + } + } +`)) + defer os.Remove(conf) + + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + user, pass := "pp", "foo" + userInfoCB := func() (string, string) { + return user, pass + } + + // check that we cannot set the user info twice + _, err := nats.Connect(s.ClientURL(), nats.UserInfo("pp", "foo"), nats.UserInfoHandler(userInfoCB)) + if !errors.Is(err, nats.ErrUserInfoAlreadySet) { + t.Fatalf("Expected ErrUserInfoAlreadySet, got: %v", err) + } + + addr, ok := s.Addr().(*net.TCPAddr) + if !ok { + t.Fatalf("Expected a TCP address, got %T", addr) + } + + // check that user/pass from url takes precedence + _, err = nats.Connect(fmt.Sprintf("nats://bad:bad@localhost:%d", addr.Port), + nats.UserInfoHandler(userInfoCB)) + if !errors.Is(err, nats.ErrAuthorization) { + t.Fatalf("Expected ErrAuthorization, got: %v", err) + } + + // connect using the handler + nc, err := nats.Connect(s.ClientURL(), + nats.ReconnectWait(100*time.Millisecond), + nats.UserInfoHandler(userInfoCB)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + // now change the password and reload the server + newConfig := []byte(` + listen: 127.0.0.1:-1 + accounts: { + A { + users: [{ user: "dd", password: "bar" }] + } + } +`) + if err := os.WriteFile(conf, newConfig, 0666); err != nil { + t.Fatalf("Error writing conf file: %v", err) + } + + // update the user info used by the callback + user, pass = "dd", "bar" + + status := nc.StatusChanged(nats.CONNECTED) + + if err := s.Reload(); err != nil { + t.Fatalf("Error on reload: %v", err) + } + + // we should get a reconnected event meaning the new credentials were used + WaitOnChannel(t, status, nats.CONNECTED) +} diff --git a/test/basic_test.go b/test/basic_test.go index 75a187d05..3161930a6 100644 --- a/test/basic_test.go +++ b/test/basic_test.go @@ -464,8 +464,8 @@ func TestQueueSubscriber(t *testing.T) { omsg := []byte("Hello World") nc.Publish("foo", omsg) nc.Flush() - r1, _ := s1.QueuedMsgs() - r2, _ := s2.QueuedMsgs() + r1, _, _ := s1.Pending() + r2, _, _ := s2.Pending() if (r1 + r2) != 1 { t.Fatal("Received too many messages for multiple queue subscribers") } @@ -479,8 +479,8 @@ func TestQueueSubscriber(t *testing.T) { } nc.Flush() v := uint(float32(total) * 0.15) - r1, _ = s1.QueuedMsgs() - r2, _ = s2.QueuedMsgs() + r1, _, _ = s1.Pending() + r2, _, _ = s2.Pending() if r1+r2 != total { t.Fatalf("Incorrect number of messages: %d vs %d", (r1 + r2), total) } @@ -1032,7 +1032,7 @@ func TestNilConnection(t *testing.T) { if _, err := sub.NextMsg(time.Millisecond); err == nil || err != nats.ErrBadSubscription { t.Fatalf("Expected ErrBadSubscription error, got %v\n", err) } - if _, err := sub.QueuedMsgs(); err == nil || err != nats.ErrBadSubscription { + if _, _, err := sub.Pending(); err == nil || err != nats.ErrBadSubscription { t.Fatalf("Expected ErrBadSubscription error, got %v\n", err) } if _, _, err := sub.Pending(); err == nil || err != nats.ErrBadSubscription { diff --git a/test/cluster_test.go b/test/cluster_test.go index a075296f8..05fab49ef 100644 --- a/test/cluster_test.go +++ b/test/cluster_test.go @@ -682,7 +682,7 @@ func (d *checkPoolUpdatedDialer) Dial(network, address string) (net.Conn, error) doReal = true } else if d.final { d.ra++ - return nil, fmt.Errorf("On purpose") + return nil, errors.New("On purpose") } else { d.ra++ if d.ra == 15 { @@ -698,12 +698,12 @@ func (d *checkPoolUpdatedDialer) Dial(network, address string) (net.Conn, error) d.conn = c return c, nil } - return nil, fmt.Errorf("On purpose") + return nil, errors.New("On purpose") } func TestServerPoolUpdatedWhenRouteGoesAway(t *testing.T) { if err := serverVersionAtLeast(1, 0, 7); err != nil { - t.Skipf(err.Error()) + t.Skip(err.Error()) } s1Opts := test.DefaultTestOptions s1Opts.Host = "127.0.0.1" diff --git a/test/compat_test.go b/test/compat_test.go index 3c2751480..d7f3cb7fb 100644 --- a/test/compat_test.go +++ b/test/compat_test.go @@ -45,10 +45,7 @@ type objectStepConfig[T any] struct { func TestCompatibilityObjectStoreDefaultBucket(t *testing.T) { t.Parallel() - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -85,10 +82,7 @@ func TestCompatibilityObjectStoreDefaultBucket(t *testing.T) { func TestCompatibilityObjectStoreCustomBucket(t *testing.T) { t.Parallel() - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -131,10 +125,7 @@ func TestCompatibilityObjectStoreGetObject(t *testing.T) { Object string `json:"object"` } - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -186,10 +177,7 @@ func TestCompatibilityObjectStoreGetObject(t *testing.T) { func TestCompatibilityObjectStorePutObject(t *testing.T) { t.Parallel() - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -239,10 +227,7 @@ func TestCompatibilityObjectStorePutObject(t *testing.T) { func TestCompatibilityObjectStoreUpdateMetadata(t *testing.T) { t.Parallel() - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -287,10 +272,7 @@ func TestCompatibilityObjectStoreWatch(t *testing.T) { Object string `json:"object"` } - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -365,10 +347,7 @@ func TestCompatibilityObjectStoreWatchUpdates(t *testing.T) { Object string `json:"object"` } - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -421,10 +400,7 @@ func TestCompatibilityObjectStoreGetLink(t *testing.T) { Object string `json:"object"` } - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) @@ -481,10 +457,7 @@ func TestCompatibilityObjectStorePutLink(t *testing.T) { LinkName string `json:"link_name"` } - nc, err := nats.Connect(nats.DefaultURL, nats.Timeout(1*time.Hour)) - if err != nil { - t.Fatalf("Error connecting to NATS: %v", err) - } + nc := connect(t) js, err := jetstream.New(nc) if err != nil { t.Fatalf("Error connecting to NATS: %v", err) diff --git a/test/configs/docker/Dockerfile b/test/configs/docker/Dockerfile index 037430d83..96c69c7f4 100644 --- a/test/configs/docker/Dockerfile +++ b/test/configs/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.20 +FROM golang:1.22 WORKDIR /usr/src/nats.go COPY . /usr/src/nats.go RUN go mod tidy -modfile go_test.mod diff --git a/test/conn_test.go b/test/conn_test.go index 36d602c96..1a4705092 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -277,7 +277,7 @@ func TestClientTLSConfig(t *testing.T) { pool := x509.NewCertPool() ok := pool.AppendCertsFromPEM(rootCAs) if !ok { - return nil, fmt.Errorf("nats: failed to parse root certificate from") + return nil, errors.New("nats: failed to parse root certificate from") } return pool, nil } @@ -614,7 +614,7 @@ func TestErrOnConnectAndDeadlock(t *testing.T) { nc, err := nats.Connect(natsURL) if err == nil { nc.Close() - errCh <- fmt.Errorf("expected bad INFO err, got none") + errCh <- errors.New("expected bad INFO err, got none") return } errCh <- nil @@ -1094,16 +1094,21 @@ func TestCallbacksOrder(t *testing.T) { } func TestConnectHandler(t *testing.T) { + handler := func(ch chan bool) func(*nats.Conn) { + return func(*nats.Conn) { + ch <- true + } + } t.Run("with RetryOnFailedConnect, connection established", func(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() connected := make(chan bool) - connHandler := func(*nats.Conn) { - connected <- true - } + reconnected := make(chan bool) + nc, err := nats.Connect(nats.DefaultURL, - nats.ConnectHandler(connHandler), + nats.ConnectHandler(handler(connected)), + nats.ReconnectHandler(handler(reconnected)), nats.RetryOnFailedConnect(true)) if err != nil { @@ -1113,24 +1118,28 @@ func TestConnectHandler(t *testing.T) { if err = Wait(connected); err != nil { t.Fatal("Timeout waiting for connect handler") } + if err = WaitTime(reconnected, 100*time.Millisecond); err == nil { + t.Fatal("Reconnect handler should not have been invoked") + } }) t.Run("with RetryOnFailedConnect, connection failed", func(t *testing.T) { connected := make(chan bool) - connHandler := func(*nats.Conn) { - connected <- true - } + reconnected := make(chan bool) + nc, err := nats.Connect(nats.DefaultURL, - nats.ConnectHandler(connHandler), + nats.ConnectHandler(handler(connected)), + nats.ReconnectHandler(handler(reconnected)), nats.RetryOnFailedConnect(true)) if err != nil { t.Fatalf("Unexpected error: %v", err) } defer nc.Close() - select { - case <-connected: - t.Fatalf("ConnectedCB invoked when no connection established") - case <-time.After(100 * time.Millisecond): + if err = WaitTime(connected, 100*time.Millisecond); err == nil { + t.Fatal("Connected handler should not have been invoked") + } + if err = WaitTime(reconnected, 100*time.Millisecond); err == nil { + t.Fatal("Reconnect handler should not have been invoked") } }) t.Run("no RetryOnFailedConnect, connection established", func(t *testing.T) { @@ -1138,11 +1147,11 @@ func TestConnectHandler(t *testing.T) { defer s.Shutdown() connected := make(chan bool) - connHandler := func(*nats.Conn) { - connected <- true - } + reconnected := make(chan bool) nc, err := nats.Connect(nats.DefaultURL, - nats.ConnectHandler(connHandler)) + nats.ConnectHandler(handler(connected)), + nats.ReconnectHandler(handler(reconnected))) + if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -1150,22 +1159,94 @@ func TestConnectHandler(t *testing.T) { if err = Wait(connected); err != nil { t.Fatal("Timeout waiting for connect handler") } + if err = WaitTime(reconnected, 100*time.Millisecond); err == nil { + t.Fatal("Reconnect handler should not have been invoked") + } }) t.Run("no RetryOnFailedConnect, connection failed", func(t *testing.T) { connected := make(chan bool) - connHandler := func(*nats.Conn) { - connected <- true - } + reconnected := make(chan bool) _, err := nats.Connect(nats.DefaultURL, - nats.ConnectHandler(connHandler)) + nats.ConnectHandler(handler(connected)), + nats.ReconnectHandler(handler(reconnected))) if err == nil { t.Fatalf("Expected error on connect, got nil") } - select { - case <-connected: - t.Fatalf("ConnectedCB invoked when no connection established") - case <-time.After(100 * time.Millisecond): + if err = WaitTime(connected, 100*time.Millisecond); err == nil { + t.Fatal("Connected handler should not have been invoked") + } + if err = WaitTime(reconnected, 100*time.Millisecond); err == nil { + t.Fatal("Reconnect handler should not have been invoked") + } + }) + t.Run("with RetryOnFailedConnect, initial connection failed, reconnect successful", func(t *testing.T) { + connected := make(chan bool) + reconnected := make(chan bool) + + nc, err := nats.Connect(nats.DefaultURL, + nats.ConnectHandler(handler(connected)), + nats.ReconnectHandler(handler(reconnected)), + nats.RetryOnFailedConnect(true), + nats.ReconnectWait(100*time.Millisecond)) + + if err != nil { + t.Fatalf("Expected error on connect, got nil") + } + + defer nc.Close() + + s := RunDefaultServer() + defer s.Shutdown() + + if err != nil { + t.Fatalf("Expected error on connect, got nil") + } + if err = Wait(connected); err != nil { + t.Fatal("Timeout waiting for reconnect handler") + } + if err = WaitTime(reconnected, 100*time.Millisecond); err == nil { + t.Fatal("Reconnect handler should not have been invoked") + } + }) + t.Run("with RetryOnFailedConnect, initial connection successful, server restart", func(t *testing.T) { + connected := make(chan bool) + reconnected := make(chan bool) + + s := RunDefaultServer() + defer s.Shutdown() + + nc, err := nats.Connect(nats.DefaultURL, + nats.ConnectHandler(handler(connected)), + nats.ReconnectHandler(handler(reconnected)), + nats.RetryOnFailedConnect(true), + nats.ReconnectWait(100*time.Millisecond)) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if err = Wait(connected); err != nil { + t.Fatal("Timeout waiting for connect handler") + } + if err = WaitTime(reconnected, 100*time.Millisecond); err == nil { + t.Fatal("Reconnect handler should not have been invoked") + } + + s.Shutdown() + + s = RunDefaultServer() + defer s.Shutdown() + + if err = Wait(reconnected); err != nil { + t.Fatal("Timeout waiting for reconnect handler") + } + if err = WaitTime(connected, 100*time.Millisecond); err == nil { + t.Fatal("Connected handler should not have been invoked") } }) } @@ -1668,7 +1749,7 @@ type customDialer struct { func (cd *customDialer) Dial(network, address string) (net.Conn, error) { cd.ch <- true - return nil, fmt.Errorf("on purpose") + return nil, errors.New("on purpose") } func TestUseCustomDialer(t *testing.T) { @@ -1766,8 +1847,8 @@ func TestDefaultOptionsDialer(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() - opts1 := nats.DefaultOptions - opts2 := nats.DefaultOptions + opts1 := nats.GetDefaultOptions() + opts2 := nats.GetDefaultOptions() nc1, err := opts1.Connect() if err != nil { @@ -2709,7 +2790,8 @@ func TestRetryOnFailedConnect(t *testing.T) { nc.Close() t.Fatal("Expected error, did not get one") } - ch := make(chan bool, 1) + reconnectedCh := make(chan bool, 1) + connectedCh := make(chan bool, 1) dch := make(chan bool, 1) nc, err = nats.Connect(nats.DefaultURL, nats.RetryOnFailedConnect(true), @@ -2718,8 +2800,11 @@ func TestRetryOnFailedConnect(t *testing.T) { nats.DisconnectErrHandler(func(_ *nats.Conn, _ error) { dch <- true }), + nats.ConnectHandler(func(_ *nats.Conn) { + connectedCh <- true + }), nats.ReconnectHandler(func(_ *nats.Conn) { - ch <- true + reconnectedCh <- true }), nats.NoCallbacksAfterClientClose()) if err != nil { @@ -2737,19 +2822,19 @@ func TestRetryOnFailedConnect(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() - var action string switch i { case 0: - action = "connected" + select { + case <-connectedCh: + case <-time.After(2 * time.Second): + t.Fatal("Should have connected") + } case 1: - action = "reconnected" - } - - // Wait for the reconnect CB which in this context means that we connected ok - select { - case <-ch: - case <-time.After(2 * time.Second): - t.Fatalf("Should have %s", action) + select { + case <-reconnectedCh: + case <-time.After(2 * time.Second): + t.Fatal("Should have reconnected") + } } // Now make sure that the pub worked and sub worked. @@ -2782,7 +2867,7 @@ func TestRetryOnFailedConnect(t *testing.T) { nats.MaxReconnects(-1), nats.ReconnectWait(15*time.Millisecond), nats.ReconnectHandler(func(_ *nats.Conn) { - ch <- true + reconnectedCh <- true }), nats.ClosedHandler(func(_ *nats.Conn) { closedCh <- true @@ -2807,7 +2892,7 @@ func TestRetryOnFailedConnect(t *testing.T) { } // Make sure that we did not get the (re)connected CB select { - case <-ch: + case <-reconnectedCh: t.Fatal("(re)connected callback should not have been invoked") default: } @@ -2830,14 +2915,14 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) { s := RunServerWithOptions(&opts) defer s.Shutdown() - ch := make(chan bool, 1) + connectedCh := make(chan bool, 1) nc, err := nats.Connect(nats.DefaultURL, nats.Secure(&tls.Config{InsecureSkipVerify: true}), nats.RetryOnFailedConnect(true), nats.MaxReconnects(-1), nats.ReconnectWait(15*time.Millisecond), - nats.ReconnectHandler(func(_ *nats.Conn) { - ch <- true + nats.ConnectHandler(func(_ *nats.Conn) { + connectedCh <- true }), nats.NoCallbacksAfterClientClose()) if err != nil { @@ -2854,23 +2939,13 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) { defer s.Shutdown() select { - case <-ch: + case <-connectedCh: case <-time.After(time.Second): t.Fatal("Should have connected") } } func TestConnStatusChangedEvents(t *testing.T) { - waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) { - select { - case s := <-ch: - if s != expected { - t.Fatalf("Expected status: %s; got: %s", expected, s) - } - case <-time.After(5 * time.Second): - t.Fatalf("Timeout waiting for status %q", expected) - } - } t.Run("default events", func(t *testing.T) { s := RunDefaultServer() nc, err := nats.Connect(s.ClientURL()) @@ -2893,15 +2968,15 @@ func TestConnStatusChangedEvents(t *testing.T) { time.Sleep(50 * time.Millisecond) s.Shutdown() - waitForStatus(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.RECONNECTING) s = RunDefaultServer() defer s.Shutdown() - waitForStatus(t, newStatus, nats.CONNECTED) + WaitOnChannel(t, newStatus, nats.CONNECTED) nc.Close() - waitForStatus(t, newStatus, nats.CLOSED) + WaitOnChannel(t, newStatus, nats.CLOSED) select { case s := <-newStatus: @@ -2934,7 +3009,7 @@ func TestConnStatusChangedEvents(t *testing.T) { s = RunDefaultServer() defer s.Shutdown() nc.Close() - waitForStatus(t, newStatus, nats.CLOSED) + WaitOnChannel(t, newStatus, nats.CLOSED) select { case s := <-newStatus: diff --git a/test/context_test.go b/test/context_test.go index b9c2f24f6..f2df307a6 100644 --- a/test/context_test.go +++ b/test/context_test.go @@ -654,324 +654,6 @@ func TestContextSubNextMsgWithDeadline(t *testing.T) { } } -func TestContextEncodedRequestWithTimeout(t *testing.T) { - s := RunDefaultServer() - defer s.Shutdown() - - nc := NewDefaultConnection(t) - c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) - if err != nil { - t.Fatalf("Unable to create encoded connection: %v", err) - } - defer c.Close() - - deadline := time.Now().Add(100 * time.Millisecond) - ctx, cancelCB := context.WithDeadline(context.Background(), deadline) - defer cancelCB() // should always be called, not discarded, to prevent context leak - - type request struct { - Message string `json:"message"` - } - type response struct { - Code int `json:"code"` - } - c.Subscribe("slow", func(_, reply string, req *request) { - got := req.Message - expected := "Hello" - if got != expected { - t.Errorf("Expected to receive request with %q, got %q", got, expected) - } - - // simulates latency into the client so that timeout is hit. - time.Sleep(40 * time.Millisecond) - c.Publish(reply, &response{Code: 200}) - }) - - for i := 0; i < 2; i++ { - req := &request{Message: "Hello"} - resp := &response{} - err := c.RequestWithContext(ctx, "slow", req, resp) - if err != nil { - t.Fatalf("Expected encoded request with context to not fail: %s", err) - } - got := resp.Code - expected := 200 - if got != expected { - t.Errorf("Expected to receive %v, got: %v", expected, got) - } - } - - // A third request with latency would make the context - // reach the deadline. - req := &request{Message: "Hello"} - resp := &response{} - err = c.RequestWithContext(ctx, "slow", req, resp) - if err == nil { - t.Fatal("Expected request with context to reach deadline") - } - - // Reported error is "context deadline exceeded" from Context package, - // which implements net.Error Timeout interface. - type timeoutError interface { - Timeout() bool - } - timeoutErr, ok := err.(timeoutError) - if !ok || !timeoutErr.Timeout() { - t.Errorf("Expected to have a timeout error") - } - expected := `context deadline exceeded` - if !strings.Contains(err.Error(), expected) { - t.Errorf("Expected %q error, got: %q", expected, err.Error()) - } -} - -func TestContextEncodedRequestWithTimeoutCanceled(t *testing.T) { - s := RunDefaultServer() - defer s.Shutdown() - - nc := NewDefaultConnection(t) - c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) - if err != nil { - t.Fatalf("Unable to create encoded connection: %v", err) - } - defer c.Close() - - ctx, cancelCB := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancelCB() // should always be called, not discarded, to prevent context leak - - type request struct { - Message string `json:"message"` - } - type response struct { - Code int `json:"code"` - } - - c.Subscribe("fast", func(_, reply string, req *request) { - got := req.Message - expected := "Hello" - if got != expected { - t.Errorf("Expected to receive request with %q, got %q", got, expected) - } - - // simulates latency into the client so that timeout is hit. - time.Sleep(40 * time.Millisecond) - - c.Publish(reply, &response{Code: 200}) - }) - - // Fast request should not fail - req := &request{Message: "Hello"} - resp := &response{} - c.RequestWithContext(ctx, "fast", req, resp) - expectedCode := 200 - if resp.Code != expectedCode { - t.Errorf("Expected to receive %d, got: %d", expectedCode, resp.Code) - } - - // Cancel the context already so that rest of requests fail. - cancelCB() - - err = c.RequestWithContext(ctx, "fast", req, resp) - if err == nil { - t.Fatal("Expected request with timeout context to fail") - } - - // Reported error is "context canceled" from Context package, - // which is not a timeout error. - type timeoutError interface { - Timeout() bool - } - if _, ok := err.(timeoutError); ok { - t.Errorf("Expected to not have a timeout error") - } - expected := `context canceled` - if !strings.Contains(err.Error(), expected) { - t.Errorf("Expected %q error, got: %q", expected, err.Error()) - } - - // 2nd request should fail again even if fast because context has already been canceled - err = c.RequestWithContext(ctx, "fast", req, resp) - if err == nil { - t.Fatal("Expected request with timeout context to fail") - } -} - -func TestContextEncodedRequestWithCancel(t *testing.T) { - s := RunDefaultServer() - defer s.Shutdown() - - nc := NewDefaultConnection(t) - c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) - if err != nil { - t.Fatalf("Unable to create encoded connection: %v", err) - } - defer c.Close() - - ctx, cancelCB := context.WithCancel(context.Background()) - defer cancelCB() // should always be called, not discarded, to prevent context leak - - // timer which cancels the context though can also be arbitrarily extended - expirationTimer := time.AfterFunc(100*time.Millisecond, func() { - cancelCB() - }) - - type request struct { - Message string `json:"message"` - } - type response struct { - Code int `json:"code"` - } - c.Subscribe("slow", func(_, reply string, req *request) { - got := req.Message - expected := "Hello" - if got != expected { - t.Errorf("Expected to receive request with %q, got %q", got, expected) - } - - // simulates latency into the client so that timeout is hit. - time.Sleep(40 * time.Millisecond) - c.Publish(reply, &response{Code: 200}) - }) - c.Subscribe("slower", func(_, reply string, req *request) { - got := req.Message - expected := "World" - if got != expected { - t.Errorf("Expected to receive request with %q, got %q", got, expected) - } - - // we know this request will take longer so extend the timeout - expirationTimer.Reset(100 * time.Millisecond) - - // slower reply which would have hit original timeout - time.Sleep(90 * time.Millisecond) - c.Publish(reply, &response{Code: 200}) - }) - - for i := 0; i < 2; i++ { - req := &request{Message: "Hello"} - resp := &response{} - err := c.RequestWithContext(ctx, "slow", req, resp) - if err != nil { - t.Fatalf("Expected encoded request with context to not fail: %s", err) - } - got := resp.Code - expected := 200 - if got != expected { - t.Errorf("Expected to receive %v, got: %v", expected, got) - } - } - - // A third request with latency would make the context - // get canceled, but these reset the timer so deadline - // gets extended: - for i := 0; i < 10; i++ { - req := &request{Message: "World"} - resp := &response{} - err := c.RequestWithContext(ctx, "slower", req, resp) - if err != nil { - t.Fatalf("Expected request with context to not fail: %s", err) - } - got := resp.Code - expected := 200 - if got != expected { - t.Errorf("Expected to receive %d, got: %d", expected, got) - } - } - - req := &request{Message: "Hello"} - resp := &response{} - - // One more slow request will expire the timer and cause an error... - err = c.RequestWithContext(ctx, "slow", req, resp) - if err == nil { - t.Fatal("Expected request with cancellation context to fail") - } - - // ...though reported error is "context canceled" from Context package, - // which is not a timeout error. - type timeoutError interface { - Timeout() bool - } - if _, ok := err.(timeoutError); ok { - t.Errorf("Expected to not have a timeout error") - } - expected := `context canceled` - if !strings.Contains(err.Error(), expected) { - t.Errorf("Expected %q error, got: %q", expected, err.Error()) - } -} - -func TestContextEncodedRequestWithDeadline(t *testing.T) { - s := RunDefaultServer() - defer s.Shutdown() - - nc := NewDefaultConnection(t) - c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) - if err != nil { - t.Fatalf("Unable to create encoded connection: %v", err) - } - defer c.Close() - - deadline := time.Now().Add(100 * time.Millisecond) - ctx, cancelCB := context.WithDeadline(context.Background(), deadline) - defer cancelCB() // should always be called, not discarded, to prevent context leak - - type request struct { - Message string `json:"message"` - } - type response struct { - Code int `json:"code"` - } - c.Subscribe("slow", func(_, reply string, req *request) { - got := req.Message - expected := "Hello" - if got != expected { - t.Errorf("Expected to receive request with %q, got %q", got, expected) - } - - // simulates latency into the client so that timeout is hit. - time.Sleep(40 * time.Millisecond) - c.Publish(reply, &response{Code: 200}) - }) - - for i := 0; i < 2; i++ { - req := &request{Message: "Hello"} - resp := &response{} - err := c.RequestWithContext(ctx, "slow", req, resp) - if err != nil { - t.Fatalf("Expected encoded request with context to not fail: %s", err) - } - got := resp.Code - expected := 200 - if got != expected { - t.Errorf("Expected to receive %v, got: %v", expected, got) - } - } - - // A third request with latency would make the context - // reach the deadline. - req := &request{Message: "Hello"} - resp := &response{} - err = c.RequestWithContext(ctx, "slow", req, resp) - if err == nil { - t.Fatal("Expected request with context to reach deadline") - } - - // Reported error is "context deadline exceeded" from Context package, - // which implements net.Error Timeout interface. - type timeoutError interface { - Timeout() bool - } - timeoutErr, ok := err.(timeoutError) - if !ok || !timeoutErr.Timeout() { - t.Errorf("Expected to have a timeout error") - } - expected := `context deadline exceeded` - if !strings.Contains(err.Error(), expected) { - t.Errorf("Expected %q error, got: %q", expected, err.Error()) - } -} - func TestContextRequestConnClosed(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() @@ -1026,58 +708,6 @@ func TestContextBadSubscription(t *testing.T) { } } -func TestContextInvalid(t *testing.T) { - s := RunDefaultServer() - defer s.Shutdown() - - nc := NewDefaultConnection(t) - c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) - if err != nil { - t.Fatalf("Unable to create encoded connection: %v", err) - } - defer c.Close() - - //lint:ignore SA1012 testing that passing nil fails - _, err = nc.RequestWithContext(nil, "foo", []byte("")) - if err == nil { - t.Fatal("Expected request to fail with error") - } - if err != nats.ErrInvalidContext { - t.Errorf("Expected request to fail with connection closed error: %s", err) - } - - sub, err := nc.Subscribe("foo", func(_ *nats.Msg) {}) - if err != nil { - t.Fatalf("Expected to be able to subscribe: %s", err) - } - - //lint:ignore SA1012 testing that passing nil fails - _, err = sub.NextMsgWithContext(nil) - if err == nil { - t.Fatal("Expected request to fail with error") - } - if err != nats.ErrInvalidContext { - t.Errorf("Expected request to fail with connection closed error: %s", err) - } - - type request struct { - Message string `json:"message"` - } - type response struct { - Code int `json:"code"` - } - req := &request{Message: "Hello"} - resp := &response{} - //lint:ignore SA1012 testing that passing nil fails - err = c.RequestWithContext(nil, "slow", req, resp) - if err == nil { - t.Fatal("Expected request to fail with error") - } - if err != nats.ErrInvalidContext { - t.Errorf("Expected request to fail with invalid context: %s", err) - } -} - func TestFlushWithContext(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() @@ -1148,3 +778,34 @@ func TestUnsubscribeAndNextMsgWithContext(t *testing.T) { } wg.Wait() } + +func TestContextInvalid(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + //lint:ignore SA1012 testing that passing nil fails + _, err := nc.RequestWithContext(nil, "foo", []byte("")) + if err == nil { + t.Fatal("Expected request to fail with error") + } + if err != nats.ErrInvalidContext { + t.Errorf("Expected request to fail with connection closed error: %s", err) + } + + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) {}) + if err != nil { + t.Fatalf("Expected to be able to subscribe: %s", err) + } + + //lint:ignore SA1012 testing that passing nil fails + _, err = sub.NextMsgWithContext(nil) + if err == nil { + t.Fatal("Expected request to fail with error") + } + if err != nats.ErrInvalidContext { + t.Errorf("Expected request to fail with connection closed error: %s", err) + } +} diff --git a/test/drain_test.go b/test/drain_test.go index 1168f617f..da07c8967 100644 --- a/test/drain_test.go +++ b/test/drain_test.go @@ -1,4 +1,4 @@ -// Copyright 2018-2023 The NATS Authors +// Copyright 2018-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -14,6 +14,7 @@ package test import ( + "errors" "fmt" "sync" "sync/atomic" @@ -55,6 +56,9 @@ func TestDrain(t *testing.T) { // Drain it and make sure we receive all messages. sub.Drain() + if !sub.IsDraining() { + t.Fatalf("Expected to be draining") + } select { case <-done: break @@ -64,6 +68,10 @@ func TestDrain(t *testing.T) { t.Fatalf("Did not receive all messages: %d of %d", r, expected) } } + time.Sleep(100 * time.Millisecond) + if sub.IsDraining() { + t.Fatalf("Expected to be done draining") + } } func TestDrainQueueSub(t *testing.T) { @@ -218,7 +226,7 @@ func TestDrainSlowSubscriber(t *testing.T) { // Wait for it to become invalid. Once drained it is unsubscribed. _, _, err := sub.Pending() if err != nats.ErrBadSubscription { - return fmt.Errorf("Still valid") + return errors.New("Still valid") } r := int(atomic.LoadInt32(&received)) if r != total { @@ -471,7 +479,7 @@ func TestDrainConnDuringReconnect(t *testing.T) { if nc.IsReconnecting() { return nil } - return fmt.Errorf("Not reconnecting yet") + return errors.New("Not reconnecting yet") }) // This should work correctly. diff --git a/test/enc_test.go b/test/enc_test.go index e40abbf06..c8109e7af 100644 --- a/test/enc_test.go +++ b/test/enc_test.go @@ -15,7 +15,9 @@ package test import ( "bytes" + "context" "fmt" + "strings" "testing" "time" @@ -25,6 +27,8 @@ import ( "github.com/nats-io/nats.go/encoders/protobuf/testdata" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + func NewDefaultEConn(t *testing.T) *nats.EncodedConn { ec, err := nats.NewEncodedConn(NewConnection(t, TEST_PORT), nats.DEFAULT_ENCODER) if err != nil { @@ -753,3 +757,350 @@ func TestRequestGOB(t *testing.T) { t.Fatalf("Did not receive proper response, %+v", reply) } } + +func TestContextEncodedRequestWithTimeout(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + deadline := time.Now().Add(100 * time.Millisecond) + ctx, cancelCB := context.WithDeadline(context.Background(), deadline) + defer cancelCB() // should always be called, not discarded, to prevent context leak + + type request struct { + Message string `json:"message"` + } + type response struct { + Code int `json:"code"` + } + c.Subscribe("slow", func(_, reply string, req *request) { + got := req.Message + expected := "Hello" + if got != expected { + t.Errorf("Expected to receive request with %q, got %q", got, expected) + } + + // simulates latency into the client so that timeout is hit. + time.Sleep(40 * time.Millisecond) + c.Publish(reply, &response{Code: 200}) + }) + + for i := 0; i < 2; i++ { + req := &request{Message: "Hello"} + resp := &response{} + err := c.RequestWithContext(ctx, "slow", req, resp) + if err != nil { + t.Fatalf("Expected encoded request with context to not fail: %s", err) + } + got := resp.Code + expected := 200 + if got != expected { + t.Errorf("Expected to receive %v, got: %v", expected, got) + } + } + + // A third request with latency would make the context + // reach the deadline. + req := &request{Message: "Hello"} + resp := &response{} + err = c.RequestWithContext(ctx, "slow", req, resp) + if err == nil { + t.Fatal("Expected request with context to reach deadline") + } + + // Reported error is "context deadline exceeded" from Context package, + // which implements net.Error Timeout interface. + type timeoutError interface { + Timeout() bool + } + timeoutErr, ok := err.(timeoutError) + if !ok || !timeoutErr.Timeout() { + t.Errorf("Expected to have a timeout error") + } + expected := `context deadline exceeded` + if !strings.Contains(err.Error(), expected) { + t.Errorf("Expected %q error, got: %q", expected, err.Error()) + } +} + +func TestContextEncodedRequestWithTimeoutCanceled(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + ctx, cancelCB := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelCB() // should always be called, not discarded, to prevent context leak + + type request struct { + Message string `json:"message"` + } + type response struct { + Code int `json:"code"` + } + + c.Subscribe("fast", func(_, reply string, req *request) { + got := req.Message + expected := "Hello" + if got != expected { + t.Errorf("Expected to receive request with %q, got %q", got, expected) + } + + // simulates latency into the client so that timeout is hit. + time.Sleep(40 * time.Millisecond) + + c.Publish(reply, &response{Code: 200}) + }) + + // Fast request should not fail + req := &request{Message: "Hello"} + resp := &response{} + c.RequestWithContext(ctx, "fast", req, resp) + expectedCode := 200 + if resp.Code != expectedCode { + t.Errorf("Expected to receive %d, got: %d", expectedCode, resp.Code) + } + + // Cancel the context already so that rest of requests fail. + cancelCB() + + err = c.RequestWithContext(ctx, "fast", req, resp) + if err == nil { + t.Fatal("Expected request with timeout context to fail") + } + + // Reported error is "context canceled" from Context package, + // which is not a timeout error. + type timeoutError interface { + Timeout() bool + } + if _, ok := err.(timeoutError); ok { + t.Errorf("Expected to not have a timeout error") + } + expected := `context canceled` + if !strings.Contains(err.Error(), expected) { + t.Errorf("Expected %q error, got: %q", expected, err.Error()) + } + + // 2nd request should fail again even if fast because context has already been canceled + err = c.RequestWithContext(ctx, "fast", req, resp) + if err == nil { + t.Fatal("Expected request with timeout context to fail") + } +} + +func TestContextEncodedRequestWithCancel(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + ctx, cancelCB := context.WithCancel(context.Background()) + defer cancelCB() // should always be called, not discarded, to prevent context leak + + // timer which cancels the context though can also be arbitrarily extended + expirationTimer := time.AfterFunc(100*time.Millisecond, func() { + cancelCB() + }) + + type request struct { + Message string `json:"message"` + } + type response struct { + Code int `json:"code"` + } + c.Subscribe("slow", func(_, reply string, req *request) { + got := req.Message + expected := "Hello" + if got != expected { + t.Errorf("Expected to receive request with %q, got %q", got, expected) + } + + // simulates latency into the client so that timeout is hit. + time.Sleep(40 * time.Millisecond) + c.Publish(reply, &response{Code: 200}) + }) + c.Subscribe("slower", func(_, reply string, req *request) { + got := req.Message + expected := "World" + if got != expected { + t.Errorf("Expected to receive request with %q, got %q", got, expected) + } + + // we know this request will take longer so extend the timeout + expirationTimer.Reset(100 * time.Millisecond) + + // slower reply which would have hit original timeout + time.Sleep(90 * time.Millisecond) + c.Publish(reply, &response{Code: 200}) + }) + + for i := 0; i < 2; i++ { + req := &request{Message: "Hello"} + resp := &response{} + err := c.RequestWithContext(ctx, "slow", req, resp) + if err != nil { + t.Fatalf("Expected encoded request with context to not fail: %s", err) + } + got := resp.Code + expected := 200 + if got != expected { + t.Errorf("Expected to receive %v, got: %v", expected, got) + } + } + + // A third request with latency would make the context + // get canceled, but these reset the timer so deadline + // gets extended: + for i := 0; i < 10; i++ { + req := &request{Message: "World"} + resp := &response{} + err := c.RequestWithContext(ctx, "slower", req, resp) + if err != nil { + t.Fatalf("Expected request with context to not fail: %s", err) + } + got := resp.Code + expected := 200 + if got != expected { + t.Errorf("Expected to receive %d, got: %d", expected, got) + } + } + + req := &request{Message: "Hello"} + resp := &response{} + + // One more slow request will expire the timer and cause an error... + err = c.RequestWithContext(ctx, "slow", req, resp) + if err == nil { + t.Fatal("Expected request with cancellation context to fail") + } + + // ...though reported error is "context canceled" from Context package, + // which is not a timeout error. + type timeoutError interface { + Timeout() bool + } + if _, ok := err.(timeoutError); ok { + t.Errorf("Expected to not have a timeout error") + } + expected := `context canceled` + if !strings.Contains(err.Error(), expected) { + t.Errorf("Expected %q error, got: %q", expected, err.Error()) + } +} + +func TestContextEncodedRequestWithDeadline(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + deadline := time.Now().Add(100 * time.Millisecond) + ctx, cancelCB := context.WithDeadline(context.Background(), deadline) + defer cancelCB() // should always be called, not discarded, to prevent context leak + + type request struct { + Message string `json:"message"` + } + type response struct { + Code int `json:"code"` + } + c.Subscribe("slow", func(_, reply string, req *request) { + got := req.Message + expected := "Hello" + if got != expected { + t.Errorf("Expected to receive request with %q, got %q", got, expected) + } + + // simulates latency into the client so that timeout is hit. + time.Sleep(40 * time.Millisecond) + c.Publish(reply, &response{Code: 200}) + }) + + for i := 0; i < 2; i++ { + req := &request{Message: "Hello"} + resp := &response{} + err := c.RequestWithContext(ctx, "slow", req, resp) + if err != nil { + t.Fatalf("Expected encoded request with context to not fail: %s", err) + } + got := resp.Code + expected := 200 + if got != expected { + t.Errorf("Expected to receive %v, got: %v", expected, got) + } + } + + // A third request with latency would make the context + // reach the deadline. + req := &request{Message: "Hello"} + resp := &response{} + err = c.RequestWithContext(ctx, "slow", req, resp) + if err == nil { + t.Fatal("Expected request with context to reach deadline") + } + + // Reported error is "context deadline exceeded" from Context package, + // which implements net.Error Timeout interface. + type timeoutError interface { + Timeout() bool + } + timeoutErr, ok := err.(timeoutError) + if !ok || !timeoutErr.Timeout() { + t.Errorf("Expected to have a timeout error") + } + expected := `context deadline exceeded` + if !strings.Contains(err.Error(), expected) { + t.Errorf("Expected %q error, got: %q", expected, err.Error()) + } +} + +func TestEncodedContextInvalid(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + c, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) + if err != nil { + t.Fatalf("Unable to create encoded connection: %v", err) + } + defer c.Close() + + type request struct { + Message string `json:"message"` + } + type response struct { + Code int `json:"code"` + } + req := &request{Message: "Hello"} + resp := &response{} + //lint:ignore SA1012 testing that passing nil fails + err = c.RequestWithContext(nil, "slow", req, resp) + if err == nil { + t.Fatal("Expected request to fail with error") + } + if err != nats.ErrInvalidContext { + t.Errorf("Expected request to fail with invalid context: %s", err) + } +} diff --git a/test/gob_test.go b/test/gob_test.go index c772e1074..f326f8c9f 100644 --- a/test/gob_test.go +++ b/test/gob_test.go @@ -20,6 +20,8 @@ import ( "github.com/nats-io/nats.go" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + func NewGobEncodedConn(tl TestLogger) *nats.EncodedConn { ec, err := nats.NewEncodedConn(NewConnection(tl, TEST_PORT), nats.GOB_ENCODER) if err != nil { diff --git a/test/helper_test.go b/test/helper_test.go index 9c04a40f9..36d47e000 100644 --- a/test/helper_test.go +++ b/test/helper_test.go @@ -54,6 +54,18 @@ func WaitTime(ch chan bool, timeout time.Duration) error { return errors.New("timeout") } +func WaitOnChannel[T comparable](t *testing.T, ch <-chan T, expected T) { + t.Helper() + select { + case s := <-ch: + if s != expected { + t.Fatalf("Expected result: %v; got: %v", expected, s) + } + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for result %v", expected) + } +} + func stackFatalf(t tLogger, f string, args ...any) { lines := make([]string, 0, 32) msg := fmt.Sprintf(f, args...) @@ -91,15 +103,6 @@ func NewConnection(t tLogger, port int) *nats.Conn { return nc } -// NewEConn -func NewEConn(t tLogger) *nats.EncodedConn { - ec, err := nats.NewEncodedConn(NewDefaultConnection(t), nats.DEFAULT_ENCODER) - if err != nil { - t.Fatalf("Failed to create an encoded connection: %v\n", err) - } - return ec -} - //////////////////////////////////////////////////////////////////////////////// // Running nats server in separate Go routines //////////////////////////////////////////////////////////////////////////////// diff --git a/test/js_test.go b/test/js_test.go index 900792a34..db791eb50 100644 --- a/test/js_test.go +++ b/test/js_test.go @@ -1,4 +1,4 @@ -// Copyright 2020-2023 The NATS Authors +// Copyright 2020-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -1239,6 +1239,64 @@ func TestPullSubscribeFetchWithHeartbeat(t *testing.T) { } } +func TestPullSubscribeFetchDrain(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + _, err := js.AddStream(&nats.StreamConfig{ + Name: "TEST", + Subjects: []string{"foo"}, + }) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + defer js.PurgeStream("TEST") + sub, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for i := 0; i < 100; i++ { + if _, err := js.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + } + // fill buffer with messages + cinfo, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name) + replySubject := strings.Replace(sub.Subject, "*", "abc", 1) + payload := `{"batch":10,"no_wait":true}` + if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + time.Sleep(100 * time.Millisecond) + + // now drain the subscription, messages should be in the buffer + sub.Drain() + msgs, err := sub.Fetch(100) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for _, msg := range msgs { + msg.Ack() + } + if len(msgs) != 10 { + t.Fatalf("Expected %d messages; got: %d", 10, len(msgs)) + } + + // subsequent fetch should return error, subscription is already drained + _, err = sub.Fetch(10, nats.MaxWait(100*time.Millisecond)) + if !errors.Is(err, nats.ErrSubscriptionClosed) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err) + } +} + func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) { s := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, s) @@ -1299,7 +1357,7 @@ func TestPullSubscribeFetchBatchWithHeartbeat(t *testing.T) { if msgs.Error() != nil { t.Fatalf("Unexpected error: %s", msgs.Error()) } - if elapsed < 290*time.Millisecond { + if elapsed < 250*time.Millisecond { t.Fatalf("Expected timeout after 300ms; got: %v", elapsed) } @@ -1761,6 +1819,55 @@ func TestPullSubscribeFetchBatch(t *testing.T) { t.Errorf("Expected error: %s; got: %s", nats.ErrNoDeadlineContext, err) } }) + + t.Run("close subscription", func(t *testing.T) { + defer js.PurgeStream("TEST") + sub, err := js.PullSubscribe("foo", "") + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + for i := 0; i < 100; i++ { + if _, err := js.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + } + // fill buffer with messages + cinfo, err := sub.ConsumerInfo() + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + nextSubject := fmt.Sprintf("$JS.API.CONSUMER.MSG.NEXT.TEST.%s", cinfo.Name) + replySubject := strings.Replace(sub.Subject, "*", "abc", 1) + payload := `{"batch":10,"no_wait":true}` + if err := nc.PublishRequest(nextSubject, replySubject, []byte(payload)); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + time.Sleep(100 * time.Millisecond) + + // now drain the subscription, messages should be in the buffer + sub.Drain() + res, err := sub.FetchBatch(100) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + msgs := make([]*nats.Msg, 0) + for msg := range res.Messages() { + msgs = append(msgs, msg) + msg.Ack() + } + if res.Error() != nil { + t.Fatalf("Unexpected error: %s", res.Error()) + } + if len(msgs) != 10 { + t.Fatalf("Expected %d messages; got: %d", 10, len(msgs)) + } + + // subsequent fetch should return error, subscription is already drained + _, err = sub.FetchBatch(10, nats.MaxWait(100*time.Millisecond)) + if !errors.Is(err, nats.ErrSubscriptionClosed) { + t.Fatalf("Expected error: %s; got: %s", nats.ErrSubscriptionClosed, err) + } + }) } func TestPullSubscribeConsumerDeleted(t *testing.T) { @@ -2497,6 +2604,15 @@ func TestJetStreamManagement(t *testing.T) { } }) + t.Run("with invalid filter subject", func(t *testing.T) { + if _, err = js.AddConsumer("foo", &nats.ConsumerConfig{Name: "tc", FilterSubject: ".foo"}); !errors.Is(err, nats.ErrInvalidFilterSubject) { + t.Fatalf("Expected: %v; got: %v", nats.ErrInvalidFilterSubject, err) + } + if _, err = js.AddConsumer("foo", &nats.ConsumerConfig{Name: "tc", FilterSubject: "foo."}); !errors.Is(err, nats.ErrInvalidFilterSubject) { + t.Fatalf("Expected: %v; got: %v", nats.ErrInvalidFilterSubject, err) + } + }) + t.Run("with invalid consumer name", func(t *testing.T) { if _, err = js.AddConsumer("foo", &nats.ConsumerConfig{Durable: "test.durable"}); err != nats.ErrInvalidConsumerName { t.Fatalf("Expected: %v; got: %v", nats.ErrInvalidConsumerName, err) @@ -7646,7 +7762,7 @@ func testJetStreamFetchOptions(t *testing.T, srvs ...*jsServer) { if err == nil { t.Fatal("Unexpected success") } - if err != nats.ErrBadSubscription { + if !errors.Is(err, nats.ErrBadSubscription) { t.Fatalf("Unexpected error: %v", err) } }) @@ -7969,6 +8085,70 @@ func TestPublishAsyncResetPendingOnReconnect(t *testing.T) { } } +func TestPublishAsyncRetryInErrHandler(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + streamCreated := make(chan struct{}) + errCB := func(js nats.JetStream, m *nats.Msg, e error) { + <-streamCreated + _, err := js.PublishMsgAsync(m) + if err != nil { + t.Fatalf("Unexpected error when republishing: %v", err) + } + } + + js, err := nc.JetStream(nats.PublishAsyncErrHandler(errCB)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + errs := make(chan error, 1) + done := make(chan struct{}, 1) + go func() { + for i := 0; i < 10; i++ { + if _, err := js.PublishAsync("FOO.A", []byte("hello")); err != nil { + errs <- err + return + } + } + done <- struct{}{} + }() + select { + case <-done: + case err := <-errs: + t.Fatalf("Unexpected error during publish: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + _, err = js.AddStream(&nats.StreamConfig{Name: "foo", Subjects: []string{"FOO.*"}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + close(streamCreated) + select { + case <-js.PublishAsyncComplete(): + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + info, err := js.StreamInfo("foo") + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if info.State.Msgs != 10 { + t.Fatalf("Expected 10 messages in the stream; got: %d", info.State.Msgs) + } +} + func TestJetStreamPublishAsyncPerf(t *testing.T) { // Comment out below to run this benchmark. t.SkipNow() @@ -8027,6 +8207,215 @@ func TestJetStreamPublishAsyncPerf(t *testing.T) { fmt.Printf("%.0f msgs/sec\n\n", float64(toSend)/tt.Seconds()) } +func TestPublishAsyncRetry(t *testing.T) { + tests := []struct { + name string + pubOpts []nats.PubOpt + ackError error + pubErr error + }{ + { + name: "retry until stream is ready", + pubOpts: []nats.PubOpt{ + nats.RetryAttempts(10), + nats.RetryWait(100 * time.Millisecond), + }, + }, + { + name: "fail after max retries", + pubOpts: []nats.PubOpt{ + nats.RetryAttempts(2), + nats.RetryWait(50 * time.Millisecond), + }, + ackError: nats.ErrNoResponders, + }, + { + name: "no retries", + pubOpts: nil, + ackError: nats.ErrNoResponders, + }, + { + name: "invalid retry attempts", + pubOpts: []nats.PubOpt{ + nats.RetryAttempts(-1), + }, + pubErr: nats.ErrInvalidArg, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // set max pending to 1 so that we can test if retries don't cause stall + js, err := nc.JetStream(nats.PublishAsyncMaxPending(1)) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + test.pubOpts = append(test.pubOpts, nats.StallWait(1*time.Nanosecond)) + ack, err := js.PublishAsync("foo", []byte("hello"), test.pubOpts...) + if !errors.Is(err, test.pubErr) { + t.Fatalf("Expected error: %v; got: %v", test.pubErr, err) + } + if err != nil { + return + } + errs := make(chan error, 1) + go func() { + // create stream with delay so that publish will receive no responders + time.Sleep(300 * time.Millisecond) + if _, err := js.AddStream(&nats.StreamConfig{Name: "TEST", Subjects: []string{"foo"}}); err != nil { + errs <- err + } + }() + select { + case <-ack.Ok(): + case err := <-ack.Err(): + if test.ackError != nil { + if !errors.Is(err, test.ackError) { + t.Fatalf("Expected error: %v; got: %v", test.ackError, err) + } + } else { + t.Fatalf("Unexpected ack error: %v", err) + } + case err := <-errs: + t.Fatalf("Error creating stream: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for ack") + } + }) + } +} +func TestJetStreamCleanupPublisher(t *testing.T) { + + t.Run("cleanup js publisher", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + // Create a stream. + if _, err := js.AddStream(&nats.StreamConfig{Name: "TEST", Subjects: []string{"FOO"}}); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + numSubs := nc.NumSubscriptions() + if _, err := js.PublishAsync("FOO", []byte("hello")); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + select { + case <-js.PublishAsyncComplete(): + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + if numSubs+1 != nc.NumSubscriptions() { + t.Fatalf("Expected an additional subscription after publish, got %d", nc.NumSubscriptions()) + } + + js.CleanupPublisher() + + if numSubs != nc.NumSubscriptions() { + t.Fatalf("Expected subscriptions to be back to original count") + } + }) + + t.Run("cleanup js publisher, cancel pending acks", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, err := nats.Connect(s.ClientURL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + cbErr := make(chan error, 10) + js, err := nc.JetStream(nats.PublishAsyncErrHandler(func(js nats.JetStream, m *nats.Msg, err error) { + cbErr <- err + })) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Create a stream with NoAck so that we can test that we cancel ack futures. + if _, err := js.AddStream(&nats.StreamConfig{Name: "TEST", Subjects: []string{"FOO"}, NoAck: true}); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + numSubs := nc.NumSubscriptions() + + var acks []nats.PubAckFuture + for i := 0; i < 10; i++ { + ack, err := js.PublishAsync("FOO", []byte("hello")) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + acks = append(acks, ack) + } + + asyncComplete := js.PublishAsyncComplete() + select { + case <-asyncComplete: + t.Fatalf("Should not complete, NoAck is set") + case <-time.After(200 * time.Millisecond): + } + + if numSubs+1 != nc.NumSubscriptions() { + t.Fatalf("Expected an additional subscription after publish, got %d", nc.NumSubscriptions()) + } + + js.CleanupPublisher() + + if numSubs != nc.NumSubscriptions() { + t.Fatalf("Expected subscriptions to be back to original count") + } + + // check that PublishAsyncComplete channel is closed + select { + case <-asyncComplete: + case <-time.After(5 * time.Second): + t.Fatalf("Did not receive completion signal") + } + + // check that all ack futures are canceled + for _, ack := range acks { + select { + case err := <-ack.Err(): + if !errors.Is(err, nats.ErrJetStreamPublisherClosed) { + t.Fatalf("Expected JetStreamContextClosed error, got %v", err) + } + case <-ack.Ok(): + t.Fatalf("Expected error on the ack future") + case <-time.After(200 * time.Millisecond): + t.Fatalf("Expected an error on the ack future") + } + } + + // check that async error handler is called for each pending ack + for i := 0; i < 10; i++ { + select { + case err := <-cbErr: + if !errors.Is(err, nats.ErrJetStreamPublisherClosed) { + t.Fatalf("Expected JetStreamContextClosed error, got %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("Expected errors to be passed from the async handler") + } + } + }) + +} + func TestJetStreamPublishExpectZero(t *testing.T) { s := RunBasicJetStreamServer() defer shutdownJSServerAndRemoveStorage(t, s) @@ -9037,7 +9426,7 @@ func TestJetStreamClusterStreamLeaderChangeClientErr(t *testing.T) { return err } if si.Cluster.Leader == "" { - return fmt.Errorf("No leader yet") + return errors.New("No leader yet") } return nil }) diff --git a/test/json_test.go b/test/json_test.go index 4ef6f42c5..cebbf31ec 100644 --- a/test/json_test.go +++ b/test/json_test.go @@ -22,6 +22,8 @@ import ( "github.com/nats-io/nats.go/encoders/builtin" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + func NewJsonEncodedConn(tl TestLogger) *nats.EncodedConn { ec, err := nats.NewEncodedConn(NewConnection(tl, TEST_PORT), nats.JSON_ENCODER) if err != nil { diff --git a/test/kv_test.go b/test/kv_test.go index 4f5d81edc..94703bc43 100644 --- a/test/kv_test.go +++ b/test/kv_test.go @@ -179,6 +179,22 @@ func TestKeyValueWatch(t *testing.T) { } } } + expectPurgeF := func(t *testing.T, watcher nats.KeyWatcher) func(key string, revision uint64) { + return func(key string, revision uint64) { + t.Helper() + select { + case v := <-watcher.Updates(): + if v.Operation() != nats.KeyValuePurge { + t.Fatalf("Expected a delete operation but got %+v", v) + } + if v.Revision() != revision { + t.Fatalf("Did not get expected revision: %d vs %d", revision, v.Revision()) + } + case <-time.After(time.Second): + t.Fatalf("Did not receive an update like expected") + } + } + } expectInitDoneF := func(t *testing.T, watcher nats.KeyWatcher) func() { return func() { t.Helper() @@ -237,13 +253,27 @@ func TestKeyValueWatch(t *testing.T) { watcher, err = kv.Watch("t.*") expectOk(t, err) - defer watcher.Stop() expectInitDone = expectInitDoneF(t, watcher) expectUpdate = expectUpdateF(t, watcher) expectUpdate("t.name", "ik", 8) expectUpdate("t.age", "44", 10) expectInitDone() + watcher.Stop() + + // test watcher with multiple filters + watcher, err = kv.WatchFiltered([]string{"t.name", "name"}) + expectOk(t, err) + expectInitDone = expectInitDoneF(t, watcher) + expectUpdate = expectUpdateF(t, watcher) + expectPurge := expectPurgeF(t, watcher) + expectUpdate("name", "ik", 3) + expectUpdate("t.name", "ik", 8) + expectInitDone() + err = kv.Purge("name") + expectOk(t, err) + expectPurge("name", 11) + defer watcher.Stop() }) t.Run("watcher with history included", func(t *testing.T) { @@ -362,6 +392,68 @@ func TestKeyValueWatch(t *testing.T) { kv.Put("t.age", []byte("66")) expectUpdate("t.age", "66", 12) }) + + t.Run("invalid watchers", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + kv, err := js.CreateKeyValue(&nats.KeyValueConfig{Bucket: "WATCH"}) + expectOk(t, err) + + // empty keys + _, err = kv.Watch("") + expectErr(t, err, nats.ErrInvalidKey) + + // invalid key + _, err = kv.Watch("a.>.b") + expectErr(t, err, nats.ErrInvalidKey) + + _, err = kv.Watch("foo.") + expectErr(t, err, nats.ErrInvalidKey) + }) + + t.Run("filtered watch with no filters", func(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + kv, err := js.CreateKeyValue(&nats.KeyValueConfig{Bucket: "WATCH"}) + expectOk(t, err) + + // this should behave like WatchAll + watcher, err := kv.WatchFiltered([]string{}) + expectOk(t, err) + defer watcher.Stop() + + expectInitDone := expectInitDoneF(t, watcher) + expectUpdate := expectUpdateF(t, watcher) + expectDelete := expectDeleteF(t, watcher) + // Make sure we already got an initial value marker. + expectInitDone() + + _, err = kv.Create("name", []byte("derek")) + expectOk(t, err) + expectUpdate("name", "derek", 1) + _, err = kv.Put("name", []byte("rip")) + expectOk(t, err) + expectUpdate("name", "rip", 2) + _, err = kv.Put("name", []byte("ik")) + expectOk(t, err) + expectUpdate("name", "ik", 3) + _, err = kv.Put("age", []byte("22")) + expectOk(t, err) + expectUpdate("age", "22", 4) + _, err = kv.Put("age", []byte("33")) + expectOk(t, err) + expectUpdate("age", "33", 5) + expectOk(t, kv.Delete("age")) + expectDelete("age", 6) + }) } func TestKeyValueWatchContext(t *testing.T) { @@ -1010,7 +1102,7 @@ func expectErr(t *testing.T, err error, expected ...error) { return } for _, e := range expected { - if err == e || strings.Contains(e.Error(), err.Error()) { + if errors.Is(err, e) { return } } @@ -1091,7 +1183,7 @@ func TestKeyValueMirrorCrossDomains(t *testing.T) { checkFor(t, 10*time.Second, 10*time.Millisecond, func() error { _, err := kv.Get(key) if err == nil { - return fmt.Errorf("Expected key to be gone") + return errors.New("Expected key to be gone") } if !errors.Is(err, nats.ErrKeyNotFound) { return err @@ -1376,11 +1468,46 @@ func TestKeyValueCreate(t *testing.T) { nc, js := jsClient(t, s) defer nc.Close() - kv, err := js.CreateKeyValue(&nats.KeyValueConfig{Bucket: "TEST"}) + kv, err := js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: "TEST", + Description: "Test KV", + MaxValueSize: 128, + History: 10, + TTL: 1 * time.Hour, + MaxBytes: 1024, + Storage: nats.FileStorage, + }) if err != nil { t.Fatalf("Error creating kv: %v", err) } + expectedStreamConfig := nats.StreamConfig{ + Name: "KV_TEST", + Description: "Test KV", + Subjects: []string{"$KV.TEST.>"}, + MaxMsgs: -1, + MaxBytes: 1024, + Discard: nats.DiscardNew, + MaxAge: 1 * time.Hour, + MaxMsgsPerSubject: 10, + MaxMsgSize: 128, + Storage: nats.FileStorage, + DenyDelete: true, + AllowRollup: true, + AllowDirect: true, + MaxConsumers: -1, + Replicas: 1, + Duplicates: 2 * time.Minute, + } + + si, err := js.StreamInfo("KV_TEST") + if err != nil { + t.Fatalf("Error getting stream info: %v", err) + } + if !reflect.DeepEqual(si.Config, expectedStreamConfig) { + t.Fatalf("Expected stream config to be %+v, got %+v", expectedStreamConfig, si.Config) + } + _, err = kv.Create("key", []byte("1")) if err != nil { t.Fatalf("Error creating key: %v", err) @@ -1452,7 +1579,6 @@ func TestKeyValueSourcing(t *testing.T) { t.Fatalf("Error creating kv: %v", err) } - // Wait half a second to make sure it has time to populate the stream from it's sources i := 0 for { status, err := kvC.Status() @@ -1463,11 +1589,11 @@ func TestKeyValueSourcing(t *testing.T) { break } else { i++ - if i > 3 { + if i > 10 { t.Fatalf("Error sourcing bucket does not contain the expected number of values") } } - time.Sleep(20 * time.Millisecond) + time.Sleep(100 * time.Millisecond) } if _, err := kvC.Get("keyA"); err != nil { diff --git a/test/netchan_test.go b/test/netchan_test.go index b7272a909..f21d32921 100644 --- a/test/netchan_test.go +++ b/test/netchan_test.go @@ -20,6 +20,17 @@ import ( "github.com/nats-io/nats.go" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + +// NewEConn +func NewEConn(t tLogger) *nats.EncodedConn { + ec, err := nats.NewEncodedConn(NewDefaultConnection(t), nats.DEFAULT_ENCODER) + if err != nil { + t.Fatalf("Failed to create an encoded connection: %v\n", err) + } + return ec +} + func TestBadChan(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() diff --git a/test/object_test.go b/test/object_test.go index e4a0171ab..f6ecb57a2 100644 --- a/test/object_test.go +++ b/test/object_test.go @@ -1114,3 +1114,71 @@ func TestObjectStoreCompression(t *testing.T) { t.Fatalf("Expected stream to be compressed with S2") } } + +func TestObjectStoreMirror(t *testing.T) { + s := RunBasicJetStreamServer() + defer shutdownJSServerAndRemoveStorage(t, s) + + nc, js := jsClient(t, s) + defer nc.Close() + + bucketName := "test-bucket" + + obs, err := js.CreateObjectStore(&nats.ObjectStoreConfig{Bucket: bucketName, Description: "testing"}) + expectOk(t, err) + + mirrorBucketName := "mirror-test-bucket" + + _, err = js.AddStream(&nats.StreamConfig{ + Name: fmt.Sprintf("OBJ_%s", mirrorBucketName), + Mirror: &nats.StreamSource{ + Name: fmt.Sprintf("OBJ_%s", bucketName), + SubjectTransforms: []nats.SubjectTransformConfig{ + { + Source: fmt.Sprintf("$O.%s.>", bucketName), + Destination: fmt.Sprintf("$O.%s.>", mirrorBucketName), + }, + }, + }, + AllowRollup: true, // meta messages are always rollups + }) + if err != nil { + t.Fatalf("Error creating object store bucket mirror: %v", err) + } + + _, err = obs.PutString("A", "abc") + expectOk(t, err) + + mirrorObs, err := js.ObjectStore(mirrorBucketName) + expectOk(t, err) + + // Make sure we sync. + checkFor(t, 2*time.Second, 15*time.Millisecond, func() error { + mirrorValue, err := mirrorObs.GetString("A") + if err != nil { + return err + } + if mirrorValue != "abc" { + t.Fatalf("Expected mirrored object store value to be the same as original") + } + return nil + }) + + watcher, err := mirrorObs.Watch() + if err != nil { + t.Fatalf("Error creating watcher: %v", err) + } + defer watcher.Stop() + + // expect to get one value and nil + for { + select { + case info := <-watcher.Updates(): + if info == nil { + return + } + case <-time.After(2 * time.Second): + t.Fatalf("Expected to receive an update") + } + } +} diff --git a/test/protobuf_test.go b/test/protobuf_test.go index 08fdf6771..d4bee8c85 100644 --- a/test/protobuf_test.go +++ b/test/protobuf_test.go @@ -25,6 +25,8 @@ import ( pb "github.com/nats-io/nats.go/encoders/protobuf/testdata" ) +//lint:file-ignore SA1019 Ignore deprecation warnings for EncodedConn + func NewProtoEncodedConn(tl TestLogger) *nats.EncodedConn { ec, err := nats.NewEncodedConn(NewConnection(tl, TEST_PORT), protobuf.PROTOBUF_ENCODER) if err != nil { diff --git a/test/reconnect_test.go b/test/reconnect_test.go index b1b398442..9fc1b2311 100644 --- a/test/reconnect_test.go +++ b/test/reconnect_test.go @@ -14,16 +14,20 @@ package test import ( + "errors" "fmt" "net" "net/url" + "strconv" "sync" "sync/atomic" "testing" "time" + "github.com/nats-io/jwt" "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" + "github.com/nats-io/nkeys" ) func startReconnectServer(t *testing.T) *server.Server { @@ -175,19 +179,15 @@ func TestBasicReconnectFunctionality(t *testing.T) { t.Fatalf("Should have connected ok: %v\n", err) } defer nc.Close() - ec, err := nats.NewEncodedConn(nc, nats.DEFAULT_ENCODER) - if err != nil { - t.Fatalf("Failed to create an encoded connection: %v\n", err) - } testString := "bar" - ec.Subscribe("foo", func(s string) { - if s != testString { + nc.Subscribe("foo", func(m *nats.Msg) { + if string(m.Data) != testString { t.Fatal("String doesn't match") } ch <- true }) - ec.Flush() + nc.Flush() ts.Shutdown() // server is stopped here... @@ -196,14 +196,14 @@ func TestBasicReconnectFunctionality(t *testing.T) { t.Fatalf("Did not get the disconnected callback on time\n") } - if err := ec.Publish("foo", testString); err != nil { + if err := nc.Publish("foo", []byte("bar")); err != nil { t.Fatalf("Failed to publish message: %v\n", err) } ts = startReconnectServer(t) defer ts.Shutdown() - if err := ec.FlushTimeout(5 * time.Second); err != nil { + if err := nc.FlushTimeout(5 * time.Second); err != nil { t.Fatalf("Error on Flush: %v", err) } @@ -212,7 +212,7 @@ func TestBasicReconnectFunctionality(t *testing.T) { } expectedReconnectCount := uint64(1) - reconnectCount := ec.Conn.Stats().Reconnects + reconnectCount := nc.Stats().Reconnects if reconnectCount != expectedReconnectCount { t.Fatalf("Reconnect count incorrect: %d vs %d\n", @@ -238,23 +238,20 @@ func TestExtendedReconnectFunctionality(t *testing.T) { t.Fatalf("Should have connected ok: %v", err) } defer nc.Close() - ec, err := nats.NewEncodedConn(nc, nats.DEFAULT_ENCODER) - if err != nil { - t.Fatalf("Failed to create an encoded connection: %v\n", err) - } + testString := "bar" received := int32(0) - ec.Subscribe("foo", func(s string) { + nc.Subscribe("foo", func(*nats.Msg) { atomic.AddInt32(&received, 1) }) - sub, _ := ec.Subscribe("foobar", func(s string) { + sub, _ := nc.Subscribe("foobar", func(*nats.Msg) { atomic.AddInt32(&received, 1) }) - ec.Publish("foo", testString) - ec.Flush() + nc.Publish("foo", []byte(testString)) + nc.Flush() ts.Shutdown() // server is stopped here.. @@ -265,18 +262,18 @@ func TestExtendedReconnectFunctionality(t *testing.T) { } // Sub while disconnected - ec.Subscribe("bar", func(s string) { + nc.Subscribe("bar", func(*nats.Msg) { atomic.AddInt32(&received, 1) }) // Unsub foobar while disconnected sub.Unsubscribe() - if err = ec.Publish("foo", testString); err != nil { + if err = nc.Publish("foo", []byte(testString)); err != nil { t.Fatalf("Received an error after disconnect: %v\n", err) } - if err = ec.Publish("bar", testString); err != nil { + if err = nc.Publish("bar", []byte(testString)); err != nil { t.Fatalf("Received an error after disconnect: %v\n", err) } @@ -289,19 +286,19 @@ func TestExtendedReconnectFunctionality(t *testing.T) { t.Fatal("Did not receive a reconnect callback message") } - if err = ec.Publish("foobar", testString); err != nil { + if err = nc.Publish("foobar", []byte(testString)); err != nil { t.Fatalf("Received an error after server restarted: %v\n", err) } - if err = ec.Publish("foo", testString); err != nil { + if err = nc.Publish("foo", []byte(testString)); err != nil { t.Fatalf("Received an error after server restarted: %v\n", err) } ch := make(chan bool) - ec.Subscribe("done", func(b bool) { + nc.Subscribe("done", func(*nats.Msg) { ch <- true }) - ec.Publish("done", true) + nc.Publish("done", nil) if e := Wait(ch); e != nil { t.Fatal("Did not receive our message") @@ -334,11 +331,6 @@ func TestQueueSubsOnReconnect(t *testing.T) { } defer nc.Close() - ec, err := nats.NewEncodedConn(nc, nats.JSON_ENCODER) - if err != nil { - t.Fatalf("Failed to create an encoded connection: %v\n", err) - } - // To hold results. results := make(map[int]int) var mu sync.Mutex @@ -361,25 +353,29 @@ func TestQueueSubsOnReconnect(t *testing.T) { subj := "foo.bar" qgroup := "workers" - cb := func(seqno int) { + cb := func(m *nats.Msg) { mu.Lock() defer mu.Unlock() + seqno, err := strconv.Atoi(string(m.Data)) + if err != nil { + t.Fatalf("Received an invalid sequence number: %v\n", err) + } results[seqno] = results[seqno] + 1 } // Create Queue Subscribers - ec.QueueSubscribe(subj, qgroup, cb) - ec.QueueSubscribe(subj, qgroup, cb) + nc.QueueSubscribe(subj, qgroup, cb) + nc.QueueSubscribe(subj, qgroup, cb) - ec.Flush() + nc.Flush() // Helper function to send messages and check results. sendAndCheckMsgs := func(numToSend int) { for i := 0; i < numToSend; i++ { - ec.Publish(subj, i) + nc.Publish(subj, []byte(fmt.Sprint(i))) } // Wait for processing. - ec.Flush() + nc.Flush() time.Sleep(50 * time.Millisecond) // Check Results @@ -826,3 +822,273 @@ func TestReconnectBufSizeDisable(t *testing.T) { t.Errorf("Unexpected buffered bytes: %v", got) } } + +func TestAuthExpiredReconnect(t *testing.T) { + ts := runTrustServer() + defer ts.Shutdown() + + _, err := nats.Connect(ts.ClientURL()) + if err == nil { + t.Fatalf("Expecting an error on connect") + } + ukp, err := nkeys.FromSeed(uSeed) + if err != nil { + t.Fatalf("Error creating user key pair: %v", err) + } + upub, err := ukp.PublicKey() + if err != nil { + t.Fatalf("Error getting user public key: %v", err) + } + akp, err := nkeys.FromSeed(aSeed) + if err != nil { + t.Fatalf("Error creating account key pair: %v", err) + } + + jwtCB := func() (string, error) { + claims := jwt.NewUserClaims("test") + claims.Expires = time.Now().Add(time.Second).Unix() + claims.Subject = upub + jwt, err := claims.Encode(akp) + if err != nil { + return "", err + } + return jwt, nil + } + sigCB := func(nonce []byte) ([]byte, error) { + kp, _ := nkeys.FromSeed(uSeed) + sig, _ := kp.Sign(nonce) + return sig, nil + } + + errCh := make(chan error, 1) + nc, err := nats.Connect(ts.ClientURL(), nats.UserJWT(jwtCB, sigCB), nats.ReconnectWait(100*time.Millisecond), + nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + errCh <- err + })) + if err != nil { + t.Fatalf("Expected to connect, got %v", err) + } + stasusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + select { + case err := <-errCh: + if !errors.Is(err, nats.ErrAuthExpired) { + t.Fatalf("Expected auth expired error, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Did not get the auth expired error") + } + WaitOnChannel(t, stasusCh, nats.RECONNECTING) + WaitOnChannel(t, stasusCh, nats.CONNECTED) + nc.Close() +} + +func TestForceReconnect(t *testing.T) { + s := RunDefaultServer() + + nc, err := nats.Connect(s.ClientURL(), nats.ReconnectWait(10*time.Second)) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // Force a reconnect + err = nc.ForceReconnect() + if err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // shutdown server and then force a reconnect + s.Shutdown() + WaitOnChannel(t, newStatus, nats.RECONNECTING) + _, err = sub.NextMsg(100 * time.Millisecond) + if err == nil { + t.Fatal("Expected error getting message") + } + + // restart server + s = RunDefaultServer() + defer s.Shutdown() + + if err := nc.ForceReconnect(); err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + // wait for the reconnect + // because the connection has long ReconnectWait, + // if force reconnect does not work, the test will timeout + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + nc.Close() +} + +func TestForceReconnectDisallowReconnect(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc, err := nats.Connect(s.ClientURL(), nats.NoReconnect()) + if err != nil { + t.Fatalf("Unexpected error on connect: %v", err) + } + defer nc.Close() + + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + + // Force a reconnect + err = nc.ForceReconnect() + if err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) + + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + _, err = sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting message: %v", err) + } + +} + +func TestAuthExpiredForceReconnect(t *testing.T) { + ts := runTrustServer() + defer ts.Shutdown() + + _, err := nats.Connect(ts.ClientURL()) + if err == nil { + t.Fatalf("Expecting an error on connect") + } + ukp, err := nkeys.FromSeed(uSeed) + if err != nil { + t.Fatalf("Error creating user key pair: %v", err) + } + upub, err := ukp.PublicKey() + if err != nil { + t.Fatalf("Error getting user public key: %v", err) + } + akp, err := nkeys.FromSeed(aSeed) + if err != nil { + t.Fatalf("Error creating account key pair: %v", err) + } + + jwtCB := func() (string, error) { + claims := jwt.NewUserClaims("test") + claims.Expires = time.Now().Add(time.Second).Unix() + claims.Subject = upub + jwt, err := claims.Encode(akp) + if err != nil { + return "", err + } + return jwt, nil + } + sigCB := func(nonce []byte) ([]byte, error) { + kp, _ := nkeys.FromSeed(uSeed) + sig, _ := kp.Sign(nonce) + return sig, nil + } + + errCh := make(chan error, 1) + nc, err := nats.Connect(ts.ClientURL(), nats.UserJWT(jwtCB, sigCB), nats.ReconnectWait(10*time.Second), + nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + errCh <- err + })) + if err != nil { + t.Fatalf("Expected to connect, got %v", err) + } + defer nc.Close() + statusCh := nc.StatusChanged(nats.RECONNECTING, nats.CONNECTED) + defer close(statusCh) + newStatus := make(chan nats.Status, 10) + // non-blocking channel, so we need to be constantly listening + go func() { + for { + s, ok := <-statusCh + if !ok { + return + } + newStatus <- s + } + }() + time.Sleep(100 * time.Millisecond) + select { + case err := <-errCh: + if !errors.Is(err, nats.ErrAuthExpired) { + t.Fatalf("Expected auth expired error, got %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Did not get the auth expired error") + } + if err := nc.ForceReconnect(); err != nil { + t.Fatalf("Unexpected error on reconnect: %v", err) + } + WaitOnChannel(t, newStatus, nats.RECONNECTING) + WaitOnChannel(t, newStatus, nats.CONNECTED) +} diff --git a/test/sub_test.go b/test/sub_test.go index c359639df..559efc50c 100644 --- a/test/sub_test.go +++ b/test/sub_test.go @@ -1,4 +1,4 @@ -// Copyright 2013-2023 The NATS Authors +// Copyright 2013-2024 The NATS Authors // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -14,7 +14,9 @@ package test import ( + "errors" "fmt" + "os" "sync" "sync/atomic" "testing" @@ -568,7 +570,7 @@ func TestAsyncErrHandler(t *testing.T) { if s != sub { t.Fatal("Did not receive proper subscription") } - if e != nats.ErrSlowConsumer { + if !errors.Is(e, nats.ErrSlowConsumer) { t.Fatalf("Did not receive proper error: %v vs %v", e, nats.ErrSlowConsumer) } // Suppress additional calls @@ -636,7 +638,7 @@ func TestAsyncErrHandlerChanSubscription(t *testing.T) { nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) { atomic.AddInt64(&aeCalled, 1) - if e != nats.ErrSlowConsumer { + if !errors.Is(e, nats.ErrSlowConsumer) { t.Fatalf("Did not receive proper error: %v vs %v", e, nats.ErrSlowConsumer) } @@ -1119,7 +1121,7 @@ func TestAsyncSubscriptionPending(t *testing.T) { } // Test old way - q, _ := sub.QueuedMsgs() + q, _, _ := sub.Pending() if q != total && q != total-1 { t.Fatalf("Expected %d or %d, got %d", total, total-1, q) } @@ -1270,7 +1272,7 @@ func TestSyncSubscriptionPending(t *testing.T) { nc.Flush() // Test old way - q, _ := sub.QueuedMsgs() + q, _, _ := sub.Pending() if q != total && q != total-1 { t.Fatalf("Expected %d or %d, got %d", total, total-1, q) } @@ -1318,10 +1320,10 @@ func TestSetPendingLimits(t *testing.T) { // Check for invalid values invalid := func() error { if err := sub.SetPendingLimits(0, 1); err == nil { - return fmt.Errorf("Setting limit with 0 should fail") + return errors.New("Setting limit with 0 should fail") } if err := sub.SetPendingLimits(1, 0); err == nil { - return fmt.Errorf("Setting limit with 0 should fail") + return errors.New("Setting limit with 0 should fail") } return nil } @@ -1614,3 +1616,157 @@ func TestSubscribe_ClosedHandler(t *testing.T) { t.Fatal("Did not receive closed callback") } } + +func TestSubscriptionEvents(t *testing.T) { + t.Run("default events", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + // disable slow consumer prints + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) {}) + defer nc.Close() + + blockChan := make(chan struct{}) + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) { + // block in subscription callback + // to force slow consumer + <-blockChan + }) + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + sub.SetPendingLimits(10, 1024) + status := sub.StatusChanged() + + // initial status + WaitOnChannel(t, status, nats.SubscriptionActive) + + for i := 0; i < 11; i++ { + nc.Publish("foo", []byte("Hello")) + } + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) + close(blockChan) + + sub.Drain() + + WaitOnChannel(t, status, nats.SubscriptionDraining) + + WaitOnChannel(t, status, nats.SubscriptionClosed) + }) + + t.Run("slow consumer event only", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + defer nc.Close() + + blockChan := make(chan struct{}) + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) { + // block in subscription callback + // to force slow consumer + <-blockChan + }) + // disable slow consumer prints + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) {}) + defer sub.Unsubscribe() + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + sub.SetPendingLimits(10, 1024) + status := sub.StatusChanged(nats.SubscriptionSlowConsumer) + + for i := 0; i < 20; i++ { + nc.Publish("foo", []byte("Hello")) + } + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) + close(blockChan) + + // now try with sync sub + sub, err = nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer sub.Unsubscribe() + sub.SetPendingLimits(10, 1024) + status = sub.StatusChanged(nats.SubscriptionSlowConsumer) + + for i := 0; i < 20; i++ { + nc.Publish("foo", []byte("Hello")) + } + WaitOnChannel(t, status, nats.SubscriptionSlowConsumer) + }) + + t.Run("do not block channel if it's not read", func(t *testing.T) { + s := RunDefaultServer() + defer s.Shutdown() + + nc := NewDefaultConnection(t) + // disable slow consumer prints + nc.SetErrorHandler(func(c *nats.Conn, s *nats.Subscription, e error) {}) + defer nc.Close() + + blockChan := make(chan struct{}) + sub, err := nc.Subscribe("foo", func(_ *nats.Msg) { + // block in subscription callback + // to force slow consumer + <-blockChan + }) + defer sub.Unsubscribe() + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + sub.SetPendingLimits(10, 1024) + status := sub.StatusChanged() + WaitOnChannel(t, status, nats.SubscriptionActive) + + // chan length is 10, so make sure we switch state more times + for i := 0; i < 20; i++ { + // subscription will enter slow consumer state + for i := 0; i < 11; i++ { + nc.Publish("foo", []byte("Hello")) + } + + // messages flow normally, status flips to active + for i := 0; i < 10; i++ { + nc.Publish("foo", []byte("Hello")) + blockChan <- struct{}{} + } + } + // do not read from subscription + close(blockChan) + }) +} + +func TestMaxSubscriptionsExceeded(t *testing.T) { + conf := createConfFile(t, []byte(` + listen: 127.0.0.1:-1 + max_subscriptions: 5 + `)) + defer os.Remove(conf) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + ch := make(chan error) + nc, err := nats.Connect(s.ClientURL(), nats.ErrorHandler(func(c *nats.Conn, s *nats.Subscription, err error) { + ch <- err + })) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + for i := 0; i < 6; i++ { + s, err := nc.Subscribe("foo", func(_ *nats.Msg) {}) + if err != nil { + t.Fatalf("Error subscribing: %v", err) + } + defer s.Unsubscribe() + } + + WaitOnChannel(t, ch, nats.ErrMaxSubscriptionsExceeded) + + // wait for the server to process the SUBs + time.Sleep(100 * time.Millisecond) +} diff --git a/test/ws_test.go b/test/ws_test.go index 15707d8b3..ee13b336f 100644 --- a/test/ws_test.go +++ b/test/ws_test.go @@ -17,6 +17,7 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "errors" "fmt" "math/rand" "net" @@ -524,7 +525,7 @@ func TestWSStress(t *testing.T) { return } if !bytes.Equal(m.Data[4:4+ps], mainPayload[:ps]) { - pushErr(fmt.Errorf("invalid content")) + pushErr(errors.New("invalid content")) return } if atomic.AddInt64(&count, 1) == totalRecv { diff --git a/ws.go b/ws.go index 2c2d421a8..fbc568845 100644 --- a/ws.go +++ b/ws.go @@ -237,8 +237,8 @@ func (r *websocketReader) Read(p []byte) (int, error) { case wsPingMessage, wsPongMessage, wsCloseMessage: if rem > wsMaxControlPayloadSize { return 0, fmt.Errorf( - fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", - wsMaxControlPayloadSize)) + "control frame length bigger than maximum allowed of %v bytes", + wsMaxControlPayloadSize) } if compressed { return 0, errors.New("control frame should not be compressed") @@ -622,7 +622,7 @@ func (nc *Conn) wsInitHandshake(u *url.URL) error { !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) { - err = fmt.Errorf("invalid websocket connection") + err = errors.New("invalid websocket connection") } // Check compression extension... if err == nil && compress { @@ -634,7 +634,7 @@ func (nc *Conn) wsInitHandshake(u *url.URL) error { if !srvCompress { compress = false } else if !noCtxTakeover { - err = fmt.Errorf("compression negotiation error") + err = errors.New("compression negotiation error") } } if resp != nil {