You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

460 lines
13 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package pubsub
  15. // TODO(jba): test keepalive
  16. // TODO(jba): test that expired messages are not kept alive
  17. // TODO(jba): test that when all messages expire, Stop returns.
  18. import (
  19. "context"
  20. "io"
  21. "strconv"
  22. "sync"
  23. "sync/atomic"
  24. "testing"
  25. "time"
  26. "cloud.google.com/go/internal/testutil"
  27. tspb "github.com/golang/protobuf/ptypes/timestamp"
  28. "github.com/google/go-cmp/cmp"
  29. "github.com/google/go-cmp/cmp/cmpopts"
  30. "google.golang.org/api/option"
  31. pb "google.golang.org/genproto/googleapis/pubsub/v1"
  32. "google.golang.org/grpc"
  33. "google.golang.org/grpc/codes"
  34. "google.golang.org/grpc/status"
  35. )
  36. var (
  37. timestamp = &tspb.Timestamp{}
  38. testMessages = []*pb.ReceivedMessage{
  39. {AckId: "0", Message: &pb.PubsubMessage{Data: []byte{1}, PublishTime: timestamp}},
  40. {AckId: "1", Message: &pb.PubsubMessage{Data: []byte{2}, PublishTime: timestamp}},
  41. {AckId: "2", Message: &pb.PubsubMessage{Data: []byte{3}, PublishTime: timestamp}},
  42. }
  43. )
  44. func TestStreamingPullBasic(t *testing.T) {
  45. client, server := newMock(t)
  46. defer server.srv.Close()
  47. defer client.Close()
  48. server.addStreamingPullMessages(testMessages)
  49. testStreamingPullIteration(t, client, server, testMessages)
  50. }
  51. func TestStreamingPullMultipleFetches(t *testing.T) {
  52. client, server := newMock(t)
  53. defer server.srv.Close()
  54. defer client.Close()
  55. server.addStreamingPullMessages(testMessages[:1])
  56. server.addStreamingPullMessages(testMessages[1:])
  57. testStreamingPullIteration(t, client, server, testMessages)
  58. }
  59. func testStreamingPullIteration(t *testing.T, client *Client, server *mockServer, msgs []*pb.ReceivedMessage) {
  60. sub := client.Subscription("S")
  61. gotMsgs, err := pullN(context.Background(), sub, len(msgs), func(_ context.Context, m *Message) {
  62. id, err := strconv.Atoi(m.ackID)
  63. if err != nil {
  64. panic(err)
  65. }
  66. // ack evens, nack odds
  67. if id%2 == 0 {
  68. m.Ack()
  69. } else {
  70. m.Nack()
  71. }
  72. })
  73. if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
  74. t.Fatalf("Pull: %v", err)
  75. }
  76. gotMap := map[string]*Message{}
  77. for _, m := range gotMsgs {
  78. gotMap[m.ackID] = m
  79. }
  80. for i, msg := range msgs {
  81. want, err := toMessage(msg)
  82. if err != nil {
  83. t.Fatal(err)
  84. }
  85. want.calledDone = true
  86. got := gotMap[want.ackID]
  87. if got == nil {
  88. t.Errorf("%d: no message for ackID %q", i, want.ackID)
  89. continue
  90. }
  91. if !testutil.Equal(got, want, cmp.AllowUnexported(Message{}), cmpopts.IgnoreTypes(time.Time{}, func(string, bool, time.Time) {})) {
  92. t.Errorf("%d: got\n%#v\nwant\n%#v", i, got, want)
  93. }
  94. }
  95. server.wait()
  96. for i := 0; i < len(msgs); i++ {
  97. id := msgs[i].AckId
  98. if i%2 == 0 {
  99. if !server.Acked[id] {
  100. t.Errorf("msg %q should have been acked but wasn't", id)
  101. }
  102. } else {
  103. if dl, ok := server.Deadlines[id]; !ok || dl != 0 {
  104. t.Errorf("msg %q should have been nacked but wasn't", id)
  105. }
  106. }
  107. }
  108. }
  109. func TestStreamingPullError(t *testing.T) {
  110. // If an RPC to the service returns a non-retryable error, Pull should
  111. // return after all callbacks return, without waiting for messages to be
  112. // acked.
  113. client, server := newMock(t)
  114. defer server.srv.Close()
  115. defer client.Close()
  116. server.addStreamingPullMessages(testMessages[:1])
  117. server.addStreamingPullError(status.Errorf(codes.Unknown, ""))
  118. sub := client.Subscription("S")
  119. // Use only one goroutine, since the fake server is configured to
  120. // return only one error.
  121. sub.ReceiveSettings.NumGoroutines = 1
  122. callbackDone := make(chan struct{})
  123. ctx, _ := context.WithTimeout(context.Background(), time.Second)
  124. err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
  125. defer close(callbackDone)
  126. <-ctx.Done()
  127. })
  128. select {
  129. case <-callbackDone:
  130. default:
  131. t.Fatal("Receive returned but callback was not done")
  132. }
  133. if want := codes.Unknown; grpc.Code(err) != want {
  134. t.Fatalf("got <%v>, want code %v", err, want)
  135. }
  136. }
  137. func TestStreamingPullCancel(t *testing.T) {
  138. // If Receive's context is canceled, it should return after all callbacks
  139. // return and all messages have been acked.
  140. client, server := newMock(t)
  141. defer server.srv.Close()
  142. defer client.Close()
  143. server.addStreamingPullMessages(testMessages)
  144. sub := client.Subscription("S")
  145. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  146. var n int32
  147. err := sub.Receive(ctx, func(ctx2 context.Context, m *Message) {
  148. atomic.AddInt32(&n, 1)
  149. defer atomic.AddInt32(&n, -1)
  150. cancel()
  151. m.Ack()
  152. })
  153. if got := atomic.LoadInt32(&n); got != 0 {
  154. t.Errorf("Receive returned with %d callbacks still running", got)
  155. }
  156. if err != nil {
  157. t.Fatalf("Receive got <%v>, want nil", err)
  158. }
  159. }
  160. func TestStreamingPullRetry(t *testing.T) {
  161. // Check that we retry on io.EOF or Unavailable.
  162. t.Parallel()
  163. client, server := newMock(t)
  164. defer server.srv.Close()
  165. defer client.Close()
  166. server.addStreamingPullMessages(testMessages[:1])
  167. server.addStreamingPullError(io.EOF)
  168. server.addStreamingPullError(io.EOF)
  169. server.addStreamingPullMessages(testMessages[1:2])
  170. server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
  171. server.addStreamingPullError(status.Errorf(codes.Unavailable, ""))
  172. server.addStreamingPullMessages(testMessages[2:])
  173. testStreamingPullIteration(t, client, server, testMessages)
  174. }
  175. func TestStreamingPullOneActive(t *testing.T) {
  176. // Only one call to Pull can be active at a time.
  177. client, srv := newMock(t)
  178. defer client.Close()
  179. defer srv.srv.Close()
  180. srv.addStreamingPullMessages(testMessages[:1])
  181. sub := client.Subscription("S")
  182. ctx, cancel := context.WithCancel(context.Background())
  183. err := sub.Receive(ctx, func(ctx context.Context, m *Message) {
  184. m.Ack()
  185. err := sub.Receive(ctx, func(context.Context, *Message) {})
  186. if err != errReceiveInProgress {
  187. t.Errorf("got <%v>, want <%v>", err, errReceiveInProgress)
  188. }
  189. cancel()
  190. })
  191. if err != nil {
  192. t.Fatalf("got <%v>, want nil", err)
  193. }
  194. }
  195. func TestStreamingPullConcurrent(t *testing.T) {
  196. newMsg := func(i int) *pb.ReceivedMessage {
  197. return &pb.ReceivedMessage{
  198. AckId: strconv.Itoa(i),
  199. Message: &pb.PubsubMessage{Data: []byte{byte(i)}, PublishTime: timestamp},
  200. }
  201. }
  202. // Multiple goroutines should be able to read from the same iterator.
  203. client, server := newMock(t)
  204. defer server.srv.Close()
  205. defer client.Close()
  206. // Add a lot of messages, a few at a time, to make sure both threads get a chance.
  207. nMessages := 100
  208. for i := 0; i < nMessages; i += 2 {
  209. server.addStreamingPullMessages([]*pb.ReceivedMessage{newMsg(i), newMsg(i + 1)})
  210. }
  211. sub := client.Subscription("S")
  212. ctx, _ := context.WithTimeout(context.Background(), time.Second)
  213. gotMsgs, err := pullN(ctx, sub, nMessages, func(ctx context.Context, m *Message) {
  214. m.Ack()
  215. })
  216. if c := status.Convert(err); err != nil && c.Code() != codes.Canceled {
  217. t.Fatalf("Pull: %v", err)
  218. }
  219. seen := map[string]bool{}
  220. for _, gm := range gotMsgs {
  221. if seen[gm.ackID] {
  222. t.Fatalf("duplicate ID %q", gm.ackID)
  223. }
  224. seen[gm.ackID] = true
  225. }
  226. if len(seen) != nMessages {
  227. t.Fatalf("got %d messages, want %d", len(seen), nMessages)
  228. }
  229. }
  230. func TestStreamingPullFlowControl(t *testing.T) {
  231. // Callback invocations should not occur if flow control limits are exceeded.
  232. client, server := newMock(t)
  233. defer server.srv.Close()
  234. defer client.Close()
  235. server.addStreamingPullMessages(testMessages)
  236. sub := client.Subscription("S")
  237. sub.ReceiveSettings.MaxOutstandingMessages = 2
  238. ctx, cancel := context.WithCancel(context.Background())
  239. activec := make(chan int)
  240. waitc := make(chan int)
  241. errc := make(chan error)
  242. go func() {
  243. errc <- sub.Receive(ctx, func(_ context.Context, m *Message) {
  244. activec <- 1
  245. <-waitc
  246. m.Ack()
  247. })
  248. }()
  249. // Here, two callbacks are active. Receive should be blocked in the flow
  250. // control acquire method on the third message.
  251. <-activec
  252. <-activec
  253. select {
  254. case <-activec:
  255. t.Fatal("third callback in progress")
  256. case <-time.After(100 * time.Millisecond):
  257. }
  258. cancel()
  259. // Receive still has not returned, because both callbacks are still blocked on waitc.
  260. select {
  261. case err := <-errc:
  262. t.Fatalf("Receive returned early with error %v", err)
  263. case <-time.After(100 * time.Millisecond):
  264. }
  265. // Let both callbacks proceed.
  266. waitc <- 1
  267. waitc <- 1
  268. // The third callback will never run, because acquire returned a non-nil
  269. // error, causing Receive to return. So now Receive should end.
  270. if err := <-errc; err != nil {
  271. t.Fatalf("got %v from Receive, want nil", err)
  272. }
  273. }
  274. func TestStreamingPull_ClosedClient(t *testing.T) {
  275. ctx := context.Background()
  276. client, server := newMock(t)
  277. defer server.srv.Close()
  278. defer client.Close()
  279. server.addStreamingPullMessages(testMessages)
  280. sub := client.Subscription("S")
  281. sub.ReceiveSettings.MaxOutstandingBytes = 1
  282. recvFinished := make(chan error)
  283. go func() {
  284. err := sub.Receive(ctx, func(_ context.Context, m *Message) {
  285. m.Ack()
  286. })
  287. recvFinished <- err
  288. }()
  289. // wait for receives to happen
  290. time.Sleep(100 * time.Millisecond)
  291. err := client.Close()
  292. if err != nil {
  293. t.Fatal(err)
  294. }
  295. // wait for things to close
  296. time.Sleep(100 * time.Millisecond)
  297. select {
  298. case recvErr := <-recvFinished:
  299. s, ok := status.FromError(recvErr)
  300. if !ok {
  301. t.Fatalf("Expected a gRPC failure, got %v", err)
  302. }
  303. if s.Code() != codes.Canceled {
  304. t.Fatalf("Expected canceled, got %v", s.Code())
  305. }
  306. case <-time.After(time.Second):
  307. t.Fatal("Receive should have exited immediately after the client was closed, but it did not")
  308. }
  309. }
  310. func TestStreamingPull_RetriesAfterUnavailable(t *testing.T) {
  311. ctx := context.Background()
  312. client, server := newMock(t)
  313. defer server.srv.Close()
  314. defer client.Close()
  315. unavail := status.Error(codes.Unavailable, "There is no connection available")
  316. server.addStreamingPullMessages(testMessages)
  317. server.addStreamingPullError(unavail)
  318. server.addAckResponse(unavail)
  319. server.addModAckResponse(unavail)
  320. server.addStreamingPullMessages(testMessages)
  321. server.addStreamingPullError(unavail)
  322. sub := client.Subscription("S")
  323. sub.ReceiveSettings.MaxOutstandingBytes = 1
  324. recvErr := make(chan error, 1)
  325. recvdMsgs := make(chan *Message, len(testMessages)*2)
  326. go func() {
  327. recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
  328. m.Ack()
  329. recvdMsgs <- m
  330. })
  331. }()
  332. // wait for receive to happen
  333. var n int
  334. for {
  335. select {
  336. case <-time.After(10 * time.Second):
  337. t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
  338. case err := <-recvErr:
  339. t.Fatal(err)
  340. case <-recvdMsgs:
  341. n++
  342. if n == len(testMessages)*2 {
  343. return
  344. }
  345. }
  346. }
  347. }
  348. func TestStreamingPull_ReconnectsAfterServerDies(t *testing.T) {
  349. ctx := context.Background()
  350. client, server := newMock(t)
  351. defer server.srv.Close()
  352. defer client.Close()
  353. server.addStreamingPullMessages(testMessages)
  354. sub := client.Subscription("S")
  355. sub.ReceiveSettings.MaxOutstandingBytes = 1
  356. recvErr := make(chan error, 1)
  357. recvdMsgs := make(chan interface{}, len(testMessages)*2)
  358. go func() {
  359. recvErr <- sub.Receive(ctx, func(_ context.Context, m *Message) {
  360. m.Ack()
  361. recvdMsgs <- struct{}{}
  362. })
  363. }()
  364. // wait for receive to happen
  365. var n int
  366. for {
  367. select {
  368. case <-time.After(5 * time.Second):
  369. t.Fatalf("timed out waiting for all message to arrive. got %d messages total", n)
  370. case err := <-recvErr:
  371. t.Fatal(err)
  372. case <-recvdMsgs:
  373. n++
  374. if n == len(testMessages) {
  375. // Restart the server
  376. server.srv.Close()
  377. server2, err := newMockServer(server.srv.Port)
  378. if err != nil {
  379. t.Fatal(err)
  380. }
  381. defer server2.srv.Close()
  382. server2.addStreamingPullMessages(testMessages)
  383. }
  384. if n == len(testMessages)*2 {
  385. return
  386. }
  387. }
  388. }
  389. }
  390. func newMock(t *testing.T) (*Client, *mockServer) {
  391. srv, err := newMockServer(0)
  392. if err != nil {
  393. t.Fatal(err)
  394. }
  395. conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
  396. if err != nil {
  397. t.Fatal(err)
  398. }
  399. client, err := NewClient(context.Background(), "P", option.WithGRPCConn(conn))
  400. if err != nil {
  401. t.Fatal(err)
  402. }
  403. return client, srv
  404. }
  405. // pullN calls sub.Receive until at least n messages are received.
  406. func pullN(ctx context.Context, sub *Subscription, n int, f func(context.Context, *Message)) ([]*Message, error) {
  407. var (
  408. mu sync.Mutex
  409. msgs []*Message
  410. )
  411. cctx, cancel := context.WithCancel(ctx)
  412. err := sub.Receive(cctx, func(ctx context.Context, m *Message) {
  413. mu.Lock()
  414. msgs = append(msgs, m)
  415. nSeen := len(msgs)
  416. mu.Unlock()
  417. f(ctx, m)
  418. if nSeen >= n {
  419. cancel()
  420. }
  421. })
  422. if err != nil {
  423. return msgs, err
  424. }
  425. return msgs, nil
  426. }