|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- // Copyright 2017 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package pubsub
-
- // TODO(jba): test keepalive
- // TODO(jba): test that expired messages are not kept alive
- // TODO(jba): test that when all messages expire, Stop returns.
-
- import (
- "context"
- "io"
- "strconv"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "cloud.google.com/go/internal/testutil"
- tspb "github.com/golang/protobuf/ptypes/timestamp"
- "github.com/google/go-cmp/cmp"
- "github.com/google/go-cmp/cmp/cmpopts"
- "google.golang.org/api/option"
- pb "google.golang.org/genproto/googleapis/pubsub/v1"
- "google.golang.org/grpc"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
- )
-
- var (
- timestamp = &tspb.Timestamp{}
- testMessages = []*pb.ReceivedMessage{
- {AckId: "0", Message: &pb.PubsubMessage{Data: []byte{1}, PublishTime: timestamp}},
- {AckId: "1", Message: &pb.PubsubMessage{Data: []byte{2}, PublishTime: timestamp}},
- {AckId: "2", Message: &pb.PubsubMessage{Data: []byte{3}, PublishTime: timestamp}},
- }
- )
-
- func TestStreamingPullBasic(t *testing.T) {
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages)
- testStreamingPullIteration(t, client, server, testMessages)
- }
-
- func TestStreamingPullMultipleFetches(t *testing.T) {
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages[:1])
- server.addStreamingPullMessages(testMessages[1:])
- testStreamingPullIteration(t, client, server, testMessages)
- }
-
- func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) {
- sub := client.Subscription("S")
- gotMsgs, err := pullN(context.Background(), sub, len(msgs), func(_ context.Context, m *Message) {
- id, err := strconv.Atoi(m.ackID)
- if err != nil {
- panic(err)
- }
- // ack evens, nack odds
- if id%2 == 0 {
- m.Ack()
- } else {
- m.Nack()
- }
- })
- if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
- t.Fatalf("Pull: %v", err)
- }
- gotMap := map[string]*Message{}
- for _, m := range gotMsgs {
- gotMap[m.ackID] = m
- }
- for i, msg := range msgs {
- want, err := toMessage(msg)
- if err != nil {
- t.Fatal(err)
- }
- want.calledDone = true
- got := gotMap[want.ackID]
- if got == nil {
- t.Errorf("%d: no message for ackID %q", i, want.ackID)
- continue
- }
- if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) {
- t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
- }
- }
- server.wait()
- for i := 0; i < len(msgs); i++ {
- id := msgs[i].AckId
- if i%2 == 0 {
- if !server.Acked[id] {
- t.Errorf("msg %q should have been acked but wasn't", id)
- }
- } else {
- if dl, ok := server.Deadlines[id]; !ok || dl != 0 {
- t.Errorf("msg %q should have been nacked but wasn't", id)
- }
- }
- }
- }
-
- func TestStreamingPullError(t *testing.T) {
- // If an RPC to the service returns a non-retryable error, Pull should
- // return after all callbacks return, without waiting for messages to be
- // acked.
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages[:1])
- server.addStreamingPullError(status.Errorf(codes.Unknown, ""))
- sub := client.Subscription("S")
- // Use only one goroutine, since the fake server is configured to
- // return only one error.
- sub.ReceiveSettings.NumGoroutines = 1
- callbackDone := make(chan struct{})
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
- err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
- defer close(callbackDone)
- <-ctx.Done()
- })
- select {
- case <-callbackDone:
- default:
- t.Fatal("Receive returned but callback was not done")
- }
- if want := codes.Unknown; grpc.Code(err) != want {
- t.Fatalf("got <%v>, want code %v", err, want)
- }
- }
-
- func TestStreamingPullCancel(t *testing.T) {
- // If Receive's context is canceled, it should return after all callbacks
- // return and all messages have been acked.
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages)
- sub := client.Subscription("S")
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- var n int32
- err := sub.Receive(ctx, func(ctx2 context.Context, m *Message) {
- atomic.AddInt32(&n, 1)
- defer atomic.AddInt32(&n, -1)
- cancel()
- m.Ack()
- })
- if got := atomic.LoadInt32(&n); got != 0 {
- t.Errorf("Receive returned with %d callbacks still running", got)
- }
- if err != nil {
- t.Fatalf("Receive got <%v>, want nil", err)
- }
- }
-
- func TestStreamingPullRetry(t *testing.T) {
- // Check that we retry on io.EOF or Unavailable.
- t.Parallel()
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages[:1])
- server.addStreamingPullError(io.EOF)
- server.addStreamingPullError(io.EOF)
- server.addStreamingPullMessages(testMessages[1:2])
- server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
- server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
- server.addStreamingPullMessages(testMessages[2:])
-
- testStreamingPullIteration(t, client, server, testMessages)
- }
-
- func TestStreamingPullOneActive(t *testing.T) {
- // Only one call to Pull can be active at a time.
- client, srv := newMock(t)
- defer client.Close()
- defer srv.srv.Close()
- srv.addStreamingPullMessages(testMessages[:1])
- sub := client.Subscription("S")
- ctx, cancel := context.WithCancel(context.Background())
- err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
- m.Ack()
- err := sub.Receive(ctx, func(context.Context, *Message) {})
- if err != errReceiveInProgress {
- t.Errorf("got <%v>, want <%v>", err, errReceiveInProgress)
- }
- cancel()
- })
- if err != nil {
- t.Fatalf("got <%v>, want nil", err)
- }
- }
-
- func TestStreamingPullConcurrent(t *testing.T) {
- newMsg := func(i int) *pb.ReceivedMessage {
- return &pb.ReceivedMessage{
- AckId: strconv.Itoa(i),
- Message: &pb.PubsubMessage{Data: []byte{byte(i)}, PublishTime: timestamp},
- }
- }
-
- // Multiple goroutines should be able to read from the same iterator.
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- // Add a lot of messages, a few at a time, to make sure both threads get a chance.
- nMessages := 100
- for i := 0; i < nMessages; i += 2 {
- server.addStreamingPullMessages([]*pb.ReceivedMessage{newMsg(i), newMsg(i + 1)})
- }
- sub := client.Subscription("S")
- ctx, _ := context.WithTimeout(context.Background(), time.Second)
- gotMsgs, err := pullN(ctx, sub, nMessages, func(ctx context.Context, m *Message) {
- m.Ack()
- })
- if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
- t.Fatalf("Pull: %v", err)
- }
- seen := map[string]bool{}
- for _, gm := range gotMsgs {
- if seen[gm.ackID] {
- t.Fatalf("duplicate ID %q", gm.ackID)
- }
- seen[gm.ackID] = true
- }
- if len(seen) != nMessages {
- t.Fatalf("got %d messages, want %d", len(seen), nMessages)
- }
- }
-
- func TestStreamingPullFlowControl(t *testing.T) {
- // Callback invocations should not occur if flow control limits are exceeded.
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages)
- sub := client.Subscription("S")
- sub.ReceiveSettings.MaxOutstandingMessages = 2
- ctx, cancel := context.WithCancel(context.Background())
- activec := make(chan int)
- waitc := make(chan int)
- errc := make(chan error)
- go func() {
- errc <- sub.Receive(ctx, func(_ context.Context, m *Message) {
- activec <- 1
- <-waitc
- m.Ack()
- })
- }()
- // Here, two callbacks are active. Receive should be blocked in the flow
- // control acquire method on the third message.
- <-activec
- <-activec
- select {
- case <-activec:
- t.Fatal("third callback in progress")
- case <-time.After(100 * time.Millisecond):
- }
- cancel()
- // Receive still has not returned, because both callbacks are still blocked on waitc.
- select {
- case err := <-errc:
- t.Fatalf("Receive returned early with error %v", err)
- case <-time.After(100 * time.Millisecond):
- }
- // Let both callbacks proceed.
- waitc <- 1
- waitc <- 1
- // The third callback will never run, because acquire returned a non-nil
- // error, causing Receive to return. So now Receive should end.
- if err := <-errc; err != nil {
- t.Fatalf("got %v from Receive, want nil", err)
- }
- }
-
- func TestStreamingPull_ClosedClient(t *testing.T) {
- ctx := context.Background()
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages)
- sub := client.Subscription("S")
- sub.ReceiveSettings.MaxOutstandingBytes = 1
- recvFinished := make(chan error)
-
- go func() {
- err := sub.Receive(ctx, func(_ context.Context, m *Message) {
- m.Ack()
- })
- recvFinished <- err
- }()
-
- // wait for receives to happen
- time.Sleep(100 * time.Millisecond)
-
- err := client.Close()
- if err != nil {
- t.Fatal(err)
- }
-
- // wait for things to close
- time.Sleep(100 * time.Millisecond)
-
- select {
- case recvErr := <-recvFinished:
- s, ok := status.FromError(recvErr)
- if !ok {
- t.Fatalf("Expected a gRPC failure, got %v", err)
- }
- if s.Code() != codes.Canceled {
- t.Fatalf("Expected canceled, got %v", s.Code())
- }
- case <-time.After(time.Second):
- t.Fatal("Receive should have exited immediately after the client was closed, but it did not")
- }
- }
-
- func TestStreamingPull_RetriesAfterUnavailable(t *testing.T) {
- ctx := context.Background()
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
-
- unavail := status.Error(codes.Unavailable, "There is no connection available")
- server.addStreamingPullMessages(testMessages)
- server.addStreamingPullError(unavail)
- server.addAckResponse(unavail)
- server.addModAckResponse(unavail)
- server.addStreamingPullMessages(testMessages)
- server.addStreamingPullError(unavail)
-
- sub := client.Subscription("S")
- sub.ReceiveSettings.MaxOutstandingBytes = 1
- recvErr := make(chan error, 1)
- recvdMsgs := make(chan *Message, len(testMessages)*2)
-
- go func() {
- recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
- m.Ack()
- recvdMsgs <- m
- })
- }()
-
- // wait for receive to happen
- var n int
- for {
- select {
- case <-time.After(10 * time.Second):
- t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
- case err := <-recvErr:
- t.Fatal(err)
- case <-recvdMsgs:
- n++
- if n == len(testMessages)*2 {
- return
- }
- }
- }
- }
-
- func TestStreamingPull_ReconnectsAfterServerDies(t *testing.T) {
- ctx := context.Background()
- client, server := newMock(t)
- defer server.srv.Close()
- defer client.Close()
- server.addStreamingPullMessages(testMessages)
- sub := client.Subscription("S")
- sub.ReceiveSettings.MaxOutstandingBytes = 1
- recvErr := make(chan error, 1)
- recvdMsgs := make(chan interface{}, len(testMessages)*2)
-
- go func() {
- recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
- m.Ack()
- recvdMsgs <- struct{}{}
- })
- }()
-
- // wait for receive to happen
- var n int
- for {
- select {
- case <-time.After(5 * time.Second):
- t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
- case err := <-recvErr:
- t.Fatal(err)
- case <-recvdMsgs:
- n++
- if n == len(testMessages) {
- // Restart the server
- server.srv.Close()
- server2, err := newMockServer(server.srv.Port)
- if err != nil {
- t.Fatal(err)
- }
- defer server2.srv.Close()
- server2.addStreamingPullMessages(testMessages)
- }
-
- if n == len(testMessages)*2 {
- return
- }
- }
- }
- }
-
- func newMock(t *testing.T) (*Client, *mockServer) {
- srv, err := newMockServer(0)
- if err != nil {
- t.Fatal(err)
- }
- conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
- if err != nil {
- t.Fatal(err)
- }
- client, err := NewClient(context.Background(), "P", option.WithGRPCConn(conn))
- if err != nil {
- t.Fatal(err)
- }
- return client, srv
- }
-
- // pullN calls sub.Receive until at least n messages are received.
- func pullN(ctx context.Context, sub *Subscription, n int, f func(context.Context, *Message)) ([]*Message, error) {
- var (
- mu sync.Mutex
- msgs []*Message
- )
- cctx, cancel := context.WithCancel(ctx)
- err := sub.Receive(cctx, func(ctx context.Context, m *Message) {
- mu.Lock()
- msgs = append(msgs, m)
- nSeen := len(msgs)
- mu.Unlock()
- f(ctx, m)
- if nSeen >= n {
- cancel()
- }
- })
- if err != nil {
- return msgs, err
- }
- return msgs, nil
- }
|